From 63156440652afb5eaf8beaadfb068a4c51036ca6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 13:04:36 +0300 Subject: [PATCH 1/6] Add peer store methods Signed-off-by: bcmmbaga --- management/server/sql_store.go | 89 +++++++++++++++++++++++++++++++++- management/server/store.go | 5 ++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 2f951cd2e1..979c7842d6 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1068,7 +1068,15 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { - return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { @@ -1112,6 +1120,85 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng return peersMap, nil } +// GetAccountPeerDNSLabels retrieves all unique DNS labels for peers associated with a specified account. +func (s *SqlStore) GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + var labels []string + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where(accountIDCondition, accountID).Pluck("dns_label", &labels) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + } + + return labels, nil +} + +// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store") + } + + return peers, nil +} + +// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store") + } + + return peers, nil +} + +// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing. +func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) { + var allEphemeralPeers, batchPeers []*nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("ephemeral = ?", true). + FindInBatches(&batchPeers, 1000, func(tx *gorm.DB, batch int) error { + allEphemeralPeers = append(allEphemeralPeers, batchPeers...) + return nil + }) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error) + return nil, fmt.Errorf("failed to retrieve ephemeral peers") + } + + return allEphemeralPeers, nil +} + +// DeletePeer removes a peer from the store. +func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete peer from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "peer not found") + } + + return nil +} + func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) diff --git a/management/server/store.go b/management/server/store.go index b16ad8a1aa..6e49a494b6 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -94,6 +94,7 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error @@ -101,9 +102,13 @@ type Store interface { GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) + GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error + DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error From 8420a525633ce9542be4257e93cadc6ed933b7da Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 13:04:49 +0300 Subject: [PATCH 2/6] Refactor ephemeral peers Signed-off-by: bcmmbaga --- management/server/ephemeral.go | 44 ++++++++++++++--------------- management/server/ephemeral_test.go | 18 +++++------- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 590b1d708b..6e245ec5ac 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -20,10 +20,10 @@ var ( ) type ephemeralPeer struct { - id string - account *Account - deadline time.Time - next *ephemeralPeer + id string + accountID string + deadline time.Time + next *ephemeralPeer } // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it @@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) - a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err) - return - } - e.peersLock.Lock() defer e.peersLock.Unlock() @@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. return } - e.addPeer(peer.ID, a, newDeadLine()) + e.addPeer(peer.AccountID, peer.ID, newDeadLine()) if e.timer == nil { e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { e.cleanup(ctx) @@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - accounts := e.store.GetAllAccounts(context.Background()) + peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare) + if err != nil { + log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) + return + } + t := newDeadLine() count := 0 - for _, a := range accounts { - for id, p := range a.Peers { - if p.Ephemeral { - count++ - e.addPeer(id, a, t) - } + for _, p := range peers { + if p.Ephemeral { + count++ + e.addPeer(p.AccountID, p.ID, t) } } + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) } @@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { for id, p := range deletePeers { log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) + err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) } } } -func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { +func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { ep := &ephemeralPeer{ - id: id, - account: account, - deadline: deadline, + id: peerID, + accountID: accountID, + deadline: deadline, } if e.headPeer == nil { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 1390352a5d..00e5d777a7 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,7 +7,6 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" ) type MockStore struct { @@ -15,17 +14,14 @@ type MockStore struct { account *Account } -func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { - return []*Account{s.account} -} - -func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { - _, ok := s.account.Peers[peerId] - if ok { - return s.account, nil +func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ LockingStrength) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + for _, v := range s.account.Peers { + if v.Ephemeral { + peers = append(peers, v) + } } - - return nil, status.NewPeerNotFoundError(peerId) + return peers, nil } type MocAccountManager struct { From f5e7449d01ab7e0906b6630c596fcc0a7fd4557c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:24:51 +0300 Subject: [PATCH 3/6] Add lock for peer store methods Signed-off-by: bcmmbaga --- management/server/sql_store.go | 29 +++++++++++++++++++++-------- management/server/store.go | 9 +++++---- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 979c7842d6..b921ed47d3 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -300,12 +300,12 @@ func (s *SqlStore) GetInstallationID() string { return installation.InstallationIDValue } -func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { +func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -355,7 +355,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID return nil } -func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { +func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -363,7 +363,7 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe "peer_status_last_seen", "peer_status_connected", "peer_status_login_expired", "peer_status_required_approval", } - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Select(fieldsToUpdate). Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) @@ -378,14 +378,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe return nil } -func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { +func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, // updating the struct ensures the correct data format is inserted into the database. peerCopy.Location = peerWithLocation.Location - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Updates(peerCopy) @@ -740,9 +740,10 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) return accountID, nil } -func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { +func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { var accountID string - result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&User{}). + Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -1066,6 +1067,18 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// GetAccountPeers retrieves peers for an account. +func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil +} + // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer diff --git a/management/server/store.go b/management/server/store.go index 6e49a494b6..9ecb9c1698 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -48,7 +48,7 @@ type Store interface { GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(userID string) (string, error) + GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later @@ -99,15 +99,16 @@ type Store interface { AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error + SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) From 7d849a92c0ce8436f52f1ea721a62b6d7f5534d3 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:32:34 +0300 Subject: [PATCH 4/6] Refactor peer handlers Signed-off-by: bcmmbaga --- management/server/http/peers_handler.go | 95 ++++++------- management/server/http/peers_handler_test.go | 141 +++++++++++-------- 2 files changed, 126 insertions(+), 110 deletions(-) diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index a5856a0e43..235e744b35 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) +func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { + peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { util.WriteError(ctx, err, w) return @@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee } dnsDomain := h.accountManager.GetDNSDomain() - groupsInfo := toGroupsInfo(account.Groups, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + if err != nil { + util.WriteError(ctx, err, w) + return + } + groupsInfo := toGroupsInfo(peerGroups) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } @@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) + peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) if err != nil { util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + if err != nil { + util.WriteError(ctx, err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { case http.MethodDelete: h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodGet, http.MethodPut: - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if r.Method == http.MethodGet { - h.getPeer(r.Context(), account, peerID, userID, w) - } else { - h.updatePeer(r.Context(), account, userID, peerID, w, r) - } + case http.MethodGet: + h.getPeer(r.Context(), accountID, peerID, userID, w) + return + case http.MethodPut: + h.updatePeer(r.Context(), accountID, userID, peerID, w, r) return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(account.Peers)) - for _, peer := range account.Peers { + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + + peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(account) + validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request } } - dnsDomain := h.accountManager.GetDNSDomain() - - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) + dnsDomain := h.accountManager.GetDNSDomain() + + customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) @@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - var groupsInfo []api.GroupMinimum - groupsChecked := make(map[string]struct{}) +func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum { + groupsInfo := make([]api.GroupMinimum, 0, len(groups)) for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } + groupsInfo = append(groupsInfo, api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + }) } return groupsInfo } diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index dd49c03b84..9279fc5361 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -39,6 +39,68 @@ const ( ) func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer.Copy() + } + + policy := &server.Policy{ + ID: "policy", + AccountID: "test_id", + Name: "policy", + Enabled: true, + Rules: []*server.PolicyRule{ + { + ID: "rule", + Name: "rule", + Enabled: true, + Action: "accept", + Destinations: []string{"group1"}, + Sources: []string{"group1"}, + Bidirectional: true, + Protocol: "all", + Ports: []string{"80"}, + }, + }, + } + + srvUser := server.NewRegularUser(serviceUser) + srvUser.IsServiceUser = true + + account := &server.Account{ + Id: "test_id", + Domain: "hotmail.com", + Peers: peersMap, + Users: map[string]*server.User{ + adminUser: server.NewAdminUser(adminUser), + regularUser: server.NewRegularUser(regularUser), + serviceUser: srvUser, + }, + Groups: map[string]*nbgroup.Group{ + "group1": { + ID: "group1", + AccountID: "test_id", + Name: "group1", + Issued: "api", + Peers: maps.Keys(peersMap), + }, + }, + Settings: &server.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour, + }, + Policies: []*server.Policy{policy}, + Network: &server.Network{ + Identifier: "ciclqisab2ss43jdn8q0", + Net: net.IPNet{ + IP: net.ParseIP("100.67.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + Serial: 51, + }, + } + return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -67,74 +129,31 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + peersID := make([]string, len(peers)) + for _, peer := range peers { + peersID = append(peersID, peer.ID) + } + return []*nbgroup.Group{ + { + ID: "group1", + AccountID: accountID, + Name: "group1", + Issued: "api", + Peers: peersID, + }, + }, nil + }, GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, + GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) { + return account, nil + }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - peersMap := make(map[string]*nbpeer.Peer) - for _, peer := range peers { - peersMap[peer.ID] = peer.Copy() - } - - policy := &server.Policy{ - ID: "policy", - AccountID: accountID, - Name: "policy", - Enabled: true, - Rules: []*server.PolicyRule{ - { - ID: "rule", - Name: "rule", - Enabled: true, - Action: "accept", - Destinations: []string{"group1"}, - Sources: []string{"group1"}, - Bidirectional: true, - Protocol: "all", - Ports: []string{"80"}, - }, - }, - } - - srvUser := server.NewRegularUser(serviceUser) - srvUser.IsServiceUser = true - - account := &server.Account{ - Id: accountID, - Domain: "hotmail.com", - Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), - serviceUser: srvUser, - }, - Groups: map[string]*nbgroup.Group{ - "group1": { - ID: "group1", - AccountID: accountID, - Name: "group1", - Issued: "api", - Peers: maps.Keys(peersMap), - }, - }, - Settings: &server.Settings{ - PeerLoginExpirationEnabled: true, - PeerLoginExpiration: time.Hour, - }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ - Identifier: "ciclqisab2ss43jdn8q0", - Net: net.IPNet{ - IP: net.ParseIP("100.67.0.0"), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - Serial: 51, - }, - } - return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { From c557c983908ca2eaf9cec3943c511332626860cf Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:33:57 +0300 Subject: [PATCH 5/6] Refactor peer to use store methods Signed-off-by: bcmmbaga --- management/server/account.go | 105 ++- management/server/integrated_validator.go | 39 +- management/server/mock_server/account_mock.go | 24 +- management/server/peer.go | 664 +++++++++++------- management/server/peer/peer.go | 2 +- management/server/user.go | 36 +- 6 files changed, 543 insertions(+), 327 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 5e9d6ebbc1..4222179d95 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -92,7 +92,7 @@ type AccountManager interface { GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) @@ -112,6 +112,7 @@ type AccountManager interface { DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error + GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error @@ -134,7 +135,7 @@ type AccountManager interface { GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -145,7 +146,7 @@ type AccountManager interface { GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(account *Account) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error @@ -1160,17 +1161,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) + err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err } @@ -1185,21 +1186,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { event := activity.AccountPeerInactivityExpirationEnabled if !newSettings.PeerInactivityExpirationEnabled { event = activity.AccountPeerInactivityExpirationDisabled am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } return nil @@ -1207,73 +1208,64 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) - return account.GetNextPeerExpiration() + return 0, false } - expiredPeers := account.GetExpiredPeers() var peerIDs []string for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) } - log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextPeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) + return 0, false } - return account.GetNextPeerExpiration() + return am.getNextPeerExpiration(ctx, accountID) } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { - am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextPeerExpiration(); ok { - go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { + go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } } // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { - log.Errorf("failed getting account %s expiring peers", accountID) - return account.GetNextInactivePeerExpiration() + log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) + return 0, false } - expiredPeers := account.GetInactivePeers() var peerIDs []string - for _, peer := range expiredPeers { + for _, peer := range inactivePeers { peerIDs = append(peerIDs, peer.ID) } - log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextInactivePeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", accountID) + return 0, false } - return account.GetNextInactivePeerExpiration() + return am.getNextInactivePeerExpiration(ctx, accountID) } } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions -func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { - am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { - go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) { + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID)) } } @@ -1409,7 +1401,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -2188,7 +2180,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -2235,7 +2227,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -2292,17 +2284,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, nil, nil, status.NewGetAccountError(err) - } - - peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) + peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -2316,12 +2303,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return status.NewGetAccountError(err) - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) + err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -2339,12 +2321,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) + _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 0c70b702a0..1692507dad 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -4,6 +4,8 @@ import ( "context" "errors" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" @@ -73,6 +75,39 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { - return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + var err error + var groups []*nbgroup.Group + var peers []*nbpeer.Peer + var settings *Settings + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + peers, err = transaction.GetAccountPeers(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + return err + }) + if err != nil { + return nil, err + } + + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 46a4fbc1fa..e1a84b4f9c 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,6 +47,7 @@ type MockAccountManager struct { DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) @@ -90,7 +91,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -130,7 +131,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + account, err := am.GetAccountFunc(ctx, accountID) + if err != nil { + return nil, err + } + approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} @@ -221,7 +227,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } @@ -682,9 +688,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { - return am.SyncPeerFunc(ctx, sync, account) + return am.SyncPeerFunc(ctx, sync, accountID) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } @@ -831,3 +837,11 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) } return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") } + +// GetPeerGroups mocks GetPeerGroups of the AccountManager interface +func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) { + if am.GetPeerGroupsFunc != nil { + return am.GetPeerGroupsFunc(ctx, accountID, peerID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index a941f404fc..ba79a5b480 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -11,8 +11,10 @@ import ( "sync" "time" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" @@ -53,43 +55,55 @@ type PeerLogin struct { // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + if user.IsRegularUser() && settings.RegularUsersViewBlocked { + return []*nbpeer.Peer{}, nil + } + + accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } + peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) - regularUser := !user.HasAdminPower() && !user.IsServiceUser - - if regularUser && account.Settings.RegularUsersViewBlocked { - return peers, nil - } - - for _, peer := range account.Peers { - if regularUser && user.Id != peer.UserID { + for _, peer := range accountPeers { + if user.IsRegularUser() && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin continue } - p := peer.Copy() - peers = append(peers, p) - peersMap[peer.ID] = p + peers = append(peers, peer) + peersMap[peer.ID] = peer } - if !regularUser { + if user.IsAdminOrServiceUser() { return peers, nil } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, err + } + + approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + if err != nil { + return nil, err + } + // fetch all the peers that have access to the user's peers for _, peer := range peers { aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) @@ -98,48 +112,46 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } } - peers = make([]*nbpeer.Peer, 0, len(peersMap)) - for _, peer := range peersMap { - peers = append(peers, peer) - } - - return peers, nil + return maps.Values(peersMap), nil } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { - peer, err := account.FindPeerByPubKey(peerPubKey) +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey) if err != nil { - return fmt.Errorf("failed to find peer by pub key: %w", err) + return err } - expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { - return fmt.Errorf("failed to update peer status and location: %w", err) + return err } - log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected) + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID) + if err != nil { + return err + } if peer.AddedWithSSOLogin() { - if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account.Id) + am.updateAccountPeers(ctx, accountID) } return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -159,18 +171,16 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context peer.Location.CountryCode = location.Country.ISOCode peer.Location.CityName = location.City.Names.En peer.Location.GeoNameID = location.City.GeonameID - err = am.Store.SavePeerLocation(account.Id, peer) + err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer) if err != nil { log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } - account.UpdatePeer(peer) - - err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus) if err != nil { - return false, fmt.Errorf("failed to save peer status: %w", err) + return false, err } return oldStatus.LoginExpired, nil @@ -181,37 +191,51 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - peer := account.GetPeer(update.ID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID) + if err != nil { + return nil, err } var requiresPeerUpdates bool - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) if err != nil { return nil, err } + var sshChanged, peerLabelChanged, loginExpirationChanged, inactivityExpirationChanged bool + if peer.SSHEnabled != update.SSHEnabled { peer.SSHEnabled = update.SSHEnabled - event := activity.PeerSSHEnabled - if !update.SSHEnabled { - event = activity.PeerSSHDisabled - } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + sshChanged = true } - peerLabelUpdated := peer.Name != update.Name - - if peerLabelUpdated { + if peer.Name != update.Name { peer.Name = update.Name + peerLabelChanged = true - existingLabels := account.getPeerDNSLabels() + existingLabels, err := am.getPeerDNSLabels(ctx, accountID) + if err != nil { + return nil, err + } newLabel, err := getPeerHostLabel(peer.Name, existingLabels) if err != nil { @@ -219,108 +243,69 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } peer.DNSLabel = newLabel - - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) } if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { - if !peer.AddedWithSSOLogin() { return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") } - peer.LoginExpirationEnabled = update.LoginExpirationEnabled - - event := activity.PeerLoginExpirationEnabled - if !update.LoginExpirationEnabled { - event = activity.PeerLoginExpirationDisabled - } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) - } + loginExpirationChanged = true } if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { - if !peer.AddedWithSSOLogin() { - return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated") } - peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled - - event := activity.PeerInactivityExpirationEnabled - if !update.InactivityExpirationEnabled { - event = activity.PeerInactivityExpirationDisabled - } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - - if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) - } + inactivityExpirationChanged = true } - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { return nil, err } - if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, accountID) + if sshChanged { + event := activity.PeerSSHEnabled + if !peer.SSHEnabled { + event = activity.PeerSSHDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - return peer, nil -} - -// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { + if peerLabelChanged { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + } - // the first loop is needed to ensure all peers present under the account before modifying, otherwise - // we might have some inconsistencies - peers := make([]*nbpeer.Peer, 0, len(peerIDs)) - for _, peerID := range peerIDs { + if loginExpirationChanged { + event := activity.PeerLoginExpirationEnabled + if !peer.LoginExpirationEnabled { + event = activity.PeerLoginExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - peers = append(peers, peer) } - // the 2nd loop performs the actual modification - for _, peer := range peers { + if inactivityExpirationChanged { + event := activity.PeerInactivityExpirationEnabled + if !peer.InactivityExpirationEnabled { + event = activity.PeerInactivityExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID) - if err != nil { - return err + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } + } - account.DeletePeer(peer.ID) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, - &UpdateMessage{ - Update: &proto.SyncResponse{ - // fill those field for backward compatibility - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - // new field - NetworkMap: &proto.NetworkMap{ - Serial: account.Network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - }, - }, - NetworkMap: &NetworkMap{}, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) - am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + if peerLabelChanged || requiresPeerUpdates { + am.updateAccountPeers(ctx, accountID) } - return nil + return peer, nil } // DeletePeer removes peer from the account by its IP @@ -328,24 +313,30 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID) if err != nil { return err } - updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) - if err != nil { - return err - } + var peer *nbpeer.Peer + var addPeerRemovedEvents []func() - err = am.deletePeers(ctx, account, []string{peerID}, userID) - if err != nil { - return err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID) + if err != nil { + return err + } - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return err + addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) + if err != nil { + return err + } + + return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + }) + + for _, addPeerRemovedEvent := range addPeerRemovedEvents { + addPeerRemovedEvent() } if updateAccountPeers { @@ -411,7 +402,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser := false if len(userID) > 0 { addedByUser = true - accountID, err = am.Store.GetAccountIDByUserID(userID) + accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) } else { accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } @@ -442,12 +433,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer - var groupsToAdd []string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { var setupKeyID string var setupKeyName string var ephemeral bool + var groupsToAdd []string if addedByUser { user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) if err != nil { @@ -590,39 +581,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s unlock() unlock = nil - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, status.NewGetAccountError(err) - } - - allGroup, err := account.GetGroupAll() - if err != nil { - return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) - } - groupsToAdd = append(groupsToAdd, allGroup.ID) - - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID) if err != nil { return nil, nil, nil, err } - if newGroupsAffectsPeers { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - approvedPeersMap, err := am.GetValidatedPeers(account) - if err != nil { - return nil, nil, nil, err - } - - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) - if err != nil { - return nil, nil, nil, err - } - - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) - return newPeer, networkMap, postureChecks, nil + return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { @@ -645,16 +613,16 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() } if peer.UserID != "" { - user, err := account.FindUser(peer.UserID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get user: %w", err) + return nil, nil, nil, err } err = checkIfPeerOwnerIsBlocked(peer, user) @@ -663,52 +631,38 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } } - if peerLoginExpired(ctx, peer, account.Settings) { - return nil, nil, nil, status.NewPeerLoginExpiredError() + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err } - updated := peer.UpdateMetaIfNew(sync.Meta) - if updated { - err = am.Store.SavePeer(ctx, account.Id, peer) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) - } - - if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account.Id) - } + if peerLoginExpired(ctx, peer, settings) { + return nil, nil, nil, status.NewPeerLoginExpiredError() } - peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err) - } - - var postureChecks []*posture.Checks - - if peerNotValid { - emptyMap := &NetworkMap{ - Network: account.Network.Copy(), - } - return peer, emptyMap, postureChecks, nil + return nil, nil, nil, err } - if isStatusChanged { - am.updateAccountPeers(ctx, account.Id) + peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra) + if err != nil { + return nil, nil, nil, err } - validPeersMap, err := am.GetValidatedPeers(account) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) + updated := peer.UpdateMetaIfNew(sync.Meta) + if updated { + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + if err != nil { + return nil, nil, nil, err + } } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) - if err != nil { - return nil, nil, nil, err + if isStatusChanged || (updated && sync.UpdateAccountPeers) { + am.updateAccountPeers(ctx, accountID) } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } // LoginPeer logs in or registers a peer. @@ -814,7 +768,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if shouldStorePeer { - err = am.Store.SavePeer(ctx, accountID, peer) + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) if err != nil { return nil, nil, nil, err } @@ -823,16 +777,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer() unlockPeer = nil - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - if updateRemotePeers || isStatusChanged { am.updateAccountPeers(ctx, accountID) } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) + return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO @@ -864,22 +813,30 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - var postureChecks []*posture.Checks - +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { if isRequiresApproval { + network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err + } + emptyMap := &NetworkMap{ - Network: account.Network.Copy(), + Network: network.Copy(), } return peer, emptyMap, nil, nil } - approvedPeersMap, err := am.GetValidatedPeers(account) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) + if err != nil { + return nil, nil, nil, err + } + + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, peer.ID) if err != nil { return nil, nil, nil, err } @@ -896,7 +853,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. peer = peer.UpdateLastLogin() - err = am.Store.SavePeer(ctx, peer.AccountID, peer) + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer) if err != nil { return err } @@ -943,41 +900,47 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + if user.IsRegularUser() && settings.RegularUsersViewBlocked { return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) } - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return nil, err } // if admin or user owns this peer, return peer - if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID { + if user.IsAdminOrServiceUser() || peer.UserID == userID { return peer, nil } // it is also possible that user doesn't own the peer but some of his peers have access to it, // this is a valid case, show the peer as well. - userPeers, err := account.FindUserPeers(userID) + userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID) + if err != nil { + return nil, err + } + + approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) if err != nil { return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, err } @@ -1006,12 +969,13 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) return } + peers := account.GetPeers() - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) return @@ -1037,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) return } @@ -1050,22 +1014,240 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } -func ConvertSliceToMap(existingLabels []string) map[string]struct{} { - labelMap := make(map[string]struct{}, len(existingLabels)) - for _, label := range existingLabels { - labelMap[label] = struct{}{} +// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) + return 0, false } - return labelMap + + if len(peersWithExpiry) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return 0, false + } + + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) + return 0, false + } + + if len(peersWithInactivity) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return 0, false + } + + var nextExpiry *time.Duration + for _, peer := range peersWithInactivity { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// getExpiredPeers returns peers that have been expired. +func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, peer := range peersWithExpiry { + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// getInactivePeers returns peers that have been expired by inactivity +func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, inactivePeer := range peersWithInactivity { + inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + + return peers, nil +} + +// GetPeerGroups returns groups that the peer is part of. +func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerGroups := make([]*nbgroup.Group, 0) + for _, group := range groups { + if slices.Contains(group.Peers, peerID) { + peerGroups = append(peerGroups, group) + } + } + + return peerGroups, nil +} + +// getPeerGroupIDs returns the IDs of the groups that the peer is part of. +func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) { + groups, err := am.GetPeerGroups(ctx, accountID, peerID) + if err != nil { + return nil, err + } + + groupIDs := make([]string, 0, len(groups)) + for _, group := range groups { + groupIDs = append(groupIDs, group.ID) + } + + return groupIDs, err +} + +func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) { + dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + existingLabels := make(lookupMap) + for _, label := range dnsLabels { + existingLabels[label] = struct{}{} + } + return existingLabels, nil } // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { - peerGroupIDs := make([]string, 0) - for _, group := range account.Groups { - if slices.Contains(group.Peers, peerID) { - peerGroupIDs = append(peerGroupIDs, group.ID) +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) { + peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID) + if err != nil { + return false, err + } + return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction +} + +// deletePeers deletes all specified peers and sends updates to the remote peers. +// Returns a slice of functions to save events after successful peer deletion. +func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { + var peerDeletedEvents []func() + + for _, peer := range peers { + if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { + return nil, err + } + + network, err := transaction.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err } + + if err = transaction.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil { + return nil, err + } + + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + }, + }, + NetworkMap: &NetworkMap{}, + }) + am.peersUpdateManager.CloseChannel(ctx, peer.ID) + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + }) + } + + return peerDeletedEvents, nil +} + +func ConvertSliceToMap(existingLabels []string) map[string]struct{} { + labelMap := make(map[string]struct{}, len(existingLabels)) + for _, label := range existingLabels { + labelMap[label] = struct{}{} } - return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) + return labelMap } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 34d7918446..146af88617 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -44,7 +44,7 @@ type Peer struct { // CreatedAt records the time the peer was created CreatedAt time.Time // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `gorm:"index"` // Geo location based on connection IP Location Location `gorm:"embedded;embeddedPrefix:location_"` } diff --git a/management/server/user.go b/management/server/user.go index 74062112af..823eaa311c 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -487,6 +487,10 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } delete(account.Users, targetUserID) + if updateAccountPeers { + account.Network.IncSerial() + } + err = am.Store.SaveAccount(ctx, account) if err != nil { return err @@ -511,12 +515,16 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU return false, nil } - peerIDs := make([]string, 0, len(peers)) - for _, peer := range peers { - peerIDs = append(peerIDs, peer.ID) + eventsToStore, err := deletePeers(ctx, am, am.Store, account.Id, initiatorUserID, peers) + if err != nil { + return false, err } - return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) + for _, storeEvent := range eventsToStore { + storeEvent() + } + + return hadPeers, nil } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. @@ -823,7 +831,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if len(expiredPeers) > 0 { - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, account.Id, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } @@ -1104,7 +1112,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -1115,16 +1123,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou } peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { - return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err) - } - - log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID) + if err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { + return err + } am.StoreEvent( ctx, - peer.UserID, peer.ID, account.Id, + peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), ) } @@ -1132,7 +1137,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account.Id) + am.updateAccountPeers(ctx, accountID) } return nil } @@ -1234,6 +1239,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account deletedUsersMeta[targetUserID] = meta } + if updateAccountPeers { + account.Network.IncSerial() + } err = am.Store.SaveAccount(ctx, account) if err != nil { return fmt.Errorf("failed to delete users: %w", err) From f6f7260897ac9b84e348fc34e15fe07fcb041586 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 14 Nov 2024 19:34:05 +0300 Subject: [PATCH 6/6] Fix tests Signed-off-by: bcmmbaga --- management/server/account_test.go | 19 ++++++------------- management/server/sql_store_test.go | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index c8c2d59410..a13b89f335 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1472,7 +1472,6 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - userID := "account_creator" account, err := createAccount(manager, "test_account", userID, "netbird.cloud") if err != nil { t.Fatal(err) @@ -1501,7 +1500,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID) + err = manager.DeletePeer(context.Background(), account.Id, peer.ID, userID) if err != nil { return } @@ -1523,7 +1522,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { assert.Equal(t, peer.Name, ev.Meta["name"]) assert.Equal(t, peer.FQDN(account.Domain), ev.Meta["fqdn"]) assert.Equal(t, userID, ev.InitiatorID) - assert.Equal(t, peer.IP.String(), ev.TargetID) + assert.Equal(t, peer.ID, ev.TargetID) assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } @@ -1853,13 +1852,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1927,11 +1923,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1962,7 +1955,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index b568b7fe03..7f36eb5061 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -400,7 +400,7 @@ func TestSqlite_SavePeer(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() - err = store.SavePeer(ctx, account.Id, peer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -416,7 +416,7 @@ func TestSqlite_SavePeer(t *testing.T) { updatedPeer.Status.Connected = false updatedPeer.Meta.Hostname = "updatedpeer" - err = store.SavePeer(ctx, account.Id, updatedPeer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -442,7 +442,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -461,7 +461,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -472,7 +472,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { newStatus.Connected = true - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -507,7 +507,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { Meta: nbpeer.PeerSystemMeta{}, } // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) account.Peers[peer.ID] = peer @@ -519,7 +519,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { peer.Location.CityName = "Berlin" peer.Location.GeoNameID = 2950159 - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID]) assert.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -529,7 +529,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { assert.Equal(t, peer.Location, actual) peer.ID = "non-existing-peer" - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -908,7 +908,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) // save new status of existing peer @@ -924,7 +924,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id)