Skip to content

Commit 2d8cb52

Browse files
committed
Enable nested UnmarshalMaxMindDB support
Extend the UnmarshalMaxMindDB interface to work recursively with nested types, matching the behavior of encoding/json's UnmarshalJSON. Custom unmarshalers are now called for: - Struct fields that implement Unmarshaler - Pointer fields (creates value if nil, then checks for Unmarshaler) - Slice elements that implement Unmarshaler - Map values that implement Unmarshaler This enhancement allows for more flexible custom decoding strategies in complex data structures, improving performance optimization opportunities for nested types.
1 parent b97e1e3 commit 2d8cb52

File tree

5 files changed

+275
-4
lines changed

5 files changed

+275
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
`json.Unmarshaler`.
1414
- Added public `Decoder` type with methods for manual decoding including
1515
`DecodeMap()`, `DecodeSlice()`, `DecodeString()`, `DecodeUInt32()`, etc.
16+
- Enhanced `UnmarshalMaxMindDB` to work with nested struct fields, slice
17+
elements, and map values. The custom unmarshaler is now called recursively
18+
for any type that implements the `Unmarshaler` interface, similar to
19+
`encoding/json`.
1620

1721
## 2.0.0-beta.3 - 2025-02-16
1822

internal/decoder/decoder.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package decoder
22

33
import (
4+
"errors"
45
"fmt"
56
"iter"
67

@@ -362,6 +363,13 @@ func (d *Decoder) setNextOffset(offset uint) {
362363
}
363364
}
364365

366+
func (d *Decoder) getNextOffset() (uint, error) {
367+
if !d.hasNextOffset {
368+
return 0, errors.New("no next offset available")
369+
}
370+
return d.nextOffset, nil
371+
}
372+
365373
func unexpectedTypeErr(expectedType, actualType Type) error {
366374
return fmt.Errorf("unexpected type %d, expected %d", actualType, expectedType)
367375
}

internal/decoder/decoder_test.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,16 +382,22 @@ func TestBoundsChecking(t *testing.T) {
382382
require.Contains(t, err.Error(), "exceeds buffer length")
383383

384384
// Test DecodeBytes bounds checking with a separate buffer
385-
bytesBuffer := []byte{0x84, 0x41} // Type bytes (4 << 5 = 0x80), size 4 (0x04), but only 2 bytes total
385+
bytesBuffer := []byte{
386+
0x84,
387+
0x41,
388+
} // Type bytes (4 << 5 = 0x80), size 4 (0x04), but only 2 bytes total
386389
dd3 := NewDataDecoder(bytesBuffer)
387390
decoder3 := &Decoder{d: dd3, offset: 0}
388-
391+
389392
_, err = decoder3.DecodeBytes()
390393
require.Error(t, err)
391394
require.Contains(t, err.Error(), "exceeds buffer length")
392395

393-
// Test DecodeUInt128 bounds checking
394-
uint128Buffer := []byte{0x0B, 0x03} // Extended type (0x0), size 11, TypeUint128-7=3, but only 2 bytes total
396+
// Test DecodeUInt128 bounds checking
397+
uint128Buffer := []byte{
398+
0x0B,
399+
0x03,
400+
} // Extended type (0x0), size 11, TypeUint128-7=3, but only 2 bytes total
395401
dd2 := NewDataDecoder(uint128Buffer)
396402
decoder2 := &Decoder{d: dd2, offset: 0}
397403

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
package decoder
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
// Inner type with UnmarshalMaxMindDB.
10+
type testInnerNested struct {
11+
Value string
12+
custom bool // track if custom unmarshaler was called
13+
}
14+
15+
func (i *testInnerNested) UnmarshalMaxMindDB(d *Decoder) error {
16+
i.custom = true
17+
str, err := d.DecodeString()
18+
if err != nil {
19+
return err
20+
}
21+
i.Value = "custom:" + str
22+
return nil
23+
}
24+
25+
// TestNestedUnmarshaler tests that UnmarshalMaxMindDB is called for nested struct fields.
26+
func TestNestedUnmarshaler(t *testing.T) {
27+
// Outer type without UnmarshalMaxMindDB
28+
type Outer struct {
29+
Field testInnerNested
30+
Name string
31+
}
32+
33+
// Create test data: a map with "Field" -> "test" and "Name" -> "example"
34+
data := []byte{
35+
// Map with 2 items
36+
0xe2,
37+
// Key "Field"
38+
0x45, 'F', 'i', 'e', 'l', 'd',
39+
// Value "test" (string)
40+
0x44, 't', 'e', 's', 't',
41+
// Key "Name"
42+
0x44, 'N', 'a', 'm', 'e',
43+
// Value "example" (string)
44+
0x47, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
45+
}
46+
47+
t.Run("nested field with UnmarshalMaxMindDB", func(t *testing.T) {
48+
d := New(data)
49+
var result Outer
50+
51+
err := d.Decode(0, &result)
52+
require.NoError(t, err)
53+
54+
// Check that custom unmarshaler WAS called for nested field
55+
require.True(
56+
t,
57+
result.Field.custom,
58+
"Custom unmarshaler should be called for nested fields",
59+
)
60+
require.Equal(t, "custom:test", result.Field.Value)
61+
require.Equal(t, "example", result.Name)
62+
})
63+
}
64+
65+
// testInnerPointer with UnmarshalMaxMindDB for pointer test.
66+
type testInnerPointer struct {
67+
Value string
68+
custom bool
69+
}
70+
71+
func (i *testInnerPointer) UnmarshalMaxMindDB(d *Decoder) error {
72+
i.custom = true
73+
str, err := d.DecodeString()
74+
if err != nil {
75+
return err
76+
}
77+
i.Value = "ptr:" + str
78+
return nil
79+
}
80+
81+
// TestNestedUnmarshalerPointer tests UnmarshalMaxMindDB with pointer fields.
82+
func TestNestedUnmarshalerPointer(t *testing.T) {
83+
type Outer struct {
84+
Field *testInnerPointer
85+
Name string
86+
}
87+
88+
// Test data
89+
data := []byte{
90+
// Map with 2 items
91+
0xe2,
92+
// Key "Field"
93+
0x45, 'F', 'i', 'e', 'l', 'd',
94+
// Value "test"
95+
0x44, 't', 'e', 's', 't',
96+
// Key "Name"
97+
0x44, 'N', 'a', 'm', 'e',
98+
// Value "example"
99+
0x47, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
100+
}
101+
102+
t.Run("pointer field with UnmarshalMaxMindDB", func(t *testing.T) {
103+
d := New(data)
104+
var result Outer
105+
err := d.Decode(0, &result)
106+
require.NoError(t, err)
107+
108+
// The pointer should be created and unmarshaled with custom unmarshaler
109+
require.NotNil(t, result.Field)
110+
require.True(
111+
t,
112+
result.Field.custom,
113+
"Custom unmarshaler should be called for pointer fields",
114+
)
115+
require.Equal(t, "ptr:test", result.Field.Value)
116+
require.Equal(t, "example", result.Name)
117+
})
118+
}
119+
120+
// testItem with UnmarshalMaxMindDB for slice test.
121+
type testItem struct {
122+
ID int
123+
custom bool
124+
}
125+
126+
func (item *testItem) UnmarshalMaxMindDB(d *Decoder) error {
127+
item.custom = true
128+
id, err := d.DecodeUInt32()
129+
if err != nil {
130+
return err
131+
}
132+
item.ID = int(id) * 2
133+
return nil
134+
}
135+
136+
// TestNestedUnmarshalerInSlice tests UnmarshalMaxMindDB for slice elements.
137+
func TestNestedUnmarshalerInSlice(t *testing.T) {
138+
type Container struct {
139+
Items []testItem
140+
}
141+
142+
// Test data: a map with "Items" -> [1, 2, 3]
143+
data := []byte{
144+
// Map with 1 item (TypeMap=7 << 5 | size=1)
145+
0xe1,
146+
// Key "Items" (TypeString=2 << 5 | size=5)
147+
0x45, 'I', 't', 'e', 'm', 's',
148+
// Slice with 3 items - TypeSlice=11, which is > 7, so we need extended type
149+
// Extended type: ctrl_byte = (TypeExtended << 5) | size = (0 << 5) | 3 = 0x03
150+
// Next byte: TypeSlice - 7 = 11 - 7 = 4
151+
0x03, 0x04,
152+
// Value 1 (TypeUint32=6 << 5 | size=1)
153+
0xc1, 0x01,
154+
// Value 2 (TypeUint32=6 << 5 | size=1)
155+
0xc1, 0x02,
156+
// Value 3 (TypeUint32=6 << 5 | size=1)
157+
0xc1, 0x03,
158+
}
159+
160+
t.Run("slice elements with UnmarshalMaxMindDB", func(t *testing.T) {
161+
d := New(data)
162+
var result Container
163+
err := d.Decode(0, &result)
164+
require.NoError(t, err)
165+
166+
require.Len(t, result.Items, 3)
167+
// With custom unmarshaler, values should be doubled
168+
require.True(
169+
t,
170+
result.Items[0].custom,
171+
"Custom unmarshaler should be called for slice elements",
172+
)
173+
require.Equal(t, 2, result.Items[0].ID) // 1 * 2
174+
require.Equal(t, 4, result.Items[1].ID) // 2 * 2
175+
require.Equal(t, 6, result.Items[2].ID) // 3 * 2
176+
})
177+
}
178+
179+
// testValue with UnmarshalMaxMindDB for map test.
180+
type testValue struct {
181+
Data string
182+
custom bool
183+
}
184+
185+
func (v *testValue) UnmarshalMaxMindDB(d *Decoder) error {
186+
v.custom = true
187+
str, err := d.DecodeString()
188+
if err != nil {
189+
return err
190+
}
191+
v.Data = "map:" + str
192+
return nil
193+
}
194+
195+
// TestNestedUnmarshalerInMap tests UnmarshalMaxMindDB for map values.
196+
func TestNestedUnmarshalerInMap(t *testing.T) {
197+
// Test data: {"key1": "value1", "key2": "value2"}
198+
data := []byte{
199+
// Map with 2 items
200+
0xe2,
201+
// Key "key1"
202+
0x44, 'k', 'e', 'y', '1',
203+
// Value "value1"
204+
0x46, 'v', 'a', 'l', 'u', 'e', '1',
205+
// Key "key2"
206+
0x44, 'k', 'e', 'y', '2',
207+
// Value "value2"
208+
0x46, 'v', 'a', 'l', 'u', 'e', '2',
209+
}
210+
211+
t.Run("map values with UnmarshalMaxMindDB", func(t *testing.T) {
212+
d := New(data)
213+
var result map[string]testValue
214+
err := d.Decode(0, &result)
215+
require.NoError(t, err)
216+
217+
require.Len(t, result, 2)
218+
require.True(t, result["key1"].custom, "Custom unmarshaler should be called for map values")
219+
require.Equal(t, "map:value1", result["key1"].Data)
220+
require.Equal(t, "map:value2", result["key2"].Data)
221+
})
222+
}

internal/decoder/reflection.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,37 @@ func (d *ReflectionDecoder) decode(offset uint, result reflect.Value, depth int)
143143
"exceeded maximum data structure depth; database is likely corrupt",
144144
)
145145
}
146+
147+
// First handle pointers by creating the value if needed, similar to indirect()
148+
// but we don't want to fully indirect yet as we need to check for Unmarshaler
149+
if result.Kind() == reflect.Ptr {
150+
if result.IsNil() {
151+
result.Set(reflect.New(result.Type().Elem()))
152+
}
153+
// Now check if the pointed-to type implements Unmarshaler
154+
if unmarshaler, ok := result.Interface().(Unmarshaler); ok {
155+
decoder := &Decoder{d: d.DataDecoder, offset: offset}
156+
if err := unmarshaler.UnmarshalMaxMindDB(decoder); err != nil {
157+
return 0, err
158+
}
159+
return decoder.getNextOffset()
160+
}
161+
// Continue with the pointed-to value
162+
return d.decode(offset, result.Elem(), depth)
163+
}
164+
165+
// Check if the value implements Unmarshaler interface
166+
// We need to check if result can be addressed and if the pointer type implements Unmarshaler
167+
if result.CanAddr() {
168+
if unmarshaler, ok := result.Addr().Interface().(Unmarshaler); ok {
169+
decoder := &Decoder{d: d.DataDecoder, offset: offset}
170+
if err := unmarshaler.UnmarshalMaxMindDB(decoder); err != nil {
171+
return 0, err
172+
}
173+
return decoder.getNextOffset()
174+
}
175+
}
176+
146177
typeNum, size, newOffset, err := d.decodeCtrlData(offset)
147178
if err != nil {
148179
return 0, err

0 commit comments

Comments
 (0)