Skip to content

Commit

Permalink
Fix alias typing and tests (#788)
Browse files Browse the repository at this point in the history
* Fix alias typing and tests

* Fix ints

* errors.new instead of fmt

* Add array support to slice (#789)
  • Loading branch information
nolag authored Sep 23, 2024
1 parent 5d12585 commit 26df9ab
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 97 deletions.
9 changes: 6 additions & 3 deletions pkg/values/big_int.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,23 @@ func (b *BigInt) Unwrap() (any, error) {

func (b *BigInt) UnwrapTo(to any) error {
if b == nil || b.Underlying == nil {
return errors.New("could not unwrap nil values.BigInt")
return fmt.Errorf("could not unwrap nil")
}

// check any here because unwrap to will make the *any point to a big.Int instead of *big.Int
switch tb := to.(type) {
case *big.Int:
if tb == nil {
return fmt.Errorf("cannot unwrap to nil pointer")
return errors.New("cannot unwrap to nil pointer")
}
*tb = *b.Underlying
case *any:
if tb == nil {
return fmt.Errorf("cannot unwrap to nil pointer")
return errors.New("cannot unwrap to nil pointer")
}

*tb = b.Underlying
return nil
default:
rto := reflect.ValueOf(to)
if rto.CanConvert(reflect.TypeOf(new(big.Int))) {
Expand Down
122 changes: 106 additions & 16 deletions pkg/values/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package values
import (
"errors"
"fmt"
"math"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/values/pb"
Expand Down Expand Up @@ -41,27 +42,116 @@ func (i *Int64) UnwrapTo(to any) error {
return fmt.Errorf("cannot unwrap to nil pointer: %+v", to)
}

if reflect.ValueOf(to).Kind() != reflect.Pointer {
return fmt.Errorf("cannot unwrap to non-pointer value: %+v", to)
}
switch tv := to.(type) {
case *int64:
*tv = i.Underlying
return nil
case *int32:
if err := verifyBounds(math.MinInt32, math.MaxInt32, i.Underlying, "int32"); err != nil {
return err
}

*tv = int32(i.Underlying)
return nil
case *int16:
if err := verifyBounds(math.MinInt16, math.MaxInt16, i.Underlying, "int16"); err != nil {
return err
}

*tv = int16(i.Underlying)
return nil
case *int8:
if err := verifyBounds(math.MinInt8, math.MaxInt8, i.Underlying, "int8"); err != nil {
return err
}

rToVal := reflect.Indirect(reflect.ValueOf(to))
switch rToVal.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if rToVal.OverflowInt(i.Underlying) {
return fmt.Errorf("cannot unwrap int64 to %T: overflow", to)
*tv = int8(i.Underlying)
return nil
case *int:
if err := verifyBounds(math.MinInt, math.MaxInt, i.Underlying, "int"); err != nil {
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:

*tv = int(i.Underlying)
return nil
case *uint64:
if i.Underlying < 0 {
return fmt.Errorf("cannot unwrap int64 to %T: underflow", to)
return fmt.Errorf("value %d is too small for uint64", i.Underlying)
}

*tv = uint64(i.Underlying)
return nil
case *uint32:
if err := verifyBounds(0, math.MaxUint32, i.Underlying, "uint32"); err != nil {
return err
}

*tv = uint32(i.Underlying)
return nil
case *uint16:
if err := verifyBounds(0, math.MaxUint16, i.Underlying, "uint16"); err != nil {
return err
}

*tv = uint16(i.Underlying)
return nil
case *uint8:
if err := verifyBounds(0, math.MaxUint8, i.Underlying, "uint8"); err != nil {
return err
}

*tv = uint8(i.Underlying)
return nil
case *uint:
if math.MaxUint == math.MaxUint64 {
if i.Underlying < 0 {
return fmt.Errorf("value %d is too small for uint64", i.Underlying)
}
}
if rToVal.OverflowUint(uint64(i.Underlying)) {
return fmt.Errorf("cannot unwrap int64 to %T: overflow", to)

*tv = uint(i.Underlying)
return nil
case *any:
*tv = i.Underlying
return nil
}

rv := reflect.ValueOf(to)
if rv.Kind() == reflect.Ptr {
switch rv.Elem().Kind() {
case reflect.Int64:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int64(0)))).Interface())
case reflect.Int32:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int32(0)))).Interface())
case reflect.Int16:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int16(0)))).Interface())
case reflect.Int8:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int8(0)))).Interface())
case reflect.Int:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(0))).Interface())
case reflect.Uint64:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint64(0)))).Interface())
case reflect.Uint32:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint32(0)))).Interface())
case reflect.Uint16:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint16(0)))).Interface())
case reflect.Uint8:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint8(0)))).Interface())
case reflect.Uint:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint(0)))).Interface())
default:
// fall through to the error, default is required by lint
}
case reflect.Interface:
default:
return fmt.Errorf("cannot unwrap to type %T", to)
}

return unwrapTo(i.Underlying, to)
return fmt.Errorf("cannot unwrap to type %T", to)
}

func verifyBounds(min, max, value int64, tpe string) error {
if value < min {
return fmt.Errorf("value %d is too large for %s", value, tpe)
} else if value > max {
return fmt.Errorf("value %d is too small for %s", value, tpe)
}
return nil
}
44 changes: 28 additions & 16 deletions pkg/values/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,15 @@ func (l *List) UnwrapTo(to any) error {
switch ptrVal.Kind() {
case reflect.Slice:
newList := reflect.MakeSlice(ptrVal.Type(), len(l.Underlying), len(l.Underlying))
for i, el := range l.Underlying {
newElm := newList.Index(i)
if newElm.Kind() == reflect.Pointer {
newElm.Set(reflect.New(newElm.Type().Elem()))
} else {
newElm = newElm.Addr()
}

if el == nil {
continue
}
if err := el.UnwrapTo(newElm.Interface()); err != nil {
return err
}
return l.unwrapToSliceOrArray(newList, val)
case reflect.Array:
if ptrVal.Len() < len(l.Underlying) {
return fmt.Errorf("too many elements to unwrap")
} else if ptrVal.Len() > len(l.Underlying) {
return fmt.Errorf("too few elements to unwrap")
}
reflect.Indirect(val).Set(newList)
return nil
arr := reflect.New(ptrVal.Type()).Elem()
return l.unwrapToSliceOrArray(arr, val)
default:
dl := []any{}
err := l.UnwrapTo(&dl)
Expand All @@ -104,3 +96,23 @@ func (l *List) UnwrapTo(to any) error {
return fmt.Errorf("cannot unwrap to type %T", to)
}
}

func (l *List) unwrapToSliceOrArray(newList reflect.Value, val reflect.Value) error {
for i, el := range l.Underlying {
newElm := newList.Index(i)
if newElm.Kind() == reflect.Pointer {
newElm.Set(reflect.New(newElm.Type().Elem()))
} else {
newElm = newElm.Addr()
}

if el == nil {
continue
}
if err := el.UnwrapTo(newElm.Interface()); err != nil {
return err
}
}
reflect.Indirect(val).Set(newList)
return nil
}
27 changes: 27 additions & 0 deletions pkg/values/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@ func Test_ListUnwrapTo(t *testing.T) {
sliceTest[any](t, expected, got)
})

t.Run("arrays", func(t *testing.T) {
v, err := Wrap([2]string{"foo", "bar"})
require.NoError(t, err)

var got [2]string
err = v.UnwrapTo(&got)
require.NoError(t, err)

require.Equal(t, [2]string{"foo", "bar"}, got)
})

t.Run("arrays too many elements return error", func(t *testing.T) {
wrapped, err := Wrap([]string{"foo", "bar", "baz"})
require.NoError(t, err)
to := [2]string{}
err = wrapped.UnwrapTo(&to)
assert.ErrorContains(t, err, "too many elements to unwrap")
})

t.Run("arrays too few elements return error", func(t *testing.T) {
wrapped, err := Wrap([]string{"foo", "bar", "baz"})
require.NoError(t, err)
to := [4]string{}
err = wrapped.UnwrapTo(&to)
assert.ErrorContains(t, err, "too few elements to unwrap")
})

t.Run("cant be assigned to passed in var", func(t *testing.T) {
a := struct{}{}
l, err := Wrap([]int{1, 2, 3})
Expand Down
49 changes: 36 additions & 13 deletions pkg/values/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,24 @@ func Wrap(v any) (Value, error) {
return NewDecimal(tv), nil
case int64:
return NewInt64(tv), nil
case int32:
return NewInt64(int64(tv)), nil
case int16:
return NewInt64(int64(tv)), nil
case int8:
return NewInt64(int64(tv)), nil
case int:
return NewInt64(int64(tv)), nil
case uint64:
return NewInt64(int64(tv)), nil
case uint:
return NewInt64(int64(tv)), nil
case uint32:
return NewInt64(int64(tv)), nil
case uint16:
return NewInt64(int64(tv)), nil
case uint8:
return NewInt64(int64(tv)), nil
case uint:
return NewInt64(int64(tv)), nil
case *big.Int:
return NewBigInt(tv), nil
case nil:
Expand Down Expand Up @@ -103,20 +113,25 @@ func Wrap(v any) (Value, error) {
return NewMap(m)
// Better complex type support for slices
case reflect.Slice:
s := make([]any, val.Len())
for i := 0; i < val.Len(); i++ {
item := val.Index(i).Interface()
s[i] = item
if val.Type().Elem().Kind() == reflect.Uint8 {
return NewBytes(val.Bytes()), nil
}
return createListFromSlice(val)
case reflect.Array:
arrayLen := val.Len()
slice := reflect.MakeSlice(reflect.SliceOf(val.Type().Elem()), arrayLen, arrayLen)
for i := 0; i < arrayLen; i++ {
slice.Index(i).Set(val.Index(i))
}
return NewList(s)
return Wrap(slice.Interface())
case reflect.Struct:
return CreateMapFromStruct(v)
case reflect.Pointer:
if reflect.Indirect(reflect.ValueOf(v)).Kind() == reflect.Struct {
return CreateMapFromStruct(reflect.Indirect(reflect.ValueOf(v)).Interface())
}
// pointer can't be null or the switch statement above would catch it.
return Wrap(val.Elem().Interface())
case reflect.String:
return Wrap(val.Convert(reflect.TypeOf("")).Interface())

case reflect.Bool:
return Wrap(val.Convert(reflect.TypeOf(true)).Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
Expand All @@ -126,6 +141,15 @@ func Wrap(v any) (Value, error) {
return nil, fmt.Errorf("could not wrap into value: %+v", v)
}

func createListFromSlice(val reflect.Value) (Value, error) {
s := make([]any, val.Len())
for i := 0; i < val.Len(); i++ {
item := val.Index(i).Interface()
s[i] = item
}
return NewList(s)
}

func WrapMap(a any) (*Map, error) {
v, err := Wrap(a)
if err != nil {
Expand Down Expand Up @@ -281,9 +305,8 @@ func unwrapTo[T any](underlying T, to any) error {
return fmt.Errorf("cannot unwrap to value of type: %T", to)
}

if rUnderlying.CanConvert(reflect.Indirect(rTo).Type()) {
conv := rUnderlying.Convert(reflect.Indirect(rTo).Type())
reflect.Indirect(rTo).Set(conv)
if rUnderlying.Type().ConvertibleTo(rTo.Type().Elem()) {
reflect.Indirect(rTo).Set(rUnderlying.Convert(rTo.Type().Elem()))
return nil
}

Expand Down
Loading

0 comments on commit 26df9ab

Please sign in to comment.