diff --git a/benchmark_test.go b/benchmark_test.go index 0fb1307..4147d70 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -85,3 +85,16 @@ func BenchmarkNativeFromTextualUsingV2(b *testing.B) { _ = nativeFromTextUsingV2(b, codec, textData) } } + +func BenchmarkScanBinaryUsingV2(b *testing.B) { + avroBlob, err := os.ReadFile("fixtures/quickstop-null.avro") + if err != nil { + b.Fatal(err) + } + nativeData, codec := nativeFromAvroUsingV2(b, avroBlob) + binaryData := binaryFromNativeUsingV2(b, codec, nativeData) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = scanBinaryUsingV2(b, codec, binaryData) + } +} diff --git a/codec.go b/codec.go index ee5bda1..6d1cb68 100644 --- a/codec.go +++ b/codec.go @@ -57,6 +57,8 @@ type Codec struct { nativeFromBinary func([]byte) (interface{}, []byte, error) textualFromNative func([]byte, interface{}) ([]byte, error) + scanBinary func([]byte, ...interface{}) ([]byte, error) + Rabin uint64 } @@ -583,6 +585,55 @@ func (c *Codec) TextualFromNative(buf []byte, datum interface{}) ([]byte, error) return newBuf, nil } +// ScanBinary copies the values from the binary encoded byte slice into the +// values pointed to by dest in the order of the fields of the Avro schema +// supplied when creating the Codec. On success, it returns a byte slice +// containing the remaining undecoded bytes, and a nil error value. On error, it +// returns the original byte slice, and the error message. +// +// func ExampleCodec_ScanBinary_avro() { +// codec, err := NewCodec(` +// { +// "type": "record", +// "name": "r1", +// "fields" : [ +// {"name": "f1", "type": "string"}, +// {"name": "f2", "type": "int"} +// ] +// } +// `) +// +// if err != nil { +// log.Fatal(err) +// } +// +// binary := []byte{ +// 0x10, // field1 size = 8 +// 't', 'h', 'i', 'r', 't', 'e', 'e', 'n', +// 0x1a, // field2 == 13 +// } +// +// var f1 string +// var f2 int +// if _, err = codec.ScanBinary(binary, &f1, &f2); err != nil { +// log.Fatal(err) +// } +// +// fmt.Printf("f1: %v, f2: %v", f1, f2) +// // Output: f1: thirteen, f2: 13 +// } +func (c *Codec) ScanBinary(buf []byte, dest ...interface{}) ([]byte, error) { + // TODO: implement for every type and remove + if c.scanBinary == nil { + return buf, fmt.Errorf("ScanBinary not implemented for codec with schema: %s", c.schemaOriginal) + } + newBuf, err := c.scanBinary(buf, dest...) + if err != nil { + return buf, err // if error, return original byte slice + } + return newBuf, nil +} + // Schema returns the original schema used to create the Codec. func (c *Codec) Schema() string { return c.schemaOriginal diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..eddc324 --- /dev/null +++ b/convert.go @@ -0,0 +1,321 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Modified by Datastax Inc. 7/4/2022 + +// Type conversions for Scan. +package goavro + +import ( + "errors" + "fmt" + "math/big" + "reflect" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/shopspring/decimal" +) + +var errNilPtr = errors.New("destination pointer is nil") + +func convertAssign(dest, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + } + case big.Int: + switch d := dest.(type) { + case *big.Int: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case decimal.Decimal: + switch d := dest.(type) { + case *decimal.Decimal: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *time.Time: + *d = s + return nil + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + } + case uuid.UUID: + switch d := dest.(type) { + case *uuid.UUID: + *d = s + return nil + case *string: + if d == nil { + return errNilPtr + } + *d = s.String() + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = s[:] + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *string: + *d = "" + return nil + case *int64: + *d = 0 + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *interface{}: + *d = src + return nil + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Pointer { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(cloneBytes(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Pointer: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + case reflect.Map: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + case reflect.Array, reflect.Slice: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} diff --git a/convert_test.go b/convert_test.go new file mode 100644 index 0000000..8489f7d --- /dev/null +++ b/convert_test.go @@ -0,0 +1,293 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Modified by Datastax Inc. 7/4/2022 + +package goavro + +import ( + "fmt" + "math/big" + "reflect" + "testing" + "time" + + "github.com/google/uuid" + "github.com/shopspring/decimal" +) + +var someTime = time.Unix(123, 0) +var answer int64 = 42 + +type ( + userDefined float64 + userDefinedSlice []int + userDefinedString string +) + +type conversionTest struct { + s, d interface{} // source and destination + + // following are used if they're non-zero + wantint int64 + wantuint uint64 + wantBigInt *big.Int + wantstr string + wantbytes []byte + wantf32 float32 + wantf64 float64 + wantDecimal *decimal.Decimal + wanttime time.Time + wantUUID uuid.UUID + wantbool bool // used if d is of type *bool + wanterr string + wantiface interface{} + wantmap map[string]int + wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr + wantnil bool // if true, *d must be *int64(nil) + wantusrdef userDefined + wantusrstr userDefinedString +} + +// Target variables for scanning into. +var ( + scanstr string + scanbytes []byte + scanint int + scanuint8 uint8 + scanuint16 uint16 + scanBigInt *big.Int + scanbool bool + scanf32 float32 + scanf64 float64 + scanDecimal *decimal.Decimal + scantime time.Time + scanUUID uuid.UUID + scanptr *int64 + scaniface interface{} + scanmap map[string]int +) + +func conversionTests() []conversionTest { + id := uuid.MustParse("12345678-1234-5678-1234-567812345678") + dec, _ := decimal.NewFromString("1.23456789") + + // Return a fresh instance to test so "go test -count 2" works correctly. + return []conversionTest{ + // Exact conversions (destination pointer type matches source type) + {s: "foo", d: &scanstr, wantstr: "foo"}, + {s: 123, d: &scanint, wantint: 123}, + {s: someTime, d: &scantime, wanttime: someTime}, + {s: dec, d: &scanDecimal, wantDecimal: &dec}, + {s: big.NewInt(123), d: &scanBigInt, wantBigInt: big.NewInt(123)}, + {s: id, d: &scanUUID, wantUUID: id}, + + // To strings + {s: "string", d: &scanstr, wantstr: "string"}, + {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, + {s: 123, d: &scanstr, wantstr: "123"}, + {s: int8(123), d: &scanstr, wantstr: "123"}, + {s: int64(123), d: &scanstr, wantstr: "123"}, + {s: uint8(123), d: &scanstr, wantstr: "123"}, + {s: uint16(123), d: &scanstr, wantstr: "123"}, + {s: uint32(123), d: &scanstr, wantstr: "123"}, + {s: uint64(123), d: &scanstr, wantstr: "123"}, + {s: 1.5, d: &scanstr, wantstr: "1.5"}, + {s: id, d: &scanstr, wantstr: "12345678-1234-5678-1234-567812345678"}, + {s: nil, d: &scanstr, wantstr: ""}, + + // From time.Time: + {s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"}, + {s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"}, + {s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"}, + {s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"}, + {s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")}, + {s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()}, + + // To uuid.UUID + {s: id, d: &scanUUID, wantUUID: id}, + + // To []byte + {s: nil, d: &scanbytes, wantbytes: nil}, + {s: "string", d: &scanbytes, wantbytes: []byte("string")}, + {s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")}, + {s: 123, d: &scanbytes, wantbytes: []byte("123")}, + {s: int8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: int64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint16(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint32(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")}, + {s: id, d: &scanbytes, wantbytes: id[:]}, + + // Strings to integers + {s: "255", d: &scanuint8, wantuint: 255}, + {s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"}, + {s: "256", d: &scanuint16, wantuint: 256}, + {s: "-1", d: &scanint, wantint: -1}, + {s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"}, + + // int64 to smaller integers + {s: int64(5), d: &scanuint8, wantuint: 5}, + {s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"}, + {s: int64(256), d: &scanuint16, wantuint: 256}, + {s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"}, + + // True bools + {s: true, d: &scanbool, wantbool: true}, + + // False bools + {s: false, d: &scanbool, wantbool: false}, + + // Not bools + {s: "yup", d: &scanbool, wanterr: "unsupported Scan, storing driver.Value type string into type *bool"}, + {s: 2, d: &scanbool, wanterr: "unsupported Scan, storing driver.Value type int into type *bool"}, + + // Floats + {s: 1.5, d: &scanf64, wantf64: 1.5}, + {s: int64(1), d: &scanf64, wantf64: float64(1)}, + {s: 1.5, d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf64, wantf64: 1.5}, + + // Pointers + {s: interface{}(nil), d: &scanptr, wantnil: true}, + {s: int64(42), d: &scanptr, wantptr: &answer}, + + // To interface{} + {s: 1.5, d: &scaniface, wantiface: 1.5}, + {s: int64(1), d: &scaniface, wantiface: int64(1)}, + {s: "str", d: &scaniface, wantiface: "str"}, + {s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")}, + {s: true, d: &scaniface, wantiface: true}, + {s: nil, d: &scaniface}, + {s: []byte(nil), d: &scaniface, wantiface: []byte(nil)}, + + // Maps + {s: map[string]int{"a": 1}, d: &scanmap, wantmap: map[string]int{"a": 1}}, + {s: nil, d: &scanmap, wantmap: nil}, + + // To a user-defined type + {s: 1.5, d: new(userDefined), wantusrdef: 1.5}, + {s: int64(123), d: new(userDefined), wantusrdef: 123}, + {s: "1.5", d: new(userDefined), wantusrdef: 1.5}, + {s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *goavro.userDefinedSlice`}, + {s: "str", d: new(userDefinedString), wantusrstr: "str"}, + + // Other errors + {s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`}, + } +} + +func intPtrValue(intptr interface{}) interface{} { + return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int() +} + +func intValue(intptr interface{}) int64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Int() +} + +func uintValue(intptr interface{}) uint64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Uint() +} + +func float64Value(ptr interface{}) float64 { + return *(ptr.(*float64)) +} + +func float32Value(ptr interface{}) float32 { + return *(ptr.(*float32)) +} + +func timeValue(ptr interface{}) time.Time { + return *(ptr.(*time.Time)) +} + +func TestConversions(t *testing.T) { + for n, ct := range conversionTests() { + err := convertAssign(ct.d, ct.s) + errstr := "" + if err != nil { + errstr = err.Error() + } + errf := func(format string, args ...interface{}) { + base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d) + t.Errorf(base+format, args...) + } + if errstr != ct.wanterr { + errf("got error %q, want error %q", errstr, ct.wanterr) + } + if ct.wantstr != "" && ct.wantstr != scanstr { + errf("want string %q, got %q", ct.wantstr, scanstr) + } + if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) { + errf("want byte %q, got %q", ct.wantbytes, scanbytes) + } + if ct.wantint != 0 && ct.wantint != intValue(ct.d) { + errf("want int %d, got %d", ct.wantint, intValue(ct.d)) + } + if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) { + errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d)) + } + if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d)) + } + if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d)) + } + if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { + errf("want bool %v, got %v", ct.wantbool, *bp) + } + if ct.wantUUID.String() != "00000000-0000-0000-0000-000000000000" && ct.wantUUID != scanUUID { + errf("want UUID %q, got %q", ct.wantUUID, scanUUID) + } + if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) { + errf("want time %v, got %v", ct.wanttime, timeValue(ct.d)) + } + if ct.wantnil && *ct.d.(**int64) != nil { + errf("want nil, got %v", intPtrValue(ct.d)) + } + if ct.wantptr != nil { + if *ct.d.(**int64) == nil { + errf("want pointer to %v, got nil", *ct.wantptr) + } else if *ct.wantptr != intPtrValue(ct.d) { + errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d)) + } + } + if ifptr, ok := ct.d.(*interface{}); ok { + if !reflect.DeepEqual(ct.wantiface, scaniface) { + errf("want interface %#v, got %#v", ct.wantiface, scaniface) + continue + } + if srcBytes, ok := ct.s.([]byte); ok { + dstBytes := (*ifptr).([]byte) + if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] { + errf("copy into interface{} didn't copy []byte data") + } + } + } + if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) { + errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined)) + } + if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) { + errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString)) + } + } +} + +// https://golang.org/issues/13905 +func TestUserDefinedBytes(t *testing.T) { + type userDefinedBytes []byte + var u userDefinedBytes + v := []byte("foo") + + err := convertAssign(&u, v) + if err != nil { + t.Fatalf("convertAssign(%v, %v) unexpected error: %v", u, v, err) + } + if &u[0] == &v[0] { + t.Fatal("userDefinedBytes got potentially dirty driver memory") + } +} diff --git a/go.mod b/go.mod index f282a30..d2cdf9f 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,7 @@ go 1.12 require ( github.com/golang/snappy v0.0.1 + github.com/google/uuid v1.3.0 + github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.7.5 ) diff --git a/go.sum b/go.sum index 9532d30..d703ec6 100644 --- a/go.sum +++ b/go.sum @@ -3,13 +3,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.5 h1:s5PTfem8p8EbKQOctVV53k6jCJt3UX4IEJzwh+C324Q= github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/helperV2_test.go b/helperV2_test.go index 2a84137..fcd302d 100644 --- a/helperV2_test.go +++ b/helperV2_test.go @@ -101,3 +101,25 @@ func nativeFromTextUsingV2(tb testing.TB, codec *Codec, textData [][]byte) []int } return nativeData } + +type person struct { + ID int64 + First, Last, Phone string + Age int +} + +func scanBinaryUsingV2(tb testing.TB, codec *Codec, binaryData [][]byte) []person { + tb.Helper() + nativeData := make([]person, len(binaryData)) + for i, binaryDatum := range binaryData { + d := nativeData[i] + buf, err := codec.ScanBinary(binaryDatum, &d.ID, &d.First, &d.Last, &d.Phone, &d.Age) + if err != nil { + tb.Fatal(err) + } + if len(buf) > 0 { + tb.Fatalf("BinaryDecode ought to have returned nil buffer: %v", buf) + } + } + return nativeData +} diff --git a/record.go b/record.go index e5ac9e4..17ac1b3 100644 --- a/record.go +++ b/record.go @@ -229,5 +229,21 @@ func makeRecordCodec(st map[string]*Codec, enclosingNamespace string, schemaMap return genericMapTextEncoder(buf, datum, nil, codecFromFieldName) } + c.scanBinary = func(buf []byte, dest ...interface{}) ([]byte, error) { + for i, fieldCodec := range codecFromIndex { + name := nameFromIndex[i] + var value interface{} + var err error + value, buf, err = fieldCodec.nativeFromBinary(buf) + if err != nil { + return nil, fmt.Errorf("cannot decode binary record %q field %q: %w", c.typeName, name, err) + } + if err := convertAssign(dest[i], value); err != nil { + return nil, fmt.Errorf("cannot convert binary record %q field %q: %w", c.typeName, name, err) + } + } + return buf, nil + } + return c, nil } diff --git a/record_test.go b/record_test.go index f8145c0..98a1de7 100644 --- a/record_test.go +++ b/record_test.go @@ -12,6 +12,7 @@ package goavro import ( "bytes" "fmt" + "log" "testing" ) @@ -612,6 +613,37 @@ func ExampleCodec_TextualFromNative_avro() { // Output: {"next":{"LongList":{"next":{"LongList":{"next":null}}}}} } +func ExampleCodec_ScanBinary_avro() { + codec, err := NewCodec(` +{ + "type": "record", + "name": "r1", + "fields" : [ + {"name": "f1", "type": "string"}, + {"name": "f2", "type": "int"} + ] +} +`) + if err != nil { + log.Fatal(err) + } + + binary := []byte{ + 0x10, // field1 size = 8 + 't', 'h', 'i', 'r', 't', 'e', 'e', 'n', + 0x1a, // field2 == 13 + } + + var f1 string + var f2 int + if _, err = codec.ScanBinary(binary, &f1, &f2); err != nil { + log.Fatal(err) + } + + fmt.Printf("f1: %v, f2: %v", f1, f2) + // Output: f1: thirteen, f2: 13 +} + func TestRecordFieldFixedDefaultValue(t *testing.T) { testSchemaValid(t, `{"type": "record", "name": "r1", "fields":[{"name": "f1", "type": {"type": "fixed", "name": "someFixed", "size": 1}, "default": "\u0000"}]}`) }