Skip to content

Commit

Permalink
[client] Improve state write timeout and abort work early on timeout (#…
Browse files Browse the repository at this point in the history
…2882)

* Improve state write timeout and abort work early on timeout

* Don't block on initial persist state
  • Loading branch information
lixmal authored Nov 13, 2024
1 parent 20a5afc commit 39329e1
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 53 deletions.
8 changes: 5 additions & 3 deletions client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}

// persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()

return nil
}
Expand Down
8 changes: 5 additions & 3 deletions client/firewall/nftables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}

// persist early
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()

return nil
}
Expand Down
6 changes: 3 additions & 3 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err
}

Expand All @@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {

// WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config)
return util.WriteJson(context.Background(), path, config)
}

// createNewConfig creates a new config generating a new Wireguard key and saving to file
Expand Down Expand Up @@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
}

if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil {
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
Expand Down
10 changes: 7 additions & 3 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}

// don't block
go func() {
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()

if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
Expand Down
20 changes: 5 additions & 15 deletions client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
)

// State interface defines the methods that all state types must implement
Expand Down Expand Up @@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}

ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

done := make(chan error, 1)

start := time.Now()
go func() {
data, err := json.MarshalIndent(m.states, "", " ")
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}

// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}

done <- nil
done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
}()

select {
Expand All @@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
}
}

log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))

clear(m.dirty)

Expand Down
22 changes: 5 additions & 17 deletions client/internal/statemanager/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,20 @@ import (
"os"
"path/filepath"
"runtime"

log "github.com/sirupsen/logrus"
)

// GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
// It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string {
var path string

switch runtime.GOOS {
case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux":
path = "/var/lib/netbird/state.json"
return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
return "/var/db/netbird/state.json"
}

dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return ""
}
return ""

return path
}
2 changes: 1 addition & 1 deletion management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now()
err := util.WriteJson(file, s)
err := util.WriteJson(context.Background(), file, s)
if err != nil {
return err
}
Expand Down
23 changes: 18 additions & 5 deletions util/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
Expand All @@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
return err
}

return writeJson(file, obj, configDir, configFileName)
return writeJson(ctx, file, obj, configDir, configFileName)
}

// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
func WriteJson(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}

return writeJson(file, obj, configDir, configFileName)
return writeJson(ctx, file, obj, configDir, configFileName)
}

// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
Expand Down Expand Up @@ -79,14 +79,22 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}

func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations
if ctx.Err() != nil {
return ctx.Err()
}

// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}

if ctx.Err() != nil {
return ctx.Err()
}

tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
Expand All @@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
return err
}

// Check context again
if ctx.Err() != nil {
return ctx.Err()
}

err = os.Rename(tempFileName, file)
if err != nil {
return err
Expand Down
7 changes: 4 additions & 3 deletions util/file_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package util

import (
"context"
"crypto/md5"
"encoding/hex"
"io"
Expand Down Expand Up @@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()

err := WriteJson(tmpDir+"/testconfig.json", tt.config)
err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err)

read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
Expand Down Expand Up @@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"

err := WriteJson(src, tt.srcContent)
err := WriteJson(context.Background(), src, tt.srcContent)
require.NoError(t, err)

err = CopyFileContents(src, dst)
Expand Down Expand Up @@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
_ = os.Remove(cfgFile)
}()

err := WriteJson(cfgFile, tt.config)
err := WriteJson(context.Background(), cfgFile, tt.config)
require.NoError(t, err)

read, err := ReadJson(cfgFile, &TestConfig{})
Expand Down

0 comments on commit 39329e1

Please sign in to comment.