From 8b1aee450bc8237f6e07e876b9700038c634daa4 Mon Sep 17 00:00:00 2001 From: Jeremy Eder Date: Fri, 27 Mar 2026 23:17:41 -0400 Subject: [PATCH 1/2] feat(cli): add browser-based OAuth2 login with PKCE Add `acpctl login browser` subcommand that authenticates via OAuth2 Authorization Code + PKCE flow. Opens the user's browser to the identity provider, receives the callback on a local HTTP server, and exchanges the code for access and refresh tokens. - Add pkg/oauth/ package: OIDC discovery, PKCE generation, state generation, authorize URL construction, token exchange, local callback server, cross-platform browser opener - Add `login browser` subcommand with --issuer-url, --client-id, --scopes flags (env var and config file fallbacks) - Extend Config with RefreshToken, IssuerURL, ClientID fields - Support both automatic browser callback and manual redirect URL paste - 15 tests covering PKCE, state, OIDC discovery, token exchange, and callback server scenarios Co-Authored-By: Claude Opus 4.6 (1M context) --- .../cmd/acpctl/login/browser/cmd.go | 206 +++++++++++++++++ .../ambient-cli/cmd/acpctl/login/cmd.go | 7 +- components/ambient-cli/pkg/config/config.go | 18 ++ components/ambient-cli/pkg/oauth/browser.go | 23 ++ components/ambient-cli/pkg/oauth/callback.go | 86 +++++++ .../ambient-cli/pkg/oauth/callback_test.go | 135 +++++++++++ components/ambient-cli/pkg/oauth/oauth.go | 139 ++++++++++++ .../ambient-cli/pkg/oauth/oauth_test.go | 210 ++++++++++++++++++ 8 files changed, 822 insertions(+), 2 deletions(-) create mode 100644 components/ambient-cli/cmd/acpctl/login/browser/cmd.go create mode 100644 components/ambient-cli/pkg/oauth/browser.go create mode 100644 components/ambient-cli/pkg/oauth/callback.go create mode 100644 components/ambient-cli/pkg/oauth/callback_test.go create mode 100644 components/ambient-cli/pkg/oauth/oauth.go create mode 100644 components/ambient-cli/pkg/oauth/oauth_test.go diff --git a/components/ambient-cli/cmd/acpctl/login/browser/cmd.go b/components/ambient-cli/cmd/acpctl/login/browser/cmd.go new file mode 100644 index 000000000..1407d6ec8 --- /dev/null +++ b/components/ambient-cli/cmd/acpctl/login/browser/cmd.go @@ -0,0 +1,206 @@ +// Package browser implements browser-based OAuth2 login using Authorization Code + PKCE. +package browser + +import ( + "bufio" + "context" + "fmt" + "net/url" + "os" + "strings" + "time" + + "github.com/ambient-code/platform/components/ambient-cli/pkg/config" + "github.com/ambient-code/platform/components/ambient-cli/pkg/oauth" + "github.com/spf13/cobra" +) + +var args struct { + issuerURL string + clientID string + scopes string +} + +var Cmd = &cobra.Command{ + Use: "browser", + Short: "Log in via browser-based OAuth2 flow", + Long: `Open a browser to authenticate with the identity provider using OAuth2 +Authorization Code + PKCE. The CLI starts a local callback server to receive the +authorization code, then exchanges it for access and refresh tokens.`, + Args: cobra.NoArgs, + RunE: run, +} + +func init() { + flags := Cmd.Flags() + flags.StringVar(&args.issuerURL, "issuer-url", "", "OIDC issuer URL (e.g. https://keycloak.example.com/realms/myrealm)") + flags.StringVar(&args.clientID, "client-id", "", "OAuth2 client ID") + flags.StringVar(&args.scopes, "scopes", "openid email profile", "OAuth2 scopes to request") +} + +func run(cmd *cobra.Command, _ []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + issuerURL := args.issuerURL + if issuerURL == "" { + issuerURL = cfg.GetIssuerURL() + } + if issuerURL == "" { + return fmt.Errorf("--issuer-url is required (or set AMBIENT_ISSUER_URL / issuer_url in config)") + } + + clientID := args.clientID + if clientID == "" { + clientID = cfg.GetClientID() + } + if clientID == "" { + return fmt.Errorf("--client-id is required (or set AMBIENT_CLIENT_ID / client_id in config)") + } + + fmt.Fprintf(cmd.OutOrStdout(), "Authenticating with %s...\n", issuerURL) + + oidcCfg, err := oauth.DiscoverEndpoints(issuerURL) + if err != nil { + return fmt.Errorf("OIDC discovery: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + state, err := oauth.GenerateState() + if err != nil { + return err + } + + pkce, err := oauth.GeneratePKCE() + if err != nil { + return err + } + + addr, resultCh, cleanup, err := oauth.StartCallbackServer(ctx, state) + if err != nil { + return err + } + defer cleanup() + + redirectURI := "http://" + addr + "/callback" + authorizeURL := oauth.BuildAuthorizeURL( + oidcCfg.AuthorizationEndpoint, + clientID, + redirectURI, + state, + pkce.Challenge, + args.scopes, + ) + + if err := oauth.OpenBrowser(authorizeURL); err != nil { + fmt.Fprintf(cmd.ErrOrStderr(), "Could not open browser: %v\n", err) + } + + fmt.Fprintln(cmd.OutOrStdout(), "If the browser did not open, visit this URL:") + fmt.Fprintln(cmd.OutOrStdout(), authorizeURL) + fmt.Fprintln(cmd.OutOrStdout()) + fmt.Fprintln(cmd.OutOrStdout(), "Or paste the redirect URL here:") + + // Listen for both callback and manual URL paste. + // Use a pipe so we can close the reader to unblock the goroutine. + pr, pw, err := os.Pipe() + if err != nil { + return fmt.Errorf("create pipe: %w", err) + } + defer pr.Close() + + // Copy stdin to pipe in background so we can close pr to stop the scanner. + go func() { + defer pw.Close() + buf := make([]byte, 4096) + for { + n, err := os.Stdin.Read(buf) + if n > 0 { + pw.Write(buf[:n]) //nolint:errcheck + } + if err != nil { + return + } + } + }() + + manualCh := make(chan oauth.CallbackResult, 1) + go func() { + scanner := bufio.NewScanner(pr) + if scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + return + } + parsed, err := url.Parse(line) + if err != nil { + manualCh <- oauth.CallbackResult{Err: fmt.Errorf("invalid URL: %w", err)} + return + } + code := parsed.Query().Get("code") + pastedState := parsed.Query().Get("state") + if code == "" { + manualCh <- oauth.CallbackResult{Err: fmt.Errorf("URL missing 'code' parameter")} + return + } + if pastedState != "" && pastedState != state { + manualCh <- oauth.CallbackResult{Err: fmt.Errorf("state mismatch in pasted URL")} + return + } + manualCh <- oauth.CallbackResult{Code: code} + } + }() + + var result oauth.CallbackResult + select { + case result = <-resultCh: + case result = <-manualCh: + case <-ctx.Done(): + return fmt.Errorf("login timed out after 5 minutes") + } + + if result.Err != nil { + return fmt.Errorf("authorization failed: %w", result.Err) + } + + fmt.Fprintln(cmd.OutOrStdout(), "Authorization code received, exchanging for tokens...") + + tokenResp, err := oauth.ExchangeCode( + oidcCfg.TokenEndpoint, + clientID, + result.Code, + redirectURI, + pkce.Verifier, + ) + if err != nil { + return fmt.Errorf("token exchange: %w", err) + } + + cfg.AccessToken = tokenResp.AccessToken + cfg.RefreshToken = tokenResp.RefreshToken + cfg.IssuerURL = issuerURL + cfg.ClientID = clientID + + if err := config.Save(cfg); err != nil { + return fmt.Errorf("save config: %w", err) + } + + location, err := config.Location() + if err != nil { + fmt.Fprintln(cmd.OutOrStdout(), "Login successful. Configuration saved.") + } else { + fmt.Fprintf(cmd.OutOrStdout(), "Login successful. Configuration saved to %s\n", location) + } + + if exp, err := config.TokenExpiry(tokenResp.AccessToken); err == nil && !exp.IsZero() { + if time.Until(exp) < 24*time.Hour { + fmt.Fprintf(cmd.ErrOrStderr(), "Note: token expires at %s\n", exp.Format(time.RFC3339)) + } + } + + return nil +} diff --git a/components/ambient-cli/cmd/acpctl/login/cmd.go b/components/ambient-cli/cmd/acpctl/login/cmd.go index 4269c33f6..3cc2c1454 100644 --- a/components/ambient-cli/cmd/acpctl/login/cmd.go +++ b/components/ambient-cli/cmd/acpctl/login/cmd.go @@ -6,6 +6,7 @@ import ( "net/url" "time" + "github.com/ambient-code/platform/components/ambient-cli/cmd/acpctl/login/browser" "github.com/ambient-code/platform/components/ambient-cli/pkg/config" "github.com/spf13/cobra" ) @@ -27,15 +28,17 @@ var Cmd = &cobra.Command{ func init() { flags := Cmd.Flags() - flags.StringVar(&args.token, "token", "", "Access token (required)") + flags.StringVar(&args.token, "token", "", "Access token (required when not using 'browser' subcommand)") flags.StringVar(&args.url, "url", "", "API server URL (default: http://localhost:8000)") flags.StringVar(&args.project, "project", "", "Default project name") flags.BoolVar(&args.insecureSkipVerify, "insecure-skip-tls-verify", false, "Skip TLS certificate verification (insecure)") + + Cmd.AddCommand(browser.Cmd) } func run(cmd *cobra.Command, positional []string) error { if args.token == "" { - return fmt.Errorf("--token is required") + return fmt.Errorf("--token is required (or use 'acpctl login browser' for browser-based OAuth login)") } cfg, err := config.Load() diff --git a/components/ambient-cli/pkg/config/config.go b/components/ambient-cli/pkg/config/config.go index 4b09ca0ed..5a493bc7a 100644 --- a/components/ambient-cli/pkg/config/config.go +++ b/components/ambient-cli/pkg/config/config.go @@ -13,6 +13,9 @@ import ( type Config struct { APIUrl string `json:"api_url,omitempty"` AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + IssuerURL string `json:"issuer_url,omitempty"` + ClientID string `json:"client_id,omitempty"` Project string `json:"project,omitempty"` Pager string `json:"pager,omitempty"` // TODO: Wire pager support into output commands (e.g. pipe through less) RequestTimeout int `json:"request_timeout,omitempty"` // Request timeout in seconds @@ -80,6 +83,21 @@ func Save(cfg *Config) error { func (c *Config) ClearToken() { c.AccessToken = "" + c.RefreshToken = "" +} + +func (c *Config) GetIssuerURL() string { + if env := os.Getenv("AMBIENT_ISSUER_URL"); env != "" { + return env + } + return c.IssuerURL +} + +func (c *Config) GetClientID() string { + if env := os.Getenv("AMBIENT_CLIENT_ID"); env != "" { + return env + } + return c.ClientID } func (c *Config) GetAPIUrl() string { diff --git a/components/ambient-cli/pkg/oauth/browser.go b/components/ambient-cli/pkg/oauth/browser.go new file mode 100644 index 000000000..604869c15 --- /dev/null +++ b/components/ambient-cli/pkg/oauth/browser.go @@ -0,0 +1,23 @@ +package oauth + +import ( + "fmt" + "os/exec" + "runtime" +) + +// OpenBrowser opens the specified URL in the user's default browser. +func OpenBrowser(url string) error { + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "linux": + cmd = exec.Command("xdg-open", url) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url) + default: + return fmt.Errorf("unsupported platform %q", runtime.GOOS) + } + return cmd.Start() +} diff --git a/components/ambient-cli/pkg/oauth/callback.go b/components/ambient-cli/pkg/oauth/callback.go new file mode 100644 index 000000000..eca03ef2d --- /dev/null +++ b/components/ambient-cli/pkg/oauth/callback.go @@ -0,0 +1,86 @@ +package oauth + +import ( + "context" + "fmt" + "net" + "net/http" + "time" +) + +// CallbackResult holds the authorization code received from the callback. +type CallbackResult struct { + Code string + Err error +} + +// StartCallbackServer starts a local HTTP server on a random port to receive +// the OAuth callback. It returns the server's address and a channel that will +// receive the authorization code. +func StartCallbackServer(ctx context.Context, expectedState string) (addr string, resultCh <-chan CallbackResult, cleanup func(), err error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", nil, nil, fmt.Errorf("listen on localhost: %w", err) + } + + ch := make(chan CallbackResult, 1) + mux := http.NewServeMux() + server := &http.Server{Handler: mux} + + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + state := r.URL.Query().Get("state") + if state != expectedState { + http.Error(w, "Invalid state parameter", http.StatusBadRequest) + ch <- CallbackResult{Err: fmt.Errorf("state mismatch: expected %q, got %q", expectedState, state)} + return + } + + errParam := r.URL.Query().Get("error") + if errParam != "" { + desc := r.URL.Query().Get("error_description") + http.Error(w, "Authorization failed: "+errParam, http.StatusBadRequest) + ch <- CallbackResult{Err: fmt.Errorf("authorization error: %s: %s", errParam, desc)} + return + } + + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, "Missing authorization code", http.StatusBadRequest) + ch <- CallbackResult{Err: fmt.Errorf("callback missing authorization code")} + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, successHTML) + ch <- CallbackResult{Code: code} + }) + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + ch <- CallbackResult{Err: fmt.Errorf("callback server: %w", err)} + } + }() + + cleanupFn := func() { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + server.Shutdown(shutdownCtx) //nolint:errcheck + } + + return listener.Addr().String(), ch, cleanupFn, nil +} + +const successHTML = ` +Login Successful + +
+

Login Successful

+

You can close this window and return to the terminal.

+
` diff --git a/components/ambient-cli/pkg/oauth/callback_test.go b/components/ambient-cli/pkg/oauth/callback_test.go new file mode 100644 index 000000000..59262805e --- /dev/null +++ b/components/ambient-cli/pkg/oauth/callback_test.go @@ -0,0 +1,135 @@ +package oauth + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" +) + +func TestCallbackServer_Success(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + state := "test-state-123" + addr, resultCh, cleanup, err := StartCallbackServer(ctx, state) + if err != nil { + t.Fatalf("StartCallbackServer() error: %v", err) + } + defer cleanup() + + callbackURL := fmt.Sprintf("http://%s/callback?code=test-code&state=%s", addr, state) + resp, err := http.Get(callbackURL) //nolint:gosec + if err != nil { + t.Fatalf("GET callback error: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("callback status = %d, want 200", resp.StatusCode) + } + + select { + case result := <-resultCh: + if result.Err != nil { + t.Fatalf("unexpected error: %v", result.Err) + } + if result.Code != "test-code" { + t.Errorf("code = %q, want %q", result.Code, "test-code") + } + case <-ctx.Done(): + t.Fatal("timed out waiting for result") + } +} + +func TestCallbackServer_StateMismatch(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + addr, resultCh, cleanup, err := StartCallbackServer(ctx, "expected-state") + if err != nil { + t.Fatalf("StartCallbackServer() error: %v", err) + } + defer cleanup() + + callbackURL := fmt.Sprintf("http://%s/callback?code=test-code&state=wrong-state", addr) + resp, err := http.Get(callbackURL) //nolint:gosec + if err != nil { + t.Fatalf("GET callback error: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("callback status = %d, want 400", resp.StatusCode) + } + + select { + case result := <-resultCh: + if result.Err == nil { + t.Fatal("expected error for state mismatch") + } + case <-ctx.Done(): + t.Fatal("timed out waiting for result") + } +} + +func TestCallbackServer_MissingCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + state := "test-state" + addr, resultCh, cleanup, err := StartCallbackServer(ctx, state) + if err != nil { + t.Fatalf("StartCallbackServer() error: %v", err) + } + defer cleanup() + + callbackURL := fmt.Sprintf("http://%s/callback?state=%s", addr, state) + resp, err := http.Get(callbackURL) //nolint:gosec + if err != nil { + t.Fatalf("GET callback error: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("callback status = %d, want 400", resp.StatusCode) + } + + select { + case result := <-resultCh: + if result.Err == nil { + t.Fatal("expected error for missing code") + } + case <-ctx.Done(): + t.Fatal("timed out waiting for result") + } +} + +func TestCallbackServer_AuthorizationError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + state := "test-state" + addr, resultCh, cleanup, err := StartCallbackServer(ctx, state) + if err != nil { + t.Fatalf("StartCallbackServer() error: %v", err) + } + defer cleanup() + + callbackURL := fmt.Sprintf("http://%s/callback?state=%s&error=access_denied&error_description=user+denied", addr, state) + resp, err := http.Get(callbackURL) //nolint:gosec + if err != nil { + t.Fatalf("GET callback error: %v", err) + } + resp.Body.Close() + + select { + case result := <-resultCh: + if result.Err == nil { + t.Fatal("expected error for authorization error") + } + case <-ctx.Done(): + t.Fatal("timed out waiting for result") + } +} diff --git a/components/ambient-cli/pkg/oauth/oauth.go b/components/ambient-cli/pkg/oauth/oauth.go new file mode 100644 index 000000000..693ab7b58 --- /dev/null +++ b/components/ambient-cli/pkg/oauth/oauth.go @@ -0,0 +1,139 @@ +// Package oauth implements OAuth2 Authorization Code + PKCE for CLI authentication. +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +var httpClient = &http.Client{Timeout: 30 * time.Second} + +// OIDCConfig holds the endpoints discovered from the issuer. +type OIDCConfig struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` +} + +// TokenResponse holds the tokens returned by the token endpoint. +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + IDToken string `json:"id_token,omitempty"` +} + +// PKCE holds the code verifier and challenge pair. +type PKCE struct { + Verifier string + Challenge string +} + +// DiscoverEndpoints fetches OIDC configuration from the issuer's well-known endpoint. +func DiscoverEndpoints(issuerURL string) (*OIDCConfig, error) { + wellKnown := strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration" + + resp, err := httpClient.Get(wellKnown) //nolint:gosec // URL is user-provided issuer, not attacker-controlled + if err != nil { + return nil, fmt.Errorf("fetch OIDC discovery: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OIDC discovery returned %d: %s", resp.StatusCode, string(body)) + } + + var cfg OIDCConfig + if err := json.NewDecoder(resp.Body).Decode(&cfg); err != nil { + return nil, fmt.Errorf("parse OIDC discovery: %w", err) + } + + if cfg.AuthorizationEndpoint == "" || cfg.TokenEndpoint == "" { + return nil, fmt.Errorf("OIDC discovery missing required endpoints") + } + + return &cfg, nil +} + +// GeneratePKCE creates a PKCE code verifier (43 chars) and S256 code challenge. +func GeneratePKCE() (*PKCE, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return nil, fmt.Errorf("generate PKCE verifier: %w", err) + } + verifier := base64.RawURLEncoding.EncodeToString(buf) + + h := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(h[:]) + + return &PKCE{Verifier: verifier, Challenge: challenge}, nil +} + +// GenerateState creates a cryptographically random state parameter. +func GenerateState() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", fmt.Errorf("generate state: %w", err) + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +// BuildAuthorizeURL constructs the full authorization URL with all required parameters. +func BuildAuthorizeURL(authEndpoint, clientID, redirectURI, state, codeChallenge, scopes string) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {redirectURI}, + "state": {state}, + "code_challenge": {codeChallenge}, + "code_challenge_method": {"S256"}, + "scope": {scopes}, + } + return authEndpoint + "?" + params.Encode() +} + +// ExchangeCode exchanges an authorization code for tokens. +func ExchangeCode(tokenEndpoint, clientID, code, redirectURI, codeVerifier string) (*TokenResponse, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {clientID}, + "code": {code}, + "redirect_uri": {redirectURI}, + "code_verifier": {codeVerifier}, + } + + resp, err := httpClient.PostForm(tokenEndpoint, data) //nolint:gosec // URL is from OIDC discovery + if err != nil { + return nil, fmt.Errorf("token exchange: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("parse token response: %w", err) + } + + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("token response missing access_token") + } + + return &tokenResp, nil +} diff --git a/components/ambient-cli/pkg/oauth/oauth_test.go b/components/ambient-cli/pkg/oauth/oauth_test.go new file mode 100644 index 000000000..7fcfb182b --- /dev/null +++ b/components/ambient-cli/pkg/oauth/oauth_test.go @@ -0,0 +1,210 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + pkce, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if len(pkce.Verifier) != 43 { + t.Errorf("verifier length = %d, want 43", len(pkce.Verifier)) + } + + // Verify S256 challenge matches verifier + h := sha256.Sum256([]byte(pkce.Verifier)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + if pkce.Challenge != expected { + t.Errorf("challenge mismatch:\n got %q\n want %q", pkce.Challenge, expected) + } +} + +func TestGeneratePKCE_Uniqueness(t *testing.T) { + p1, _ := GeneratePKCE() + p2, _ := GeneratePKCE() + if p1.Verifier == p2.Verifier { + t.Error("two PKCE verifiers should not be identical") + } +} + +func TestGenerateState(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState() error: %v", err) + } + + if len(state) != 43 { + t.Errorf("state length = %d, want 43", len(state)) + } +} + +func TestGenerateState_Uniqueness(t *testing.T) { + s1, _ := GenerateState() + s2, _ := GenerateState() + if s1 == s2 { + t.Error("two state values should not be identical") + } +} + +func TestBuildAuthorizeURL(t *testing.T) { + result := BuildAuthorizeURL( + "https://auth.example.com/authorize", + "my-client", + "http://localhost:12345/callback", + "test-state", + "test-challenge", + "openid email", + ) + + parsed, err := url.Parse(result) + if err != nil { + t.Fatalf("failed to parse URL: %v", err) + } + + if parsed.Scheme != "https" || parsed.Host != "auth.example.com" || parsed.Path != "/authorize" { + t.Errorf("unexpected base URL: %s", result) + } + + params := parsed.Query() + tests := map[string]string{ + "response_type": "code", + "client_id": "my-client", + "redirect_uri": "http://localhost:12345/callback", + "state": "test-state", + "code_challenge": "test-challenge", + "code_challenge_method": "S256", + "scope": "openid email", + } + + for key, want := range tests { + if got := params.Get(key); got != want { + t.Errorf("param %q = %q, want %q", key, got, want) + } + } +} + +func TestDiscoverEndpoints(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + json.NewEncoder(w).Encode(OIDCConfig{ + AuthorizationEndpoint: "https://auth.example.com/authorize", + TokenEndpoint: "https://auth.example.com/token", + }) + })) + defer server.Close() + + cfg, err := DiscoverEndpoints(server.URL) + if err != nil { + t.Fatalf("DiscoverEndpoints() error: %v", err) + } + + if cfg.AuthorizationEndpoint != "https://auth.example.com/authorize" { + t.Errorf("authorization_endpoint = %q", cfg.AuthorizationEndpoint) + } + if cfg.TokenEndpoint != "https://auth.example.com/token" { + t.Errorf("token_endpoint = %q", cfg.TokenEndpoint) + } +} + +func TestDiscoverEndpoints_MissingEndpoints(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{"issuer": "https://example.com"}) + })) + defer server.Close() + + _, err := DiscoverEndpoints(server.URL) + if err == nil { + t.Fatal("expected error for missing endpoints") + } + if !strings.Contains(err.Error(), "missing required endpoints") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestDiscoverEndpoints_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal error", http.StatusInternalServerError) + })) + defer server.Close() + + _, err := DiscoverEndpoints(server.URL) + if err == nil { + t.Fatal("expected error for server error") + } +} + +func TestExchangeCode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "bad grant_type", http.StatusBadRequest) + return + } + + json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "Bearer", + ExpiresIn: 3600, + }) + })) + defer server.Close() + + resp, err := ExchangeCode(server.URL, "client", "code", "http://localhost/callback", "verifier") + if err != nil { + t.Fatalf("ExchangeCode() error: %v", err) + } + + if resp.AccessToken != "test-access-token" { + t.Errorf("access_token = %q", resp.AccessToken) + } + if resp.RefreshToken != "test-refresh-token" { + t.Errorf("refresh_token = %q", resp.RefreshToken) + } +} + +func TestExchangeCode_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"invalid_grant"}`, http.StatusBadRequest) + })) + defer server.Close() + + _, err := ExchangeCode(server.URL, "client", "bad-code", "http://localhost/callback", "verifier") + if err == nil { + t.Fatal("expected error for bad grant") + } +} + +func TestExchangeCode_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{"token_type": "Bearer"}) + })) + defer server.Close() + + _, err := ExchangeCode(server.URL, "client", "code", "http://localhost/callback", "verifier") + if err == nil { + t.Fatal("expected error for missing access_token") + } +} From 97541122ebd4e8bfe44bfb58fd086f2ef7d25474 Mon Sep 17 00:00:00 2001 From: Jeremy Eder Date: Fri, 27 Mar 2026 23:44:55 -0400 Subject: [PATCH 2/2] fix(oauth): address CodeRabbit security review feedback - Require state parameter in manual URL paste path to prevent CSRF bypass - Stop persisting unused refresh token to config file - Validate OIDC discovery issuer matches expected issuer (RFC 8414) - Validate discovery endpoints are absolute URLs - Use url.Parse for BuildAuthorizeURL to preserve existing query params - Add TestDiscoverEndpoints_IssuerMismatch test Co-Authored-By: Claude Opus 4.6 (1M context) --- .../cmd/acpctl/login/browser/cmd.go | 13 ++++-- components/ambient-cli/pkg/oauth/oauth.go | 40 ++++++++++++++----- .../ambient-cli/pkg/oauth/oauth_test.go | 31 +++++++++++++- 3 files changed, 69 insertions(+), 15 deletions(-) diff --git a/components/ambient-cli/cmd/acpctl/login/browser/cmd.go b/components/ambient-cli/cmd/acpctl/login/browser/cmd.go index 1407d6ec8..05546613d 100644 --- a/components/ambient-cli/cmd/acpctl/login/browser/cmd.go +++ b/components/ambient-cli/cmd/acpctl/login/browser/cmd.go @@ -87,7 +87,7 @@ func run(cmd *cobra.Command, _ []string) error { defer cleanup() redirectURI := "http://" + addr + "/callback" - authorizeURL := oauth.BuildAuthorizeURL( + authorizeURL, err := oauth.BuildAuthorizeURL( oidcCfg.AuthorizationEndpoint, clientID, redirectURI, @@ -95,6 +95,9 @@ func run(cmd *cobra.Command, _ []string) error { pkce.Challenge, args.scopes, ) + if err != nil { + return fmt.Errorf("build authorize URL: %w", err) + } if err := oauth.OpenBrowser(authorizeURL); err != nil { fmt.Fprintf(cmd.ErrOrStderr(), "Could not open browser: %v\n", err) @@ -147,7 +150,11 @@ func run(cmd *cobra.Command, _ []string) error { manualCh <- oauth.CallbackResult{Err: fmt.Errorf("URL missing 'code' parameter")} return } - if pastedState != "" && pastedState != state { + if pastedState == "" { + manualCh <- oauth.CallbackResult{Err: fmt.Errorf("URL missing 'state' parameter")} + return + } + if pastedState != state { manualCh <- oauth.CallbackResult{Err: fmt.Errorf("state mismatch in pasted URL")} return } @@ -181,7 +188,7 @@ func run(cmd *cobra.Command, _ []string) error { } cfg.AccessToken = tokenResp.AccessToken - cfg.RefreshToken = tokenResp.RefreshToken + cfg.RefreshToken = "" cfg.IssuerURL = issuerURL cfg.ClientID = clientID diff --git a/components/ambient-cli/pkg/oauth/oauth.go b/components/ambient-cli/pkg/oauth/oauth.go index 693ab7b58..a007688b8 100644 --- a/components/ambient-cli/pkg/oauth/oauth.go +++ b/components/ambient-cli/pkg/oauth/oauth.go @@ -18,6 +18,7 @@ var httpClient = &http.Client{Timeout: 30 * time.Second} // OIDCConfig holds the endpoints discovered from the issuer. type OIDCConfig struct { + Issuer string `json:"issuer"` AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` } @@ -57,10 +58,25 @@ func DiscoverEndpoints(issuerURL string) (*OIDCConfig, error) { return nil, fmt.Errorf("parse OIDC discovery: %w", err) } + expectedIssuer := strings.TrimRight(issuerURL, "/") + if cfg.Issuer != expectedIssuer { + return nil, fmt.Errorf("OIDC discovery issuer mismatch: got %q, want %q", cfg.Issuer, expectedIssuer) + } + if cfg.AuthorizationEndpoint == "" || cfg.TokenEndpoint == "" { return nil, fmt.Errorf("OIDC discovery missing required endpoints") } + for name, raw := range map[string]string{ + "authorization_endpoint": cfg.AuthorizationEndpoint, + "token_endpoint": cfg.TokenEndpoint, + } { + u, err := url.Parse(raw) + if err != nil || u.Scheme == "" || u.Host == "" { + return nil, fmt.Errorf("OIDC discovery returned invalid %s: %q", name, raw) + } + } + return &cfg, nil } @@ -88,17 +104,21 @@ func GenerateState() (string, error) { } // BuildAuthorizeURL constructs the full authorization URL with all required parameters. -func BuildAuthorizeURL(authEndpoint, clientID, redirectURI, state, codeChallenge, scopes string) string { - params := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {redirectURI}, - "state": {state}, - "code_challenge": {codeChallenge}, - "code_challenge_method": {"S256"}, - "scope": {scopes}, +func BuildAuthorizeURL(authEndpoint, clientID, redirectURI, state, codeChallenge, scopes string) (string, error) { + u, err := url.Parse(authEndpoint) + if err != nil { + return "", fmt.Errorf("parse authorization endpoint: %w", err) } - return authEndpoint + "?" + params.Encode() + params := u.Query() + params.Set("response_type", "code") + params.Set("client_id", clientID) + params.Set("redirect_uri", redirectURI) + params.Set("state", state) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + params.Set("scope", scopes) + u.RawQuery = params.Encode() + return u.String(), nil } // ExchangeCode exchanges an authorization code for tokens. diff --git a/components/ambient-cli/pkg/oauth/oauth_test.go b/components/ambient-cli/pkg/oauth/oauth_test.go index 7fcfb182b..052aa867a 100644 --- a/components/ambient-cli/pkg/oauth/oauth_test.go +++ b/components/ambient-cli/pkg/oauth/oauth_test.go @@ -57,7 +57,7 @@ func TestGenerateState_Uniqueness(t *testing.T) { } func TestBuildAuthorizeURL(t *testing.T) { - result := BuildAuthorizeURL( + result, err := BuildAuthorizeURL( "https://auth.example.com/authorize", "my-client", "http://localhost:12345/callback", @@ -65,6 +65,9 @@ func TestBuildAuthorizeURL(t *testing.T) { "test-challenge", "openid email", ) + if err != nil { + t.Fatalf("BuildAuthorizeURL() error: %v", err) + } parsed, err := url.Parse(result) if err != nil { @@ -94,17 +97,20 @@ func TestBuildAuthorizeURL(t *testing.T) { } func TestDiscoverEndpoints(t *testing.T) { + var serverURL string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { http.NotFound(w, r) return } json.NewEncoder(w).Encode(OIDCConfig{ + Issuer: serverURL, AuthorizationEndpoint: "https://auth.example.com/authorize", TokenEndpoint: "https://auth.example.com/token", }) })) defer server.Close() + serverURL = server.URL cfg, err := DiscoverEndpoints(server.URL) if err != nil { @@ -120,10 +126,12 @@ func TestDiscoverEndpoints(t *testing.T) { } func TestDiscoverEndpoints_MissingEndpoints(t *testing.T) { + var serverURL string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]string{"issuer": "https://example.com"}) + json.NewEncoder(w).Encode(map[string]string{"issuer": serverURL}) })) defer server.Close() + serverURL = server.URL _, err := DiscoverEndpoints(server.URL) if err == nil { @@ -134,6 +142,25 @@ func TestDiscoverEndpoints_MissingEndpoints(t *testing.T) { } } +func TestDiscoverEndpoints_IssuerMismatch(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "issuer": "https://wrong-issuer.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + }) + })) + defer server.Close() + + _, err := DiscoverEndpoints(server.URL) + if err == nil { + t.Fatal("expected error for issuer mismatch") + } + if !strings.Contains(err.Error(), "issuer mismatch") { + t.Errorf("unexpected error: %v", err) + } +} + func TestDiscoverEndpoints_ServerError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal error", http.StatusInternalServerError)