From daaca0bd8f53d180ea3ad83b3f5c542694d694f0 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 20 Jun 2024 15:28:56 +0000 Subject: [PATCH] Use watcher and get rid of RefreshState() This change uses the database watcher to watch for changes to the github entities, credentials and controller info. Signed-off-by: Gabriel Adrian Samfira --- auth/instance_middleware.go | 17 ++- auth/interfaces.go | 10 +- cmd/garm-cli/cmd/github_credentials.go | 7 ++ database/sql/enterprise.go | 8 +- database/sql/organizations.go | 8 +- database/sql/repositories.go | 7 +- database/watcher/filters.go | 26 ++++ params/params.go | 54 +++++---- runner/common/mocks/PoolManager.go | 18 --- runner/common/pool.go | 2 - runner/enterprises.go | 6 +- runner/enterprises_test.go | 14 +-- runner/interfaces.go | 3 - runner/mocks/PoolManagerController.go | 90 -------------- runner/organizations.go | 6 +- runner/organizations_test.go | 46 ++++--- runner/pool/pool.go | 156 +++++++++--------------- runner/pool/stub_client.go | 57 +++++++++ runner/pool/util.go | 18 +++ runner/pool/watcher.go | 154 ++++++++++++++++++++++++ runner/repositories.go | 7 +- runner/repositories_test.go | 44 ++++--- runner/runner.go | 158 +++---------------------- 23 files changed, 453 insertions(+), 463 deletions(-) create mode 100644 runner/pool/stub_client.go create mode 100644 runner/pool/watcher.go diff --git a/auth/instance_middleware.go b/auth/instance_middleware.go index 0d3e4f2a..c21be3e7 100644 --- a/auth/instance_middleware.go +++ b/auth/instance_middleware.go @@ -46,7 +46,20 @@ type InstanceJWTClaims struct { jwt.RegisteredClaims } -func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) { +func NewInstanceTokenGetter(jwtSecret string) (InstanceTokenGetter, error) { + if jwtSecret == "" { + return nil, fmt.Errorf("jwt secret is required") + } + return &instanceToken{ + jwtSecret: jwtSecret, + }, nil +} + +type instanceToken struct { + jwtSecret string +} + +func (i *instanceToken) NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) { // Token expiration is equal to the bootstrap timeout set on the pool plus the polling // interval garm uses to check for timed out runners. Runners that have not sent their info // by the end of this interval are most likely failed and will be reaped by garm anyway. @@ -67,7 +80,7 @@ func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolTy CreateAttempt: instance.CreateAttempt, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(secret)) + tokenString, err := token.SignedString([]byte(i.jwtSecret)) if err != nil { return "", errors.Wrap(err, "signing token") } diff --git a/auth/interfaces.go b/auth/interfaces.go index fa5ca43c..4e4d370c 100644 --- a/auth/interfaces.go +++ b/auth/interfaces.go @@ -14,9 +14,17 @@ package auth -import "net/http" +import ( + "net/http" + + "github.com/cloudbase/garm/params" +) // Middleware defines an authentication middleware type Middleware interface { Middleware(next http.Handler) http.Handler } + +type InstanceTokenGetter interface { + NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) +} diff --git a/cmd/garm-cli/cmd/github_credentials.go b/cmd/garm-cli/cmd/github_credentials.go index 5f09a8e4..db2c8846 100644 --- a/cmd/garm-cli/cmd/github_credentials.go +++ b/cmd/garm-cli/cmd/github_credentials.go @@ -303,6 +303,10 @@ func parseCredentialsAddParams() (ret params.CreateGithubCredentialsParams, err func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error) { var updateParams params.UpdateGithubCredentialsParams + if credentialsAppInstallationID != 0 || credentialsAppID != 0 || credentialsPrivateKeyPath != "" { + updateParams.App = ¶ms.GithubApp{} + } + if credentialsName != "" { updateParams.Name = &credentialsName } @@ -312,6 +316,9 @@ func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error } if credentialsOAuthToken != "" { + if updateParams.PAT == nil { + updateParams.PAT = ¶ms.GithubPAT{} + } updateParams.PAT.OAuth2Token = credentialsOAuthToken } diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index c5af3bc4..30b42137 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -132,7 +132,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e } func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error { - enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials") + enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return errors.Wrap(err, "fetching enterprise") } @@ -206,17 +206,13 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, return errors.Wrap(q.Error, "saving enterprise") } - if creds.ID != 0 { - enterprise.Credentials = creds - } - return nil }) if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } - enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials") + enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 0f3d58a3..02ae5e62 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -123,7 +123,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio } func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) { - org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials") + org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return errors.Wrap(err, "fetching org") } @@ -198,17 +198,13 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para return errors.Wrap(q.Error, "saving org") } - if creds.ID != 0 { - org.Credentials = creds - } - return nil }) if err != nil { return params.Organization{}, errors.Wrap(err, "saving org") } - org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials") + org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return params.Organization{}, errors.Wrap(err, "updating enterprise") } diff --git a/database/sql/repositories.go b/database/sql/repositories.go index 5469950f..a08e815b 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -122,7 +122,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository, } func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) { - repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials") + repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return errors.Wrap(err, "fetching repo") } @@ -197,16 +197,13 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param return errors.Wrap(q.Error, "saving repo") } - if creds.ID != 0 { - repo.Credentials = creds - } return nil }) if err != nil { return params.Repository{}, errors.Wrap(err, "saving repo") } - repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials") + repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint") if err != nil { return params.Repository{}, errors.Wrap(err, "updating enterprise") } diff --git a/database/watcher/filters.go b/database/watcher/filters.go index 9b175d7a..ffff9320 100644 --- a/database/watcher/filters.go +++ b/database/watcher/filters.go @@ -32,6 +32,18 @@ func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc { } } +// WithAll returns a filter function that returns true if all of the provided filters return true. +func WithAll(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + for _, filter := range filters { + if !filter(payload) { + return false + } + } + return true + } +} + // WithEntityTypeFilter returns a filter function that filters payloads by entity type. // The filter function returns true if the payload's entity type matches the provided entity type. func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc { @@ -139,3 +151,17 @@ func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFun } } } + +// WithGithubCredentialsFilter returns a filter function that filters payloads by Github credentials. +func WithGithubCredentialsFilter(creds params.GithubCredentials) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + if payload.EntityType != dbCommon.GithubCredentialsEntityType { + return false + } + credsPayload, ok := payload.Payload.(params.GithubCredentials) + if !ok { + return false + } + return credsPayload.ID == creds.ID + } +} diff --git a/params/params.go b/params/params.go index b17a2d14..e1ea8327 100644 --- a/params/params.go +++ b/params/params.go @@ -419,10 +419,13 @@ func (r Repository) GetEntity() (GithubEntity, error) { return GithubEntity{}, fmt.Errorf("repository has no ID") } return GithubEntity{ - ID: r.ID, - EntityType: GithubEntityTypeRepository, - Owner: r.Owner, - Name: r.Name, + ID: r.ID, + EntityType: GithubEntityTypeRepository, + Owner: r.Owner, + Name: r.Name, + PoolBalancerType: r.PoolBalancerType, + Credentials: r.Credentials, + WebhookSecret: r.WebhookSecret, }, nil } @@ -470,10 +473,12 @@ func (o Organization) GetEntity() (GithubEntity, error) { return GithubEntity{}, fmt.Errorf("organization has no ID") } return GithubEntity{ - ID: o.ID, - EntityType: GithubEntityTypeOrganization, - Owner: o.Name, - WebhookSecret: o.WebhookSecret, + ID: o.ID, + EntityType: GithubEntityTypeOrganization, + Owner: o.Name, + WebhookSecret: o.WebhookSecret, + PoolBalancerType: o.PoolBalancerType, + Credentials: o.Credentials, }, nil } @@ -517,10 +522,12 @@ func (e Enterprise) GetEntity() (GithubEntity, error) { return GithubEntity{}, fmt.Errorf("enterprise has no ID") } return GithubEntity{ - ID: e.ID, - EntityType: GithubEntityTypeEnterprise, - Owner: e.Name, - WebhookSecret: e.WebhookSecret, + ID: e.ID, + EntityType: GithubEntityTypeEnterprise, + Owner: e.Name, + WebhookSecret: e.WebhookSecret, + PoolBalancerType: e.PoolBalancerType, + Credentials: e.Credentials, }, nil } @@ -685,11 +692,6 @@ type Provider struct { // used by swagger client generated code type Providers []Provider -type UpdatePoolStateParams struct { - WebhookSecret string - InternalConfig *Internal -} - type PoolManagerStatus struct { IsRunning bool `json:"running"` FailureReason string `json:"failure_reason,omitempty"` @@ -788,15 +790,23 @@ type UpdateSystemInfoParams struct { } type GithubEntity struct { - Owner string `json:"owner"` - Name string `json:"name"` - ID string `json:"id"` - EntityType GithubEntityType `json:"entity_type"` - Credentials GithubCredentials `json:"credentials"` + Owner string `json:"owner"` + Name string `json:"name"` + ID string `json:"id"` + EntityType GithubEntityType `json:"entity_type"` + Credentials GithubCredentials `json:"credentials"` + PoolBalancerType PoolBalancerType `json:"pool_balancing_type"` WebhookSecret string `json:"-"` } +func (g GithubEntity) GetPoolBalancerType() PoolBalancerType { + if g.PoolBalancerType == "" { + return PoolBalancerTypeRoundRobin + } + return g.PoolBalancerType +} + func (g GithubEntity) LabelScope() string { switch g.EntityType { case GithubEntityTypeRepository: diff --git a/runner/common/mocks/PoolManager.go b/runner/common/mocks/PoolManager.go index 57f65861..bf1af0c0 100644 --- a/runner/common/mocks/PoolManager.go +++ b/runner/common/mocks/PoolManager.go @@ -152,24 +152,6 @@ func (_m *PoolManager) InstallWebhook(ctx context.Context, param params.InstallW return r0, r1 } -// RefreshState provides a mock function with given fields: param -func (_m *PoolManager) RefreshState(param params.UpdatePoolStateParams) error { - ret := _m.Called(param) - - if len(ret) == 0 { - panic("no return value specified for RefreshState") - } - - var r0 error - if rf, ok := ret.Get(0).(func(params.UpdatePoolStateParams) error); ok { - r0 = rf(param) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // RootCABundle provides a mock function with given fields: func (_m *PoolManager) RootCABundle() (params.CertificateBundle, error) { ret := _m.Called() diff --git a/runner/common/pool.go b/runner/common/pool.go index fe833826..68a7ddf0 100644 --- a/runner/common/pool.go +++ b/runner/common/pool.go @@ -53,8 +53,6 @@ type PoolManager interface { // a repo, org or enterprise, we determine the destination of that webhook, retrieve the pool manager // for it and call this function with the WorkflowJob as a parameter. HandleWorkflowJob(job params.WorkflowJob) error - // RefreshState allows us to update webhook secrets and configuration for a pool manager. - RefreshState(param params.UpdatePoolStateParams) error // DeleteRunner will attempt to remove a runner from the pool. If forceRemove is true, any error // received from the provider will be ignored and we will proceed to remove the runner from the database. diff --git a/runner/enterprises.go b/runner/enterprises.go index 7b12b245..6fb86f96 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -174,11 +174,9 @@ func (r *Runner) UpdateEnterprise(ctx context.Context, enterpriseID string, para return params.Enterprise{}, errors.Wrap(err, "updating enterprise") } - // Use the admin context in the pool manager. Any access control is already done above when - // updating the store. - poolMgr, err := r.poolManagerCtrl.UpdateEnterprisePoolManager(r.ctx, enterprise) + poolMgr, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise) if err != nil { - return params.Enterprise{}, fmt.Errorf("failed to update enterprise pool manager: %w", err) + return params.Enterprise{}, fmt.Errorf("failed to get enterprise pool manager: %w", err) } enterprise.PoolManagerStatus = poolMgr.Status() diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index 22946ae6..f912c8ef 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -45,7 +45,6 @@ type EnterpriseTestFixtures struct { CreateInstanceParams params.CreateInstanceParams UpdateRepoParams params.UpdateEntityParams UpdatePoolParams params.UpdatePoolParams - UpdatePoolStateParams params.UpdatePoolStateParams ErrMock error ProviderMock *runnerCommonMocks.Provider PoolMgrMock *runnerCommonMocks.PoolManager @@ -138,9 +137,6 @@ func (s *EnterpriseTestSuite) SetupTest() { Image: "test-images-updated", Flavor: "test-flavor-updated", }, - UpdatePoolStateParams: params.UpdatePoolStateParams{ - WebhookSecret: "test-update-repo-webhook-secret", - }, ErrMock: fmt.Errorf("mock error"), ProviderMock: providerMock, PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()), @@ -298,7 +294,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolMgrFailed() { } func (s *EnterpriseTestSuite) TestUpdateEnterprise() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil) param := s.Fixtures.UpdateRepoParams @@ -330,21 +326,21 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidCreds() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) _, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) + s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) } func (s *EnterpriseTestSuite) TestUpdateEnterpriseCreateEnterprisePoolMgrFailed() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) _, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) + s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) } func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { diff --git a/runner/interfaces.go b/runner/interfaces.go index 05ae9c0f..ff8129ed 100644 --- a/runner/interfaces.go +++ b/runner/interfaces.go @@ -24,7 +24,6 @@ import ( type RepoPoolManager interface { CreateRepoPoolManager(ctx context.Context, repo params.Repository, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) - UpdateRepoPoolManager(ctx context.Context, repo params.Repository) (common.PoolManager, error) GetRepoPoolManager(repo params.Repository) (common.PoolManager, error) DeleteRepoPoolManager(repo params.Repository) error GetRepoPoolManagers() (map[string]common.PoolManager, error) @@ -32,7 +31,6 @@ type RepoPoolManager interface { type OrgPoolManager interface { CreateOrgPoolManager(ctx context.Context, org params.Organization, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) - UpdateOrgPoolManager(ctx context.Context, org params.Organization) (common.PoolManager, error) GetOrgPoolManager(org params.Organization) (common.PoolManager, error) DeleteOrgPoolManager(org params.Organization) error GetOrgPoolManagers() (map[string]common.PoolManager, error) @@ -40,7 +38,6 @@ type OrgPoolManager interface { type EnterprisePoolManager interface { CreateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) - UpdateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise) (common.PoolManager, error) GetEnterprisePoolManager(enterprise params.Enterprise) (common.PoolManager, error) DeleteEnterprisePoolManager(enterprise params.Enterprise) error GetEnterprisePoolManagers() (map[string]common.PoolManager, error) diff --git a/runner/mocks/PoolManagerController.go b/runner/mocks/PoolManagerController.go index 2fa40b8e..2e680daa 100644 --- a/runner/mocks/PoolManagerController.go +++ b/runner/mocks/PoolManagerController.go @@ -343,96 +343,6 @@ func (_m *PoolManagerController) GetRepoPoolManagers() (map[string]common.PoolMa return r0, r1 } -// UpdateEnterprisePoolManager provides a mock function with given fields: ctx, enterprise -func (_m *PoolManagerController) UpdateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise) (common.PoolManager, error) { - ret := _m.Called(ctx, enterprise) - - if len(ret) == 0 { - panic("no return value specified for UpdateEnterprisePoolManager") - } - - var r0 common.PoolManager - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, params.Enterprise) (common.PoolManager, error)); ok { - return rf(ctx, enterprise) - } - if rf, ok := ret.Get(0).(func(context.Context, params.Enterprise) common.PoolManager); ok { - r0 = rf(ctx, enterprise) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(common.PoolManager) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, params.Enterprise) error); ok { - r1 = rf(ctx, enterprise) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// UpdateOrgPoolManager provides a mock function with given fields: ctx, org -func (_m *PoolManagerController) UpdateOrgPoolManager(ctx context.Context, org params.Organization) (common.PoolManager, error) { - ret := _m.Called(ctx, org) - - if len(ret) == 0 { - panic("no return value specified for UpdateOrgPoolManager") - } - - var r0 common.PoolManager - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, params.Organization) (common.PoolManager, error)); ok { - return rf(ctx, org) - } - if rf, ok := ret.Get(0).(func(context.Context, params.Organization) common.PoolManager); ok { - r0 = rf(ctx, org) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(common.PoolManager) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, params.Organization) error); ok { - r1 = rf(ctx, org) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// UpdateRepoPoolManager provides a mock function with given fields: ctx, repo -func (_m *PoolManagerController) UpdateRepoPoolManager(ctx context.Context, repo params.Repository) (common.PoolManager, error) { - ret := _m.Called(ctx, repo) - - if len(ret) == 0 { - panic("no return value specified for UpdateRepoPoolManager") - } - - var r0 common.PoolManager - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, params.Repository) (common.PoolManager, error)); ok { - return rf(ctx, repo) - } - if rf, ok := ret.Get(0).(func(context.Context, params.Repository) common.PoolManager); ok { - r0 = rf(ctx, repo) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(common.PoolManager) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, params.Repository) error); ok { - r1 = rf(ctx, repo) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // NewPoolManagerController creates a new instance of PoolManagerController. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewPoolManagerController(t interface { diff --git a/runner/organizations.go b/runner/organizations.go index ae2e853f..ac55de54 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -203,11 +203,9 @@ func (r *Runner) UpdateOrganization(ctx context.Context, orgID string, param par return params.Organization{}, errors.Wrap(err, "updating org") } - // Use the admin context in the pool manager. Any access control is already done above when - // updating the store. - poolMgr, err := r.poolManagerCtrl.UpdateOrgPoolManager(r.ctx, org) + poolMgr, err := r.poolManagerCtrl.GetOrgPoolManager(org) if err != nil { - return params.Organization{}, fmt.Errorf("updating org pool manager: %w", err) + return params.Organization{}, fmt.Errorf("failed to get org pool manager: %w", err) } org.PoolManagerStatus = poolMgr.Status() diff --git a/runner/organizations_test.go b/runner/organizations_test.go index 4d439b76..f7513234 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -34,22 +34,21 @@ import ( ) type OrgTestFixtures struct { - AdminContext context.Context - DBFile string - Store dbCommon.Store - StoreOrgs map[string]params.Organization - Providers map[string]common.Provider - Credentials map[string]params.GithubCredentials - CreateOrgParams params.CreateOrgParams - CreatePoolParams params.CreatePoolParams - CreateInstanceParams params.CreateInstanceParams - UpdateRepoParams params.UpdateEntityParams - UpdatePoolParams params.UpdatePoolParams - UpdatePoolStateParams params.UpdatePoolStateParams - ErrMock error - ProviderMock *runnerCommonMocks.Provider - PoolMgrMock *runnerCommonMocks.PoolManager - PoolMgrCtrlMock *runnerMocks.PoolManagerController + AdminContext context.Context + DBFile string + Store dbCommon.Store + StoreOrgs map[string]params.Organization + Providers map[string]common.Provider + Credentials map[string]params.GithubCredentials + CreateOrgParams params.CreateOrgParams + CreatePoolParams params.CreatePoolParams + CreateInstanceParams params.CreateInstanceParams + UpdateRepoParams params.UpdateEntityParams + UpdatePoolParams params.UpdatePoolParams + ErrMock error + ProviderMock *runnerCommonMocks.Provider + PoolMgrMock *runnerCommonMocks.PoolManager + PoolMgrCtrlMock *runnerMocks.PoolManagerController } type OrgTestSuite struct { @@ -139,9 +138,6 @@ func (s *OrgTestSuite) SetupTest() { Image: "test-images-updated", Flavor: "test-flavor-updated", }, - UpdatePoolStateParams: params.UpdatePoolStateParams{ - WebhookSecret: "test-update-repo-webhook-secret", - }, ErrMock: fmt.Errorf("mock error"), ProviderMock: providerMock, PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()), @@ -312,7 +308,7 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolMgrFailed() { } func (s *OrgTestSuite) TestUpdateOrganization() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil) org, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams) @@ -326,7 +322,7 @@ func (s *OrgTestSuite) TestUpdateOrganization() { func (s *OrgTestSuite) TestUpdateRepositoryBalancingType() { s.Fixtures.UpdateRepoParams.PoolBalancerType = params.PoolBalancerTypePack - s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil) param := s.Fixtures.UpdateRepoParams @@ -355,21 +351,21 @@ func (s *OrgTestSuite) TestUpdateOrganizationInvalidCreds() { } func (s *OrgTestSuite) TestUpdateOrganizationPoolMgrFailed() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) _, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(fmt.Sprintf("updating org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) + s.Require().Equal(fmt.Sprintf("failed to get org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) } func (s *OrgTestSuite) TestUpdateOrganizationCreateOrgPoolMgrFailed() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateOrgPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) _, err := s.Runner.UpdateOrganization(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.UpdateRepoParams) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(fmt.Sprintf("updating org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) + s.Require().Equal(fmt.Sprintf("failed to get org pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) } func (s *OrgTestSuite) TestCreateOrgPool() { diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 1de26244..c08c1e9c 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -35,6 +35,7 @@ import ( "github.com/cloudbase/garm-provider-common/util" "github.com/cloudbase/garm/auth" dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/params" "github.com/cloudbase/garm/runner/common" garmUtil "github.com/cloudbase/garm/util" @@ -61,16 +62,9 @@ const ( maxCreateAttempts = 5 ) -type urls struct { - callbackURL string - metadataURL string - webhookURL string - controllerWebhookURL string -} - -func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, cfgInternal params.Internal, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) { +func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, instanceTokenGetter auth.InstanceTokenGetter, providers map[string]common.Provider, store dbCommon.Store) (common.PoolManager, error) { ctx = garmUtil.WithContext(ctx, slog.Any("pool_mgr", entity.String()), slog.Any("pool_type", entity.EntityType)) - ghc, err := garmUtil.GithubClient(ctx, entity, cfgInternal.GithubCredentialsDetails) + ghc, err := garmUtil.GithubClient(ctx, entity, entity.Credentials) if err != nil { return nil, errors.Wrap(err, "getting github client") } @@ -79,38 +73,47 @@ func NewEntityPoolManager(ctx context.Context, entity params.GithubEntity, cfgIn return nil, errors.New("webhook secret is empty") } + controllerInfo, err := store.ControllerInfo() + if err != nil { + return nil, errors.Wrap(err, "getting controller info") + } + + consumerID := fmt.Sprintf("pool-manager-%s", entity.String()) + consumer, err := watcher.RegisterConsumer( + ctx, consumerID, + composeWatcherFilters(entity), + ) + if err != nil { + return nil, errors.Wrap(err, "registering consumer") + } + wg := &sync.WaitGroup{} keyMuxes := &keyMutex{} repo := &basePoolManager{ - ctx: ctx, - cfgInternal: cfgInternal, - entity: entity, - ghcli: ghc, - - store: store, - providers: providers, - controllerID: cfgInternal.ControllerID, - urls: urls{ - webhookURL: cfgInternal.BaseWebhookURL, - callbackURL: cfgInternal.InstanceCallbackURL, - metadataURL: cfgInternal.InstanceMetadataURL, - controllerWebhookURL: cfgInternal.ControllerWebhookURL, - }, - quit: make(chan struct{}), - credsDetails: cfgInternal.GithubCredentialsDetails, - wg: wg, - keyMux: keyMuxes, + ctx: ctx, + entity: entity, + ghcli: ghc, + controllerInfo: controllerInfo, + instanceTokenGetter: instanceTokenGetter, + + store: store, + providers: providers, + quit: make(chan struct{}), + wg: wg, + keyMux: keyMuxes, + consumer: consumer, } return repo, nil } type basePoolManager struct { - ctx context.Context - controllerID string - entity params.GithubEntity - ghcli common.GithubClient - cfgInternal params.Internal + ctx context.Context + entity params.GithubEntity + ghcli common.GithubClient + controllerInfo params.ControllerInfo + instanceTokenGetter auth.InstanceTokenGetter + consumer dbCommon.Consumer store dbCommon.Store @@ -118,13 +121,9 @@ type basePoolManager struct { tools []commonParams.RunnerApplicationDownload quit chan struct{} - credsDetails params.GithubCredentials - managerIsRunning bool managerErrorReason string - urls urls - mux sync.Mutex wg *sync.WaitGroup keyMux *keyMutex @@ -353,9 +352,9 @@ func (r *basePoolManager) updateTools() error { tools, err := r.FetchTools() if err != nil { slog.With(slog.Any("error", err)).ErrorContext( - r.ctx, "failed to update tools for repo") + r.ctx, "failed to update tools for entity", "entity", r.entity.String()) r.setPoolRunningState(false, err.Error()) - return fmt.Errorf("failed to update tools for repo %s: %w", r.entity.String(), err) + return fmt.Errorf("failed to update tools for entity %s: %w", r.entity.String(), err) } r.mux.Lock() r.tools = tools @@ -381,7 +380,7 @@ func (r *basePoolManager) cleanupOrphanedProviderRunners(runners []*github.Runne runnerNames := map[string]bool{} for _, run := range runners { - if !isManagedRunner(labelsFromRunner(run), r.controllerID) { + if !isManagedRunner(labelsFromRunner(run), r.controllerInfo.ControllerID.String()) { slog.DebugContext( r.ctx, "runner is not managed by a pool we manage", "runner_name", run.GetName()) @@ -457,7 +456,7 @@ func (r *basePoolManager) reapTimedOutRunners(runners []*github.Runner) error { runnersByName := map[string]*github.Runner{} for _, run := range runners { - if !isManagedRunner(labelsFromRunner(run), r.controllerID) { + if !isManagedRunner(labelsFromRunner(run), r.controllerInfo.ControllerID.String()) { slog.DebugContext( r.ctx, "runner is not managed by a pool we manage", "runner_name", run.GetName()) @@ -515,7 +514,7 @@ func (r *basePoolManager) cleanupOrphanedGithubRunners(runners []*github.Runner) poolInstanceCache := map[string][]commonParams.ProviderInstance{} g, ctx := errgroup.WithContext(r.ctx) for _, runner := range runners { - if !isManagedRunner(labelsFromRunner(runner), r.controllerID) { + if !isManagedRunner(labelsFromRunner(runner), r.controllerInfo.ControllerID.String()) { slog.DebugContext( r.ctx, "runner is not managed by a pool we manage", "runner_name", runner.GetName()) @@ -741,8 +740,8 @@ func (r *basePoolManager) AddRunner(ctx context.Context, poolID string, aditiona RunnerStatus: params.RunnerPending, OSArch: pool.OSArch, OSType: pool.OSType, - CallbackURL: r.urls.callbackURL, - MetadataURL: r.urls.metadataURL, + CallbackURL: r.controllerInfo.CallbackURL, + MetadataURL: r.controllerInfo.MetadataURL, CreateAttempt: 1, GitHubRunnerGroup: pool.GitHubRunnerGroup, AditionalLabels: aditionalLabels, @@ -832,7 +831,7 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error jwtValidity := pool.RunnerTimeout() entity := r.entity.String() - jwtToken, err := auth.NewInstanceJWTToken(instance, r.cfgInternal.JWTSecret, entity, pool.PoolType(), jwtValidity) + jwtToken, err := r.instanceTokenGetter.NewInstanceJWTToken(instance, entity, pool.PoolType(), jwtValidity) if err != nil { return errors.Wrap(err, "fetching instance jwt token") } @@ -852,7 +851,7 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error Image: pool.Image, ExtraSpecs: pool.ExtraSpecs, PoolID: instance.PoolID, - CACertBundle: r.credsDetails.CABundle, + CACertBundle: r.entity.Credentials.CABundle, GitHubRunnerGroup: instance.GitHubRunnerGroup, JitConfigEnabled: hasJITConfig, } @@ -954,7 +953,7 @@ func (r *basePoolManager) poolLabel(poolID string) string { } func (r *basePoolManager) controllerLabel() string { - return fmt.Sprintf("%s%s", controllerLabelPrefix, r.controllerID) + return fmt.Sprintf("%s%s", controllerLabelPrefix, r.controllerInfo.ControllerID.String()) } func (r *basePoolManager) updateArgsFromProviderInstance(providerInstance commonParams.ProviderInstance) params.UpdateInstanceParams { @@ -1525,6 +1524,7 @@ func (r *basePoolManager) Start() error { initialToolUpdate <- struct{}{} }() + go r.runWatcher() go func() { select { case <-r.quit: @@ -1552,37 +1552,6 @@ func (r *basePoolManager) Stop() error { return nil } -func (r *basePoolManager) RefreshState(param params.UpdatePoolStateParams) error { - r.mux.Lock() - - if param.WebhookSecret != "" { - r.entity.WebhookSecret = param.WebhookSecret - } - if param.InternalConfig != nil { - r.cfgInternal = *param.InternalConfig - r.urls = urls{ - webhookURL: r.cfgInternal.BaseWebhookURL, - callbackURL: r.cfgInternal.InstanceCallbackURL, - metadataURL: r.cfgInternal.InstanceMetadataURL, - controllerWebhookURL: r.cfgInternal.ControllerWebhookURL, - } - } - - ghc, err := garmUtil.GithubClient(r.ctx, r.entity, r.cfgInternal.GithubCredentialsDetails) - if err != nil { - return errors.Wrap(err, "getting github client") - } - r.ghcli = ghc - r.mux.Unlock() - - // Update the tools as soon as state is updated. This should revive a stopped pool manager - // or stop one if the supplied credentials are not okay. - if err := r.updateTools(); err != nil { - return fmt.Errorf("failed to update tools: %w", err) - } - return nil -} - func (r *basePoolManager) WebhookSecret() string { return r.entity.WebhookSecret } @@ -1688,7 +1657,7 @@ func (r *basePoolManager) consumeQueuedJobs() error { } poolsCache := poolsForTags{ - poolCacheType: r.PoolBalancerType(), + poolCacheType: r.entity.GetPoolBalancerType(), } slog.DebugContext( @@ -1812,7 +1781,7 @@ func (r *basePoolManager) consumeQueuedJobs() error { } func (r *basePoolManager) UninstallWebhook(ctx context.Context) error { - if r.urls.controllerWebhookURL == "" { + if r.controllerInfo.ControllerWebhookURL == "" { return errors.Wrap(runnerErrors.ErrBadRequest, "controller webhook url is empty") } @@ -1823,8 +1792,8 @@ func (r *basePoolManager) UninstallWebhook(ctx context.Context) error { var controllerHookID int64 var baseHook string - trimmedBase := strings.TrimRight(r.urls.webhookURL, "/") - trimmedController := strings.TrimRight(r.urls.controllerWebhookURL, "/") + trimmedBase := strings.TrimRight(r.controllerInfo.WebhookURL, "/") + trimmedController := strings.TrimRight(r.controllerInfo.ControllerWebhookURL, "/") for _, hook := range allHooks { hookInfo := hookToParamsHookInfo(hook) @@ -1859,7 +1828,7 @@ func (r *basePoolManager) InstallHook(ctx context.Context, req *github.Hook) (pa return params.HookInfo{}, errors.Wrap(err, "listing hooks") } - if err := validateHookRequest(r.cfgInternal.ControllerID, r.cfgInternal.BaseWebhookURL, allHooks, req); err != nil { + if err := validateHookRequest(r.controllerInfo.ControllerID.String(), r.controllerInfo.WebhookURL, allHooks, req); err != nil { return params.HookInfo{}, errors.Wrap(err, "validating hook request") } @@ -1879,7 +1848,7 @@ func (r *basePoolManager) InstallHook(ctx context.Context, req *github.Hook) (pa } func (r *basePoolManager) InstallWebhook(ctx context.Context, param params.InstallWebhookParams) (params.HookInfo, error) { - if r.urls.controllerWebhookURL == "" { + if r.controllerInfo.ControllerWebhookURL == "" { return params.HookInfo{}, errors.Wrap(runnerErrors.ErrBadRequest, "controller webhook url is empty") } @@ -1890,7 +1859,7 @@ func (r *basePoolManager) InstallWebhook(ctx context.Context, param params.Insta req := &github.Hook{ Active: github.Bool(true), Config: map[string]interface{}{ - "url": r.urls.controllerWebhookURL, + "url": r.controllerInfo.ControllerWebhookURL, "content_type": "json", "insecure_ssl": insecureSSL, "secret": r.WebhookSecret(), @@ -1978,21 +1947,14 @@ func (r *basePoolManager) GetGithubRunners() ([]*github.Runner, error) { return allRunners, nil } -func (r *basePoolManager) PoolBalancerType() params.PoolBalancerType { - if r.cfgInternal.PoolBalancerType == "" { - return params.PoolBalancerTypeRoundRobin - } - return r.cfgInternal.PoolBalancerType -} - func (r *basePoolManager) GithubURL() string { switch r.entity.EntityType { case params.GithubEntityTypeRepository: - return fmt.Sprintf("%s/%s/%s", r.cfgInternal.GithubCredentialsDetails.BaseURL, r.entity.Owner, r.entity.Name) + return fmt.Sprintf("%s/%s/%s", r.entity.Credentials.BaseURL, r.entity.Owner, r.entity.Name) case params.GithubEntityTypeOrganization: - return fmt.Sprintf("%s/%s", r.cfgInternal.GithubCredentialsDetails.BaseURL, r.entity.Owner) + return fmt.Sprintf("%s/%s", r.entity.Credentials.BaseURL, r.entity.Owner) case params.GithubEntityTypeEnterprise: - return fmt.Sprintf("%s/enterprises/%s", r.cfgInternal.GithubCredentialsDetails.BaseURL, r.entity.Owner) + return fmt.Sprintf("%s/enterprises/%s", r.entity.Credentials.BaseURL, r.entity.Owner) } return "" } @@ -2002,8 +1964,8 @@ func (r *basePoolManager) GetWebhookInfo(ctx context.Context) (params.HookInfo, if err != nil { return params.HookInfo{}, errors.Wrap(err, "listing hooks") } - trimmedBase := strings.TrimRight(r.urls.webhookURL, "/") - trimmedController := strings.TrimRight(r.urls.controllerWebhookURL, "/") + trimmedBase := strings.TrimRight(r.controllerInfo.WebhookURL, "/") + trimmedController := strings.TrimRight(r.controllerInfo.ControllerWebhookURL, "/") var controllerHookInfo *params.HookInfo var baseHookInfo *params.HookInfo @@ -2034,5 +1996,5 @@ func (r *basePoolManager) GetWebhookInfo(ctx context.Context) (params.HookInfo, } func (r *basePoolManager) RootCABundle() (params.CertificateBundle, error) { - return r.credsDetails.RootCertificateBundle() + return r.entity.Credentials.RootCertificateBundle() } diff --git a/runner/pool/stub_client.go b/runner/pool/stub_client.go new file mode 100644 index 00000000..df547501 --- /dev/null +++ b/runner/pool/stub_client.go @@ -0,0 +1,57 @@ +package pool + +import ( + "context" + + "github.com/google/go-github/v57/github" + + "github.com/cloudbase/garm/params" +) + +type stubGithubClient struct { + err error +} + +func (s *stubGithubClient) ListEntityHooks(_ context.Context, _ *github.ListOptions) ([]*github.Hook, *github.Response, error) { + return nil, nil, s.err +} + +func (s *stubGithubClient) GetEntityHook(_ context.Context, _ int64) (*github.Hook, error) { + return nil, s.err +} + +func (s *stubGithubClient) CreateEntityHook(_ context.Context, _ *github.Hook) (*github.Hook, error) { + return nil, s.err +} + +func (s *stubGithubClient) DeleteEntityHook(_ context.Context, _ int64) (*github.Response, error) { + return nil, s.err +} + +func (s *stubGithubClient) PingEntityHook(_ context.Context, _ int64) (*github.Response, error) { + return nil, s.err +} + +func (s *stubGithubClient) ListEntityRunners(_ context.Context, _ *github.ListOptions) (*github.Runners, *github.Response, error) { + return nil, nil, s.err +} + +func (s *stubGithubClient) ListEntityRunnerApplicationDownloads(_ context.Context) ([]*github.RunnerApplicationDownload, *github.Response, error) { + return nil, nil, s.err +} + +func (s *stubGithubClient) RemoveEntityRunner(_ context.Context, _ int64) (*github.Response, error) { + return nil, s.err +} + +func (s *stubGithubClient) CreateEntityRegistrationToken(_ context.Context) (*github.RegistrationToken, *github.Response, error) { + return nil, nil, s.err +} + +func (s *stubGithubClient) GetEntityJITConfig(_ context.Context, _ string, _ params.Pool, _ []string) (map[string]string, *github.Runner, error) { + return nil, nil, s.err +} + +func (s *stubGithubClient) GetWorkflowJobByID(_ context.Context, _, _ string, _ int64) (*github.WorkflowJob, *github.Response, error) { + return nil, nil, s.err +} diff --git a/runner/pool/util.go b/runner/pool/util.go index 8ceea49d..e2308160 100644 --- a/runner/pool/util.go +++ b/runner/pool/util.go @@ -10,6 +10,8 @@ import ( runnerErrors "github.com/cloudbase/garm-provider-common/errors" commonParams "github.com/cloudbase/garm-provider-common/params" + dbCommon "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" "github.com/cloudbase/garm/params" ) @@ -116,3 +118,19 @@ func isManagedRunner(labels []string, controllerID string) bool { runnerControllerID := controllerIDFromLabels(labels) return runnerControllerID == controllerID } + +func composeWatcherFilters(entity params.GithubEntity) dbCommon.PayloadFilterFunc { + // We want to watch for changes in either the controller or the + // entity itself. + return watcher.WithAny( + watcher.WithAll( + // Updates to the controller + watcher.WithEntityTypeFilter(dbCommon.ControllerEntityType), + watcher.WithOperationTypeFilter(dbCommon.UpdateOperation), + ), + // Any operation on the entity we're managing the pool for. + watcher.WithEntityFilter(entity), + // Watch for changes to the github credentials + watcher.WithGithubCredentialsFilter(entity.Credentials), + ) +} diff --git a/runner/pool/watcher.go b/runner/pool/watcher.go new file mode 100644 index 00000000..b50a85b2 --- /dev/null +++ b/runner/pool/watcher.go @@ -0,0 +1,154 @@ +package pool + +import ( + "log/slog" + + "github.com/pkg/errors" + + runnerErrors "github.com/cloudbase/garm-provider-common/errors" + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/params" + runnerCommon "github.com/cloudbase/garm/runner/common" + garmUtil "github.com/cloudbase/garm/util" +) + +// entityGetter is implemented by all github entities (repositories, organizations and enterprises) +type entityGetter interface { + GetEntity() (params.GithubEntity, error) +} + +func (r *basePoolManager) handleControllerUpdateEvent(controllerInfo params.ControllerInfo) { + r.mux.Lock() + defer r.mux.Unlock() + + slog.DebugContext(r.ctx, "updating controller info", "controller_info", controllerInfo) + r.controllerInfo = controllerInfo +} + +func (r *basePoolManager) getClientOrStub() runnerCommon.GithubClient { + var err error + var ghc runnerCommon.GithubClient + ghc, err = garmUtil.GithubClient(r.ctx, r.entity, r.entity.Credentials) + if err != nil { + slog.WarnContext(r.ctx, "failed to create github client", "error", err) + ghc = &stubGithubClient{ + err: errors.Wrapf(runnerErrors.ErrUnauthorized, "failed to create github client; please update credentials: %v", err), + } + } + return ghc +} + +func (r *basePoolManager) handleEntityUpdate(entity params.GithubEntity) { + slog.DebugContext(r.ctx, "received entity update", "entity", entity.ID) + credentialsUpdate := r.entity.Credentials.ID != entity.Credentials.ID + defer func() { + slog.DebugContext(r.ctx, "deferred tools update", "credentials_update", credentialsUpdate) + if !credentialsUpdate { + return + } + slog.DebugContext(r.ctx, "updating tools", "entity", entity.ID) + if err := r.updateTools(); err != nil { + slog.ErrorContext(r.ctx, "failed to update tools", "error", err) + } + }() + + slog.DebugContext(r.ctx, "updating entity", "entity", entity.ID) + r.mux.Lock() + slog.DebugContext(r.ctx, "lock acquired", "entity", entity.ID) + + r.entity = entity + if credentialsUpdate { + if r.consumer != nil { + filters := composeWatcherFilters(r.entity) + r.consumer.SetFilters(filters) + } + slog.DebugContext(r.ctx, "credentials update", "entity", entity.ID) + r.ghcli = r.getClientOrStub() + } + r.mux.Unlock() + slog.DebugContext(r.ctx, "lock released", "entity", entity.ID) +} + +func (r *basePoolManager) handleCredentialsUpdate(credentials params.GithubCredentials) { + // when we switch credentials on an entity (like from one app to another or from an app + // to a PAT), we may still get events for the previous credentials as the channel is buffered. + // The watcher will watch for changes to the entity itself, which includes events that + // change the credentials name on the entity, but we also watch for changes to the credentials + // themselves, like an updated PAT token set on existing credentials entity. + // The handleCredentialsUpdate function handles situations where we have changes on the + // credentials entity itself, not on the entity that the credentials are set on. + // For example, we may have a credentials entity called org_pat set on a repo called + // test-repo. This function would handle situations where "org_pat" is updated. + // If "test-repo" is updated with new credentials, that event is handled above in + // handleEntityUpdate. + shouldUpdateTools := r.entity.Credentials.ID == credentials.ID + defer func() { + if !shouldUpdateTools { + return + } + slog.DebugContext(r.ctx, "deferred tools update", "credentials_id", credentials.ID) + if err := r.updateTools(); err != nil { + slog.ErrorContext(r.ctx, "failed to update tools", "error", err) + } + }() + + r.mux.Lock() + if !shouldUpdateTools { + slog.InfoContext(r.ctx, "credential ID mismatch; stale event?", "credentials_id", credentials.ID) + r.mux.Unlock() + return + } + + slog.DebugContext(r.ctx, "updating credentials", "credentials_id", credentials.ID) + r.entity.Credentials = credentials + r.ghcli = r.getClientOrStub() + r.mux.Unlock() +} + +func (r *basePoolManager) handleWatcherEvent(event common.ChangePayload) { + dbEntityType := common.DatabaseEntityType(r.entity.EntityType) + switch event.EntityType { + case common.GithubCredentialsEntityType: + credentials, ok := event.Payload.(params.GithubCredentials) + if !ok { + slog.ErrorContext(r.ctx, "failed to cast payload to github credentials") + return + } + r.handleCredentialsUpdate(credentials) + case common.ControllerEntityType: + controllerInfo, ok := event.Payload.(params.ControllerInfo) + if !ok { + slog.ErrorContext(r.ctx, "failed to cast payload to controller info") + return + } + r.handleControllerUpdateEvent(controllerInfo) + case dbEntityType: + entity, ok := event.Payload.(entityGetter) + if !ok { + slog.ErrorContext(r.ctx, "failed to cast payload to entity") + return + } + entityInfo, err := entity.GetEntity() + if err != nil { + slog.ErrorContext(r.ctx, "failed to get entity", "error", err) + return + } + r.handleEntityUpdate(entityInfo) + } +} + +func (r *basePoolManager) runWatcher() { + for { + select { + case <-r.quit: + return + case <-r.ctx.Done(): + return + case event, ok := <-r.consumer.Watch(): + if !ok { + return + } + go r.handleWatcherEvent(event) + } + } +} diff --git a/runner/repositories.go b/runner/repositories.go index e316aa47..ce0bbc73 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -197,16 +197,15 @@ func (r *Runner) UpdateRepository(ctx context.Context, repoID string, param para return params.Repository{}, runnerErrors.NewBadRequestError("invalid pool balancer type: %s", param.PoolBalancerType) } + slog.InfoContext(ctx, "updating repository", "repo_id", repoID, "param", param) repo, err := r.store.UpdateRepository(ctx, repoID, param) if err != nil { return params.Repository{}, errors.Wrap(err, "updating repo") } - // Use the admin context in the pool manager. Any access control is already done above when - // updating the store. - poolMgr, err := r.poolManagerCtrl.UpdateRepoPoolManager(r.ctx, repo) + poolMgr, err := r.poolManagerCtrl.GetRepoPoolManager(repo) if err != nil { - return params.Repository{}, fmt.Errorf("failed to update pool manager: %w", err) + return params.Repository{}, fmt.Errorf("failed to get pool manager: %w", err) } repo.PoolManagerStatus = poolMgr.Status() diff --git a/runner/repositories_test.go b/runner/repositories_test.go index f17bd93a..a13b6112 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -35,21 +35,20 @@ import ( ) type RepoTestFixtures struct { - AdminContext context.Context - Store dbCommon.Store - StoreRepos map[string]params.Repository - Providers map[string]common.Provider - Credentials map[string]params.GithubCredentials - CreateRepoParams params.CreateRepoParams - CreatePoolParams params.CreatePoolParams - CreateInstanceParams params.CreateInstanceParams - UpdateRepoParams params.UpdateEntityParams - UpdatePoolParams params.UpdatePoolParams - UpdatePoolStateParams params.UpdatePoolStateParams - ErrMock error - ProviderMock *runnerCommonMocks.Provider - PoolMgrMock *runnerCommonMocks.PoolManager - PoolMgrCtrlMock *runnerMocks.PoolManagerController + AdminContext context.Context + Store dbCommon.Store + StoreRepos map[string]params.Repository + Providers map[string]common.Provider + Credentials map[string]params.GithubCredentials + CreateRepoParams params.CreateRepoParams + CreatePoolParams params.CreatePoolParams + CreateInstanceParams params.CreateInstanceParams + UpdateRepoParams params.UpdateEntityParams + UpdatePoolParams params.UpdatePoolParams + ErrMock error + ProviderMock *runnerCommonMocks.Provider + PoolMgrMock *runnerCommonMocks.PoolManager + PoolMgrCtrlMock *runnerMocks.PoolManagerController } func init() { @@ -143,9 +142,6 @@ func (s *RepoTestSuite) SetupTest() { Image: "test-images-updated", Flavor: "test-flavor-updated", }, - UpdatePoolStateParams: params.UpdatePoolStateParams{ - WebhookSecret: "test-update-repo-webhook-secret", - }, ErrMock: fmt.Errorf("mock error"), ProviderMock: providerMock, PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()), @@ -327,7 +323,7 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolMgrFailed() { } func (s *RepoTestSuite) TestUpdateRepository() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil) repo, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams) @@ -341,7 +337,7 @@ func (s *RepoTestSuite) TestUpdateRepository() { } func (s *RepoTestSuite) TestUpdateRepositoryBalancingType() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) + s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil) updateRepoParams := s.Fixtures.UpdateRepoParams @@ -372,21 +368,21 @@ func (s *RepoTestSuite) TestUpdateRepositoryInvalidCreds() { } func (s *RepoTestSuite) TestUpdateRepositoryPoolMgrFailed() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) _, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(fmt.Sprintf("failed to update pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) + s.Require().Equal(fmt.Sprintf("failed to get pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) } func (s *RepoTestSuite) TestUpdateRepositoryCreateRepoPoolMgrFailed() { - s.Fixtures.PoolMgrCtrlMock.On("UpdateRepoPoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) + s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) _, err := s.Runner.UpdateRepository(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.UpdateRepoParams) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(fmt.Sprintf("failed to update pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) + s.Require().Equal(fmt.Sprintf("failed to get pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error()) } func (s *RepoTestSuite) TestCreateRepoPool() { diff --git a/runner/runner.go b/runner/runner.go index 2a08ae12..c7ff4534 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -100,23 +100,16 @@ func (p *poolManagerCtrl) CreateRepoPoolManager(ctx context.Context, repo params p.mux.Lock() defer p.mux.Unlock() - creds, err := p.store.GetGithubCredentials(ctx, repo.CredentialsID, true) + entity, err := repo.GetEntity() if err != nil { - return nil, errors.Wrap(err, "fetching credentials") + return nil, errors.Wrap(err, "getting entity") } - cfgInternal, err := p.getInternalConfig(ctx, creds, repo.GetBalancerType()) + instanceTokenGetter, err := auth.NewInstanceTokenGetter(p.config.JWTAuth.Secret) if err != nil { - return nil, errors.Wrap(err, "fetching internal config") + return nil, errors.Wrap(err, "creating instance token getter") } - entity := params.GithubEntity{ - Owner: repo.Owner, - Name: repo.Name, - ID: repo.ID, - WebhookSecret: repo.WebhookSecret, - EntityType: params.GithubEntityTypeRepository, - } - poolManager, err := pool.NewEntityPoolManager(ctx, entity, cfgInternal, providers, store) + poolManager, err := pool.NewEntityPoolManager(ctx, entity, instanceTokenGetter, providers, store) if err != nil { return nil, errors.Wrap(err, "creating repo pool manager") } @@ -124,36 +117,6 @@ func (p *poolManagerCtrl) CreateRepoPoolManager(ctx context.Context, repo params return poolManager, nil } -func (p *poolManagerCtrl) UpdateRepoPoolManager(ctx context.Context, repo params.Repository) (common.PoolManager, error) { - p.mux.Lock() - defer p.mux.Unlock() - - poolMgr, ok := p.repositories[repo.ID] - if !ok { - return nil, errors.Wrapf(runnerErrors.ErrNotFound, "repository %s/%s pool manager not loaded", repo.Owner, repo.Name) - } - - creds, err := p.store.GetGithubCredentials(ctx, repo.CredentialsID, true) - if err != nil { - return nil, errors.Wrap(err, "fetching credentials") - } - - internalCfg, err := p.getInternalConfig(ctx, creds, repo.GetBalancerType()) - if err != nil { - return nil, errors.Wrap(err, "fetching internal config") - } - - newState := params.UpdatePoolStateParams{ - WebhookSecret: repo.WebhookSecret, - InternalConfig: &internalCfg, - } - - if err := poolMgr.RefreshState(newState); err != nil { - return nil, errors.Wrap(err, "updating repo pool manager") - } - return poolMgr, nil -} - func (p *poolManagerCtrl) GetRepoPoolManager(repo params.Repository) (common.PoolManager, error) { if repoPoolMgr, ok := p.repositories[repo.ID]; ok { return repoPoolMgr, nil @@ -183,21 +146,16 @@ func (p *poolManagerCtrl) CreateOrgPoolManager(ctx context.Context, org params.O p.mux.Lock() defer p.mux.Unlock() - creds, err := p.store.GetGithubCredentials(ctx, org.CredentialsID, true) + entity, err := org.GetEntity() if err != nil { - return nil, errors.Wrap(err, "fetching credentials") + return nil, errors.Wrap(err, "getting entity") } - cfgInternal, err := p.getInternalConfig(ctx, creds, org.GetBalancerType()) + + instanceTokenGetter, err := auth.NewInstanceTokenGetter(p.config.JWTAuth.Secret) if err != nil { - return nil, errors.Wrap(err, "fetching internal config") + return nil, errors.Wrap(err, "creating instance token getter") } - entity := params.GithubEntity{ - Owner: org.Name, - ID: org.ID, - WebhookSecret: org.WebhookSecret, - EntityType: params.GithubEntityTypeOrganization, - } - poolManager, err := pool.NewEntityPoolManager(ctx, entity, cfgInternal, providers, store) + poolManager, err := pool.NewEntityPoolManager(ctx, entity, instanceTokenGetter, providers, store) if err != nil { return nil, errors.Wrap(err, "creating org pool manager") } @@ -205,35 +163,6 @@ func (p *poolManagerCtrl) CreateOrgPoolManager(ctx context.Context, org params.O return poolManager, nil } -func (p *poolManagerCtrl) UpdateOrgPoolManager(ctx context.Context, org params.Organization) (common.PoolManager, error) { - p.mux.Lock() - defer p.mux.Unlock() - - poolMgr, ok := p.organizations[org.ID] - if !ok { - return nil, errors.Wrapf(runnerErrors.ErrNotFound, "org %s pool manager not loaded", org.Name) - } - - creds, err := p.store.GetGithubCredentials(ctx, org.CredentialsID, true) - if err != nil { - return nil, errors.Wrap(err, "fetching credentials") - } - internalCfg, err := p.getInternalConfig(ctx, creds, org.GetBalancerType()) - if err != nil { - return nil, errors.Wrap(err, "fetching internal config") - } - - newState := params.UpdatePoolStateParams{ - WebhookSecret: org.WebhookSecret, - InternalConfig: &internalCfg, - } - - if err := poolMgr.RefreshState(newState); err != nil { - return nil, errors.Wrap(err, "updating repo pool manager") - } - return poolMgr, nil -} - func (p *poolManagerCtrl) GetOrgPoolManager(org params.Organization) (common.PoolManager, error) { if orgPoolMgr, ok := p.organizations[org.ID]; ok { return orgPoolMgr, nil @@ -263,22 +192,16 @@ func (p *poolManagerCtrl) CreateEnterprisePoolManager(ctx context.Context, enter p.mux.Lock() defer p.mux.Unlock() - creds, err := p.store.GetGithubCredentials(ctx, enterprise.CredentialsID, true) - if err != nil { - return nil, errors.Wrap(err, "fetching credentials") - } - cfgInternal, err := p.getInternalConfig(ctx, creds, enterprise.GetBalancerType()) + entity, err := enterprise.GetEntity() if err != nil { - return nil, errors.Wrap(err, "fetching internal config") + return nil, errors.Wrap(err, "getting entity") } - entity := params.GithubEntity{ - Owner: enterprise.Name, - ID: enterprise.ID, - WebhookSecret: enterprise.WebhookSecret, - EntityType: params.GithubEntityTypeEnterprise, + instanceTokenGetter, err := auth.NewInstanceTokenGetter(p.config.JWTAuth.Secret) + if err != nil { + return nil, errors.Wrap(err, "creating instance token getter") } - poolManager, err := pool.NewEntityPoolManager(ctx, entity, cfgInternal, providers, store) + poolManager, err := pool.NewEntityPoolManager(ctx, entity, instanceTokenGetter, providers, store) if err != nil { return nil, errors.Wrap(err, "creating enterprise pool manager") } @@ -286,35 +209,6 @@ func (p *poolManagerCtrl) CreateEnterprisePoolManager(ctx context.Context, enter return poolManager, nil } -func (p *poolManagerCtrl) UpdateEnterprisePoolManager(ctx context.Context, enterprise params.Enterprise) (common.PoolManager, error) { - p.mux.Lock() - defer p.mux.Unlock() - - poolMgr, ok := p.enterprises[enterprise.ID] - if !ok { - return nil, errors.Wrapf(runnerErrors.ErrNotFound, "enterprise %s pool manager not loaded", enterprise.Name) - } - - creds, err := p.store.GetGithubCredentials(ctx, enterprise.CredentialsID, true) - if err != nil { - return nil, errors.Wrap(err, "fetching credentials") - } - internalCfg, err := p.getInternalConfig(ctx, creds, enterprise.GetBalancerType()) - if err != nil { - return nil, errors.Wrap(err, "fetching internal config") - } - - newState := params.UpdatePoolStateParams{ - WebhookSecret: enterprise.WebhookSecret, - InternalConfig: &internalCfg, - } - - if err := poolMgr.RefreshState(newState); err != nil { - return nil, errors.Wrap(err, "updating repo pool manager") - } - return poolMgr, nil -} - func (p *poolManagerCtrl) GetEnterprisePoolManager(enterprise params.Enterprise) (common.PoolManager, error) { if enterprisePoolMgr, ok := p.enterprises[enterprise.ID]; ok { return enterprisePoolMgr, nil @@ -340,24 +234,6 @@ func (p *poolManagerCtrl) GetEnterprisePoolManagers() (map[string]common.PoolMan return p.enterprises, nil } -func (p *poolManagerCtrl) getInternalConfig(_ context.Context, creds params.GithubCredentials, poolBalancerType params.PoolBalancerType) (params.Internal, error) { - controllerInfo, err := p.store.ControllerInfo() - if err != nil { - return params.Internal{}, errors.Wrap(err, "fetching controller info") - } - - return params.Internal{ - ControllerID: controllerInfo.ControllerID.String(), - InstanceCallbackURL: controllerInfo.CallbackURL, - InstanceMetadataURL: controllerInfo.MetadataURL, - BaseWebhookURL: controllerInfo.WebhookURL, - ControllerWebhookURL: controllerInfo.ControllerWebhookURL, - JWTSecret: p.config.JWTAuth.Secret, - PoolBalancerType: poolBalancerType, - GithubCredentialsDetails: creds, - }, nil -} - type Runner struct { mux sync.Mutex