@@ -39,12 +39,10 @@ func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.T
39
39
40
40
func (d * decoderBuilder ) newStructUnionDecoder (t reflect.Type ) decoderFunc {
41
41
type variantDecoder struct {
42
- decoder decoderFunc
43
- field reflect.StructField
44
- discriminatorValue any
42
+ decoder decoderFunc
43
+ field reflect.StructField
45
44
}
46
-
47
- variants := []variantDecoder {}
45
+ decoders := []variantDecoder {}
48
46
for i := 0 ; i < t .NumField (); i ++ {
49
47
field := t .Field (i )
50
48
@@ -53,18 +51,26 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
53
51
}
54
52
55
53
decoder := d .typeDecoder (field .Type )
56
- variants = append (variants , variantDecoder {
54
+ decoders = append (decoders , variantDecoder {
57
55
decoder : decoder ,
58
56
field : field ,
59
57
})
60
58
}
61
59
60
+ type discriminatedDecoder struct {
61
+ variantDecoder
62
+ discriminator any
63
+ }
64
+ discriminatedDecoders := []discriminatedDecoder {}
62
65
unionEntry , discriminated := unionRegistry [t ]
63
- for _ , unionVariant := range unionEntry .variants {
64
- for i := 0 ; i < len (variants ); i ++ {
65
- variant := & variants [i ]
66
- if variant .field .Type .Elem () == unionVariant .Type {
67
- variant .discriminatorValue = unionVariant .DiscriminatorValue
66
+ for _ , variant := range unionEntry .variants {
67
+ // For each union variant, find a matching decoder and save it
68
+ for _ , decoder := range decoders {
69
+ if decoder .field .Type .Elem () == variant .Type {
70
+ discriminatedDecoders = append (discriminatedDecoders , discriminatedDecoder {
71
+ decoder ,
72
+ variant .DiscriminatorValue ,
73
+ })
68
74
break
69
75
}
70
76
}
@@ -73,10 +79,10 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
73
79
return func (n gjson.Result , v reflect.Value , state * decoderState ) error {
74
80
if discriminated && n .Type == gjson .JSON && len (unionEntry .discriminatorKey ) != 0 {
75
81
discriminator := n .Get (unionEntry .discriminatorKey ).Value ()
76
- for _ , variant := range variants {
77
- if discriminator == variant . discriminatorValue {
78
- inner := v .FieldByIndex (variant .field .Index )
79
- return variant .decoder (n , inner , state )
82
+ for _ , decoder := range discriminatedDecoders {
83
+ if discriminator == decoder . discriminator {
84
+ inner := v .FieldByIndex (decoder .field .Index )
85
+ return decoder .decoder (n , inner , state )
80
86
}
81
87
}
82
88
return errors .New ("apijson: was not able to find discriminated union variant" )
@@ -85,15 +91,15 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
85
91
// Set bestExactness to worse than loose
86
92
bestExactness := loose - 1
87
93
bestVariant := - 1
88
- for i , variant := range variants {
94
+ for i , decoder := range decoders {
89
95
// Pointers are used to discern JSON object variants from value variants
90
- if n .Type != gjson .JSON && variant .field .Type .Kind () == reflect .Ptr {
96
+ if n .Type != gjson .JSON && decoder .field .Type .Kind () == reflect .Ptr {
91
97
continue
92
98
}
93
99
94
100
sub := decoderState {strict : state .strict , exactness : exact }
95
- inner := v .FieldByIndex (variant .field .Index )
96
- err := variant .decoder (n , inner , & sub )
101
+ inner := v .FieldByIndex (decoder .field .Index )
102
+ err := decoder .decoder (n , inner , & sub )
97
103
if err != nil {
98
104
continue
99
105
}
@@ -116,11 +122,11 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
116
122
return errors .New ("apijson: was not able to coerce type as union strictly" )
117
123
}
118
124
119
- for i := 0 ; i < len (variants ); i ++ {
125
+ for i := 0 ; i < len (decoders ); i ++ {
120
126
if i == bestVariant {
121
127
continue
122
128
}
123
- v .FieldByIndex (variants [i ].field .Index ).SetZero ()
129
+ v .FieldByIndex (decoders [i ].field .Index ).SetZero ()
124
130
}
125
131
126
132
return nil
0 commit comments