diff --git a/expr/binding_test.go b/expr/binding_test.go index c3f6537..2025241 100644 --- a/expr/binding_test.go +++ b/expr/binding_test.go @@ -12,7 +12,7 @@ import ( ) var ( - extReg = NewEmptyExtensionRegistry(&extensions.DefaultCollection) + extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) uPointRef = extReg.GetTypeAnchor(extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point", diff --git a/expr/builder_test.go b/expr/builder_test.go index f517ec7..e62d4b4 100644 --- a/expr/builder_test.go +++ b/expr/builder_test.go @@ -14,7 +14,7 @@ import ( func TestExprBuilder(t *testing.T) { b := expr.ExprBuilder{ - Reg: expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection), + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), } precomputedLiteral, _ := expr.NewLiteral(int32(3), false) diff --git a/expr/expressions_test.go b/expr/expressions_test.go index 9473872..a5a9fff 100644 --- a/expr/expressions_test.go +++ b/expr/expressions_test.go @@ -224,7 +224,7 @@ func TestExpressionsRoundtrip(t *testing.T) { } // get the extension set extSet := ext.GetExtensionSet(&plan) - reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError()) tests := []expr.Expression{ sampleNestedExpr(reg, substraitExtURI), } @@ -240,7 +240,7 @@ func TestExpressionsRoundtrip(t *testing.T) { func ExampleExpression_Visit() { const substraitExtURI = "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" var ( - exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(&ext.DefaultCollection), substraitExtURI) + exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(ext.GetDefaultCollectionWithNoError()), substraitExtURI) preVisit, postVisit expr.VisitFunc ) @@ -347,7 +347,7 @@ func TestRoundTripUsingTestData(t *testing.T) { require.NoError(t, err) require.NoError(t, protojson.Unmarshal(raw, &protoSchema)) baseSchema := types.NewNamedStructFromProto(&protoSchema) - reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError()) for _, tc := range tmp["cases"].([]any) { tt := tc.(map[string]any) t.Run(tt["name"].(string), func(t *testing.T) { @@ -403,7 +403,7 @@ func TestRoundTripExtendedExpression(t *testing.T) { var ex proto.ExtendedExpression require.NoError(t, protojson.Unmarshal(buf.Bytes(), &ex)) - result, err := expr.ExtendedFromProto(&ex, &ext.DefaultCollection) + result, err := expr.ExtendedFromProto(&ex, ext.GetDefaultCollectionWithNoError()) require.NoError(t, err) out := result.ToProto() diff --git a/extensions/extension_mgr.go b/extensions/extension_mgr.go index 1f46929..bbbea2b 100644 --- a/extensions/extension_mgr.go +++ b/extensions/extension_mgr.go @@ -3,15 +3,17 @@ package extensions import ( + "embed" "fmt" "io" "io/fs" "path" "sort" + "sync" "github.com/creasty/defaults" "github.com/goccy/go-yaml" - substrait "github.com/substrait-io/substrait" + "github.com/substrait-io/substrait" substraitgo "github.com/substrait-io/substrait-go/v3" "github.com/substrait-io/substrait-go/v3/proto/extensions" ) @@ -20,45 +22,65 @@ type AdvancedExtension = extensions.AdvancedExtension const SubstraitDefaultURIPrefix = "https://github.com/substrait-io/substrait/blob/main/extensions/" -// DefaultCollection is loaded with the default Substrait extension -// definitions. Not all files are currently parsable. -// Parser needs to enhanced to support all files -var DefaultCollection Collection +var ( + getDefaultCollectionOnce = sync.OnceValues[*Collection, error](loadDefaultCollection) + unsupportedExtensions = map[string]bool{ + "unknown.yaml": true, + } +) + +// GetDefaultCollectionWithNoError returns a Collection that is loaded with the default Substrait extension definitions. +// This version is provided for the ease of use of legacy code. Please use GetDefaultCollection instead. +func GetDefaultCollectionWithNoError() *Collection { + c, err := GetDefaultCollection() + if err != nil { + panic(err) + } + return c +} -func init() { +// GetDefaultCollection returns a Collection that is loaded with the default Substrait extension definitions. +func GetDefaultCollection() (*Collection, error) { + return getDefaultCollectionOnce() +} + +func loadDefaultCollection() (*Collection, error) { substraitFS := substrait.GetSubstraitExtensionsFS() entries, err := substraitFS.ReadDir("extensions") if err != nil { - return + return nil, err } + var defaultCollection Collection for _, ent := range entries { - f, err := substraitFS.Open(path.Join("extensions/", ent.Name())) - if err != nil { - panic(err) + err2 := loadExtensionFile(&defaultCollection, substraitFS, ent) + if err2 != nil { + return nil, err2 } - fileStat, err := f.Stat() + } + return &defaultCollection, nil +} + +func loadExtensionFile(collection *Collection, substraitFS embed.FS, ent fs.DirEntry) error { + f, err := substraitFS.Open(path.Join("extensions/", ent.Name())) + if err != nil { + return err + } + defer func() { + _ = f.Close() + }() + fileStat, err := f.Stat() + if err != nil { + return err + } + fileName := path.Base(fileStat.Name()) + if _, ok := unsupportedExtensions[fileName]; !ok { + err = collection.Load(SubstraitDefaultURIPrefix+ent.Name(), f) if err != nil { - panic(err) - } - fileName := path.Base(fileStat.Name()) - // Catch and ignore load error for a file - // Currently extension grammar is not fully parseable - // There is a parser fix planned, once that is done, - // we can panic instead of ignoring failed extension file load - defer func(f fs.File, fileName string) { - if r := recover(); r != nil { - fmt.Printf("Ignoring extension file:%s, Recovered from panic: %v\n", fileName, r) - } - if err1 := f.Close(); err1 != nil { - panic(err1) - } - }(f, fileName) - err1 := DefaultCollection.Load(SubstraitDefaultURIPrefix+ent.Name(), f) - if err1 != nil { - fmt.Printf("Ignoring extension file:%s err:%v, Skipping it \n", fileName, err1) + return err } } + return nil } // ID is the unique identifier for a substrait object diff --git a/extensions/extension_mgr_test.go b/extensions/extension_mgr_test.go index 8184402..0ef7bd7 100644 --- a/extensions/extension_mgr_test.go +++ b/extensions/extension_mgr_test.go @@ -277,11 +277,11 @@ func TestDefaultCollection(t *testing.T) { ) switch tt.typ { case scalarFunc: - variant, ok = extensions.DefaultCollection.GetScalarFunc(id) + variant, ok = extensions.GetDefaultCollectionWithNoError().GetScalarFunc(id) case aggFunc: - variant, ok = extensions.DefaultCollection.GetAggregateFunc(id) + variant, ok = extensions.GetDefaultCollectionWithNoError().GetAggregateFunc(id) case windowFunc: - variant, ok = extensions.DefaultCollection.GetWindowFunc(id) + variant, ok = extensions.GetDefaultCollectionWithNoError().GetWindowFunc(id) } require.True(t, ok) @@ -295,7 +295,7 @@ func TestDefaultCollection(t *testing.T) { }) } - et, ok := extensions.DefaultCollection.GetType(extensions.ID{ + et, ok := extensions.GetDefaultCollectionWithNoError().GetType(extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point"}) assert.True(t, ok) assert.Equal(t, "point", et.Name) @@ -303,9 +303,10 @@ func TestDefaultCollection(t *testing.T) { } func TestCollection_GetAllScalarFunctions(t *testing.T) { - scalarFunctions := extensions.DefaultCollection.GetAllScalarFunctions() - aggregateFunctions := extensions.DefaultCollection.GetAllAggregateFunctions() - windowFunctions := extensions.DefaultCollection.GetAllWindowFunctions() + defaultExtensions := extensions.GetDefaultCollectionWithNoError() + scalarFunctions := defaultExtensions.GetAllScalarFunctions() + aggregateFunctions := defaultExtensions.GetAllAggregateFunctions() + windowFunctions := defaultExtensions.GetAllWindowFunctions() assert.GreaterOrEqual(t, len(scalarFunctions), 309) assert.GreaterOrEqual(t, len(aggregateFunctions), 62) assert.GreaterOrEqual(t, len(windowFunctions), 7) @@ -323,21 +324,20 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) { for _, tt := range tests { t.Run(tt.signature, func(t *testing.T) { assert.True(t, tt.isScalar || tt.isAggregate || tt.isWindow) - c := extensions.DefaultCollection if tt.isScalar { - sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) + sf, ok := defaultExtensions.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) assert.True(t, ok) assert.Contains(t, scalarFunctions, sf) // verify that default nullability is set to MIRROR assert.Equal(t, extensions.MirrorNullability, sf.Nullability()) } if tt.isAggregate { - af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) + af, ok := defaultExtensions.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) assert.True(t, ok) assert.Contains(t, aggregateFunctions, af) } if tt.isWindow { - wf, ok := c.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) + wf, ok := defaultExtensions.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature}) assert.True(t, ok) assert.Contains(t, windowFunctions, wf) } diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 868e33d..b279549 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -229,7 +229,7 @@ func TestMatchWithSyncParams(t *testing.T) { require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, testFileInfo.numTests) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for _, tc := range testFile.TestCases { t.Run(tc.FuncName, func(t *testing.T) { switch tc.FuncType { diff --git a/functions/dialect_test.go b/functions/dialect_test.go index dba1735..13ade37 100644 --- a/functions/dialect_test.go +++ b/functions/dialect_test.go @@ -16,7 +16,7 @@ import ( var gFunctionRegistry FunctionRegistry func TestMain(m *testing.M) { - gFunctionRegistry = NewFunctionRegistry(&extensions.DefaultCollection) + gFunctionRegistry = NewFunctionRegistry(extensions.GetDefaultCollectionWithNoError()) m.Run() } diff --git a/functions/local_functions_test.go b/functions/local_functions_test.go index 05f2821..c1068f7 100644 --- a/functions/local_functions_test.go +++ b/functions/local_functions_test.go @@ -94,7 +94,7 @@ add(120::i8, 10::i8) [overflow:SILENT] = assert.Len(t, testFile.TestCases, len(testResults)) require.GreaterOrEqual(t, len(testFile.TestCases), len(testResults)) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for i, result := range testResults { tc := testFile.TestCases[i] t.Run(result.name, func(t *testing.T) { @@ -220,7 +220,7 @@ sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.000000715255737304 testCases := append(testFile.TestCases, testFile1.TestCases...) require.GreaterOrEqual(t, len(testCases), len(testResults)) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for i, result := range testResults { tc := testCases[i] t.Run(result.name, func(t *testing.T) { diff --git a/plan/builders.go b/plan/builders.go index 1bc7910..0e3aae9 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -154,7 +154,7 @@ type Builder interface { const FETCH_COUNT_ALL_RECORDS = -1 func NewBuilderDefault() Builder { - return NewBuilder(&extensions.DefaultCollection) + return NewBuilder(extensions.GetDefaultCollectionWithNoError()) } func NewBuilder(c *extensions.Collection) Builder { diff --git a/plan/internal/helper_test.go b/plan/internal/helper_test.go index 5686a10..ef48306 100644 --- a/plan/internal/helper_test.go +++ b/plan/internal/helper_test.go @@ -41,7 +41,7 @@ func TestVirtualTableExpressionFromProto(t *testing.T) { literal1 := expr.NewPrimitiveLiteral(int32(1), false) expr1 := literal1.ToProto() - reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection) + reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError()) rows := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{ expr1, }} diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index b8b3d58..82791a7 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -65,7 +65,7 @@ func TestBasicEmitPlan(t *testing.T) { protoPlan, err := p.ToProto() require.NoError(t, err) - roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection) + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) require.NoError(t, err) assert.Equal(t, p, roundTrip) @@ -105,7 +105,7 @@ func TestEmitEmptyPlan(t *testing.T) { protoPlan, err := p.ToProto() require.NoError(t, err) - roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection) + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) require.NoError(t, err) assert.Equal(t, p, roundTrip) @@ -169,7 +169,7 @@ func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) { assert.Truef(t, proto.Equal(&expectedProto, protoPlan), "JSON expected: %s\ngot: %s", protojson.Format(&expectedProto), protojson.Format(protoPlan)) - roundTrip, err := plan.FromProto(&expectedProto, &extensions.DefaultCollection) + roundTrip, err := plan.FromProto(&expectedProto, extensions.GetDefaultCollectionWithNoError()) require.NoError(t, err) roundTripProto, err := roundTrip.ToProto() diff --git a/plan/plan_test.go b/plan/plan_test.go index 363367b..e74d6dc 100644 --- a/plan/plan_test.go +++ b/plan/plan_test.go @@ -13,7 +13,7 @@ import ( func TestRelFromProto(t *testing.T) { - registry := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + registry := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) literal5 := &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_I64{I64: 5}} exprLiteral5 := &proto.Expression{RexType: &proto.Expression_Literal_{Literal: literal5}} diff --git a/plan/relations_test.go b/plan/relations_test.go index b261cfc..08b385e 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -28,7 +28,7 @@ func createPrimitiveBool(value bool) expr.Expression { } func TestRelations_Copy(t *testing.T) { - extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection) + extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError()) aggregateFnID := extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", Name: "avg", @@ -414,7 +414,7 @@ func TestRelations_Copy(t *testing.T) { } func TestAggregateRelToBuilder(t *testing.T) { - extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection) + extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError()) aggregateFnID := extensions.ID{ URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml", Name: "avg", diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index 86e3b63..9b5d13f 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -50,7 +50,7 @@ add(120::i8, 10::i8) [overflow:ERROR] = {&types.Int16Type{Nullability: types.NullabilityRequired}, &types.Int16Type{Nullability: types.NullabilityRequired}}, {&types.Int8Type{Nullability: types.NullabilityRequired}, &types.Int8Type{Nullability: types.NullabilityRequired}}, } - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) basicGroupDesc := "'Basic examples without any special cases'" overflowGroupDesc := "Overflow examples demonstrating overflow behavior" groupDescs := []string{basicGroupDesc, basicGroupDesc, overflowGroupDesc} @@ -325,7 +325,7 @@ func TestParseAggregateFunc(t *testing.T) { avg((1,2,3)::fp32) = 2::fp64 sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = ` - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" testFile, err := ParseTestCasesFromString(header + tests) require.NoError(t, err) @@ -544,7 +544,7 @@ LIST_AGG(t1.col0, ','::string) = 1::fp64 require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, 1) assert.Equal(t, AggregateFuncType, testFile.TestCases[0].FuncType) - reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) aggFun, err := testFile.TestCases[0].GetAggregateFunctionInvocation(®, nil) require.NoError(t, err) assert.Equal(t, "string_agg", aggFun.Name()) @@ -724,7 +724,7 @@ func TestLoadAllSubstraitTestFiles(t *testing.T) { testFile, err := ParseTestCaseFileFromFS(got, filePath) require.NoError(t, err) require.NotNil(t, testFile) - reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError()) for _, tc := range testFile.TestCases { testGetFunctionInvocation(t, tc, ®, funcRegistry) }