Skip to content

Commit 5468ee8

Browse files
fix(internal): unmarshal correctly when there are multiple discriminators
1 parent 5b87899 commit 5468ee8

File tree

2 files changed

+115
-21
lines changed

2 files changed

+115
-21
lines changed

internal/apijson/decodeparam_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,36 @@ func init() {
351351
})
352352
}
353353

354+
type FooVariant struct {
355+
Type string `json:"type,required"`
356+
Value string `json:"value,required"`
357+
}
358+
359+
type BarVariant struct {
360+
Type string `json:"type,required"`
361+
Enable bool `json:"enable,required"`
362+
}
363+
364+
type MultiDiscriminatorUnion struct {
365+
OfFoo *FooVariant `json:",inline"`
366+
OfBar *BarVariant `json:",inline"`
367+
368+
paramUnion
369+
}
370+
371+
func init() {
372+
apijson.RegisterDiscriminatedUnion[MultiDiscriminatorUnion]("type", map[string]reflect.Type{
373+
"foo": reflect.TypeOf(FooVariant{}),
374+
"foo_v2": reflect.TypeOf(FooVariant{}),
375+
"bar": reflect.TypeOf(BarVariant{}),
376+
"bar_legacy": reflect.TypeOf(BarVariant{}),
377+
})
378+
}
379+
380+
func (m *MultiDiscriminatorUnion) UnmarshalJSON(data []byte) error {
381+
return apijson.UnmarshalRoot(data, m)
382+
}
383+
354384
func (d *DiscriminatedUnion) UnmarshalJSON(data []byte) error {
355385
return apijson.UnmarshalRoot(data, d)
356386
}
@@ -408,3 +438,61 @@ func TestDiscriminatedUnion(t *testing.T) {
408438
})
409439
}
410440
}
441+
442+
func TestMultiDiscriminatorUnion(t *testing.T) {
443+
tests := map[string]struct {
444+
raw string
445+
target MultiDiscriminatorUnion
446+
shouldFail bool
447+
}{
448+
"foo_variant": {
449+
raw: `{"type":"foo","value":"test"}`,
450+
target: MultiDiscriminatorUnion{OfFoo: &FooVariant{
451+
Type: "foo",
452+
Value: "test",
453+
}},
454+
},
455+
"foo_v2_variant": {
456+
raw: `{"type":"foo_v2","value":"test_v2"}`,
457+
target: MultiDiscriminatorUnion{OfFoo: &FooVariant{
458+
Type: "foo_v2",
459+
Value: "test_v2",
460+
}},
461+
},
462+
"bar_variant": {
463+
raw: `{"type":"bar","enable":true}`,
464+
target: MultiDiscriminatorUnion{OfBar: &BarVariant{
465+
Type: "bar",
466+
Enable: true,
467+
}},
468+
},
469+
"bar_legacy_variant": {
470+
raw: `{"type":"bar_legacy","enable":false}`,
471+
target: MultiDiscriminatorUnion{OfBar: &BarVariant{
472+
Type: "bar_legacy",
473+
Enable: false,
474+
}},
475+
},
476+
"invalid_type": {
477+
raw: `{"type":"unknown","value":"test"}`,
478+
target: MultiDiscriminatorUnion{},
479+
shouldFail: true,
480+
},
481+
}
482+
483+
for name, test := range tests {
484+
t.Run(name, func(t *testing.T) {
485+
var dst MultiDiscriminatorUnion
486+
err := json.Unmarshal([]byte(test.raw), &dst)
487+
if err != nil && !test.shouldFail {
488+
t.Fatalf("failed unmarshal with err: %v", err)
489+
}
490+
if err == nil && test.shouldFail {
491+
t.Fatalf("expected unmarshal to fail but it succeeded")
492+
}
493+
if !reflect.DeepEqual(dst, test.target) {
494+
t.Fatalf("failed equality, got %#v but expected %#v", dst, test.target)
495+
}
496+
})
497+
}
498+
}

internal/apijson/union.go

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,10 @@ func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.T
3939

4040
func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
4141
type variantDecoder struct {
42-
decoder decoderFunc
43-
field reflect.StructField
44-
discriminatorValue any
42+
decoder decoderFunc
43+
field reflect.StructField
4544
}
46-
47-
variants := []variantDecoder{}
45+
decoders := []variantDecoder{}
4846
for i := 0; i < t.NumField(); i++ {
4947
field := t.Field(i)
5048

@@ -53,18 +51,26 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
5351
}
5452

5553
decoder := d.typeDecoder(field.Type)
56-
variants = append(variants, variantDecoder{
54+
decoders = append(decoders, variantDecoder{
5755
decoder: decoder,
5856
field: field,
5957
})
6058
}
6159

60+
type discriminatedDecoder struct {
61+
variantDecoder
62+
discriminator any
63+
}
64+
discriminatedDecoders := []discriminatedDecoder{}
6265
unionEntry, discriminated := unionRegistry[t]
63-
for _, unionVariant := range unionEntry.variants {
64-
for i := 0; i < len(variants); i++ {
65-
variant := &variants[i]
66-
if variant.field.Type.Elem() == unionVariant.Type {
67-
variant.discriminatorValue = unionVariant.DiscriminatorValue
66+
for _, variant := range unionEntry.variants {
67+
// For each union variant, find a matching decoder and save it
68+
for _, decoder := range decoders {
69+
if decoder.field.Type.Elem() == variant.Type {
70+
discriminatedDecoders = append(discriminatedDecoders, discriminatedDecoder{
71+
decoder,
72+
variant.DiscriminatorValue,
73+
})
6874
break
6975
}
7076
}
@@ -73,10 +79,10 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
7379
return func(n gjson.Result, v reflect.Value, state *decoderState) error {
7480
if discriminated && n.Type == gjson.JSON && len(unionEntry.discriminatorKey) != 0 {
7581
discriminator := n.Get(unionEntry.discriminatorKey).Value()
76-
for _, variant := range variants {
77-
if discriminator == variant.discriminatorValue {
78-
inner := v.FieldByIndex(variant.field.Index)
79-
return variant.decoder(n, inner, state)
82+
for _, decoder := range discriminatedDecoders {
83+
if discriminator == decoder.discriminator {
84+
inner := v.FieldByIndex(decoder.field.Index)
85+
return decoder.decoder(n, inner, state)
8086
}
8187
}
8288
return errors.New("apijson: was not able to find discriminated union variant")
@@ -85,15 +91,15 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
8591
// Set bestExactness to worse than loose
8692
bestExactness := loose - 1
8793
bestVariant := -1
88-
for i, variant := range variants {
94+
for i, decoder := range decoders {
8995
// Pointers are used to discern JSON object variants from value variants
90-
if n.Type != gjson.JSON && variant.field.Type.Kind() == reflect.Ptr {
96+
if n.Type != gjson.JSON && decoder.field.Type.Kind() == reflect.Ptr {
9197
continue
9298
}
9399

94100
sub := decoderState{strict: state.strict, exactness: exact}
95-
inner := v.FieldByIndex(variant.field.Index)
96-
err := variant.decoder(n, inner, &sub)
101+
inner := v.FieldByIndex(decoder.field.Index)
102+
err := decoder.decoder(n, inner, &sub)
97103
if err != nil {
98104
continue
99105
}
@@ -116,11 +122,11 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
116122
return errors.New("apijson: was not able to coerce type as union strictly")
117123
}
118124

119-
for i := 0; i < len(variants); i++ {
125+
for i := 0; i < len(decoders); i++ {
120126
if i == bestVariant {
121127
continue
122128
}
123-
v.FieldByIndex(variants[i].field.Index).SetZero()
129+
v.FieldByIndex(decoders[i].field.Index).SetZero()
124130
}
125131

126132
return nil

0 commit comments

Comments
 (0)