From 6c77e26f02bed5586784f824ecd495d64d7d3a70 Mon Sep 17 00:00:00 2001 From: Shengming Liang Date: Tue, 23 Jun 2026 17:59:08 -0700 Subject: [PATCH] feat(cli): prompt for updates on startup --- cmd/dedalus/main.go | 11 +- pkg/cmd/startup_update.go | 338 +++++++++++++++++++++++++++++++++ pkg/cmd/startup_update_test.go | 204 ++++++++++++++++++++ 3 files changed, 552 insertions(+), 1 deletion(-) create mode 100644 pkg/cmd/startup_update.go create mode 100644 pkg/cmd/startup_update_test.go diff --git a/cmd/dedalus/main.go b/cmd/dedalus/main.go index 7551ab0..2ccba35 100644 --- a/cmd/dedalus/main.go +++ b/cmd/dedalus/main.go @@ -18,6 +18,7 @@ import ( func main() { app := cmd.Command + ctx := context.Background() if slices.Contains(os.Args, "__complete") { prepareForAutocomplete(app) @@ -30,7 +31,15 @@ func main() { } } - if err := app.Run(context.Background(), os.Args); err != nil { + updated, err := cmd.MaybeRunStartupUpdate(ctx, os.Args, os.Stdin, os.Stdout, os.Stderr) + if err == nil && updated { + return + } + + if err == nil { + err = app.Run(ctx, os.Args) + } + if err != nil { exitCode := 1 // Check if error has a custom exit code diff --git a/pkg/cmd/startup_update.go b/pkg/cmd/startup_update.go new file mode 100644 index 0000000..0481264 --- /dev/null +++ b/pkg/cmd/startup_update.go @@ -0,0 +1,338 @@ +// Copyright (c) 2026 Dedalus Labs, Inc. All rights reserved. + +package cmd + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/charmbracelet/x/term" +) + +const ( + dedalusHomeEnv = "DEDALUS_HOME" + disableUpdateCheckEnv = "DEDALUS_NO_UPDATE_CHECK" + startupUpdateCacheFile = "version.json" + startupUpdateCheckInterval = 20 * time.Hour + startupUpdateCheckTimeout = 2 * time.Second +) + +type startupVersionInfo struct { + LatestVersion string `json:"latest_version"` + LastCheckedAt time.Time `json:"last_checked_at"` +} + +type startupUpdatePrompt struct { + stdin io.Reader + stderr io.Writer + getenv func(string) string + homeDir func() (string, error) + now func() time.Time + isInteractive func() bool + latestVersion func(context.Context) (string, error) + runUpdate func(context.Context) error + cachePath string +} + +// MaybeRunStartupUpdate prompts interactive users to update before command execution. +func MaybeRunStartupUpdate(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) (bool, error) { + p := newStartupUpdatePrompt(stdin, stdout, stderr) + return p.run(ctx, args) +} + +func newStartupUpdatePrompt(stdin io.Reader, stdout, stderr io.Writer) *startupUpdatePrompt { + p := &startupUpdatePrompt{ + stdin: stdin, + stderr: stderr, + getenv: os.Getenv, + homeDir: os.UserHomeDir, + now: time.Now, + } + p.isInteractive = func() bool { + return terminalReader(stdin) && terminalWriter(stdout) && terminalWriter(stderr) + } + p.latestVersion = func(ctx context.Context) (string, error) { + updater := newUpdater(io.Discard, io.Discard) + return updater.latestVersion(ctx) + } + p.runUpdate = func(ctx context.Context) error { + return newUpdater(stdout, stderr).update(ctx, updateOptions{}) + } + return p +} + +func (p *startupUpdatePrompt) run(ctx context.Context, args []string) (bool, error) { + if p.shouldSkip(args) { + return false, nil + } + + latest, ok := p.upgradeVersion(ctx) + if !ok { + return false, nil + } + + return p.prompt(ctx, latest) +} + +func (p *startupUpdatePrompt) shouldSkip(args []string) bool { + if !p.isInteractive() { + return true + } + if envTruthy(p.getenv("CI")) || envTruthy(p.getenv(disableUpdateCheckEnv)) { + return true + } + if hasHelpOrVersionArg(args) { + return true + } + + switch rootCommandArg(args) { + case "update", "__complete", "@completion", "@manpages", "help": + return true + default: + return false + } +} + +func (p *startupUpdatePrompt) upgradeVersion(ctx context.Context) (string, bool) { + cachePath, err := p.versionCachePath() + if err != nil { + return "", false + } + + info, cacheOK := readStartupVersionInfo(cachePath) + if cacheOK && p.now().Sub(info.LastCheckedAt) < startupUpdateCheckInterval { + return newerStartupVersion(info.LatestVersion) + } + + checkCtx, cancel := context.WithTimeout(ctx, startupUpdateCheckTimeout) + defer cancel() + + latest, err := p.latestVersion(checkCtx) + if err == nil { + if err := writeStartupVersionInfo(cachePath, startupVersionInfo{ + LatestVersion: latest, + LastCheckedAt: p.now(), + }); err != nil { + return newerStartupVersion(latest) + } + return newerStartupVersion(latest) + } + + if cacheOK { + return newerStartupVersion(info.LatestVersion) + } + return "", false +} + +func (p *startupUpdatePrompt) prompt(ctx context.Context, latest string) (bool, error) { + fmt.Fprintf(p.stderr, "A new Dedalus CLI is available.\n\n") + fmt.Fprintf(p.stderr, "Current: %s\n", versionTag(Version)) + fmt.Fprintf(p.stderr, "Latest: %s\n\n", versionTag(latest)) + fmt.Fprint(p.stderr, "Update now?\n") + fmt.Fprint(p.stderr, "> Yes\n") + fmt.Fprint(p.stderr, " Not now\n\n") + fmt.Fprint(p.stderr, "Press Enter to update, or type n then Enter to skip: ") + + answer, err := bufio.NewReader(p.stdin).ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + fmt.Fprintln(p.stderr) + return false, nil + } + if errors.Is(err, io.EOF) && strings.TrimSpace(answer) == "" { + fmt.Fprintln(p.stderr) + return false, nil + } + if shouldRunStartupUpdate(answer) { + fmt.Fprintln(p.stderr) + return true, p.runUpdate(ctx) + } + + fmt.Fprintln(p.stderr) + return false, nil +} + +func (p *startupUpdatePrompt) versionCachePath() (string, error) { + if p.cachePath != "" { + return p.cachePath, nil + } + home := strings.TrimSpace(p.getenv(dedalusHomeEnv)) + if home == "" { + userHome, err := p.homeDir() + if err != nil { + return "", err + } + home = filepath.Join(userHome, ".dedalus") + } + return filepath.Join(home, startupUpdateCacheFile), nil +} + +func readStartupVersionInfo(path string) (startupVersionInfo, bool) { + data, err := os.ReadFile(path) + if err != nil { + return startupVersionInfo{}, false + } + var info startupVersionInfo + if err := json.Unmarshal(data, &info); err != nil { + return startupVersionInfo{}, false + } + if strings.TrimSpace(info.LatestVersion) == "" || info.LastCheckedAt.IsZero() { + return startupVersionInfo{}, false + } + return info, true +} + +func writeStartupVersionInfo(path string, info startupVersionInfo) error { + data, err := json.Marshal(info) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + return os.WriteFile(path, append(data, '\n'), 0644) +} + +func newerStartupVersion(latest string) (string, bool) { + if isNewerVersion(latest, Version) { + return versionTag(latest), true + } + return "", false +} + +func shouldRunStartupUpdate(answer string) bool { + answer = strings.ToLower(strings.TrimSpace(answer)) + return answer == "" || answer == "y" || answer == "yes" +} + +func rootCommandArg(args []string) string { + for i := 1; i < len(args); i++ { + arg := args[i] + if arg == "--" { + return "" + } + if strings.HasPrefix(arg, "-") { + if rootFlagTakesValue(arg) && !strings.Contains(arg, "=") { + i++ + } + continue + } + return arg + } + return "" +} + +func rootFlagTakesValue(arg string) bool { + name := strings.TrimLeft(strings.SplitN(arg, "=", 2)[0], "-") + switch name { + case "api-key", "base-url", "dedalus-org-id", "format", "format-error", "transform", "transform-error", "x-api-key": + return true + default: + return false + } +} + +func hasHelpOrVersionArg(args []string) bool { + for _, arg := range args[1:] { + switch arg { + case "-h", "--help", "-v", "--version": + return true + } + } + return false +} + +func envTruthy(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func terminalReader(r io.Reader) bool { + f, ok := r.(*os.File) + return ok && term.IsTerminal(f.Fd()) +} + +func terminalWriter(w io.Writer) bool { + f, ok := w.(*os.File) + return ok && term.IsTerminal(f.Fd()) +} + +type parsedVersion struct { + major int + minor int + patch int + prerelease string +} + +func isNewerVersion(candidate, current string) bool { + candidateVersion, okCandidate := parseVersion(candidate) + currentVersion, okCurrent := parseVersion(current) + if !okCandidate || !okCurrent { + return !sameVersion(candidate, current) + } + + if candidateVersion.major != currentVersion.major { + return candidateVersion.major > currentVersion.major + } + if candidateVersion.minor != currentVersion.minor { + return candidateVersion.minor > currentVersion.minor + } + if candidateVersion.patch != currentVersion.patch { + return candidateVersion.patch > currentVersion.patch + } + if candidateVersion.prerelease == currentVersion.prerelease { + return false + } + if candidateVersion.prerelease == "" { + return true + } + if currentVersion.prerelease == "" { + return false + } + return candidateVersion.prerelease > currentVersion.prerelease +} + +func parseVersion(version string) (parsedVersion, bool) { + version = strings.TrimPrefix(strings.TrimSpace(version), "v") + if version == "" { + return parsedVersion{}, false + } + + mainVersion, prerelease, _ := strings.Cut(version, "-") + parts := strings.Split(mainVersion, ".") + if len(parts) != 3 { + return parsedVersion{}, false + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return parsedVersion{}, false + } + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return parsedVersion{}, false + } + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return parsedVersion{}, false + } + + return parsedVersion{ + major: major, + minor: minor, + patch: patch, + prerelease: prerelease, + }, true +} diff --git a/pkg/cmd/startup_update_test.go b/pkg/cmd/startup_update_test.go new file mode 100644 index 0000000..bece302 --- /dev/null +++ b/pkg/cmd/startup_update_test.go @@ -0,0 +1,204 @@ +// Copyright (c) 2026 Dedalus Labs, Inc. All rights reserved. + +package cmd + +import ( + "bytes" + "context" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestStartupUpdatePromptUsesFreshCacheAndRunsUpdate(t *testing.T) { + now := time.Date(2026, 6, 24, 12, 0, 0, 0, time.UTC) + cachePath := filepath.Join(t.TempDir(), startupUpdateCacheFile) + if err := writeStartupVersionInfo(cachePath, startupVersionInfo{ + LatestVersion: "v9.9.9", + LastCheckedAt: now, + }); err != nil { + t.Fatalf("writeStartupVersionInfo() returned unexpected error: %v", err) + } + + var stderr bytes.Buffer + ranUpdate := false + prompt := newTestStartupUpdatePrompt(t, cachePath, "\n", &stderr) + prompt.now = func() time.Time { return now } + prompt.latestVersion = func(context.Context) (string, error) { + t.Fatal("fresh cache should not fetch latest version") + return "", nil + } + prompt.runUpdate = func(context.Context) error { + ranUpdate = true + return nil + } + + updated, err := prompt.run(context.Background(), []string{"dedalus", "machines", "list"}) + if err != nil { + t.Fatalf("startup update prompt returned unexpected error: %v", err) + } + if !updated { + t.Fatal("startup update prompt did not report that update ran") + } + if !ranUpdate { + t.Fatal("startup update prompt did not run updater") + } + + got := stderr.String() + for _, want := range []string{ + "A new Dedalus CLI is available.", + "Current: v" + Version, + "Latest: v9.9.9", + "Update now?\n> Yes\n Not now", + } { + if !strings.Contains(got, want) { + t.Errorf("prompt output = %q, want substring %q", got, want) + } + } +} + +func TestStartupUpdatePromptDeclinesUpdate(t *testing.T) { + now := time.Date(2026, 6, 24, 12, 0, 0, 0, time.UTC) + cachePath := filepath.Join(t.TempDir(), startupUpdateCacheFile) + if err := writeStartupVersionInfo(cachePath, startupVersionInfo{ + LatestVersion: "v9.9.9", + LastCheckedAt: now, + }); err != nil { + t.Fatalf("writeStartupVersionInfo() returned unexpected error: %v", err) + } + + var stderr bytes.Buffer + prompt := newTestStartupUpdatePrompt(t, cachePath, "n\n", &stderr) + prompt.now = func() time.Time { return now } + prompt.runUpdate = func(context.Context) error { + t.Fatal("declining the prompt should not run updater") + return nil + } + + updated, err := prompt.run(context.Background(), []string{"dedalus", "machines", "list"}) + if err != nil { + t.Fatalf("startup update prompt returned unexpected error: %v", err) + } + if updated { + t.Fatal("startup update prompt reported an update after user declined") + } + if got, want := stderr.String(), "Update now?\n> Yes\n Not now"; !strings.Contains(got, want) { + t.Errorf("prompt output = %q, want substring %q", got, want) + } +} + +func TestStartupUpdatePromptSkipsNonInteractiveLaunch(t *testing.T) { + cachePath := filepath.Join(t.TempDir(), startupUpdateCacheFile) + if err := writeStartupVersionInfo(cachePath, startupVersionInfo{ + LatestVersion: "v9.9.9", + LastCheckedAt: time.Now(), + }); err != nil { + t.Fatalf("writeStartupVersionInfo() returned unexpected error: %v", err) + } + + var stderr bytes.Buffer + prompt := newTestStartupUpdatePrompt(t, cachePath, "\n", &stderr) + prompt.isInteractive = func() bool { return false } + prompt.runUpdate = func(context.Context) error { + t.Fatal("noninteractive launch should not run updater") + return nil + } + + updated, err := prompt.run(context.Background(), []string{"dedalus", "machines", "list"}) + if err != nil { + t.Fatalf("startup update prompt returned unexpected error: %v", err) + } + if updated { + t.Fatal("noninteractive launch reported an update") + } + if got := stderr.String(); got != "" { + t.Errorf("noninteractive prompt output = %q, want empty", got) + } +} + +func TestStartupUpdatePromptRefreshesStaleCache(t *testing.T) { + now := time.Date(2026, 6, 24, 12, 0, 0, 0, time.UTC) + cachePath := filepath.Join(t.TempDir(), startupUpdateCacheFile) + if err := writeStartupVersionInfo(cachePath, startupVersionInfo{ + LatestVersion: "v" + Version, + LastCheckedAt: now.Add(-startupUpdateCheckInterval - time.Minute), + }); err != nil { + t.Fatalf("writeStartupVersionInfo() returned unexpected error: %v", err) + } + + var stderr bytes.Buffer + prompt := newTestStartupUpdatePrompt(t, cachePath, "n\n", &stderr) + prompt.now = func() time.Time { return now } + prompt.latestVersion = func(context.Context) (string, error) { + return "v9.9.9", nil + } + + updated, err := prompt.run(context.Background(), []string{"dedalus", "machines", "list"}) + if err != nil { + t.Fatalf("startup update prompt returned unexpected error: %v", err) + } + if updated { + t.Fatal("declining refreshed prompt reported an update") + } + + info, ok := readStartupVersionInfo(cachePath) + if !ok { + t.Fatal("expected refreshed startup version cache") + } + if info.LatestVersion != "v9.9.9" { + t.Errorf("refreshed latest version = %q, want v9.9.9", info.LatestVersion) + } + if !info.LastCheckedAt.Equal(now) { + t.Errorf("refreshed last_checked_at = %s, want %s", info.LastCheckedAt, now) + } +} + +func TestRootCommandArgSkipsRootUpdateOnly(t *testing.T) { + t.Parallel() + + if got, want := rootCommandArg([]string{"dedalus", "--format", "json", "update"}), "update"; got != want { + t.Errorf("rootCommandArg(root update) = %q, want %q", got, want) + } + if got, want := rootCommandArg([]string{"dedalus", "machines", "update"}), "machines"; got != want { + t.Errorf("rootCommandArg(nested update) = %q, want %q", got, want) + } +} + +func TestIsNewerVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + candidate string + current string + want bool + }{ + {candidate: "v0.4.1", current: "0.4.0", want: true}, + {candidate: "v0.4.0", current: "0.4.0", want: false}, + {candidate: "v0.3.9", current: "0.4.0", want: false}, + {candidate: "v1.0.0-beta", current: "v1.0.0", want: false}, + {candidate: "v1.0.0", current: "v1.0.0-beta", want: true}, + } + + for _, tt := range tests { + if got := isNewerVersion(tt.candidate, tt.current); got != tt.want { + t.Errorf("isNewerVersion(%q, %q) = %t, want %t", tt.candidate, tt.current, got, tt.want) + } + } +} + +func newTestStartupUpdatePrompt(t *testing.T, cachePath string, stdin string, stderr *bytes.Buffer) *startupUpdatePrompt { + t.Helper() + + return &startupUpdatePrompt{ + stdin: strings.NewReader(stdin), + stderr: stderr, + getenv: func(string) string { return "" }, + homeDir: func() (string, error) { return t.TempDir(), nil }, + now: time.Now, + isInteractive: func() bool { return true }, + latestVersion: func(context.Context) (string, error) { return "v" + Version, nil }, + runUpdate: func(context.Context) error { return nil }, + cachePath: cachePath, + } +}