Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Fix state manager race conditions #2890

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
20 changes: 7 additions & 13 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"runtime"
"strings"
"sync"
"time"

"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
Expand Down Expand Up @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err)
}

// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()

// don't block
go func() {
if err := s.stateManager.PersistState(ctx); err != nil {
// persist dns state right away
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()
Expand Down Expand Up @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}

// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
}()

if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
Expand Down
3 changes: 1 addition & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"


nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
Expand Down Expand Up @@ -297,7 +296,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(ctx); err != nil {
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}

Expand Down
23 changes: 10 additions & 13 deletions client/internal/routemanager/systemops/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,28 @@ package systemops

import (
"net/netip"
"sync"

"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)

type ShutdownState struct {
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
type ShutdownState ExclusionCounter

func (s *ShutdownState) Name() string {
return "route_state"
}

func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()

if s.Counter == nil {
return nil
}

sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter)
sysops.refCounter.LoadData((*ExclusionCounter)(s))

return sysops.refCounter.Flush()
}

func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}

func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}
20 changes: 3 additions & 17 deletions client/internal/routemanager/systemops/systemops_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, err
},
func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table
// update state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)

Expand All @@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return r.setupHooks(initAddresses)
}

// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager)

state.Counter = r.refCounter

if err := stateManager.UpdateState(state); err != nil {
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
Expand Down Expand Up @@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}

func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}

return shutdownState
}
44 changes: 33 additions & 11 deletions client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,20 @@ func (m *Manager) Stop(ctx context.Context) error {
return nil
}

var cancel context.CancelFunc
m.mu.Lock()
defer m.mu.Unlock()
cancel = m.cancel
m.mu.Unlock()

if m.cancel != nil {
m.cancel()
if cancel == nil {
return nil
}
cancel()

select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
}

return nil
Expand Down Expand Up @@ -179,14 +181,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}

bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do not unlock the m.mu? It blocks the Updated function calls in the worst case for 5 sec. After this point, the code does not touch the m.state so make no sense to protect it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to clean up the dirty map further down. If we unlock here we might clear dirty with new entries that haven't been yet written.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log running code (5 sec) really does not cause any issue outside of the statemanager?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can delay updating iptables/nftables/routes/dns on network map updates. Because of routes it might also delay p2p connections

ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

done := make(chan error, 1)

start := time.Now()
go func() {
done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be if this function running more then 10 sec? The ticker will start a PersistState call and will be a conflict in the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what the ctx check and deadline is for in this fn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the check and move is not an atomic operation. If the code runs parallel this lines with two different ctx the outcome is unpredictable.

	// Check context again
	if ctx.Err() != nil {
		return ctx.Err()
	}

	if err = os.Rename(tempFileName, file); err != nil {
		return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
	}
	```

}()

select {
Expand Down Expand Up @@ -286,3 +292,19 @@ func (m *Manager) PerformCleanup() error {

return nberrors.FormatErrorOrNil(merr)
}

func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error

func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()

return bs, err
}
82 changes: 82 additions & 0 deletions client/internal/statemanager/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package statemanager

import (
"context"
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// MockState implements the State interface for testing
type MockState struct {
}

func (m MockState) Name() string {
return "mock_state"
}

func (m MockState) Cleanup() error {
return nil
}

func TestManager_PersistState_SlowWrite(t *testing.T) {
tmpDir := t.TempDir()

tests := []struct {
name string
contextTimeout time.Duration
expectError bool
errorType error
}{
{
name: "write completes before deadline",
contextTimeout: 1 * time.Second,
expectError: false,
},
{
name: "write exceeds deadline",
contextTimeout: 0,
expectError: true,
errorType: context.DeadlineExceeded,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stateFile := filepath.Join(tmpDir, tt.name+"-state.json")

file, err := os.Create(stateFile)
require.NoError(t, err)
defer file.Close()

m := New(stateFile)

// Register and update mock state
mockState := &MockState{}
m.RegisterState(mockState)
err = m.UpdateState(mockState)
require.NoError(t, err)

// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), tt.contextTimeout)
defer cancel()

// Attempt to persist state
err = m.PersistState(ctx)

if tt.expectError {
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
assert.Len(t, m.dirty, 1)
} else {
assert.NoError(t, err)
assert.FileExists(t, stateFile)
assert.Empty(t, m.dirty)
}
})
}
}
55 changes: 39 additions & 16 deletions util/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
Expand All @@ -14,6 +15,19 @@
log "github.com/sirupsen/logrus"
)

func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return fmt.Errorf("prepare config file dir: %w", err)
}

if err = EnforcePermission(file); err != nil {
return fmt.Errorf("enfore permission: %w", err)

Check failure on line 25 in util/file.go

View workflow job for this annotation

GitHub Actions / codespell

enfore ==> enforce
}

return writeBytes(ctx, file, err, configDir, configFileName, bs)
}

// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
Expand Down Expand Up @@ -82,29 +96,44 @@
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations
if ctx.Err() != nil {
return ctx.Err()
return fmt.Errorf("write json start: %w", ctx.Err())
}

// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
return fmt.Errorf("marshal: %w", err)
}

return writeBytes(ctx, file, err, configDir, configFileName, bs)
}

func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
if ctx.Err() != nil {
return ctx.Err()
return fmt.Errorf("write bytes start: %w", ctx.Err())
}

tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
return fmt.Errorf("create temp: %w", err)
}

tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()

if deadline, ok := ctx.Deadline(); ok {
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
log.Warnf("failed to set deadline: %v", err)
}
}

_, err = tempFile.Write(bs)
if err != nil {
return err
_ = tempFile.Close()
return fmt.Errorf("write: %w", err)
}

if err = tempFile.Close(); err != nil {
return fmt.Errorf("close %s: %w", tempFileName, err)
}

defer func() {
Expand All @@ -114,19 +143,13 @@
}
}()

err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}

// Check context again
if ctx.Err() != nil {
return ctx.Err()
return fmt.Errorf("after temp file: %w", ctx.Err())
}

err = os.Rename(tempFileName, file)
if err != nil {
return err
if err = os.Rename(tempFileName, file); err != nil {
return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
}

return nil
Expand Down
Loading