From 1853871bdd13110b7a8c84584103efe01f003e62 Mon Sep 17 00:00:00 2001 From: Cool Developer Date: Fri, 8 Nov 2024 10:26:20 -0500 Subject: [PATCH] fix the pair schema codec --- collections/indexing.go | 3 +++ collections/pair.go | 44 ++++++++++++++++++++++++++++++++++++-- indexer/postgres/params.go | 7 +++--- types/collections.go | 38 ++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 5 deletions(-) diff --git a/collections/indexing.go b/collections/indexing.go index bb039e7be256..fe2315441999 100644 --- a/collections/indexing.go +++ b/collections/indexing.go @@ -150,6 +150,9 @@ func (c collectionImpl[K, V]) schemaCodec() (*collectionSchemaCodec, error) { if err != nil { return nil, err } + if valueDecoder.ToSchemaType == nil { + return x, nil + } return valueDecoder.ToSchemaType(x) } ensureFieldNames(c.m.vc, "value", res.objectType.ValueFields) diff --git a/collections/pair.go b/collections/pair.go index 955cfe3d22b7..9f1945ec32fc 100644 --- a/collections/pair.go +++ b/collections/pair.go @@ -245,17 +245,43 @@ func (p pairKeyCodec[K1, K2]) SchemaCodec() (codec.SchemaCodec[Pair[K1, K2]], er return codec.SchemaCodec[Pair[K1, K2]]{}, fmt.Errorf("error getting key2 field: %w", err) } + codec1, err := codec.KeySchemaCodec(p.keyCodec1) + if err != nil { + return codec.SchemaCodec[Pair[K1, K2]]{}, fmt.Errorf("error getting key1 schema codec: %w", err) + } + + codec2, err := codec.KeySchemaCodec(p.keyCodec2) + if err != nil { + return codec.SchemaCodec[Pair[K1, K2]]{}, fmt.Errorf("error getting key2 schema codec: %w", err) + } + return codec.SchemaCodec[Pair[K1, K2]]{ Fields: []schema.Field{field1, field2}, ToSchemaType: func(pair Pair[K1, K2]) (any, error) { - return []interface{}{pair.K1(), pair.K2()}, nil + k1, err := toKeySchemaType(codec1, pair.K1()) + if err != nil { + return nil, err + } + k2, err := toKeySchemaType(codec2, pair.K2()) + if err != nil { + return nil, err + } + return []interface{}{k1, k2}, nil }, FromSchemaType: func(a any) (Pair[K1, K2], error) { aSlice, ok := a.([]interface{}) if !ok || len(aSlice) != 2 { return Pair[K1, K2]{}, fmt.Errorf("expected slice of length 2, got %T", a) } - return Join(aSlice[0].(K1), aSlice[1].(K2)), nil + k1, err := fromKeySchemaType(codec1, aSlice[0]) + if err != nil { + return Pair[K1, K2]{}, err + } + k2, err := fromKeySchemaType(codec2, aSlice[1]) + if err != nil { + return Pair[K1, K2]{}, err + } + return Join(k1, k2), nil }, }, nil } @@ -273,6 +299,20 @@ func getNamedKeyField[T any](keyCdc codec.KeyCodec[T], name string) (schema.Fiel return field, nil } +func toKeySchemaType[T any](cdc codec.SchemaCodec[T], key T) (any, error) { + if cdc.ToSchemaType != nil { + return cdc.ToSchemaType(key) + } + return key, nil +} + +func fromKeySchemaType[T any](cdc codec.SchemaCodec[T], key any) (T, error) { + if cdc.FromSchemaType != nil { + return cdc.FromSchemaType(key) + } + return key.(T), nil +} + // NewPrefixUntilPairRange defines a collection query which ranges until the provided Pair prefix. // Unstable: this API might change in the future. func NewPrefixUntilPairRange[K1, K2 any](prefix K1) *PairRange[K1, K2] { diff --git a/indexer/postgres/params.go b/indexer/postgres/params.go index ea7a1d486ea8..2aedc127fc4a 100644 --- a/indexer/postgres/params.go +++ b/indexer/postgres/params.go @@ -108,10 +108,11 @@ func (tm *objectIndexer) bindParam(field schema.Field, value interface{}) (param param = int64(t) } else if field.Kind == schema.AddressKind { - param, err = tm.options.addressCodec.BytesToString(value.([]byte)) - if err != nil { - return nil, fmt.Errorf("address encoding failed for field %q: %w", field.Name, err) + t, ok := value.(string) + if !ok { + return nil, fmt.Errorf("expected string value for field %q, got %T", field.Name, value) } + param = t } return } diff --git a/types/collections.go b/types/collections.go index df1e27617abb..e4f62411fd73 100644 --- a/types/collections.go +++ b/types/collections.go @@ -9,6 +9,7 @@ import ( "cosmossdk.io/collections" collcodec "cosmossdk.io/collections/codec" "cosmossdk.io/math" + "cosmossdk.io/schema" ) var ( @@ -120,6 +121,23 @@ func (a genericAddressKey[T]) SizeNonTerminal(key T) int { return collections.BytesKey.SizeNonTerminal(key) } +func (a genericAddressKey[T]) SchemaCodec() (collcodec.SchemaCodec[T], error) { + return collcodec.SchemaCodec[T]{ + Fields: []schema.Field{{Kind: schema.AddressKind}}, + ToSchemaType: func(t T) (any, error) { + return t.String(), nil + }, + FromSchemaType: func(s any) (T, error) { + var t T + sz, ok := s.(string) + if !ok { + return t, fmt.Errorf("expected string, got %T", s) + } + return a.stringDecoder(sz) + }, + }, nil +} + // Deprecated: lengthPrefixedAddressKey is a special key codec used to retain state backwards compatibility // when a generic address key (be: AccAddress, ValAddress, ConsAddress), is used as an index key. // More docs can be found in the LengthPrefixedAddressKey function. @@ -214,6 +232,26 @@ func (i intValueCodec) ValueType() string { return Int } +func (i intValueCodec) SchemaCodec() (collcodec.SchemaCodec[math.Int], error) { + return collcodec.SchemaCodec[math.Int]{ + Fields: []schema.Field{{Kind: schema.IntegerKind}}, + ToSchemaType: func(t math.Int) (any, error) { + return t.String(), nil + }, + FromSchemaType: func(s any) (math.Int, error) { + sz, ok := s.(string) + if !ok { + return math.Int{}, fmt.Errorf("expected string, got %T", s) + } + t, ok := math.NewIntFromString(sz) + if !ok { + return math.Int{}, fmt.Errorf("failed to parse Int from string: %s", sz) + } + return t, nil + }, + }, nil +} + type uintValueCodec struct{} func (i uintValueCodec) Encode(value math.Uint) ([]byte, error) {