Skip to content

Commit

Permalink
grpc: fix message length checks when compression is enabled and maxRe…
Browse files Browse the repository at this point in the history
…ceiveMessageSize is MaxInt (grpc#7918)
  • Loading branch information
vinothkumarr227 authored Jan 23, 2025
1 parent 67bee55 commit 8cf8fd1
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 30 deletions.
69 changes: 39 additions & 30 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,30 +828,13 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
return nil, st.Err()
}

var size int
if pf.isCompressed() {
defer compressed.Free()

// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
var uncompressedBuf []byte
uncompressedBuf, err = dc.Do(compressed.Reader())
if err == nil {
out = mem.BufferSlice{mem.SliceBuffer(uncompressedBuf)}
}
size = len(uncompressedBuf)
} else {
out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
}
out, err = decompress(compressor, compressed, dc, maxReceiveMessageSize, p.bufferPool)
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if size > maxReceiveMessageSize {
out.Free()
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
return nil, err
}
} else {
out = compressed
Expand All @@ -866,20 +849,46 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
return out, nil
}

// Using compressor, decompress d, returning data and size.
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) {
dcReader, err := compressor.Decompress(d.Reader())
if err != nil {
return nil, 0, err
// decompress processes the given data by decompressing it using either a custom decompressor or a standard compressor.
// If a custom decompressor is provided, it takes precedence. The function validates that the decompressed data
// does not exceed the specified maximum size and returns an error if this limit is exceeded.
// On success, it returns the decompressed data. Otherwise, it returns an error if decompression fails or the data exceeds the size limit.
func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) {
if dc != nil {
uncompressed, err := dc.Do(d.Reader())
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if len(uncompressed) > maxReceiveMessageSize {
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", len(uncompressed), maxReceiveMessageSize)
}
return mem.BufferSlice{mem.SliceBuffer(uncompressed)}, nil
}
if compressor != nil {
dcReader, err := compressor.Decompress(d.Reader())
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err)
}

out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool)
if err != nil {
out.Free()
return nil, 0, err
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool)
if err != nil {
out.Free()
return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err)
}

if out.Len() == maxReceiveMessageSize && !atEOF(dcReader) {
out.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
}
return out, nil
}
return out, out.Len(), nil
return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload")
}

// atEOF reads data from r and returns true if zero bytes could be read and r.Read returns EOF.
func atEOF(dcReader io.Reader) bool {
n, err := dcReader.Read(make([]byte, 1))
return n == 0 && err == io.EOF
}

type recvCompressor interface {
Expand Down
127 changes: 127 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@ package grpc
import (
"bytes"
"compress/gzip"
"errors"
"io"
"math"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/gzip"
protoenc "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
Expand All @@ -36,6 +41,11 @@ import (
"google.golang.org/protobuf/proto"
)

const (
defaultDecompressedData = "default decompressed data"
decompressionErrorMsg = "invalid compression format"
)

type fullReader struct {
data []byte
}
Expand Down Expand Up @@ -294,3 +304,120 @@ func BenchmarkGZIPCompressor512KiB(b *testing.B) {
func BenchmarkGZIPCompressor1MiB(b *testing.B) {
bmCompressor(b, 1024*1024, NewGZIPCompressor())
}

// compressWithDeterministicError compresses the input data and returns a BufferSlice.
func compressWithDeterministicError(t *testing.T, input []byte) mem.BufferSlice {
t.Helper()
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
if _, err := gz.Write(input); err != nil {
t.Fatalf("compressInput() failed to write data: %v", err)
}
if err := gz.Close(); err != nil {
t.Fatalf("compressInput() failed to close gzip writer: %v", err)
}
compressedData := buf.Bytes()
return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
}

// MockDecompressor is a mock implementation of a decompressor used for testing purposes.
// It simulates decompression behavior, returning either decompressed data or an error based on the ShouldError flag.
type MockDecompressor struct {
ShouldError bool // Flag to control whether the decompression should simulate an error.
}

// Do simulates decompression. It returns a predefined error if ShouldError is true,
// or a fixed set of decompressed data if ShouldError is false.
func (m *MockDecompressor) Do(_ io.Reader) ([]byte, error) {
if m.ShouldError {
return nil, errors.New(decompressionErrorMsg)
}
return []byte(defaultDecompressedData), nil
}

// Type returns the string identifier for the MockDecompressor.
func (m *MockDecompressor) Type() string {
return "MockDecompressor"
}

// TestDecompress tests the decompress function behaves correctly for following scenarios
// decompress successfully when message is <= maxReceiveMessageSize
// errors when message > maxReceiveMessageSize
// decompress successfully when maxReceiveMessageSize is MaxInt
// errors when the decompressed message has an invalid format
// errors when the decompressed message exceeds the maxReceiveMessageSize.
func (s) TestDecompress(t *testing.T) {
compressor := encoding.GetCompressor("gzip")
validDecompressor := &MockDecompressor{ShouldError: false}
invalidFormatDecompressor := &MockDecompressor{ShouldError: true}

testCases := []struct {
name string
input mem.BufferSlice
dc Decompressor
maxReceiveMessageSize int
want []byte
wantErr error
}{
{
name: "Decompresses successfully with sufficient buffer size",
input: compressWithDeterministicError(t, []byte("decompressed data")),
dc: nil,
maxReceiveMessageSize: 50,
want: []byte("decompressed data"),
wantErr: nil,
},
{
name: "Fails due to exceeding maxReceiveMessageSize",
input: compressWithDeterministicError(t, []byte("message that is too large")),
dc: nil,
maxReceiveMessageSize: len("message that is too large") - 1,
want: nil,
wantErr: status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", len("message that is too large")-1),
},
{
name: "Decompresses to exactly maxReceiveMessageSize",
input: compressWithDeterministicError(t, []byte("exact size message")),
dc: nil,
maxReceiveMessageSize: len("exact size message"),
want: []byte("exact size message"),
wantErr: nil,
},
{
name: "Decompresses successfully with maxReceiveMessageSize MaxInt",
input: compressWithDeterministicError(t, []byte("large message")),
dc: nil,
maxReceiveMessageSize: math.MaxInt,
want: []byte("large message"),
wantErr: nil,
},
{
name: "Fails with decompression error due to invalid format",
input: compressWithDeterministicError(t, []byte("invalid compressed data")),
dc: invalidFormatDecompressor,
maxReceiveMessageSize: 50,
want: nil,
wantErr: status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", errors.New(decompressionErrorMsg)),
},
{
name: "Fails with resourceExhausted error when decompressed message exceeds maxReceiveMessageSize",
input: compressWithDeterministicError(t, []byte("large compressed data")),
dc: validDecompressor,
maxReceiveMessageSize: 20,
want: nil,
wantErr: status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 25, 20),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
output, err := decompress(compressor, tc.input, tc.dc, tc.maxReceiveMessageSize, mem.DefaultBufferPool())
if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) {
t.Fatalf("decompress() err = %v, wantErr = %v", err, tc.wantErr)
}
if !cmp.Equal(tc.want, output.Materialize()) {
t.Fatalf("decompress() output mismatch: got = %v, want = %v", output.Materialize(), tc.want)
}
})
}
}

0 comments on commit 8cf8fd1

Please sign in to comment.