Skip to content

Commit

Permalink
feat: load the default extension collection lazily (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran authored Feb 12, 2025
1 parent 3be7ba3 commit 261dc94
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 62 deletions.
2 changes: 1 addition & 1 deletion expr/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

var (
extReg = NewEmptyExtensionRegistry(&extensions.DefaultCollection)
extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())
uPointRef = extReg.GetTypeAnchor(extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml",
Name: "point",
Expand Down
2 changes: 1 addition & 1 deletion expr/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestExprBuilder(t *testing.T) {
b := expr.ExprBuilder{
Reg: expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()),
BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct),
}
precomputedLiteral, _ := expr.NewLiteral(int32(3), false)
Expand Down
8 changes: 4 additions & 4 deletions expr/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func TestExpressionsRoundtrip(t *testing.T) {
}
// get the extension set
extSet := ext.GetExtensionSet(&plan)
reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection)
reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError())
tests := []expr.Expression{
sampleNestedExpr(reg, substraitExtURI),
}
Expand All @@ -240,7 +240,7 @@ func TestExpressionsRoundtrip(t *testing.T) {
func ExampleExpression_Visit() {
const substraitExtURI = "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
var (
exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(&ext.DefaultCollection), substraitExtURI)
exp = sampleNestedExpr(expr.NewEmptyExtensionRegistry(ext.GetDefaultCollectionWithNoError()), substraitExtURI)
preVisit, postVisit expr.VisitFunc
)

Expand Down Expand Up @@ -347,7 +347,7 @@ func TestRoundTripUsingTestData(t *testing.T) {
require.NoError(t, err)
require.NoError(t, protojson.Unmarshal(raw, &protoSchema))
baseSchema := types.NewNamedStructFromProto(&protoSchema)
reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection)
reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError())
for _, tc := range tmp["cases"].([]any) {
tt := tc.(map[string]any)
t.Run(tt["name"].(string), func(t *testing.T) {
Expand Down Expand Up @@ -403,7 +403,7 @@ func TestRoundTripExtendedExpression(t *testing.T) {
var ex proto.ExtendedExpression
require.NoError(t, protojson.Unmarshal(buf.Bytes(), &ex))

result, err := expr.ExtendedFromProto(&ex, &ext.DefaultCollection)
result, err := expr.ExtendedFromProto(&ex, ext.GetDefaultCollectionWithNoError())
require.NoError(t, err)

out := result.ToProto()
Expand Down
80 changes: 51 additions & 29 deletions extensions/extension_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
package extensions

import (
"embed"
"fmt"
"io"
"io/fs"
"path"
"sort"
"sync"

"github.com/creasty/defaults"
"github.com/goccy/go-yaml"
substrait "github.com/substrait-io/substrait"
"github.com/substrait-io/substrait"
substraitgo "github.com/substrait-io/substrait-go/v3"
"github.com/substrait-io/substrait-go/v3/proto/extensions"
)
Expand All @@ -20,45 +22,65 @@ type AdvancedExtension = extensions.AdvancedExtension

const SubstraitDefaultURIPrefix = "https://github.com/substrait-io/substrait/blob/main/extensions/"

// DefaultCollection is loaded with the default Substrait extension
// definitions. Not all files are currently parsable.
// Parser needs to enhanced to support all files
var DefaultCollection Collection
var (
getDefaultCollectionOnce = sync.OnceValues[*Collection, error](loadDefaultCollection)
unsupportedExtensions = map[string]bool{
"unknown.yaml": true,
}
)

// GetDefaultCollectionWithNoError returns a Collection that is loaded with the default Substrait extension definitions.
// This version is provided for the ease of use of legacy code. Please use GetDefaultCollection instead.
func GetDefaultCollectionWithNoError() *Collection {
c, err := GetDefaultCollection()
if err != nil {
panic(err)
}
return c
}

func init() {
// GetDefaultCollection returns a Collection that is loaded with the default Substrait extension definitions.
func GetDefaultCollection() (*Collection, error) {
return getDefaultCollectionOnce()
}

func loadDefaultCollection() (*Collection, error) {
substraitFS := substrait.GetSubstraitExtensionsFS()
entries, err := substraitFS.ReadDir("extensions")
if err != nil {
return
return nil, err
}

var defaultCollection Collection
for _, ent := range entries {
f, err := substraitFS.Open(path.Join("extensions/", ent.Name()))
if err != nil {
panic(err)
err2 := loadExtensionFile(&defaultCollection, substraitFS, ent)
if err2 != nil {
return nil, err2
}
fileStat, err := f.Stat()
}
return &defaultCollection, nil
}

func loadExtensionFile(collection *Collection, substraitFS embed.FS, ent fs.DirEntry) error {
f, err := substraitFS.Open(path.Join("extensions/", ent.Name()))
if err != nil {
return err
}
defer func() {
_ = f.Close()
}()
fileStat, err := f.Stat()
if err != nil {
return err
}
fileName := path.Base(fileStat.Name())
if _, ok := unsupportedExtensions[fileName]; !ok {
err = collection.Load(SubstraitDefaultURIPrefix+ent.Name(), f)
if err != nil {
panic(err)
}
fileName := path.Base(fileStat.Name())
// Catch and ignore load error for a file
// Currently extension grammar is not fully parseable
// There is a parser fix planned, once that is done,
// we can panic instead of ignoring failed extension file load
defer func(f fs.File, fileName string) {
if r := recover(); r != nil {
fmt.Printf("Ignoring extension file:%s, Recovered from panic: %v\n", fileName, r)
}
if err1 := f.Close(); err1 != nil {
panic(err1)
}
}(f, fileName)
err1 := DefaultCollection.Load(SubstraitDefaultURIPrefix+ent.Name(), f)
if err1 != nil {
fmt.Printf("Ignoring extension file:%s err:%v, Skipping it \n", fileName, err1)
return err
}
}
return nil
}

// ID is the unique identifier for a substrait object
Expand Down
22 changes: 11 additions & 11 deletions extensions/extension_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,11 @@ func TestDefaultCollection(t *testing.T) {
)
switch tt.typ {
case scalarFunc:
variant, ok = extensions.DefaultCollection.GetScalarFunc(id)
variant, ok = extensions.GetDefaultCollectionWithNoError().GetScalarFunc(id)
case aggFunc:
variant, ok = extensions.DefaultCollection.GetAggregateFunc(id)
variant, ok = extensions.GetDefaultCollectionWithNoError().GetAggregateFunc(id)
case windowFunc:
variant, ok = extensions.DefaultCollection.GetWindowFunc(id)
variant, ok = extensions.GetDefaultCollectionWithNoError().GetWindowFunc(id)
}

require.True(t, ok)
Expand All @@ -295,17 +295,18 @@ func TestDefaultCollection(t *testing.T) {
})
}

et, ok := extensions.DefaultCollection.GetType(extensions.ID{
et, ok := extensions.GetDefaultCollectionWithNoError().GetType(extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml", Name: "point"})
assert.True(t, ok)
assert.Equal(t, "point", et.Name)
assert.Equal(t, map[string]interface{}{"latitude": "i32", "longitude": "i32"}, et.Structure)
}

func TestCollection_GetAllScalarFunctions(t *testing.T) {
scalarFunctions := extensions.DefaultCollection.GetAllScalarFunctions()
aggregateFunctions := extensions.DefaultCollection.GetAllAggregateFunctions()
windowFunctions := extensions.DefaultCollection.GetAllWindowFunctions()
defaultExtensions := extensions.GetDefaultCollectionWithNoError()
scalarFunctions := defaultExtensions.GetAllScalarFunctions()
aggregateFunctions := defaultExtensions.GetAllAggregateFunctions()
windowFunctions := defaultExtensions.GetAllWindowFunctions()
assert.GreaterOrEqual(t, len(scalarFunctions), 309)
assert.GreaterOrEqual(t, len(aggregateFunctions), 62)
assert.GreaterOrEqual(t, len(windowFunctions), 7)
Expand All @@ -323,21 +324,20 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.signature, func(t *testing.T) {
assert.True(t, tt.isScalar || tt.isAggregate || tt.isWindow)
c := extensions.DefaultCollection
if tt.isScalar {
sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
sf, ok := defaultExtensions.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, scalarFunctions, sf)
// verify that default nullability is set to MIRROR
assert.Equal(t, extensions.MirrorNullability, sf.Nullability())
}
if tt.isAggregate {
af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
af, ok := defaultExtensions.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, aggregateFunctions, af)
}
if tt.isWindow {
wf, ok := c.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
wf, ok := defaultExtensions.GetWindowFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, windowFunctions, wf)
}
Expand Down
2 changes: 1 addition & 1 deletion extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestMatchWithSyncParams(t *testing.T) {
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, testFileInfo.numTests)

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError())
for _, tc := range testFile.TestCases {
t.Run(tc.FuncName, func(t *testing.T) {
switch tc.FuncType {
Expand Down
2 changes: 1 addition & 1 deletion functions/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
var gFunctionRegistry FunctionRegistry

func TestMain(m *testing.M) {
gFunctionRegistry = NewFunctionRegistry(&extensions.DefaultCollection)
gFunctionRegistry = NewFunctionRegistry(extensions.GetDefaultCollectionWithNoError())
m.Run()
}

Expand Down
4 changes: 2 additions & 2 deletions functions/local_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ add(120::i8, 10::i8) [overflow:SILENT] = <!UNDEFINED>
assert.Len(t, testFile.TestCases, len(testResults))
require.GreaterOrEqual(t, len(testFile.TestCases), len(testResults))

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError())
for i, result := range testResults {
tc := testFile.TestCases[i]
t.Run(result.name, func(t *testing.T) {
Expand Down Expand Up @@ -220,7 +220,7 @@ sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.000000715255737304
testCases := append(testFile.TestCases, testFile1.TestCases...)
require.GreaterOrEqual(t, len(testCases), len(testResults))

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError())
for i, result := range testResults {
tc := testCases[i]
t.Run(result.name, func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion plan/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ type Builder interface {
const FETCH_COUNT_ALL_RECORDS = -1

func NewBuilderDefault() Builder {
return NewBuilder(&extensions.DefaultCollection)
return NewBuilder(extensions.GetDefaultCollectionWithNoError())
}

func NewBuilder(c *extensions.Collection) Builder {
Expand Down
2 changes: 1 addition & 1 deletion plan/internal/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestVirtualTableExpressionFromProto(t *testing.T) {
literal1 := expr.NewPrimitiveLiteral(int32(1), false)
expr1 := literal1.ToProto()

reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection)
reg := expr.NewExtensionRegistry(extSet, ext.GetDefaultCollectionWithNoError())
rows := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{
expr1,
}}
Expand Down
6 changes: 3 additions & 3 deletions plan/plan_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestBasicEmitPlan(t *testing.T) {
protoPlan, err := p.ToProto()
require.NoError(t, err)

roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection)
roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError())
require.NoError(t, err)

assert.Equal(t, p, roundTrip)
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestEmitEmptyPlan(t *testing.T) {
protoPlan, err := p.ToProto()
require.NoError(t, err)

roundTrip, err := plan.FromProto(protoPlan, &extensions.DefaultCollection)
roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError())
require.NoError(t, err)

assert.Equal(t, p, roundTrip)
Expand Down Expand Up @@ -169,7 +169,7 @@ func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) {
assert.Truef(t, proto.Equal(&expectedProto, protoPlan), "JSON expected: %s\ngot: %s",
protojson.Format(&expectedProto), protojson.Format(protoPlan))

roundTrip, err := plan.FromProto(&expectedProto, &extensions.DefaultCollection)
roundTrip, err := plan.FromProto(&expectedProto, extensions.GetDefaultCollectionWithNoError())
require.NoError(t, err)

roundTripProto, err := roundTrip.ToProto()
Expand Down
2 changes: 1 addition & 1 deletion plan/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

func TestRelFromProto(t *testing.T) {

registry := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)
registry := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())
literal5 := &proto.Expression_Literal{LiteralType: &proto.Expression_Literal_I64{I64: 5}}
exprLiteral5 := &proto.Expression{RexType: &proto.Expression_Literal_{Literal: literal5}}

Expand Down
4 changes: 2 additions & 2 deletions plan/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func createPrimitiveBool(value bool) expr.Expression {
}

func TestRelations_Copy(t *testing.T) {
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError())
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
Expand Down Expand Up @@ -414,7 +414,7 @@ func TestRelations_Copy(t *testing.T) {
}

func TestAggregateRelToBuilder(t *testing.T) {
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError())
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
Expand Down
8 changes: 4 additions & 4 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ add(120::i8, 10::i8) [overflow:ERROR] = <!ERROR>
{&types.Int16Type{Nullability: types.NullabilityRequired}, &types.Int16Type{Nullability: types.NullabilityRequired}},
{&types.Int8Type{Nullability: types.NullabilityRequired}, &types.Int8Type{Nullability: types.NullabilityRequired}},
}
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError())
basicGroupDesc := "'Basic examples without any special cases'"
overflowGroupDesc := "Overflow examples demonstrating overflow behavior"
groupDescs := []string{basicGroupDesc, basicGroupDesc, overflowGroupDesc}
Expand Down Expand Up @@ -325,7 +325,7 @@ func TestParseAggregateFunc(t *testing.T) {
avg((1,2,3)::fp32) = 2::fp64
sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ERROR>`

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError())
arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
testFile, err := ParseTestCasesFromString(header + tests)
require.NoError(t, err)
Expand Down Expand Up @@ -544,7 +544,7 @@ LIST_AGG(t1.col0, ','::string) = 1::fp64
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, 1)
assert.Equal(t, AggregateFuncType, testFile.TestCases[0].FuncType)
reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)
reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())
aggFun, err := testFile.TestCases[0].GetAggregateFunctionInvocation(&reg, nil)
require.NoError(t, err)
assert.Equal(t, "string_agg", aggFun.Name())
Expand Down Expand Up @@ -724,7 +724,7 @@ func TestLoadAllSubstraitTestFiles(t *testing.T) {
testFile, err := ParseTestCaseFileFromFS(got, filePath)
require.NoError(t, err)
require.NotNil(t, testFile)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError())
for _, tc := range testFile.TestCases {
testGetFunctionInvocation(t, tc, &reg, funcRegistry)
}
Expand Down

0 comments on commit 261dc94

Please sign in to comment.