diff --git a/README.md b/README.md index c87c063..2e3f14d 100644 --- a/README.md +++ b/README.md @@ -211,22 +211,34 @@ tools: By default, blocked patterns run in `arg` mode (each argument is matched independently). Use `match: command` when a regex needs to span multiple args (for example `repo\\s+delete`). -### Forced environment variables +### Environment variables -Variables that are always set and cannot be overridden by the agent: +The `env` key supports three value types — all entries are admin-controlled and cannot be overridden by the agent: ```yaml +credentials: + gog-keyring-password: + source: pass:gog/keyring + db-password: + source: op://vault/db/password + tools: gog: binary: /home/linuxbrew/.linuxbrew/bin/gog env: + # Credential reference: value matches a defined credential name GOG_KEYRING_PASSWORD: gog-keyring-password - forced_env: + + # Template interpolation: {{ name }} substituted inline + DATABASE_URL: "postgres://app:{{ db-password }}@localhost/mydb" + + # Literal value: no credential refs, used as-is GOG_ENABLE_COMMANDS: 'gmail,calendar,drive,tasks,contacts,keep,time' ``` -The agent cannot change `GOG_ENABLE_COMMANDS` — it's stripped from inherited environment and set by -the daemon. +The agent cannot override any `env` entry — values are stripped from inherited environment and set by the daemon. + +> **Deprecated:** `forced_env` is deprecated. Use `env` instead — literal values (without credential refs) work the same way. ### Output redaction diff --git a/internal/config/config.go b/internal/config/config.go index dfd2fb9..cb8ebbc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -171,9 +171,11 @@ type CredentialDef struct { // ToolDef defines a wrapped tool. type ToolDef struct { - Binary string `yaml:"binary"` - Timeout string `yaml:"timeout,omitempty"` - Env map[string]string `yaml:"env,omitempty"` + Binary string `yaml:"binary"` + Timeout string `yaml:"timeout,omitempty"` + Env map[string]string `yaml:"env,omitempty"` // Unified env: credential refs, {{ interpolation }}, or literals + // Deprecated: Use Env instead. ForcedEnv values are always treated as literals. + // Will be removed in a future version. ForcedEnv map[string]string `yaml:"forced_env,omitempty"` Mode string `yaml:"mode,omitempty"` // "blocklist" (default) or "allowlist" BlockedArgs []BlockedArg `yaml:"blocked_args,omitempty"` @@ -282,15 +284,24 @@ func (c *Config) Validate() error { log.Printf("[WARN] tool %q: binary %q not found on disk: %v", toolName, tool.Binary, err) } - for envVar, credName := range tool.Env { + // Build credential names set for validation + credNames := CredentialNamesSet(c.Credentials) + + // Validate env entries (unified: credential refs, {{ interpolation }}, or literals) + for envVar, value := range tool.Env { if !envVarNameRegex.MatchString(envVar) { return fmt.Errorf("tool %q: invalid env var name %q", toolName, envVar) } - if _, ok := c.Credentials[credName]; !ok { - return fmt.Errorf("tool %q: references undefined credential %q", toolName, credName) + // Validate any credential references in the value + if missing := ValidateEnvRefs(value, credNames); len(missing) > 0 { + return fmt.Errorf("tool %q: env %q references undefined credential(s): %v", toolName, envVar, missing) } } + // Validate forced_env (deprecated) - emit warning and validate var names + if len(tool.ForcedEnv) > 0 { + log.Printf("[WARN] tool %q: forced_env is deprecated, use env instead (values without credential refs are treated as literals)", toolName) + } for envVar := range tool.ForcedEnv { if !envVarNameRegex.MatchString(envVar) { return fmt.Errorf("tool %q: invalid forced_env var name %q", toolName, envVar) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index e63c4c4..b4db55b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -806,24 +806,43 @@ func TestValidate_EmptyBinary(t *testing.T) { } func TestValidate_MissingCredentialRef(t *testing.T) { + // With unified env, plain values are literals (allowed). + // Template refs {{ name }} must reference defined credentials. cfg := &Config{ Credentials: map[string]CredentialDef{}, // no credentials defined Tools: map[string]ToolDef{ "gh": { Binary: "/usr/bin/gh", - Env: map[string]string{"GH_TOKEN": "nonexistent-cred"}, + Env: map[string]string{"GH_TOKEN": "{{ nonexistent-cred }}"}, }, }, } err := cfg.Validate() if err == nil { - t.Fatal("Validate() should reject undefined credential ref") + t.Fatal("Validate() should reject undefined credential ref in template") } if !strings.Contains(err.Error(), "undefined credential") { t.Errorf("error = %v, want 'undefined credential'", err) } } +func TestValidate_LiteralEnvValue(t *testing.T) { + // Plain values without {{ refs }} are treated as literals (allowed). + cfg := &Config{ + Credentials: map[string]CredentialDef{}, + Tools: map[string]ToolDef{ + "gh": { + Binary: "/usr/bin/gh", + Env: map[string]string{"PATH": "/usr/bin:/bin"}, + }, + }, + } + err := cfg.Validate() + if err != nil { + t.Errorf("Validate() unexpected error for literal env value: %v", err) + } +} + func TestValidate_ValidCredentialRef(t *testing.T) { cfg := &Config{ Credentials: map[string]CredentialDef{ diff --git a/internal/config/interpolate.go b/internal/config/interpolate.go new file mode 100644 index 0000000..78946ce --- /dev/null +++ b/internal/config/interpolate.go @@ -0,0 +1,138 @@ +// Package config handles loading and parsing the wrappers.yaml configuration. +package config + +import ( + "fmt" + "regexp" +) + +// credentialNameRe matches valid credential names in {{ name }} templates. +// More restrictive than the existing credentialRefRe: only allows valid credential name chars. +var credentialNameRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_-]+)\s*\}\}`) + +// FindCredentialRefs extracts all credential reference names from a template string. +// Returns unique names in order of first appearance. +// Example: "prefix:{{ foo }}:{{ bar }}:{{ foo }}" → ["foo", "bar"] +func FindCredentialRefs(value string) []string { + matches := credentialNameRe.FindAllStringSubmatch(value, -1) + if len(matches) == 0 { + return nil + } + + seen := make(map[string]bool) + refs := make([]string, 0, len(matches)) + for _, m := range matches { + if len(m) > 1 { + name := m[1] + if !seen[name] { + seen[name] = true + refs = append(refs, name) + } + } + } + return refs +} + +// HasCredentialRefs returns true if the value contains any {{ name }} templates. +func HasCredentialRefs(value string) bool { + return credentialNameRe.MatchString(value) +} + +// Interpolate replaces all {{ name }} templates in value with resolved credentials. +// The resolver function is called for each unique credential name. +// Returns error if any credential resolution fails. +func Interpolate(value string, resolver func(name string) (string, error)) (string, error) { + refs := FindCredentialRefs(value) + if len(refs) == 0 { + return value, nil + } + + // Resolve all unique credentials first + resolved := make(map[string]string, len(refs)) + for _, name := range refs { + secret, err := resolver(name) + if err != nil { + return "", fmt.Errorf("credential %q: %w", name, err) + } + resolved[name] = secret + } + + // Replace all occurrences + result := credentialNameRe.ReplaceAllStringFunc(value, func(match string) string { + // Extract name from match (handles whitespace) + m := credentialNameRe.FindStringSubmatch(match) + if len(m) > 1 { + return resolved[m[1]] + } + return match + }) + + return result, nil +} + +// ClassifyEnvValue determines how an env value should be resolved: +// - "credential": value is an exact credential name → fetch entire value +// - "interpolate": value contains {{ refs }} → interpolate +// - "literal": no credential refs → use as-is +func ClassifyEnvValue(value string, credentialNames map[string]struct{}) string { + // Check exact match first + if _, exists := credentialNames[value]; exists { + return "credential" + } + // Check for template refs + if HasCredentialRefs(value) { + return "interpolate" + } + return "literal" +} + +// ValidateEnvRefs validates that all credential references in an env value exist. +// For exact credential matches, checks the value itself. +// For interpolated values, checks all {{ name }} refs. +// Returns list of missing credential names, or nil if all valid. +func ValidateEnvRefs(value string, credentialNames map[string]struct{}) []string { + classification := ClassifyEnvValue(value, credentialNames) + + switch classification { + case "credential": + // Already validated by ClassifyEnvValue returning "credential" + return nil + case "interpolate": + refs := FindCredentialRefs(value) + var missing []string + for _, ref := range refs { + if _, exists := credentialNames[ref]; !exists { + missing = append(missing, ref) + } + } + return missing + default: + return nil + } +} + +// ResolveEnvValue resolves an env value to its final string. +// Handles all three cases: exact credential, interpolated, and literal. +func ResolveEnvValue(value string, credentialNames map[string]struct{}, resolver func(name string) (string, error)) (string, error) { + classification := ClassifyEnvValue(value, credentialNames) + + switch classification { + case "credential": + return resolver(value) + case "interpolate": + return Interpolate(value, resolver) + default: + return value, nil + } +} + +// CredentialNamesSet builds a set of credential names from a credentials map. +// Helper to avoid repeatedly building this set. +func CredentialNamesSet(credentials map[string]CredentialDef) map[string]struct{} { + names := make(map[string]struct{}, len(credentials)) + for name := range credentials { + names[name] = struct{}{} + } + return names +} + diff --git a/internal/config/interpolate_test.go b/internal/config/interpolate_test.go new file mode 100644 index 0000000..657dac3 --- /dev/null +++ b/internal/config/interpolate_test.go @@ -0,0 +1,366 @@ +package config + +import ( + "errors" + "testing" +) + +func TestFindCredentialRefs(t *testing.T) { + tests := []struct { + name string + value string + want []string + }{ + { + name: "no refs", + value: "literal value", + want: nil, + }, + { + name: "single ref", + value: "prefix:{{ foo }}:suffix", + want: []string{"foo"}, + }, + { + name: "multiple refs", + value: "{{ foo }}:{{ bar }}", + want: []string{"foo", "bar"}, + }, + { + name: "duplicate refs", + value: "{{ foo }}:{{ bar }}:{{ foo }}", + want: []string{"foo", "bar"}, + }, + { + name: "whitespace tolerance", + value: "{{ foo }}:{{bar}}:{{ baz }}", + want: []string{"foo", "bar", "baz"}, + }, + { + name: "hyphens in name", + value: "{{ my-secret }}", + want: []string{"my-secret"}, + }, + { + name: "underscores in name", + value: "{{ my_secret }}", + want: []string{"my_secret"}, + }, + { + name: "numbers in name", + value: "{{ secret123 }}", + want: []string{"secret123"}, + }, + { + name: "unclosed brace - literal", + value: "literal {{ text", + want: nil, + }, + { + name: "empty braces - no match", + value: "{{ }}", + want: nil, + }, + { + name: "nested braces - extracts inner", + value: "{{ {{ inner }} }}", + want: []string{"inner"}, // regex finds valid inner match + }, + { + name: "empty string", + value: "", + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FindCredentialRefs(tt.value) + if !slicesEqual(got, tt.want) { + t.Errorf("FindCredentialRefs(%q) = %v, want %v", tt.value, got, tt.want) + } + }) + } +} + +func TestHasCredentialRefs(t *testing.T) { + tests := []struct { + value string + want bool + }{ + {"literal", false}, + {"{{ foo }}", true}, + {"prefix:{{ foo }}:suffix", true}, + {"{{ }}", false}, // empty ref doesn't count + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + got := HasCredentialRefs(tt.value) + if got != tt.want { + t.Errorf("HasCredentialRefs(%q) = %v, want %v", tt.value, got, tt.want) + } + }) + } +} + +func TestInterpolate(t *testing.T) { + secrets := map[string]string{ + "foo": "secret-foo", + "bar": "secret-bar", + "my-secret": "hyphenated-value", + } + resolver := func(name string) (string, error) { + if v, ok := secrets[name]; ok { + return v, nil + } + return "", errors.New("not found") + } + + tests := []struct { + name string + value string + want string + wantErr bool + }{ + { + name: "no refs - passthrough", + value: "literal value", + want: "literal value", + }, + { + name: "single ref", + value: "prefix:{{ foo }}:suffix", + want: "prefix:secret-foo:suffix", + }, + { + name: "multiple refs", + value: "{{ foo }}:{{ bar }}", + want: "secret-foo:secret-bar", + }, + { + name: "duplicate refs", + value: "{{ foo }}:{{ foo }}", + want: "secret-foo:secret-foo", + }, + { + name: "whitespace tolerance", + value: "{{ foo }}", + want: "secret-foo", + }, + { + name: "hyphenated name", + value: "{{ my-secret }}", + want: "hyphenated-value", + }, + { + name: "missing credential", + value: "{{ missing }}", + wantErr: true, + }, + { + name: "empty string", + value: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Interpolate(tt.value, resolver) + if tt.wantErr { + if err == nil { + t.Errorf("Interpolate(%q) error = nil, want error", tt.value) + } + return + } + if err != nil { + t.Errorf("Interpolate(%q) error = %v, want nil", tt.value, err) + return + } + if got != tt.want { + t.Errorf("Interpolate(%q) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} + +func TestClassifyEnvValue(t *testing.T) { + creds := map[string]struct{}{ + "github-token": {}, + "db-password": {}, + } + + tests := []struct { + value string + want string + }{ + {"github-token", "credential"}, + {"db-password", "credential"}, + {"/usr/bin:/bin", "literal"}, + {"literal-value", "literal"}, + {"postgres://user:{{ db-password }}@localhost/db", "interpolate"}, + {"{{ github-token }}", "interpolate"}, // template syntax, not exact match + {"unknown-cred", "literal"}, // doesn't match any credential name + } + + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + got := ClassifyEnvValue(tt.value, creds) + if got != tt.want { + t.Errorf("ClassifyEnvValue(%q) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} + +func TestValidateEnvRefs(t *testing.T) { + creds := map[string]struct{}{ + "github-token": {}, + "db-password": {}, + } + + tests := []struct { + name string + value string + missing []string + }{ + { + name: "exact credential", + value: "github-token", + missing: nil, + }, + { + name: "valid interpolation", + value: "prefix:{{ db-password }}:suffix", + missing: nil, + }, + { + name: "literal", + value: "/usr/bin", + missing: nil, + }, + { + name: "missing ref", + value: "{{ missing }}", + missing: []string{"missing"}, + }, + { + name: "multiple missing", + value: "{{ foo }}:{{ bar }}", + missing: []string{"foo", "bar"}, + }, + { + name: "one valid one missing", + value: "{{ github-token }}:{{ missing }}", + missing: []string{"missing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateEnvRefs(tt.value, creds) + if !slicesEqual(got, tt.missing) { + t.Errorf("ValidateEnvRefs(%q) = %v, want %v", tt.value, got, tt.missing) + } + }) + } +} + +func TestResolveEnvValue(t *testing.T) { + creds := map[string]struct{}{ + "github-token": {}, + "db-password": {}, + } + secrets := map[string]string{ + "github-token": "ghp_xxx", + "db-password": "hunter2", + } + resolver := func(name string) (string, error) { + if v, ok := secrets[name]; ok { + return v, nil + } + return "", errors.New("not found") + } + + tests := []struct { + name string + value string + want string + wantErr bool + }{ + { + name: "exact credential", + value: "github-token", + want: "ghp_xxx", + }, + { + name: "interpolation", + value: "postgres://user:{{ db-password }}@localhost/db", + want: "postgres://user:hunter2@localhost/db", + }, + { + name: "literal", + value: "/usr/bin:/bin", + want: "/usr/bin:/bin", + }, + { + name: "missing credential", + value: "missing-cred", + want: "missing-cred", // literal, not error + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ResolveEnvValue(tt.value, creds, resolver) + if tt.wantErr { + if err == nil { + t.Errorf("ResolveEnvValue(%q) error = nil, want error", tt.value) + } + return + } + if err != nil { + t.Errorf("ResolveEnvValue(%q) error = %v, want nil", tt.value, err) + return + } + if got != tt.want { + t.Errorf("ResolveEnvValue(%q) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} + +func TestCredentialNamesSet(t *testing.T) { + creds := map[string]CredentialDef{ + "foo": {Source: "pass:foo"}, + "bar": {Source: "env:BAR"}, + } + + set := CredentialNamesSet(creds) + + if _, ok := set["foo"]; !ok { + t.Error("expected 'foo' in set") + } + if _, ok := set["bar"]; !ok { + t.Error("expected 'bar' in set") + } + if _, ok := set["baz"]; ok { + t.Error("unexpected 'baz' in set") + } +} + +// slicesEqual compares two string slices for equality. +func slicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/daemon/executor.go b/internal/daemon/executor.go index f668d3c..ef926fa 100644 --- a/internal/daemon/executor.go +++ b/internal/daemon/executor.go @@ -259,11 +259,12 @@ func (e *ToolExecutor) buildEnvironment() ([]string, error) { envMap["TERM"] = "dumb" } - // Add credentials from tool.Env - for envVar, credName := range e.tool.Env { - credDef, ok := e.cfg.Credentials[credName] + // Build credential names set and resolver for unified env handling + credNames := config.CredentialNamesSet(e.cfg.Credentials) + credResolver := func(name string) (string, error) { + credDef, ok := e.cfg.Credentials[name] if !ok { - return nil, fmt.Errorf("missing credential config: %s", credName) + return "", fmt.Errorf("undefined credential: %s", name) } value, err := credentials.Fetch( credDef.Source, @@ -272,16 +273,28 @@ func (e *ToolExecutor) buildEnvironment() ([]string, error) { credentials.WithBWBinary(e.cfg.GetBWBinary()), ) if err != nil { - return nil, fmt.Errorf("fetch credential %s: %w", credName, err) + return "", err } if value == "" { - return nil, fmt.Errorf("empty credential: %s", credName) + return "", fmt.Errorf("credential %s returned empty value", name) } - envMap[envVar] = value + return value, nil } - // Add forced_env (these cannot be overridden) + // All env entries are admin-controlled (forced), cannot be overridden by request forcedKeys := make(map[string]bool) + + // Process unified env: credential refs, {{ interpolation }}, or literals + for envVar, value := range e.tool.Env { + resolved, err := config.ResolveEnvValue(value, credNames, credResolver) + if err != nil { + return nil, fmt.Errorf("env %s: %w", envVar, err) + } + envMap[envVar] = resolved + forcedKeys[envVar] = true + } + + // Process deprecated forced_env (always treated as literals) for k, v := range e.tool.ForcedEnv { envMap[k] = v forcedKeys[k] = true diff --git a/internal/daemon/executor_env_test.go b/internal/daemon/executor_env_test.go index dd10b48..c0b0317 100644 --- a/internal/daemon/executor_env_test.go +++ b/internal/daemon/executor_env_test.go @@ -412,3 +412,85 @@ func TestBuildEnvironment_UseProxyRequiresToken(t *testing.T) { t.Fatalf("error = %v, want missing proxy auth token", err) } } + +func TestBuildEnvironment_UnifiedEnvLiteralValue(t *testing.T) { + // Env value without credential refs is treated as literal + executor := &ToolExecutor{ + req: &protocol.ProxyRequest{}, + tool: &config.ToolDef{ + Env: map[string]string{ + "MY_LITERAL": "/usr/bin:/bin", + }, + }, + cfg: &config.Config{}, + } + + env, err := executor.buildEnvironment() + if err != nil { + t.Fatalf("buildEnvironment() error: %v", err) + } + + val, found := envContains(env, "MY_LITERAL") + if !found { + t.Fatal("buildEnvironment() missing MY_LITERAL") + } + if val != "/usr/bin:/bin" { + t.Errorf("MY_LITERAL = %q, want %q", val, "/usr/bin:/bin") + } +} + +func TestBuildEnvironment_UnifiedEnvCannotBeOverridden(t *testing.T) { + // All unified env entries are forced (cannot be overridden by request) + executor := &ToolExecutor{ + req: &protocol.ProxyRequest{ + Env: map[string]string{ + "MY_VAR": "attacker-value", + }, + }, + tool: &config.ToolDef{ + Env: map[string]string{ + "MY_VAR": "admin-value", + }, + }, + cfg: &config.Config{}, + } + + env, err := executor.buildEnvironment() + if err != nil { + t.Fatalf("buildEnvironment() error: %v", err) + } + + val, found := envContains(env, "MY_VAR") + if !found { + t.Fatal("buildEnvironment() missing MY_VAR") + } + if val != "admin-value" { + t.Errorf("MY_VAR = %q, want %q (unified env must win over request)", val, "admin-value") + } +} + +func TestBuildEnvironment_UnifiedEnvAllowsDangerousVars(t *testing.T) { + // Admin-controlled env can set dangerous vars (like forced_env) + executor := &ToolExecutor{ + req: &protocol.ProxyRequest{}, + tool: &config.ToolDef{ + Env: map[string]string{ + "LD_PRELOAD": "/lib/admin-controlled.so", + }, + }, + cfg: &config.Config{}, + } + + env, err := executor.buildEnvironment() + if err != nil { + t.Fatalf("buildEnvironment() error: %v", err) + } + + val, found := envContains(env, "LD_PRELOAD") + if !found { + t.Fatal("buildEnvironment() should allow LD_PRELOAD via unified env (admin-controlled)") + } + if val != "/lib/admin-controlled.so" { + t.Errorf("LD_PRELOAD = %q, want %q", val, "/lib/admin-controlled.so") + } +}