diff --git a/management/server/account.go b/management/server/account.go index 4c150fd7ee5..23781c915ab 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -161,6 +161,8 @@ type DefaultAccountManager struct { eventStore activity.Store geo *geolocation.Geolocation + cache *AccountCache + // singleAccountMode indicates whether the instance has a single account. // If true, then every new user will end up under the same account. // This value will be set to false if management service has more than one account. @@ -967,6 +969,7 @@ func BuildManager( userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, + cache: NewAccountCache(ctx, store), } allAccounts := store.GetAllAccounts(ctx) // enable single account mode only if configured by user and number of existing accounts is not grater than 1 diff --git a/management/server/account_cache.go b/management/server/account_cache.go new file mode 100644 index 00000000000..13ce45819e6 --- /dev/null +++ b/management/server/account_cache.go @@ -0,0 +1,106 @@ +package server + +import ( + "context" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// AccountRequest holds the result channel to return the requested account. +type AccountRequest struct { + AccountID string + ResultChan chan *AccountResult +} + +// AccountResult holds the account data or an error. +type AccountResult struct { + Account *Account + Err error +} + +type AccountCache struct { + store Store + getAccountRequests map[string][]*AccountRequest + mu sync.Mutex + getAccountRequestCh chan *AccountRequest + bufferInterval time.Duration +} + +func NewAccountCache(ctx context.Context, store Store) *AccountCache { + bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL") + bufferInterval, err := time.ParseDuration(bufferIntervalStr) + if err != nil && bufferIntervalStr != "" { + log.WithContext(ctx).Warnf("failed to parse account cache buffer interval: %s", err) + bufferInterval = 300 * time.Millisecond + } + + log.WithContext(ctx).Infof("set account cache buffer interval to %s", bufferInterval) + + ac := AccountCache{ + store: store, + getAccountRequests: make(map[string][]*AccountRequest), + getAccountRequestCh: make(chan *AccountRequest), + bufferInterval: bufferInterval, + } + + go ac.processGetAccountRequests(ctx) + + return &ac +} +func (ac *AccountCache) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) { + req := &AccountRequest{ + AccountID: accountID, + ResultChan: make(chan *AccountResult, 1), + } + + log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID) + startTime := time.Now() + ac.getAccountRequestCh <- req + + result := <-req.ResultChan + log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) + return result.Account, result.Err +} + +func (ac *AccountCache) processGetAccountBatch(ctx context.Context, accountID string) { + ac.mu.Lock() + requests := ac.getAccountRequests[accountID] + delete(ac.getAccountRequests, accountID) + ac.mu.Unlock() + + if len(requests) == 0 { + return + } + + startTime := time.Now() + account, err := ac.store.GetAccount(ctx, accountID) + log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime)) + result := &AccountResult{Account: account, Err: err} + + for _, req := range requests { + req.ResultChan <- result + close(req.ResultChan) + } +} + +func (ac *AccountCache) processGetAccountRequests(ctx context.Context) { + for { + select { + case req := <-ac.getAccountRequestCh: + ac.mu.Lock() + ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req) + if len(ac.getAccountRequests[req.AccountID]) == 1 { + go func(ctx context.Context, accountID string) { + time.Sleep(ac.bufferInterval) + ac.processGetAccountBatch(ctx, accountID) + }(ctx, req.AccountID) + } + ac.mu.Unlock() + case <-ctx.Done(): + return + } + } +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index fe1e36d47a4..aa9c0d81ec1 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -3,13 +3,17 @@ package server import ( "context" "fmt" + "io" "net" "os" "path/filepath" "runtime" + "sync" + "sync/atomic" "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -24,6 +28,12 @@ import ( "github.com/netbirdio/netbird/util" ) +type TestingT interface { + require.TestingT + Helper() + Cleanup(func()) +} + var ( kaep = keepalive.EnforcementPolicy{ MinTime: 15 * time.Second, @@ -86,7 +96,7 @@ func Test_SyncProtocol(t *testing.T) { defer func() { os.Remove(filepath.Join(dir, "store.json")) //nolint }() - mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{ + mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -402,7 +412,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { +func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -485,7 +495,7 @@ func testSyncStatusRace(t *testing.T) { os.Remove(filepath.Join(dir, "store.json")) //nolint }() - mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{ + mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -545,7 +555,6 @@ func testSyncStatusRace(t *testing.T) { ctx2, cancelFunc2 := context.WithCancel(context.Background()) - //client. sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{ WgPubKey: concurrentPeerKey2.PublicKey().String(), Body: message2, @@ -574,7 +583,7 @@ func testSyncStatusRace(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) - //client. + // client. sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{ WgPubKey: peerWithInvalidStatus.PublicKey().String(), Body: message, @@ -626,3 +635,208 @@ func testSyncStatusRace(t *testing.T) { t.Fatal("Peer should be connected") } } + +func Test_LoginPerformance(t *testing.T) { + if os.Getenv("CI") == "true" { + t.Skip("Skipping on CI") + } + + t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") + + benchCases := []struct { + name string + peers int + accounts int + }{ + // {"XXS", 5, 1}, + // {"XS", 10, 1}, + // {"S", 100, 1}, + // {"M", 250, 1}, + // {"L", 500, 1}, + // {"XL", 750, 1}, + {"XXL", 1000, 5}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + for _, bc := range benchCases { + t.Run(bc.name, func(t *testing.T) { + t.Helper() + dir := t.TempDir() + err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) + if err != nil { + t.Fatal(err) + } + defer func() { + os.Remove(filepath.Join(dir, "store.json")) //nolint + }() + + mgmtServer, am, _, err := startManagementForTest(t, &Config{ + Stuns: []*Host{{ + Proto: "udp", + URI: "stun:stun.wiretrustee.com:3468", + }}, + TURNConfig: &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{}, + Secret: "whatever", + Turns: []*Host{{ + Proto: "udp", + URI: "turn:stun.wiretrustee.com:3468", + }}, + }, + Signal: &Host{ + Proto: "http", + URI: "signal.wiretrustee.com:10000", + }, + Datadir: dir, + HttpConfig: nil, + }) + if err != nil { + t.Fatal(err) + return + } + defer mgmtServer.GracefulStop() + + var counter int32 + var counterStart int32 + var wg sync.WaitGroup + var mu sync.Mutex + messageCalls := []func() error{} + for j := 0; j < bc.accounts; j++ { + wg.Add(1) + go func(j int, counter *int32, counterStart *int32) { + defer wg.Done() + + account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j)) + if err != nil { + t.Logf("account creation failed: %v", err) + return + } + + setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) + if err != nil { + t.Logf("error creating setup key: %v", err) + return + } + + for i := 0; i < bc.peers; i++ { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Logf("failed to generate key: %v", err) + return + } + + meta := &mgmtProto.PeerSystemMeta{ + Hostname: key.PublicKey().String(), + GoOS: runtime.GOOS, + OS: runtime.GOOS, + Core: "core", + Platform: "platform", + Kernel: "kernel", + WiretrusteeVersion: "", + } + + peerLogin := PeerLogin{ + WireGuardPubKey: key.String(), + SSHKey: "random", + Meta: extractPeerMeta(context.Background(), meta), + SetupKey: setupKey.Key, + ConnectionIP: net.IP{1, 1, 1, 1}, + } + + login := func() error { + _, _, _, err = am.LoginPeer(context.Background(), peerLogin) + if err != nil { + t.Logf("failed to login peer: %v", err) + return err + } + atomic.AddInt32(counter, 1) + if *counter%100 == 0 { + t.Logf("finished %d login calls", *counter) + } + return nil + } + + mu.Lock() + messageCalls = append(messageCalls, login) + mu.Unlock() + _, _, _, err = am.LoginPeer(context.Background(), peerLogin) + if err != nil { + t.Logf("failed to login peer: %v", err) + return + } + + atomic.AddInt32(counterStart, 1) + if *counterStart%100 == 0 { + t.Logf("registered %d peers", *counterStart) + } + } + }(j, &counter, &counterStart) + } + + wg.Wait() + + t.Logf("prepared %d login calls", len(messageCalls)) + testLoginPerformance(t, messageCalls) + + }) + } +} + +func testLoginPerformance(t *testing.T, loginCalls []func() error) { + t.Helper() + wgSetup := sync.WaitGroup{} + startChan := make(chan struct{}) + + wgDone := sync.WaitGroup{} + durations := []time.Duration{} + l := sync.Mutex{} + + for i, function := range loginCalls { + wgSetup.Add(1) + wgDone.Add(1) + go func(function func() error, i int) { + defer wgDone.Done() + wgSetup.Done() + + <-startChan + start := time.Now() + + err := function() + if err != nil { + t.Logf("Error: %v", err) + return + } + + duration := time.Since(start) + l.Lock() + durations = append(durations, duration) + l.Unlock() + }(function, i) + } + + wgSetup.Wait() + t.Logf("starting login calls") + close(startChan) + wgDone.Wait() + var tMin, tMax, tSum time.Duration + for i, d := range durations { + if i == 0 { + tMin = d + tMax = d + tSum = d + continue + } + if d < tMin { + tMin = d + } + if d > tMax { + tMax = d + } + tSum += d + } + tAvg := tSum / time.Duration(len(durations)) + t.Logf("Min: %v, Max: %v, Avg: %v", tMin, tMax, tAvg) +} diff --git a/management/server/peer.go b/management/server/peer.go index 93234d9dee6..c7d757bb479 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -714,7 +714,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer() unlockPeer = nil - account, err := am.Store.GetAccount(ctx, accountID) + account, err := am.cache.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, nil, nil, err }