Skip to content

Commit c7308c4

Browse files
committed
rebase & support sql.Scanner iface
1 parent 6e27ad1 commit c7308c4

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Added support of custom types to row.ScanStruct using sql.Scanner interface
2+
13
## v3.104.5
24
* Added query client session pool metrics: create_in_progress, in_use, waiters_queue
35
* Added pool item closing for not-alived item

internal/value/cast.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
package value
22

3+
import (
4+
"database/sql"
5+
"database/sql/driver"
6+
7+
"github.com/google/uuid"
8+
)
9+
310
func CastTo(v Value, dst interface{}) error {
411
if dst == nil {
512
return errNilDestination
@@ -10,5 +17,20 @@ func CastTo(v Value, dst interface{}) error {
1017
return nil
1118
}
1219

20+
if _, ok := dst.(*uuid.UUID); ok {
21+
return v.castTo(dst)
22+
}
23+
24+
if scanner, has := dst.(sql.Scanner); has {
25+
dv := new(driver.Value)
26+
27+
err := v.castTo(dv)
28+
if err != nil {
29+
return err
30+
}
31+
32+
return scanner.Scan(*dv)
33+
}
34+
1335
return v.castTo(dst)
1436
}

internal/value/cast_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package value
22

33
import (
44
"database/sql/driver"
5+
"errors"
56
"reflect"
67
"testing"
78
"time"
@@ -32,6 +33,22 @@ func loadLocation(t *testing.T, name string) *time.Location {
3233
return loc
3334
}
3435

36+
<<<<<<< HEAD
37+
=======
38+
type testStringSQLScanner string
39+
40+
func (s *testStringSQLScanner) Scan(value any) error {
41+
ts, ok := value.(string)
42+
if !ok {
43+
return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string")
44+
}
45+
46+
*s = testStringSQLScanner(ts)
47+
48+
return nil
49+
}
50+
51+
>>>>>>> 74dbdc6f (rebase & support sql.Scanner iface)
3552
func TestCastTo(t *testing.T) {
3653
testsCases := []struct {
3754
name string
@@ -428,6 +445,13 @@ func TestCastTo(t *testing.T) {
428445
exp: DateValueFromTime(time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)),
429446
err: nil,
430447
},
448+
{
449+
name: xtest.CurrentFileLine(),
450+
value: TextValue("text-string"),
451+
dst: ptr[testStringSQLScanner](),
452+
exp: testStringSQLScanner("text-string"),
453+
err: nil,
454+
},
431455
}
432456
for _, tt := range testsCases {
433457
t.Run(tt.name, func(t *testing.T) {

internal/value/value.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,7 +1515,7 @@ func (v *listValue) castTo(dst any) error {
15151515
inner.Set(newSlice)
15161516

15171517
for i, item := range v.ListItems() {
1518-
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1518+
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
15191519
return xerrors.WithStackTrace(fmt.Errorf(
15201520
"%w '%s(%+v)' to '%T' destination",
15211521
ErrCannotCast, v.Type().Yql(), v, dstValue,
@@ -1649,7 +1649,7 @@ func (v *setValue) castTo(dst any) error {
16491649
inner.Set(newSlice)
16501650

16511651
for i, item := range v.items {
1652-
if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil {
1652+
if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil {
16531653
return xerrors.WithStackTrace(fmt.Errorf(
16541654
"%w '%s(%+v)' to '%T' destination",
16551655
ErrCannotCast, v.Type().Yql(), v, dstValue,
@@ -1757,7 +1757,7 @@ func (v *optionalValue) castTo(dst any) error {
17571757
return nil
17581758
}
17591759

1760-
if err := v.value.castTo(ptr.Interface()); err != nil {
1760+
if err := CastTo(v.value, (ptr.Interface())); err != nil {
17611761
return xerrors.WithStackTrace(err)
17621762
}
17631763

@@ -1772,7 +1772,7 @@ func (v *optionalValue) castTo(dst any) error {
17721772

17731773
inner.Set(reflect.New(inner.Type().Elem()))
17741774

1775-
if err := v.value.castTo(inner.Interface()); err != nil {
1775+
if err := CastTo(v.value, inner.Interface()); err != nil {
17761776
return xerrors.WithStackTrace(err)
17771777
}
17781778

@@ -1853,7 +1853,7 @@ func (v *structValue) castTo(dst any) error {
18531853
}
18541854

18551855
for i, field := range v.fields {
1856-
if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil {
1856+
if err := CastTo(field.V, inner.Field(i).Addr().Interface()); err != nil {
18571857
return xerrors.WithStackTrace(fmt.Errorf(
18581858
"scan error on struct field name '%s': %w",
18591859
field.Name, err,
@@ -2031,7 +2031,7 @@ func (v *tupleValue) TupleItems() []Value {
20312031

20322032
func (v *tupleValue) castTo(dst any) error {
20332033
if len(v.items) == 1 {
2034-
return v.items[0].castTo(dst)
2034+
return CastTo(v.items[0], dst)
20352035
}
20362036

20372037
switch dstValue := dst.(type) {

0 commit comments

Comments
 (0)