Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions internal/reflect/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,28 @@ func decodeFixedSizeTypes(t ttype, b []byte, p unsafe.Pointer) int {
}
}

// minWireSize is the minimum number of bytes a value of a given wire type takes
// in Thrift binary encoding. It is used while decoding to reject corrupted
// container/string lengths that can not possibly fit in the remaining buffer,
// before allocating memory for them.
var minWireSize = [256]int8{
tBOOL: 1,
tBYTE: 1,
tI16: 2,
tI32: 4,
tI64: 8,
tDOUBLE: 8,
tSTRING: 4, // length header only, content may be empty
tSTRUCT: 1, // tSTOP only
tMAP: 6, // header only, may hold zero entries
tSET: 5, // header only, may hold zero elements
tLIST: 5, // header only, may hold zero elements
}

func decodeStringNoCopy(t *tType, b []byte, p unsafe.Pointer) (i int, err error) {
if len(b) < strHeaderLen {
return 0, io.ErrShortBuffer
}
l := int(int32(binary.BigEndian.Uint32(b)))
if l < 0 {
err = errNegativeSize
Expand All @@ -177,8 +198,8 @@ func decodeStringNoCopy(t *tType, b []byte, p unsafe.Pointer) (i int, err error)
return
}

if i+l-1 >= len(b) {
return i, io.ErrShortBuffer
if l > len(b)-i {
return i, newSizeExceedsBufferException(l, len(b)-i)
}

if t.Tag == defs.T_binary {
Expand All @@ -199,6 +220,9 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
}
switch t.T {
case tSTRING:
if len(b) < strHeaderLen {
return 0, io.ErrShortBuffer
}
l := int(int32(binary.BigEndian.Uint32(b)))
if l < 0 {
return 0, errNegativeSize
Expand All @@ -213,8 +237,8 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
return i, nil
}

if i+l-1 >= len(b) {
return i, io.ErrShortBuffer
if l > len(b)-i {
return i, newSizeExceedsBufferException(l, len(b)-i)
}

x := d.Malloc(l, 1, 0)
Expand All @@ -229,6 +253,9 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int

case tMAP:
// map header
if len(b) < mapHeaderLen {
return 0, io.ErrShortBuffer
}
t0, t1, l := ttype(b[0]), ttype(b[1]), int(int32(binary.BigEndian.Uint32(b[2:])))
if l < 0 {
return 0, errNegativeSize
Expand All @@ -241,6 +268,13 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
return 0, newTypeMismatchKV(kt.WT, vt.WT, t0, t1)
}

// reject corrupted lengths before allocating the map: every entry needs
// at least minWireSize[key]+minWireSize[value] bytes, so l entries can
// not fit if they exceed the remaining buffer. likely data is broken.
if remain := len(b) - mapHeaderLen; l > remain/(int(minWireSize[kt.WT])+int(minWireSize[vt.WT])) {
return mapHeaderLen, newSizeExceedsBufferException(l, remain)
}

// decode map

// tmp vars
Expand Down Expand Up @@ -318,6 +352,9 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int

case tLIST, tSET: // NOTE: for tSET, it may be map in the future
// list header
if len(b) < listHeaderLen {
return 0, io.ErrShortBuffer
}
tp, l := ttype(b[0]), int(int32(binary.BigEndian.Uint32(b[1:])))
if l < 0 {
return 0, errNegativeSize
Expand All @@ -336,6 +373,14 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
h.Zero()
return i, nil
}

// reject corrupted lengths before allocating the slice: every element
// needs at least minWireSize[et] bytes, so l elements can not fit if
// they exceed the remaining buffer. likely the data is broken.
if remain := len(b) - i; l > remain/int(minWireSize[et.WT]) {
return i, newSizeExceedsBufferException(l, remain)
}

x := d.Malloc(l*et.Size, et.Align, et.MallocAbiType) // malloc for slice. make([]Type, l, l)
h.Data = x
h.Len = l
Expand Down
77 changes: 66 additions & 11 deletions internal/reflect/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package reflect

import (
"bytes"
"errors"
"io"
"math"
"math/rand"
Expand Down Expand Up @@ -456,21 +457,22 @@ func TestDecodeStringShortBuffer(t *testing.T) {
var result string
ptr := unsafe.Pointer(&result)

// io.ErrShortBuffer: decodeType
data := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l'} // length=5, only 4 bytes
n, err := decoder.decodeType(typ, data, ptr, 1)
assert.True(t, err == io.ErrShortBuffer)
assert.True(t, n <= len(data))
// header truncated: fewer than strHeaderLen bytes to even read the length.
// this is a genuine short buffer, not corrupted data.
for _, data := range [][]byte{nil, {0x00}, {0x00, 0x00, 0x00}} {
n, err := decoder.decodeType(typ, data, ptr, 1)
assert.True(t, err == io.ErrShortBuffer)
assert.True(t, n <= len(data))

// io.ErrShortBuffer: decodeStringNoCopy
n, err = decodeStringNoCopy(typ, data, ptr)
assert.True(t, err == io.ErrShortBuffer)
assert.True(t, n <= len(data))
n, err = decodeStringNoCopy(typ, data, ptr)
assert.True(t, err == io.ErrShortBuffer)
assert.True(t, n <= len(data))
}

// Normal case: decodeType
data = []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l', 'o'}
data := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l', 'o'}
result = ""
n, err = decoder.decodeType(typ, data, ptr, 1)
n, err := decoder.decodeType(typ, data, ptr, 1)
assert.Nil(t, err)
assert.Equal(t, len(data), n)
assert.Equal(t, "hello", result)
Expand All @@ -481,3 +483,56 @@ func TestDecodeStringShortBuffer(t *testing.T) {
assert.Equal(t, len(data), n)
assert.Equal(t, "hello", result)
}

// assertSizeLimit asserts err is a thrift SIZE_LIMIT protocol exception, i.e. a
// decoded size that exceeds the remaining buffer (likely corrupted data).
func assertSizeLimit(t *testing.T, err error) {
t.Helper()
var pe *thrift.ProtocolException
if !errors.As(err, &pe) {
t.Fatalf("expected *thrift.ProtocolException, got %v", err)
}
assert.Equal(t, int32(thrift.SIZE_LIMIT), pe.TypeID())
}

func TestDecodeSizeExceedsBuffer(t *testing.T) {
// string / binary: declared length exceeds the remaining buffer
decoder := &tDecoder{}
typ := &tType{T: tSTRING, Tag: defs.T_string}
var s string
data := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l'} // length=5, only 4 bytes
_, err := decoder.decodeType(typ, data, unsafe.Pointer(&s), 1)
assertSizeLimit(t, err)
_, err = decodeStringNoCopy(typ, data, unsafe.Pointer(&s))
assertSizeLimit(t, err)

// list: element count far exceeds the remaining buffer
{
type Msg struct {
L []int32 `frugal:"1,default,list<i32>"`
}
b := []byte{
byte(tLIST), 0x00, 0x01, // field: type=LIST id=1
byte(tI32), // element type
0x7f, 0xff, 0xff, 0xff, // count = 2147483647
byte(tSTOP),
}
_, err := Decode(b, &Msg{})
assertSizeLimit(t, err)
}

// map: entry count far exceeds the remaining buffer
{
type Msg struct {
M map[int32]int32 `frugal:"1,default,map<i32:i32>"`
}
b := []byte{
byte(tMAP), 0x00, 0x01, // field: type=MAP id=1
byte(tI32), byte(tI32), // key type, value type
0x7f, 0xff, 0xff, 0xff, // count = 2147483647
byte(tSTOP),
}
_, err := Decode(b, &Msg{})
assertSizeLimit(t, err)
}
}
11 changes: 11 additions & 0 deletions internal/reflect/exception.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ func newRequiredFieldNotSetException(name string) error {
)
}

// newSizeExceedsBufferException is returned when a decoded length or element
// count is larger than the remaining buffer can possibly hold. Unlike
// io.ErrShortBuffer (which only means more bytes are needed), it signals the
// size field itself is bogus, so the data is most likely corrupted.
func newSizeExceedsBufferException(size, remain int) error {
return thrift.NewProtocolException(
thrift.SIZE_LIMIT,
fmt.Sprintf("decoded size %d exceeds remaining buffer %d, data may be corrupted", size, remain),
)
}

func newTypeMismatch(expect, got ttype) error {
return thrift.NewProtocolException(
thrift.INVALID_DATA,
Expand Down
Loading