diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go index 752d4e8f8562..0f9fd70461df 100644 --- a/internal/transport/controlbuf.go +++ b/internal/transport/controlbuf.go @@ -496,6 +496,16 @@ const ( serverSide ) +// maxWriteBufSize is the maximum length (number of elements) the cached +// writeBuf can grow to. The length depends on the number of buffers +// contained within the BufferSlice produced by the codec, which is +// generally small. +// +// If a writeBuf larger than this limit is required, it will be allocated +// and freed after use, rather than being cached. This avoids holding +// on to large amounts of memory. +const maxWriteBufSize = 64 + // Loopy receives frames from the control buffer. // Each frame is handled individually; most of the work done by loopy goes // into handling data frames. Loopy maintains a queue of active streams, and each @@ -530,6 +540,8 @@ type loopyWriter struct { // Side-specific handlers ssGoAwayHandler func(*goAway) (bool, error) + + writeBuf [][]byte // cached slice to avoid heap allocations for calls to mem.Reader.Peek. } func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter { @@ -962,11 +974,11 @@ func (l *loopyWriter) processData() (bool, error) { if len(dataItem.h) == 0 && reader.Remaining() == 0 { // Empty data frame // Client sends out empty data frame with endStream = true - if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { + if err := l.framer.writeData(dataItem.streamID, dataItem.endStream, nil); err != nil { return false, err } str.itl.dequeue() // remove the empty data item from stream - _ = reader.Close() + reader.Close() if str.itl.isEmpty() { str.state = empty } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. @@ -999,25 +1011,20 @@ func (l *loopyWriter) processData() (bool, error) { remainingBytes := len(dataItem.h) + reader.Remaining() - hSize - dSize size := hSize + dSize - var buf *[]byte - - if hSize != 0 && dSize == 0 { - buf = &dataItem.h - } else { - // Note: this is only necessary because the http2.Framer does not support - // partially writing a frame, so the sequence must be materialized into a buffer. - // TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed. - pool := l.bufferPool - if pool == nil { - // Note that this is only supposed to be nil in tests. Otherwise, stream is - // always initialized with a BufferPool. - pool = mem.DefaultBufferPool() + l.writeBuf = l.writeBuf[:0] + if hSize > 0 { + l.writeBuf = append(l.writeBuf, dataItem.h[:hSize]) + } + if dSize > 0 { + var err error + l.writeBuf, err = reader.Peek(dSize, l.writeBuf) + if err != nil { + // This must never happen since the reader must have at least dSize + // bytes. + clear(l.writeBuf) + l.writeBuf = nil + return false, err } - buf = pool.Get(size) - defer pool.Put(buf) - - copy((*buf)[:hSize], dataItem.h) - _, _ = reader.Read((*buf)[hSize:]) } // Now that outgoing flow controls are checked we can replenish str's write quota @@ -1030,7 +1037,14 @@ func (l *loopyWriter) processData() (bool, error) { if dataItem.onEachWrite != nil { dataItem.onEachWrite() } - if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil { + err := l.framer.writeData(dataItem.streamID, endStream, l.writeBuf) + reader.Discard(dSize) + if cap(l.writeBuf) > maxWriteBufSize { + l.writeBuf = nil + } else { + clear(l.writeBuf) + } + if err != nil { return false, err } str.bytesOutStanding += size @@ -1038,7 +1052,7 @@ func (l *loopyWriter) processData() (bool, error) { dataItem.h = dataItem.h[hSize:] if remainingBytes == 0 { // All the data from that message was written out. - _ = reader.Close() + reader.Close() str.itl.dequeue() } if str.itl.isEmpty() { diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index e3663f87f391..a1ef2057bf1d 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -389,8 +389,9 @@ func toIOError(err error) error { } type framer struct { - writer *bufWriter - fr *http2.Framer + writer *bufWriter + fr *http2.Framer + headerBuf []byte // cached slice for framer headers to reduce heap allocs. } var writeBufferPoolMap = make(map[int]*sync.Pool) @@ -422,6 +423,41 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu return f } +// writeData writes a DATA frame. +// +// It is the caller's responsibility not to violate the maximum frame size. +func (f *framer) writeData(streamID uint32, endStream bool, data [][]byte) error { + var flags http2.Flags + if endStream { + flags = http2.FlagDataEndStream + } + length := uint32(0) + for _, d := range data { + length += uint32(len(d)) + } + // TODO: Replace the header write with the framer API being added in + // https://github.com/golang/go/issues/66655. + f.headerBuf = append(f.headerBuf[:0], + byte(length>>16), + byte(length>>8), + byte(length), + byte(http2.FrameData), + byte(flags), + byte(streamID>>24), + byte(streamID>>16), + byte(streamID>>8), + byte(streamID)) + if _, err := f.writer.Write(f.headerBuf); err != nil { + return err + } + for _, d := range data { + if _, err := f.writer.Write(d); err != nil { + return err + } + } + return nil +} + func getWriteBufferPool(size int) *sync.Pool { writeBufferMutex.Lock() defer writeBufferMutex.Unlock() diff --git a/mem/buffer_slice.go b/mem/buffer_slice.go index 9fcb12b989e5..084fb19c6d15 100644 --- a/mem/buffer_slice.go +++ b/mem/buffer_slice.go @@ -19,6 +19,7 @@ package mem import ( + "fmt" "io" ) @@ -126,9 +127,10 @@ func (s BufferSlice) Reader() *Reader { } // Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface -// with other parts systems. It also provides an additional convenience method -// Remaining(), which returns the number of unread bytes remaining in the slice. +// with other systems. +// // Buffers will be freed as they are read. +// // A Reader can be constructed from a BufferSlice; alternatively the zero value // of a Reader may be used after calling Reset on it. type Reader struct { @@ -285,3 +287,59 @@ nextBuffer: } } } + +// Discard skips the next n bytes, returning the number of bytes discarded. +// +// It frees buffers as they are fully consumed. +// +// If Discard skips fewer than n bytes, it also returns an error. +func (r *Reader) Discard(n int) (discarded int, err error) { + total := n + for n > 0 && r.len > 0 { + curData := r.data[0].ReadOnlyData() + curSize := min(n, len(curData)-r.bufferIdx) + n -= curSize + r.len -= curSize + r.bufferIdx += curSize + if r.bufferIdx >= len(curData) { + r.data[0].Free() + r.data = r.data[1:] + r.bufferIdx = 0 + } + } + discarded = total - n + if n > 0 { + return discarded, fmt.Errorf("insufficient bytes in reader") + } + return discarded, nil +} + +// Peek returns the next n bytes without advancing the reader. +// +// Peek appends results to the provided res slice and returns the updated slice. +// This pattern allows re-using the storage of res if it has sufficient +// capacity. +// +// The returned subslices are views into the underlying buffers and are only +// valid until the reader is advanced past the corresponding buffer. +// +// If Peek returns fewer than n bytes, it also returns an error. +func (r *Reader) Peek(n int, res [][]byte) ([][]byte, error) { + for i := 0; n > 0 && i < len(r.data); i++ { + curData := r.data[i].ReadOnlyData() + start := 0 + if i == 0 { + start = r.bufferIdx + } + curSize := min(n, len(curData)-start) + if curSize == 0 { + continue + } + res = append(res, curData[start:start+curSize]) + n -= curSize + } + if n > 0 { + return nil, fmt.Errorf("insufficient bytes in reader") + } + return res, nil +} diff --git a/mem/buffer_slice_test.go b/mem/buffer_slice_test.go index bb9303f0e9e1..822acbda87da 100644 --- a/mem/buffer_slice_test.go +++ b/mem/buffer_slice_test.go @@ -484,3 +484,193 @@ func (t *testPool) Put(buf *[]byte) { } delete(t.allocated, buf) } + +func (s) TestBufferSlice_Iteration(t *testing.T) { + tests := []struct { + name string + buffers [][]byte + operations func(t *testing.T, c *mem.Reader) + }{ + { + name: "empty", + operations: func(t *testing.T, r *mem.Reader) { + if r.Remaining() != 0 { + t.Fatalf("Remaining() = %v, want 0", r.Remaining()) + } + _, err := r.Peek(1, nil) + if err == nil { + t.Fatalf("Peek(1) returned error , want non-nil") + } + discarded, err := r.Discard(1) + if got, want := discarded, 0; got != want { + t.Fatalf("Discard(1) = %d, want %d", got, want) + } + if err == nil { + t.Fatalf("Discard(1) returned error , want non-nil") + } + if r.Remaining() != 0 { + t.Fatalf("Remaining() after Discard = %v, want 0", r.Remaining()) + } + }, + }, + { + name: "single_buffer", + buffers: [][]byte{[]byte("0123456789")}, + operations: func(t *testing.T, r *mem.Reader) { + if r.Remaining() != 10 { + t.Fatalf("Remaining() = %v, want 10", r.Remaining()) + } + + res := make([][]byte, 0, 10) + res, err := r.Peek(5, res) + if err != nil { + t.Fatalf("Peek(5) return error %v, want ", err) + } + if len(res) != 1 || !bytes.Equal(res[0], []byte("01234")) { + t.Fatalf("Peek(5) = %v, want [[01234]]", res) + } + if cap(res) != 10 { + t.Fatalf("Peek(5) did not use the provided slice.") + } + + discarded, err := r.Discard(5) + if got, want := discarded, 5; got != want { + t.Fatalf("Discard(5) = %d, want %d", got, want) + } + if err != nil { + t.Fatalf("Discard(5) return error %v, want ", err) + } + if r.Remaining() != 5 { + t.Fatalf("Remaining() after Discard(5) = %v, want 5", r.Remaining()) + } + res, err = r.Peek(5, res[:0]) + if err != nil { + t.Fatalf("Peek(5) return error %v, want ", err) + } + if len(res) != 1 || !bytes.Equal(res[0], []byte("56789")) { + t.Fatalf("Peek(5) after Discard(5) = %v, want [[56789]]", res) + } + + discarded, err = r.Discard(100) + if got, want := discarded, 5; got != want { + t.Fatalf("Discard(100) = %d, want %d", got, want) + } + if err == nil { + t.Fatalf("Discard(100) returned error , want non-nil") + } + if r.Remaining() != 0 { + t.Fatalf("Remaining() after Discard(100) = %v, want 0", r.Remaining()) + } + }, + }, + { + name: "multiple_buffers", + buffers: [][]byte{[]byte("012"), []byte("345"), []byte("6789")}, + operations: func(t *testing.T, r *mem.Reader) { + if r.Remaining() != 10 { + t.Fatalf("Remaining() = %v, want 10", r.Remaining()) + } + + res, err := r.Peek(5, nil) + if err != nil { + t.Fatalf("Peek(5) return error %v, want ", err) + } + if len(res) != 2 || !bytes.Equal(res[0], []byte("012")) || !bytes.Equal(res[1], []byte("34")) { + t.Fatalf("Peek(5) = %v, want [[012] [34]]", res) + } + + discarded, err := r.Discard(5) + if got, want := discarded, 5; got != want { + t.Fatalf("Discard(5) = %d, want %d", got, want) + } + if err != nil { + t.Fatalf("Discard(5) return error %v, want ", err) + } + if r.Remaining() != 5 { + t.Fatalf("Remaining() after Discard(5) = %v, want 5", r.Remaining()) + } + + res, err = r.Peek(5, res[:0]) + if err != nil { + t.Fatalf("Peek(5) return error %v, want ", err) + } + if len(res) != 2 || !bytes.Equal(res[0], []byte("5")) || !bytes.Equal(res[1], []byte("6789")) { + t.Fatalf("Peek(5) after advance = %v, want [[5] [6789]]", res) + } + }, + }, + { + name: "close", + buffers: [][]byte{[]byte("0123456789")}, + operations: func(t *testing.T, r *mem.Reader) { + r.Close() + if r.Remaining() != 0 { + t.Fatalf("Remaining() after Close = %v, want 0", r.Remaining()) + } + }, + }, + { + name: "reset", + buffers: [][]byte{[]byte("0123")}, + operations: func(t *testing.T, r *mem.Reader) { + newSlice := mem.BufferSlice{mem.SliceBuffer([]byte("56789"))} + r.Reset(newSlice) + if r.Remaining() != 5 { + t.Fatalf("Remaining() after Reset = %v, want 5", r.Remaining()) + } + res, err := r.Peek(5, nil) + if err != nil { + t.Fatalf("Peek(5) return error %v, want ", err) + } + if len(res) != 1 || !bytes.Equal(res[0], []byte("56789")) { + t.Fatalf("Peek(5) after Reset = %v, want [[56789]]", res) + } + }, + }, + { + name: "zero_ops", + buffers: [][]byte{[]byte("01234")}, + operations: func(t *testing.T, c *mem.Reader) { + if c.Remaining() != 5 { + t.Fatalf("Remaining() = %v, want 5", c.Remaining()) + } + res, err := c.Peek(0, nil) + if err != nil { + t.Fatalf("Peek(0) return error %v, want ", err) + } + if len(res) != 0 { + t.Fatalf("Peek(0) got slices: %v, want empty", res) + } + discarded, err := c.Discard(0) + if err != nil { + t.Fatalf("Discard(0) return error %v, want ", err) + } + if got, want := discarded, 0; got != want { + t.Fatalf("Discard(0) = %d, want %d", got, want) + } + if c.Remaining() != 5 { + t.Fatalf("Remaining() after Discard(0) = %v, want 5", c.Remaining()) + } + res, err = c.Peek(2, res[:0]) + if err != nil { + t.Fatalf("Peek(2) return error %v, want ", err) + } + if len(res) != 1 || !bytes.Equal(res[0], []byte("01")) { + t.Fatalf("Peek(2) after zero ops = %v, want [[01]]", res) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var slice mem.BufferSlice + for _, b := range tt.buffers { + slice = append(slice, mem.SliceBuffer(b)) + } + c := slice.Reader() + slice.Free() + defer c.Close() + tt.operations(t, c) + }) + } +}