Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions tapgarden/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,10 @@ type MockKeyRing struct {
KeyIndex uint32

Keys map[keychain.KeyLocator]*btcec.PrivateKey

// deriveNextKeyCallCount is used to track the number of calls to
// DeriveNextKey.
deriveNextKeyCallCount atomic.Uint64
}

var _ KeyRing = (*MockKeyRing)(nil)
Expand All @@ -869,8 +873,11 @@ func NewMockKeyRing() *MockKeyRing {
keyRing.On(
"DeriveNextKey", mock.Anything,
keychain.KeyFamily(asset.TaprootAssetsKeyFamily),
).Return(nil)
keyRing.On("DeriveNextTaprootAssetKey", mock.Anything).Return(nil)
).Return(keychain.KeyDescriptor{}, nil)

keyRing.On(
"DeriveNextTaprootAssetKey", mock.Anything,
).Return(keychain.KeyDescriptor{}, nil)

return keyRing
}
Expand All @@ -880,6 +887,7 @@ func NewMockKeyRing() *MockKeyRing {
func (m *MockKeyRing) DeriveNextTaprootAssetKey(
ctx context.Context) (keychain.KeyDescriptor, error) {

// No need to lock mutex here, DeriveNextKey does that for us.
m.Called(ctx)

return m.DeriveNextKey(ctx, asset.TaprootAssetsKeyFamily)
Expand All @@ -888,20 +896,21 @@ func (m *MockKeyRing) DeriveNextTaprootAssetKey(
func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {

m.Lock()
defer func() {
m.KeyIndex++
m.Unlock()
}()

m.Called(ctx, keyFam)
m.deriveNextKeyCallCount.Add(1)

select {
case <-ctx.Done():
return keychain.KeyDescriptor{}, fmt.Errorf("shutting down")
default:
}

m.Lock()
defer func() {
m.KeyIndex++
m.Unlock()
}()

priv, err := btcec.NewPrivateKey()
if err != nil {
return keychain.KeyDescriptor{}, err
Expand All @@ -925,10 +934,10 @@ func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
func (m *MockKeyRing) IsLocalKey(ctx context.Context,
d keychain.KeyDescriptor) bool {

m.Called(ctx, d)
m.Lock()
defer m.Unlock()

m.RLock()
defer m.RUnlock()
m.Called(ctx, d)

priv, ok := m.Keys[d.KeyLocator]
if ok && priv.PubKey().IsEqual(d.PubKey) {
Expand All @@ -945,8 +954,8 @@ func (m *MockKeyRing) IsLocalKey(ctx context.Context,
}

func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
m.RLock()
defer m.RUnlock()
m.Lock()
defer m.Unlock()

loc := keychain.KeyLocator{
Index: idx,
Expand All @@ -962,8 +971,8 @@ func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
}

func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
m.RLock()
defer m.RUnlock()
m.Lock()
defer m.Unlock()

loc := keychain.KeyLocator{
Index: idx,
Expand All @@ -984,13 +993,13 @@ func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
func (m *MockKeyRing) DeriveSharedKey(_ context.Context, key *btcec.PublicKey,
locator *keychain.KeyLocator) ([sha256.Size]byte, error) {

m.Lock()
defer m.Unlock()

if locator == nil {
return [32]byte{}, fmt.Errorf("locator is nil")
}

m.RLock()
defer m.RUnlock()

priv, ok := m.Keys[*locator]
if !ok {
return [32]byte{}, fmt.Errorf("script key not found at index "+
Expand All @@ -1003,6 +1012,19 @@ func (m *MockKeyRing) DeriveSharedKey(_ context.Context, key *btcec.PublicKey,
return ecdh.ECDH(key)
}

// DeriveNextKeyCallCount returns the number of calls to DeriveNextKey. This is
// useful in tests to assert that the key ring was used as expected in
// concurrent scenarios.
func (m *MockKeyRing) DeriveNextKeyCallCount() int {
return int(m.deriveNextKeyCallCount.Load())
}

// ResetDeriveNextKeyCallCount resets the call counter for DeriveNextKey to
// zero. This is useful in tests to ensure a clean state for assertions.
func (m *MockKeyRing) ResetDeriveNextKeyCallCount() {
m.deriveNextKeyCallCount.Store(0)
}

type MockGenSigner struct {
KeyRing *MockKeyRing
failSigning atomic.Bool
Expand Down
45 changes: 31 additions & 14 deletions tapgarden/planter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ func (t *mintingTestHarness) queueSeedlingsInBatch(isFunded bool,

for i, seedling := range seedlings {
seedling := seedling

t.keyRing.ResetDeriveNextKeyCallCount()
keyCount := 0
t.keyRing.Calls = nil

// For the first seedling sent, we should get a new request,
// representing the batch internal key.
Expand Down Expand Up @@ -310,18 +311,21 @@ func (t *mintingTestHarness) queueSeedlingsInBatch(isFunded bool,
// The received update should be a state of MintingStateSeed.
require.Equal(t, tapgarden.MintingStateSeed, update.NewState)

require.Eventually(t, func() bool {
err = wait.NoError(func() error {
// Assert that the key ring method DeriveNextKey was
// called the expected number of times.
count := 0
for _, call := range t.keyRing.Calls {
if call.Method == "DeriveNextKey" {
count++
}
expectedCount := keyCount
actualCount := t.keyRing.DeriveNextKeyCallCount()

if actualCount < expectedCount {
return fmt.Errorf("expected %d calls to key "+
"derivation, got %d", expectedCount,
actualCount)
}

return count == keyCount
}, defaultTimeout, wait.PollInterval)
return nil
}, defaultTimeout)
require.NoError(t, err)
}
}

Expand All @@ -332,13 +336,26 @@ func (t *mintingTestHarness) assertPendingBatchExists(numSeedlings int) {

// The planter is a state machine, so we need to wait until it has
// reached the expected state.
require.Eventually(t, func() bool {
err := wait.NoError(func() error {
batch, err := t.planter.PendingBatch()
require.NoError(t, err)
if err != nil {
return fmt.Errorf("unable to fetch pending batch: %w",
err)
}

if batch == nil {
return fmt.Errorf("expected pending batch to be " +
"non-nil")
}

if len(batch.Seedlings) < numSeedlings {
return fmt.Errorf("expected %d seedlings, got %d",
numSeedlings, len(batch.Seedlings))
}

require.NotNil(t, batch)
return len(batch.Seedlings) == numSeedlings
}, defaultTimeout, wait.PollInterval)
return nil
}, defaultTimeout)
require.NoError(t, err)
}

// assertNoActiveBatch asserts that no pending batch exists.
Expand Down
Loading