Skip to content

Commit

Permalink
refactor WhereCache func to import performance
Browse files Browse the repository at this point in the history
Signed-off-by: Yan Zhu <[email protected]>
  • Loading branch information
halfcrazy committed Oct 7, 2023
1 parent 89850f8 commit 50f2203
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 227 deletions.
236 changes: 137 additions & 99 deletions cache/cache_test.go

Large diffs are not rendered by default.

64 changes: 23 additions & 41 deletions client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ type API interface {
// If it has a capacity != 0, only 'capacity' elements will be filled in
List(ctx context.Context, result interface{}) error

// Create a Conditional API from a Function that is used to filter cached data
// WhereCache creates a Conditional API from a Function that is used to filter cached data
// The function must accept a Model implementation and return a boolean. E.g:
// ConditionFromFunc(func(l *LogicalSwitch) bool { return l.Enabled })
WhereCache(predicate interface{}) ConditionalAPI
// WhereCache(&LogicalSwitchPort{}, func(m model.Model) bool { l := m.(*LogicalSwitchPort); return l.Enabled })
WhereCache(model.Model, func(model.Model) bool) ConditionalAPI

// Create a ConditionalAPI from a Model's index data, where operations
// apply to elements that match the values provided in one or more
Expand Down Expand Up @@ -120,13 +120,20 @@ func (a api) List(ctx context.Context, result interface{}) error {
// structs
var appendValue func(reflect.Value)
var m model.Model
var ok bool
if resultVal.Type().Elem().Kind() == reflect.Ptr {
m = reflect.New(resultVal.Type().Elem().Elem()).Interface()
m, ok = reflect.New(resultVal.Type().Elem().Elem()).Interface().(model.Model)
if !ok {
return &ErrWrongType{resultPtr.Type(), "Expected pointer to slice of valid Models"}
}
appendValue = func(v reflect.Value) {
resultVal.Set(reflect.Append(resultVal, v))
}
} else {
m = reflect.New(resultVal.Type().Elem()).Interface()
m, ok = reflect.New(resultVal.Type().Elem()).Interface().(model.Model)
if !ok {
return &ErrWrongType{resultPtr.Type(), "Expected pointer to slice of valid Models"}
}
appendValue = func(v reflect.Value) {
resultVal.Set(reflect.Append(resultVal, reflect.Indirect(v)))
}
Expand Down Expand Up @@ -194,14 +201,17 @@ func (a api) WhereAll(m model.Model, cond ...model.Condition) ConditionalAPI {
}

// WhereCache returns a conditionalAPI based a Predicate
func (a api) WhereCache(predicate interface{}) ConditionalAPI {
return newConditionalAPI(a.cache, a.conditionFromFunc(predicate), a.logger)
func (a api) WhereCache(m model.Model, predicate func(model.Model) bool) ConditionalAPI {
return newConditionalAPI(a.cache, a.conditionFromFunc(m, predicate), a.logger)
}

// Conditional interface implementation
// FromFunc returns a Condition from a function
func (a api) conditionFromFunc(predicate interface{}) Conditional {
table, err := a.getTableFromFunc(predicate)
func (a api) conditionFromFunc(m model.Model, predicate func(model.Model) bool) Conditional {
if predicate == nil {
return newErrorConditional(fmt.Errorf("expect predicate as a function, got nil"))
}
table, err := a.getTableFromModel(m)
if err != nil {
return newErrorConditional(err)
}
Expand Down Expand Up @@ -538,43 +548,15 @@ func (a api) Wait(untilConFun ovsdb.WaitCondition, timeout *int, model model.Mod

// getTableFromModel returns the table name from a Model object after performing
// type verifications on the model
func (a api) getTableFromModel(m interface{}) (string, error) {
if _, ok := m.(model.Model); !ok {
return "", &ErrWrongType{reflect.TypeOf(m), "Type does not implement Model interface"}
}
table := a.cache.DatabaseModel().FindTable(reflect.TypeOf(m))
if table == "" {
func (a api) getTableFromModel(m model.Model) (string, error) {
table := m.Table()
_, found := a.cache.DatabaseModel().Types()[table]
if table == "" || !found {
return "", &ErrWrongType{reflect.TypeOf(m), "Model not found in Database Model"}
}
return table, nil
}

// getTableFromModel returns the table name from a the predicate after performing
// type verifications
func (a api) getTableFromFunc(predicate interface{}) (string, error) {
predType := reflect.TypeOf(predicate)
if predType == nil || predType.Kind() != reflect.Func {
return "", &ErrWrongType{predType, "Expected function"}
}
if predType.NumIn() != 1 || predType.NumOut() != 1 || predType.Out(0).Kind() != reflect.Bool {
return "", &ErrWrongType{predType, "Expected func(Model) bool"}
}

modelInterface := reflect.TypeOf((*model.Model)(nil)).Elem()
modelType := predType.In(0)
if !modelType.Implements(modelInterface) {
return "", &ErrWrongType{predType,
fmt.Sprintf("Type %s does not implement Model interface", modelType.String())}
}

table := a.cache.DatabaseModel().FindTable(modelType)
if table == "" {
return "", &ErrWrongType{predType,
fmt.Sprintf("Model %s not found in Database Model", modelType.String())}
}
return table, nil
}

// newAPI returns a new API to interact with the database
func newAPI(cache *cache.TableCache, logger *logr.Logger) API {
return api{
Expand Down
98 changes: 40 additions & 58 deletions client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,52 +217,51 @@ func TestAPIListPredicate(t *testing.T) {

test := []struct {
name string
predicate interface{}
m model.Model
predicate func(model.Model) bool
content []model.Model
err bool
}{
{
name: "none",
predicate: func(t *testLogicalSwitch) bool {
m: &testLogicalSwitch{},
predicate: func(t model.Model) bool {
return false
},
content: []model.Model{},
err: false,
},
{
name: "all",
predicate: func(t *testLogicalSwitch) bool {
m: &testLogicalSwitch{},
predicate: func(t model.Model) bool {
return true
},
content: lscacheList,
err: false,
},
{
name: "nil function must fail",
m: &testLogicalSwitch{},
err: true,
},
{
name: "arbitrary condition",
predicate: func(t *testLogicalSwitch) bool {
m: &testLogicalSwitch{},
predicate: func(m model.Model) bool {
t := m.(*testLogicalSwitch)
return strings.HasPrefix(t.Name, "magic")
},
content: []model.Model{lscacheList[1], lscacheList[3]},
err: false,
},
{
name: "error wrong type",
predicate: func(t testLogicalSwitch) string {
return "foo"
},
err: true,
},
}

for _, tt := range test {
t.Run(fmt.Sprintf("ApiListPredicate: %s", tt.name), func(t *testing.T) {
var result []*testLogicalSwitch
api := newAPI(tcache, &discardLogger)
cond := api.WhereCache(tt.predicate)
cond := api.WhereCache(tt.m, tt.predicate)
err := cond.List(context.Background(), &result)
if tt.err {
assert.NotNil(t, err)
Expand Down Expand Up @@ -536,26 +535,14 @@ func TestAPIListMulti(t *testing.T) {
func TestConditionFromFunc(t *testing.T) {
test := []struct {
name string
arg interface{}
m model.Model
arg func(model.Model) bool
err bool
}{
{
name: "wrong function must fail",
arg: func(s string) bool {
return false
},
err: true,
},
{
name: "wrong function must fail2 ",
arg: func(t *testLogicalSwitch) string {
return "foo"
},
err: true,
},
{
name: "correct func should succeed",
arg: func(t *testLogicalSwitch) bool {
m: &testLogicalSwitch{},
arg: func(m model.Model) bool {
return true
},
err: false,
Expand All @@ -566,7 +553,7 @@ func TestConditionFromFunc(t *testing.T) {
t.Run(fmt.Sprintf("conditionFromFunc: %s", tt.name), func(t *testing.T) {
cache := apiTestCache(t, nil)
apiIface := newAPI(cache, &discardLogger)
condition := apiIface.(api).conditionFromFunc(tt.arg)
condition := apiIface.(api).conditionFromFunc(tt.m, tt.arg)
if tt.err {
assert.IsType(t, &errorConditional{}, condition)
} else {
Expand All @@ -584,23 +571,6 @@ func TestConditionFromModel(t *testing.T) {
conds []model.Condition
err bool
}{
{
name: "wrong model must fail",
models: []model.Model{
&struct{ a string }{},
},
err: true,
},
{
name: "wrong condition must fail",
models: []model.Model{
&struct {
a string `ovsdb:"_uuid"`
}{},
},
conds: []model.Condition{{Field: "foo"}},
err: true,
},
{
name: "correct model must succeed",
models: []model.Model{
Expand Down Expand Up @@ -1009,7 +979,8 @@ func TestAPIMutate(t *testing.T) {
{
name: "select single by predicate name insert element in map",
condition: func(a API) ConditionalAPI {
return a.WhereCache(func(lsp *testLogicalSwitchPort) bool {
return a.WhereCache(&testLogicalSwitchPort{}, func(m model.Model) bool {
lsp := m.(*testLogicalSwitchPort)
return lsp.Name == "lsp2"
})
},
Expand All @@ -1033,7 +1004,8 @@ func TestAPIMutate(t *testing.T) {
{
name: "select many by predicate name insert element in map",
condition: func(a API) ConditionalAPI {
return a.WhereCache(func(lsp *testLogicalSwitchPort) bool {
return a.WhereCache(&testLogicalSwitchPort{}, func(m model.Model) bool {
lsp := m.(*testLogicalSwitchPort)
return lsp.Type == "someType"
})
},
Expand Down Expand Up @@ -1063,7 +1035,8 @@ func TestAPIMutate(t *testing.T) {
{
name: "No mutations should error",
condition: func(a API) ConditionalAPI {
return a.WhereCache(func(lsp *testLogicalSwitchPort) bool {
return a.WhereCache(&testLogicalSwitchPort{}, func(m model.Model) bool {
lsp := m.(*testLogicalSwitchPort)
return lsp.Type == "someType"
})
},
Expand Down Expand Up @@ -1411,7 +1384,8 @@ func TestAPIUpdate(t *testing.T) {
{
name: "select multiple by predicate change multiple field",
condition: func(a API) ConditionalAPI {
return a.WhereCache(func(t *testLogicalSwitchPort) bool {
return a.WhereCache(&testLogicalSwitchPort{}, func(m model.Model) bool {
t := m.(*testLogicalSwitchPort)
return t.Enabled != nil && *t.Enabled == true
})
},
Expand Down Expand Up @@ -1645,7 +1619,8 @@ func TestAPIDelete(t *testing.T) {
{
name: "select multiple by predicate",
condition: func(a API) ConditionalAPI {
return a.WhereCache(func(t *testLogicalSwitchPort) bool {
return a.WhereCache(&testLogicalSwitchPort{}, func(m model.Model) bool {
t := m.(*testLogicalSwitchPort)
return t.Enabled != nil && *t.Enabled == true
})
},
Expand Down Expand Up @@ -1724,29 +1699,36 @@ func BenchmarkAPIList(b *testing.B) {

test := []struct {
name string
predicate interface{}
m model.Model
predicate func(model.Model) bool
}{
{
name: "predicate returns none",
predicate: func(t *testLogicalSwitchPort) bool {
m: &testLogicalSwitchPort{},
predicate: func(t model.Model) bool {
return false
},
},
{
name: "predicate returns all",
predicate: func(t *testLogicalSwitchPort) bool {
m: &testLogicalSwitchPort{},
predicate: func(t model.Model) bool {
return true
},
},
{
name: "predicate on an arbitrary condition",
predicate: func(t *testLogicalSwitchPort) bool {
m: &testLogicalSwitchPort{},
predicate: func(m model.Model) bool {
t := m.(*testLogicalSwitchPort)
return strings.HasPrefix(t.Name, "ls1")
},
},
{
name: "predicate matches name",
predicate: func(t *testLogicalSwitchPort) bool {
m: &testLogicalSwitchPort{},
predicate: func(m model.Model) bool {
t := m.(*testLogicalSwitchPort)
return t.Name == lscacheList[index].Name
},
},
Expand All @@ -1763,7 +1745,7 @@ func BenchmarkAPIList(b *testing.B) {
api := newAPI(tcache, &discardLogger)
var cond ConditionalAPI
if tt.predicate != nil {
cond = api.WhereCache(tt.predicate)
cond = api.WhereCache(tt.m, tt.predicate)
} else {
cond = api.Where(lscacheList[index])
}
Expand Down Expand Up @@ -1979,7 +1961,7 @@ func TestAPIWait(t *testing.T) {
{
name: "no operation",
condition: func(a API) ConditionalAPI {
return a.WhereCache(func(t *testLogicalSwitchPort) bool { return false })
return a.WhereCache(&testLogicalSwitchPort{}, func(t model.Model) bool { return false })
},
until: "==",
prepare: func() (model.Model, []interface{}) {
Expand Down
4 changes: 2 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,6 @@ func (o *ovsdbClient) WhereAll(m model.Model, conditions ...model.Condition) Con
}

// WhereCache implements the API interface's WhereCache function
func (o *ovsdbClient) WhereCache(predicate interface{}) ConditionalAPI {
return o.primaryDB().api.WhereCache(predicate)
func (o *ovsdbClient) WhereCache(m model.Model, predicate func(model.Model) bool) ConditionalAPI {
return o.primaryDB().api.WhereCache(m, predicate)
}
8 changes: 8 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ type Bridge struct {
STPEnable bool `ovsdb:"stp_enable"`
}

func (b *Bridge) Table() string {
return "Bridge"
}

// OpenvSwitch defines an object in Open_vSwitch table
type OpenvSwitch struct {
UUID string `ovsdb:"_uuid"`
Expand All @@ -103,6 +107,10 @@ type OpenvSwitch struct {
SystemVersion *string `ovsdb:"system_version"`
}

func (o *OpenvSwitch) Table() string {
return "Open_vSwitch"
}

var defDB, _ = model.NewClientDBModel("Open_vSwitch",
map[string]model.Model{
"Open_vSwitch": &OpenvSwitch{},
Expand Down
Loading

0 comments on commit 50f2203

Please sign in to comment.