From 8cf8fd1433873cb11e6bec22991a3b490d7b7b19 Mon Sep 17 00:00:00 2001 From: vinothkumarr227 Date: Fri, 24 Jan 2025 00:08:25 +0530 Subject: [PATCH] grpc: fix message length checks when compression is enabled and maxReceiveMessageSize is MaxInt (#7918) --- rpc_util.go | 69 ++++++++++++++----------- rpc_util_test.go | 127 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 30 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index 9fac2b08b48b..94160b130897 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -828,30 +828,13 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM return nil, st.Err() } - var size int if pf.isCompressed() { defer compressed.Free() - // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // use this decompressor as the default. - if dc != nil { - var uncompressedBuf []byte - uncompressedBuf, err = dc.Do(compressed.Reader()) - if err == nil { - out = mem.BufferSlice{mem.SliceBuffer(uncompressedBuf)} - } - size = len(uncompressedBuf) - } else { - out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool) - } + out, err = decompress(compressor, compressed, dc, maxReceiveMessageSize, p.bufferPool) if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) - } - if size > maxReceiveMessageSize { - out.Free() - // TODO: Revisit the error code. Currently keep it consistent with java - // implementation. - return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) + return nil, err } } else { out = compressed @@ -866,20 +849,46 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM return out, nil } -// Using compressor, decompress d, returning data and size. -// Optionally, if data will be over maxReceiveMessageSize, just return the size. -func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) { - dcReader, err := compressor.Decompress(d.Reader()) - if err != nil { - return nil, 0, err +// decompress processes the given data by decompressing it using either a custom decompressor or a standard compressor. +// If a custom decompressor is provided, it takes precedence. The function validates that the decompressed data +// does not exceed the specified maximum size and returns an error if this limit is exceeded. +// On success, it returns the decompressed data. Otherwise, it returns an error if decompression fails or the data exceeds the size limit. +func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) { + if dc != nil { + uncompressed, err := dc.Do(d.Reader()) + if err != nil { + return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) + } + if len(uncompressed) > maxReceiveMessageSize { + return nil, status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", len(uncompressed), maxReceiveMessageSize) + } + return mem.BufferSlice{mem.SliceBuffer(uncompressed)}, nil } + if compressor != nil { + dcReader, err := compressor.Decompress(d.Reader()) + if err != nil { + return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err) + } - out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool) - if err != nil { - out.Free() - return nil, 0, err + out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool) + if err != nil { + out.Free() + return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err) + } + + if out.Len() == maxReceiveMessageSize && !atEOF(dcReader) { + out.Free() + return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize) + } + return out, nil } - return out, out.Len(), nil + return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload") +} + +// atEOF reads data from r and returns true if zero bytes could be read and r.Read returns EOF. +func atEOF(dcReader io.Reader) bool { + n, err := dcReader.Read(make([]byte, 1)) + return n == 0 && err == io.EOF } type recvCompressor interface { diff --git a/rpc_util_test.go b/rpc_util_test.go index 94f50bc24ade..608cc1002471 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -21,12 +21,17 @@ package grpc import ( "bytes" "compress/gzip" + "errors" "io" "math" "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" + _ "google.golang.org/grpc/encoding/gzip" protoenc "google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/transport" @@ -36,6 +41,11 @@ import ( "google.golang.org/protobuf/proto" ) +const ( + defaultDecompressedData = "default decompressed data" + decompressionErrorMsg = "invalid compression format" +) + type fullReader struct { data []byte } @@ -294,3 +304,120 @@ func BenchmarkGZIPCompressor512KiB(b *testing.B) { func BenchmarkGZIPCompressor1MiB(b *testing.B) { bmCompressor(b, 1024*1024, NewGZIPCompressor()) } + +// compressWithDeterministicError compresses the input data and returns a BufferSlice. +func compressWithDeterministicError(t *testing.T, input []byte) mem.BufferSlice { + t.Helper() + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + if _, err := gz.Write(input); err != nil { + t.Fatalf("compressInput() failed to write data: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("compressInput() failed to close gzip writer: %v", err) + } + compressedData := buf.Bytes() + return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)} +} + +// MockDecompressor is a mock implementation of a decompressor used for testing purposes. +// It simulates decompression behavior, returning either decompressed data or an error based on the ShouldError flag. +type MockDecompressor struct { + ShouldError bool // Flag to control whether the decompression should simulate an error. +} + +// Do simulates decompression. It returns a predefined error if ShouldError is true, +// or a fixed set of decompressed data if ShouldError is false. +func (m *MockDecompressor) Do(_ io.Reader) ([]byte, error) { + if m.ShouldError { + return nil, errors.New(decompressionErrorMsg) + } + return []byte(defaultDecompressedData), nil +} + +// Type returns the string identifier for the MockDecompressor. +func (m *MockDecompressor) Type() string { + return "MockDecompressor" +} + +// TestDecompress tests the decompress function behaves correctly for following scenarios +// decompress successfully when message is <= maxReceiveMessageSize +// errors when message > maxReceiveMessageSize +// decompress successfully when maxReceiveMessageSize is MaxInt +// errors when the decompressed message has an invalid format +// errors when the decompressed message exceeds the maxReceiveMessageSize. +func (s) TestDecompress(t *testing.T) { + compressor := encoding.GetCompressor("gzip") + validDecompressor := &MockDecompressor{ShouldError: false} + invalidFormatDecompressor := &MockDecompressor{ShouldError: true} + + testCases := []struct { + name string + input mem.BufferSlice + dc Decompressor + maxReceiveMessageSize int + want []byte + wantErr error + }{ + { + name: "Decompresses successfully with sufficient buffer size", + input: compressWithDeterministicError(t, []byte("decompressed data")), + dc: nil, + maxReceiveMessageSize: 50, + want: []byte("decompressed data"), + wantErr: nil, + }, + { + name: "Fails due to exceeding maxReceiveMessageSize", + input: compressWithDeterministicError(t, []byte("message that is too large")), + dc: nil, + maxReceiveMessageSize: len("message that is too large") - 1, + want: nil, + wantErr: status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", len("message that is too large")-1), + }, + { + name: "Decompresses to exactly maxReceiveMessageSize", + input: compressWithDeterministicError(t, []byte("exact size message")), + dc: nil, + maxReceiveMessageSize: len("exact size message"), + want: []byte("exact size message"), + wantErr: nil, + }, + { + name: "Decompresses successfully with maxReceiveMessageSize MaxInt", + input: compressWithDeterministicError(t, []byte("large message")), + dc: nil, + maxReceiveMessageSize: math.MaxInt, + want: []byte("large message"), + wantErr: nil, + }, + { + name: "Fails with decompression error due to invalid format", + input: compressWithDeterministicError(t, []byte("invalid compressed data")), + dc: invalidFormatDecompressor, + maxReceiveMessageSize: 50, + want: nil, + wantErr: status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", errors.New(decompressionErrorMsg)), + }, + { + name: "Fails with resourceExhausted error when decompressed message exceeds maxReceiveMessageSize", + input: compressWithDeterministicError(t, []byte("large compressed data")), + dc: validDecompressor, + maxReceiveMessageSize: 20, + want: nil, + wantErr: status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 25, 20), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + output, err := decompress(compressor, tc.input, tc.dc, tc.maxReceiveMessageSize, mem.DefaultBufferPool()) + if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) { + t.Fatalf("decompress() err = %v, wantErr = %v", err, tc.wantErr) + } + if !cmp.Equal(tc.want, output.Materialize()) { + t.Fatalf("decompress() output mismatch: got = %v, want = %v", output.Materialize(), tc.want) + } + }) + } +}