Skip to content

Commit

Permalink
Add custom headers to PUT requests
Browse files Browse the repository at this point in the history
  • Loading branch information
mboersma committed Aug 18, 2023
1 parent 02ab8e6 commit 30c0c18
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 19 deletions.
19 changes: 19 additions & 0 deletions azure/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,25 @@ func (p userAgentPolicy) Do(req *policy.Request) (*http.Response, error) {
return req.Next()
}

// CustomPutHeaderPolicy adds custom headers to a PUT request.
// It implements the policy.Policy interface.
type CustomPutHeaderPolicy struct {
Getter ResourceSpecGetter
}

// Do adds any custom headers to a PUT request.
func (p CustomPutHeaderPolicy) Do(req *policy.Request) (*http.Response, error) {
if req.Raw().Method == http.MethodPut {
headerSpec, ok := p.Getter.(ResourceSpecGetterWithHeaders)
if ok {
for key, element := range headerSpec.CustomHeaders() {
req.Raw().Header.Set(key, element)
}
}
}
return req.Next()
}

// SetAutoRestClientDefaults set authorizer and user agent for autorest client.
func SetAutoRestClientDefaults(c *autorest.Client, auth autorest.Authorizer) {
c.Authorizer = auth
Expand Down
75 changes: 75 additions & 0 deletions azure/defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/go-autorest/autorest"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
"sigs.k8s.io/cluster-api-provider-azure/azure/mock_azure"
"sigs.k8s.io/cluster-api-provider-azure/util/tele"
)

Expand Down Expand Up @@ -118,6 +120,79 @@ func TestPerCallPolicies(t *testing.T) {
g.Expect(resp.StatusCode).To(Equal(http.StatusOK))
}

func TestCustomPutHeaderPolicy(t *testing.T) {
testHeaders := map[string]string{
"X-Test-Header": "test-value",
"X-Test-Header2": "test-value2",
}
testcases := []struct {
name string
method string
headers map[string]string
expected map[string]string
}{
{
name: "should add custom headers to PUT request",
method: http.MethodPut,
headers: testHeaders,
expected: testHeaders,
},
{
name: "should skip empty custom headers for PUT request",
method: http.MethodPut,
},
{
name: "should not add custom headers to GET request",
method: http.MethodGet,
headers: testHeaders,
},
{
name: "should not add custom headers to POST request",
method: http.MethodPost,
headers: testHeaders,
},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
g := NewWithT(t)

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

// This server will check that custom headers are set correctly.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for k, v := range tc.expected {
g.Expect(r.Header.Get(k)).To(Equal(v))
}
fmt.Fprintf(w, "Hello, %s", r.Proto)
}))
defer server.Close()

// Create a custom PUT header per-call policy
opts, err := ARMClientOptions("")
g.Expect(err).NotTo(HaveOccurred())
policies := opts.PerCallPolicies
getterMock := mock_azure.NewMockResourceSpecGetterWithHeaders(mockCtrl)
getterMock.EXPECT().CustomHeaders().Return(tc.headers).AnyTimes()
policy := CustomPutHeaderPolicy{Getter: getterMock}
policies = append(policies, policy)[len(policies)-1:]

// Create a request
req, err := runtime.NewRequest(context.Background(), tc.method, server.URL)
g.Expect(err).NotTo(HaveOccurred())

// Create a pipeline and send the request to the test server for validation.
pipeline := defaultTestPipeline(policies)
resp, err := pipeline.Do(req)
g.Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()
g.Expect(resp.StatusCode).To(Equal(http.StatusOK))
})
}
}

func defaultTestPipeline(policies []policy.Policy) runtime.Pipeline {
return runtime.NewPipeline(
"testmodule",
Expand Down
20 changes: 6 additions & 14 deletions azure/services/agentpools/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,26 @@ type azureClient struct {
}

// newClient creates a new agent pools client from subscription ID.
func newClient(auth azure.Authorizer) (*azureClient, error) {
c, err := newAgentPoolsClient(auth.SubscriptionID(), auth.CloudEnvironment())
func newClient(scope AgentPoolScope) (*azureClient, error) {
c, err := newAgentPoolsClient(scope)
if err != nil {
return nil, errors.Wrap(err, "failed to create managed clusters client")
}
return &azureClient{c}, nil
}

// newAgentPoolsClient creates a new agent pool client from subscription ID.
func newAgentPoolsClient(subscriptionID, azureEnvironment string) (armcontainerservice.AgentPoolsClient, error) {
func newAgentPoolsClient(scope AgentPoolScope) (armcontainerservice.AgentPoolsClient, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return armcontainerservice.AgentPoolsClient{}, errors.Wrap(err, "failed to create default Azure credential")
}
opts, err := azure.ARMClientOptions(azureEnvironment)
opts, err := azure.ARMClientOptions(scope.CloudEnvironment())
opts.PerCallPolicies = append(opts.PerCallPolicies, azure.CustomPutHeaderPolicy{Getter: scope.AgentPoolSpec()})
if err != nil {
return armcontainerservice.AgentPoolsClient{}, errors.Wrap(err, "failed to create ARM client options")
}
factory, err := armcontainerservice.NewClientFactory(subscriptionID, cred, opts)
factory, err := armcontainerservice.NewClientFactory(scope.SubscriptionID(), cred, opts)
if err != nil {
return armcontainerservice.AgentPoolsClient{}, errors.Wrap(err, "failed to create client factory")
}
Expand Down Expand Up @@ -88,15 +89,6 @@ func (ac *azureClient) CreateOrUpdateAsync(ctx context.Context, spec azure.Resou
agentPool = ap
}

// TODO: add in these custom headers
// headerSpec, ok := spec.(azure.ResourceSpecGetterWithHeaders)
// if !ok {
// return nil, nil, errors.Errorf("%T is not a azure.ResourceSpecGetterWithHeaders", spec)
// }
// for key, element := range headerSpec.CustomHeaders() {
// preparer.Header.Add(key, element)
// }

opts := &armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions{ResumeToken: resumeToken}
log.V(4).Info("sending request", "resumeToken", resumeToken)
poller, err = ac.agentpools.BeginCreateOrUpdate(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName(), agentPool, opts)
Expand Down
11 changes: 6 additions & 5 deletions azure/services/managedclusters/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,26 @@ type azureClient struct {
}

// newClient creates a new managed cluster client from an authorizer.
func newClient(auth azure.Authorizer) (*azureClient, error) {
c, err := newManagedClustersClient(auth.SubscriptionID(), auth.CloudEnvironment())
func newClient(scope ManagedClusterScope) (*azureClient, error) {
c, err := newManagedClustersClient(scope)
if err != nil {
return nil, errors.Wrap(err, "failed to create managed clusters client")
}
return &azureClient{c}, nil
}

// newManagedClustersClient creates a new managed clusters client from subscription ID.
func newManagedClustersClient(subscriptionID, azureEnvironment string) (armcontainerservice.ManagedClustersClient, error) {
func newManagedClustersClient(scope ManagedClusterScope) (armcontainerservice.ManagedClustersClient, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return armcontainerservice.ManagedClustersClient{}, errors.Wrap(err, "failed to create default Azure credential")
}
opts, err := azure.ARMClientOptions(azureEnvironment)
opts, err := azure.ARMClientOptions(scope.CloudEnvironment())
opts.PerCallPolicies = append(opts.PerCallPolicies, azure.CustomPutHeaderPolicy{Getter: scope.ManagedClusterSpec()})
if err != nil {
return armcontainerservice.ManagedClustersClient{}, errors.Wrap(err, "failed to create ARM client options")
}
factory, err := armcontainerservice.NewClientFactory(subscriptionID, cred, opts)
factory, err := armcontainerservice.NewClientFactory(scope.SubscriptionID(), cred, opts)
if err != nil {
return armcontainerservice.ManagedClustersClient{}, errors.Wrap(err, "failed to create client factory")
}
Expand Down

0 comments on commit 30c0c18

Please sign in to comment.