diff --git a/azure/defaults.go b/azure/defaults.go index 6dab387adf3..9a7d44ce7a2 100644 --- a/azure/defaults.go +++ b/azure/defaults.go @@ -384,6 +384,25 @@ func (p userAgentPolicy) Do(req *policy.Request) (*http.Response, error) { return req.Next() } +// CustomPutHeaderPolicy adds custom headers to a PUT request. +// It implements the policy.Policy interface. +type CustomPutHeaderPolicy struct { + Getter ResourceSpecGetter +} + +// Do adds any custom headers to a PUT request. +func (p CustomPutHeaderPolicy) Do(req *policy.Request) (*http.Response, error) { + if req.Raw().Method == http.MethodPut { + headerSpec, ok := p.Getter.(ResourceSpecGetterWithHeaders) + if ok { + for key, element := range headerSpec.CustomHeaders() { + req.Raw().Header.Set(key, element) + } + } + } + return req.Next() +} + // SetAutoRestClientDefaults set authorizer and user agent for autorest client. func SetAutoRestClientDefaults(c *autorest.Client, auth autorest.Authorizer) { c.Authorizer = auth diff --git a/azure/defaults_test.go b/azure/defaults_test.go index 3b1e96ac4e2..17372374f3b 100644 --- a/azure/defaults_test.go +++ b/azure/defaults_test.go @@ -29,6 +29,8 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/go-autorest/autorest" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" + "sigs.k8s.io/cluster-api-provider-azure/azure/mock_azure" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -118,6 +120,79 @@ func TestPerCallPolicies(t *testing.T) { g.Expect(resp.StatusCode).To(Equal(http.StatusOK)) } +func TestCustomPutHeaderPolicy(t *testing.T) { + testHeaders := map[string]string{ + "X-Test-Header": "test-value", + "X-Test-Header2": "test-value2", + } + testcases := []struct { + name string + method string + headers map[string]string + expected map[string]string + }{ + { + name: "should add custom headers to PUT request", + method: http.MethodPut, + headers: testHeaders, + expected: testHeaders, + }, + { + name: "should skip empty custom headers for PUT request", + method: http.MethodPut, + }, + { + name: "should not add custom headers to GET request", + method: http.MethodGet, + headers: testHeaders, + }, + { + name: "should not add custom headers to POST request", + method: http.MethodPost, + headers: testHeaders, + }, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + // This server will check that custom headers are set correctly. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range tc.expected { + g.Expect(r.Header.Get(k)).To(Equal(v)) + } + fmt.Fprintf(w, "Hello, %s", r.Proto) + })) + defer server.Close() + + // Create a custom PUT header per-call policy + opts, err := ARMClientOptions("") + g.Expect(err).NotTo(HaveOccurred()) + policies := opts.PerCallPolicies + getterMock := mock_azure.NewMockResourceSpecGetterWithHeaders(mockCtrl) + getterMock.EXPECT().CustomHeaders().Return(tc.headers).AnyTimes() + policy := CustomPutHeaderPolicy{Getter: getterMock} + policies = append(policies, policy)[len(policies)-1:] + + // Create a request + req, err := runtime.NewRequest(context.Background(), tc.method, server.URL) + g.Expect(err).NotTo(HaveOccurred()) + + // Create a pipeline and send the request to the test server for validation. + pipeline := defaultTestPipeline(policies) + resp, err := pipeline.Do(req) + g.Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + g.Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + } +} + func defaultTestPipeline(policies []policy.Policy) runtime.Pipeline { return runtime.NewPipeline( "testmodule", diff --git a/azure/services/agentpools/client.go b/azure/services/agentpools/client.go index e46618e9f5c..cb111591af5 100644 --- a/azure/services/agentpools/client.go +++ b/azure/services/agentpools/client.go @@ -34,8 +34,8 @@ type azureClient struct { } // newClient creates a new agent pools client from subscription ID. -func newClient(auth azure.Authorizer) (*azureClient, error) { - c, err := newAgentPoolsClient(auth.SubscriptionID(), auth.CloudEnvironment()) +func newClient(scope AgentPoolScope) (*azureClient, error) { + c, err := newAgentPoolsClient(scope) if err != nil { return nil, errors.Wrap(err, "failed to create managed clusters client") } @@ -43,16 +43,17 @@ func newClient(auth azure.Authorizer) (*azureClient, error) { } // newAgentPoolsClient creates a new agent pool client from subscription ID. -func newAgentPoolsClient(subscriptionID, azureEnvironment string) (armcontainerservice.AgentPoolsClient, error) { +func newAgentPoolsClient(scope AgentPoolScope) (armcontainerservice.AgentPoolsClient, error) { cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return armcontainerservice.AgentPoolsClient{}, errors.Wrap(err, "failed to create default Azure credential") } - opts, err := azure.ARMClientOptions(azureEnvironment) + opts, err := azure.ARMClientOptions(scope.CloudEnvironment()) + opts.PerCallPolicies = append(opts.PerCallPolicies, azure.CustomPutHeaderPolicy{Getter: scope.AgentPoolSpec()}) if err != nil { return armcontainerservice.AgentPoolsClient{}, errors.Wrap(err, "failed to create ARM client options") } - factory, err := armcontainerservice.NewClientFactory(subscriptionID, cred, opts) + factory, err := armcontainerservice.NewClientFactory(scope.SubscriptionID(), cred, opts) if err != nil { return armcontainerservice.AgentPoolsClient{}, errors.Wrap(err, "failed to create client factory") } @@ -88,15 +89,6 @@ func (ac *azureClient) CreateOrUpdateAsync(ctx context.Context, spec azure.Resou agentPool = ap } - // TODO: add in these custom headers - // headerSpec, ok := spec.(azure.ResourceSpecGetterWithHeaders) - // if !ok { - // return nil, nil, errors.Errorf("%T is not a azure.ResourceSpecGetterWithHeaders", spec) - // } - // for key, element := range headerSpec.CustomHeaders() { - // preparer.Header.Add(key, element) - // } - opts := &armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions{ResumeToken: resumeToken} log.V(4).Info("sending request", "resumeToken", resumeToken) poller, err = ac.agentpools.BeginCreateOrUpdate(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName(), agentPool, opts) diff --git a/azure/services/managedclusters/client.go b/azure/services/managedclusters/client.go index 6bd442eb745..429e05336b5 100644 --- a/azure/services/managedclusters/client.go +++ b/azure/services/managedclusters/client.go @@ -39,8 +39,8 @@ type azureClient struct { } // newClient creates a new managed cluster client from an authorizer. -func newClient(auth azure.Authorizer) (*azureClient, error) { - c, err := newManagedClustersClient(auth.SubscriptionID(), auth.CloudEnvironment()) +func newClient(scope ManagedClusterScope) (*azureClient, error) { + c, err := newManagedClustersClient(scope) if err != nil { return nil, errors.Wrap(err, "failed to create managed clusters client") } @@ -48,16 +48,17 @@ func newClient(auth azure.Authorizer) (*azureClient, error) { } // newManagedClustersClient creates a new managed clusters client from subscription ID. -func newManagedClustersClient(subscriptionID, azureEnvironment string) (armcontainerservice.ManagedClustersClient, error) { +func newManagedClustersClient(scope ManagedClusterScope) (armcontainerservice.ManagedClustersClient, error) { cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return armcontainerservice.ManagedClustersClient{}, errors.Wrap(err, "failed to create default Azure credential") } - opts, err := azure.ARMClientOptions(azureEnvironment) + opts, err := azure.ARMClientOptions(scope.CloudEnvironment()) + opts.PerCallPolicies = append(opts.PerCallPolicies, azure.CustomPutHeaderPolicy{Getter: scope.ManagedClusterSpec()}) if err != nil { return armcontainerservice.ManagedClustersClient{}, errors.Wrap(err, "failed to create ARM client options") } - factory, err := armcontainerservice.NewClientFactory(subscriptionID, cred, opts) + factory, err := armcontainerservice.NewClientFactory(scope.SubscriptionID(), cred, opts) if err != nil { return armcontainerservice.ManagedClustersClient{}, errors.Wrap(err, "failed to create client factory") }