diff --git a/tools/walletextension/services/session_key_activity.go b/tools/walletextension/services/session_key_activity.go index 499d291eee..47ac7526f6 100644 --- a/tools/walletextension/services/session_key_activity.go +++ b/tools/walletextension/services/session_key_activity.go @@ -1,12 +1,14 @@ package services import ( + "container/list" "sync" "time" gethcommon "github.com/ethereum/go-ethereum/common" gethlog "github.com/ethereum/go-ethereum/log" "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/storage" ) // SessionKeyActivityTracker exposes a minimal API for tracking activity @@ -16,62 +18,224 @@ type SessionKeyActivityTracker interface { ListAll() []common.SessionKeyActivity Load(items []common.SessionKeyActivity) Delete(addr gethcommon.Address) bool + // Stop gracefully shuts down the tracker, flushing pending writes + Stop() +} + +// lruEntry represents an entry in the LRU cache +type lruEntry struct { + addr gethcommon.Address + userID []byte + lastActive time.Time } type sessionKeyActivityTracker struct { - mu sync.RWMutex - byKey map[gethcommon.Address]sessionKeyActivityState - // maxEntries bounds memory usage; when full, oldest entry is evicted upon new insert + mu sync.RWMutex + + // LRU cache: doubly-linked list for O(1) eviction of oldest entry + // Front = most recently used, Back = least recently used (oldest) + lruList *list.List + // Map for O(1) lookup by address + byKey map[gethcommon.Address]*list.Element + + // maxEntries bounds memory usage; when full, oldest entry is evicted maxEntries int logger gethlog.Logger -} -// sessionKeyActivityState is the internal storage value; address is the map key -type sessionKeyActivityState struct { - UserID []byte - LastActive time.Time + // Async write queue for persisting evicted entries to DB + persistQueue chan common.SessionKeyActivity + persistStorage storage.SessionKeyActivityStorage + stopChan chan struct{} + stopOnce sync.Once + wg sync.WaitGroup } -// defaultMaxActivityEntries defines an upper bound to avoid unbounded memory growth -const defaultMaxActivityEntries = 100000 +// Configuration constants +const ( + defaultMaxActivityEntries = 100000 + persistQueueSize = 10000 + persistBatchSize = 100 + persistFlushInterval = 5 * time.Second +) func NewSessionKeyActivityTracker(logger gethlog.Logger) SessionKeyActivityTracker { - return &sessionKeyActivityTracker{ - byKey: make(map[gethcommon.Address]sessionKeyActivityState), - maxEntries: defaultMaxActivityEntries, - logger: logger, + return NewSessionKeyActivityTrackerWithStorage(logger, nil) +} + +// NewSessionKeyActivityTrackerWithStorage creates a tracker with async DB persistence +func NewSessionKeyActivityTrackerWithStorage(logger gethlog.Logger, persistStorage storage.SessionKeyActivityStorage) SessionKeyActivityTracker { + t := &sessionKeyActivityTracker{ + lruList: list.New(), + byKey: make(map[gethcommon.Address]*list.Element), + maxEntries: defaultMaxActivityEntries, + logger: logger, + persistStorage: persistStorage, + stopChan: make(chan struct{}), } + + // Start async persistence worker if storage is provided + if persistStorage != nil { + t.persistQueue = make(chan common.SessionKeyActivity, persistQueueSize) + t.wg.Add(1) + go t.persistWorker() + } + + return t +} + +// persistWorker runs in the background and batches writes to CosmosDB +func (t *sessionKeyActivityTracker) persistWorker() { + defer t.wg.Done() + + batch := make([]common.SessionKeyActivity, 0, persistBatchSize) + ticker := time.NewTicker(persistFlushInterval) + defer ticker.Stop() + + flush := func() { + if len(batch) == 0 { + return + } + if err := t.persistStorage.SaveBatch(batch); err != nil { + if t.logger != nil { + t.logger.Warn("Failed to persist evicted session key activities", "count", len(batch), "error", err) + } + } else { + if t.logger != nil { + t.logger.Debug("Persisted evicted session key activities", "count", len(batch)) + } + } + batch = batch[:0] + } + + for { + select { + case item, ok := <-t.persistQueue: + if !ok { + // Channel closed, flush remaining and exit + flush() + return + } + batch = append(batch, item) + if len(batch) >= persistBatchSize { + flush() + } + case <-ticker.C: + flush() + case <-t.stopChan: + // Drain remaining items from queue, checking for closed channel + for { + select { + case item, ok := <-t.persistQueue: + if !ok { + // Channel closed, flush and exit + flush() + return + } + batch = append(batch, item) + default: + flush() + return + } + } + } + } +} + +// Stop gracefully shuts down the tracker, flushing pending writes +func (t *sessionKeyActivityTracker) Stop() { + t.stopOnce.Do(func() { + close(t.stopChan) + if t.persistQueue != nil { + close(t.persistQueue) + } + }) + t.wg.Wait() } func (t *sessionKeyActivityTracker) MarkActive(userID []byte, addr gethcommon.Address) { now := time.Now() t.mu.Lock() defer t.mu.Unlock() - // if the address is already in the map, update the last active time - if state, ok := t.byKey[addr]; ok { - state.LastActive = now - t.byKey[addr] = state - } else { - // check if the map is at capacity - if len(t.byKey) >= t.maxEntries { + + // If the address already exists, update and move to front (most recently used) + if elem, ok := t.byKey[addr]; ok { + entry := elem.Value.(*lruEntry) + entry.lastActive = now + t.lruList.MoveToFront(elem) + return + } + + // New entry: check capacity + if len(t.byKey) >= t.maxEntries { + // Evict the oldest entry (back of the list) + t.evictOldest() + } + + // Add new entry at front (most recently used) + entry := &lruEntry{ + addr: addr, + userID: userID, + lastActive: now, + } + elem := t.lruList.PushFront(entry) + t.byKey[addr] = elem +} + +// evictOldest removes the least recently used entry and queues it for DB persistence +// Must be called with lock held +func (t *sessionKeyActivityTracker) evictOldest() { + back := t.lruList.Back() + if back == nil { + return + } + + entry := back.Value.(*lruEntry) + + // Queue for async DB persistence before removing from memory + if t.persistQueue != nil { + activity := common.SessionKeyActivity{ + Addr: entry.addr, + UserID: entry.userID, + LastActive: entry.lastActive, + } + select { + case t.persistQueue <- activity: + // Successfully queued + default: + // Queue full, log warning but continue with eviction if t.logger != nil { - t.logger.Warn("SessionKeyActivityTracker capacity reached; dropping new activity", "capacity", t.maxEntries, "addr", addr.Hex()) + t.logger.Warn("Persist queue full, evicted entry may be lost", "addr", entry.addr.Hex()) } - } else { - // if the map is not at capacity, add the address to the map - t.byKey[addr] = sessionKeyActivityState{UserID: userID, LastActive: now} } } + + // Remove from cache + t.lruList.Remove(back) + delete(t.byKey, entry.addr) + + if t.logger != nil { + t.logger.Debug("Evicted oldest session key activity", "addr", entry.addr.Hex(), "lastActive", entry.lastActive) + } } func (t *sessionKeyActivityTracker) ListOlderThan(cutoff time.Time) []common.SessionKeyActivity { t.mu.RLock() defer t.mu.RUnlock() - // preallocate with current size upper bound; filter below - result := make([]common.SessionKeyActivity, 0, len(t.byKey)) - for addr, state := range t.byKey { - if state.LastActive.Before(cutoff) { - result = append(result, common.SessionKeyActivity{Addr: addr, UserID: state.UserID, LastActive: state.LastActive}) + + result := make([]common.SessionKeyActivity, 0) + for elem := t.lruList.Back(); elem != nil; elem = elem.Prev() { + entry := elem.Value.(*lruEntry) + if entry.lastActive.Before(cutoff) { + result = append(result, common.SessionKeyActivity{ + Addr: entry.addr, + UserID: entry.userID, + LastActive: entry.lastActive, + }) + } else { + // Since list is ordered by last access time (oldest at back), + // once we hit an entry newer than cutoff, all remaining entries + // will also be newer, so we can stop early + break } } return result @@ -80,37 +244,73 @@ func (t *sessionKeyActivityTracker) ListOlderThan(cutoff time.Time) []common.Ses func (t *sessionKeyActivityTracker) ListAll() []common.SessionKeyActivity { t.mu.RLock() defer t.mu.RUnlock() + result := make([]common.SessionKeyActivity, 0, len(t.byKey)) - for addr, state := range t.byKey { - result = append(result, common.SessionKeyActivity{Addr: addr, UserID: state.UserID, LastActive: state.LastActive}) + for elem := t.lruList.Front(); elem != nil; elem = elem.Next() { + entry := elem.Value.(*lruEntry) + result = append(result, common.SessionKeyActivity{ + Addr: entry.addr, + UserID: entry.userID, + LastActive: entry.lastActive, + }) } return result } func (t *sessionKeyActivityTracker) Load(items []common.SessionKeyActivity) { t.mu.Lock() - // Enforce capacity limit by truncating the input slice if necessary - if len(items) > t.maxEntries { + defer t.mu.Unlock() + + // Clear existing data + t.lruList = list.New() + t.byKey = make(map[gethcommon.Address]*list.Element) + + // Sort items by LastActive (oldest first) so we can build the LRU list correctly + // Items at the front will be most recent, items at back will be oldest + // We'll add them in reverse order (oldest first) so oldest ends up at back + sorted := make([]common.SessionKeyActivity, len(items)) + copy(sorted, items) + + // Simple insertion sort by LastActive (ascending = oldest first) + for i := 1; i < len(sorted); i++ { + j := i + for j > 0 && sorted[j].LastActive.Before(sorted[j-1].LastActive) { + sorted[j], sorted[j-1] = sorted[j-1], sorted[j] + j-- + } + } + + // Enforce capacity limit + startIdx := 0 + if len(sorted) > t.maxEntries { + startIdx = len(sorted) - t.maxEntries if t.logger != nil { - t.logger.Warn("ReplaceAll truncated due to capacity", "requested", len(items), "capacity", t.maxEntries) + t.logger.Warn("Load truncated due to capacity, oldest entries dropped", + "total", len(sorted), "dropped", startIdx, "loaded", t.maxEntries) } - items = items[:t.maxEntries] } - newMap := make(map[gethcommon.Address]sessionKeyActivityState, len(items)) - for _, it := range items { - newMap[it.Addr] = sessionKeyActivityState{UserID: it.UserID, LastActive: it.LastActive} + // Add entries: oldest first (will be at back of list), newest last (will be at front) + for i := startIdx; i < len(sorted); i++ { + item := sorted[i] + entry := &lruEntry{ + addr: item.Addr, + userID: item.UserID, + lastActive: item.LastActive, + } + elem := t.lruList.PushFront(entry) + t.byKey[item.Addr] = elem } - t.byKey = newMap - t.mu.Unlock() } func (t *sessionKeyActivityTracker) Delete(addr gethcommon.Address) bool { t.mu.Lock() defer t.mu.Unlock() - _, existed := t.byKey[addr] - if existed { + + if elem, ok := t.byKey[addr]; ok { + t.lruList.Remove(elem) delete(t.byKey, addr) + return true } - return existed + return false } diff --git a/tools/walletextension/services/session_key_activity_test.go b/tools/walletextension/services/session_key_activity_test.go new file mode 100644 index 0000000000..2e060122b8 --- /dev/null +++ b/tools/walletextension/services/session_key_activity_test.go @@ -0,0 +1,349 @@ +package services + +import ( + "sync" + "testing" + "time" + + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ten-protocol/go-ten/tools/walletextension/common" +) + +// mockActivityStorage is a mock implementation of SessionKeyActivityStorage for testing +type mockActivityStorage struct { + mu sync.Mutex + items map[gethcommon.Address]common.SessionKeyActivity + saved []common.SessionKeyActivity + errors map[string]error +} + +func newMockActivityStorage() *mockActivityStorage { + return &mockActivityStorage{ + items: make(map[gethcommon.Address]common.SessionKeyActivity), + saved: make([]common.SessionKeyActivity, 0), + errors: make(map[string]error), + } +} + +func (m *mockActivityStorage) Load() ([]common.SessionKeyActivity, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := make([]common.SessionKeyActivity, 0, len(m.items)) + for _, item := range m.items { + result = append(result, item) + } + return result, m.errors["Load"] +} + +func (m *mockActivityStorage) Save(items []common.SessionKeyActivity) error { + m.mu.Lock() + defer m.mu.Unlock() + m.items = make(map[gethcommon.Address]common.SessionKeyActivity) + for _, item := range items { + m.items[item.Addr] = item + } + return m.errors["Save"] +} + +func (m *mockActivityStorage) SaveBatch(items []common.SessionKeyActivity) error { + m.mu.Lock() + defer m.mu.Unlock() + m.saved = append(m.saved, items...) + for _, item := range items { + m.items[item.Addr] = item + } + return m.errors["SaveBatch"] +} + +func (m *mockActivityStorage) ListOlderThan(cutoff time.Time) ([]common.SessionKeyActivity, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := make([]common.SessionKeyActivity, 0) + for _, item := range m.items { + if item.LastActive.Before(cutoff) { + result = append(result, item) + } + } + return result, m.errors["ListOlderThan"] +} + +func (m *mockActivityStorage) Delete(addr gethcommon.Address) error { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.items, addr) + return m.errors["Delete"] +} + +func (m *mockActivityStorage) getSaved() []common.SessionKeyActivity { + m.mu.Lock() + defer m.mu.Unlock() + return append([]common.SessionKeyActivity{}, m.saved...) +} + +func TestSessionKeyActivityTracker_MarkActive_Basic(t *testing.T) { + tracker := NewSessionKeyActivityTracker(nil) + defer tracker.Stop() + + addr := gethcommon.HexToAddress("0x1234567890123456789012345678901234567890") + userID := []byte("test-user-1") + + tracker.MarkActive(userID, addr) + + all := tracker.ListAll() + require.Len(t, all, 1) + assert.Equal(t, addr, all[0].Addr) + assert.Equal(t, userID, all[0].UserID) +} + +func TestSessionKeyActivityTracker_MarkActive_UpdateExisting(t *testing.T) { + tracker := NewSessionKeyActivityTracker(nil) + defer tracker.Stop() + + addr := gethcommon.HexToAddress("0x1234567890123456789012345678901234567890") + userID := []byte("test-user-1") + + // First activation + tracker.MarkActive(userID, addr) + firstAll := tracker.ListAll() + firstTime := firstAll[0].LastActive + + // Wait a bit and activate again + time.Sleep(10 * time.Millisecond) + tracker.MarkActive(userID, addr) + + all := tracker.ListAll() + require.Len(t, all, 1) + assert.True(t, all[0].LastActive.After(firstTime), "LastActive should be updated") +} + +func TestSessionKeyActivityTracker_ListOlderThan(t *testing.T) { + tracker := NewSessionKeyActivityTracker(nil) + defer tracker.Stop() + + addr1 := gethcommon.HexToAddress("0x1111111111111111111111111111111111111111") + addr2 := gethcommon.HexToAddress("0x2222222222222222222222222222222222222222") + + // Add first entry + tracker.MarkActive([]byte("user1"), addr1) + + // Wait and add second entry + time.Sleep(20 * time.Millisecond) + cutoff := time.Now() + time.Sleep(20 * time.Millisecond) + + tracker.MarkActive([]byte("user2"), addr2) + + // Only addr1 should be older than cutoff + older := tracker.ListOlderThan(cutoff) + require.Len(t, older, 1) + assert.Equal(t, addr1, older[0].Addr) +} + +func TestSessionKeyActivityTracker_Delete(t *testing.T) { + tracker := NewSessionKeyActivityTracker(nil) + defer tracker.Stop() + + addr := gethcommon.HexToAddress("0x1234567890123456789012345678901234567890") + + tracker.MarkActive([]byte("user"), addr) + require.Len(t, tracker.ListAll(), 1) + + deleted := tracker.Delete(addr) + assert.True(t, deleted) + assert.Empty(t, tracker.ListAll()) + + // Delete non-existent + deleted = tracker.Delete(addr) + assert.False(t, deleted) +} + +func TestSessionKeyActivityTracker_Load(t *testing.T) { + tracker := NewSessionKeyActivityTracker(nil) + defer tracker.Stop() + + items := []common.SessionKeyActivity{ + { + Addr: gethcommon.HexToAddress("0x1111111111111111111111111111111111111111"), + UserID: []byte("user1"), + LastActive: time.Now().Add(-2 * time.Hour), + }, + { + Addr: gethcommon.HexToAddress("0x2222222222222222222222222222222222222222"), + UserID: []byte("user2"), + LastActive: time.Now().Add(-1 * time.Hour), + }, + } + + tracker.Load(items) + + all := tracker.ListAll() + require.Len(t, all, 2) +} + +func TestSessionKeyActivityTracker_LRU_EvictsOldest(t *testing.T) { + storage := newMockActivityStorage() + tracker := NewSessionKeyActivityTrackerWithStorage(nil, storage).(*sessionKeyActivityTracker) + // Override max entries for testing + tracker.maxEntries = 3 + + addr1 := gethcommon.HexToAddress("0x1111111111111111111111111111111111111111") + addr2 := gethcommon.HexToAddress("0x2222222222222222222222222222222222222222") + addr3 := gethcommon.HexToAddress("0x3333333333333333333333333333333333333333") + addr4 := gethcommon.HexToAddress("0x4444444444444444444444444444444444444444") + + // Add 3 entries (at capacity) + tracker.MarkActive([]byte("user1"), addr1) + time.Sleep(5 * time.Millisecond) + tracker.MarkActive([]byte("user2"), addr2) + time.Sleep(5 * time.Millisecond) + tracker.MarkActive([]byte("user3"), addr3) + + require.Len(t, tracker.ListAll(), 3) + + // Add 4th entry - should evict addr1 (oldest) + time.Sleep(5 * time.Millisecond) + tracker.MarkActive([]byte("user4"), addr4) + + // Should still have 3 entries + all := tracker.ListAll() + require.Len(t, all, 3) + + // addr1 should be gone, addr2, addr3, addr4 should remain + addrs := make(map[gethcommon.Address]bool) + for _, a := range all { + addrs[a.Addr] = true + } + + assert.False(t, addrs[addr1], "addr1 should have been evicted") + assert.True(t, addrs[addr2], "addr2 should remain") + assert.True(t, addrs[addr3], "addr3 should remain") + assert.True(t, addrs[addr4], "addr4 should remain") + + // Stop the tracker to flush pending writes + tracker.Stop() + + // Check that evicted entry was queued for persistence + saved := storage.getSaved() + require.Len(t, saved, 1) + assert.Equal(t, addr1, saved[0].Addr) +} + +func TestSessionKeyActivityTracker_LRU_UpdateMovesToFront(t *testing.T) { + storage := newMockActivityStorage() + tracker := NewSessionKeyActivityTrackerWithStorage(nil, storage).(*sessionKeyActivityTracker) + tracker.maxEntries = 3 + + addr1 := gethcommon.HexToAddress("0x1111111111111111111111111111111111111111") + addr2 := gethcommon.HexToAddress("0x2222222222222222222222222222222222222222") + addr3 := gethcommon.HexToAddress("0x3333333333333333333333333333333333333333") + addr4 := gethcommon.HexToAddress("0x4444444444444444444444444444444444444444") + + // Add 3 entries + tracker.MarkActive([]byte("user1"), addr1) // oldest + time.Sleep(5 * time.Millisecond) + tracker.MarkActive([]byte("user2"), addr2) + time.Sleep(5 * time.Millisecond) + tracker.MarkActive([]byte("user3"), addr3) + + // Re-activate addr1 - this should move it to front (most recently used) + time.Sleep(5 * time.Millisecond) + tracker.MarkActive([]byte("user1"), addr1) + + // Now add addr4 - addr2 should be evicted (it's now the oldest) + time.Sleep(5 * time.Millisecond) + tracker.MarkActive([]byte("user4"), addr4) + + all := tracker.ListAll() + require.Len(t, all, 3) + + addrs := make(map[gethcommon.Address]bool) + for _, a := range all { + addrs[a.Addr] = true + } + + assert.True(t, addrs[addr1], "addr1 should remain (was re-activated)") + assert.False(t, addrs[addr2], "addr2 should have been evicted (was oldest after addr1 re-activation)") + assert.True(t, addrs[addr3], "addr3 should remain") + assert.True(t, addrs[addr4], "addr4 should remain") + + tracker.Stop() +} + +func TestSessionKeyActivityTracker_ConcurrentAccess(t *testing.T) { + tracker := NewSessionKeyActivityTracker(nil) + defer tracker.Stop() + + var wg sync.WaitGroup + numGoroutines := 10 + numOpsPerGoroutine := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + for j := 0; j < numOpsPerGoroutine; j++ { + addr := gethcommon.HexToAddress("0x" + string(rune('0'+idx)) + "234567890123456789012345678901234567890") + tracker.MarkActive([]byte("user"), addr) + tracker.ListAll() + tracker.ListOlderThan(time.Now()) + } + }(i) + } + + wg.Wait() + // Should not panic or deadlock +} + +func TestSessionKeyActivityTracker_LoadWithCapacityLimit(t *testing.T) { + tracker := NewSessionKeyActivityTracker(nil).(*sessionKeyActivityTracker) + tracker.maxEntries = 3 + defer tracker.Stop() + + // Create more items than capacity + items := []common.SessionKeyActivity{ + { + Addr: gethcommon.HexToAddress("0x1111111111111111111111111111111111111111"), + UserID: []byte("user1"), + LastActive: time.Now().Add(-4 * time.Hour), // oldest - should be dropped + }, + { + Addr: gethcommon.HexToAddress("0x2222222222222222222222222222222222222222"), + UserID: []byte("user2"), + LastActive: time.Now().Add(-3 * time.Hour), // second oldest - should be dropped + }, + { + Addr: gethcommon.HexToAddress("0x3333333333333333333333333333333333333333"), + UserID: []byte("user3"), + LastActive: time.Now().Add(-2 * time.Hour), + }, + { + Addr: gethcommon.HexToAddress("0x4444444444444444444444444444444444444444"), + UserID: []byte("user4"), + LastActive: time.Now().Add(-1 * time.Hour), + }, + { + Addr: gethcommon.HexToAddress("0x5555555555555555555555555555555555555555"), + UserID: []byte("user5"), + LastActive: time.Now(), // newest + }, + } + + tracker.Load(items) + + all := tracker.ListAll() + require.Len(t, all, 3, "should only load up to capacity") + + // Should have the 3 newest entries + addrs := make(map[gethcommon.Address]bool) + for _, a := range all { + addrs[a.Addr] = true + } + + assert.False(t, addrs[gethcommon.HexToAddress("0x1111111111111111111111111111111111111111")], "oldest should be dropped") + assert.False(t, addrs[gethcommon.HexToAddress("0x2222222222222222222222222222222222222222")], "second oldest should be dropped") + assert.True(t, addrs[gethcommon.HexToAddress("0x3333333333333333333333333333333333333333")]) + assert.True(t, addrs[gethcommon.HexToAddress("0x4444444444444444444444444444444444444444")]) + assert.True(t, addrs[gethcommon.HexToAddress("0x5555555555555555555555555555555555555555")]) +} diff --git a/tools/walletextension/services/session_key_expiration.go b/tools/walletextension/services/session_key_expiration.go index 9e4fcd892b..00cb120eec 100644 --- a/tools/walletextension/services/session_key_expiration.go +++ b/tools/walletextension/services/session_key_expiration.go @@ -4,6 +4,7 @@ import ( "context" "time" + gethcommon "github.com/ethereum/go-ethereum/common" gethlog "github.com/ethereum/go-ethereum/log" "github.com/ten-protocol/go-ten/go/common/stopcontrol" wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" @@ -103,8 +104,20 @@ func (s *SessionKeyExpirationService) sessionKeyExpiration() { s.logger.Info("Session key expiration check started") cutoff := time.Now().Add(-s.config.SessionKeyExpirationThreshold) - candidates := s.activityTracker.ListOlderThan(cutoff) - s.logger.Info("Session key expiration check", "cutoff", cutoff, "candidatesFound", len(candidates)) + + // Get expired candidates from in-memory LRU cache + memoryCandidates := s.activityTracker.ListOlderThan(cutoff) + + // Also query CosmosDB for expired entries (evicted from memory but still valid) + dbCandidates, err := s.activityStorage.ListOlderThan(cutoff) + if err != nil { + s.logger.Warn("Failed to query DB for expired session keys", "error", err) + dbCandidates = nil + } + + // Merge and deduplicate: in-memory takes precedence for recent updates + candidates := s.mergeActivityLists(memoryCandidates, dbCandidates) + s.logger.Info("Session key expiration check", "cutoff", cutoff, "memoryCount", len(memoryCandidates), "dbCount", len(dbCandidates), "mergedCount", len(candidates)) for _, c := range candidates { s.logger.Info("Processing expired session key candidate", @@ -120,8 +133,8 @@ func (s *SessionKeyExpirationService) sessionKeyExpiration() { // Ensure this session key still belongs to the user if _, ok := user.SessionKeys[c.Addr]; !ok { - // The session key may have been deleted; remove from tracker - s.activityTracker.Delete(c.Addr) + // The session key may have been deleted; remove from both tracker and DB + s.deleteActivity(c.Addr) continue } @@ -148,11 +161,44 @@ func (s *SessionKeyExpirationService) sessionKeyExpiration() { "sessionKeyAddress", c.Addr.Hex(), "txHash", txHash.Hex()) - // After successful external operation, delete from tracker - _ = s.activityTracker.Delete(c.Addr) + // After successful external operation, delete from both tracker and DB + s.deleteActivity(c.Addr) } // store all activities in the database to make them persistent and recoverable in case of restart allActivities := s.activityTracker.ListAll() _ = s.activityStorage.Save(allActivities) } + +// mergeActivityLists merges activities from memory and DB, deduplicating by address. +// In-memory entries take precedence as they have the most recent state. +func (s *SessionKeyExpirationService) mergeActivityLists(memory, db []wecommon.SessionKeyActivity) []wecommon.SessionKeyActivity { + seen := make(map[gethcommon.Address]struct{}) + result := make([]wecommon.SessionKeyActivity, 0, len(memory)+len(db)) + + // Add all memory entries first (they take precedence) + for _, item := range memory { + if _, exists := seen[item.Addr]; !exists { + seen[item.Addr] = struct{}{} + result = append(result, item) + } + } + + // Add DB entries that aren't already in memory + for _, item := range db { + if _, exists := seen[item.Addr]; !exists { + seen[item.Addr] = struct{}{} + result = append(result, item) + } + } + + return result +} + +// deleteActivity removes an activity from both the in-memory tracker and the database +func (s *SessionKeyExpirationService) deleteActivity(addr gethcommon.Address) { + s.activityTracker.Delete(addr) + if err := s.activityStorage.Delete(addr); err != nil { + s.logger.Warn("Failed to delete activity from DB", "addr", addr.Hex(), "error", err) + } +} diff --git a/tools/walletextension/services/wallet_extension.go b/tools/walletextension/services/wallet_extension.go index 34179290ce..b9034e9b91 100644 --- a/tools/walletextension/services/wallet_extension.go +++ b/tools/walletextension/services/wallet_extension.go @@ -71,7 +71,7 @@ const MaxAccountsPerUser = 100 // ErrMaxAccountsPerUserReached indicates a user has reached the allowed account limit var ErrMaxAccountsPerUserReached = errors.New("maximum number of accounts per user reached") -func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.UserStorage, stopControl *stopcontrol.StopControl, version string, logger gethlog.Logger, metricsTracker metrics.Metrics, config *common.Config) *Services { +func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.UserStorage, activityStorage storage.SessionKeyActivityStorage, stopControl *stopcontrol.StopControl, version string, logger gethlog.Logger, metricsTracker metrics.Metrics, config *common.Config) *Services { var newGatewayCache cache.Cache var err error @@ -89,7 +89,8 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.UserSto rateLimiter := ratelimiter.NewRateLimiter(config.RateLimitUserComputeTime, config.RateLimitWindow, uint32(config.RateLimitMaxConcurrentRequests), logger) httpRateLimiter := ratelimiter.NewHTTPRateLimiter(config.HTTPRateLimitGlobalRate, config.HTTPRateLimitPerIPRate, logger) - activityTracker := NewSessionKeyActivityTracker(logger) + // Create activity tracker with async persistence to DB for evicted entries + activityTracker := NewSessionKeyActivityTrackerWithStorage(logger, activityStorage) services := Services{ HostAddrHTTP: hostAddrHTTP, @@ -375,5 +376,9 @@ func (w *Services) Stop() { if w.HTTPRateLimiter != nil { w.HTTPRateLimiter.Stop() } + // Stop the activity tracker to flush pending writes + if w.ActivityTracker != nil { + w.ActivityTracker.Stop() + } close(w.cacheInvalidationCh) } diff --git a/tools/walletextension/storage/database/cosmosdb/session_key_activity_storage.go b/tools/walletextension/storage/database/cosmosdb/session_key_activity_storage.go index 90d5401acd..62241148b0 100644 --- a/tools/walletextension/storage/database/cosmosdb/session_key_activity_storage.go +++ b/tools/walletextension/storage/database/cosmosdb/session_key_activity_storage.go @@ -26,8 +26,16 @@ const ( // SessionKeyActivityStorage interface defines the session key activity storage operations type SessionKeyActivityStorage interface { + // Load retrieves all stored session key activities (used on startup) Load() ([]wecommon.SessionKeyActivity, error) + // Save performs a full replacement of all stored activities Save([]wecommon.SessionKeyActivity) error + // SaveBatch upserts a batch of activities (used for async persistence of evicted entries) + SaveBatch([]wecommon.SessionKeyActivity) error + // ListOlderThan returns activities with LastActive before the given cutoff + ListOlderThan(cutoff time.Time) ([]wecommon.SessionKeyActivity, error) + // Delete removes a specific activity by address + Delete(addr gethcommon.Address) error } type sessionKeyActivityStorageCosmosDB struct { @@ -295,3 +303,140 @@ func (s *sessionKeyActivityStorageCosmosDB) shardIndexForAddress(addr gethcommon func (s *sessionKeyActivityStorageCosmosDB) getShardDocumentIDByIndex(index int) string { return fmt.Sprintf("%s%d", skShardPrefix, index) } + +// SaveBatch upserts a batch of session key activities to the database. +// Unlike Save(), this method performs incremental updates - it reads existing shard data, +// merges with the new items, and writes back. This is used for async persistence of +// evicted entries from the in-memory LRU cache. +func (s *sessionKeyActivityStorageCosmosDB) SaveBatch(items []wecommon.SessionKeyActivity) error { + if len(items) == 0 { + return nil + } + + ctx := context.Background() + timestamp := time.Now().UTC().Format(time.RFC3339) + + // Group items by shard index + itemsByShardIndex := make(map[int][]wecommon.SessionKeyActivity) + for _, item := range items { + shardIdx := s.shardIndexForAddress(item.Addr) + itemsByShardIndex[shardIdx] = append(itemsByShardIndex[shardIdx], item) + } + + // For each affected shard, read existing data, merge, and write back + for shardIdx, newItems := range itemsByShardIndex { + existing, err := s.readShard(ctx, shardIdx) + if err != nil { + return fmt.Errorf("failed to read shard %d for merge: %w", shardIdx, err) + } + + // Merge: use a map to deduplicate by address, new items take precedence + merged := make(map[gethcommon.Address]wecommon.SessionKeyActivity) + for _, item := range existing { + merged[item.Addr] = item + } + for _, item := range newItems { + merged[item.Addr] = item + } + + // Convert back to slice + mergedSlice := make([]wecommon.SessionKeyActivity, 0, len(merged)) + for _, item := range merged { + mergedSlice = append(mergedSlice, item) + } + + // Write the merged data + if err := s.writeShardData(ctx, shardIdx, mergedSlice, timestamp); err != nil { + return fmt.Errorf("failed to write merged shard %d: %w", shardIdx, err) + } + } + + return nil +} + +// readShard reads all activities from a specific shard. +func (s *sessionKeyActivityStorageCosmosDB) readShard(ctx context.Context, shardIdx int) ([]wecommon.SessionKeyActivity, error) { + shardID := s.getShardDocumentIDByIndex(shardIdx) + pk := azcosmos.NewPartitionKeyString(shardID) + resp, err := s.container.ReadItem(ctx, pk, shardID, nil) + if err != nil { + if strings.Contains(err.Error(), "NotFound") { + return nil, nil // Shard doesn't exist yet + } + return nil, err + } + + // Unmarshal into EncryptedDocument + var doc EncryptedDocument + if err := json.Unmarshal(resp.Value, &doc); err != nil { + return nil, fmt.Errorf("failed to unmarshal document: %w", err) + } + + // Decrypt the data + data, err := s.encryptor.Decrypt(doc.Data) + if err != nil { + return nil, fmt.Errorf("failed to decrypt data: %w", err) + } + + // Unmarshal decrypted JSON into DTO + var dto sessionKeyActivityDTO + if err := json.Unmarshal(data, &dto); err != nil { + return nil, fmt.Errorf("failed to unmarshal session key activity data: %w", err) + } + + result := make([]wecommon.SessionKeyActivity, 0, len(dto.Items)) + for _, it := range dto.Items { + result = append(result, wecommon.SessionKeyActivity{ + Addr: gethcommon.BytesToAddress(it.Addr), + UserID: it.UserID, + LastActive: it.LastActive, + }) + } + return result, nil +} + +// ListOlderThan returns all activities with LastActive before the given cutoff time. +// This queries all shards and filters by the cutoff timestamp. +func (s *sessionKeyActivityStorageCosmosDB) ListOlderThan(cutoff time.Time) ([]wecommon.SessionKeyActivity, error) { + ctx := context.Background() + result := make([]wecommon.SessionKeyActivity, 0) + + for i := 0; i < s.shardCount; i++ { + shardItems, err := s.readShard(ctx, i) + if err != nil { + return nil, fmt.Errorf("failed to read shard %d: %w", i, err) + } + + for _, item := range shardItems { + if item.LastActive.Before(cutoff) { + result = append(result, item) + } + } + } + + return result, nil +} + +// Delete removes a specific activity by address from the database. +// It reads the shard containing the address, removes the entry, and writes back. +func (s *sessionKeyActivityStorageCosmosDB) Delete(addr gethcommon.Address) error { + ctx := context.Background() + shardIdx := s.shardIndexForAddress(addr) + timestamp := time.Now().UTC().Format(time.RFC3339) + + existing, err := s.readShard(ctx, shardIdx) + if err != nil { + return fmt.Errorf("failed to read shard %d for delete: %w", shardIdx, err) + } + + // Filter out the address to delete + filtered := make([]wecommon.SessionKeyActivity, 0, len(existing)) + for _, item := range existing { + if item.Addr != addr { + filtered = append(filtered, item) + } + } + + // Write back the filtered data + return s.writeShardData(ctx, shardIdx, filtered, timestamp) +} diff --git a/tools/walletextension/storage/session_key_activity_storage.go b/tools/walletextension/storage/session_key_activity_storage.go index dc09f8cf01..3adca6c790 100644 --- a/tools/walletextension/storage/session_key_activity_storage.go +++ b/tools/walletextension/storage/session_key_activity_storage.go @@ -1,14 +1,25 @@ package storage import ( + "time" + + gethcommon "github.com/ethereum/go-ethereum/common" wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" "github.com/ten-protocol/go-ten/tools/walletextension/storage/database/cosmosdb" ) // SessionKeyActivityStorage defines persistence for session key activity tracker type SessionKeyActivityStorage interface { + // Load retrieves all stored session key activities (used on startup) Load() ([]wecommon.SessionKeyActivity, error) + // Save performs a full replacement of all stored activities Save([]wecommon.SessionKeyActivity) error + // SaveBatch upserts a batch of activities (used for async persistence of evicted entries) + SaveBatch([]wecommon.SessionKeyActivity) error + // ListOlderThan returns activities with LastActive before the given cutoff + ListOlderThan(cutoff time.Time) ([]wecommon.SessionKeyActivity, error) + // Delete removes a specific activity by address + Delete(addr gethcommon.Address) error } // NewSessionKeyActivityStorage is a factory that returns a concrete storage based on dbType @@ -31,3 +42,11 @@ func (n *noOpSessionKeyActivityStorage) Load() ([]wecommon.SessionKeyActivity, e } func (n *noOpSessionKeyActivityStorage) Save([]wecommon.SessionKeyActivity) error { return nil } + +func (n *noOpSessionKeyActivityStorage) SaveBatch([]wecommon.SessionKeyActivity) error { return nil } + +func (n *noOpSessionKeyActivityStorage) ListOlderThan(time.Time) ([]wecommon.SessionKeyActivity, error) { + return nil, nil +} + +func (n *noOpSessionKeyActivityStorage) Delete(gethcommon.Address) error { return nil } diff --git a/tools/walletextension/walletextension_container.go b/tools/walletextension/walletextension_container.go index c886b6a6a3..7ac8e498f1 100644 --- a/tools/walletextension/walletextension_container.go +++ b/tools/walletextension/walletextension_container.go @@ -105,7 +105,7 @@ func NewContainerFromConfig(config wecommon.Config, logger gethlog.Logger) *Cont } stopControl := stopcontrol.New() - walletExt := services.NewServices(hostRPCBindAddrHTTP, hostRPCBindAddrWS, userStorage, stopControl, version, logger, metricsTracker, &config) + walletExt := services.NewServices(hostRPCBindAddrHTTP, hostRPCBindAddrWS, userStorage, activityStorage, stopControl, version, logger, metricsTracker, &config) // Create session key expiration service after services are created var sessionKeyExpirationService *services.SessionKeyExpirationService