Skip to content

json: support omitzero #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
version: "2"
linters:
exclusions:
rules:
# Tests copied from the stdlib are not meant to be linted.
- path: 'golang_(.+_)?test\.go'
source: "^" # regex
2 changes: 1 addition & 1 deletion benchmarks/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ benchstat := ${GOPATH}/bin/benchstat
all:

$(benchstat):
go install golang.org/x/perf/cmd/benchstat
go install golang.org/x/perf/cmd/benchstat@latest

$(benchmark.cmd.dir)/message.pb.go: $(benchmark.cmd.dir)/message.proto
@protoc -I. \
Expand Down
139 changes: 83 additions & 56 deletions json/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding"
"encoding/json"
"fmt"
"maps"
"math/big"
"reflect"
"sort"
Expand Down Expand Up @@ -73,12 +74,9 @@ func cacheLoad() map[unsafe.Pointer]codec {

func cacheStore(typ reflect.Type, cod codec, oldCodecs map[unsafe.Pointer]codec) {
newCodecs := make(map[unsafe.Pointer]codec, len(oldCodecs)+1)
maps.Copy(newCodecs, oldCodecs)
newCodecs[typeid(typ)] = cod

for t, c := range oldCodecs {
newCodecs[t] = c
}

cache.Store(&newCodecs)
}

Expand Down Expand Up @@ -205,7 +203,7 @@ func constructCodec(t reflect.Type, seen map[reflect.Type]*structType, canAddr b
c = constructUnsupportedTypeCodec(t)
}

p := reflect.PtrTo(t)
p := reflect.PointerTo(t)

if canAddr {
switch {
Expand Down Expand Up @@ -291,7 +289,7 @@ func constructSliceCodec(t reflect.Type, seen map[reflect.Type]*structType) code
// Go 1.7+ behavior: slices of byte types (and aliases) may override the
// default encoding and decoding behaviors by implementing marshaler and
// unmarshaler interfaces.
p := reflect.PtrTo(e)
p := reflect.PointerTo(e)
c := codec{}

switch {
Expand Down Expand Up @@ -391,7 +389,7 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec
kc := codec{}
vc := constructCodec(v, seen, false)

if k.Implements(textMarshalerType) || reflect.PtrTo(k).Implements(textUnmarshalerType) {
if k.Implements(textMarshalerType) || reflect.PointerTo(k).Implements(textUnmarshalerType) {
kc.encode = constructTextMarshalerEncodeFunc(k, false)
kc.decode = constructTextUnmarshalerDecodeFunc(k, true)

Expand Down Expand Up @@ -570,6 +568,7 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
anonymous = f.Anonymous
tag = false
omitempty = false
omitzero = false
stringify = false
unexported = len(f.PkgPath) != 0
)
Expand All @@ -595,6 +594,8 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
switch tag {
case "omitempty":
omitempty = true
case "omitzero":
omitzero = true
case "string":
stringify = true
}
Expand Down Expand Up @@ -677,9 +678,11 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
fields = append(fields, structField{
codec: codec,
offset: offset + f.Offset,
empty: emptyFuncOf(f.Type),
isEmpty: emptyFuncOf(f.Type),
isZero: zeroFuncOf(f.Type),
tag: tag,
omitempty: omitempty,
omitzero: omitzero,
name: name,
index: i << 32,
typ: f.Type,
Expand Down Expand Up @@ -897,6 +900,18 @@ func isValidTag(s string) bool {
return true
}

func zeroFuncOf(t reflect.Type) emptyFunc {
if t.Implements(isZeroerType) {
return func(p unsafe.Pointer) bool {
return unsafeToAny(t, p).(isZeroer).IsZero()
}
}

return func(p unsafe.Pointer) bool {
return reflectDeref(t, p).IsZero()
}
}

func emptyFuncOf(t reflect.Type) emptyFunc {
switch t {
case bytesType, rawMessageType:
Expand All @@ -910,7 +925,7 @@ func emptyFuncOf(t reflect.Type) emptyFunc {
}

case reflect.Map:
return func(p unsafe.Pointer) bool { return reflect.NewAt(t, p).Elem().Len() == 0 }
return func(p unsafe.Pointer) bool { return reflectDeref(t, p).Len() == 0 }

case reflect.Slice:
return func(p unsafe.Pointer) bool { return (*slice)(p).len == 0 }
Expand Down Expand Up @@ -955,6 +970,14 @@ func emptyFuncOf(t reflect.Type) emptyFunc {
return func(unsafe.Pointer) bool { return false }
}

func reflectDeref(t reflect.Type, p unsafe.Pointer) reflect.Value {
return reflect.NewAt(t, p).Elem()
}

func unsafeToAny(t reflect.Type, p unsafe.Pointer) any {
return reflectDeref(t, p).Interface()
}

type iface struct {
typ unsafe.Pointer
ptr unsafe.Pointer
Expand All @@ -972,15 +995,16 @@ type structType struct {
ficaseIndex map[string]*structField
keyset []byte
typ reflect.Type
inlined bool
}

type structField struct {
codec codec
offset uintptr
empty emptyFunc
isEmpty emptyFunc
isZero emptyFunc
tag bool
omitempty bool
omitzero bool
json string
html string
name string
Expand Down Expand Up @@ -1066,53 +1090,56 @@ type sliceHeader struct {
Cap int
}

type isZeroer interface{ IsZero() bool }

var (
nullType = reflect.TypeOf(nil)
boolType = reflect.TypeOf(false)

intType = reflect.TypeOf(int(0))
int8Type = reflect.TypeOf(int8(0))
int16Type = reflect.TypeOf(int16(0))
int32Type = reflect.TypeOf(int32(0))
int64Type = reflect.TypeOf(int64(0))

uintType = reflect.TypeOf(uint(0))
uint8Type = reflect.TypeOf(uint8(0))
uint16Type = reflect.TypeOf(uint16(0))
uint32Type = reflect.TypeOf(uint32(0))
uint64Type = reflect.TypeOf(uint64(0))
uintptrType = reflect.TypeOf(uintptr(0))

float32Type = reflect.TypeOf(float32(0))
float64Type = reflect.TypeOf(float64(0))

bigIntType = reflect.TypeOf(new(big.Int))
numberType = reflect.TypeOf(json.Number(""))
stringType = reflect.TypeOf("")
stringsType = reflect.TypeOf([]string(nil))
bytesType = reflect.TypeOf(([]byte)(nil))
durationType = reflect.TypeOf(time.Duration(0))
timeType = reflect.TypeOf(time.Time{})
rawMessageType = reflect.TypeOf(RawMessage(nil))

numberPtrType = reflect.PtrTo(numberType)
durationPtrType = reflect.PtrTo(durationType)
timePtrType = reflect.PtrTo(timeType)
rawMessagePtrType = reflect.PtrTo(rawMessageType)

sliceInterfaceType = reflect.TypeOf(([]any)(nil))
sliceStringType = reflect.TypeOf(([]any)(nil))
mapStringInterfaceType = reflect.TypeOf((map[string]any)(nil))
mapStringRawMessageType = reflect.TypeOf((map[string]RawMessage)(nil))
mapStringStringType = reflect.TypeOf((map[string]string)(nil))
mapStringStringSliceType = reflect.TypeOf((map[string][]string)(nil))
mapStringBoolType = reflect.TypeOf((map[string]bool)(nil))

interfaceType = reflect.TypeOf((*any)(nil)).Elem()
jsonMarshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
jsonUnmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
boolType = reflect.TypeFor[bool]()

intType = reflect.TypeFor[int]()
int8Type = reflect.TypeFor[int8]()
int16Type = reflect.TypeFor[int16]()
int32Type = reflect.TypeFor[int32]()
int64Type = reflect.TypeFor[int64]()

uintType = reflect.TypeFor[uint]()
uint8Type = reflect.TypeFor[uint8]()
uint16Type = reflect.TypeFor[uint16]()
uint32Type = reflect.TypeFor[uint32]()
uint64Type = reflect.TypeFor[uint64]()
uintptrType = reflect.TypeFor[uintptr]()

float32Type = reflect.TypeFor[float32]()
float64Type = reflect.TypeFor[float64]()

bigIntType = reflect.TypeFor[*big.Int]()
numberType = reflect.TypeFor[json.Number]()
stringType = reflect.TypeFor[string]()
stringsType = reflect.TypeFor[[]string]()
bytesType = reflect.TypeFor[[]byte]()
durationType = reflect.TypeFor[time.Duration]()
timeType = reflect.TypeFor[time.Time]()
rawMessageType = reflect.TypeFor[RawMessage]()

numberPtrType = reflect.PointerTo(numberType)
durationPtrType = reflect.PointerTo(durationType)
timePtrType = reflect.PointerTo(timeType)
rawMessagePtrType = reflect.PointerTo(rawMessageType)

sliceInterfaceType = reflect.TypeFor[[]any]()
sliceStringType = reflect.TypeFor[[]any]()
mapStringInterfaceType = reflect.TypeFor[map[string]any]()
mapStringRawMessageType = reflect.TypeFor[map[string]RawMessage]()
mapStringStringType = reflect.TypeFor[map[string]string]()
mapStringStringSliceType = reflect.TypeFor[map[string][]string]()
mapStringBoolType = reflect.TypeFor[map[string]bool]()

interfaceType = reflect.TypeFor[any]()
jsonMarshalerType = reflect.TypeFor[Marshaler]()
jsonUnmarshalerType = reflect.TypeFor[Unmarshaler]()
textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]()
textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
isZeroerType = reflect.TypeFor[isZeroer]()

bigIntDecoder = constructJSONUnmarshalerDecodeFunc(bigIntType, false)
)
Expand Down
4 changes: 2 additions & 2 deletions json/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ func (d decoder) decodeMaybeEmptyInterface(b []byte, p unsafe.Pointer, t reflect
return d.decodeUnmarshalTypeError(b, p, t)
}

func (d decoder) decodeUnmarshalTypeError(b []byte, p unsafe.Pointer, t reflect.Type) ([]byte, error) {
func (d decoder) decodeUnmarshalTypeError(b []byte, _ unsafe.Pointer, t reflect.Type) ([]byte, error) {
v, b, _, err := d.parseValue(b)
if err != nil {
return b, err
Expand Down Expand Up @@ -1500,7 +1500,7 @@ func (d decoder) decodeTextUnmarshaler(b []byte, p unsafe.Pointer, t reflect.Typ
value = "array"
}

return b, &UnmarshalTypeError{Value: value, Type: reflect.PtrTo(t)}
return b, &UnmarshalTypeError{Value: value, Type: reflect.PointerTo(t)}
}

func (d decoder) prependField(key, field string) string {
Expand Down
6 changes: 5 additions & 1 deletion json/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,11 @@ func (e encoder) encodeStruct(b []byte, p unsafe.Pointer, st *structType) ([]byt
f := &st.fields[i]
v := unsafe.Pointer(uintptr(p) + f.offset)

if f.omitempty && f.empty(v) {
switch {
case f.omitempty && f.isEmpty(v):
continue

case f.omitzero && f.isZero(v):
continue
}

Expand Down
4 changes: 2 additions & 2 deletions json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
type Delim = json.Delim

// InvalidUTF8Error is documented at https://golang.org/pkg/encoding/json/#InvalidUTF8Error
type InvalidUTF8Error = json.InvalidUTF8Error
type InvalidUTF8Error = json.InvalidUTF8Error //nolint:staticcheck // compat.

// InvalidUnmarshalError is documented at https://golang.org/pkg/encoding/json/#InvalidUnmarshalError
type InvalidUnmarshalError = json.InvalidUnmarshalError
Expand All @@ -39,7 +39,7 @@ type SyntaxError = json.SyntaxError
type Token = json.Token

// UnmarshalFieldError is documented at https://golang.org/pkg/encoding/json/#UnmarshalFieldError
type UnmarshalFieldError = json.UnmarshalFieldError
type UnmarshalFieldError = json.UnmarshalFieldError //nolint:staticcheck // compat.

// UnmarshalTypeError is documented at https://golang.org/pkg/encoding/json/#UnmarshalTypeError
type UnmarshalTypeError = json.UnmarshalTypeError
Expand Down
11 changes: 3 additions & 8 deletions json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ var testValues = [...]any{
A string `json:"name"`
B string `json:"-"`
C string `json:",omitempty"`
D map[string]any `json:",string"`
D map[string]any `json:",string"` //nolint:staticcheck // intentional
e string
}{A: "Luke", D: map[string]any{"answer": float64(42)}},
struct{ point }{point{1, 2}},
Expand Down Expand Up @@ -880,12 +880,11 @@ func TestDecodeLines(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
d := NewDecoder(test.reader)
var count int
var err error
for {
var o obj
err = d.Decode(&o)
err := d.Decode(&o)
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
break
}

Expand All @@ -904,10 +903,6 @@ func TestDecodeLines(t *testing.T) {
count++
}

if err != nil && err != io.EOF {
t.Error(err)
}

if count != test.expectCount {
t.Errorf("expected %d objects, got %d", test.expectCount, count)
}
Expand Down
5 changes: 0 additions & 5 deletions json/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ const (
cr = '\r'
)

const (
escape = '\\'
quote = '"'
)

func internalParseFlags(b []byte) (flags ParseFlags) {
// Don't consider surrounding whitespace
b = skipSpaces(b)
Expand Down