diff --git a/app/auth/plugins/azure_managed_identity/outgoing.go b/app/auth/plugins/azure_managed_identity/outgoing.go index 00d281b..5b2ace0 100644 --- a/app/auth/plugins/azure_managed_identity/outgoing.go +++ b/app/auth/plugins/azure_managed_identity/outgoing.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "os" "strconv" "sync" "time" @@ -88,19 +89,18 @@ func (a *AzureManagedIdentity) AddAuth(ctx context.Context, r *http.Request, par } func fetchToken(ctx context.Context, resource, clientID string) (string, time.Time, error) { - q := url.Values{} - q.Set("resource", resource) - q.Set("api-version", "2018-02-01") - if clientID != "" { - q.Set("client_id", clientID) + metaURL, headers, err := metadataRequest(resource, clientID) + if err != nil { + return "", time.Time{}, err } - metaURL := fmt.Sprintf("%s/metadata/identity/oauth2/token?%s", MetadataHost, q.Encode()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, metaURL, nil) if err != nil { return "", time.Time{}, err } - req.Header.Set("Metadata", "true") + for k, v := range headers { + req.Header.Set(k, v) + } resp, err := HTTPClient.Do(req) if err != nil { @@ -127,6 +127,35 @@ func fetchToken(ctx context.Context, resource, clientID string) (string, time.Ti return tr.AccessToken, exp, nil } +func metadataRequest(resource, clientID string) (string, map[string]string, error) { + q := url.Values{} + q.Set("resource", resource) + if clientID != "" { + q.Set("client_id", clientID) + } + + if endpoint := os.Getenv("IDENTITY_ENDPOINT"); endpoint != "" { + header := os.Getenv("IDENTITY_HEADER") + if header == "" { + return "", nil, fmt.Errorf("missing IDENTITY_HEADER for IDENTITY_ENDPOINT") + } + q.Set("api-version", "2019-08-01") + return fmt.Sprintf("%s?%s", endpoint, q.Encode()), map[string]string{"X-IDENTITY-HEADER": header}, nil + } + + if endpoint := os.Getenv("MSI_ENDPOINT"); endpoint != "" { + secret := os.Getenv("MSI_SECRET") + if secret == "" { + return "", nil, fmt.Errorf("missing MSI_SECRET for MSI_ENDPOINT") + } + q.Set("api-version", "2017-09-01") + return fmt.Sprintf("%s?%s", endpoint, q.Encode()), map[string]string{"Secret": secret}, nil + } + + q.Set("api-version", "2018-02-01") + return fmt.Sprintf("%s/metadata/identity/oauth2/token?%s", MetadataHost, q.Encode()), map[string]string{"Metadata": "true"}, nil +} + func parseExpiry(expiresOn string, expiresIn json.Number) time.Time { if expiresOn != "" { if ts, err := strconv.ParseInt(expiresOn, 10, 64); err == nil && ts > 0 { diff --git a/app/auth/plugins/azure_managed_identity/outgoing_test.go b/app/auth/plugins/azure_managed_identity/outgoing_test.go index d431729..8b970dc 100644 --- a/app/auth/plugins/azure_managed_identity/outgoing_test.go +++ b/app/auth/plugins/azure_managed_identity/outgoing_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" "sync/atomic" "testing" "time" @@ -313,3 +314,77 @@ func TestFetchTokenBadURL(t *testing.T) { t.Fatal("expected url parse error") } } + +func TestFetchTokenUsesIdentityEndpoint(t *testing.T) { + resetCache() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-IDENTITY-HEADER"); got != "secret" { + t.Fatalf("unexpected identity header %q", got) + } + if api := r.URL.Query().Get("api-version"); api != "2019-08-01" { + t.Fatalf("unexpected api-version %q", api) + } + fmt.Fprint(w, `{"access_token":"tok","expires_in":120}`) + })) + defer ts.Close() + + t.Setenv("IDENTITY_ENDPOINT", ts.URL+"/token") + t.Setenv("IDENTITY_HEADER", "secret") + + oldClient := HTTPClient + HTTPClient = ts.Client() + defer func() { HTTPClient = oldClient }() + + tok, _, err := fetchToken(context.Background(), "api://res", "") + if err != nil { + t.Fatalf("fetchToken failed: %v", err) + } + if tok != "tok" { + t.Fatalf("unexpected token %q", tok) + } +} + +func TestFetchTokenMissingIdentityHeader(t *testing.T) { + t.Setenv("IDENTITY_ENDPOINT", "http://localhost/identity") + t.Setenv("IDENTITY_HEADER", "") + + if _, _, err := fetchToken(context.Background(), "api://res", ""); err == nil { + t.Fatal("expected error for missing identity header") + } +} + +func TestFetchTokenUsesMSIEndpoint(t *testing.T) { + resetCache() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Secret"); got != "msi-secret" { + t.Fatalf("unexpected MSI secret %q", got) + } + if api := r.URL.Query().Get("api-version"); api != "2017-09-01" { + t.Fatalf("unexpected api-version %q", api) + } + fmt.Fprint(w, `{"access_token":"tok","expires_in":120}`) + })) + defer ts.Close() + + t.Setenv("MSI_ENDPOINT", ts.URL+"/msi/token") + t.Setenv("MSI_SECRET", "msi-secret") + t.Setenv("IDENTITY_ENDPOINT", "") + t.Setenv("IDENTITY_HEADER", "") + + oldClient := HTTPClient + HTTPClient = ts.Client() + defer func() { HTTPClient = oldClient }() + + tok, _, err := fetchToken(context.Background(), "api://res", "") + if err != nil { + t.Fatalf("fetchToken failed: %v", err) + } + if tok != "tok" { + t.Fatalf("unexpected token %q", tok) + } + if got := os.Getenv("IDENTITY_ENDPOINT"); got != "" { + t.Fatalf("IDENTITY_ENDPOINT should be cleared, got %q", got) + } +}