diff --git a/src/encoding/xml/marshal.go b/src/encoding/xml/marshal.go index 133503fa2de41c..13fbeeeedc75ce 100644 --- a/src/encoding/xml/marshal.go +++ b/src/encoding/xml/marshal.go @@ -416,12 +416,6 @@ func (p *printer) popPrefix() { } } -var ( - marshalerType = reflect.TypeFor[Marshaler]() - marshalerAttrType = reflect.TypeFor[MarshalerAttr]() - textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]() -) - // marshalValue writes one or more XML elements representing val. // If val was obtained from a struct field, finfo must have its details. func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplate *StartElement) error { @@ -450,24 +444,32 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat typ := val.Type() // Check for marshaler. - if val.CanInterface() && typ.Implements(marshalerType) { - return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate)) + if val.CanInterface() { + if marshaler, ok := reflect.TypeAssert[Marshaler](val); ok { + return p.marshalInterface(marshaler, defaultStart(typ, finfo, startTemplate)) + } } if val.CanAddr() { pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(marshalerType) { - return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate)) + if pv.CanInterface() { + if marshaler, ok := reflect.TypeAssert[Marshaler](pv); ok { + return p.marshalInterface(marshaler, defaultStart(pv.Type(), finfo, startTemplate)) + } } } // Check for text marshaler. - if val.CanInterface() && typ.Implements(textMarshalerType) { - return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate)) + if val.CanInterface() { + if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](val); ok { + return p.marshalTextInterface(textMarshaler, defaultStart(typ, finfo, startTemplate)) + } } if val.CanAddr() { pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { - return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate)) + if pv.CanInterface() { + if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok { + return p.marshalTextInterface(textMarshaler, defaultStart(pv.Type(), finfo, startTemplate)) + } } } @@ -503,7 +505,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name } else { fv := xmlname.value(val, dontInitNilPointers) - if v, ok := fv.Interface().(Name); ok && v.Local != "" { + if v, ok := reflect.TypeAssert[Name](fv); ok && v.Local != "" { start.Name = v } } @@ -580,21 +582,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat // marshalAttr marshals an attribute with the given name and value, adding to start.Attr. func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) error { - if val.CanInterface() && val.Type().Implements(marshalerAttrType) { - attr, err := val.Interface().(MarshalerAttr).MarshalXMLAttr(name) - if err != nil { - return err - } - if attr.Name.Local != "" { - start.Attr = append(start.Attr, attr) - } - return nil - } - - if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { - attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + if val.CanInterface() { + if marshaler, ok := reflect.TypeAssert[MarshalerAttr](val); ok { + attr, err := marshaler.MarshalXMLAttr(name) if err != nil { return err } @@ -605,19 +595,25 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) } } - if val.CanInterface() && val.Type().Implements(textMarshalerType) { - text, err := val.Interface().(encoding.TextMarshaler).MarshalText() - if err != nil { - return err + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() { + if marshaler, ok := reflect.TypeAssert[MarshalerAttr](pv); ok { + attr, err := marshaler.MarshalXMLAttr(name) + if err != nil { + return err + } + if attr.Name.Local != "" { + start.Attr = append(start.Attr, attr) + } + return nil + } } - start.Attr = append(start.Attr, Attr{name, string(text)}) - return nil } - if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { - text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if val.CanInterface() { + if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](val); ok { + text, err := textMarshaler.MarshalText() if err != nil { return err } @@ -626,6 +622,20 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) } } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() { + if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok { + text, err := textMarshaler.MarshalText() + if err != nil { + return err + } + start.Attr = append(start.Attr, Attr{name, string(text)}) + return nil + } + } + } + // Dereference or skip nil pointer, interface values. switch val.Kind() { case reflect.Pointer, reflect.Interface: @@ -647,7 +657,8 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) } if val.Type() == attrType { - start.Attr = append(start.Attr, val.Interface().(Attr)) + attr, _ := reflect.TypeAssert[Attr](val) + start.Attr = append(start.Attr, attr) return nil } @@ -854,20 +865,9 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { if err := s.trim(finfo.parents); err != nil { return err } - if vf.CanInterface() && vf.Type().Implements(textMarshalerType) { - data, err := vf.Interface().(encoding.TextMarshaler).MarshalText() - if err != nil { - return err - } - if err := emit(p, data); err != nil { - return err - } - continue - } - if vf.CanAddr() { - pv := vf.Addr() - if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { - data, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if vf.CanInterface() { + if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](vf); ok { + data, err := textMarshaler.MarshalText() if err != nil { return err } @@ -877,6 +877,21 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { continue } } + if vf.CanAddr() { + pv := vf.Addr() + if pv.CanInterface() { + if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok { + data, err := textMarshaler.MarshalText() + if err != nil { + return err + } + if err := emit(p, data); err != nil { + return err + } + continue + } + } + } var scratch [64]byte vf = indirect(vf) @@ -902,7 +917,7 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { return err } case reflect.Slice: - if elem, ok := vf.Interface().([]byte); ok { + if elem, ok := reflect.TypeAssert[[]byte](vf); ok { if err := emit(p, elem); err != nil { return err } diff --git a/src/encoding/xml/read.go b/src/encoding/xml/read.go index af25c20f0618dc..d3cb74b2c4311a 100644 --- a/src/encoding/xml/read.go +++ b/src/encoding/xml/read.go @@ -255,28 +255,36 @@ func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { } val = val.Elem() } - if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) { + if val.CanInterface() { // This is an unmarshaler with a non-pointer receiver, // so it's likely to be incorrect, but we do what we're told. - return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + if unmarshaler, ok := reflect.TypeAssert[UnmarshalerAttr](val); ok { + return unmarshaler.UnmarshalXMLAttr(attr) + } } if val.CanAddr() { pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) { - return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + if pv.CanInterface() { + if unmarshaler, ok := reflect.TypeAssert[UnmarshalerAttr](pv); ok { + return unmarshaler.UnmarshalXMLAttr(attr) + } } } // Not an UnmarshalerAttr; try encoding.TextUnmarshaler. - if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { + if val.CanInterface() { // This is an unmarshaler with a non-pointer receiver, // so it's likely to be incorrect, but we do what we're told. - return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](val); ok { + return textUnmarshaler.UnmarshalText([]byte(attr.Value)) + } } if val.CanAddr() { pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { - return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + if pv.CanInterface() { + if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok { + return textUnmarshaler.UnmarshalText([]byte(attr.Value)) + } } } @@ -303,12 +311,7 @@ func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { return copyValue(val, []byte(attr.Value)) } -var ( - attrType = reflect.TypeFor[Attr]() - unmarshalerType = reflect.TypeFor[Unmarshaler]() - unmarshalerAttrType = reflect.TypeFor[UnmarshalerAttr]() - textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() -) +var attrType = reflect.TypeFor[Attr]() const ( maxUnmarshalDepth = 10000 @@ -352,27 +355,35 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e val = val.Elem() } - if val.CanInterface() && val.Type().Implements(unmarshalerType) { + if val.CanInterface() { // This is an unmarshaler with a non-pointer receiver, // so it's likely to be incorrect, but we do what we're told. - return d.unmarshalInterface(val.Interface().(Unmarshaler), start) + if unmarshaler, ok := reflect.TypeAssert[Unmarshaler](val); ok { + return d.unmarshalInterface(unmarshaler, start) + } } if val.CanAddr() { pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(unmarshalerType) { - return d.unmarshalInterface(pv.Interface().(Unmarshaler), start) + if pv.CanInterface() { + if unmarshaler, ok := reflect.TypeAssert[Unmarshaler](pv); ok { + return d.unmarshalInterface(unmarshaler, start) + } } } - if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { - return d.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler)) + if val.CanInterface() { + if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](val); ok { + return d.unmarshalTextInterface(textUnmarshaler) + } } if val.CanAddr() { pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { - return d.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler)) + if pv.CanInterface() { + if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok { + return d.unmarshalTextInterface(textUnmarshaler) + } } } @@ -453,7 +464,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e return UnmarshalError(e) } fv := finfo.value(sv, initNilPointers) - if _, ok := fv.Interface().(Name); ok { + if _, ok := reflect.TypeAssert[Name](fv); ok { fv.Set(reflect.ValueOf(start.Name)) } } @@ -578,20 +589,24 @@ Loop: } } - if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) { - if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { - return err + if saveData.IsValid() && saveData.CanInterface() { + if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](saveData); ok { + if err := textUnmarshaler.UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} } - saveData = reflect.Value{} } if saveData.IsValid() && saveData.CanAddr() { pv := saveData.Addr() - if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { - if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { - return err + if pv.CanInterface() { + if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok { + if err := textUnmarshaler.UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} } - saveData = reflect.Value{} } }