diff --git a/client.go b/client.go index 61772b6..57bf8c0 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package sse import ( + "bufio" "bytes" "context" "encoding/base64" @@ -32,6 +33,12 @@ func ClientMaxBufferSize(s int) func(c *Client) { } } +func SplitFunc(f bufio.SplitFunc) func(c *Client) { + return func(c *Client) { + c.split = f + } +} + // ConnCallback defines a function to be called on a particular connection event type ConnCallback func(c *Client) @@ -55,6 +62,7 @@ type Client struct { mu sync.Mutex EncodingBase64 bool Connected bool + split bufio.SplitFunc } // NewClient creates a new client @@ -97,7 +105,7 @@ func (c *Client) SubscribeWithContext(ctx context.Context, stream string, handle } defer resp.Body.Close() - reader := NewEventStreamReader(resp.Body, c.maxBufferSize) + reader := NewEventStreamReader(resp.Body, c.maxBufferSize, c.split) eventChan, errorChan := c.startReadLoop(reader) for { @@ -155,7 +163,7 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch connected = true } - reader := NewEventStreamReader(resp.Body, c.maxBufferSize) + reader := NewEventStreamReader(resp.Body, c.maxBufferSize, c.split) eventChan, errorChan := c.startReadLoop(reader) for { @@ -387,4 +395,3 @@ func trimHeader(size int, data []byte) []byte { data = data[:len(data)-1] } return data -} diff --git a/event.go b/event.go index 1258038..421dc8b 100644 --- a/event.go +++ b/event.go @@ -32,26 +32,29 @@ type EventStreamReader struct { } // NewEventStreamReader creates an instance of EventStreamReader. -func NewEventStreamReader(eventStream io.Reader, maxBufferSize int) *EventStreamReader { +func NewEventStreamReader(eventStream io.Reader, maxBufferSize int, split bufio.SplitFunc) *EventStreamReader { scanner := bufio.NewScanner(eventStream) initBufferSize := minPosInt(4096, maxBufferSize) scanner.Buffer(make([]byte, initBufferSize), maxBufferSize) - split := func(data []byte, atEOF bool) (int, []byte, error) { - if atEOF && len(data) == 0 { + if split == nil { + split = func(data []byte, atEOF bool) (int, []byte, error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + // We have a full event payload to parse. + if i, nlen := containsDoubleNewline(data); i >= 0 { + return i + nlen, data[0:i], nil + } + // If we're at EOF, we have all of the data. + if atEOF { + return len(data), data, nil + } + // Request more data. return 0, nil, nil } - // We have a full event payload to parse. - if i, nlen := containsDoubleNewline(data); i >= 0 { - return i + nlen, data[0:i], nil - } - // If we're at EOF, we have all of the data. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil } // Set the split function for the scanning operation. scanner.Split(split)