From 5b38c56faa7551889c5c1a27125320523a418cee Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 11:18:28 +0100 Subject: [PATCH 01/11] Fix routes state race condition --- .../internal/routemanager/systemops/state.go | 23 ++++++++----------- .../systemops/systemops_generic.go | 6 +---- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 42590892297..8e158711e50 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -2,31 +2,28 @@ package systemops import ( "net/netip" - "sync" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type ShutdownState struct { - Counter *ExclusionCounter `json:"counter,omitempty"` - mu sync.RWMutex -} +type ShutdownState ExclusionCounter func (s *ShutdownState) Name() string { return "route_state" } func (s *ShutdownState) Cleanup() error { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.Counter == nil { - return nil - } - sysops := NewSysOps(nil, nil) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData(s.Counter) + sysops.refCounter.LoadData((*ExclusionCounter)(s)) return sysops.refCounter.Flush() } + +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + return (*ExclusionCounter)(s).MarshalJSON() +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return (*ExclusionCounter)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4ff34aa5162..d1e1bf0fd8a 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -76,11 +76,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana } func (r *SysOps) updateState(stateManager *statemanager.Manager) { - state := getState(stateManager) - - state.Counter = r.refCounter - - if err := stateManager.UpdateState(state); err != nil { + if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) } } From 3c95f6fc20a1048d8de8d86d6dc0d679eed9ef9b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 12:59:39 +0100 Subject: [PATCH 02/11] Ensure lock is in place during marshaling --- client/internal/statemanager/manager.go | 8 ++++++-- util/file.go | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 580ccdfc78a..8b085b882d2 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } + bs, err := json.MarshalIndent(m.states, "", " ") + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) - start := time.Now() go func() { - done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) + done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs) }() select { diff --git a/util/file.go b/util/file.go index 4641cc1b825..7be5742b3e4 100644 --- a/util/file.go +++ b/util/file.go @@ -14,6 +14,19 @@ import ( log "github.com/sirupsen/logrus" ) +func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error { + configDir, configFileName, err := prepareConfigFileDir(file) + if err != nil { + return fmt.Errorf("prepare config file dir: %w", err) + } + + if err = EnforcePermission(file); err != nil { + return fmt.Errorf("enfore permission: %w", err) + } + + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) @@ -91,6 +104,10 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri return err } + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + +func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { if ctx.Err() != nil { return ctx.Err() } From 00a4edc812f58613c4d2a2cbd129553865439de2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 13:43:34 +0100 Subject: [PATCH 03/11] Add file deadline --- util/file.go | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/util/file.go b/util/file.go index 7be5742b3e4..f547cd76c41 100644 --- a/util/file.go +++ b/util/file.go @@ -114,14 +114,26 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { - return err + return fmt.Errorf("create temp: %w", err) } tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() + + if deadline, ok := ctx.Deadline(); ok { + if err := tempFile.SetDeadline(deadline); err != nil { + //if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set write deadline: %v", err) + } + } + + _, err = tempFile.Write(bs) if err != nil { - return err + _ = tempFile.Close() + return fmt.Errorf("write: %w", err) + } + + if err = tempFile.Close(); err != nil { + return fmt.Errorf("close %s: %w", tempFileName, err) } defer func() { @@ -131,19 +143,13 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c } }() - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err - } - // Check context again if ctx.Err() != nil { return ctx.Err() } - err = os.Rename(tempFileName, file) - if err != nil { - return err + if err = os.Rename(tempFileName, file); err != nil { + return fmt.Errorf("move %s to %s: %w", tempFileName, file, err) } return nil From cd0dbae1ecdf3f29ad8bece4e51c54ee6fc84a6c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 14:00:13 +0100 Subject: [PATCH 04/11] Add test --- client/internal/statemanager/manager_test.go | 82 ++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 client/internal/statemanager/manager_test.go diff --git a/client/internal/statemanager/manager_test.go b/client/internal/statemanager/manager_test.go new file mode 100644 index 00000000000..f3ca8187fa6 --- /dev/null +++ b/client/internal/statemanager/manager_test.go @@ -0,0 +1,82 @@ +package statemanager + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockState implements the State interface for testing +type MockState struct { +} + +func (m MockState) Name() string { + return "mock_state" +} + +func (m MockState) Cleanup() error { + return nil +} + +func TestManager_PersistState_SlowWrite(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + contextTimeout time.Duration + expectError bool + errorType error + }{ + { + name: "write completes before deadline", + contextTimeout: 1 * time.Second, + expectError: false, + }, + { + name: "write exceeds deadline", + contextTimeout: 0, + expectError: true, + errorType: context.DeadlineExceeded, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stateFile := filepath.Join(tmpDir, tt.name+"-state.json") + + file, err := os.Create(stateFile) + require.NoError(t, err) + defer file.Close() + + m := New(stateFile) + + // Register and update mock state + mockState := &MockState{} + m.RegisterState(mockState) + err = m.UpdateState(mockState) + require.NoError(t, err) + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), tt.contextTimeout) + defer cancel() + + // Attempt to persist state + err = m.PersistState(ctx) + + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) + assert.Len(t, m.dirty, 1) + } else { + assert.NoError(t, err) + assert.FileExists(t, stateFile) + assert.Empty(t, m.dirty) + } + }) + } +} From e07caa8f0514bcd81d6a9dff0feedda286a80224 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 14:18:47 +0100 Subject: [PATCH 05/11] Remove unused function --- .../routemanager/systemops/systemops_generic.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index d1e1bf0fd8a..f8b3ebbb8c1 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { - // remove from state even if we have trouble removing it from the route table + // update state even if we have trouble removing it from the route table // it could be already gone r.updateState(stateManager) @@ -75,6 +75,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } +// updateState updates state on every change so it will be persisted regularly func (r *SysOps) updateState(stateManager *statemanager.Manager) { if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) @@ -528,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } - -func getState(stateManager *statemanager.Manager) *ShutdownState { - var shutdownState *ShutdownState - if state := stateManager.GetState(shutdownState); state != nil { - shutdownState = state.(*ShutdownState) - } else { - shutdownState = &ShutdownState{} - } - - return shutdownState -} From 9a56fc0137aa132cac83cf88236379f592a7012b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 15:30:26 +0100 Subject: [PATCH 06/11] Catch marshal panics --- client/internal/statemanager/manager.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 8b085b882d2..263806ab0ad 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -179,7 +179,7 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } - bs, err := json.MarshalIndent(m.states, "", " ") + bs, err := marshalWithPanicRecovery(m.states) if err != nil { return fmt.Errorf("marshal states: %w", err) } @@ -290,3 +290,19 @@ func (m *Manager) PerformCleanup() error { return nberrors.FormatErrorOrNil(merr) } + +func marshalWithPanicRecovery(v any) ([]byte, error) { + var bs []byte + var err error + + func() { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during marshal: %v", r) + } + }() + bs, err = json.MarshalIndent(v, "", " ") + }() + + return bs, err +} From 81f0810918a03b70a06b6ccc88dcd793469f5b5d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:14:24 +0100 Subject: [PATCH 07/11] Don't prettify json --- client/internal/statemanager/manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 263806ab0ad..7c9d8742720 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -301,7 +301,7 @@ func marshalWithPanicRecovery(v any) ([]byte, error) { err = fmt.Errorf("panic during marshal: %v", r) } }() - bs, err = json.MarshalIndent(v, "", " ") + bs, err = json.Marshal(v) }() return bs, err From 41c9c395b45409fcb0837aa131c2976bbacddbf6 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:14:31 +0100 Subject: [PATCH 08/11] Add error context --- util/file.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/util/file.go b/util/file.go index f547cd76c41..e75c988de06 100644 --- a/util/file.go +++ b/util/file.go @@ -95,13 +95,13 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { // Check context before expensive operations if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write json start: %w", ctx.Err()) } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") if err != nil { - return err + return fmt.Errorf("marshal: %w", err) } return writeBytes(ctx, file, err, configDir, configFileName, bs) @@ -109,7 +109,7 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write bytes start: %w", ctx.Err()) } tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) @@ -145,7 +145,7 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c // Check context again if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("after temp file: %w", ctx.Err()) } if err = os.Rename(tempFileName, file); err != nil { From 3c581d8fdca1763c954a78f5f3a3eeed0fe0019f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:28:21 +0100 Subject: [PATCH 09/11] Don't use cancelled contexts --- client/internal/dns/server.go | 20 +++++++------------- client/internal/engine.go | 3 +-- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6c4dccae74a..f0277319cd5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,7 +7,6 @@ import ( "runtime" "strings" "sync" - "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { log.Error(err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - - // don't block go func() { - if err := s.stateManager.PersistState(ctx); err != nil { + // persist dns state right away + if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } }() @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } + go func() { + if err := s.stateManager.PersistState(s.ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + }() if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() diff --git a/client/internal/engine.go b/client/internal/engine.go index 190d795cdbe..cce69b6d79a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,7 +38,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -297,7 +296,7 @@ func (e *Engine) Stop() error { if err := e.stateManager.Stop(ctx); err != nil { return fmt.Errorf("failed to stop state manager: %w", err) } - if err := e.stateManager.PersistState(ctx); err != nil { + if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } From e1af05654b80f8e2247558217183dff2f024f775 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:30:24 +0100 Subject: [PATCH 10/11] Ignore deadline not supported error --- util/file.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/util/file.go b/util/file.go index e75c988de06..fcb1b5184a4 100644 --- a/util/file.go +++ b/util/file.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -120,9 +121,8 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c tempFileName := tempFile.Name() if deadline, ok := ctx.Deadline(); ok { - if err := tempFile.SetDeadline(deadline); err != nil { - //if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { - log.Warnf("failed to set write deadline: %v", err) + if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set deadline: %v", err) } } From eceab3669791a3b2fb6e67c8827b4b5c9786e5f0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 21:24:34 +0100 Subject: [PATCH 11/11] Avoid deadlock on stop --- client/internal/statemanager/manager.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 7c9d8742720..4feef38e09c 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -71,18 +71,20 @@ func (m *Manager) Stop(ctx context.Context) error { return nil } + var cancel context.CancelFunc m.mu.Lock() - defer m.mu.Unlock() + cancel = m.cancel + m.mu.Unlock() - if m.cancel != nil { - m.cancel() + if cancel == nil { + return nil + } + cancel() - select { - case <-ctx.Done(): - return ctx.Err() - case <-m.done: - return nil - } + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: } return nil