Skip to content

Commit

Permalink
Fix alias typing and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nolag committed Sep 20, 2024
1 parent 34e8551 commit 6e11a49
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 64 deletions.
6 changes: 4 additions & 2 deletions pkg/values/big_int.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package values

import (
"errors"
"fmt"
"math/big"
"reflect"
Expand Down Expand Up @@ -31,9 +30,10 @@ func (b *BigInt) Unwrap() (any, error) {

func (b *BigInt) UnwrapTo(to any) error {
if b == nil || b.Underlying == nil {
return errors.New("could not unwrap nil values.BigInt")
return fmt.Errorf("could not unwrap nil")
}

// check any here because unwrap to will make the *any point to a big.Int instead of *big.Int
switch tb := to.(type) {
case *big.Int:
if tb == nil {
Expand All @@ -44,7 +44,9 @@ func (b *BigInt) UnwrapTo(to any) error {
if tb == nil {
return fmt.Errorf("cannot unwrap to nil pointer")
}

*tb = b.Underlying
return nil
default:
rto := reflect.ValueOf(to)
if rto.CanConvert(reflect.TypeOf(new(big.Int))) {
Expand Down
47 changes: 34 additions & 13 deletions pkg/values/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,24 @@ func Wrap(v any) (Value, error) {
return NewDecimal(tv), nil
case int64:
return NewInt64(tv), nil
case int32:
return NewInt64(int64(tv)), nil
case int16:
return NewInt64(int64(tv)), nil
case int8:
return NewInt64(int64(tv)), nil
case int:
return NewInt64(int64(tv)), nil
case uint64:
return NewInt64(int64(tv)), nil
case uint:
return NewInt64(int64(tv)), nil
case uint32:
return NewInt64(int64(tv)), nil
case uint16:
return NewInt64(int64(tv)), nil
case uint8:
return NewInt64(int64(tv)), nil
case uint:
return NewInt64(int64(tv)), nil
case *big.Int:
return NewBigInt(tv), nil
case nil:
Expand Down Expand Up @@ -103,20 +113,18 @@ func Wrap(v any) (Value, error) {
return NewMap(m)
// Better complex type support for slices
case reflect.Slice:
s := make([]any, val.Len())
for i := 0; i < val.Len(); i++ {
item := val.Index(i).Interface()
s[i] = item
if val.Type().Elem().Kind() == reflect.Uint8 {
return NewBytes(val.Bytes()), nil
}
return NewList(s)
return createListFromSlice(val)
case reflect.Struct:
return CreateMapFromStruct(v)
case reflect.Pointer:
if reflect.Indirect(reflect.ValueOf(v)).Kind() == reflect.Struct {
return CreateMapFromStruct(reflect.Indirect(reflect.ValueOf(v)).Interface())
}
// pointer can't be null or the switch statement above would catch it.
return Wrap(val.Elem().Interface())
case reflect.String:
return Wrap(val.Convert(reflect.TypeOf("")).Interface())

case reflect.Bool:
return Wrap(val.Convert(reflect.TypeOf(true)).Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
Expand All @@ -126,6 +134,15 @@ func Wrap(v any) (Value, error) {
return nil, fmt.Errorf("could not wrap into value: %+v", v)
}

func createListFromSlice(val reflect.Value) (Value, error) {
s := make([]any, val.Len())
for i := 0; i < val.Len(); i++ {
item := val.Index(i).Interface()
s[i] = item
}
return NewList(s)
}

func WrapMap(a any) (*Map, error) {
v, err := Wrap(a)
if err != nil {
Expand Down Expand Up @@ -277,13 +294,17 @@ func unwrapTo[T any](underlying T, to any) error {
// eg: type FeedId string allows verification of FeedId's shape while unmarshalling
rTo := reflect.ValueOf(to)
rUnderlying := reflect.ValueOf(underlying)
underlyingPtr := reflect.PointerTo(rUnderlying.Type())
if rTo.Kind() != reflect.Pointer {
return fmt.Errorf("cannot unwrap to value of type: %T", to)
}

if rTo.CanConvert(underlyingPtr) {
reflect.Indirect(rTo.Convert(underlyingPtr)).Set(rUnderlying)
if rUnderlying.Type().ConvertibleTo(rTo.Type().Elem()) {
// special case, don't unwrap bytes to string and vice versa
if rUnderlying.Kind() != rTo.Type().Elem().Kind() {
return fmt.Errorf("cannot unwrap to value of type: %T", to)
}

reflect.Indirect(rTo).Set(rUnderlying.Convert(rTo.Type().Elem()))
return nil
}

Expand Down
163 changes: 114 additions & 49 deletions pkg/values/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package values
import (
"math"
"math/big"
"reflect"
"testing"

"github.com/go-viper/mapstructure/v2"
Expand Down Expand Up @@ -228,6 +229,41 @@ func Test_Value(t *testing.T) {
}
}

func Test_WrapPointers(t *testing.T) {
underlying := "foo"
actual, err := Wrap(&underlying)
require.NoError(t, err)

expected, err := Wrap("foo")
require.NoError(t, err)

assert.True(t, reflect.DeepEqual(expected, actual))
}

func Test_IntTypes(t *testing.T) {
anyValue := int64(100)
testCases := []struct {
name string
test func(tt *testing.T)
}{
{name: "int32", test: func(tt *testing.T) { wrappableTest[int64, int32](tt, anyValue) }},
{name: "int16", test: func(tt *testing.T) { wrappableTest[int64, int16](tt, anyValue) }},
{name: "int8", test: func(tt *testing.T) { wrappableTest[int64, int8](tt, anyValue) }},
{name: "int", test: func(tt *testing.T) { wrappableTest[int64, int](tt, anyValue) }},
{name: "uint64", test: func(tt *testing.T) { wrappableTest[int64, uint64](tt, anyValue) }},
{name: "uint32", test: func(tt *testing.T) { wrappableTest[int64, uint32](tt, anyValue) }},
{name: "uint16", test: func(tt *testing.T) { wrappableTest[int64, uint16](tt, anyValue) }},
{name: "uint8", test: func(tt *testing.T) { wrappableTest[int64, uint8](tt, anyValue) }},
{name: "uint", test: func(tt *testing.T) { wrappableTest[int64, uint](tt, anyValue) }},
}

for _, tc := range testCases {
t.Run(tc.name, func(st *testing.T) {

})
}
}

func Test_StructWrapUnwrap(t *testing.T) {
// TODO: https://smartcontract-it.atlassian.net/browse/KS-439 decimal.Decimal is broken when encoded.
type sStruct struct {
Expand Down Expand Up @@ -383,81 +419,110 @@ func Test_Copy(t *testing.T) {
}
}

type aliasByte []byte
type aliasBytes []byte
type aliasString string
type aliasInt int
type aliasMap map[string]any
type aliasSingleByte uint8
type aliasByte uint8
type decimalAlias decimal.Decimal
type bigIntAlias big.Int
type bigIntPtrAlias *big.Int

func Test_Aliases(t *testing.T) {
testCases := []struct {
name string
val func() any
alias func() any
convert func(any) any
name string
test func(t *testing.T)
}{
{
name: "alias to []byte",
val: func() any { return []byte("string") },
alias: func() any { return aliasByte([]byte{}) },
name: "[]byte alias",
test: func(t *testing.T) { wrappableTest[[]byte, aliasBytes](t, []byte("hello")) },
},
{
name: "byte alias in slice",
test: func(tt *testing.T) {
wrappableTestWithConversion[[]byte, []aliasByte](tt, []byte("hello"), func(native []aliasByte) []byte {
converted := make([]byte, len(native))
for i, b := range native {
converted[i] = byte(b)
}
return converted
})
},
},
{
name: "simple aliases",
val: func() any { return "string" },
alias: func() any { return aliasString("") },
name: "basic alias",
test: func(tt *testing.T) { wrappableTest[string, aliasString](tt, "hello") },
},
{
name: "aliasByte -> []byte",
val: func() any { return []byte("string") },
alias: func() any { return aliasByte([]byte{}) },
name: "integer",
test: func(tt *testing.T) { wrappableTest[int, aliasInt](tt, 1) },
},
{
name: "[]aliasSingleByte -> []byte",
val: func() any { return []byte("string") },
alias: func() any { return []aliasSingleByte{} },
name: "map",
test: func(tt *testing.T) { wrappableTest[map[string]any, aliasMap](tt, map[string]any{"hello": "world"}) },
},
{
name: "int",
val: func() any { return 2 },
alias: func() any { return aliasInt(0) },
convert: func(a any) any { return int(a.(int64)) },
name: "decimal alias",
test: func(tt *testing.T) { wrappableTest[decimal.Decimal, decimalAlias](tt, decimal.NewFromFloat(1.0)) },
},
{
name: "[][]byte -> []aliasByte",
val: func() any { return [][]byte{[]byte("hello")} },
alias: func() any { return []aliasByte{} },
convert: func(a any) any {
to := [][]byte{}
for _, v := range a.([]interface{}) {
to = append(to, v.([]byte))
}

return to
name: "big int alias",
test: func(tt *testing.T) {
testBigIntType[*bigIntAlias](tt, big.NewInt(1), func() *bigIntAlias {
return new(bigIntAlias)
})
},
},
{
name: "aliasMap -> map[string]any",
val: func() any { return map[string]any{} },
alias: func() any { return aliasMap{} },
convert: func(a any) any { return map[string]any(a.(aliasMap)) },
name: "big int pointer alias",
test: func(tt *testing.T) {
testBigIntType[bigIntPtrAlias](tt, big.NewInt(1), func() bigIntPtrAlias {
return new(big.Int)
})
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(st *testing.T) {
v := tc.val()
wv, err := Wrap(v)
require.NoError(t, err)

a := tc.alias()
err = wv.UnwrapTo(&a)
require.NoError(t, err)

if tc.convert != nil {
assert.Equal(t, tc.convert(a), v)
} else {
assert.Equal(t, a, v)
}
tc.test(st)
})
}
}

func wrappableTest[Native, Alias any](t *testing.T, native Native) {
wrappableTestWithConversion(t, native, func(alias Alias) Native {
return reflect.ValueOf(alias).Convert(reflect.TypeOf(native)).Interface().(Native)
})
}

func testBigIntType[Alias any](t *testing.T, native *big.Int, create func() Alias) {
wv, err := Wrap(native)
require.NoError(t, err)

a := create()

err = wv.UnwrapTo(a)
require.NoError(t, err)

assert.Equal(t, native, reflect.ValueOf(a).Convert(reflect.TypeOf(native)).Interface())

aliasWrapped, err := Wrap(a)
require.NoError(t, err)
assert.Equal(t, wv, aliasWrapped)
}

func wrappableTestWithConversion[Native, Alias any](t *testing.T, native Native, convert func(native Alias) Native) {
wv, err := Wrap(native)
require.NoError(t, err)

var a Alias
err = wv.UnwrapTo(&a)
require.NoError(t, err)

assert.Equal(t, native, convert(a))

aliasWrapped, err := Wrap(a)
require.NoError(t, err)
assert.Equal(t, wv, aliasWrapped)
}

0 comments on commit 6e11a49

Please sign in to comment.