diff --git a/cmd/login.go b/cmd/login.go
index e504e37..fe652b7 100644
--- a/cmd/login.go
+++ b/cmd/login.go
@@ -7,14 +7,23 @@ import (
)
func init() {
- rootCmd.AddCommand(&cobra.Command{
+ var token string
+
+ c := &cobra.Command{
Use: "login",
Short: "Authenticate with your Supermodel account",
- Long: `Prompts for an API key and saves it to ~/.supermodel/config.yaml.
+ Long: `Opens your browser to create an API key and automatically saves it.
-Get a key at https://supermodeltools.com/dashboard`,
+For CI or headless environments, pass the key directly:
+ supermodel login --token smsk_live_...`,
RunE: func(cmd *cobra.Command, _ []string) error {
+ if token != "" {
+ return auth.LoginWithToken(token)
+ }
return auth.Login(cmd.Context())
},
- })
+ }
+
+ c.Flags().StringVar(&token, "token", "", "API key for non-interactive login (CI)")
+ rootCmd.AddCommand(c)
}
diff --git a/internal/auth/handler.go b/internal/auth/handler.go
index 656f776..3f187c2 100644
--- a/internal/auth/handler.go
+++ b/internal/auth/handler.go
@@ -3,10 +3,17 @@ package auth
import (
"bufio"
"context"
+ "crypto/rand"
+ "encoding/hex"
"fmt"
+ "net"
+ "net/http"
"os"
+ "os/exec"
+ "runtime"
"strings"
"syscall"
+ "time"
"golang.org/x/term"
@@ -14,30 +21,104 @@ import (
"github.com/supermodeltools/cli/internal/ui"
)
-// Login prompts the user for an API key and saves it to the config file.
-// Input is read without echo when a terminal is attached.
-func Login(_ context.Context) error {
- fmt.Println("Get your API key at https://supermodeltools.com/dashboard")
- fmt.Print("Paste your API key: ")
+const dashboardBase = "https://dashboard.supermodeltools.com"
- key, err := readSecret()
+// Login runs the browser-based login flow. Opens the dashboard to create an
+// API key, receives it via localhost callback, validates, and saves it.
+// Falls back to manual paste if the browser flow fails.
+func Login(ctx context.Context) error {
+ cfg, err := config.Load()
if err != nil {
- return fmt.Errorf("read input: %w", err)
+ return err
}
- key = strings.TrimSpace(key)
- if key == "" {
- return fmt.Errorf("API key cannot be empty")
+
+ // Start localhost server on a random port.
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Could not start local server — falling back to manual login.")
+ return loginManual(cfg)
+ }
+ port := listener.Addr().(*net.TCPAddr).Port
+ state := randomState()
+
+ // Channel to receive the API key from the callback.
+ keyCh := make(chan string, 1)
+ errCh := make(chan error, 1)
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("state") != state {
+ http.Error(w, "Invalid state parameter", http.StatusBadRequest)
+ return
+ }
+ key := r.URL.Query().Get("key")
+ if key == "" {
+ http.Error(w, "Missing key", http.StatusBadRequest)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html")
+ fmt.Fprint(w, `
✓ Authenticated
You can close this tab and return to your terminal.
`)
+ keyCh <- key
+ })
+
+ srv := &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec // localhost-only server
+ go func() {
+ if err := srv.Serve(listener); err != nil && err != http.ErrServerClosed {
+ errCh <- err
+ }
+ }()
+ defer srv.Close()
+
+ // Build the dashboard URL and open the browser.
+ authURL := fmt.Sprintf("%s/cli-auth?port=%d&state=%s", dashboardBase, port, state)
+ fmt.Println("Opening browser to log in...")
+ fmt.Printf("If the browser doesn't open, visit:\n %s\n\n", authURL)
+
+ if err := openBrowser(authURL); err != nil {
+ fmt.Fprintln(os.Stderr, "Could not open browser — falling back to manual login.")
+ srv.Close()
+ return loginManual(cfg)
}
+ // Wait for callback or timeout.
+ fmt.Print("Waiting for authentication...")
+ select {
+ case key := <-keyCh:
+ fmt.Println()
+ cfg.APIKey = strings.TrimSpace(key)
+ if err := cfg.Save(); err != nil {
+ return err
+ }
+ ui.Success("Authenticated — key saved to %s", config.Path())
+ return nil
+ case err := <-errCh:
+ fmt.Println()
+ return fmt.Errorf("local server error: %w", err)
+ case <-time.After(5 * time.Minute):
+ fmt.Println()
+ fmt.Fprintln(os.Stderr, "Timed out waiting for browser login — falling back to manual login.")
+ srv.Close()
+ return loginManual(cfg)
+ case <-ctx.Done():
+ fmt.Println()
+ return ctx.Err()
+ }
+}
+
+// LoginWithToken saves an API key directly (for CI/headless use).
+func LoginWithToken(token string) error {
+ token = strings.TrimSpace(token)
+ if token == "" {
+ return fmt.Errorf("API key cannot be empty")
+ }
cfg, err := config.Load()
if err != nil {
return err
}
- cfg.APIKey = key
+ cfg.APIKey = token
if err := cfg.Save(); err != nil {
return err
}
-
ui.Success("Authenticated — key saved to %s", config.Path())
return nil
}
@@ -60,18 +141,58 @@ func Logout(_ context.Context) error {
return nil
}
+// loginManual is the fallback paste-based login.
+func loginManual(cfg *config.Config) error {
+ fmt.Println("Get your API key at https://dashboard.supermodeltools.com/api-keys")
+ fmt.Print("Paste your API key: ")
+
+ key, err := readSecret()
+ if err != nil {
+ return fmt.Errorf("read input: %w", err)
+ }
+ key = strings.TrimSpace(key)
+ if key == "" {
+ return fmt.Errorf("API key cannot be empty")
+ }
+
+ cfg.APIKey = key
+ if err := cfg.Save(); err != nil {
+ return err
+ }
+ ui.Success("Authenticated — key saved to %s", config.Path())
+ return nil
+}
+
+func openBrowser(url string) error {
+ switch runtime.GOOS {
+ case "darwin":
+ return exec.Command("open", url).Start()
+ case "linux":
+ return exec.Command("xdg-open", url).Start()
+ case "windows":
+ return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
+ default:
+ return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
+ }
+}
+
+func randomState() string {
+ b := make([]byte, 16)
+ _, _ = rand.Read(b)
+ return hex.EncodeToString(b)
+}
+
// readSecret reads a line from stdin, suppressing echo when a TTY is attached.
func readSecret() (string, error) {
fd := int(syscall.Stdin) //nolint:unconvert // syscall.Stdin is uintptr on Windows
if term.IsTerminal(fd) {
b, err := term.ReadPassword(fd)
- fmt.Println() // restore newline after hidden input
+ fmt.Println()
if err != nil {
return "", err
}
return string(b), nil
}
- // Non-TTY (pipe, CI): read as plain text
scanner := bufio.NewScanner(os.Stdin)
if scanner.Scan() {
return scanner.Text(), nil
diff --git a/internal/auth/handler_test.go b/internal/auth/handler_test.go
new file mode 100644
index 0000000..3d3a7f2
--- /dev/null
+++ b/internal/auth/handler_test.go
@@ -0,0 +1,185 @@
+package auth
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/supermodeltools/cli/internal/config"
+)
+
+func TestLoginWithToken(t *testing.T) {
+ // Point config to a temp dir so we don't touch real config.
+ tmp := t.TempDir()
+ t.Setenv("HOME", tmp)
+
+ if err := LoginWithToken("smsk_live_test123"); err != nil {
+ t.Fatalf("LoginWithToken: %v", err)
+ }
+
+ cfg, err := config.Load()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if cfg.APIKey != "smsk_live_test123" {
+ t.Errorf("expected key smsk_live_test123, got %q", cfg.APIKey)
+ }
+}
+
+func TestLoginWithToken_Empty(t *testing.T) {
+ if err := LoginWithToken(""); err == nil {
+ t.Fatal("expected error for empty token")
+ }
+}
+
+func TestLoginWithToken_Whitespace(t *testing.T) {
+ tmp := t.TempDir()
+ t.Setenv("HOME", tmp)
+
+ if err := LoginWithToken(" smsk_live_padded "); err != nil {
+ t.Fatalf("LoginWithToken: %v", err)
+ }
+
+ cfg, _ := config.Load()
+ if cfg.APIKey != "smsk_live_padded" {
+ t.Errorf("expected trimmed key, got %q", cfg.APIKey)
+ }
+}
+
+func TestCallbackServer(t *testing.T) {
+ // Simulate the browser callback flow by starting the localhost server
+ // and hitting the callback endpoint directly.
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ port := listener.Addr().(*net.TCPAddr).Port
+ state := "test-state-123"
+
+ keyCh := make(chan string, 1)
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("state") != state {
+ http.Error(w, "bad state", http.StatusBadRequest)
+ return
+ }
+ key := r.URL.Query().Get("key")
+ if key == "" {
+ http.Error(w, "missing key", http.StatusBadRequest)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ keyCh <- key
+ })
+
+ srv := &http.Server{Handler: mux, ReadHeaderTimeout: 5 * time.Second}
+ go srv.Serve(listener)
+ defer srv.Close()
+
+ // Simulate the dashboard redirect.
+ resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/callback?key=smsk_live_from_browser&state=%s", port, state))
+ if err != nil {
+ t.Fatalf("callback request: %v", err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("expected 200, got %d", resp.StatusCode)
+ }
+
+ select {
+ case key := <-keyCh:
+ if key != "smsk_live_from_browser" {
+ t.Errorf("expected smsk_live_from_browser, got %q", key)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for key")
+ }
+}
+
+func TestCallbackServer_BadState(t *testing.T) {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("state") != "correct-state" {
+ http.Error(w, "bad state", http.StatusBadRequest)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ resp, err := http.Get(ts.URL + "/callback?key=smsk_live_x&state=wrong-state")
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Errorf("expected 400 for bad state, got %d", resp.StatusCode)
+ }
+}
+
+func TestCallbackServer_MissingKey(t *testing.T) {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("state") != "s" {
+ http.Error(w, "bad state", http.StatusBadRequest)
+ return
+ }
+ key := r.URL.Query().Get("key")
+ if key == "" {
+ http.Error(w, "missing key", http.StatusBadRequest)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ resp, err := http.Get(ts.URL + "/callback?state=s")
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Errorf("expected 400 for missing key, got %d", resp.StatusCode)
+ }
+}
+
+func TestRandomState(t *testing.T) {
+ s1 := randomState()
+ s2 := randomState()
+ if s1 == s2 {
+ t.Error("randomState should produce unique values")
+ }
+ if len(s1) != 32 { // 16 bytes = 32 hex chars
+ t.Errorf("expected 32 char hex string, got %d chars", len(s1))
+ }
+}
+
+func TestLogout(t *testing.T) {
+ tmp := t.TempDir()
+ t.Setenv("HOME", tmp)
+
+ // Set up a config with a key.
+ cfg := &config.Config{APIKey: "smsk_live_toremove", APIBase: config.DefaultAPIBase, Output: "human"}
+ os.MkdirAll(filepath.Join(tmp, ".supermodel"), 0o700)
+ cfg.Save()
+
+ if err := Logout(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+
+ cfg, _ = config.Load()
+ if cfg.APIKey != "" {
+ t.Errorf("expected empty key after logout, got %q", cfg.APIKey)
+ }
+}