diff --git a/connection.go b/connection.go index 9cd90fa0..3de549d3 100644 --- a/connection.go +++ b/connection.go @@ -24,6 +24,7 @@ package ouroboros import ( + "context" "errors" "fmt" "io" @@ -78,6 +79,8 @@ type Connection struct { delayProtocolStart bool fullDuplex bool peerSharingEnabled bool + ctx context.Context + cancelCtx context.CancelFunc // Mini-protocols blockFetch *blockfetch.BlockFetch blockFetchConfig *blockfetch.Config @@ -266,6 +269,52 @@ func (c *Connection) shutdown() { close(c.errorChan) } +// handleConnectionError handles connection-level errors centrally +func (c *Connection) handleConnectionError(err error) error { + if err == nil { + return nil + } + + // Only propagate EOF errors when acting as a client with active server-side protocols + if errors.Is(err, io.EOF) { + // Check if we have any active server-side protocols + if c.server { + return err + } + + // For clients, only propagate EOF if we have active server protocols + hasActiveServerProtocols := false + if c.chainSync != nil && c.chainSync.Server != nil && !c.chainSync.Server.IsDone() { + hasActiveServerProtocols = true + } + if c.blockFetch != nil && c.blockFetch.Server != nil && !c.blockFetch.Server.IsDone() { + hasActiveServerProtocols = true + } + if c.txSubmission != nil && c.txSubmission.Server != nil && !c.txSubmission.Server.IsDone() { + hasActiveServerProtocols = true + } + if c.localStateQuery != nil && c.localStateQuery.Server != nil && !c.localStateQuery.Server.IsDone() { + hasActiveServerProtocols = true + } + if c.localTxMonitor != nil && c.localTxMonitor.Server != nil && !c.localTxMonitor.Server.IsDone() { + hasActiveServerProtocols = true + } + if c.localTxSubmission != nil && c.localTxSubmission.Server != nil && !c.localTxSubmission.Server.IsDone() { + hasActiveServerProtocols = true + } + + if hasActiveServerProtocols { + return err + } + + // EOF with no active server protocols is normal connection closure + return nil + } + + // For non-EOF errors, always propagate + return err +} + // setupConnection establishes the muxer, configures and starts the handshake process, and initializes // the appropriate mini-protocols func (c *Connection) setupConnection() error { @@ -276,10 +325,13 @@ func (c *Connection) setupConnection() error { c.networkMagic, ) } + // Create context for connection + c.ctx, c.cancelCtx = context.WithCancel(context.Background()) // Start Goroutine to shutdown when doneChan is closed c.doneChan = make(chan any) go func() { <-c.doneChan + c.cancelCtx() c.shutdown() }() // Populate connection ID @@ -301,16 +353,20 @@ func (c *Connection) setupConnection() error { if !ok { return } - var connErr *muxer.ConnectionClosedError - if errors.As(err, &connErr) { - // Pass through ConnectionClosedError from muxer - c.errorChan <- err - } else { - // Wrap error message to denote it comes from the muxer - c.errorChan <- fmt.Errorf("muxer error: %w", err) + + // Use centralized connection error handling + if handledErr := c.handleConnectionError(err); handledErr != nil { + var connErr *muxer.ConnectionClosedError + if errors.As(handledErr, &connErr) { + // Pass through ConnectionClosedError from muxer + c.errorChan <- handledErr + } else { + // Wrap error message to denote it comes from the muxer + c.errorChan <- fmt.Errorf("muxer error: %w", handledErr) + } + // Close connection on muxer errors + c.Close() } - // Close connection on muxer errors - c.Close() } }() protoOptions := protocol.ProtocolOptions{ @@ -318,6 +374,7 @@ func (c *Connection) setupConnection() error { Muxer: c.muxer, Logger: c.logger, ErrorChan: c.protoErrorChan, + Context: c.ctx, } if c.useNodeToNodeProto { protoOptions.Mode = protocol.ProtocolModeNodeToNode diff --git a/connection_test.go b/connection_test.go index 412b72d2..a8d936f6 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2025 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,70 +15,339 @@ package ouroboros_test import ( - "fmt" + "strings" "testing" "time" ouroboros "github.com/blinklabs-io/gouroboros" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/chainsync" + "github.com/blinklabs-io/gouroboros/protocol/common" + "github.com/blinklabs-io/gouroboros/protocol/handshake" ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" "go.uber.org/goleak" ) -// Ensure that we don't panic when closing the Connection object after a failed Dial() call -func TestDialFailClose(t *testing.T) { +// TestErrorHandlingWithActiveProtocols tests that connection errors are propagated +// when protocols are active, and ignored when protocols are stopped +func TestErrorHandlingWithActiveProtocols(t *testing.T) { defer goleak.VerifyNone(t) - oConn, err := ouroboros.New() - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } - err = oConn.Dial("unix", "/path/does/not/exist") - if err == nil { - t.Fatalf("did not get expected failure on Dial()") - } - // Close connection - oConn.Close() + + t.Run("ErrorsPropagatedWhenProtocolsActive", func(t *testing.T) { + // Create a mock connection that will complete handshake and start the chainsync protocol + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + // MsgProposeVersions from mock client + ouroboros_mock.ConversationEntryOutput{ + ProtocolId: handshake.ProtocolId, + Messages: []protocol.Message{ + handshake.NewMsgProposeVersions( + protocol.ProtocolVersionMap{ + (10 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14( + ouroboros_mock.MockNetworkMagic, + ), + }, + ), + }, + }, + // MsgAcceptVersion from server + ouroboros_mock.ConversationEntryInput{ + ProtocolId: handshake.ProtocolId, + IsResponse: true, + MsgFromCborFunc: handshake.NewMsgFromCbor, + Message: handshake.NewMsgAcceptVersion( + (10 + protocol.ProtocolVersionNtCOffset), + protocol.VersionDataNtC9to14( + ouroboros_mock.MockNetworkMagic, + ), + ), + }, + // ChainSync messages + ouroboros_mock.ConversationEntryOutput{ + ProtocolId: chainsync.ProtocolIdNtC, + Messages: []protocol.Message{ + chainsync.NewMsgFindIntersect( + []common.Point{ + { + Slot: 21600, + Hash: []byte("19297addad3da631einos029"), + }, + }, + ), + }, + }, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithServer(true), + ouroboros.WithChainSyncConfig( + chainsync.NewConfig( + chainsync.WithFindIntersectFunc( + func(ctx chainsync.CallbackContext, points []common.Point) (common.Point, chainsync.Tip, error) { + // Wait for shutdown instead of sleeping + <-ctx.Done() + return common.Point{}, chainsync.Tip{}, ctx.Err() + }, + ), + ), + ), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Wait for handshake to complete by checking if protocols are initialized + var chainSyncProtocol *chainsync.ChainSync + for i := 0; i < 100; i++ { + chainSyncProtocol = oConn.ChainSync() + if chainSyncProtocol != nil && chainSyncProtocol.Server != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + if chainSyncProtocol == nil || chainSyncProtocol.Server == nil { + oConn.Close() + t.Fatal("chain sync protocol not initialized") + } + + // Wait a bit for protocol to start + time.Sleep(100 * time.Millisecond) + + // Close the mock connection first to trigger an error + mockConn.Close() + + // We should receive a connection error since protocols were active when error occurred + timeout := time.After(2 * time.Second) + for { + select { + case err, ok := <-oConn.ErrorChan(): + if !ok { + t.Log("Error channel closed") + goto done + } + if err == nil { + t.Error("received nil error") + continue + } + t.Logf("Received connection error (expected with active protocols): %s", err) + if strings.Contains(err.Error(), "EOF") || + strings.Contains(err.Error(), "use of closed network connection") { + goto done + } + case <-timeout: + t.Error("timed out waiting for connection error") + goto done + } + } + done: + // Clean up - wait for connection to fully shut down + oConn.Close() + // Give time for goroutines to clean up + time.Sleep(100 * time.Millisecond) + }) + + t.Run("ErrorsIgnoredWhenProtocolsStopped", func(t *testing.T) { + // Create a mock connection that will send a Done message to stop the protocol + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + // MsgProposeVersions from mock client + ouroboros_mock.ConversationEntryOutput{ + ProtocolId: handshake.ProtocolId, + Messages: []protocol.Message{ + handshake.NewMsgProposeVersions( + protocol.ProtocolVersionMap{ + (10 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14( + ouroboros_mock.MockNetworkMagic, + ), + }, + ), + }, + }, + // MsgAcceptVersion from server + ouroboros_mock.ConversationEntryInput{ + ProtocolId: handshake.ProtocolId, + IsResponse: true, + MsgFromCborFunc: handshake.NewMsgFromCbor, + Message: handshake.NewMsgAcceptVersion( + (10 + protocol.ProtocolVersionNtCOffset), + protocol.VersionDataNtC9to14( + ouroboros_mock.MockNetworkMagic, + ), + ), + }, + // Send Done message to stop the protocol + ouroboros_mock.ConversationEntryOutput{ + ProtocolId: chainsync.ProtocolIdNtC, + Messages: []protocol.Message{chainsync.NewMsgDone()}, + }, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithServer(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Wait for handshake to complete + var chainSyncProtocol *chainsync.ChainSync + for i := 0; i < 100; i++ { + chainSyncProtocol = oConn.ChainSync() + if chainSyncProtocol != nil && chainSyncProtocol.Server != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + if chainSyncProtocol == nil || chainSyncProtocol.Server == nil { + oConn.Close() + t.Fatal("chain sync protocol not initialized") + } + + // Wait for protocol to be done (Done message from mock should trigger this) + select { + case <-chainSyncProtocol.Server.DoneChan(): + // Protocol is stopped + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for protocol to stop") + } + + // Close the mock connection + mockConn.Close() + + // With protocols stopped, we should either get no error or just connection closed errors + timeout := time.After(2 * time.Second) + for { + select { + case err, ok := <-oConn.ErrorChan(): + if !ok { + t.Log("Error channel closed") + goto done + } + if err != nil { + if !strings.Contains(err.Error(), "EOF") && + !strings.Contains(err.Error(), "use of closed network connection") { + t.Errorf("Unexpected error during shutdown: %s", err) + } + } + case <-time.After(500 * time.Millisecond): + t.Log("No connection error received (expected when protocols are stopped)") + goto done + case <-timeout: + t.Error("timed out waiting for connection cleanup") + goto done + } + } + done: + // Clean up + oConn.Close() + // Give time for goroutines to clean up + time.Sleep(100 * time.Millisecond) + }) } -func TestDoubleClose(t *testing.T) { +// TestErrorHandlingWithMultipleProtocols tests error handling with multiple active protocols +func TestErrorHandlingWithMultipleProtocols(t *testing.T) { defer goleak.VerifyNone(t) + mockConn := ouroboros_mock.NewConnection( ouroboros_mock.ProtocolRoleClient, []ouroboros_mock.ConversationEntry{ ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtCResponse, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, }, ) + oConn, err := ouroboros.New( ouroboros.WithConnection(mockConn), ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), ) if err != nil { t.Fatalf("unexpected error when creating Connection object: %s", err) } - // Async error handler - go func() { - err, ok := <-oConn.ErrorChan() - if !ok { - return + + // Wait for handshake to complete + time.Sleep(100 * time.Millisecond) + + // Close mock connection first to generate error + mockConn.Close() + + // Should receive error since protocols were active + timeout := time.After(2 * time.Second) + for { + select { + case err, ok := <-oConn.ErrorChan(): + if !ok { + t.Log("Error channel closed") + goto done + } + if err == nil { + t.Error("received nil error") + continue + } + t.Logf("Received connection error with multiple active protocols: %s", err) + if strings.Contains(err.Error(), "EOF") || + strings.Contains(err.Error(), "use of closed network connection") { + goto done + } + case <-timeout: + t.Error("timed out waiting for connection error") + goto done } - // We can't call t.Fatalf() from a different Goroutine, so we panic instead - panic(fmt.Sprintf("unexpected Ouroboros connection error: %s", err)) - }() - // Close connection - if err := oConn.Close(); err != nil { - t.Fatalf("unexpected error when closing Connection object: %s", err) } - // Close connection again - if err := oConn.Close(); err != nil { - t.Fatalf( - "unexpected error when closing Connection object again: %s", - err, +done: + // Clean up + oConn.Close() + // Give time for goroutines to clean up + time.Sleep(100 * time.Millisecond) +} + +// TestBasicErrorHandling tests basic error handling scenarios +func TestBasicErrorHandling(t *testing.T) { + defer goleak.VerifyNone(t) + + t.Run("DialFailure", func(t *testing.T) { + oConn, err := ouroboros.New( + ouroboros.WithNetworkMagic(764824073), ) - } - // Wait for connection shutdown - select { - case <-oConn.ErrorChan(): - case <-time.After(10 * time.Second): - t.Errorf("did not shutdown within timeout") - } + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + err = oConn.Dial("tcp", "invalid-hostname:9999") + if err == nil { + t.Fatal("expected dial error, got nil") + } + + oConn.Close() + }) + + t.Run("DoubleClose", func(t *testing.T) { + oConn, err := ouroboros.New( + ouroboros.WithNetworkMagic(764824073), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // First close + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error on first close: %s", err) + } + + // Second close should also work + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error on second close: %s", err) + } + }) } diff --git a/protocol/blockfetch/blockfetch_test.go b/protocol/blockfetch/blockfetch_test.go new file mode 100644 index 00000000..2a5cfbba --- /dev/null +++ b/protocol/blockfetch/blockfetch_test.go @@ -0,0 +1,166 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package blockfetch + +import ( + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/ledger" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/common" + "github.com/stretchr/testify/assert" +) + +// testAddr implements net.Addr for testing +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +// testConn implements net.Conn for testing with buffered writes +type testConn struct { + writeChan chan []byte + closed bool + closeChan chan struct{} +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Write(b []byte) (n int, err error) { + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} +func (c *testConn) Close() error { + if !c.closed { + close(c.closeChan) + c.closed = true + } + return nil +} +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + go mux.Start() + go func() { + <-conn.(*testConn).closeChan + mux.Stop() + }() + + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + } +} + +func TestNewBlockFetch(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + bf := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, bf.Client) + assert.NotNil(t, bf.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 5*time.Second, cfg.BatchStartTimeout) + assert.Equal(t, 60*time.Second, cfg.BlockTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithBatchStartTimeout(10*time.Second), + WithBlockTimeout(30*time.Second), + WithRecvQueueSize(100), + ) + assert.Equal(t, 10*time.Second, cfg.BatchStartTimeout) + assert.Equal(t, 30*time.Second, cfg.BlockTimeout) + assert.Equal(t, 100, cfg.RecvQueueSize) + }) +} + +func TestCallbackRegistration(t *testing.T) { + conn := newTestConn() + defer conn.Close() + + t.Run("Block callback registration", func(t *testing.T) { + blockFunc := func(ctx CallbackContext, slot uint, block ledger.Block) error { + return nil + } + cfg := NewConfig(WithBlockFunc(blockFunc)) + client := NewClient(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, client) + assert.NotNil(t, cfg.BlockFunc) + }) + + t.Run("RequestRange callback registration", func(t *testing.T) { + requestRangeFunc := func(ctx CallbackContext, start, end common.Point) error { + return nil + } + cfg := NewConfig(WithRequestRangeFunc(requestRangeFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestRangeFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + client := NewClient(getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + client.Start() + defer client.Stop() + + // Send a done message + err := client.SendMessage(NewMsgClientDone()) + assert.NoError(t, err) + + // Verify message was written to connection + select { + case <-conn.writeChan: + // Message was sent successfully + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message send") + } + }) +} diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index d3ad3eb7..7ce5c792 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -113,8 +113,12 @@ func (c *Client) Stop() error { "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) - msg := NewMsgClientDone() - err = c.SendMessage(msg) + if !c.IsDone() { + msg := NewMsgClientDone() + if err = c.SendMessage(msg); err != nil { + return + } + } }) return err } diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index 40ae8450..51e1433d 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -16,6 +16,7 @@ package chainsync import ( + "context" "fmt" "sync" "time" @@ -238,6 +239,7 @@ const ( // Callback context type CallbackContext struct { + context.Context ConnectionId connection.ConnectionId Client *Client Server *Server diff --git a/protocol/chainsync/chainsync_test.go b/protocol/chainsync/chainsync_test.go new file mode 100644 index 00000000..d660717d --- /dev/null +++ b/protocol/chainsync/chainsync_test.go @@ -0,0 +1,167 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package chainsync + +import ( + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/common" + "github.com/stretchr/testify/assert" +) + +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +type testConn struct { + writeChan chan []byte + closed bool + closeChan chan struct{} +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Write(b []byte) (n int, err error) { + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} +func (c *testConn) Close() error { + if !c.closed { + close(c.closeChan) + c.closed = true + } + return nil +} +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + go mux.Start() + go func() { + <-conn.(*testConn).closeChan + mux.Stop() + }() + + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + Mode: protocol.ProtocolModeNodeToClient, + } +} + +func TestNewChainSync(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + cs := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, cs.Client) + assert.NotNil(t, cs.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 5*time.Second, cfg.IntersectTimeout) + assert.Equal(t, 300*time.Second, cfg.BlockTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithIntersectTimeout(10*time.Second), + WithBlockTimeout(30*time.Second), + WithPipelineLimit(10), + WithRecvQueueSize(100), + ) + assert.Equal(t, 10*time.Second, cfg.IntersectTimeout) + assert.Equal(t, 30*time.Second, cfg.BlockTimeout) + assert.Equal(t, 10, cfg.PipelineLimit) + assert.Equal(t, 100, cfg.RecvQueueSize) + }) +} + +func TestCallbackRegistration(t *testing.T) { + conn := newTestConn() + defer conn.Close() + + t.Run("RollForward callback registration", func(t *testing.T) { + rollForwardFunc := func(ctx CallbackContext, blockType uint, blockData any, tip Tip) error { + return nil + } + cfg := NewConfig(WithRollForwardFunc(rollForwardFunc)) + client := NewClient(nil, getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, client) + assert.NotNil(t, cfg.RollForwardFunc) + }) + + t.Run("FindIntersect callback registration", func(t *testing.T) { + findIntersectFunc := func(ctx CallbackContext, points []common.Point) (common.Point, Tip, error) { + return common.Point{}, Tip{}, nil + } + cfg := NewConfig(WithFindIntersectFunc(findIntersectFunc)) + server := NewServer(nil, getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.FindIntersectFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + client := NewClient(nil, getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + client.Start() + defer client.Stop() + + // Send a done message + err := client.SendMessage(NewMsgDone()) + assert.NoError(t, err) + + // Verify message was written to connection + select { + case <-conn.writeChan: + // Message was sent successfully + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message send") + } + }) +} diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 0bf49a99..4e95cf66 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -81,6 +81,7 @@ func NewClient( wantIntersectFoundChan: make(chan chan<- clientPointResult, 1), } c.callbackContext = CallbackContext{ + Context: protoOptions.Context, Client: c, ConnectionId: protoOptions.ConnectionId, } @@ -147,9 +148,11 @@ func (c *Client) Stop() error { ) c.busyMutex.Lock() defer c.busyMutex.Unlock() - msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return + if !c.IsDone() { + msg := NewMsgDone() + if err = c.SendMessage(msg); err != nil { + return + } } }) return err diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index 73588334..a22f20c3 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -45,6 +45,7 @@ func NewServer( stateContext: stateContext, } s.callbackContext = CallbackContext{ + Context: protoOptions.Context, Server: s, ConnectionId: protoOptions.ConnectionId, } diff --git a/protocol/protocol.go b/protocol/protocol.go index f769131f..313a7eff 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -17,6 +17,7 @@ package protocol import ( "bytes" + "context" "errors" "fmt" "io" @@ -38,6 +39,7 @@ const DefaultRecvQueueSize = 50 // Protocol implements the base functionality of an Ouroboros mini-protocol type Protocol struct { config ProtocolConfig + currentState State doneChan chan struct{} muxerSendChan chan *muxer.Segment muxerRecvChan chan *muxer.Segment @@ -56,7 +58,6 @@ type Protocol struct { pendingRecvBytes int pendingRecvSizes []int // Track sizes of pending received messages for accurate decrement currentStateMu sync.RWMutex - currentState State } // ProtocolConfig provides the configuration for Protocol @@ -105,11 +106,13 @@ type ProtocolOptions struct { // TODO: remove me Role ProtocolRole Version uint16 + Context context.Context } type protocolStateTransition struct { - msg Message - errorChan chan<- error + msg Message + errorChan chan<- error + stateRespChan chan<- State } // MessageHandlerFunc represents a function that handles an incoming message @@ -132,6 +135,36 @@ func New(config ProtocolConfig) *Protocol { return p } +// CurrentState returns the current protocol state +func (p *Protocol) CurrentState() State { + p.currentStateMu.RLock() + defer p.currentStateMu.RUnlock() + return p.currentState +} + +// IsDone checks if the protocol is in a done/completed state +func (p *Protocol) IsDone() bool { + currentState := p.CurrentState() + // return true if current state has AgencyNone + if entry, exists := p.config.StateMap[currentState]; exists { + if entry.Agency == AgencyNone { + return true + } + } + // return true if current state is the initial state + return currentState == p.config.InitialState +} + +// GetDoneState returns the done state from the state map +func (s StateMap) GetDoneState() State { + for state, entry := range s { + if entry.Agency == AgencyNone { + return state + } + } + return State{} +} + // Start initializes the mini-protocol func (p *Protocol) Start() { p.onceStart.Do(func() { @@ -586,6 +619,7 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { ) } } + getTimerChan := func() <-chan time.Time { if transitionTimer == nil { return nil @@ -660,7 +694,7 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) { func (p *Protocol) transitionState(msg Message) error { errorChan := make(chan error, 1) - p.stateTransitionChan <- protocolStateTransition{msg, errorChan} + p.stateTransitionChan <- protocolStateTransition{msg, errorChan, nil} return <-errorChan } diff --git a/protocol/txsubmission/client.go b/protocol/txsubmission/client.go index aa6be02d..1d9ec33e 100644 --- a/protocol/txsubmission/client.go +++ b/protocol/txsubmission/client.go @@ -86,6 +86,7 @@ func (c *Client) Init() { func (c *Client) messageHandler(msg protocol.Message) error { c.Protocol.Logger(). Debug(fmt.Sprintf("%s: client message for %+v", ProtocolName, c.callbackContext.ConnectionId.RemoteAddr)) + var err error switch msg.Type() { case MessageTypeRequestTxIds: diff --git a/protocol/txsubmission/txsubmission_test.go b/protocol/txsubmission/txsubmission_test.go new file mode 100644 index 00000000..eb7e5e60 --- /dev/null +++ b/protocol/txsubmission/txsubmission_test.go @@ -0,0 +1,181 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package txsubmission + +import ( + "io" + "log/slog" + "net" + "sync" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +type testConn struct { + writeChan chan []byte + closed bool + closeChan chan struct{} + closeOnce sync.Once + mu sync.Mutex +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Close() error { + c.closeOnce.Do(func() { + c.mu.Lock() + defer c.mu.Unlock() + close(c.closeChan) + c.closed = true + }) + return nil +} +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func (c *testConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, io.EOF + } + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + } +} + +func TestNewTxSubmission(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + ts := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, ts.Client) + assert.NotNil(t, ts.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 300*time.Second, cfg.IdleTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithIdleTimeout(60*time.Second), + WithRequestTxIdsFunc(func(ctx CallbackContext, blocking bool, ack, req uint16) ([]TxIdAndSize, error) { + return nil, nil + }), + WithRequestTxsFunc(func(ctx CallbackContext, txIds []TxId) ([]TxBody, error) { + return nil, nil + }), + ) + assert.Equal(t, 60*time.Second, cfg.IdleTimeout) + assert.NotNil(t, cfg.RequestTxIdsFunc) + assert.NotNil(t, cfg.RequestTxsFunc) + }) +} +func TestCallbackRegistration(t *testing.T) { + conn := newTestConn() + defer conn.Close() + + t.Run("RequestTxIds callback registration", func(t *testing.T) { + requestTxIdsFunc := func(ctx CallbackContext, blocking bool, ack, req uint16) ([]TxIdAndSize, error) { + return nil, nil + } + cfg := NewConfig(WithRequestTxIdsFunc(requestTxIdsFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestTxIdsFunc) + }) + + t.Run("RequestTxs callback registration", func(t *testing.T) { + requestTxsFunc := func(ctx CallbackContext, txIds []TxId) ([]TxBody, error) { + return nil, nil + } + cfg := NewConfig(WithRequestTxsFunc(requestTxsFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestTxsFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + client := NewClient(getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + client.Start() + defer client.Stop() + + err := client.SendMessage(NewMsgInit()) + require.NoError(t, err) + + select { + case msg := <-conn.writeChan: + assert.NotEmpty(t, msg, "expected message to be written") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for message send") + } + }) +} + +func TestServerMessageHandling(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + server := NewServer(getTestProtocolOptions(conn), &cfg) + + t.Run("Server can be started", func(t *testing.T) { + server.Start() + defer server.Stop() + assert.NotNil(t, server) + }) +}