diff --git a/cmd/login.go b/cmd/login.go index e504e37..fe652b7 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -7,14 +7,23 @@ import ( ) func init() { - rootCmd.AddCommand(&cobra.Command{ + var token string + + c := &cobra.Command{ Use: "login", Short: "Authenticate with your Supermodel account", - Long: `Prompts for an API key and saves it to ~/.supermodel/config.yaml. + Long: `Opens your browser to create an API key and automatically saves it. -Get a key at https://supermodeltools.com/dashboard`, +For CI or headless environments, pass the key directly: + supermodel login --token smsk_live_...`, RunE: func(cmd *cobra.Command, _ []string) error { + if token != "" { + return auth.LoginWithToken(token) + } return auth.Login(cmd.Context()) }, - }) + } + + c.Flags().StringVar(&token, "token", "", "API key for non-interactive login (CI)") + rootCmd.AddCommand(c) } diff --git a/internal/auth/handler.go b/internal/auth/handler.go index 656f776..3f187c2 100644 --- a/internal/auth/handler.go +++ b/internal/auth/handler.go @@ -3,10 +3,17 @@ package auth import ( "bufio" "context" + "crypto/rand" + "encoding/hex" "fmt" + "net" + "net/http" "os" + "os/exec" + "runtime" "strings" "syscall" + "time" "golang.org/x/term" @@ -14,30 +21,104 @@ import ( "github.com/supermodeltools/cli/internal/ui" ) -// Login prompts the user for an API key and saves it to the config file. -// Input is read without echo when a terminal is attached. -func Login(_ context.Context) error { - fmt.Println("Get your API key at https://supermodeltools.com/dashboard") - fmt.Print("Paste your API key: ") +const dashboardBase = "https://dashboard.supermodeltools.com" - key, err := readSecret() +// Login runs the browser-based login flow. Opens the dashboard to create an +// API key, receives it via localhost callback, validates, and saves it. +// Falls back to manual paste if the browser flow fails. +func Login(ctx context.Context) error { + cfg, err := config.Load() if err != nil { - return fmt.Errorf("read input: %w", err) + return err } - key = strings.TrimSpace(key) - if key == "" { - return fmt.Errorf("API key cannot be empty") + + // Start localhost server on a random port. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + fmt.Fprintln(os.Stderr, "Could not start local server — falling back to manual login.") + return loginManual(cfg) + } + port := listener.Addr().(*net.TCPAddr).Port + state := randomState() + + // Channel to receive the API key from the callback. + keyCh := make(chan string, 1) + errCh := make(chan error, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + http.Error(w, "Invalid state parameter", http.StatusBadRequest) + return + } + key := r.URL.Query().Get("key") + if key == "" { + http.Error(w, "Missing key", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

✓ Authenticated

You can close this tab and return to your terminal.

`) + keyCh <- key + }) + + srv := &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec // localhost-only server + go func() { + if err := srv.Serve(listener); err != nil && err != http.ErrServerClosed { + errCh <- err + } + }() + defer srv.Close() + + // Build the dashboard URL and open the browser. + authURL := fmt.Sprintf("%s/cli-auth?port=%d&state=%s", dashboardBase, port, state) + fmt.Println("Opening browser to log in...") + fmt.Printf("If the browser doesn't open, visit:\n %s\n\n", authURL) + + if err := openBrowser(authURL); err != nil { + fmt.Fprintln(os.Stderr, "Could not open browser — falling back to manual login.") + srv.Close() + return loginManual(cfg) } + // Wait for callback or timeout. + fmt.Print("Waiting for authentication...") + select { + case key := <-keyCh: + fmt.Println() + cfg.APIKey = strings.TrimSpace(key) + if err := cfg.Save(); err != nil { + return err + } + ui.Success("Authenticated — key saved to %s", config.Path()) + return nil + case err := <-errCh: + fmt.Println() + return fmt.Errorf("local server error: %w", err) + case <-time.After(5 * time.Minute): + fmt.Println() + fmt.Fprintln(os.Stderr, "Timed out waiting for browser login — falling back to manual login.") + srv.Close() + return loginManual(cfg) + case <-ctx.Done(): + fmt.Println() + return ctx.Err() + } +} + +// LoginWithToken saves an API key directly (for CI/headless use). +func LoginWithToken(token string) error { + token = strings.TrimSpace(token) + if token == "" { + return fmt.Errorf("API key cannot be empty") + } cfg, err := config.Load() if err != nil { return err } - cfg.APIKey = key + cfg.APIKey = token if err := cfg.Save(); err != nil { return err } - ui.Success("Authenticated — key saved to %s", config.Path()) return nil } @@ -60,18 +141,58 @@ func Logout(_ context.Context) error { return nil } +// loginManual is the fallback paste-based login. +func loginManual(cfg *config.Config) error { + fmt.Println("Get your API key at https://dashboard.supermodeltools.com/api-keys") + fmt.Print("Paste your API key: ") + + key, err := readSecret() + if err != nil { + return fmt.Errorf("read input: %w", err) + } + key = strings.TrimSpace(key) + if key == "" { + return fmt.Errorf("API key cannot be empty") + } + + cfg.APIKey = key + if err := cfg.Save(); err != nil { + return err + } + ui.Success("Authenticated — key saved to %s", config.Path()) + return nil +} + +func openBrowser(url string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", url).Start() + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +func randomState() string { + b := make([]byte, 16) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + // readSecret reads a line from stdin, suppressing echo when a TTY is attached. func readSecret() (string, error) { fd := int(syscall.Stdin) //nolint:unconvert // syscall.Stdin is uintptr on Windows if term.IsTerminal(fd) { b, err := term.ReadPassword(fd) - fmt.Println() // restore newline after hidden input + fmt.Println() if err != nil { return "", err } return string(b), nil } - // Non-TTY (pipe, CI): read as plain text scanner := bufio.NewScanner(os.Stdin) if scanner.Scan() { return scanner.Text(), nil diff --git a/internal/auth/handler_test.go b/internal/auth/handler_test.go new file mode 100644 index 0000000..3d3a7f2 --- /dev/null +++ b/internal/auth/handler_test.go @@ -0,0 +1,185 @@ +package auth + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/supermodeltools/cli/internal/config" +) + +func TestLoginWithToken(t *testing.T) { + // Point config to a temp dir so we don't touch real config. + tmp := t.TempDir() + t.Setenv("HOME", tmp) + + if err := LoginWithToken("smsk_live_test123"); err != nil { + t.Fatalf("LoginWithToken: %v", err) + } + + cfg, err := config.Load() + if err != nil { + t.Fatal(err) + } + if cfg.APIKey != "smsk_live_test123" { + t.Errorf("expected key smsk_live_test123, got %q", cfg.APIKey) + } +} + +func TestLoginWithToken_Empty(t *testing.T) { + if err := LoginWithToken(""); err == nil { + t.Fatal("expected error for empty token") + } +} + +func TestLoginWithToken_Whitespace(t *testing.T) { + tmp := t.TempDir() + t.Setenv("HOME", tmp) + + if err := LoginWithToken(" smsk_live_padded "); err != nil { + t.Fatalf("LoginWithToken: %v", err) + } + + cfg, _ := config.Load() + if cfg.APIKey != "smsk_live_padded" { + t.Errorf("expected trimmed key, got %q", cfg.APIKey) + } +} + +func TestCallbackServer(t *testing.T) { + // Simulate the browser callback flow by starting the localhost server + // and hitting the callback endpoint directly. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := listener.Addr().(*net.TCPAddr).Port + state := "test-state-123" + + keyCh := make(chan string, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + http.Error(w, "bad state", http.StatusBadRequest) + return + } + key := r.URL.Query().Get("key") + if key == "" { + http.Error(w, "missing key", http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + keyCh <- key + }) + + srv := &http.Server{Handler: mux, ReadHeaderTimeout: 5 * time.Second} + go srv.Serve(listener) + defer srv.Close() + + // Simulate the dashboard redirect. + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/callback?key=smsk_live_from_browser&state=%s", port, state)) + if err != nil { + t.Fatalf("callback request: %v", err) + } + resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + select { + case key := <-keyCh: + if key != "smsk_live_from_browser" { + t.Errorf("expected smsk_live_from_browser, got %q", key) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for key") + } +} + +func TestCallbackServer_BadState(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != "correct-state" { + http.Error(w, "bad state", http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + }) + + ts := httptest.NewServer(mux) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/callback?key=smsk_live_x&state=wrong-state") + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("expected 400 for bad state, got %d", resp.StatusCode) + } +} + +func TestCallbackServer_MissingKey(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != "s" { + http.Error(w, "bad state", http.StatusBadRequest) + return + } + key := r.URL.Query().Get("key") + if key == "" { + http.Error(w, "missing key", http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + }) + + ts := httptest.NewServer(mux) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/callback?state=s") + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("expected 400 for missing key, got %d", resp.StatusCode) + } +} + +func TestRandomState(t *testing.T) { + s1 := randomState() + s2 := randomState() + if s1 == s2 { + t.Error("randomState should produce unique values") + } + if len(s1) != 32 { // 16 bytes = 32 hex chars + t.Errorf("expected 32 char hex string, got %d chars", len(s1)) + } +} + +func TestLogout(t *testing.T) { + tmp := t.TempDir() + t.Setenv("HOME", tmp) + + // Set up a config with a key. + cfg := &config.Config{APIKey: "smsk_live_toremove", APIBase: config.DefaultAPIBase, Output: "human"} + os.MkdirAll(filepath.Join(tmp, ".supermodel"), 0o700) + cfg.Save() + + if err := Logout(context.Background()); err != nil { + t.Fatal(err) + } + + cfg, _ = config.Load() + if cfg.APIKey != "" { + t.Errorf("expected empty key after logout, got %q", cfg.APIKey) + } +}