Skip to content

Commit 01636bf

Browse files
authored
Merge pull request #13 from dispatchrun/serialize-primitive-slices
Native serialization of JSON-like slices & maps
2 parents 9371a7a + d4984b4 commit 01636bf

File tree

7 files changed

+183
-25
lines changed

7 files changed

+183
-25
lines changed

dispatchhttp/client.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !durable
2+
13
package dispatchhttp
24

35
import (

dispatchhttp/header.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !durable
2+
13
package dispatchhttp
24

35
import (

dispatchhttp/request.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !durable
2+
13
package dispatchhttp
24

35
import (

dispatchhttp/response.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !durable
2+
13
package dispatchhttp
24

35
import (

dispatchproto/any.go

Lines changed: 153 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ func Duration(v time.Duration) Any {
8585
// Primitive values (booleans, integers, floats, strings, bytes, timestamps,
8686
// durations) are supported, along with values that implement either
8787
// proto.Message, json.Marshaler, encoding.TextMarshaler or
88-
// encoding.BinaryMarshaler.
88+
// encoding.BinaryMarshaler. Slices and maps are also supported, as long
89+
// as they are JSON-like in shape.
8990
func Marshal(v any) (Any, error) {
90-
if rv := reflect.ValueOf(v); rv.Kind() == reflect.Pointer && rv.IsNil() {
91+
rv := reflect.ValueOf(v)
92+
if rv.Kind() == reflect.Pointer && rv.IsNil() {
9193
return Nil(), nil
9294
}
9395
var m proto.Message
@@ -160,7 +162,10 @@ func Marshal(v any) (Any, error) {
160162
case []byte:
161163
m = wrapperspb.Bytes(vv)
162164
default:
163-
return Any{}, fmt.Errorf("cannot serialize %v (%T)", v, v)
165+
var err error
166+
if m, err = newStructpbValue(rv); err != nil {
167+
return Any{}, fmt.Errorf("cannot serialize %v: %w", v, err)
168+
}
164169
}
165170

166171
proto, err := anypb.New(m)
@@ -386,6 +391,10 @@ func (a Any) Unmarshal(v any) error {
386391
}
387392
}
388393

394+
if s, ok := m.(*structpb.Value); ok {
395+
return fromStructpbValue(elem, s)
396+
}
397+
389398
return fmt.Errorf("cannot deserialize %T into %v (%v kind)", m, elem.Type(), elem.Kind())
390399
}
391400

@@ -404,3 +413,144 @@ func (a Any) String() string {
404413
func (a Any) Equal(other Any) bool {
405414
return proto.Equal(a.proto, other.proto)
406415
}
416+
417+
func newStructpbValue(rv reflect.Value) (*structpb.Value, error) {
418+
switch rv.Kind() {
419+
case reflect.Bool:
420+
return structpb.NewBoolValue(rv.Bool()), nil
421+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
422+
n := rv.Int()
423+
f := float64(n)
424+
if int64(f) != n {
425+
return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f)
426+
}
427+
return structpb.NewNumberValue(f), nil
428+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
429+
n := rv.Uint()
430+
f := float64(n)
431+
if uint64(f) != n {
432+
return nil, fmt.Errorf("cannot serialize %d as number structpb.Value (%v) without losing information", n, f)
433+
}
434+
return structpb.NewNumberValue(f), nil
435+
case reflect.Float32, reflect.Float64:
436+
return structpb.NewNumberValue(rv.Float()), nil
437+
case reflect.String:
438+
return structpb.NewStringValue(rv.String()), nil
439+
case reflect.Interface:
440+
if rv.NumMethod() == 0 { // interface{} aka. any
441+
v := rv.Interface()
442+
if v == nil {
443+
return structpb.NewNullValue(), nil
444+
}
445+
return newStructpbValue(reflect.ValueOf(v))
446+
}
447+
case reflect.Slice:
448+
list := &structpb.ListValue{Values: make([]*structpb.Value, rv.Len())}
449+
for i := range list.Values {
450+
elem := rv.Index(i)
451+
var err error
452+
list.Values[i], err = newStructpbValue(elem)
453+
if err != nil {
454+
return nil, err
455+
}
456+
}
457+
return structpb.NewListValue(list), nil
458+
case reflect.Map:
459+
strct := &structpb.Struct{Fields: make(map[string]*structpb.Value, rv.Len())}
460+
iter := rv.MapRange()
461+
for iter.Next() {
462+
k := iter.Key()
463+
464+
var strKey string
465+
var hasStrKey bool
466+
switch k.Kind() {
467+
case reflect.String:
468+
strKey = k.String()
469+
hasStrKey = true
470+
case reflect.Interface:
471+
if s, ok := k.Interface().(string); ok {
472+
strKey = s
473+
hasStrKey = true
474+
}
475+
}
476+
if !hasStrKey {
477+
return nil, fmt.Errorf("cannot serialize map with %s (%s) key", k.Type(), k.Kind())
478+
}
479+
480+
v, err := newStructpbValue(iter.Value())
481+
if err != nil {
482+
return nil, err
483+
}
484+
strct.Fields[strKey] = v
485+
}
486+
return structpb.NewStructValue(strct), nil
487+
}
488+
return nil, fmt.Errorf("not implemented: %s", rv.Type())
489+
}
490+
491+
func fromStructpbValue(rv reflect.Value, s *structpb.Value) error {
492+
switch rv.Kind() {
493+
case reflect.Bool:
494+
if b, ok := s.Kind.(*structpb.Value_BoolValue); ok {
495+
rv.SetBool(b.BoolValue)
496+
return nil
497+
}
498+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
499+
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
500+
rv.SetInt(int64(n.NumberValue))
501+
return nil
502+
}
503+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
504+
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
505+
rv.SetUint(uint64(n.NumberValue))
506+
return nil
507+
}
508+
case reflect.Float32, reflect.Float64:
509+
if n, ok := s.Kind.(*structpb.Value_NumberValue); ok {
510+
rv.SetFloat(n.NumberValue)
511+
return nil
512+
}
513+
case reflect.String:
514+
if str, ok := s.Kind.(*structpb.Value_StringValue); ok {
515+
rv.SetString(str.StringValue)
516+
return nil
517+
}
518+
case reflect.Slice:
519+
if l, ok := s.Kind.(*structpb.Value_ListValue); ok {
520+
values := l.ListValue.GetValues()
521+
rv.Grow(len(values))
522+
rv.SetLen(len(values))
523+
for i, value := range values {
524+
if err := fromStructpbValue(rv.Index(i), value); err != nil {
525+
return err
526+
}
527+
}
528+
return nil
529+
}
530+
case reflect.Map:
531+
if strct, ok := s.Kind.(*structpb.Value_StructValue); ok {
532+
fields := strct.StructValue.Fields
533+
rv.Set(reflect.MakeMapWithSize(rv.Type(), len(fields)))
534+
valueType := rv.Type().Elem()
535+
for key, value := range fields {
536+
mv := reflect.New(valueType).Elem()
537+
if err := fromStructpbValue(mv, value); err != nil {
538+
return err
539+
}
540+
rv.SetMapIndex(reflect.ValueOf(key), mv)
541+
}
542+
return nil
543+
}
544+
case reflect.Interface:
545+
if rv.NumMethod() == 0 { // interface{} aka. any
546+
v := s.AsInterface()
547+
if v == nil {
548+
rv.SetZero()
549+
} else {
550+
rv.Set(reflect.ValueOf(s.AsInterface()))
551+
}
552+
return nil
553+
}
554+
}
555+
return fmt.Errorf("cannot deserialize %T into %v (%v kind)", s, rv.Type(), rv.Kind())
556+
}

dispatchproto/any_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"fmt"
88
"math"
9+
"net/http"
910
"reflect"
1011
"strings"
1112
"testing"
@@ -355,6 +356,20 @@ func TestAny(t *testing.T) {
355356
List: []any{nil, false, []any{"foo", "bar"}, map[string]any{"abc": "xyz"}},
356357
Object: map[string]any{"n": 3.14, "flag": true, "tags": []any{"x", "y", "z"}},
357358
}},
359+
360+
// slices
361+
[]string{"foo", "bar"},
362+
[]int{-1, 1, 111},
363+
[]bool{true, false, true},
364+
[]float64{3.14, 1.25},
365+
[][]string{{"foo", "bar"}, {"abc", "xyz"}},
366+
[]any{3.14, true, "x", nil},
367+
368+
// maps
369+
map[string]string{"abc": "xyz", "foo": "bar"},
370+
map[string]int{"n": 3},
371+
map[string]http.Header{"original": {"X-Foo": []string{"bar"}}},
372+
map[any]any{"foo": "bar", "pi": 3.14},
358373
} {
359374
t.Run(fmt.Sprintf("%v", v), func(t *testing.T) {
360375
boxed, err := dispatchproto.Marshal(v)

examples/fanout/main.go

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !durable
2+
13
package main
24

35
import (
@@ -12,14 +14,14 @@ import (
1214

1315
func main() {
1416
getRepo := dispatch.Func("getRepo", func(ctx context.Context, name string) (*dispatchhttp.Response, error) {
15-
return dispatchhttp.Get(context.Background(), "https://api.github.com/repos/dispatchrun/"+name)
17+
return dispatchhttp.Get(ctx, "https://api.github.com/repos/dispatchrun/"+name)
1618
})
1719

1820
getStargazers := dispatch.Func("getStargazers", func(ctx context.Context, url string) (*dispatchhttp.Response, error) {
19-
return dispatchhttp.Get(context.Background(), url)
21+
return dispatchhttp.Get(ctx, url)
2022
})
2123

22-
reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs strings) (strings, error) {
24+
reduceStargazers := dispatch.Func("reduceStargazers", func(ctx context.Context, stargazerURLs []string) ([]string, error) {
2325
responses, err := getStargazers.Gather(stargazerURLs)
2426
if err != nil {
2527
return nil, err
@@ -39,7 +41,7 @@ func main() {
3941
return maps.Keys(stargazers), nil
4042
})
4143

42-
fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames strings) (strings, error) {
44+
fanout := dispatch.Func("fanout", func(ctx context.Context, repoNames []string) ([]string, error) {
4345
responses, err := getRepo.Gather(repoNames)
4446
if err != nil {
4547
return nil, err
@@ -65,7 +67,7 @@ func main() {
6567
}
6668

6769
go func() {
68-
if _, err := fanout.Dispatch(context.Background(), strings{"coroutine", "dispatch-py"}); err != nil {
70+
if _, err := fanout.Dispatch(context.Background(), []string{"coroutine", "dispatch-py"}); err != nil {
6971
log.Fatalf("failed to dispatch call: %v", err)
7072
}
7173
}()
@@ -74,20 +76,3 @@ func main() {
7476
log.Fatalf("failed to serve endpoint: %v", err)
7577
}
7678
}
77-
78-
// TODO: update dispatchproto.Marshal to support serializing slices/maps
79-
// natively (if they can be sent on the wire as structpb.Value)
80-
type strings []string
81-
82-
func (s strings) MarshalJSON() ([]byte, error) {
83-
return json.Marshal([]string(s))
84-
}
85-
86-
func (s *strings) UnmarshalJSON(b []byte) error {
87-
var c []string
88-
if err := json.Unmarshal(b, &c); err != nil {
89-
return err
90-
}
91-
*s = c
92-
return nil
93-
}

0 commit comments

Comments
 (0)