diff --git a/internal/cas/benchmark_test.go b/internal/cas/benchmark_test.go index 7d512f527d..ba11d04202 100644 --- a/internal/cas/benchmark_test.go +++ b/internal/cas/benchmark_test.go @@ -21,6 +21,9 @@ func BenchmarkClone(b *testing.B) { l := logger.CreateLogger() + v, err := cas.OSVenv() + require.NoError(b, err) + b.Run("fresh clone", func(b *testing.B) { tempDir := b.TempDir() @@ -37,10 +40,8 @@ func BenchmarkClone(b *testing.B) { b.StartTimer() - require.NoError(b, c.Clone(b.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - Depth: -1, - }, repoURL)) + require.NoError(b, c.Clone(b.Context(), l, v, repoURL, cas.WithDir(targetPath), + cas.WithDepth(-1))) } }) @@ -52,10 +53,8 @@ func BenchmarkClone(b *testing.B) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(b, err) - require.NoError(b, c.Clone(b.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "initial"), - Depth: -1, - }, repoURL)) + require.NoError(b, c.Clone(b.Context(), l, v, repoURL, cas.WithDir(filepath.Join(tempDir, "initial")), + cas.WithDepth(-1))) b.ResetTimer() @@ -69,10 +68,8 @@ func BenchmarkClone(b *testing.B) { b.StartTimer() - require.NoError(b, c.Clone(b.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - Depth: -1, - }, repoURL)) + require.NoError(b, c.Clone(b.Context(), l, v, repoURL, cas.WithDir(targetPath), + cas.WithDepth(-1))) } }) } @@ -87,6 +84,9 @@ func BenchmarkContent(b *testing.B) { l := logger.CreateLogger() + v, err := cas.OSVenv() + require.NoError(b, err) + b.Run("store", func(b *testing.B) { for i := 0; b.Loop(); i++ { b.StopTimer() @@ -95,7 +95,7 @@ func BenchmarkContent(b *testing.B) { b.StartTimer() - require.NoError(b, content.Store(l, hash, testData)) + require.NoError(b, content.Store(l, v, hash, testData)) } }) @@ -121,7 +121,7 @@ func BenchmarkContent(b *testing.B) { mu.Unlock() - if err := content.Store(l, hash, testData); err != nil { + if err := content.Store(l, v, hash, testData); err != nil { b.Fatal(err) } diff --git a/internal/cas/cas.go b/internal/cas/cas.go index 0b9dd6ec24..e3ea8d9869 100644 --- a/internal/cas/cas.go +++ b/internal/cas/cas.go @@ -17,9 +17,7 @@ import ( "github.com/gruntwork-io/terragrunt/internal/errors" "github.com/gruntwork-io/terragrunt/internal/git" - "github.com/gruntwork-io/terragrunt/internal/telemetry" "github.com/gruntwork-io/terragrunt/internal/util" - "github.com/gruntwork-io/terragrunt/internal/vexec" "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/pkg/log" ) @@ -55,14 +53,50 @@ type CloneOptions struct { Mutable bool } +// CloneOption customizes a single [CAS.Clone] invocation. Each option +// mutates the per-call CloneOptions before the clone runs. Options are +// applied in order, so a later option overwrites fields set by an +// earlier one. +type CloneOption func(*CloneOptions) + +// WithDir sets the target directory for the clone. Empty means use the +// repository name. +func WithDir(dir string) CloneOption { + return func(o *CloneOptions) { o.Dir = dir } +} + +// WithBranch sets the branch to clone. Empty means HEAD. +func WithBranch(branch string) CloneOption { + return func(o *CloneOptions) { o.Branch = branch } +} + +// WithIncludedGitFiles preserves the named files from the .git +// directory in the materialized clone. +func WithIncludedGitFiles(files []string) CloneOption { + return func(o *CloneOptions) { o.IncludedGitFiles = files } +} + +// WithDepth sets the `git clone --depth` value for this clone. Positive +// values request a shallow clone; -1 means full history (Terragrunt +// omits --depth; git rejects --depth 0). Zero falls back to the CAS- +// wide default configured via [WithCloneDepth]. +func WithDepth(depth int) CloneOption { + return func(o *CloneOptions) { o.Depth = depth } +} + +// WithMutable copies blobs into the target directory instead of +// hardlinking from the CAS store, so the destination tree is safe to +// mutate without corrupting the shared store. +func WithMutable(mutable bool) CloneOption { + return func(o *CloneOptions) { o.Mutable = mutable } +} + // CAS clones a git repository using content-addressable storage. type CAS struct { - fs vfs.FS blobStore *Store treeStore *Store synthStore *Store gitStore *GitStore - git *git.GitRunner storePath string cloneDepth int } @@ -77,24 +111,15 @@ func WithStorePath(path string) Option { } } -// WithCloneDepth sets git clone --depth for CAS (positive shallow clone; negative, -// e.g. -1, means full history with no --depth). Terragrunt validates user-supplied -// values with ValidateCASCloneDepth (zero is invalid for git). Omit this option to -// keep cloneDepth unset so per-operation CloneOptions.Depth can fall back to DefaultCASCloneDepth. +// WithCloneDepth sets git clone --depth for CAS. Positive values request a +// shallow clone; -1 means full history (no --depth). Zero is invalid for git +// and is rejected by [ValidateCASCloneDepth] before reaching this option. func WithCloneDepth(depth int) Option { return func(c *CAS) { c.cloneDepth = depth } } -// WithFS specifies the filesystem for file operations. -// If not set, defaults to the real OS filesystem. -func WithFS(fs vfs.FS) Option { - return func(c *CAS) { - c.fs = fs - } -} - // New creates a new CAS instance with the given options. func New(opts ...Option) (*CAS, error) { c := &CAS{} @@ -103,17 +128,6 @@ func New(opts ...Option) (*CAS, error) { opt(c) } - if c.fs == nil { - c.fs = vfs.NewOSFS() - } - - // CAS shells out to git, which only sees the real disk. Validate - // here so a non-OS backing fails at the constructor instead of - // from a deeper store-init step. - if !vfs.IsOSFS(c.fs) { - return nil, ErrGitStoreFSNotOS - } - if c.storePath == "" { cacheDir, err := util.EnsureCacheDir() if err != nil { @@ -123,42 +137,14 @@ func New(opts ...Option) (*CAS, error) { c.storePath = filepath.Join(cacheDir, "cas", "store") } - if err := c.fs.MkdirAll(c.storePath, DefaultDirPerms); err != nil { - return nil, fmt.Errorf("failed to create CAS store path: %w", err) - } - - c.blobStore = NewStore(filepath.Join(c.storePath, "blobs")).WithFS(c.fs) - c.treeStore = NewStore(filepath.Join(c.storePath, "trees")).WithFS(c.fs) - c.synthStore = NewStore(filepath.Join(c.storePath, "synth", "trees")).WithFS(c.fs) - - for _, s := range []*Store{c.blobStore, c.treeStore, c.synthStore} { - if err := c.fs.MkdirAll(s.Path(), DefaultDirPerms); err != nil { - return nil, fmt.Errorf("failed to create CAS store subdirectory %s: %w", s.Path(), err) - } - } - - g, err := git.NewGitRunner(vexec.NewOSExec()) - if err != nil { - return nil, err - } - - c.git = g - - gs, err := NewGitStore(c.fs, g, filepath.Join(c.storePath, "git")) - if err != nil { - return nil, err - } - - c.gitStore = gs + c.blobStore = NewStore(filepath.Join(c.storePath, "blobs")) + c.treeStore = NewStore(filepath.Join(c.storePath, "trees")) + c.synthStore = NewStore(filepath.Join(c.storePath, "synth", "trees")) + c.gitStore = NewGitStore(filepath.Join(c.storePath, "git")) return c, nil } -// FS returns the configured filesystem. -func (c *CAS) FS() vfs.FS { - return c.fs -} - // BlobStore returns the store for blob content. func (c *CAS) BlobStore() *Store { return c.blobStore } @@ -168,62 +154,128 @@ func (c *CAS) TreeStore() *Store { return c.treeStore } // SynthStore returns the store for synthetic tree content. func (c *CAS) SynthStore() *Store { return c.synthStore } -// Clone performs the clone operation -// -// TODO: Make options optional -func (c *CAS) Clone(ctx context.Context, l log.Logger, opts *CloneOptions, url string) error { - if err := c.ensureCloneStores(); err != nil { - return err +// StorePath returns the root directory containing every CAS store. +func (c *CAS) StorePath() string { return c.storePath } + +// ensureStorePaths creates the store directory hierarchy on v.FS. Callers +// invoke this from any top-level entry point that may write to a store, so +// the directories appear lazily on first use rather than at construction. +func (c *CAS) ensureStorePaths(v Venv) error { + if !vfs.IsOSFS(v.FS) { + return ErrGitStoreFSNotOS } - return telemetry.TelemeterFromContext(ctx).Collect(ctx, "cas_clone", map[string]any{ - "url": url, - "branch": opts.Branch, - }, func(childCtx context.Context) error { - ref, err := c.resolveReference(childCtx, url, opts.Branch) - if err != nil { - return err + if err := v.FS.MkdirAll(c.storePath, DefaultDirPerms); err != nil { + return fmt.Errorf("create CAS store path: %w", err) + } + + for _, s := range []*Store{c.blobStore, c.treeStore, c.synthStore} { + if err := v.FS.MkdirAll(s.Path(), DefaultDirPerms); err != nil { + return fmt.Errorf("create CAS store subdirectory %s: %w", s.Path(), err) } + } - targetDir := c.prepareTargetDirectory(opts.Dir, url) + return nil +} - canonicalHash, err := c.populateTreeFromRef(childCtx, l, opts, ref) - if err != nil { - return err - } +// GitResolver is a [SourceResolver] for git URLs. +// +// Branch travels as a field rather than as a URL query parameter so +// SCP-form URLs (`git@host:path`) reach git intact: net/url.Parse +// rejects SCP form, so any encoding scheme that round-trips through +// it silently loses the branch. +type GitResolver struct { + // Venv supplies the git runner Probe shells out through. Required. + Venv Venv + // Store enables an offline fast path: a full-length SHA in + // [GitResolver.Branch] is checked against the local store before + // reaching ls-remote. When nil, every Probe runs ls-remote. + Store *GitStore + // Branch is the ref to query. Empty means HEAD. + Branch string +} - treeContent := NewContent(c.treeStore) +// Scheme returns "git". +func (r *GitResolver) Scheme() string { return "git" } - treeData, err := treeContent.Read(canonicalHash) - if err != nil { - return err +// Probe returns the commit SHA for r.Branch (HEAD when empty). The +// returned SHA is the cache key verbatim and doubles as the git +// object name the fetcher consumes. +// +// `git ls-remote` is authoritative; ls-remote misses (the caller +// supplied a commit-form ref directly) surface as +// [ErrNoVersionMetadata] so the fetcher canonicalizes via rev-parse. +// When [GitResolver.Store] is set, a full-length SHA in r.Branch +// short-circuits ls-remote on a local cache hit. +func (r *GitResolver) Probe(ctx context.Context, rawURL string) (string, error) { + if r.Store != nil && looksLikeFullSHA(r.Branch) { + if hash, ok := r.Store.ProbeCachedCommit(ctx, r.Venv, rawURL, r.Branch); ok { + return hash, nil } + } - tree, err := git.ParseTree(treeData, targetDir) - if err != nil { - return err + results, err := r.Venv.Git.LsRemote(ctx, rawURL, r.Branch) + if err != nil { + if errors.Is(err, git.ErrNoMatchingReference) { + return "", ErrNoVersionMetadata } - var linkOpts []LinkTreeOption - if opts.Mutable { - linkOpts = append(linkOpts, WithForceCopy()) - } + return "", err + } - return LinkTree(childCtx, c.blobStore, c.treeStore, tree, targetDir, linkOpts...) + if len(results) == 0 { + return "", ErrNoVersionMetadata + } + + return results[0].Hash, nil +} + +// Clone fetches url into the target directory through the CAS, using a +// [GitResolver] for the probe and ingesting via `git ls-tree -r` / +// `git cat-file` so the native git blob and tree formats reach the +// stores intact. Callers customize the clone by passing options such as +// [WithDir], [WithBranch], or [WithDepth]; calling Clone with no +// options runs against the zero CloneOptions. +// +// Requires v.FS for store I/O and v.Git for the underlying clone. +// Panics with [ErrVenvFSUnset] or [ErrVenvGitUnset] if either is unset. +func (c *CAS) Clone(ctx context.Context, l log.Logger, v Venv, url string, options ...CloneOption) error { + v.RequireFS() + v.RequireGit() + + opts := CloneOptions{} + for _, opt := range options { + opt(&opts) + } + + clonedOpts := opts + clonedOpts.Dir = c.prepareTargetDirectory(opts.Dir, url) + + return c.FetchSource(ctx, l, v, &clonedOpts, SourceRequest{ + Scheme: "git", + URL: url, + Resolver: &GitResolver{Venv: v, Store: c.gitStore, Branch: opts.Branch}, + Fetch: c.gitFetcher(url, &opts), + Attrs: map[string]any{"branch": opts.Branch}, }) } -// ensureCloneStores creates the blob and tree store directories that -// [CAS.Clone] writes to. Defensive: [New] already creates them, but a -// long-lived [CAS] instance could see them removed between calls. -func (c *CAS) ensureCloneStores() error { - for _, s := range []*Store{c.blobStore, c.treeStore} { - if err := c.fs.MkdirAll(s.Path(), DefaultDirPerms); err != nil { - return fmt.Errorf("create CAS store path %s: %w", s.Path(), err) +// gitFetcher returns a SourceFetcher that ingests through the git-native +// path (cat-file + ls-tree). A non-empty suggestedKey is the canonical +// commit SHA from ls-remote; empty means ls-remote produced no match and +// rev-parse against the central GitStore canonicalizes the user ref after +// fetching. +func (c *CAS) gitFetcher(url string, opts *CloneOptions) SourceFetcher { + return func(ctx context.Context, l log.Logger, v Venv, suggestedKey string) (string, error) { + var ref resolvedRef + if suggestedKey != "" { + ref = &symbolicRef{URL: url, Branch: opts.Branch, Hash: suggestedKey} + } else { + ref = &commitRef{URL: url, RawRef: opts.Branch} } - } - return nil + return c.populateTreeFromRef(ctx, l, v, opts, ref) + } } // populateTreeFromRef dispatches by ref kind, short-circuiting on a @@ -232,27 +284,28 @@ func (c *CAS) ensureCloneStores() error { func (c *CAS) populateTreeFromRef( ctx context.Context, l log.Logger, + v Venv, opts *CloneOptions, ref resolvedRef, ) (string, error) { switch ref := ref.(type) { case *symbolicRef: - if !c.treeStore.NeedsWrite(ref.Hash) { + if !c.treeStore.NeedsWrite(v, ref.Hash) { return ref.Hash, nil } - if err := c.populateTreeFromSymbolicRef(ctx, l, opts, ref); err != nil { + if err := c.populateTreeFromSymbolicRef(ctx, l, v, opts, ref); err != nil { return "", err } return ref.Hash, nil case *commitRef: - if ref.Hash != "" && !c.treeStore.NeedsWrite(ref.Hash) { + if ref.Hash != "" && !c.treeStore.NeedsWrite(v, ref.Hash) { return ref.Hash, nil } - return c.populateTreeFromCommitRef(ctx, l, opts, ref) + return c.populateTreeFromCommitRef(ctx, l, v, opts, ref) default: return "", fmt.Errorf("unsupported resolved ref type %T", ref) @@ -266,36 +319,37 @@ func (c *CAS) populateTreeFromRef( func (c *CAS) populateTreeFromSymbolicRef( ctx context.Context, l log.Logger, + v Venv, opts *CloneOptions, ref *symbolicRef, ) error { depth := resolveCloneDepth(opts.Depth, c.cloneDepth) - repo, err := c.gitStore.EnsureRef(ctx, l, c.fs, ref.URL, ref.Branch, ref.Hash, depth) + repo, err := c.gitStore.EnsureRef(ctx, l, v, ref.URL, ref.Branch, ref.Hash, depth) if err == nil { defer repo.Release(l) - runner := c.git.WithWorkDir(repo.Path) + runner := v.Git.WithWorkDir(repo.Path) - return c.storeRootTreeFrom(ctx, l, runner, ref.Hash, opts) + return c.storeRootTreeFrom(ctx, l, v, runner, ref.Hash, opts) } l.Warnf("central git store unavailable for %s, falling back to temporary clone: %v", ref.URL, err) - tempDir, cleanup, err := c.makeFallbackCloneDir(l) + tempDir, cleanup, err := c.makeFallbackCloneDir(l, v) if err != nil { return err } defer cleanup() - runner := c.git.WithWorkDir(tempDir) + runner := v.Git.WithWorkDir(tempDir) if err := runner.Clone(ctx, ref.URL, true, depth, ref.Branch); err != nil { return err } - return c.storeRootTreeFrom(ctx, l, runner, ref.Hash, opts) + return c.storeRootTreeFrom(ctx, l, v, runner, ref.Hash, opts) } // populateTreeFromCommitRef resolves ref via [GitStore.EnsureCommit] @@ -305,20 +359,21 @@ func (c *CAS) populateTreeFromSymbolicRef( func (c *CAS) populateTreeFromCommitRef( ctx context.Context, l log.Logger, + v Venv, opts *CloneOptions, ref *commitRef, ) (string, error) { - repo, err := c.gitStore.EnsureCommit(ctx, l, c.fs, ref.URL, ref.RawRef, ref.Hash) + repo, err := c.gitStore.EnsureCommit(ctx, l, v, ref.URL, ref.RawRef, ref.Hash) if err == nil { defer repo.Release(l) - if !c.treeStore.NeedsWrite(repo.Hash) { + if !c.treeStore.NeedsWrite(v, repo.Hash) { return repo.Hash, nil } - runner := c.git.WithWorkDir(repo.Path) + runner := v.Git.WithWorkDir(repo.Path) - if err := c.storeRootTreeFrom(ctx, l, runner, repo.Hash, opts); err != nil { + if err := c.storeRootTreeFrom(ctx, l, v, runner, repo.Hash, opts); err != nil { return "", err } @@ -331,14 +386,14 @@ func (c *CAS) populateTreeFromCommitRef( l.Warnf("central git store unavailable for %s, falling back to temporary clone: %v", ref.URL, err) - tempDir, cleanup, err := c.makeFallbackCloneDir(l) + tempDir, cleanup, err := c.makeFallbackCloneDir(l, v) if err != nil { return "", err } defer cleanup() - runner := c.git.WithWorkDir(tempDir) + runner := v.Git.WithWorkDir(tempDir) if err := runner.Clone(ctx, ref.URL, true, 0, ""); err != nil { return "", err @@ -357,11 +412,11 @@ func (c *CAS) populateTreeFromCommitRef( return "", err } - if !c.treeStore.NeedsWrite(canonicalHash) { + if !c.treeStore.NeedsWrite(v, canonicalHash) { return canonicalHash, nil } - if err := c.storeRootTreeFrom(ctx, l, runner, canonicalHash, opts); err != nil { + if err := c.storeRootTreeFrom(ctx, l, v, runner, canonicalHash, opts); err != nil { return "", err } @@ -370,14 +425,14 @@ func (c *CAS) populateTreeFromCommitRef( // makeFallbackCloneDir creates a temporary directory for a bare clone // fallback and returns a cleanup function that removes it. -func (c *CAS) makeFallbackCloneDir(l log.Logger) (string, func(), error) { - tempDir, err := vfs.MkdirTemp(c.fs, "", "terragrunt-cas-fallback-*") +func (c *CAS) makeFallbackCloneDir(l log.Logger, v Venv) (string, func(), error) { + tempDir, err := vfs.MkdirTemp(v.FS, "", "terragrunt-cas-fallback-*") if err != nil { return "", nil, fmt.Errorf("create fallback clone dir: %w", errors.Join(ErrFallbackCloneDir, err)) } cleanup := func() { - if rmErr := c.fs.RemoveAll(tempDir); rmErr != nil { + if rmErr := v.FS.RemoveAll(tempDir); rmErr != nil { l.Warnf("cleanup error: %v", rmErr) } } @@ -411,14 +466,10 @@ func (c *CAS) prepareTargetDirectory(dir, url string) string { return filepath.Clean(targetDir) } -// resolvedRef is what [CAS.resolveReference] returns: a [symbolicRef] -// when ls-remote resolved the input to a branch, tag, or HEAD; a -// [commitRef] when it did not. Sealed by package visibility. +// resolvedRef is a sealed sum type returned by [CAS.resolveReference]: +// [symbolicRef] when ls-remote canonicalized the input, [commitRef] +// otherwise. type resolvedRef interface { - // CommitHash returns the canonical commit hash for [symbolicRef] - // and the raw user-supplied ref for [commitRef] (no - // canonicalization: an abbreviated SHA is returned as-is, not - // expanded to a full hash). CommitHash() string } @@ -432,90 +483,58 @@ type symbolicRef struct { // CommitHash returns the canonical commit hash ls-remote resolved. func (r *symbolicRef) CommitHash() string { return r.Hash } -// commitRef carries a ref ls-remote did not canonicalize to a -// commit on the remote. The typical case is a SHA the server does -// not publish as a branch tip, but any user-supplied name -// ls-remote returned no match for funnels here too. Resolution -// against the central git store happens later via rev-parse, with -// a full-history fetch on a cache miss. +// commitRef carries a ref ls-remote did not canonicalize. The central git +// store resolves it later via rev-parse and a full-history fetch on a +// cache miss. type commitRef struct { // URL is the remote repository URL. URL string - - // RawRef is the user-supplied ref. SHAs (full SHA-1, full - // SHA-256, or abbreviated prefixes) are the common case because - // ls-remote does not surface commit hashes, but any name - // ls-remote did not match also lands here. The central git - // store canonicalizes via rev-parse, so any form git accepts - // works. + // RawRef is the user-supplied ref. Any form `git rev-parse` accepts + // works (full SHA, abbreviated SHA, name ls-remote did not match). RawRef string - - // Hash is the canonical full SHA pre-resolved by - // [GitStore.ProbeCachedCommit] before reaching ls-remote; - // empty when the commitRef arose from an ls-remote miss. When - // set, downstream code keys the tree-store short-circuit on it - // and forwards it to [GitStore.EnsureCommit] to skip a - // redundant rev-parse. + // Hash is the canonical SHA when [GitStore.ProbeCachedCommit] + // resolved RawRef locally before reaching ls-remote; empty + // otherwise. Lets downstream code skip a redundant rev-parse. Hash string } -// CommitHash returns the user-supplied ref. Hash is intentionally -// not returned: stacks key the CAS on this value, and the key -// must not depend on whether the central git store happened to -// have the commit cached. +// CommitHash returns the user-supplied ref, not r.Hash. Stacks key the +// CAS on this value, and the key must not depend on whether the central +// git store happened to have the commit cached. func (r *commitRef) CommitHash() string { return r.RawRef } -// resolveReference resolves branch into a [resolvedRef]. -// -// Full-length SHA input (40 or 64 hex chars) is probed against the -// central git store first. A cached hit returns a [*commitRef] -// carrying the canonical hash without contacting the remote, so -// previously-cloned commits succeed offline even when ls-remote -// would fail to spawn. The pre-resolved hash also lets -// [CAS.populateTreeFromCommitRef] skip a redundant rev-parse -// inside [GitStore.EnsureCommit]. Abbreviated SHAs skip the probe -// because a hex-named branch could share the abbreviation as a -// prefix of its tip, which would freeze that branch at the -// first-fetched tip; see [looksLikeFullSHA]. +// resolveReference resolves branch into a [resolvedRef] via [GitResolver]. // -// Otherwise ls-remote is authoritative: a result returns a -// [*symbolicRef] with the canonical hash; an empty result or -// [git.ErrNoMatchingReference] returns a [*commitRef] with an -// empty Hash so the central store resolves the input via rev-parse -// and a full-history fetch. -func (c *CAS) resolveReference(ctx context.Context, url, branch string) (resolvedRef, error) { +// Full-length SHAs (40 or 64 hex chars) are checked against the local git +// store first so previously-cloned commits resolve offline; abbreviated +// SHAs skip the probe to avoid mistaking a hex-named branch tip for the +// SHA prefix and freezing the branch at its first-fetched tip (see +// [looksLikeFullSHA]). +func (c *CAS) resolveReference(ctx context.Context, v Venv, url, branch string) (resolvedRef, error) { if looksLikeFullSHA(branch) { - if hash, ok := c.gitStore.ProbeCachedCommit(ctx, c.fs, url, branch); ok { + if hash, ok := c.gitStore.ProbeCachedCommit(ctx, v, url, branch); ok { return &commitRef{URL: url, RawRef: branch, Hash: hash}, nil } } - results, err := c.git.LsRemote(ctx, url, branch) + r := &GitResolver{Venv: v, Store: c.gitStore, Branch: branch} + + key, err := r.Probe(ctx, url) if err != nil { - if errors.Is(err, git.ErrNoMatchingReference) { + if errors.Is(err, ErrNoVersionMetadata) { return &commitRef{URL: url, RawRef: branch}, nil } return nil, err } - if len(results) == 0 { - return &commitRef{URL: url, RawRef: branch}, nil - } - - return &symbolicRef{URL: url, Branch: branch, Hash: results[0].Hash}, nil + return &symbolicRef{URL: url, Branch: branch, Hash: key}, nil } -// looksLikeFullSHA reports whether s is exactly 40 or 64 hex -// characters, the canonical full lengths for SHA-1 and SHA-256 -// commit hashes. -// -// Abbreviations are rejected so the probe in [CAS.resolveReference] -// cannot mistake a hex-named branch (e.g. branch "a1b2" whose tip -// happens to start with "a1b2...") for a cached commit prefix and -// freeze the branch at its first-fetched tip. Abbreviated SHAs -// still resolve correctly via the [GitStore.EnsureCommit] fallback -// once ls-remote returns no match. +// looksLikeFullSHA reports whether s is exactly 40 or 64 hex characters, +// the canonical lengths for SHA-1 and SHA-256 commit hashes. Abbreviations +// are intentionally rejected; see the offline-probe rationale on +// [CAS.resolveReference]. func looksLikeFullSHA(s string) bool { if len(s) != 40 && len(s) != 64 { return false @@ -533,6 +552,7 @@ func looksLikeFullSHA(s string) bool { func (c *CAS) storeRootTreeFrom( ctx context.Context, l log.Logger, + v Venv, runner *git.GitRunner, hash string, opts *CloneOptions, @@ -542,7 +562,7 @@ func (c *CAS) storeRootTreeFrom( return err } - if err = c.storeTreeRecursive(ctx, l, runner, hash, tree); err != nil { + if err = c.storeTreeRecursive(ctx, l, v, runner, hash, tree); err != nil { return err } @@ -552,13 +572,13 @@ func (c *CAS) storeRootTreeFrom( treeContent := NewContent(c.treeStore) - data, err := treeContent.Read(hash) + data, err := treeContent.Read(v, hash) if err != nil { return err } for _, file := range opts.IncludedGitFiles { - stat, err := c.fs.Stat(filepath.Join(runner.WorkDir, file)) + stat, err := v.FS.Stat(filepath.Join(runner.WorkDir, file)) if err != nil { return err } @@ -569,14 +589,14 @@ func (c *CAS) storeRootTreeFrom( workDirPath := filepath.Join(runner.WorkDir, file) - includedHash, err := hashFile(c.fs, workDirPath) + includedHash, err := hashFile(v.FS, workDirPath) if err != nil { return err } blobContent := NewContent(c.blobStore) - if err := blobContent.EnsureCopy(l, includedHash, workDirPath); err != nil { + if err := blobContent.EnsureCopy(l, v, includedHash, workDirPath); err != nil { return err } @@ -585,43 +605,42 @@ func (c *CAS) storeRootTreeFrom( data = append(data, fmt.Appendf(nil, "%06o blob %s\t%s\n", stat.Mode().Perm(), includedHash, path)...) } - // Overwrite the root tree with the new data - return treeContent.Store(l, hash, data) + return treeContent.Store(l, v, hash, data) } -// storeTreeRecursive stores a tree fetched from git ls-tree -r +// storeTreeRecursive stores a tree fetched from git ls-tree -r. func (c *CAS) storeTreeRecursive( ctx context.Context, l log.Logger, + v Venv, runner *git.GitRunner, hash string, tree *git.Tree, ) error { - if !c.treeStore.NeedsWrite(hash) { + if !c.treeStore.NeedsWrite(v, hash) { return nil } - if err := c.storeBlobs(ctx, runner, tree.Entries()); err != nil { + if err := c.storeBlobs(ctx, v, runner, tree.Entries()); err != nil { return err } - // Store the tree object itself treeContent := NewContent(c.treeStore) - if err := treeContent.EnsureWithWait(l, hash, tree.Data()); err != nil { + if err := treeContent.EnsureWithWait(l, v, hash, tree.Data()); err != nil { return err } return nil } -// storeBlobs stores blobs in the CAS -func (c *CAS) storeBlobs(ctx context.Context, runner *git.GitRunner, entries []git.TreeEntry) error { +// storeBlobs stores blobs in the CAS. +func (c *CAS) storeBlobs(ctx context.Context, v Venv, runner *git.GitRunner, entries []git.TreeEntry) error { for _, entry := range entries { - if !c.blobStore.NeedsWrite(entry.Hash) { + if !c.blobStore.NeedsWrite(v, entry.Hash) { continue } - if err := c.ensureBlob(ctx, runner, entry.Hash, gitFilePerm(entry.Mode)); err != nil { + if err := c.ensureBlob(ctx, v, runner, entry.Hash, gitFilePerm(entry.Mode)); err != nil { return err } } @@ -636,18 +655,27 @@ func (c *CAS) storeBlobs(ctx context.Context, runner *git.GitRunner, entries []g // this blob; the stored blob is chmodded to gitPerm with the // write bits cleared so the default-link path can hardlink the // blob directly without altering its executable-ness. -func (c *CAS) ensureBlob(ctx context.Context, runner *git.GitRunner, hash string, gitPerm os.FileMode) error { - needsWrite, lock, err := c.blobStore.EnsureWithWait(hash) +// +// err is a named return so the deferred unlock and tempfile cleanup +// can errors.Join their failures into what the caller actually sees; +// otherwise the assignments target a local variable that has no +// connection to the function's return slot. +func (c *CAS) ensureBlob( + ctx context.Context, + v Venv, + runner *git.GitRunner, + hash string, + gitPerm os.FileMode, +) (err error) { + needsWrite, lock, err := c.blobStore.EnsureWithWait(v, hash) if err != nil { return err } - // If content already exists or was written by another process, we're done if !needsWrite { return nil } - // We have the lock and need to write the content defer func() { if unlockErr := lock.Unlock(); unlockErr != nil { err = errors.Join(err, unlockErr) @@ -656,18 +684,16 @@ func (c *CAS) ensureBlob(ctx context.Context, runner *git.GitRunner, hash string content := NewContent(c.blobStore) - tmpHandle, err := content.GetTmpHandle(hash) + tmpHandle, err := content.GetTmpHandle(v, hash) if err != nil { return err } tmpPath := tmpHandle.Name() - // We want to make sure we remove the temporary file - // if we encounter an error defer func() { - if _, statErr := c.fs.Stat(tmpPath); statErr == nil { - err = errors.Join(err, c.fs.Remove(tmpPath)) + if _, statErr := v.FS.Stat(tmpPath); statErr == nil { + err = errors.Join(err, v.FS.Remove(tmpPath)) } }() @@ -676,7 +702,6 @@ func (c *CAS) ensureBlob(ctx context.Context, runner *git.GitRunner, hash string return err } - // For Windows, ensure data is synchronized to disk if runtime.GOOS == "windows" { if err = tmpHandle.Sync(); err != nil { return err @@ -687,7 +712,7 @@ func (c *CAS) ensureBlob(ctx context.Context, runner *git.GitRunner, hash string return err } - if err = c.fs.Rename(tmpPath, content.getPath(hash)); err != nil { + if err = v.FS.Rename(tmpPath, content.getPath(hash)); err != nil { return err } @@ -699,7 +724,7 @@ func (c *CAS) ensureBlob(ctx context.Context, runner *git.GitRunner, hash string storedPerm = StoredFilePerms } - if err = c.fs.Chmod(content.getPath(hash), storedPerm); err != nil { + if err = v.FS.Chmod(content.getPath(hash), storedPerm); err != nil { return err } diff --git a/internal/cas/cas_test.go b/internal/cas/cas_test.go index 8bbb80d3ce..7380441818 100644 --- a/internal/cas/cas_test.go +++ b/internal/cas/cas_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/internal/vexec" "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/test/helpers" "github.com/gruntwork-io/terragrunt/test/helpers/logger" @@ -19,6 +21,9 @@ func TestCAS_Clone(t *testing.T) { l := logger.CreateLogger() repoURL := startTestServer(t) + v, err := cas.OSVenv() + require.NoError(t, err) + t.Run("clone new repository", func(t *testing.T) { t.Parallel() tempDir := helpers.TmpDirWOSymlinks(t) @@ -28,10 +33,8 @@ func TestCAS_Clone(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), l, v, repoURL, cas.WithDir(targetPath), + cas.WithDepth(-1)) require.NoError(t, err) // Verify repository was cloned @@ -52,11 +55,9 @@ func TestCAS_Clone(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - Branch: "main", - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), l, v, repoURL, cas.WithDir(targetPath), + cas.WithBranch("main"), + cas.WithDepth(-1)) require.NoError(t, err) // Verify repository was cloned @@ -73,11 +74,9 @@ func TestCAS_Clone(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - IncludedGitFiles: []string{"HEAD", "config"}, - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), l, v, repoURL, cas.WithDir(targetPath), + cas.WithIncludedGitFiles([]string{"HEAD", "config"}), + cas.WithDepth(-1)) require.NoError(t, err) // Verify repository was cloned @@ -110,10 +109,11 @@ func TestCAS_FallbackWhenGitStoreFails(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - err = c.Clone(t.Context(), logger.CreateLogger(), &cas.CloneOptions{ - Dir: targetPath, - Depth: -1, - }, repoURL) + v, err := cas.OSVenv() + require.NoError(t, err) + + err = c.Clone(t.Context(), logger.CreateLogger(), v, repoURL, cas.WithDir(targetPath), + cas.WithDepth(-1)) require.NoError(t, err) _, err = os.Stat(filepath.Join(targetPath, "README.md")) @@ -148,10 +148,11 @@ func TestCAS_CloneRepoWithSymlink(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - err = c.Clone(t.Context(), logger.CreateLogger(), &cas.CloneOptions{ - Dir: targetPath, - Depth: -1, - }, repoURL) + v, err := cas.OSVenv() + require.NoError(t, err) + + err = c.Clone(t.Context(), logger.CreateLogger(), v, repoURL, cas.WithDir(targetPath), + cas.WithDepth(-1)) require.NoError(t, err) linkPath := filepath.Join(targetPath, "link.txt") @@ -173,13 +174,24 @@ func TestCAS_CloneRepoWithSymlink(t *testing.T) { assert.Equal(t, []byte("hello"), data) } -// TestCASRejectsNonOSFilesystem pins the early OS-filesystem gate -// in [cas.New]: a non-OS backing must fail at construction. +// TestCASRejectsNonOSFilesystem pins the early OS-filesystem gate when a +// non-OS-backed Venv reaches CAS operations: ensureStorePaths refuses to +// continue rather than producing surprising behavior against an in-memory FS. func TestCASRejectsNonOSFilesystem(t *testing.T) { t.Parallel() storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") - _, err := cas.New(cas.WithFS(vfs.NewMemMapFS()), cas.WithStorePath(storePath)) + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(t, err) + + runner, err := git.NewGitRunner(vexec.NewOSExec()) + require.NoError(t, err) + + v := cas.Venv{FS: vfs.NewMemMapFS(), Git: runner} + + err = c.Clone(t.Context(), logger.CreateLogger(), v, "https://example.com/repo.git", + cas.WithDir(filepath.Join(helpers.TmpDirWOSymlinks(t), "repo")), + cas.WithDepth(-1)) require.ErrorIs(t, err, cas.ErrGitStoreFSNotOS) } diff --git a/internal/cas/clone_e2e_test.go b/internal/cas/clone_e2e_test.go new file mode 100644 index 0000000000..ed3f3d90c8 --- /dev/null +++ b/internal/cas/clone_e2e_test.go @@ -0,0 +1,208 @@ +package cas_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCASClone_E2E_SymbolicRefSecondRunReusesCache(t *testing.T) { + t.Parallel() + + repoURL := startTestServer(t) + headHash := resolveHeadE2E(t, repoURL) + + tempDir := helpers.TmpDirWOSymlinks(t) + storePath := filepath.Join(tempDir, "store") + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + + // First clone: probe hits, fetcher runs (tree not cached yet). + dst1 := filepath.Join(tempDir, "dst1") + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(dst1), + cas.WithBranch("main"), + cas.WithDepth(-1))) + + require.FileExists(t, filepath.Join(dst1, "README.md")) + require.FileExists(t, filepath.Join(dst1, "main.tf")) + + // Tree is stored under the commit SHA, the same key the second + // clone's probe will derive. + treeContent := cas.NewContent(c.TreeStore()) + _, err = treeContent.Read(v, headHash) + require.NoError(t, err, "tree must be stored under the canonical commit SHA") + + // Second clone: probe still hits ls-remote, derives the same key, + // FetchSource short-circuits via treeStore.NeedsWrite, fetcher + // never runs. + dst2 := filepath.Join(tempDir, "dst2") + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(dst2), + cas.WithBranch("main"), + cas.WithDepth(-1))) + + require.FileExists(t, filepath.Join(dst2, "README.md")) + require.FileExists(t, filepath.Join(dst2, "main.tf")) +} + +func TestCASClone_E2E_CommitFormRefRoundTrip(t *testing.T) { + t.Parallel() + + repoURL := startTestServer(t) + headHash := resolveHeadE2E(t, repoURL) + + tempDir := helpers.TmpDirWOSymlinks(t) + c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + + // Probe will return ErrNoVersionMetadata (ls-remote can't resolve + // a raw SHA), so fetcher canonicalizes via populateTreeFromCommitRef. + dst := filepath.Join(tempDir, "dst") + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(dst), + cas.WithBranch(headHash), + cas.WithDepth(-1))) + + require.FileExists(t, filepath.Join(dst, "README.md")) + + // The canonical SHA path stores the tree under the resolved commit + // hash, so a follow-up symbolic clone of "main" finds it. + treeContent := cas.NewContent(c.TreeStore()) + _, err = treeContent.Read(v, headHash) + require.NoError(t, err) +} + +func TestCASClone_E2E_ThroughCASGetter(t *testing.T) { + t.Parallel() + + repoURL := startTestServer(t) + + tempDir := helpers.TmpDirWOSymlinks(t) + c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + + // Full CASGetter dispatch (Detect → Get → CAS.Clone). The + // CASGetter is responsible for the ?ref= round-trip. + g := getter.NewCASGetter(l, c, v, &cas.CloneOptions{Depth: -1}) + client := &getter.Client{Getters: []getter.Getter{g}} + + dst := filepath.Join(tempDir, "dst") + _, err = client.Get(t.Context(), &getter.Request{ + Src: "git::" + repoURL + "?ref=main", + Dst: dst, + GetMode: getter.ModeDir, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(dst, "README.md")) +} + +func TestCASClone_E2E_RemainsOfflineAfterFirstClone(t *testing.T) { + t.Parallel() + + // Once the tree is cached, a second Clone() keyed by the full + // commit SHA resolves entirely from the local CAS: + // - GitResolver.Probe sees looksLikeFullSHA(Branch) and runs + // ProbeCachedCommit against the bare repo (rev-parse, no + // network). + // - FetchSource finds the tree already stored under that SHA + // and short-circuits to linkStoredTree; Fetch never runs. + // + // We pin both halves by shutting the server down between clones: + // the second Clone() must succeed, and (since LsRemote would + // fail-fast against a dead listener) any path that still reaches + // it would surface here. + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + + repoURL, err := srv.Start(t.Context()) + require.NoError(t, err) + + headHash, err := srv.Head() + require.NoError(t, err) + + tempDir := helpers.TmpDirWOSymlinks(t) + c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + + dst1 := filepath.Join(tempDir, "dst1") + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(dst1), + cas.WithBranch("main"), + cas.WithDepth(-1))) + + // Shut the server down. Any subsequent ls-remote against repoURL + // would fail with "Could not resolve host" / "Connection refused". + require.NoError(t, srv.Close()) + + dst2 := filepath.Join(tempDir, "dst2") + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(dst2), + cas.WithBranch(headHash), + cas.WithDepth(-1)), "second clone keyed by full SHA must resolve from local CAS without ls-remote") + + require.FileExists(t, filepath.Join(dst2, "README.md")) +} + +func TestCASClone_E2E_MutableSetCopiesBlobs(t *testing.T) { + t.Parallel() + + repoURL := startTestServer(t) + + tempDir := helpers.TmpDirWOSymlinks(t) + c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + + dst := filepath.Join(tempDir, "dst") + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(dst), + cas.WithBranch("main"), + cas.WithDepth(-1), + cas.WithMutable(true))) + + // Mutable=true: destination files have the original perms, not + // the write-bit-stripped read-only perms the default path uses. + stat, err := os.Stat(filepath.Join(dst, "README.md")) + require.NoError(t, err) + assert.NotEqual(t, os.FileMode(0o444), stat.Mode().Perm(), + "mutable clone should not strip write bits; default path does") +} + +// resolveHeadE2E is a convenience wrapper used by several tests in +// this file; included here so the file is independent of +// commitref_test.go's helpers. +func resolveHeadE2E(t *testing.T, srv string) string { + t.Helper() + + results, err := newGitRunner(t).LsRemote(t.Context(), srv, "HEAD") + require.NoError(t, err) + require.NotEmpty(t, results) + + return results[0].Hash +} diff --git a/internal/cas/commitref_test.go b/internal/cas/commitref_test.go index f5523c1c67..a8d785c05d 100644 --- a/internal/cas/commitref_test.go +++ b/internal/cas/commitref_test.go @@ -22,6 +22,9 @@ func TestCASCloneByCommitRef(t *testing.T) { repoURL := startTestServer(t) headHash := resolveHead(t, repoURL) + v, err := cas.OSVenv() + require.NoError(t, err) + t.Run("clone with full commit SHA", func(t *testing.T) { t.Parallel() tempDir := helpers.TmpDirWOSymlinks(t) @@ -30,11 +33,9 @@ func TestCASCloneByCommitRef(t *testing.T) { require.NoError(t, err) targetPath := filepath.Join(tempDir, "repo") - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - Branch: headHash, - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), l, v, repoURL, cas.WithDir(targetPath), + cas.WithBranch(headHash), + cas.WithDepth(-1)) require.NoError(t, err) _, err = os.Stat(filepath.Join(targetPath, "README.md")) @@ -49,11 +50,9 @@ func TestCASCloneByCommitRef(t *testing.T) { require.NoError(t, err) targetPath := filepath.Join(tempDir, "repo") - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - Branch: headHash[:8], - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), l, v, repoURL, cas.WithDir(targetPath), + cas.WithBranch(headHash[:8]), + cas.WithDepth(-1)) require.NoError(t, err) _, err = os.Stat(filepath.Join(targetPath, "README.md")) @@ -69,11 +68,9 @@ func TestCASCloneByCommitRef(t *testing.T) { require.NoError(t, err) // Prime the central git store. - require.NoError(t, c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "first"), - Branch: headHash, - Depth: -1, - }, repoURL)) + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(filepath.Join(tempDir, "first")), + cas.WithBranch(headHash), + cas.WithDepth(-1))) // Drop the test server: a cached clone must not need it. repoEntry := cas.EntryPathForURL(filepath.Join(storePath, "git"), repoURL) @@ -81,11 +78,9 @@ func TestCASCloneByCommitRef(t *testing.T) { require.NoError(t, err) secondClone := filepath.Join(tempDir, "second") - require.NoError(t, c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: secondClone, - Branch: headHash, - Depth: -1, - }, repoURL)) + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(secondClone), + cas.WithBranch(headHash), + cas.WithDepth(-1))) _, err = os.Stat(filepath.Join(secondClone, "README.md")) require.NoError(t, err) @@ -98,11 +93,9 @@ func TestCASCloneByCommitRef(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "repo"), - Branch: "0000000000000000000000000000000000000000", - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), l, v, repoURL, cas.WithDir(filepath.Join(tempDir, "repo")), + cas.WithBranch("0000000000000000000000000000000000000000"), + cas.WithDepth(-1)) require.Error(t, err) assert.ErrorIs(t, err, git.ErrNoMatchingReference) }) @@ -114,19 +107,19 @@ func TestGitStoreEnsureCommit_CachedAfterFirstFetch(t *testing.T) { url := startTestServer(t) hash := resolveHead(t, url) - store, fs, _ := newTestGitStore(t) + store, v, _ := newTestGitStore(t) l := logger.CreateLogger() ctx := t.Context() // First call must fetch. - repo, err := store.EnsureCommit(ctx, l, fs, url, hash, "") + repo, err := store.EnsureCommit(ctx, l, v, url, hash, "") require.NoError(t, err) assert.Equal(t, hash, repo.Hash) assert.NotEmpty(t, repo.Path) require.NoError(t, repo.Unlock()) // Second call hits the local-cache short-circuit. - repo2, err := store.EnsureCommit(ctx, l, fs, url, hash, "") + repo2, err := store.EnsureCommit(ctx, l, v, url, hash, "") require.NoError(t, err) assert.Equal(t, hash, repo2.Hash) require.NoError(t, repo2.Unlock()) @@ -138,10 +131,10 @@ func TestGitStoreEnsureCommit_AbbreviatedSHA(t *testing.T) { url := startTestServer(t) hash := resolveHead(t, url) - store, fs, _ := newTestGitStore(t) + store, v, _ := newTestGitStore(t) l := logger.CreateLogger() - repo, err := store.EnsureCommit(t.Context(), l, fs, url, hash[:8], "") + repo, err := store.EnsureCommit(t.Context(), l, v, url, hash[:8], "") require.NoError(t, err) assert.Equal(t, hash, repo.Hash, "abbreviated SHA must canonicalize to the full hash") require.NoError(t, repo.Unlock()) @@ -152,10 +145,10 @@ func TestGitStoreEnsureCommit_UnresolvableSurfacesNoMatchingReference(t *testing url := startTestServer(t) - store, fs, _ := newTestGitStore(t) + store, v, _ := newTestGitStore(t) l := logger.CreateLogger() - _, err := store.EnsureCommit(t.Context(), l, fs, url, "0000000000000000000000000000000000000000", "") + _, err := store.EnsureCommit(t.Context(), l, v, url, "0000000000000000000000000000000000000000", "") require.Error(t, err) assert.ErrorIs(t, err, git.ErrNoMatchingReference) } @@ -189,12 +182,13 @@ func TestCASClone_NonTipCommit(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + targetPath := filepath.Join(tempDir, "repo") - err = c.Clone(t.Context(), logger.CreateLogger(), &cas.CloneOptions{ - Dir: targetPath, - Branch: firstHash, - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), logger.CreateLogger(), v, repoURL, cas.WithDir(targetPath), + cas.WithBranch(firstHash), + cas.WithDepth(-1)) require.NoError(t, err) // Only the first commit's file should be present; later commits @@ -239,13 +233,14 @@ func TestCASClone_AbbreviatedHexBranchAdvancesAcrossClones(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + l := logger.CreateLogger() - require.NoError(t, c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "first"), - Branch: branch, - Depth: -1, - }, repoURL)) + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(filepath.Join(tempDir, "first")), + cas.WithBranch(branch), + cas.WithDepth(-1))) // Advance the branch to a new commit. ls-remote must see the new // tip on the second clone; the probe would otherwise serve the @@ -254,11 +249,9 @@ func TestCASClone_AbbreviatedHexBranchAdvancesAcrossClones(t *testing.T) { require.NoError(t, srv.Branch(branch)) secondDir := filepath.Join(tempDir, "second") - require.NoError(t, c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: secondDir, - Branch: branch, - Depth: -1, - }, repoURL)) + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(secondDir), + cas.WithBranch(branch), + cas.WithDepth(-1))) _, err = os.Stat(filepath.Join(secondDir, "v2.txt")) require.NoError(t, err, "second clone must reflect the moved branch tip, not the cached prefix-matching commit") @@ -287,12 +280,13 @@ func TestCASClone_HexBranchNameResolvesViaLsRemote(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + targetPath := filepath.Join(tempDir, "repo") - err = c.Clone(t.Context(), logger.CreateLogger(), &cas.CloneOptions{ - Dir: targetPath, - Branch: hexBranch, - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), logger.CreateLogger(), v, repoURL, cas.WithDir(targetPath), + cas.WithBranch(hexBranch), + cas.WithDepth(-1)) require.NoError(t, err) _, err = os.Stat(filepath.Join(targetPath, "README.md")) @@ -316,12 +310,13 @@ func TestCASClone_TagRef(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + targetPath := filepath.Join(tempDir, "repo") - err = c.Clone(t.Context(), logger.CreateLogger(), &cas.CloneOptions{ - Dir: targetPath, - Branch: "v1.0.0", - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), logger.CreateLogger(), v, repoURL, cas.WithDir(targetPath), + cas.WithBranch("v1.0.0"), + cas.WithDepth(-1)) require.NoError(t, err) _, err = os.Stat(filepath.Join(targetPath, "README.md")) @@ -398,17 +393,17 @@ func TestGitStoreEnsureCommit_OfflineWhenCached(t *testing.T) { repoURL, err := srv.Start(t.Context()) require.NoError(t, err) - store, fs, _ := newTestGitStore(t) + store, v, _ := newTestGitStore(t) l := logger.CreateLogger() ctx := t.Context() - primed, err := store.EnsureCommit(ctx, l, fs, repoURL, hash, "") + primed, err := store.EnsureCommit(ctx, l, v, repoURL, hash, "") require.NoError(t, err) require.NoError(t, primed.Unlock()) require.NoError(t, srv.Close()) - cached, err := store.EnsureCommit(ctx, l, fs, repoURL, hash, "") + cached, err := store.EnsureCommit(ctx, l, v, repoURL, hash, "") require.NoError(t, err, "cached commit must resolve without contacting the server") assert.Equal(t, hash, cached.Hash) require.NoError(t, cached.Unlock()) @@ -471,22 +466,21 @@ func TestCASClone_OfflineWhenCommitCached(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + l := logger.CreateLogger() - require.NoError(t, c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "primed"), - Branch: hash, - Depth: -1, - }, repoURL)) + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(filepath.Join(tempDir, "primed")), + cas.WithBranch(hash), + cas.WithDepth(-1))) require.NoError(t, srv.Close()) cachedDir := filepath.Join(tempDir, "cached") - require.NoError(t, c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: cachedDir, - Branch: hash, - Depth: -1, - }, repoURL), "cached commit ref must resolve without contacting the server") + require.NoError(t, c.Clone(t.Context(), l, v, repoURL, cas.WithDir(cachedDir), + cas.WithBranch(hash), + cas.WithDepth(-1)), "cached commit ref must resolve without contacting the server") _, err = os.Stat(filepath.Join(cachedDir, "README.md")) require.NoError(t, err) @@ -506,7 +500,10 @@ func TestCASGetterGet_WithCommitRef(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - g := getter.NewCASGetter(logger.CreateLogger(), c, &cas.CloneOptions{Depth: -1}) + v, err := cas.OSVenv() + require.NoError(t, err) + + g := getter.NewCASGetter(logger.CreateLogger(), c, v, &cas.CloneOptions{Depth: -1}) client := getter.Client{Getters: []getter.Getter{g}} dst := filepath.Join(tempDir, "repo") @@ -545,12 +542,13 @@ func TestCAS_CommitRefFallbackWhenGitStoreFails(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + targetPath := filepath.Join(tempDir, "repo") - err = c.Clone(t.Context(), logger.CreateLogger(), &cas.CloneOptions{ - Dir: targetPath, - Branch: headHash, - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), logger.CreateLogger(), v, repoURL, cas.WithDir(targetPath), + cas.WithBranch(headHash), + cas.WithDepth(-1)) require.NoError(t, err) _, err = os.Stat(filepath.Join(targetPath, "README.md")) @@ -574,6 +572,9 @@ func TestCASCloneByCommitRefConcurrentWithRacing(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + const workers = 4 var wg sync.WaitGroup @@ -586,11 +587,10 @@ func TestCASCloneByCommitRefConcurrentWithRacing(t *testing.T) { go func(idx int) { defer wg.Done() - errs[idx] = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "repo", "worker", string(rune('a'+idx))), - Branch: headHash, - Depth: -1, - }, repoURL) + errs[idx] = c.Clone(t.Context(), l, v, repoURL, + cas.WithDir(filepath.Join(tempDir, "repo", "worker", string(rune('a'+idx)))), + cas.WithBranch(headHash), + cas.WithDepth(-1)) }(i) } diff --git a/internal/cas/content.go b/internal/cas/content.go index c4e835eed2..288f064a98 100644 --- a/internal/cas/content.go +++ b/internal/cas/content.go @@ -16,30 +16,26 @@ import ( ) const ( - // DefaultDirPerms represents standard directory permissions (rwxr-xr-x) + // DefaultDirPerms represents standard directory permissions (rwxr-xr-x). DefaultDirPerms = os.FileMode(0755) - // StoredFilePerms represents read-only file permissions (r--r--r--) + // StoredFilePerms represents read-only file permissions (r--r--r--). StoredFilePerms = os.FileMode(0444) - // RegularFilePerms represents standard file permissions (rw-r--r--) + // RegularFilePerms represents standard file permissions (rw-r--r--). RegularFilePerms = os.FileMode(0644) // WriteBitMask covers all owner/group/other write bits. WriteBitMask = os.FileMode(0o222) - // WindowsOS is the name of the Windows operating system + // WindowsOS is the name of the Windows operating system. WindowsOS = "windows" ) -// Content manages git object storage and linking +// Content manages git object storage and linking. type Content struct { store *Store - fs vfs.FS } -// NewContent creates a new Content instance +// NewContent creates a new Content instance bound to store. func NewContent(store *Store) *Content { - return &Content{ - store: store, - fs: store.FS(), - } + return &Content{store: store} } // LinkOption configures a single Content.Link call. @@ -56,20 +52,20 @@ func WithLinkForceCopy() LinkOption { return func(o *linkOpts) { o.forceCopy = true } } -// Link materializes a stored blob at targetPath. gitPerm is the original git -// mode bits for the entry (e.g. 0o644 or 0o755). -// -// Default path: the destination has the original git perms with the write bit -// stripped, so the target is read-only and cannot poison the shared store. -// Stored blobs already carry these read-only-of-original perms, so the -// hardlink path covers both regular files (0o444) and executables (0o555). -// If the stored blob's perms do not match (rare collision: same content -// referenced under different modes), Link falls back to a copy at the -// requested perm. +// Link materializes a stored blob at targetPath under gitPerm. // -// WithLinkForceCopy: the destination has the exact original git perms, -// always via copy, so callers can edit the working tree freely. -func (c *Content) Link(ctx context.Context, hash, targetPath string, gitPerm os.FileMode, opts ...LinkOption) error { +// The default path hardlinks the stored blob with its write bits stripped, so +// the destination cannot be edited back into the shared store. The fallback +// copy path applies when stored perms don't match the request (rare cross-mode +// collision) or when [WithLinkForceCopy] is in effect, so callers can edit +// the working tree freely. +func (c *Content) Link( + ctx context.Context, + v Venv, + hash, targetPath string, + gitPerm os.FileMode, + opts ...LinkOption, +) error { var o linkOpts for _, opt := range opts { opt(&o) @@ -93,15 +89,17 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string, gitPerm os. // leak back into the shared store and so the destination carries the // requested mode. if !o.forceCopy { - if info, statErr := c.fs.Stat(sourcePath); statErr == nil && info.Mode().Perm() == desired { - if err := vfs.Link(c.fs, sourcePath, targetPath); err == nil || os.IsExist(err) { + if info, statErr := v.FS.Stat(sourcePath); statErr == nil && info.Mode().Perm() == desired { + if err := vfs.Link(v.FS, sourcePath, targetPath); err == nil { return nil } - // Fall through to copy on link failure. + // Fall through to copy on link failure. An existing + // targetPath is handled by the temp-file+rename below, + // which overwrites stale bytes atomically. } } - data, readErr := vfs.ReadFile(c.fs, sourcePath) + data, readErr := vfs.ReadFile(v.FS, sourcePath) if readErr != nil { return &WrappedError{ Op: "read_source", @@ -110,9 +108,8 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string, gitPerm os. } } - // Write to temporary file first tempPath := targetPath + ".tmp" - if err := vfs.WriteFile(c.fs, tempPath, data, desired); err != nil { + if err := vfs.WriteFile(v.FS, tempPath, data, desired); err != nil { return &WrappedError{ Op: "write_target", Path: tempPath, @@ -121,7 +118,7 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string, gitPerm os. } // Reapply perms after write to override any umask masking. - if err := c.fs.Chmod(tempPath, desired); err != nil { + if err := v.FS.Chmod(tempPath, desired); err != nil { return &WrappedError{ Op: "chmod_target", Path: tempPath, @@ -129,8 +126,7 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string, gitPerm os. } } - // Atomic rename to final path - if err := c.fs.Rename(tempPath, targetPath); err != nil { + if err := v.FS.Rename(tempPath, targetPath); err != nil { return &WrappedError{ Op: "rename_target", Path: tempPath, @@ -143,9 +139,9 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string, gitPerm os. } // Store stores a single content item. This is typically used for trees, -// As blobs are written directly from git cat-file stdout. -func (c *Content) Store(l log.Logger, hash string, data []byte) error { - lock, err := c.store.AcquireLock(hash) +// as blobs are written directly from git cat-file stdout. +func (c *Content) Store(l log.Logger, v Venv, hash string, data []byte) error { + lock, err := c.store.AcquireLock(v, hash) if err != nil { return fmt.Errorf("acquire lock for %s: %w", hash, err) } @@ -156,78 +152,74 @@ func (c *Content) Store(l log.Logger, hash string, data []byte) error { } }() - if err = c.fs.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { + if err = v.FS.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { return fmt.Errorf("create store dir %s: %w", c.store.Path(), ErrCreateDir) } - // Ensure partition directory exists partitionDir := c.getPartition(hash) - if err = c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err = v.FS.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return fmt.Errorf("create partition dir %s: %w", partitionDir, ErrCreateDir) } - return c.writeContentToFile(l, hash, data) + return c.writeContentToFile(l, v, hash, data) } -// Ensure ensures that a content item exists in the store -func (c *Content) Ensure(l log.Logger, hash string, data []byte) error { +// Ensure ensures that a content item exists in the store. +func (c *Content) Ensure(l log.Logger, v Venv, hash string, data []byte) error { path := c.getPath(hash) - if c.store.hasContent(path) { + if c.store.hasContent(v, path) { return nil } - return c.Store(l, hash, data) + return c.Store(l, v, hash, data) } // EnsureWithWait ensures that a content item exists in the store, with optimization -// to wait for concurrent writes instead of doing redundant work -func (c *Content) EnsureWithWait(l log.Logger, hash string, data []byte) error { - needsWrite, lock, err := c.store.EnsureWithWait(hash) +// to wait for concurrent writes instead of doing redundant work. +func (c *Content) EnsureWithWait(l log.Logger, v Venv, hash string, data []byte) error { + needsWrite, lock, err := c.store.EnsureWithWait(v, hash) if err != nil { return fmt.Errorf("ensure content for %s: %w", hash, err) } - // If content already exists or was written by another process, we're done if !needsWrite { return nil } - // We have the lock and need to write the content defer func() { if unlockErr := lock.Unlock(); unlockErr != nil { l.Warnf("failed to unlock filesystem lock for hash %s: %v", hash, unlockErr) } }() - if err = c.fs.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { + if err = v.FS.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { return fmt.Errorf("create store dir %s: %w", c.store.Path(), ErrCreateDir) } - // Ensure partition directory exists partitionDir := c.getPartition(hash) - if err = c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err = v.FS.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return fmt.Errorf("create partition dir %s: %w", partitionDir, ErrCreateDir) } - return c.writeContentToFile(l, hash, data) + return c.writeContentToFile(l, v, hash, data) } // EnsureCopy ensures that a content item exists in the store by copying from a file. // The stored blob is chmodded to the source file's perms with the write bits cleared, // so the default-link path can hardlink the blob directly without losing its // executable-ness or risking writes back into the shared store. -func (c *Content) EnsureCopy(l log.Logger, hash, src string) (err error) { +func (c *Content) EnsureCopy(l log.Logger, v Venv, hash, src string) (err error) { path := c.getPath(hash) - if c.store.hasContent(path) { + if c.store.hasContent(v, path) { return nil } - srcInfo, err := c.fs.Stat(src) + srcInfo, err := v.FS.Stat(src) if err != nil { return fmt.Errorf("stat source %s: %w", src, err) } - lock, err := c.store.AcquireLock(hash) + lock, err := c.store.AcquireLock(v, hash) if err != nil { return fmt.Errorf("acquire lock for %s: %w", hash, err) } @@ -240,27 +232,42 @@ func (c *Content) EnsureCopy(l log.Logger, hash, src string) (err error) { // a read-only blob between the lock-free hasContent check and AcquireLock. // Without this guard, Create below would fail with EACCES on the existing // 0o444 file. - if c.store.hasContent(path) { + if c.store.hasContent(v, path) { return nil } - // Ensure partition directory exists partitionDir := c.getPartition(hash) - if err = c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err = v.FS.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return fmt.Errorf("create partition dir %s: %w", partitionDir, ErrCreateDir) } - f, err := c.fs.Create(path) + // Write through a tempPath so a crash mid-copy cannot leave a + // half-written blob at the final hash-addressed path. The rename + // is the publish step. + tempPath := path + ".tmp" + + f, err := v.FS.Create(tempPath) if err != nil { - return fmt.Errorf("create file %s: %w", path, err) + return fmt.Errorf("create file %s: %w", tempPath, err) } + // renamed flips after the publish step so the deferred cleanup + // removes a stale tempPath only on the error path. + renamed := false + defer func() { - err = errors.Join(err, f.Close()) + if renamed { + return + } + + if rmErr := v.FS.Remove(tempPath); rmErr != nil && !os.IsNotExist(rmErr) { + err = errors.Join(err, rmErr) + } }() - r, err := c.fs.Open(src) + r, err := v.FS.Open(src) if err != nil { + err = errors.Join(err, f.Close()) return fmt.Errorf("open source %s: %w", src, err) } @@ -269,27 +276,40 @@ func (c *Content) EnsureCopy(l log.Logger, hash, src string) (err error) { }() if _, err := io.Copy(f, r); err != nil { - return fmt.Errorf("copy from %s: %w", src, err) + closeErr := f.Close() + return fmt.Errorf("copy from %s: %w", src, errors.Join(err, closeErr)) + } + + // Close the writer before rename so platforms that disallow + // renaming an open file (Windows) can complete the publish. + if err := f.Close(); err != nil { + return fmt.Errorf("close %s: %w", tempPath, err) } - if err := c.fs.Chmod(path, srcInfo.Mode().Perm()&^WriteBitMask); err != nil { - return fmt.Errorf("chmod %s: %w", path, err) + if err := v.FS.Chmod(tempPath, srcInfo.Mode().Perm()&^WriteBitMask); err != nil { + return fmt.Errorf("chmod %s: %w", tempPath, err) } + if err := v.FS.Rename(tempPath, path); err != nil { + return fmt.Errorf("finalize %s: %w", path, err) + } + + renamed = true + return nil } // GetTmpHandle returns a file handle to a temporary file where content will be stored. -func (c *Content) GetTmpHandle(hash string) (vfs.File, error) { +func (c *Content) GetTmpHandle(v Venv, hash string) (vfs.File, error) { partitionDir := c.getPartition(hash) - if err := c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err := v.FS.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return nil, fmt.Errorf("create partition dir %s: %w", partitionDir, ErrCreateDir) } path := c.getPath(hash) tempPath := path + ".tmp" - f, err := c.fs.Create(tempPath) + f, err := v.FS.Create(tempPath) if err != nil { return nil, fmt.Errorf("create temp file %s: %w", tempPath, err) } @@ -297,19 +317,19 @@ func (c *Content) GetTmpHandle(hash string) (vfs.File, error) { return f, err } -// Read retrieves content from the store by hash -func (c *Content) Read(hash string) ([]byte, error) { +// Read retrieves content from the store by hash. +func (c *Content) Read(v Venv, hash string) ([]byte, error) { path := c.getPath(hash) - return vfs.ReadFile(c.fs, path) + return vfs.ReadFile(v.FS, path) } -// writeContentToFile writes data to a temporary file, -// sets appropriate permissions, and performs an atomic rename. -func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) error { +// writeContentToFile writes data to a temporary file, sets appropriate +// permissions, and performs an atomic rename. +func (c *Content) writeContentToFile(l log.Logger, v Venv, hash string, data []byte) error { path := c.getPath(hash) tempPath := path + ".tmp" - f, err := c.fs.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, RegularFilePerms) + f, err := v.FS.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, RegularFilePerms) if err != nil { return fmt.Errorf("create temp file %s: %w", tempPath, err) } @@ -321,7 +341,7 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err l.Warnf("failed to close temp file %s: %v", tempPath, closeErr) } - if removeErr := c.fs.Remove(tempPath); removeErr != nil { + if removeErr := v.FS.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } @@ -333,7 +353,7 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err l.Warnf("failed to close temp file %s: %v", tempPath, closeErr) } - if removeErr := c.fs.Remove(tempPath); removeErr != nil { + if removeErr := v.FS.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } @@ -341,46 +361,39 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err } if err := f.Close(); err != nil { - if removeErr := c.fs.Remove(tempPath); removeErr != nil { + if removeErr := v.FS.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } return fmt.Errorf("close %s: %w", tempPath, err) } - // Set read-only permissions on the temporary file - if err := c.fs.Chmod(tempPath, StoredFilePerms); err != nil { - if removeErr := c.fs.Remove(tempPath); removeErr != nil { + if err := v.FS.Chmod(tempPath, StoredFilePerms); err != nil { + if removeErr := v.FS.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } return fmt.Errorf("chmod temp %s: %w", tempPath, err) } - // For Windows, handle readonly attributes specifically if runtime.GOOS == WindowsOS { - // Check if a destination file exists and is read-only - if _, err := c.fs.Stat(path); err == nil { - // File exists, make it writable before rename operation - if err := c.fs.Chmod(path, RegularFilePerms); err != nil { + if _, err := v.FS.Stat(path); err == nil { + if err := v.FS.Chmod(path, RegularFilePerms); err != nil { l.Warnf("failed to make destination file writable %s: %v", path, err) } } } - // Atomic rename - if err := c.fs.Rename(tempPath, path); err != nil { - if removeErr := c.fs.Remove(tempPath); removeErr != nil { + if err := v.FS.Rename(tempPath, path); err != nil { + if removeErr := v.FS.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } return fmt.Errorf("finalize %s: %w", path, err) } - // For Windows, we need to set the permissions again after rename if runtime.GOOS == WindowsOS { - // Ensure the file has read-only permissions after rename - if err := c.fs.Chmod(path, StoredFilePerms); err != nil { + if err := v.FS.Chmod(path, StoredFilePerms); err != nil { return fmt.Errorf("chmod %s: %w", path, err) } } @@ -388,12 +401,12 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err return nil } -// getPartition returns the partition path for a given hash +// getPartition returns the partition path for a given hash. func (c *Content) getPartition(hash string) string { return filepath.Join(c.store.Path(), hash[:2]) } -// getPath returns the full path for a given hash +// getPath returns the full path for a given hash. func (c *Content) getPath(hash string) string { return filepath.Join(c.getPartition(hash), hash) } diff --git a/internal/cas/content_test.go b/internal/cas/content_test.go index 27c98d442f..6ca0734f45 100644 --- a/internal/cas/content_test.go +++ b/internal/cas/content_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/internal/vexec" "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/test/helpers/logger" "github.com/stretchr/testify/assert" @@ -22,21 +24,22 @@ func TestContent_Store(t *testing.T) { t.Run("store new content", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") - err := content.Store(l, testHash, testData) + err := content.Store(l, v, testHash, testData) require.NoError(t, err) // Verify content was stored partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := vfs.ReadFile(memFs, storedPath) + storedData, err := vfs.ReadFile(v.FS, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -44,9 +47,10 @@ func TestContent_Store(t *testing.T) { t.Run("ensure existing content", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := testHashValue @@ -54,15 +58,15 @@ func TestContent_Store(t *testing.T) { differentData := []byte("different content") // Store content twice - err := content.Ensure(l, testHash, testData) + err := content.Ensure(l, v, testHash, testData) require.NoError(t, err) - err = content.Ensure(l, testHash, differentData) + err = content.Ensure(l, v, testHash, differentData) require.NoError(t, err) // Verify original content remains partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := vfs.ReadFile(memFs, storedPath) + storedData, err := vfs.ReadFile(v.FS, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -70,9 +74,10 @@ func TestContent_Store(t *testing.T) { t.Run("overwrite existing content", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := testHashValue @@ -80,15 +85,15 @@ func TestContent_Store(t *testing.T) { differentData := []byte("different content") // Store content twice - err := content.Store(l, testHash, testData) + err := content.Store(l, v, testHash, testData) require.NoError(t, err) - err = content.Store(l, testHash, differentData) + err = content.Store(l, v, testHash, differentData) require.NoError(t, err) // Verify content was overwritten partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := vfs.ReadFile(memFs, storedPath) + storedData, err := vfs.ReadFile(v.FS, storedPath) require.NoError(t, err) assert.Equal(t, differentData, storedData) }) @@ -102,27 +107,28 @@ func TestContent_Link(t *testing.T) { t.Run("create new link", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - require.NoError(t, memFs.MkdirAll("/target", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + require.NoError(t, v.FS.MkdirAll("/target", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") // First store some content - err := content.Store(l, testHash, testData) + err := content.Store(l, v, testHash, testData) require.NoError(t, err) // Then create a link to it targetPath := filepath.Join("/target", "test.txt") - err = content.Link(t.Context(), testHash, targetPath, 0o644) + err = content.Link(t.Context(), v, testHash, targetPath, 0o644) require.NoError(t, err) // Verify link was created and contains correct content - linkedData, err := vfs.ReadFile(memFs, targetPath) + linkedData, err := vfs.ReadFile(v.FS, targetPath) require.NoError(t, err) assert.Equal(t, testData, linkedData) }) @@ -130,20 +136,22 @@ func TestContent_Link(t *testing.T) { t.Run("create hard link on real filesystem", func(t *testing.T) { t.Parallel() - osFs := vfs.NewOSFS() + v, err := cas.OSVenv() + require.NoError(t, err) + storeDir := t.TempDir() targetDir := t.TempDir() - store := cas.NewStore(storeDir).WithFS(osFs) + store := cas.NewStore(storeDir) content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") - err := content.Store(l, testHash, testData) + err = content.Store(l, v, testHash, testData) require.NoError(t, err) targetPath := filepath.Join(targetDir, "test.txt") - err = content.Link(t.Context(), testHash, targetPath, 0o644) + err = content.Link(t.Context(), v, testHash, targetPath, 0o644) require.NoError(t, err) // Verify hard link by comparing inodes @@ -158,20 +166,22 @@ func TestContent_Link(t *testing.T) { t.Run("force copy creates independent inode on real filesystem", func(t *testing.T) { t.Parallel() - osFs := vfs.NewOSFS() + v, err := cas.OSVenv() + require.NoError(t, err) + storeDir := t.TempDir() targetDir := t.TempDir() - store := cas.NewStore(storeDir).WithFS(osFs) + store := cas.NewStore(storeDir) content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") - err := content.Store(l, testHash, testData) + err = content.Store(l, v, testHash, testData) require.NoError(t, err) targetPath := filepath.Join(targetDir, "test.txt") - err = content.Link(t.Context(), testHash, targetPath, 0o644, cas.WithLinkForceCopy()) + err = content.Link(t.Context(), v, testHash, targetPath, 0o644, cas.WithLinkForceCopy()) require.NoError(t, err) sourcePath := filepath.Join(storeDir, testHash[:2], testHash) @@ -199,19 +209,21 @@ func TestContent_Link(t *testing.T) { t.Run("default path strips write bit from non-executable", func(t *testing.T) { t.Parallel() - osFs := vfs.NewOSFS() + v, err := cas.OSVenv() + require.NoError(t, err) + storeDir := t.TempDir() targetDir := t.TempDir() - store := cas.NewStore(storeDir).WithFS(osFs) + store := cas.NewStore(storeDir) content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") - require.NoError(t, content.Store(l, testHash, testData)) + require.NoError(t, content.Store(l, v, testHash, testData)) targetPath := filepath.Join(targetDir, "test.txt") - require.NoError(t, content.Link(t.Context(), testHash, targetPath, 0o644)) + require.NoError(t, content.Link(t.Context(), v, testHash, targetPath, 0o644)) info, err := os.Stat(targetPath) require.NoError(t, err) @@ -222,16 +234,18 @@ func TestContent_Link(t *testing.T) { t.Run("default path hardlinks executable when store carries matching perms", func(t *testing.T) { t.Parallel() - osFs := vfs.NewOSFS() + v, err := cas.OSVenv() + require.NoError(t, err) + storeDir := t.TempDir() targetDir := t.TempDir() - store := cas.NewStore(storeDir).WithFS(osFs) + store := cas.NewStore(storeDir) content := cas.NewContent(store) testHash := testHashValue testData := []byte("#!/bin/sh\necho hi\n") - require.NoError(t, content.Store(l, testHash, testData)) + require.NoError(t, content.Store(l, v, testHash, testData)) // Mirror the store-side chmod that the git-clone path applies: stored // blobs carry their original git mode with write bits cleared, so @@ -240,7 +254,7 @@ func TestContent_Link(t *testing.T) { require.NoError(t, os.Chmod(sourcePath, 0o555)) targetPath := filepath.Join(targetDir, "run.sh") - require.NoError(t, content.Link(t.Context(), testHash, targetPath, 0o755)) + require.NoError(t, content.Link(t.Context(), v, testHash, targetPath, 0o755)) info, err := os.Stat(targetPath) require.NoError(t, err) @@ -256,23 +270,25 @@ func TestContent_Link(t *testing.T) { t.Run("default path falls back to copy on perm collision", func(t *testing.T) { t.Parallel() - osFs := vfs.NewOSFS() + v, err := cas.OSVenv() + require.NoError(t, err) + storeDir := t.TempDir() targetDir := t.TempDir() - store := cas.NewStore(storeDir).WithFS(osFs) + store := cas.NewStore(storeDir) content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") - require.NoError(t, content.Store(l, testHash, testData)) + require.NoError(t, content.Store(l, v, testHash, testData)) // The blob landed in the store at 0o444 (treated as non-exec). A second // tree referencing the same content under mode 100755 wants 0o555. // Link must produce a fresh inode at 0o555 rather than hardlinking // the 0o444 blob. targetPath := filepath.Join(targetDir, "run.sh") - require.NoError(t, content.Link(t.Context(), testHash, targetPath, 0o755)) + require.NoError(t, content.Link(t.Context(), v, testHash, targetPath, 0o755)) info, err := os.Stat(targetPath) require.NoError(t, err) @@ -288,19 +304,21 @@ func TestContent_Link(t *testing.T) { t.Run("force copy preserves executable bits", func(t *testing.T) { t.Parallel() - osFs := vfs.NewOSFS() + v, err := cas.OSVenv() + require.NoError(t, err) + storeDir := t.TempDir() targetDir := t.TempDir() - store := cas.NewStore(storeDir).WithFS(osFs) + store := cas.NewStore(storeDir) content := cas.NewContent(store) testHash := testHashValue testData := []byte("#!/bin/sh\necho hi\n") - require.NoError(t, content.Store(l, testHash, testData)) + require.NoError(t, content.Store(l, v, testHash, testData)) targetPath := filepath.Join(targetDir, "run.sh") - require.NoError(t, content.Link(t.Context(), testHash, targetPath, 0o755, cas.WithLinkForceCopy())) + require.NoError(t, content.Link(t.Context(), v, testHash, targetPath, 0o755, cas.WithLinkForceCopy())) info, err := os.Stat(targetPath) require.NoError(t, err) @@ -308,35 +326,37 @@ func TestContent_Link(t *testing.T) { "force copy must reproduce git mode exactly (0o755)") }) - t.Run("link to existing file", func(t *testing.T) { + t.Run("link to existing file overwrites stale content", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - require.NoError(t, memFs.MkdirAll("/target", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + require.NoError(t, v.FS.MkdirAll("/target", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") // Store content - err := content.Store(l, testHash, testData) + err := content.Store(l, v, testHash, testData) require.NoError(t, err) - // Create target file + // Pre-populate the target with stale bytes. A previous failed + // run could leave the working tree in this state; Link must + // publish the CAS content rather than silently keep the stale + // file. targetPath := filepath.Join("/target", "test.txt") - err = vfs.WriteFile(memFs, targetPath, []byte("existing content"), 0644) + err = vfs.WriteFile(v.FS, targetPath, []byte("existing content"), 0644) require.NoError(t, err) - // Try to create link - err = content.Link(t.Context(), testHash, targetPath, 0o644) + err = content.Link(t.Context(), v, testHash, targetPath, 0o644) require.NoError(t, err) - // Verify original content remains - existingData, err := vfs.ReadFile(memFs, targetPath) + got, err := vfs.ReadFile(v.FS, targetPath) require.NoError(t, err) - assert.Equal(t, []byte("existing content"), existingData) + assert.Equal(t, testData, got) }) } @@ -348,26 +368,27 @@ func TestContent_EnsureWithWait(t *testing.T) { t.Run("content already exists", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") // Store content first - err := content.Store(l, testHash, testData) + err := content.Store(l, v, testHash, testData) require.NoError(t, err) // EnsureWithWait should not need to write again - err = content.EnsureWithWait(l, testHash, []byte("different content")) + err = content.EnsureWithWait(l, v, testHash, []byte("different content")) require.NoError(t, err) // Verify original content remains partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := vfs.ReadFile(memFs, storedPath) + storedData, err := vfs.ReadFile(v.FS, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -375,22 +396,23 @@ func TestContent_EnsureWithWait(t *testing.T) { t.Run("content doesn't exist", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := "newcontent123456" testData := []byte("new test content") // EnsureWithWait should store the content - err := content.EnsureWithWait(l, testHash, testData) + err := content.EnsureWithWait(l, v, testHash, testData) require.NoError(t, err) // Verify content was stored partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := vfs.ReadFile(memFs, storedPath) + storedData, err := vfs.ReadFile(v.FS, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -398,9 +420,10 @@ func TestContent_EnsureWithWait(t *testing.T) { t.Run("concurrent writes - optimization", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) testHash := "concurrent123456" @@ -414,7 +437,7 @@ func TestContent_EnsureWithWait(t *testing.T) { go func() { defer close(process1Done) - err := content.EnsureWithWait(l, testHash, []byte("process 1 data")) + err := content.EnsureWithWait(l, v, testHash, []byte("process 1 data")) assert.NoError(t, err) close(process1Started) @@ -427,7 +450,7 @@ func TestContent_EnsureWithWait(t *testing.T) { // Wait for process 1 to start <-process1Started - err := content.EnsureWithWait(l, testHash, []byte("process 2 data")) + err := content.EnsureWithWait(l, v, testHash, []byte("process 2 data")) assert.NoError(t, err) }() @@ -438,8 +461,17 @@ func TestContent_EnsureWithWait(t *testing.T) { // Verify only one content exists (from process 1) partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := vfs.ReadFile(memFs, storedPath) + storedData, err := vfs.ReadFile(v.FS, storedPath) require.NoError(t, err) assert.Equal(t, []byte("process 1 data"), storedData) }) } + +func newMemVenv(t *testing.T) cas.Venv { + t.Helper() + + runner, err := git.NewGitRunner(vexec.NewOSExec()) + require.NoError(t, err) + + return cas.Venv{FS: vfs.NewMemMapFS(), Git: runner} +} diff --git a/internal/cas/errors.go b/internal/cas/errors.go index 1f5e8bd619..cebd15b0cc 100644 --- a/internal/cas/errors.go +++ b/internal/cas/errors.go @@ -64,14 +64,15 @@ func (e *WrappedError) Unwrap() error { // Git operation errors var ( - ErrCommandSpawn = errors.New("failed to spawn git command") - ErrNoMatchingReference = errors.New("no matching reference") - ErrReadTree = errors.New("failed to read tree") - ErrNoWorkDir = errors.New("working directory not set") - ErrGitStorePath = errors.New("failed to prepare git store path") - ErrGitStoreLock = errors.New("failed to acquire git store lock") - ErrGitStoreFSNotOS = errors.New("git store requires an OS-backed filesystem") - ErrFallbackCloneDir = errors.New("failed to create fallback clone directory") + ErrCommandSpawn = errors.New("failed to spawn git command") + ErrNoMatchingReference = errors.New("no matching reference") + ErrReadTree = errors.New("failed to read tree") + ErrNoWorkDir = errors.New("working directory not set") + ErrGitStorePath = errors.New("failed to prepare git store path") + ErrGitStoreLock = errors.New("failed to acquire git store lock") + ErrGitStoreFSNotOS = errors.New("git store requires an OS-backed filesystem") + ErrFallbackCloneDir = errors.New("failed to create fallback clone directory") + ErrFetchClosureRequired = errors.New("fetch closure is required") ) // UpdateSourceWithCASRequiresCASError is returned when a block sets diff --git a/internal/cas/getter_ssh_test.go b/internal/cas/getter_ssh_test.go index 6f6a761883..fc98e64a91 100644 --- a/internal/cas/getter_ssh_test.go +++ b/internal/cas/getter_ssh_test.go @@ -51,11 +51,14 @@ func TestSSHCASGetterGet(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + opts := &cas.CloneOptions{ Branch: "main", } l := logger.CreateLogger() - g := getter.NewCASGetter(l, c, opts) + g := getter.NewCASGetter(l, c, v, opts) client := getter.Client{ Getters: []getter.Getter{g}, } diff --git a/internal/cas/getter_test.go b/internal/cas/getter_test.go index daa80f536f..bac62730c1 100644 --- a/internal/cas/getter_test.go +++ b/internal/cas/getter_test.go @@ -8,6 +8,7 @@ import ( "github.com/gruntwork-io/terragrunt/internal/cas" "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/test/helpers" "github.com/gruntwork-io/terragrunt/test/helpers/logger" "github.com/stretchr/testify/assert" @@ -44,8 +45,7 @@ func TestCASGetterDetect(t *testing.T) { tmp := helpers.TmpDirWOSymlinks(t) - os.MkdirAll(filepath.Join(tmp, "fake-module"), 0755) - os.WriteFile(filepath.Join(tmp, "fake-module", "main.tf"), []byte(""), 0644) + require.NoError(t, vfs.WriteFile(g.Venv.FS, filepath.Join(tmp, "fake-module", "main.tf"), []byte(""), 0644)) tests := []struct { expectedErr error @@ -107,13 +107,16 @@ func TestCASGetterGet(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + opts := &cas.CloneOptions{ Depth: -1, } l := logger.CreateLogger() - g := getter.NewCASGetter(l, c, opts) + g := getter.NewCASGetter(l, c, v, opts) client := getter.Client{ Getters: []getter.Getter{g}, } @@ -157,22 +160,22 @@ func TestCASGetterLocalDir(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + opts := &cas.CloneOptions{ Branch: "main", } l := logger.CreateLogger() - g := getter.NewCASGetter(l, c, opts) + g := getter.NewCASGetter(l, c, v, opts) fakeModule := filepath.Join(tmp, "fake-module") - os.MkdirAll(fakeModule, 0755) - fakeModuleSubdir := filepath.Join(fakeModule, "subdir") - os.MkdirAll(fakeModuleSubdir, 0755) - os.WriteFile(filepath.Join(fakeModule, "main.tf"), []byte(""), 0644) - os.WriteFile(filepath.Join(fakeModuleSubdir, "subfile.tf"), []byte(""), 0644) + require.NoError(t, vfs.WriteFile(v.FS, filepath.Join(fakeModule, "main.tf"), []byte(""), 0644)) + require.NoError(t, vfs.WriteFile(v.FS, filepath.Join(fakeModuleSubdir, "subfile.tf"), []byte(""), 0644)) fakeDest := filepath.Join(tmp, "fake-dest") @@ -211,5 +214,8 @@ func newTestCASGetter(t *testing.T, opts *cas.CloneOptions) *getter.CASGetter { c, err := cas.New(cas.WithStorePath(filepath.Join(helpers.TmpDirWOSymlinks(t), "store"))) require.NoError(t, err) - return getter.NewCASGetter(logger.CreateLogger(), c, opts) + v, err := cas.OSVenv() + require.NoError(t, err) + + return getter.NewCASGetter(logger.CreateLogger(), c, v, opts) } diff --git a/internal/cas/gitstore.go b/internal/cas/gitstore.go index f1aa607ad0..2b5da60924 100644 --- a/internal/cas/gitstore.go +++ b/internal/cas/gitstore.go @@ -25,30 +25,24 @@ const ( ) // GitStore keeps one bare git repository per remote URL on disk so CAS cache -// misses can issue an incremental git fetch instead of a full shallow clone. +// misses can issue an incremental fetch instead of a fresh shallow clone. +// // Each per-URL repository is gated by an exclusive flock because pack-file -// writes are not safe to interleave with concurrent reads of the same repo. -// EnsureRef waits up to gitStoreLockTimeout for the lock; on context -// cancellation or timeout the caller can fall back to a temporary clone -// rather than block indefinitely on a hung holder. After acquiring the -// lock, EnsureRef re-checks for the requested object so a unit that simply -// waited out a peer's fetch can proceed without re-doing the work. -// The flock is held from EnsureRef return until the caller releases it. +// writes are not safe to interleave with concurrent reads. EnsureRef waits up +// to [gitStoreLockTimeout] before giving up so the caller can fall back to a +// temporary clone rather than block indefinitely on a hung holder. type GitStore struct { - runner *git.GitRunner rootPath string } -// GitStoreRepo is a locked handle to a per-URL bare repository. The -// caller has exclusive access to the underlying repo until [GitStoreRepo.Unlock] -// is called; failing to release the lock blocks subsequent fetches -// against the same URL. +// GitStoreRepo is a locked handle to a per-URL bare repository. The caller +// has exclusive access until [GitStoreRepo.Unlock] returns; failing to +// release the lock blocks every subsequent fetch against the same URL. type GitStoreRepo struct { unlocker vfs.Unlocker - // url records the URL acquire was called with so Release can - // include it in unlock-failure log messages without callers - // re-threading it. + // url is the source URL, kept so Release can name it in + // unlock-failure logs without callers re-threading it. url string // Path is the bare repository path, suitable for @@ -61,61 +55,39 @@ type GitStoreRepo struct { Hash string } -// Unlock releases the per-URL flock held by this repo handle and -// returns any unlock error to the caller. +// Unlock releases the per-URL flock and returns any unlock error. func (r *GitStoreRepo) Unlock() error { return r.unlocker.Unlock() } -// Release unlocks the repo handle, logging unlock failures against the -// originating URL. Intended for `defer repo.Release(l)`; callers that -// need the unlock error directly should call [GitStoreRepo.Unlock] instead. +// Release unlocks and logs any unlock error against the originating URL. +// Intended for `defer repo.Release(l)`; callers that need the error +// directly should use [GitStoreRepo.Unlock]. func (r *GitStoreRepo) Release(l log.Logger) { if err := r.unlocker.Unlock(); err != nil { l.Warnf("git store: failed to release lock for %s: %v", r.url, err) } } -// NewGitStore returns a [GitStore] rooted at rootPath, creating the directory -// on fs if needed. The filesystem is not retained; callers pass one explicitly -// to [GitStore.EnsureRef] and [GitStore.EnsureCommit]. -// -// The git store shells out to `git`, which only sees the real disk. Callers -// must pass an OS-backed [vfs.FS] from [vfs.NewOSFS]; an in-memory backing -// returns [ErrGitStoreFSNotOS]. -func NewGitStore(fs vfs.FS, runner *git.GitRunner, rootPath string) (*GitStore, error) { - if !vfs.IsOSFS(fs) { - return nil, ErrGitStoreFSNotOS - } - - if err := fs.MkdirAll(rootPath, DefaultDirPerms); err != nil { - return nil, fmt.Errorf("create git store at %s: %w", rootPath, errors.Join(ErrGitStorePath, err)) - } - - return &GitStore{ - runner: runner, - rootPath: rootPath, - }, nil +// NewGitStore returns a [GitStore] rooted at rootPath. The directory is +// created lazily on first write. +func NewGitStore(rootPath string) *GitStore { + return &GitStore{rootPath: rootPath} } // EnsureRef ensures the bare repository for url contains the object at -// hash, fetching ref at the requested depth if it does not. -// -// On success the returned repo's Path is suitable for -// [git.GitRunner.WithWorkDir], and the caller must release the embedded -// flock with [GitStoreRepo.Unlock] (or [GitStoreRepo.Release]) once done -// reading objects. -// -// On failure the lock is released before returning so callers can take -// a different code path without managing the lock themselves. +// hash, fetching ref at the requested depth on a cache miss. The returned +// handle holds the per-URL flock; the caller must release it via +// [GitStoreRepo.Unlock] or [GitStoreRepo.Release]. On error the lock is +// released before returning. func (s *GitStore) EnsureRef( ctx context.Context, l log.Logger, - fs vfs.FS, + v Venv, url, ref, hash string, depth int, ) (*GitStoreRepo, error) { - session, err := s.acquire(ctx, fs, l, url) + session, err := s.acquire(ctx, v, l, url) if err != nil { return nil, err } @@ -152,42 +124,25 @@ func (s *GitStore) EnsureRef( // EnsureCommit ensures the bare repository for url contains a commit // reachable from rawRef and returns its canonical full hash via -// [GitStoreRepo.Hash]. -// -// rawRef may be a full SHA-1 (40 hex chars), full SHA-256 (64 hex -// chars), or an abbreviated SHA that disambiguates inside the repo. -// Resolution runs `git rev-parse ^{commit}` against the per-URL -// bare repository, so any form git accepts works. -// -// If knownHash is non-empty, it is taken as the canonical hash of -// rawRef (typically from a prior [GitStore.ProbeCachedCommit] call). -// Presence is verified via [git.GitRunner.HasObject] and rev-parse -// is skipped on the cache-hit path. Pass "" to canonicalize via -// rev-parse. -// -// Behavior: +// [GitStoreRepo.Hash]. Any rawRef `git rev-parse` accepts works. // -// 1. If the commit is already cached in the bare repo, no network -// call is made. -// 2. Otherwise the bare repo is updated with a full-history fetch of -// every ref (no `--depth`). Tags are included so commits reachable -// only via tags resolve without a second fetch. Fetching by raw SHA -// is avoided because it requires `uploadpack.allowAnySHA1InWant`, -// which is not universally enabled on git servers. -// 3. If rev-parse still cannot resolve rawRef after the fetch, a -// [git.WrappedError] wrapping [git.ErrNoMatchingReference] is -// returned so callers can use [errors.Is] for the same condition -// `git ls-remote` surfaces for symbolic refs. +// If knownHash is non-empty (typically from [GitStore.ProbeCachedCommit]) +// the cache-hit path verifies it with [git.GitRunner.HasObject] and skips +// rev-parse. On a cache miss the bare repo fetches every branch with no +// --depth; fetching by raw SHA would require `uploadpack.allowAnySHA1InWant`, +// which is not universally enabled on git servers. An unresolvable rawRef +// after the fetch surfaces as [git.WrappedError] wrapping +// [git.ErrNoMatchingReference] so callers can match it with [errors.Is]. // -// On success the caller must release the lock via [GitStoreRepo.Unlock] -// or [GitStoreRepo.Release], matching the contract of [GitStore.EnsureRef]. +// Lock contract matches [GitStore.EnsureRef]: callers must release on +// success, the lock is released for them on error. func (s *GitStore) EnsureCommit( ctx context.Context, l log.Logger, - fs vfs.FS, + v Venv, url, rawRef, knownHash string, ) (*GitStoreRepo, error) { - session, err := s.acquire(ctx, fs, l, url) + session, err := s.acquire(ctx, v, l, url) if err != nil { return nil, err } @@ -230,12 +185,10 @@ func (s *GitStore) EnsureCommit( return session.keep(), nil } -// ensureKnownCommit handles the [GitStore.EnsureCommit] cache-hit path -// when the caller has already canonicalized rawRef via -// [GitStore.ProbeCachedCommit]. Presence of knownHash is verified -// with [git.GitRunner.HasObject], skipping the rev-parse spawn. On -// miss (e.g. a peer ran git-gc between the lock-free probe and the -// locked verify) a full-history fetch runs and presence is rechecked. +// ensureKnownCommit handles the [GitStore.EnsureCommit] path where the +// caller has already canonicalized rawRef. A locked miss (a peer ran +// git-gc between the lock-free probe and this verify) triggers a +// full-history fetch and a recheck. func (s *GitStore) ensureKnownCommit( ctx context.Context, session *repoSession, @@ -273,34 +226,30 @@ func (s *GitStore) ensureKnownCommit( return session.keep(), nil } -// ProbeCachedCommit returns the canonical commit hash if rawRef -// resolves to a commit already stored in the per-URL bare repository -// for url and rawRef is a prefix of that hash. Returns ok=false in -// any other case (no bare repo yet, unresolvable ref, or a name that -// resolved through ref lookup such as a hex-named branch whose tip -// is a different commit). +// ProbeCachedCommit returns the canonical commit hash when rawRef is a +// prefix of a commit already stored in the per-URL bare repository, and +// ok=false otherwise (no bare repo, unresolvable ref, or a name that +// happened to resolve through ref lookup, such as a hex-named branch). // -// Panics if fs is not OS-backed. git only sees the real disk, so a -// non-OS backing cannot satisfy the probe. +// The probe is lock-free: rev-parse only reads pack indices and refs, +// both updated atomically by git. Acquiring the per-URL flock here would +// queue every probe behind any in-flight fetch and erase the offline +// win. // -// The probe is lock-free on purpose. The per-URL flock serializes -// fetches so concurrent pack-file writes do not interleave, but -// rev-parse only reads pack indices and refs, both of which git -// updates atomically. Acquiring the flock for a read would queue the -// probe behind any in-flight fetch and erase the offline win. -func (s *GitStore) ProbeCachedCommit(ctx context.Context, fs vfs.FS, url, rawRef string) (string, bool) { - if !vfs.IsOSFS(fs) { +// Panics when v.FS is not OS-backed; git only sees the real disk. +func (s *GitStore) ProbeCachedCommit(ctx context.Context, v Venv, url, rawRef string) (string, bool) { + if !vfs.IsOSFS(v.FS) { panic(ErrGitStoreFSNotOS) } _, repoPath, _ := s.repoPaths(url) - initialized, err := bareRepoInitialized(fs, repoPath) + initialized, err := bareRepoInitialized(v.FS, repoPath) if err != nil || !initialized { return "", false } - hash, err := s.runner.WithWorkDir(repoPath).RevParseCommit(ctx, rawRef) + hash, err := v.Git.WithWorkDir(repoPath).RevParseCommit(ctx, rawRef) if err != nil { return "", false } @@ -346,11 +295,10 @@ func bareRepoInitialized(fs vfs.FS, repoPath string) (bool, error) { return vfs.FileExists(fs, filepath.Join(repoPath, "HEAD")) } -// repoSession bundles everything a [GitStore.acquire] caller needs -// to operate on a per-URL bare repository: the [GitStoreRepo] handle -// (the locked thing), a runner pointed at it, and a deferred-cleanup -// helper. Callers defer cleanup(); keep() promotes the handle so the -// lock survives until the caller releases it explicitly. +// repoSession bundles the locked repo handle, the runner pointed at it, +// and a deferred-cleanup helper. Callers `defer session.cleanup()` to +// release the lock on error and call `session.keep()` to promote the +// handle on success. type repoSession struct { l log.Logger repo *GitStoreRepo @@ -358,16 +306,14 @@ type repoSession struct { kept bool } -// keep promotes the session's repo handle to the caller. After keep -// the deferred cleanup is a no-op and the caller owns the lock until -// it invokes [GitStoreRepo.Unlock] or [GitStoreRepo.Release]. +// keep promotes the handle so cleanup is a no-op and the caller owns +// the lock until [GitStoreRepo.Unlock] or [GitStoreRepo.Release]. func (s *repoSession) keep() *GitStoreRepo { s.kept = true return s.repo } -// cleanup releases the lock unless keep was called. Intended for -// `defer session.cleanup()`. +// cleanup releases the lock unless keep was called. func (s *repoSession) cleanup() { if s.kept { return @@ -376,28 +322,28 @@ func (s *repoSession) cleanup() { s.repo.Release(s.l) } -// acquire claims the per-URL flock for url, prepares the bare-repo -// directory, and returns a [repoSession] carrying the locked handle -// and a runner pointed at it. The caller defers session.cleanup(); -// session.keep() promotes the handle so the lock survives. -// -// `git init --bare` is invoked only on first use of a store entry; -// subsequent calls detect the existing HEAD and skip the spawn. -func (s *GitStore) acquire(ctx context.Context, fs vfs.FS, l log.Logger, url string) (*repoSession, error) { - if !vfs.IsOSFS(fs) { +// acquire claims the per-URL flock and returns a [repoSession] carrying +// the locked handle. `git init --bare` runs only on first use of a store +// entry; subsequent calls detect HEAD and skip the spawn. +func (s *GitStore) acquire(ctx context.Context, v Venv, l log.Logger, url string) (*repoSession, error) { + if !vfs.IsOSFS(v.FS) { return nil, ErrGitStoreFSNotOS } + if err := v.FS.MkdirAll(s.rootPath, DefaultDirPerms); err != nil { + return nil, fmt.Errorf("create git store at %s: %w", s.rootPath, errors.Join(ErrGitStorePath, err)) + } + dir, repoPath, lockPath := s.repoPaths(url) - if err := fs.MkdirAll(dir, DefaultDirPerms); err != nil { + if err := v.FS.MkdirAll(dir, DefaultDirPerms); err != nil { return nil, fmt.Errorf("create git store entry %s: %w", dir, errors.Join(ErrGitStorePath, err)) } lockCtx, cancel := context.WithTimeout(ctx, gitStoreLockTimeout) defer cancel() - unlocker, err := vfs.LockContext(lockCtx, fs, lockPath) + unlocker, err := vfs.LockContext(lockCtx, v.FS, lockPath) if err != nil { return nil, fmt.Errorf("lock git store for %s: %w", url, errors.Join(ErrGitStoreLock, err)) } @@ -405,15 +351,15 @@ func (s *GitStore) acquire(ctx context.Context, fs vfs.FS, l log.Logger, url str session := &repoSession{ l: l, repo: &GitStoreRepo{unlocker: unlocker, url: url, Path: repoPath}, - runner: s.runner.WithWorkDir(repoPath), + runner: v.Git.WithWorkDir(repoPath), } - if err := fs.MkdirAll(repoPath, DefaultDirPerms); err != nil { + if err := v.FS.MkdirAll(repoPath, DefaultDirPerms); err != nil { session.cleanup() return nil, fmt.Errorf("create bare repo dir %s: %w", repoPath, errors.Join(ErrGitStorePath, err)) } - initialized, err := bareRepoInitialized(fs, repoPath) + initialized, err := bareRepoInitialized(v.FS, repoPath) if err != nil { session.cleanup() return nil, fmt.Errorf("inspect bare repo %s: %w", repoPath, errors.Join(ErrGitStorePath, err)) diff --git a/internal/cas/gitstore_test.go b/internal/cas/gitstore_test.go index 53795e8f18..a40946f36b 100644 --- a/internal/cas/gitstore_test.go +++ b/internal/cas/gitstore_test.go @@ -25,22 +25,22 @@ func TestGitStoreEnsureRef_InitsAndFetches(t *testing.T) { url := startTestServer(t) hash := resolveHead(t, url) - store, fs, root := newTestGitStore(t) + store, v, root := newTestGitStore(t) l := logger.CreateLogger() ctx := t.Context() - repo, err := store.EnsureRef(ctx, l, fs, url, "main", hash, 0) + repo, err := store.EnsureRef(ctx, l, v, url, "main", hash, 0) require.NoError(t, err) assert.True(t, strings.HasPrefix(repo.Path, root), "repo path %q should be under store root %q", repo.Path, root) - _, err = fs.Stat(filepath.Join(repo.Path, "HEAD")) + _, err = v.FS.Stat(filepath.Join(repo.Path, "HEAD")) require.NoError(t, err) require.NoError(t, repo.Unlock()) // Second call hits the cache-warm path: object already present, no fetch. - repo2, err := store.EnsureRef(ctx, l, fs, url, "main", hash, 0) + repo2, err := store.EnsureRef(ctx, l, v, url, "main", hash, 0) require.NoError(t, err) require.NoError(t, repo2.Unlock()) } @@ -51,7 +51,7 @@ func TestGitStoreEnsureRef_PartitionsByURL(t *testing.T) { url1 := startTestServer(t) url2 := startTestServer(t) - store, fs, root := newTestGitStore(t) + store, v, root := newTestGitStore(t) require.NotEmpty(t, root) l := logger.CreateLogger() @@ -60,11 +60,11 @@ func TestGitStoreEnsureRef_PartitionsByURL(t *testing.T) { hash1 := resolveHead(t, url1) hash2 := resolveHead(t, url2) - e1, err := store.EnsureRef(ctx, l, fs, url1, "main", hash1, 0) + e1, err := store.EnsureRef(ctx, l, v, url1, "main", hash1, 0) require.NoError(t, err) require.NoError(t, e1.Unlock()) - e2, err := store.EnsureRef(ctx, l, fs, url2, "main", hash2, 0) + e2, err := store.EnsureRef(ctx, l, v, url2, "main", hash2, 0) require.NoError(t, err) require.NoError(t, e2.Unlock()) @@ -77,7 +77,7 @@ func TestGitStoreEnsureRefConcurrentSameURLWithRacing(t *testing.T) { url := startTestServer(t) hash := resolveHead(t, url) - store, fs, root := newTestGitStore(t) + store, v, root := newTestGitStore(t) require.NotEmpty(t, root) l := logger.CreateLogger() @@ -94,7 +94,7 @@ func TestGitStoreEnsureRefConcurrentSameURLWithRacing(t *testing.T) { go func(idx int) { defer wg.Done() - repo, err := store.EnsureRef(t.Context(), l, fs, url, "main", hash, 0) + repo, err := store.EnsureRef(t.Context(), l, v, url, "main", hash, 0) if err != nil { errs[idx] = err return @@ -117,13 +117,13 @@ func TestGitStoreEnsureRef_LockHeldRespectsContextCancellation(t *testing.T) { url := startTestServer(t) hash := resolveHead(t, url) - store, fs, root := newTestGitStore(t) + store, v, root := newTestGitStore(t) require.NotEmpty(t, root) l := logger.CreateLogger() // First caller takes the per-URL lock and holds it. - repo, err := store.EnsureRef(t.Context(), l, fs, url, "main", hash, 0) + repo, err := store.EnsureRef(t.Context(), l, v, url, "main", hash, 0) require.NoError(t, err) require.NotEmpty(t, repo.Path) t.Cleanup(func() { _ = repo.Unlock() }) @@ -135,7 +135,7 @@ func TestGitStoreEnsureRef_LockHeldRespectsContextCancellation(t *testing.T) { start := time.Now() - _, err = store.EnsureRef(ctx, l, fs, url, "main", hash, 0) + _, err = store.EnsureRef(ctx, l, v, url, "main", hash, 0) require.Error(t, err) assert.Less(t, time.Since(start), 5*time.Second, "EnsureRef should not block past the context deadline") assert.True( @@ -151,12 +151,12 @@ func TestGitStoreEnsureRefLockReleaseAllowsWaiterToProceedWithRacing(t *testing. url := startTestServer(t) hash := resolveHead(t, url) - store, fs, root := newTestGitStore(t) + store, v, root := newTestGitStore(t) require.NotEmpty(t, root) l := logger.CreateLogger() - repo, err := store.EnsureRef(t.Context(), l, fs, url, "main", hash, 0) + repo, err := store.EnsureRef(t.Context(), l, v, url, "main", hash, 0) require.NoError(t, err) // Release the holder after a short delay so the waiter sees the lock open. @@ -169,7 +169,7 @@ func TestGitStoreEnsureRefLockReleaseAllowsWaiterToProceedWithRacing(t *testing. ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) defer cancel() - repo2, err := store.EnsureRef(ctx, l, fs, url, "main", hash, 0) + repo2, err := store.EnsureRef(ctx, l, v, url, "main", hash, 0) require.NoError(t, err) require.NoError(t, repo2.Unlock()) } @@ -177,12 +177,12 @@ func TestGitStoreEnsureRefLockReleaseAllowsWaiterToProceedWithRacing(t *testing. func TestGitStoreEnsureRef_FetchFailureSurfacesError(t *testing.T) { t.Parallel() - store, fs, root := newTestGitStore(t) + store, v, root := newTestGitStore(t) require.NotEmpty(t, root) l := logger.CreateLogger() - _, err := store.EnsureRef(t.Context(), l, fs, "file:///does/not/exist", "main", "deadbeef", 0) + _, err := store.EnsureRef(t.Context(), l, v, "file:///does/not/exist", "main", "deadbeef", 0) require.Error(t, err) } @@ -194,32 +194,32 @@ func TestGitStoreRejectsNonOSFilesystem(t *testing.T) { root := filepath.Join(helpers.TmpDirWOSymlinks(t), "gitstore") - _, err = cas.NewGitStore(vfs.NewMemMapFS(), runner, root) - require.ErrorIs(t, err, cas.ErrGitStoreFSNotOS) + store := cas.NewGitStore(root) - store, err := cas.NewGitStore(vfs.NewOSFS(), runner, root) - require.NoError(t, err) + memVenv := cas.Venv{FS: vfs.NewMemMapFS(), Git: runner} _, err = store.EnsureRef( - t.Context(), logger.CreateLogger(), vfs.NewMemMapFS(), + t.Context(), logger.CreateLogger(), memVenv, "file:///does/not/exist", "main", "deadbeef", 0, ) require.ErrorIs(t, err, cas.ErrGitStoreFSNotOS) + + require.PanicsWithValue(t, cas.ErrGitStoreFSNotOS, func() { + store.ProbeCachedCommit(t.Context(), memVenv, "file:///does/not/exist", "deadbeef") + }) } -func newTestGitStore(t *testing.T) (*cas.GitStore, vfs.FS, string) { +func newTestGitStore(t *testing.T) (*cas.GitStore, cas.Venv, string) { t.Helper() root := filepath.Join(helpers.TmpDirWOSymlinks(t), "gitstore") - runner, err := git.NewGitRunner(vexec.NewOSExec()) - require.NoError(t, err) - - fs := vfs.NewOSFS() - store, err := cas.NewGitStore(fs, runner, root) + v, err := cas.OSVenv() require.NoError(t, err) - return store, fs, root + store := cas.NewGitStore(root) + + return store, v, root } func resolveHead(t *testing.T, url string) string { diff --git a/internal/cas/integration_test.go b/internal/cas/integration_test.go index 4704aa35c0..3d79b7cef7 100644 --- a/internal/cas/integration_test.go +++ b/internal/cas/integration_test.go @@ -20,6 +20,9 @@ func TestIntegration_CloneAndReuse(t *testing.T) { l := logger.CreateLogger() repoURL := startTestServer(t) + v, err := cas.OSVenv() + require.NoError(t, err) + t.Run("clone same repo twice uses store", func(t *testing.T) { t.Parallel() tempDir := helpers.TmpDirWOSymlinks(t) @@ -29,10 +32,8 @@ func TestIntegration_CloneAndReuse(t *testing.T) { firstClonePath := filepath.Join(tempDir, "first") cas1, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - require.NoError(t, cas1.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: firstClonePath, - Depth: -1, - }, repoURL)) + require.NoError(t, cas1.Clone(t.Context(), l, v, repoURL, cas.WithDir(firstClonePath), + cas.WithDepth(-1))) // Get info about first clone firstReadme := filepath.Join(firstClonePath, "README.md") @@ -43,10 +44,8 @@ func TestIntegration_CloneAndReuse(t *testing.T) { secondClonePath := filepath.Join(tempDir, "second") cas2, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - require.NoError(t, cas2.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: secondClonePath, - Depth: -1, - }, repoURL)) + require.NoError(t, cas2.Clone(t.Context(), l, v, repoURL, cas.WithDir(secondClonePath), + cas.WithDepth(-1))) // Get info about second clone secondReadme := filepath.Join(secondClonePath, "README.md") @@ -68,11 +67,9 @@ func TestIntegration_CloneAndReuse(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "repo"), - Branch: "nonexistent-branch", - Depth: -1, - }, repoURL) + err = c.Clone(t.Context(), l, v, repoURL, cas.WithDir(filepath.Join(tempDir, "repo")), + cas.WithBranch("nonexistent-branch"), + cas.WithDepth(-1)) require.Error(t, err) var wrappedErr *git.WrappedError @@ -87,10 +84,9 @@ func TestIntegration_CloneAndReuse(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) - err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "repo"), - Depth: -1, - }, "http://127.0.0.1:1/nonexistent-repo.git") + err = c.Clone(t.Context(), l, v, "http://127.0.0.1:1/nonexistent-repo.git", + cas.WithDir(filepath.Join(tempDir, "repo")), + cas.WithDepth(-1)) require.Error(t, err) }) } @@ -102,6 +98,9 @@ func TestIntegration_TreeStorage(t *testing.T) { l := logger.CreateLogger() repoURL := startTestServer(t) + v, err := cas.OSVenv() + require.NoError(t, err) + t.Run("stores tree objects", func(t *testing.T) { t.Parallel() tempDir := helpers.TmpDirWOSymlinks(t) @@ -110,10 +109,8 @@ func TestIntegration_TreeStorage(t *testing.T) { // First clone to populate store c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - require.NoError(t, c.Clone(ctx, l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "repo"), - Depth: -1, - }, repoURL)) + require.NoError(t, c.Clone(ctx, l, v, repoURL, cas.WithDir(filepath.Join(tempDir, "repo")), + cas.WithDepth(-1))) // Get the commit hash for HEAD g, err := git.NewGitRunner(vexec.NewOSExec()) @@ -128,11 +125,11 @@ func TestIntegration_TreeStorage(t *testing.T) { treeStore := cas.NewStore(filepath.Join(storePath, "trees")) require.NoError(t, err) - assert.False(t, treeStore.NeedsWrite(commitHash), "Tree object should be stored") + assert.False(t, treeStore.NeedsWrite(v, commitHash), "Tree object should be stored") // Verify we can read the tree content content := cas.NewContent(treeStore) - treeData, err := content.Read(commitHash) + treeData, err := content.Read(v, commitHash) require.NoError(t, err) // Parse the tree data to confirm it's valid diff --git a/internal/cas/local.go b/internal/cas/local.go index 1955a7c58b..70e96ca21d 100644 --- a/internal/cas/local.go +++ b/internal/cas/local.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/fs" + "os" "path/filepath" "strings" @@ -14,6 +15,9 @@ import ( "github.com/gruntwork-io/terragrunt/pkg/log" ) +// gitSymlinkMode is the git tree entry mode for a symbolic link. +const gitSymlinkMode = "120000" + // DefaultLocalHashAlgorithm is used for content-addressed hashing of local source // trees. It is chosen independently of any git repository's object format because // local sources have no repository to inherit a format from. @@ -21,18 +25,23 @@ const DefaultLocalHashAlgorithm = HashSHA256 // StoreLocalDirectory persists all content from a local source directory into the CAS // and then links the persisted files to the target directory. +// +// Requires v.FS. v.Git is not used. func (c *CAS) StoreLocalDirectory( ctx context.Context, l log.Logger, + v Venv, sourceDir, targetDir string, opts ...LinkTreeOption, ) error { - hash, treeData, err := c.buildLocalTree(sourceDir, DefaultLocalHashAlgorithm) + v.RequireFS() + + hash, treeData, err := c.buildLocalTree(v, sourceDir, DefaultLocalHashAlgorithm) if err != nil { return fmt.Errorf("failed to hash local directory %s: %w", sourceDir, err) } - if err = c.storeLocalContent(l, sourceDir, hash, treeData, DefaultLocalHashAlgorithm); err != nil { + if err = c.storeFetchedContent(l, v, sourceDir, hash, treeData, DefaultLocalHashAlgorithm); err != nil { return fmt.Errorf("failed to store local content: %w", err) } @@ -41,30 +50,40 @@ func (c *CAS) StoreLocalDirectory( return fmt.Errorf("failed to parse local tree: %w", err) } - return LinkTree(ctx, c.blobStore, c.treeStore, tree, targetDir, opts...) + return LinkTree(ctx, v, c.blobStore, c.treeStore, tree, targetDir, opts...) } // ComputeLocalRootHash walks dir in deterministic (lexical) order and produces a // content-addressed hash over (relpath, mode, file-content-hash) triples. The -// returned hash plays the same role as a git ref hash does in the remote flow — +// returned hash plays the same role as a git ref hash does in the remote flow: // it is the "root" for DeterministicTreeHash calls when rewriting nested sources. // The same file-content hashes are used both inside the root-hash and as blob // hashes in the synthetic tree, so blob lookups and tree lookups stay consistent. -func (c *CAS) ComputeLocalRootHash(dir string, alg HashAlgorithm) (string, error) { - hash, _, err := c.buildLocalTree(dir, alg) +// +// Requires v.FS. v.Git is not used. +func (c *CAS) ComputeLocalRootHash(v Venv, dir string, alg HashAlgorithm) (string, error) { + v.RequireFS() + + hash, _, err := c.buildLocalTree(v, dir, alg) + return hash, err } // buildLocalTree walks dir and returns (rootHash, treeData). The treeData has // the same " blob \t\n" format as a git tree, but with // file-content hashes taken in the chosen algorithm. -func (c *CAS) buildLocalTree(dir string, alg HashAlgorithm) (string, []byte, error) { +// +// Symlinks are preserved as 120000 entries whose blob hash is the hash of the +// link target string, matching git's symlink representation. Targets that +// escape dir are rejected at ingest time so the CAS cannot store a tree that +// would resolve outside the destination at materialize time. +func (c *CAS) buildLocalTree(v Venv, dir string, alg HashAlgorithm) (string, []byte, error) { var ( treeData []byte rootBuf []byte ) - err := vfs.WalkDir(c.fs, dir, func(path string, d fs.DirEntry, walkErr error) error { + err := vfs.WalkDir(v.FS, dir, func(path string, d fs.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } @@ -78,32 +97,23 @@ func (c *CAS) buildLocalTree(dir string, alg HashAlgorithm) (string, []byte, err return fmt.Errorf("failed to stat file %s: %w", path, err) } - // Skip symlinks and other non-regular entries to keep the synthetic - // tree consistent with copyTree, which only copies regular files. - if !info.Mode().IsRegular() { - return nil - } - - relPath, err := filepath.Rel(dir, path) + relPath, err := localRelPath(dir, path) if err != nil { return err } - // Git-style forward slashes in tree entries, regardless of host OS. - relPath = strings.ReplaceAll(relPath, string(filepath.Separator), "/") - - fileHash, err := hashFileAlg(c.fs, path, alg) + mode, blobHash, err := hashLocalEntry(v.FS, dir, path, info, alg) if err != nil { - return fmt.Errorf("failed to hash file %s: %w", path, err) + return err } - mode := fmt.Sprintf("%06o", info.Mode().Perm()) - treeData = append(treeData, fmt.Appendf(nil, "%s blob %s\t%s\n", mode, fileHash, relPath)...) + if mode == "" { + // Non-regular, non-symlink (device, fifo, socket); skip. + return nil + } - // Root hash input includes path, mode, and content hash so that two - // trees with identical files at different relative paths (or different - // permissions) get distinct root hashes. - rootBuf = append(rootBuf, fmt.Appendf(nil, "%s %s %s\n", relPath, mode, fileHash)...) + treeData = append(treeData, fmt.Appendf(nil, "%s blob %s\t%s\n", mode, blobHash, relPath)...) + rootBuf = append(rootBuf, fmt.Appendf(nil, "%s %s %s\n", relPath, mode, blobHash)...) return nil }) @@ -114,44 +124,52 @@ func (c *CAS) buildLocalTree(dir string, alg HashAlgorithm) (string, []byte, err return alg.Sum(rootBuf), treeData, nil } -// storeLocalContent stores the tree object and every blob referenced by it. -func (c *CAS) storeLocalContent(l log.Logger, sourceDir, dirHash string, treeData []byte, alg HashAlgorithm) error { - treeContent := NewContent(c.treeStore) - if err := treeContent.Ensure(l, dirHash, treeData); err != nil { - return fmt.Errorf("failed to store tree data: %w", err) +// localRelPath returns the git-style (forward-slash) relative path of path +// inside dir. +func localRelPath(dir, path string) (string, error) { + rel, err := filepath.Rel(dir, path) + if err != nil { + return "", err } - blobContent := NewContent(c.blobStore) + return strings.ReplaceAll(rel, string(filepath.Separator), "/"), nil +} - return vfs.WalkDir(c.fs, sourceDir, func(path string, d fs.DirEntry, err error) error { +// hashLocalEntry returns the git-style mode and content hash of a single +// directory entry. For regular files it hashes the bytes through alg; for +// symlinks it hashes the link target after [vfs.ValidateSymlinkTarget] checks +// it stays inside dir. mode is empty for entries that should be skipped +// entirely (devices, FIFOs, sockets). +func hashLocalEntry( + fsys vfs.FS, + dir, path string, + info os.FileInfo, + alg HashAlgorithm, +) (mode, hash string, err error) { + switch { + case info.Mode().IsRegular(): + fileHash, err := hashFileAlg(fsys, path, alg) if err != nil { - return err + return "", "", fmt.Errorf("failed to hash file %s: %w", path, err) } - if d.IsDir() { - return nil - } + return fmt.Sprintf("%06o", info.Mode().Perm()), fileHash, nil - info, err := d.Info() + case info.Mode()&os.ModeSymlink != 0: + target, err := vfs.Readlink(fsys, path) if err != nil { - return fmt.Errorf("failed to stat file %s: %w", path, err) - } - - if !info.Mode().IsRegular() { - return nil + return "", "", fmt.Errorf("read symlink %s: %w", path, err) } - fileHash, err := hashFileAlg(c.fs, path, alg) - if err != nil { - return fmt.Errorf("failed to hash file %s: %w", path, err) + if err := vfs.ValidateSymlinkTarget(dir, path, target); err != nil { + return "", "", err } - if err := blobContent.EnsureCopy(l, fileHash, path); err != nil { - return fmt.Errorf("failed to store file %s: %w", path, err) - } + return gitSymlinkMode, alg.Sum([]byte(target)), nil - return nil - }) + default: + return "", "", nil + } } // hashFileAlg hashes a file's contents using the given algorithm and returns diff --git a/internal/cas/local_test.go b/internal/cas/local_test.go index 3458f87b73..67ec18227f 100644 --- a/internal/cas/local_test.go +++ b/internal/cas/local_test.go @@ -8,50 +8,26 @@ import ( "github.com/gruntwork-io/terragrunt/internal/cas" "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) -// writeLocalFixture populates a fresh directory with a small deterministic tree -// and returns its absolute path. -func writeLocalFixture(t *testing.T, files map[string]string) string { - t.Helper() - - dir := helpers.TmpDirWOSymlinks(t) - for rel, body := range files { - full := filepath.Join(dir, rel) - require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) - require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) - } - - return dir -} - -// newCAS constructs a CAS instance backed by a fresh per-test store directory. -func newCAS(t *testing.T) *cas.CAS { - t.Helper() - - storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") - c, err := cas.New(cas.WithStorePath(storePath)) - require.NoError(t, err) - - return c -} - func TestComputeLocalRootHash_Deterministic(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) dir := writeLocalFixture(t, map[string]string{ "main.tf": `resource "null_resource" "a" {}`, "subdir/nested.tf": `variable "x" {}` + "\n", "subdir/README.md": "hello", }) - h1, err := c.ComputeLocalRootHash(dir, cas.HashSHA256) + h1, err := c.ComputeLocalRootHash(v, dir, cas.HashSHA256) require.NoError(t, err) - h2, err := c.ComputeLocalRootHash(dir, cas.HashSHA256) + h2, err := c.ComputeLocalRootHash(v, dir, cas.HashSHA256) require.NoError(t, err) assert.Equal(t, h1, h2, "same directory must produce the same root hash") @@ -61,17 +37,17 @@ func TestComputeLocalRootHash_Deterministic(t *testing.T) { func TestComputeLocalRootHash_DiffersOnContentChange(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) dir := writeLocalFixture(t, map[string]string{ "main.tf": "one", }) - before, err := c.ComputeLocalRootHash(dir, cas.HashSHA256) + before, err := c.ComputeLocalRootHash(v, dir, cas.HashSHA256) require.NoError(t, err) require.NoError(t, os.WriteFile(filepath.Join(dir, "main.tf"), []byte("two"), 0o644)) - after, err := c.ComputeLocalRootHash(dir, cas.HashSHA256) + after, err := c.ComputeLocalRootHash(v, dir, cas.HashSHA256) require.NoError(t, err) assert.NotEqual(t, before, after, "changing a file's contents must change the root hash") @@ -84,17 +60,17 @@ func TestComputeLocalRootHash_DiffersOnModeChange(t *testing.T) { t.Skip("file mode changes are not meaningfully observable on Windows") } - c := newCAS(t) + c, v := newCAS(t) dir := writeLocalFixture(t, map[string]string{ "script.sh": "#!/bin/sh\n", }) - before, err := c.ComputeLocalRootHash(dir, cas.HashSHA256) + before, err := c.ComputeLocalRootHash(v, dir, cas.HashSHA256) require.NoError(t, err) require.NoError(t, os.Chmod(filepath.Join(dir, "script.sh"), 0o755)) - after, err := c.ComputeLocalRootHash(dir, cas.HashSHA256) + after, err := c.ComputeLocalRootHash(v, dir, cas.HashSHA256) require.NoError(t, err) assert.NotEqual(t, before, after, "changing a file's mode must change the root hash") @@ -103,7 +79,7 @@ func TestComputeLocalRootHash_DiffersOnModeChange(t *testing.T) { func TestComputeLocalRootHash_IgnoresAbsolutePath(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) files := map[string]string{ "main.tf": `resource "null_resource" "a" {}` + "\n", @@ -113,11 +89,169 @@ func TestComputeLocalRootHash_IgnoresAbsolutePath(t *testing.T) { dirA := writeLocalFixture(t, files) dirB := writeLocalFixture(t, files) - hashA, err := c.ComputeLocalRootHash(dirA, cas.HashSHA256) + hashA, err := c.ComputeLocalRootHash(v, dirA, cas.HashSHA256) require.NoError(t, err) - hashB, err := c.ComputeLocalRootHash(dirB, cas.HashSHA256) + hashB, err := c.ComputeLocalRootHash(v, dirB, cas.HashSHA256) require.NoError(t, err) assert.Equal(t, hashA, hashB, "identical contents at different absolute paths must hash identically") } + +// TestStoreLocalDirectoryConcurrentWithRacing pins the blob-then-tree +// write order in storeFetchedContent: a racing reader that finds the +// tree must always find every blob it references. Pre-refactor, the +// original storeLocalContent wrote the tree first and the blobs after, +// leaving a window where a reader could hit a `read_source: failed to +// read file` error when linking blobs. CI runs this test under -race +// per the WithRacing suffix convention. +func TestStoreLocalDirectoryConcurrentWithRacing(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + src := writeLocalFixture(t, map[string]string{ + "main.tf": `resource "null_resource" "test" {}`, + "vars.tf": `variable "x" {}`, + "README.md": "readme", + "sub/nest.tf": "# nested file", + }) + + const n = 8 + + dsts := make([]string, n) + for i := range n { + dsts[i] = filepath.Join(t.TempDir(), "dst") + } + + var g errgroup.Group + + for i := range n { + dst := dsts[i] + + g.Go(func() error { + return c.StoreLocalDirectory(t.Context(), l, v, src, dst) + }) + } + + require.NoError(t, g.Wait()) + + for _, dst := range dsts { + require.FileExists(t, filepath.Join(dst, "main.tf")) + require.FileExists(t, filepath.Join(dst, "sub", "nest.tf")) + } +} + +// TestStoreLocalDirectorySymlink covers symlink ingestion end-to-end: a +// fixture with an in-tree symlink round-trips through StoreLocalDirectory +// and the destination has a real symlink pointing at the same target. +func TestStoreLocalDirectorySymlink(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("os.Symlink on Windows requires special permissions; covered by Unix CI") + } + + c, v := newCAS(t) + l := logger.CreateLogger() + + src := writeLocalFixture(t, map[string]string{ + "main.tf": "ok", + }) + require.NoError(t, os.Symlink("main.tf", filepath.Join(src, "alias.tf"))) + + dst := filepath.Join(t.TempDir(), "dst") + require.NoError(t, c.StoreLocalDirectory(t.Context(), l, v, src, dst)) + + info, err := os.Lstat(filepath.Join(dst, "alias.tf")) + require.NoError(t, err) + assert.NotZero(t, info.Mode()&os.ModeSymlink, "alias.tf must be a real symlink") + + got, err := os.Readlink(filepath.Join(dst, "alias.tf")) + require.NoError(t, err) + assert.Equal(t, "main.tf", got) +} + +// TestStoreLocalDirectoryRejectsEscapingSymlink pins the safety check: a +// symlink whose target climbs above the source root is rejected at ingest +// time rather than poisoning the CAS with content a later materialize would +// have to refuse anyway. +func TestStoreLocalDirectoryRejectsEscapingSymlink(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("os.Symlink on Windows requires special permissions; covered by Unix CI") + } + + c, v := newCAS(t) + l := logger.CreateLogger() + + src := writeLocalFixture(t, map[string]string{ + "main.tf": "ok", + }) + require.NoError(t, os.Symlink("../etc/passwd", filepath.Join(src, "escape"))) + + dst := filepath.Join(t.TempDir(), "dst") + err := c.StoreLocalDirectory(t.Context(), l, v, src, dst) + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink target escapes") +} + +// TestComputeLocalRootHashIncludesSymlinks pins that swapping a symlink's +// target changes the root hash. The pre-symlink-support buildLocalTree +// silently skipped symlinks, so two trees that differed only in link target +// hashed identically, breaking the content-addressed contract. +func TestComputeLocalRootHashIncludesSymlinks(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("os.Symlink on Windows requires special permissions; covered by Unix CI") + } + + c, v := newCAS(t) + + dirA := writeLocalFixture(t, map[string]string{"main.tf": "ok"}) + require.NoError(t, os.Symlink("main.tf", filepath.Join(dirA, "link"))) + + dirB := writeLocalFixture(t, map[string]string{"main.tf": "ok"}) + require.NoError(t, os.Symlink("other.tf", filepath.Join(dirB, "link"))) + + hashA, err := c.ComputeLocalRootHash(v, dirA, cas.HashSHA256) + require.NoError(t, err) + + hashB, err := c.ComputeLocalRootHash(v, dirB, cas.HashSHA256) + require.NoError(t, err) + + assert.NotEqual(t, hashA, hashB, "symlink target must contribute to the root hash") +} + +// writeLocalFixture populates a fresh directory with a small deterministic tree +// and returns its absolute path. +func writeLocalFixture(t *testing.T, files map[string]string) string { + t.Helper() + + dir := helpers.TmpDirWOSymlinks(t) + for rel, body := range files { + full := filepath.Join(dir, rel) + require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) + require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) + } + + return dir +} + +// newCAS constructs a CAS instance backed by a fresh per-test store directory +// along with a production [cas.Venv]. +func newCAS(t *testing.T) (*cas.CAS, cas.Venv) { + t.Helper() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + return c, v +} diff --git a/internal/cas/materialize_test.go b/internal/cas/materialize_test.go index 5107ff9639..7f91dfe323 100644 --- a/internal/cas/materialize_test.go +++ b/internal/cas/materialize_test.go @@ -26,28 +26,32 @@ func TestMaterializeTree_FromSynthStore(t *testing.T) { require.NoError(t, os.MkdirAll(s.Path(), 0755)) } + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + // Store a blob in the blob store blobData := []byte("hello world\n") blobHash := "abc123" blobContent := cas.NewContent(blobStore) - require.NoError(t, blobContent.Store(nil, blobHash, blobData)) + require.NoError(t, blobContent.Store(l, v, blobHash, blobData)) // Store a synthetic tree that references the blob treeData := []byte("100644 blob abc123\tREADME.md\n") treeHash := "synth999" synthContent := cas.NewContent(synthStore) - require.NoError(t, synthContent.Store(nil, treeHash, treeData)) + require.NoError(t, synthContent.Store(l, v, treeHash, treeData)) // Build a CAS instance using the same store paths c, err := cas.New(cas.WithStorePath(storeDir)) require.NoError(t, err) destDir := helpers.TmpDirWOSymlinks(t) - l := logger.CreateLogger() - err = c.MaterializeTree(t.Context(), l, treeHash, destDir) + err = c.MaterializeTree(t.Context(), l, v, treeHash, destDir) require.NoError(t, err) content, err := os.ReadFile(filepath.Join(destDir, "README.md")) @@ -68,26 +72,30 @@ func TestMaterializeTree_FromGitTreeStore(t *testing.T) { require.NoError(t, os.MkdirAll(s.Path(), 0755)) } + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + blobData := []byte("module content\n") blobHash := "blob111" blobContent := cas.NewContent(blobStore) - require.NoError(t, blobContent.Store(nil, blobHash, blobData)) + require.NoError(t, blobContent.Store(l, v, blobHash, blobData)) // Store a tree in the git tree store (not synth) treeData := []byte("100644 blob blob111\tmain.tf\n") treeHash := "tree222" treeContent := cas.NewContent(treeStore) - require.NoError(t, treeContent.Store(nil, treeHash, treeData)) + require.NoError(t, treeContent.Store(l, v, treeHash, treeData)) c, err := cas.New(cas.WithStorePath(storeDir)) require.NoError(t, err) destDir := helpers.TmpDirWOSymlinks(t) - l := logger.CreateLogger() - err = c.MaterializeTree(t.Context(), l, treeHash, destDir) + err = c.MaterializeTree(t.Context(), l, v, treeHash, destDir) require.NoError(t, err) content, err := os.ReadFile(filepath.Join(destDir, "main.tf")) @@ -103,10 +111,13 @@ func TestMaterializeTree_NotFound(t *testing.T) { c, err := cas.New(cas.WithStorePath(storeDir)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + destDir := helpers.TmpDirWOSymlinks(t) l := logger.CreateLogger() - err = c.MaterializeTree(t.Context(), l, "nonexistent", destDir) + err = c.MaterializeTree(t.Context(), l, v, "nonexistent", destDir) require.Error(t, err) assert.ErrorIs(t, err, cas.ErrTreeNotFound) } @@ -124,30 +135,34 @@ func TestMaterializeTree_SynthTakesPrecedence(t *testing.T) { require.NoError(t, os.MkdirAll(s.Path(), 0755)) } + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + blobA := []byte("synth version\n") blobB := []byte("git version\n") blobContent := cas.NewContent(blobStore) - require.NoError(t, blobContent.Store(nil, "blobA", blobA)) - require.NoError(t, blobContent.Store(nil, "blobB", blobB)) + require.NoError(t, blobContent.Store(l, v, "blobA", blobA)) + require.NoError(t, blobContent.Store(l, v, "blobB", blobB)) hash := "samehash" // Store in synth store (references blobA) synthContent := cas.NewContent(synthStore) - require.NoError(t, synthContent.Store(nil, hash, []byte("100644 blob blobA\tfile.txt\n"))) + require.NoError(t, synthContent.Store(l, v, hash, []byte("100644 blob blobA\tfile.txt\n"))) // Store in git tree store (references blobB) gitContent := cas.NewContent(treeStore) - require.NoError(t, gitContent.Store(nil, hash, []byte("100644 blob blobB\tfile.txt\n"))) + require.NoError(t, gitContent.Store(l, v, hash, []byte("100644 blob blobB\tfile.txt\n"))) c, err := cas.New(cas.WithStorePath(storeDir)) require.NoError(t, err) destDir := helpers.TmpDirWOSymlinks(t) - l := logger.CreateLogger() - err = c.MaterializeTree(t.Context(), l, hash, destDir) + err = c.MaterializeTree(t.Context(), l, v, hash, destDir) require.NoError(t, err) // Synth store is checked first, so the synth version should win @@ -189,23 +204,27 @@ func TestCASProtocolGetterGet(t *testing.T) { require.NoError(t, os.MkdirAll(s.Path(), 0755)) } + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + fileContent := []byte("resource {}\n") fileHash := "file123" blobContent := cas.NewContent(blobStore) - require.NoError(t, blobContent.Store(nil, fileHash, fileContent)) + require.NoError(t, blobContent.Store(l, v, fileHash, fileContent)) treeHash := "tree456" treeData := []byte("100644 blob file123\tmain.tf\n") synthContent := cas.NewContent(synthStore) - require.NoError(t, synthContent.Store(nil, treeHash, treeData)) + require.NoError(t, synthContent.Store(l, v, treeHash, treeData)) c, err := cas.New(cas.WithStorePath(storeDir)) require.NoError(t, err) - l := logger.CreateLogger() - g := getter.NewCASProtocolGetter(l, c) + g := getter.NewCASProtocolGetter(l, c, v) destDir := helpers.TmpDirWOSymlinks(t) @@ -241,23 +260,27 @@ func TestCASProtocolGetterGet_Mutable(t *testing.T) { require.NoError(t, os.MkdirAll(s.Path(), 0755)) } + v, err := cas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + fileContent := []byte("resource {}\n") fileHash := "filemut1" blobContent := cas.NewContent(blobStore) - require.NoError(t, blobContent.Store(nil, fileHash, fileContent)) + require.NoError(t, blobContent.Store(l, v, fileHash, fileContent)) treeHash := "treemut1" treeData := []byte("100644 blob filemut1\tmain.tf\n") synthContent := cas.NewContent(synthStore) - require.NoError(t, synthContent.Store(nil, treeHash, treeData)) + require.NoError(t, synthContent.Store(l, v, treeHash, treeData)) c, err := cas.New(cas.WithStorePath(storeDir)) require.NoError(t, err) - l := logger.CreateLogger() - g := getter.NewCASProtocolGetter(l, c) + g := getter.NewCASProtocolGetter(l, c, v) g.Mutable = true destDir := helpers.TmpDirWOSymlinks(t) @@ -290,8 +313,11 @@ func TestCASProtocolGetterGet_InvalidRef(t *testing.T) { c, err := cas.New(cas.WithStorePath(storeDir)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + l := logger.CreateLogger() - g := getter.NewCASProtocolGetter(l, c) + g := getter.NewCASProtocolGetter(l, c, v) req := &getter.Request{ Src: "badprefix:abc123", diff --git a/internal/cas/protocol.go b/internal/cas/protocol.go index 826181569e..8e1bc4ef4c 100644 --- a/internal/cas/protocol.go +++ b/internal/cas/protocol.go @@ -103,31 +103,35 @@ func FormatCASRefWithSubdir(hash, subdir string) string { // MaterializeTree reads a tree from the CAS store and links its contents to the destination directory. // It tries the synth store first, then falls back to the git tree store. +// +// Requires v.FS for reading the stored tree and writing links. v.Git is +// not used because materialization is a pure FS operation. func (c *CAS) MaterializeTree( ctx context.Context, l log.Logger, + v Venv, hash string, dest string, opts ...LinkTreeOption, ) error { + v.RequireFS() + var treeData []byte var treeStoreUsed *Store - // Try synth store first (synthetic trees from stack CAS processing). synthContent := NewContent(c.synthStore) - data, err := synthContent.Read(hash) + data, err := synthContent.Read(v, hash) if err == nil { treeData = data treeStoreUsed = c.synthStore } - // Fall back to main tree store (git-derived trees). if treeData == nil { treeContent := NewContent(c.treeStore) - data, err = treeContent.Read(hash) + data, err = treeContent.Read(v, hash) if err != nil { return &WrappedError{ Op: "materialize_tree", @@ -145,5 +149,5 @@ func (c *CAS) MaterializeTree( return fmt.Errorf("failed to parse CAS tree %s: %w", hash, err) } - return LinkTree(ctx, c.blobStore, treeStoreUsed, tree, dest, opts...) + return LinkTree(ctx, v, c.blobStore, treeStoreUsed, tree, dest, opts...) } diff --git a/internal/cas/race_test.go b/internal/cas/race_test.go index 468eea6092..e097fd9757 100644 --- a/internal/cas/race_test.go +++ b/internal/cas/race_test.go @@ -27,13 +27,16 @@ func TestCASGetterGetWithRacing(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + opts := &cas.CloneOptions{ Depth: -1, } l := logger.CreateLogger() - g := getter.NewCASGetter(l, c, opts) + g := getter.NewCASGetter(l, c, v, opts) client := getter.Client{ Getters: []getter.Getter{g}, } @@ -81,6 +84,9 @@ func TestProcessStackComponentLocalSourceConcurrentWithRacing(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + const workers = 4 results := make([]string, workers) @@ -96,7 +102,7 @@ func TestProcessStackComponentLocalSourceConcurrentWithRacing(t *testing.T) { source := root + "//stacks/my-stack" - result, runErr := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, runErr := c.ProcessStackComponent(t.Context(), l, v, source, "stack") if runErr != nil { errs[idx] = runErr return diff --git a/internal/cas/resolver_git_test.go b/internal/cas/resolver_git_test.go new file mode 100644 index 0000000000..61598a3f78 --- /dev/null +++ b/internal/cas/resolver_git_test.go @@ -0,0 +1,246 @@ +package cas_test + +import ( + "context" + "testing" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/internal/vexec" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGitResolver_ProbeHEAD(t *testing.T) { + t.Parallel() + + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + + headHash, err := srv.Head() + require.NoError(t, err) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + r := &cas.GitResolver{Venv: cas.Venv{Git: newGitRunner(t)}} + + // Empty Branch → resolver queries HEAD. + got, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Equal(t, headHash, got, "Probe(HEAD) must return the canonical commit SHA verbatim") +} + +func TestGitResolver_ProbeBranch(t *testing.T) { + t.Parallel() + + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + require.NoError(t, srv.Branch("feature")) + + branchHash, err := srv.Head() + require.NoError(t, err) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + r := &cas.GitResolver{Venv: cas.Venv{Git: newGitRunner(t)}, Branch: "feature"} + + got, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Equal(t, branchHash, got) +} + +func TestGitResolver_ProbeTag(t *testing.T) { + t.Parallel() + + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + require.NoError(t, srv.Tag("v1.0.0")) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + r := &cas.GitResolver{Venv: cas.Venv{Git: newGitRunner(t)}, Branch: "v1.0.0"} + + // Annotated-tag ls-remote returns the tag object's hash, not the + // commit it points to. We just assert the resolver returns a + // SHA-shaped string; the specific value is whatever git computed. + got, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Len(t, got, 40) +} + +func TestGitResolver_CommitFormRefReturnsErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + + commitSHA, err := srv.Head() + require.NoError(t, err) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + r := &cas.GitResolver{Venv: cas.Venv{Git: newGitRunner(t)}, Branch: commitSHA} + + // ls-remote does not resolve raw SHAs as refs; the caller passes + // a commit-form ref directly. Probe must surface this as + // ErrNoVersionMetadata so the fetcher canonicalizes via rev-parse. + _, err = r.Probe(t.Context(), url) + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +func TestGitResolver_UnknownRefReturnsErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + r := &cas.GitResolver{Venv: cas.Venv{Git: newGitRunner(t)}, Branch: "does-not-exist"} + + _, err = r.Probe(t.Context(), url) + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +func TestGitResolver_TokenIsCacheKeyVerbatim(t *testing.T) { + t.Parallel() + + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + + headHash, err := srv.Head() + require.NoError(t, err) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + r := &cas.GitResolver{Venv: cas.Venv{Git: newGitRunner(t)}} + + got, err := r.Probe(t.Context(), url) + require.NoError(t, err) + // GitResolver returns the commit SHA itself; SourceCacheKey would + // be a no-op on top of this. Any future change that pre-hashes + // the token would break the git fetcher's use of the returned key + // as a git object name, so this contract is worth pinning. + assert.Len(t, got, 40, "git SHA-1 must surface as a 40-char hex string") + assert.Equal(t, headHash, got) +} + +// TestGitResolver_FullSHAHitsLocalCacheOffline pins the offline +// fast path. After EnsureCommit warms the local GitStore, Probe must +// resolve a full-SHA ref without contacting the remote, verified by +// shutting the test server down before calling Probe. +func TestGitResolver_FullSHAHitsLocalCacheOffline(t *testing.T) { + t.Parallel() + + srv := newEmptyTestServer(t) + require.NoError(t, srv.CommitFile("README.md", []byte("hi"), "init")) + + headHash, err := srv.Head() + require.NoError(t, err) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + store, v, _ := newTestGitStore(t) + l := logger.CreateLogger() + + repo, err := store.EnsureCommit(t.Context(), l, v, url, headHash, "") + require.NoError(t, err) + require.NoError(t, repo.Unlock()) + + require.NoError(t, srv.Close()) + + r := &cas.GitResolver{Venv: v, Store: store, Branch: headHash} + + got, err := r.Probe(t.Context(), url) + require.NoError(t, err, "fast path must skip ls-remote when commit is cached") + assert.Equal(t, headHash, got) +} + +// TestGitResolver_ProbeSCPURLWithBranchUsesSeparateArgs pins the +// regression where Probe glued ?ref= onto an SCP-form URL +// (`git@host:path`) and then handed the result to net/url.Parse, +// which rejects SCP form. Probe silently lost the branch and called +// `git ls-remote 'git@host:path?ref=feature' HEAD`. Branch must +// travel as a separate ls-remote argument, leaving the SCP URL +// intact. +func TestGitResolver_ProbeSCPURLWithBranchUsesSeparateArgs(t *testing.T) { + t.Parallel() + + var capturedArgs []string + + runner := newStubGitRunner(t, func(_ context.Context, inv vexec.Invocation) vexec.Result { + capturedArgs = inv.Args + return vexec.Result{Stdout: []byte("deadbeefcafefacedeadbeefcafefacedeadbeef\trefs/heads/main\n")} + }) + + r := &cas.GitResolver{Venv: cas.Venv{Git: runner}, Branch: "main"} + + _, err := r.Probe(t.Context(), "git@github.com:org/repo.git") + require.NoError(t, err) + assert.Equal(t, + []string{"ls-remote", "git@github.com:org/repo.git", "main"}, + capturedArgs, + "SCP URL must reach git as-is with branch passed as a separate ls-remote argument", + ) +} + +// TestGitResolver_ProbeHTTPURLWithBranchUsesSeparateArgs covers the +// non-SCP companion of the above: even for URLs net/url.Parse handles +// cleanly, the branch travels as a field on the resolver rather than +// being threaded through the URL. +func TestGitResolver_ProbeHTTPURLWithBranchUsesSeparateArgs(t *testing.T) { + t.Parallel() + + var capturedArgs []string + + runner := newStubGitRunner(t, func(_ context.Context, inv vexec.Invocation) vexec.Result { + capturedArgs = inv.Args + return vexec.Result{Stdout: []byte("deadbeefcafefacedeadbeefcafefacedeadbeef\trefs/heads/main\n")} + }) + + r := &cas.GitResolver{Venv: cas.Venv{Git: runner}, Branch: "main"} + + _, err := r.Probe(t.Context(), "https://example.com/org/repo.git") + require.NoError(t, err) + assert.Equal(t, + []string{"ls-remote", "https://example.com/org/repo.git", "main"}, + capturedArgs, + "HTTP URL must reach git without ?ref= glued on; branch is a separate argument", + ) +} + +// newGitRunner returns the *git.GitRunner the GitResolver shells out +// through. Tests run against the in-memory git HTTP server defined in +// internal/git so no network is touched. +func newGitRunner(t *testing.T) *git.GitRunner { + t.Helper() + + r, err := git.NewGitRunner(vexec.NewOSExec()) + require.NoError(t, err) + + return r +} + +// newStubGitRunner returns a *git.GitRunner backed by a +// [vexec.NewMemExec] handler so tests can capture the argv passed to +// git without spawning the binary. +func newStubGitRunner(t *testing.T, handler vexec.Handler) *git.GitRunner { + t.Helper() + + e := vexec.NewMemExec(handler, + vexec.WithLookPath(func(string) (string, error) { return "git", nil }), + ) + + r, err := git.NewGitRunner(e) + require.NoError(t, err) + + return r +} diff --git a/internal/cas/source.go b/internal/cas/source.go new file mode 100644 index 0000000000..1618071d78 --- /dev/null +++ b/internal/cas/source.go @@ -0,0 +1,327 @@ +package cas + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io/fs" + "maps" + + "github.com/gruntwork-io/terragrunt/internal/errors" + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/internal/telemetry" + "github.com/gruntwork-io/terragrunt/internal/vfs" + "github.com/gruntwork-io/terragrunt/pkg/log" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +// ErrNoVersionMetadata reports that a SourceResolver had no usable +// version identifier for a source. Callers fall back to downloading, +// walking the result, and keying the tree by its content hash. +var ErrNoVersionMetadata = errors.New("no version metadata available for source") + +// SourceResolver derives a tree-store cache key for a source from a +// cheap remote probe, so FetchSource can short-circuit the download +// when the bytes haven't changed upstream. +type SourceResolver interface { + // Scheme returns the URL scheme this resolver handles (e.g. + // "s3", "gcs", "http"). + Scheme() string + + // Probe returns a cache key for rawURL. + // + // Returns ErrNoVersionMetadata when the source has no cheap + // signal; FetchSource then falls back to downloading and + // content-hashing. Other errors are logged and treated the same + // way, so a misconfigured probe never breaks a fetch. + Probe(ctx context.Context, rawURL string) (cacheKey string, err error) +} + +// SourceFetcher downloads and ingests a source into CAS, returning the +// tree-store key the materialized tree was written under. +// +// suggestedKey is the probe-derived cache key, or empty when the probe +// produced none. Fetchers that learn the canonical key only after +// downloading (the git rev-parse path) may ignore it and return the +// canonical key instead. +type SourceFetcher func(ctx context.Context, l log.Logger, v Venv, suggestedKey string) (treeKey string, err error) + +// SourceRequest is the input to CAS.FetchSource. +type SourceRequest struct { + // Resolver probes the source for a cache key. Nil means always + // download and hash. + Resolver SourceResolver + // Fetch ingests the source. Required. + Fetch SourceFetcher + // Attrs are scheme-specific telemetry attributes merged into the + // cas_fetch_source span (e.g. the git path adds "branch"). + Attrs map[string]any + // Scheme is the URL scheme of URL ("s3", "gcs", "http"). Used in + // telemetry; resolvers that need it pull it from URL themselves. + Scheme string + // URL is the canonical source URL. Passed to Resolver.Probe and + // used in error messages. + URL string +} + +// FetchSource routes src through the CAS. On a probe hit it links the +// cached tree into opts.Dir without invoking Fetch. On a probe miss it +// calls Fetch and links the resulting tree. +// +// opts.Dir is the destination. opts.Mutable selects copy vs hardlink +// for the final link, matching the git path. +// +// Requires v.FS for store I/O. v.Git is only consulted by fetchers that +// shell out to git (e.g. the closure built by [CAS.Clone]); other +// fetchers are free to leave it unset. +func (c *CAS) FetchSource( + ctx context.Context, + l log.Logger, + v Venv, + opts *CloneOptions, + src SourceRequest, +) error { + v.RequireFS() + + if src.Fetch == nil { + return ErrFetchClosureRequired + } + + if err := c.ensureStorePaths(v); err != nil { + return err + } + + attrs := map[string]any{ + "url": src.URL, + "scheme": src.Scheme, + } + + maps.Copy(attrs, src.Attrs) + + tlm := telemetry.TelemeterFromContext(ctx) + + return tlm.Collect(ctx, "cas_fetch_source", attrs, func(childCtx context.Context) error { + suggestedKey := c.probeSource(childCtx, l, src) + + if suggestedKey != "" && !c.treeStore.NeedsWrite(v, suggestedKey) { + recordFetchOutcome(childCtx, true) + + return c.linkStoredTree(childCtx, v, opts, suggestedKey) + } + + recordFetchOutcome(childCtx, false) + + treeKey, err := src.Fetch(childCtx, l, v, suggestedKey) + if err != nil { + return fmt.Errorf("fetch %s: %w", src.URL, err) + } + + return c.linkStoredTree(childCtx, v, opts, treeKey) + }) +} + +// ContentKey derives a cache key for a probe token that is a content +// hash of the source bytes (S3 x-amz-checksum-sha256, GCS md5Hash, Hg +// node hash, ...). The scheme and URL drop out so identical bytes at +// different URLs share one entry. +func ContentKey(alg, token string) string { + h := sha256.New() + h.Write([]byte("content\x00")) + h.Write([]byte(alg)) + h.Write([]byte{0}) + h.Write([]byte(token)) + + return hex.EncodeToString(h.Sum(nil)) +} + +// OpaqueKey derives a URL-scoped cache key for a probe token that is +// not a content hash (ETag, Last-Modified). The token alone does not +// identify the bytes, so the scheme and URL stay in the key. +func OpaqueKey(scheme, url, token string) string { + h := sha256.New() + h.Write([]byte("source\x00")) + h.Write([]byte(scheme)) + h.Write([]byte{0}) + h.Write([]byte(url)) + h.Write([]byte{0}) + h.Write([]byte(token)) + + return hex.EncodeToString(h.Sum(nil)) +} + +// MakeFetchTempDir creates a scratch directory for a [SourceFetcher] and +// returns the path with a cleanup closure that logs failures rather than +// returning them. Exported so out-of-package [SourceFetcher] implementations +// share the same temp-dir layout. +// +// Requires v.FS. +func (c *CAS) MakeFetchTempDir(l log.Logger, v Venv) (string, func(), error) { + v.RequireFS() + + tempDir, err := vfs.MkdirTemp(v.FS, "", "terragrunt-cas-fetch-*") + if err != nil { + return "", nil, fmt.Errorf("create source fetch dir: %w", err) + } + + cleanup := func() { + if rmErr := v.FS.RemoveAll(tempDir); rmErr != nil { + l.Warnf("cleanup error for %s: %v", tempDir, rmErr) + } + } + + return tempDir, cleanup, nil +} + +// IngestDirectory hashes sourceDir under [DefaultLocalHashAlgorithm] and +// stores the tree and blobs in CAS. The returned tree key is suggestedKey +// when non-empty (probe-derived); otherwise it is the content hash of the +// tree. Exported so out-of-package [SourceFetcher] implementations ingest +// through the same path the local-source flow uses. +// +// Requires v.FS. +func (c *CAS) IngestDirectory(l log.Logger, v Venv, sourceDir, suggestedKey string) (string, error) { + v.RequireFS() + + hash, treeData, err := c.buildLocalTree(v, sourceDir, DefaultLocalHashAlgorithm) + if err != nil { + return "", fmt.Errorf("hash %s: %w", sourceDir, err) + } + + treeKey := suggestedKey + if treeKey == "" { + treeKey = hash + } + + if err := c.storeFetchedContent(l, v, sourceDir, treeKey, treeData, DefaultLocalHashAlgorithm); err != nil { + return "", fmt.Errorf("store %s: %w", sourceDir, err) + } + + return treeKey, nil +} + +// probeSource invokes the resolver and returns its cache key, or empty +// when no resolver is configured or the probe failed. See +// [SourceResolver.Probe] for the fallback contract. +func (c *CAS) probeSource(ctx context.Context, l log.Logger, src SourceRequest) string { + if src.Resolver == nil { + return "" + } + + key, err := src.Resolver.Probe(ctx, src.URL) + if err != nil { + if !errors.Is(err, ErrNoVersionMetadata) { + l.Debugf("cas: source probe for %s failed (falling back to content hash): %v", src.URL, err) + } + + return "" + } + + return key +} + +// recordFetchOutcome stamps cache_hit on the active cas_fetch_source span +// so dashboards can distinguish probe short-circuits from network fetches. +func recordFetchOutcome(ctx context.Context, cacheHit bool) { + span := trace.SpanFromContext(ctx) + if !span.IsRecording() { + return + } + + span.SetAttributes(attribute.Bool("cache_hit", cacheHit)) +} + +// linkStoredTree materializes the tree at key into opts.Dir. +func (c *CAS) linkStoredTree(ctx context.Context, v Venv, opts *CloneOptions, key string) error { + treeContent := NewContent(c.treeStore) + + treeData, err := treeContent.Read(v, key) + if err != nil { + return fmt.Errorf("read cached tree %s: %w", key, err) + } + + tree, err := git.ParseTree(treeData, opts.Dir) + if err != nil { + return fmt.Errorf("parse cached tree %s: %w", key, err) + } + + var linkOpts []LinkTreeOption + if opts.Mutable { + linkOpts = append(linkOpts, WithForceCopy()) + } + + return LinkTree(ctx, v, c.blobStore, c.treeStore, tree, opts.Dir, linkOpts...) +} + +// storeFetchedContent stores every blob referenced by the tree, then +// the tree object itself, under treeKey. Order matters: a racing +// reader that sees the tree must find every referenced blob. Writing +// the tree last means a treeStore.NeedsWrite hit implies the blobs +// are already present. +// +// Symlink entries are stored as blobs whose content is the link target +// string, matching git's representation. [hashLocalEntry] validates the +// target stays inside sourceDir so the CAS never persists an escape. +// +// treeKey is either a probe-derived key or, when no probe applies, the +// content hash from buildLocalTree. +func (c *CAS) storeFetchedContent( + l log.Logger, + v Venv, + sourceDir, treeKey string, + treeData []byte, + alg HashAlgorithm, +) error { + blobContent := NewContent(c.blobStore) + + walkErr := vfs.WalkDir(v.FS, sourceDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() { + return nil + } + + info, err := d.Info() + if err != nil { + return fmt.Errorf("stat %s: %w", path, err) + } + + mode, blobHash, err := hashLocalEntry(v.FS, sourceDir, path, info, alg) + if err != nil { + return err + } + + switch mode { + case "": + return nil + case gitSymlinkMode: + target, err := vfs.Readlink(v.FS, path) + if err != nil { + return fmt.Errorf("read symlink %s: %w", path, err) + } + + if err := blobContent.Ensure(l, v, blobHash, []byte(target)); err != nil { + return fmt.Errorf("store symlink blob %s: %w", path, err) + } + default: + if err := blobContent.EnsureCopy(l, v, blobHash, path); err != nil { + return fmt.Errorf("store blob %s: %w", path, err) + } + } + + return nil + }) + if walkErr != nil { + return walkErr + } + + treeContent := NewContent(c.treeStore) + if err := treeContent.EnsureWithWait(l, v, treeKey, treeData); err != nil { + return fmt.Errorf("store tree %s: %w", treeKey, err) + } + + return nil +} diff --git a/internal/cas/source_test.go b/internal/cas/source_test.go new file mode 100644 index 0000000000..5e2c1dfa29 --- /dev/null +++ b/internal/cas/source_test.go @@ -0,0 +1,342 @@ +package cas_test + +import ( + "context" + "path/filepath" + "sync/atomic" + "testing" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/vfs" + "github.com/gruntwork-io/terragrunt/pkg/log" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestFetchSource_ProbeHitSkipsDownload(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + resolver := &fakeResolver{ + scheme: "http", + key: cas.OpaqueKey("http", "https://example.com/mod.tgz", "etag-abc"), + } + + var fetchCalls atomic.Int32 + + fetch := fakeFetcher(c, map[string]string{ + "main.tf": `# hello`, + "README": "readme", + "sub/x.tf": `variable "x" {}`, + }, &fetchCalls) + + dst1 := filepath.Join(t.TempDir(), "dst1") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dst1}, cas.SourceRequest{ + Scheme: "http", + URL: "https://example.com/mod.tgz", + Resolver: resolver, + Fetch: fetch, + })) + + require.Equal(t, int32(1), fetchCalls.Load()) + require.FileExists(t, filepath.Join(dst1, "main.tf")) + + dst2 := filepath.Join(t.TempDir(), "dst2") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dst2}, cas.SourceRequest{ + Scheme: "http", + URL: "https://example.com/mod.tgz", + Resolver: resolver, + Fetch: fetch, + })) + + assert.Equal(t, int32(1), fetchCalls.Load(), "second call must hit the cache and not re-fetch") + assert.FileExists(t, filepath.Join(dst2, "main.tf")) + assert.FileExists(t, filepath.Join(dst2, "sub", "x.tf")) +} + +func TestFetchSource_NoMetadataFallsBackToContentHash(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + resolver := &fakeResolver{ + scheme: "http", + err: cas.ErrNoVersionMetadata, + } + + var fetchCalls atomic.Int32 + + fetch := fakeFetcher(c, map[string]string{ + "main.tf": "content", + }, &fetchCalls) + + dst1 := filepath.Join(t.TempDir(), "dst1") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dst1}, cas.SourceRequest{ + Scheme: "http", + URL: "https://example.com/mod.tgz", + Resolver: resolver, + Fetch: fetch, + })) + + dst2 := filepath.Join(t.TempDir(), "dst2") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dst2}, cas.SourceRequest{ + Scheme: "http", + URL: "https://example.com/mod.tgz", + Resolver: resolver, + Fetch: fetch, + })) + + assert.Equal(t, int32(2), fetchCalls.Load(), "no probe means we re-download every time") + assert.Equal(t, int32(2), resolver.calls.Load()) + assert.FileExists(t, filepath.Join(dst2, "main.tf")) +} + +func TestFetchSource_NilResolverContentHashes(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + var fetchCalls atomic.Int32 + + fetch := fakeFetcher(c, map[string]string{ + "a.tf": "1", + }, &fetchCalls) + + dst := filepath.Join(t.TempDir(), "dst") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dst}, cas.SourceRequest{ + Scheme: "s3", + URL: "s3://bucket/key.tgz", + Fetch: fetch, + })) + + assert.Equal(t, int32(1), fetchCalls.Load()) + assert.FileExists(t, filepath.Join(dst, "a.tf")) +} + +func TestFetchSource_ContentAddressedDedupesAcrossURLs(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + // Both resolvers report the same content-addressed cache key. The + // URL drops out of the derivation so identical bytes at two URLs + // share a tree-store entry. + contentKey := cas.ContentKey("sha256", "deadbeef") + + files := map[string]string{ + "main.tf": "ok", + } + + var fetchCalls atomic.Int32 + + resolverA := &fakeResolver{scheme: "s3", key: contentKey} + resolverB := &fakeResolver{scheme: "gcs", key: contentKey} + + dstA := filepath.Join(t.TempDir(), "dstA") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dstA}, cas.SourceRequest{ + Scheme: "s3", + URL: "s3://bucketA/mod.tgz", + Resolver: resolverA, + Fetch: fakeFetcher(c, files, &fetchCalls), + })) + + dstB := filepath.Join(t.TempDir(), "dstB") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dstB}, cas.SourceRequest{ + Scheme: "gcs", + URL: "gs://bucketB/different/path.tgz", + Resolver: resolverB, + Fetch: fakeFetcher(c, files, &fetchCalls), + })) + + assert.Equal(t, int32(1), fetchCalls.Load(), + "content-addressed probe must dedupe identical bytes across two URLs") + assert.FileExists(t, filepath.Join(dstB, "main.tf")) +} + +func TestFetchSource_OpaqueProbeURLScoped(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + files := map[string]string{"main.tf": "hello"} + + var fetchCalls atomic.Int32 + + resolverA := &fakeResolver{ + scheme: "http", + key: cas.OpaqueKey("http", "https://a.example.com/mod.tgz", "same-etag"), + } + resolverB := &fakeResolver{ + scheme: "http", + key: cas.OpaqueKey("http", "https://b.example.com/mod.tgz", "same-etag"), + } + + dstA := filepath.Join(t.TempDir(), "dstA") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dstA}, cas.SourceRequest{ + Scheme: "http", + URL: "https://a.example.com/mod.tgz", + Resolver: resolverA, + Fetch: fakeFetcher(c, files, &fetchCalls), + })) + + dstB := filepath.Join(t.TempDir(), "dstB") + require.NoError(t, c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dstB}, cas.SourceRequest{ + Scheme: "http", + URL: "https://b.example.com/mod.tgz", + Resolver: resolverB, + Fetch: fakeFetcher(c, files, &fetchCalls), + })) + + assert.Equal(t, int32(2), fetchCalls.Load(), + "opaque probe must not dedupe across distinct URLs even when the token matches") +} + +func TestFetchSource_ConcurrentFetchesOfSameKeyWithRacing(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + resolver := &fakeResolver{ + scheme: "http", + key: cas.OpaqueKey("http", "https://example.com/race.tgz", "etag-race"), + } + + var fetchCalls atomic.Int32 + + files := map[string]string{"main.tf": "race"} + + const n = 8 + + // Pre-allocate destination directories from the test goroutine. + // t.TempDir() invoked from worker goroutines races with t.Cleanup + // registration when many workers fire at once on macOS. + dsts := make([]string, n) + for i := range n { + dsts[i] = filepath.Join(t.TempDir(), "dst") + } + + var g errgroup.Group + + for i := range n { + dst := dsts[i] + + g.Go(func() error { + return c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: dst}, cas.SourceRequest{ + Scheme: "http", + URL: "https://example.com/race.tgz", + Resolver: resolver, + Fetch: fakeFetcher(c, files, &fetchCalls), + }) + }) + } + + require.NoError(t, g.Wait()) + + // At least one fetch occurred; concurrent racers may legitimately + // fetch into separate temp dirs and only the first to acquire the + // tree-store lock wins. + assert.GreaterOrEqual(t, fetchCalls.Load(), int32(1)) +} + +func TestFetchSource_RequiresFetchClosure(t *testing.T) { + t.Parallel() + + c, v := newCAS(t) + l := logger.CreateLogger() + + err := c.FetchSource(t.Context(), l, v, &cas.CloneOptions{Dir: t.TempDir()}, cas.SourceRequest{ + Scheme: "http", + URL: "https://example.com", + }) + require.ErrorIs(t, err, cas.ErrFetchClosureRequired) +} + +func TestContentKey_URLIndependent(t *testing.T) { + t.Parallel() + + k1 := cas.ContentKey("sha256", "abc") + k2 := cas.ContentKey("sha256", "abc") + assert.Equal(t, k1, k2, "same alg+token must produce the same key") + + // Algorithm tag matters: different alg, same token, different key. + assert.NotEqual(t, k1, cas.ContentKey("md5", "abc")) + // Token matters. + assert.NotEqual(t, k1, cas.ContentKey("sha256", "abd")) //codespell:ignore abd +} + +func TestOpaqueKey_URLScoped(t *testing.T) { + t.Parallel() + + k1 := cas.OpaqueKey("http", "https://a.example/x", "etag") + k2 := cas.OpaqueKey("http", "https://b.example/x", "etag") + assert.NotEqual(t, k1, k2, "different URLs must produce different keys") + + // Scheme matters. + assert.NotEqual(t, k1, cas.OpaqueKey("https", "https://a.example/x", "etag")) + // Token matters. + assert.NotEqual(t, k1, cas.OpaqueKey("http", "https://a.example/x", "other")) +} + +func TestContentKey_DoesNotCollideWithOpaqueKey(t *testing.T) { + t.Parallel() + + // Both derivations would otherwise hash the same token "abc" but + // the namespace prefix ("content" vs "source") keeps them apart. + content := cas.ContentKey("sha256", "abc") + opaque := cas.OpaqueKey("sha256", "abc", "abc") + assert.NotEqual(t, content, opaque) +} + +// fakeResolver returns a canned cache key on Probe. +type fakeResolver struct { + err error + scheme string + key string + calls atomic.Int32 +} + +func (f *fakeResolver) Scheme() string { return f.scheme } + +func (f *fakeResolver) Probe(_ context.Context, _ string) (string, error) { + f.calls.Add(1) + + if f.err != nil { + return "", f.err + } + + return f.key, nil +} + +// fakeFetcher writes a deterministic fixture into a fresh temp dir, +// ingests it into CAS via IngestDirectory, and returns the resulting +// tree key. It counts invocations for assertions. +func fakeFetcher(c *cas.CAS, files map[string]string, calls *atomic.Int32) cas.SourceFetcher { + return func(_ context.Context, l log.Logger, v cas.Venv, suggestedKey string) (string, error) { + calls.Add(1) + + tempDir, cleanup, err := c.MakeFetchTempDir(l, v) + if err != nil { + return "", err + } + + defer cleanup() + + for rel, body := range files { + full := filepath.Join(tempDir, rel) + if err := vfs.WriteFile(v.FS, full, []byte(body), 0o644); err != nil { + return "", err + } + } + + return c.IngestDirectory(l, v, tempDir, suggestedKey) + } +} diff --git a/internal/cas/stacks.go b/internal/cas/stacks.go index 33144aecee..363103b776 100644 --- a/internal/cas/stacks.go +++ b/internal/cas/stacks.go @@ -11,7 +11,6 @@ import ( "strings" "github.com/gruntwork-io/terragrunt/internal/git" - "github.com/gruntwork-io/terragrunt/internal/vexec" "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/pkg/log" "github.com/hashicorp/go-getter/v2" @@ -69,13 +68,25 @@ type StackCASResult struct { // cloned into a temp directory; local sources are copied into a temp directory // so rewrites do not mutate the caller's working tree. The kind should be // "unit" or "stack". -func (c *CAS) ProcessStackComponent(ctx context.Context, l log.Logger, source, kind string) (*StackCASResult, error) { +// +// Requires v.FS unconditionally. Remote sources additionally require +// v.Git; the assertion fires once dispatch picks the remote branch. +func (c *CAS) ProcessStackComponent( + ctx context.Context, + l log.Logger, + v Venv, + source, kind string, +) (*StackCASResult, error) { + v.RequireFS() + repoURL, subdir := getter.SourceDirSubdir(source) - if isLocalPath(repoURL) { - return c.processLocalStackComponent(ctx, l, repoURL, subdir) + if isLocalPath(v.FS, repoURL) { + return c.processLocalStackComponent(ctx, l, v, repoURL, subdir) } + v.RequireGit() + detectedURL, err := DetectRemoteSource(repoURL) if err != nil { return nil, fmt.Errorf("failed to detect source URL %q: %w", repoURL, err) @@ -88,7 +99,6 @@ func (c *CAS) ProcessStackComponent(ctx context.Context, l log.Logger, source, k ref := parsedURL.Query().Get("ref") - // Remove ref from query so we can clone q := parsedURL.Query() q.Del("ref") parsedURL.RawQuery = q.Encode() @@ -101,58 +111,67 @@ func (c *CAS) ProcessStackComponent(ctx context.Context, l log.Logger, source, k // stacks; CommitHash returns the user input as-is for the // commit-ref path, and the canonical hash for the symbolic-ref // path. - resolved, err := c.resolveReference(ctx, cleanURL, ref) + resolved, err := c.resolveReference(ctx, v, cleanURL, ref) if err != nil { return nil, fmt.Errorf("failed to resolve reference %q: %w", ref, err) } refHash := resolved.CommitHash() - // Create temp dir for the clone - tempDir, err := os.MkdirTemp("", "terragrunt-cas-stack-*") + tempDir, err := vfs.MkdirTemp(v.FS, "", "terragrunt-cas-stack-*") if err != nil { return nil, fmt.Errorf("failed to create temp dir: %w", err) } cleanup := func() { - _ = os.RemoveAll(tempDir) + if rmErr := v.FS.RemoveAll(tempDir); rmErr != nil { + l.Warnf("cleanup error for %s: %v", tempDir, rmErr) + } } cloneDir := filepath.Join(tempDir, "repo") - // Clone the repo via CAS - if err := c.Clone(ctx, l, &CloneOptions{ - Dir: cloneDir, - Branch: ref, - Depth: c.cloneDepth, - }, cleanURL); err != nil { + if err := c.Clone(ctx, l, v, cleanURL, WithDir(cloneDir), + WithBranch(ref), + WithDepth(c.cloneDepth)); err != nil { cleanup() return nil, fmt.Errorf("failed to CAS clone %q: %w", cleanURL, err) } - // Detect the repository's hash algorithm from the cloned content. - hashAlg, err := detectRepoHashAlgorithm(ctx, cloneDir) + hashAlg, err := detectRepoHashAlgorithm(ctx, v.Git, cloneDir) if err != nil { l.Debugf("Failed to detect object format, defaulting to SHA-1: %v", err) hashAlg = HashSHA1 } - // Navigate to the subdir within the cloned repo contentDir := cloneDir + if subdir != "" { - contentDir = filepath.Join(cloneDir, subdir) + if filepath.IsAbs(subdir) { + cleanup() + + return nil, fmt.Errorf("%w: %q", ErrAbsoluteSource, subdir) + } + + contentDir = filepath.Clean(filepath.Join(cloneDir, subdir)) + + rel, err := filepath.Rel(cloneDir, contentDir) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + cleanup() + + return nil, fmt.Errorf("%w: %q", ErrSourceEscapesRepo, subdir) + } } - if _, err := os.Stat(contentDir); err != nil { + if _, err := v.FS.Stat(contentDir); err != nil { cleanup() return nil, fmt.Errorf("subdir %q not found in cloned repo: %w", subdir, err) } - // Process the directory: rewrite sources, create synthetic CAS entries - if err := c.processDirectory(ctx, l, cloneDir, contentDir, refHash, hashAlg); err != nil { + if err := c.processDirectory(ctx, l, v, cloneDir, contentDir, refHash, hashAlg); err != nil { cleanup() return nil, fmt.Errorf("failed to process directory for CAS: %w", err) @@ -190,16 +209,33 @@ func SplitSourceDoubleSlash(source string) (basePath, subdir string) { return before, after } -// detectRepoHashAlgorithm queries the git object format of a cloned repository. -func detectRepoHashAlgorithm(ctx context.Context, repoDir string) (HashAlgorithm, error) { - g, err := git.NewGitRunner(vexec.NewOSExec()) - if err != nil { - return "", fmt.Errorf("failed to create git runner: %w", err) +// ResolveInRepoSource resolves an update_source_with_cas source string relative to +// dirPath and returns the cleaned absolute path. Absolute sources and sources +// whose resolved path escapes repoRoot via ".." segments are rejected so CAS +// materialization stays scoped to the cloned repository. +func ResolveInRepoSource(repoRoot, dirPath, source string) (string, error) { + sourcePath, sourceSubdir := SplitSourceDoubleSlash(source) + if filepath.IsAbs(sourcePath) { + return "", fmt.Errorf("%w: %q", ErrAbsoluteSource, source) } - g.WorkDir = repoDir + resolved := filepath.Clean(filepath.Join(dirPath, sourcePath)) + if sourceSubdir != "" { + resolved = filepath.Join(resolved, sourceSubdir) + } - format, err := g.ObjectFormat(ctx) + rel, err := filepath.Rel(filepath.Clean(repoRoot), resolved) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("%w: %q", ErrSourceEscapesRepo, source) + } + + return resolved, nil +} + +// detectRepoHashAlgorithm queries the git object format of a cloned repository +// using the supplied runner so callers control the git/vexec binding. +func detectRepoHashAlgorithm(ctx context.Context, runner *git.GitRunner, repoDir string) (HashAlgorithm, error) { + format, err := runner.WithWorkDir(repoDir).ObjectFormat(ctx) if err != nil { return "", err } @@ -210,18 +246,18 @@ func detectRepoHashAlgorithm(ctx context.Context, repoDir string) (HashAlgorithm // processDirectory recursively processes a stack or unit directory, rewriting // sources and creating synthetic CAS entries. func (c *CAS) processDirectory( - ctx context.Context, l log.Logger, + ctx context.Context, l log.Logger, v Venv, repoRoot, dirPath, refHash string, hashAlg HashAlgorithm, ) error { stackFile := filepath.Join(dirPath, "terragrunt.stack.hcl") unitFile := filepath.Join(dirPath, "terragrunt.hcl") - if _, err := os.Stat(stackFile); err == nil { - return c.processStackFile(ctx, l, repoRoot, dirPath, stackFile, refHash, hashAlg) + if _, err := v.FS.Stat(stackFile); err == nil { + return c.processStackFile(ctx, l, v, repoRoot, dirPath, stackFile, refHash, hashAlg) } - if _, err := os.Stat(unitFile); err == nil { - return c.processUnitFile(l, repoRoot, dirPath, unitFile, refHash, hashAlg) + if _, err := v.FS.Stat(unitFile); err == nil { + return c.processUnitFile(l, v, repoRoot, dirPath, unitFile, refHash, hashAlg) } return nil @@ -230,10 +266,10 @@ func (c *CAS) processDirectory( // processStackFile processes a terragrunt.stack.hcl file, rewriting sources // for blocks that have update_source_with_cas = true. func (c *CAS) processStackFile( - ctx context.Context, l log.Logger, + ctx context.Context, l log.Logger, v Venv, repoRoot, dirPath, stackFile, refHash string, hashAlg HashAlgorithm, ) error { - content, err := os.ReadFile(stackFile) + content, err := vfs.ReadFile(v.FS, stackFile) if err != nil { return fmt.Errorf("failed to read stack file %s: %w", stackFile, err) } @@ -259,17 +295,15 @@ func (c *CAS) processStackFile( return fmt.Errorf("failed to resolve source for %s %q: %w", block.BlockType, block.Name, err) } - if err := c.processDirectory(ctx, l, repoRoot, targetDir, refHash, hashAlg); err != nil { + if err := c.processDirectory(ctx, l, v, repoRoot, targetDir, refHash, hashAlg); err != nil { return fmt.Errorf("failed to process %s %q source: %w", block.BlockType, block.Name, err) } - // Build a synthetic tree for the target directory - synthHash, err := c.buildSyntheticTree(l, targetDir, refHash, repoRoot, hashAlg) + synthHash, err := c.buildSyntheticTree(l, v, targetDir, refHash, repoRoot, hashAlg) if err != nil { return fmt.Errorf("failed to build synthetic tree for %s %q: %w", block.BlockType, block.Name, err) } - // Rewrite the source in the stack file newSource := FormatCASRef(synthHash) content, err = RewriteStackBlockSource(content, block.BlockType, block.Name, newSource) @@ -283,17 +317,22 @@ func (c *CAS) processStackFile( // The file may be a read-only hard link from the CAS store, so remove it // before writing the rewritten content to avoid permission errors and to // avoid mutating the stored blob. - if err := os.Remove(stackFile); err != nil && !os.IsNotExist(err) { + if err := v.FS.Remove(stackFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove stack file before rewrite %s: %w", stackFile, err) } - return os.WriteFile(stackFile, content, RegularFilePerms) + return vfs.WriteFile(v.FS, stackFile, content, RegularFilePerms) } // processUnitFile processes a terragrunt.hcl file, rewriting the // terraform.source if update_source_with_cas is set. -func (c *CAS) processUnitFile(l log.Logger, repoRoot, dirPath, unitFile, refHash string, hashAlg HashAlgorithm) error { - content, err := os.ReadFile(unitFile) +func (c *CAS) processUnitFile( + l log.Logger, + v Venv, + repoRoot, dirPath, unitFile, refHash string, + hashAlg HashAlgorithm, +) error { + content, err := vfs.ReadFile(v.FS, unitFile) if err != nil { return fmt.Errorf("failed to read unit file %s: %w", unitFile, err) } @@ -318,7 +357,7 @@ func (c *CAS) processUnitFile(l log.Logger, repoRoot, dirPath, unitFile, refHash return fmt.Errorf("failed to resolve terraform source %q: %w", source, err) } - synthHash, err := c.buildSyntheticTree(l, moduleDir, refHash, repoRoot, hashAlg) + synthHash, err := c.buildSyntheticTree(l, v, moduleDir, refHash, repoRoot, hashAlg) if err != nil { return fmt.Errorf("failed to build synthetic tree for terraform source %q: %w", source, err) } @@ -335,24 +374,29 @@ func (c *CAS) processUnitFile(l log.Logger, repoRoot, dirPath, unitFile, refHash // The file may be a read-only hard link from the CAS store, so remove it // before writing the rewritten content to avoid permission errors and to // avoid mutating the stored blob. - if err := os.Remove(unitFile); err != nil && !os.IsNotExist(err) { + if err := v.FS.Remove(unitFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove unit file before rewrite %s: %w", unitFile, err) } - return os.WriteFile(unitFile, content, RegularFilePerms) + return vfs.WriteFile(v.FS, unitFile, content, RegularFilePerms) } // buildSyntheticTree creates a synthetic CAS tree entry for a directory. It // hashes every file, stores the blobs, and writes a tree object into the synth // store. The resulting tree hash is deterministic: hashAlg(refHash + relPathInRepo). +// +// Symlinks are stored as 120000 entries whose blob is the link target string. +// [vfs.ValidateSymlinkTarget] rejects targets that escape dirPath, since the +// CAS protocol getter materializes synthetic trees into a self-contained +// destination directory and any escape would dangle. func (c *CAS) buildSyntheticTree( - l log.Logger, dirPath, refHash, repoRoot string, hashAlg HashAlgorithm, + l log.Logger, v Venv, dirPath, refHash, repoRoot string, hashAlg HashAlgorithm, ) (string, error) { var treeData []byte blobContent := NewContent(c.blobStore) - err := vfs.WalkDir(c.fs, dirPath, func(path string, d fs.DirEntry, walkErr error) error { + err := vfs.WalkDir(v.FS, dirPath, func(path string, d fs.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } @@ -366,26 +410,42 @@ func (c *CAS) buildSyntheticTree( return err } - relPath, err := filepath.Rel(dirPath, path) + relPath, err := localRelPath(dirPath, path) if err != nil { return err } - // Convert to forward slashes for consistency (git-style paths) - relPath = strings.ReplaceAll(relPath, string(filepath.Separator), "/") + switch { + case info.Mode()&os.ModeSymlink != 0: + target, err := vfs.Readlink(v.FS, path) + if err != nil { + return fmt.Errorf("read symlink %s: %w", path, err) + } - fileHash, err := hashFileAlg(c.fs, path, hashAlg) - if err != nil { - return fmt.Errorf("failed to hash file %s: %w", path, err) - } + if err := vfs.ValidateSymlinkTarget(dirPath, path, target); err != nil { + return err + } - if err := blobContent.EnsureCopy(l, fileHash, path); err != nil { - return fmt.Errorf("failed to store blob %s: %w", path, err) - } + blobHash := hashAlg.Sum([]byte(target)) + if err := blobContent.Ensure(l, v, blobHash, []byte(target)); err != nil { + return fmt.Errorf("failed to store symlink blob %s: %w", path, err) + } - mode := gitTreeMode(info.Mode()) - treeLine := fmt.Sprintf("%s blob %s\t%s\n", mode, fileHash, relPath) - treeData = append(treeData, []byte(treeLine)...) + treeData = append(treeData, fmt.Appendf(nil, "%s blob %s\t%s\n", gitSymlinkMode, blobHash, relPath)...) + + case info.Mode().IsRegular(): + fileHash, err := hashFileAlg(v.FS, path, hashAlg) + if err != nil { + return fmt.Errorf("failed to hash file %s: %w", path, err) + } + + if err := blobContent.EnsureCopy(l, v, fileHash, path); err != nil { + return fmt.Errorf("failed to store blob %s: %w", path, err) + } + + mode := gitTreeMode(info.Mode()) + treeData = append(treeData, fmt.Appendf(nil, "%s blob %s\t%s\n", mode, fileHash, relPath)...) + } return nil }) @@ -393,7 +453,6 @@ func (c *CAS) buildSyntheticTree( return "", err } - // Compute deterministic hash: hashAlg(refHash + relPathInRepo) relPathInRepo, err := filepath.Rel(repoRoot, dirPath) if err != nil { return "", fmt.Errorf("failed to compute relative path for deterministic hash: %w", err) @@ -403,38 +462,14 @@ func (c *CAS) buildSyntheticTree( treeHash := hashAlg.Sum([]byte(refHash + relPathInRepo)) - // Store in synth tree store synthContent := NewContent(c.synthStore) - if err := synthContent.Ensure(l, treeHash, treeData); err != nil { + if err := synthContent.Ensure(l, v, treeHash, treeData); err != nil { return "", fmt.Errorf("failed to store synthetic tree: %w", err) } return treeHash, nil } -// ResolveInRepoSource resolves an update_source_with_cas source string relative to -// dirPath and returns the cleaned absolute path. Absolute sources and sources -// whose resolved path escapes repoRoot via ".." segments are rejected so CAS -// materialization stays scoped to the cloned repository. -func ResolveInRepoSource(repoRoot, dirPath, source string) (string, error) { - sourcePath, sourceSubdir := SplitSourceDoubleSlash(source) - if filepath.IsAbs(sourcePath) { - return "", fmt.Errorf("%w: %q", ErrAbsoluteSource, source) - } - - resolved := filepath.Clean(filepath.Join(dirPath, sourcePath)) - if sourceSubdir != "" { - resolved = filepath.Join(resolved, sourceSubdir) - } - - rel, err := filepath.Rel(filepath.Clean(repoRoot), resolved) - if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { - return "", fmt.Errorf("%w: %q", ErrSourceEscapesRepo, source) - } - - return resolved, nil -} - // gitTreeMode returns the git tree-entry mode string for a file with the given // filesystem mode. Directories are handled by the caller, so only the regular // file, executable, and symlink cases are covered here. @@ -449,11 +484,11 @@ func gitTreeMode(mode os.FileMode) string { } } -// isLocalPath reports whether source refers to an existing directory on the -// local filesystem. Remote URLs, go-getter forcers (git::), SSH shorthand -// (git@host:…), and non-directory paths all return false and fall through to -// the remote processing flow. -func isLocalPath(source string) bool { +// isLocalPath reports whether source refers to an existing directory on fs. +// Remote URLs, go-getter forcers (git::), SSH shorthand (git@host:…), and +// non-directory paths all return false and fall through to the remote +// processing flow. +func isLocalPath(fs vfs.FS, source string) bool { if source == "" { return false } @@ -469,7 +504,7 @@ func isLocalPath(source string) bool { return true } - // SSH shorthand like git@github.com:owner/repo.git — no scheme but not local. + // SSH shorthand like git@github.com:owner/repo.git has no scheme but is not local. if strings.Contains(source, "@") && strings.Contains(source, ":") { return false } @@ -478,7 +513,7 @@ func isLocalPath(source string) bool { return false } - info, err := os.Stat(source) + info, err := fs.Stat(source) if err != nil { return false } @@ -491,14 +526,14 @@ func isLocalPath(source string) bool { // mutate the caller's working tree, computes a content-addressed root hash, // and dispatches through the same processDirectory pipeline as the remote case. func (c *CAS) processLocalStackComponent( - ctx context.Context, l log.Logger, sourceDir, subdir string, + ctx context.Context, l log.Logger, v Venv, sourceDir, subdir string, ) (*StackCASResult, error) { absSource, err := filepath.Abs(sourceDir) if err != nil { return nil, fmt.Errorf("failed to resolve local source %q: %w", sourceDir, err) } - info, err := os.Stat(absSource) + info, err := v.FS.Stat(absSource) if err != nil { return nil, fmt.Errorf("failed to stat local source %q: %w", absSource, err) } @@ -507,18 +542,20 @@ func (c *CAS) processLocalStackComponent( return nil, fmt.Errorf("%w: %s", ErrNotADirectory, absSource) } - tempDir, err := os.MkdirTemp("", "terragrunt-cas-stack-local-*") + tempDir, err := vfs.MkdirTemp(v.FS, "", "terragrunt-cas-stack-local-*") if err != nil { return nil, fmt.Errorf("failed to create temp dir: %w", err) } cleanup := func() { - _ = os.RemoveAll(tempDir) + if rmErr := v.FS.RemoveAll(tempDir); rmErr != nil { + l.Warnf("cleanup error for %s: %v", tempDir, rmErr) + } } repoRoot := filepath.Join(tempDir, "repo") - if err := c.copyTree(absSource, repoRoot); err != nil { + if err := c.copyTree(v, absSource, repoRoot); err != nil { cleanup() return nil, fmt.Errorf("failed to copy local source into temp dir: %w", err) @@ -530,7 +567,7 @@ func (c *CAS) processLocalStackComponent( if filepath.IsAbs(subdir) { cleanup() - return nil, fmt.Errorf("%w: %q", ErrSourceEscapesRepo, subdir) + return nil, fmt.Errorf("%w: %q", ErrAbsoluteSource, subdir) } contentDir = filepath.Clean(filepath.Join(repoRoot, subdir)) @@ -543,20 +580,20 @@ func (c *CAS) processLocalStackComponent( } } - if _, err := os.Stat(contentDir); err != nil { + if _, err := v.FS.Stat(contentDir); err != nil { cleanup() return nil, fmt.Errorf("subdir %q not found in local source: %w", subdir, err) } - rootHash, err := c.ComputeLocalRootHash(repoRoot, DefaultLocalHashAlgorithm) + rootHash, err := c.ComputeLocalRootHash(v, repoRoot, DefaultLocalHashAlgorithm) if err != nil { cleanup() return nil, fmt.Errorf("failed to compute local root hash: %w", err) } - if err := c.processDirectory(ctx, l, repoRoot, contentDir, rootHash, DefaultLocalHashAlgorithm); err != nil { + if err := c.processDirectory(ctx, l, v, repoRoot, contentDir, rootHash, DefaultLocalHashAlgorithm); err != nil { cleanup() return nil, fmt.Errorf("failed to process local source for CAS: %w", err) @@ -568,11 +605,11 @@ func (c *CAS) processLocalStackComponent( }, nil } -// copyTree copies the directory tree rooted at src into dst using c.fs for all +// copyTree copies the directory tree rooted at src into dst using v.FS for all // reads and writes, preserving file permissions. Regular files, directories, // and symlinks are copied; other special files are skipped. -func (c *CAS) copyTree(src, dst string) error { - return vfs.WalkDir(c.fs, src, func(path string, d fs.DirEntry, walkErr error) error { +func (c *CAS) copyTree(v Venv, src, dst string) error { + return vfs.WalkDir(v.FS, src, func(path string, d fs.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } @@ -590,55 +627,47 @@ func (c *CAS) copyTree(src, dst string) error { } if d.IsDir() { - return c.fs.MkdirAll(target, DefaultDirPerms) + return v.FS.MkdirAll(target, DefaultDirPerms) } if info.Mode()&os.ModeSymlink != 0 { - linkTarget, err := vfs.Readlink(c.fs, path) + linkTarget, err := vfs.Readlink(v.FS, path) if err != nil { return err } - resolved := linkTarget - if !filepath.IsAbs(resolved) { - resolved = filepath.Join(filepath.Dir(path), resolved) - } - - resolved = filepath.Clean(resolved) - - rel, relErr := filepath.Rel(filepath.Clean(src), resolved) - if relErr != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { - return fmt.Errorf("%w: symlink %q -> %q", ErrSourceEscapesRepo, path, linkTarget) + if err := vfs.ValidateSymlinkTarget(src, path, linkTarget); err != nil { + return fmt.Errorf("%w: %w", ErrSourceEscapesRepo, err) } - if err := c.fs.MkdirAll(filepath.Dir(target), DefaultDirPerms); err != nil { + if err := v.FS.MkdirAll(filepath.Dir(target), DefaultDirPerms); err != nil { return err } - return vfs.Symlink(c.fs, linkTarget, target) + return vfs.Symlink(v.FS, linkTarget, target) } if !info.Mode().IsRegular() { return nil } - return c.copyFileInFS(path, target, info.Mode().Perm()) + return c.copyFileInFS(v, path, target, info.Mode().Perm()) }) } // copyFileInFS copies a single regular file from srcPath to dstPath through -// c.fs, creating any missing parent directories with DefaultDirPerms. -func (c *CAS) copyFileInFS(srcPath, dstPath string, perm fs.FileMode) error { - if err := c.fs.MkdirAll(filepath.Dir(dstPath), DefaultDirPerms); err != nil { +// v.FS, creating any missing parent directories with DefaultDirPerms. +func (c *CAS) copyFileInFS(v Venv, srcPath, dstPath string, perm fs.FileMode) error { + if err := v.FS.MkdirAll(filepath.Dir(dstPath), DefaultDirPerms); err != nil { return err } - in, err := c.fs.Open(srcPath) + in, err := v.FS.Open(srcPath) if err != nil { return err } - out, err := c.fs.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) + out, err := v.FS.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) if err != nil { _ = in.Close() diff --git a/internal/cas/stacks_local_test.go b/internal/cas/stacks_local_test.go index be620e261b..7e977aff87 100644 --- a/internal/cas/stacks_local_test.go +++ b/internal/cas/stacks_local_test.go @@ -19,197 +19,6 @@ import ( "github.com/stretchr/testify/require" ) -// buildLocalStackFixture lays out a directory tree on disk that mirrors the -// structure used by the remote stack tests so we can exercise the same -// processing pipeline against a local source. The returned path is the -// repo-root; callers append "//stacks/my-stack" to target the stack. -func buildLocalStackFixture(t *testing.T) string { - t.Helper() - - root := helpers.TmpDirWOSymlinks(t) - - write := func(rel, body string) { - full := filepath.Join(root, rel) - require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) - require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) - } - - write("stacks/my-stack/terragrunt.stack.hcl", `unit "service" { - source = "../..//units/my-service" - - update_source_with_cas = true - - path = "service" -} - -unit "plain" { - source = "../../units/plain-service" - path = "plain" -} -`) - write("units/my-service/terragrunt.hcl", `terraform { - source = "../..//modules/vpc" - - update_source_with_cas = true -} -`) - write("units/plain-service/terragrunt.hcl", `terraform { - source = "../../modules/vpc" -} -`) - write("modules/vpc/main.tf", `resource "aws_vpc" "main" { - cidr_block = "10.0.0.0/16" -} -`) - write("modules/vpc/variables.tf", `variable "name" { - type = string -} -`) - - return root -} - -// buildSharedTemplateFixture lays out a stack with two unit blocks that point -// at the same unit-template directory. Reproduces issue #6141: the first -// block's pass over the shared terragrunt.hcl rewrites terraform.source to a -// cas:: ref, and the second block's pass over the same file used to treat -// that ref as a relative path and abort CAS processing. -func buildSharedTemplateFixture(t *testing.T) string { - t.Helper() - - root := helpers.TmpDirWOSymlinks(t) - - write := func(rel, body string) { - full := filepath.Join(root, rel) - require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) - require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) - } - - write("stacks/my-stack/terragrunt.stack.hcl", `unit "first" { - source = "../..//units/shared" - - update_source_with_cas = true - - path = "first" -} - -unit "second" { - source = "../..//units/shared" - - update_source_with_cas = true - - path = "second" -} -`) - write("units/shared/terragrunt.hcl", `terraform { - source = "../..//modules/vpc" - - update_source_with_cas = true -} -`) - write("modules/vpc/main.tf", `resource "aws_vpc" "main" { - cidr_block = "10.0.0.0/16" -} -`) - - return root -} - -// buildSharedNestedStackFixture lays out a top-level stack with two stack -// blocks that point at the same nested-stack directory. The same shared- -// template failure mode applies through the recursive processStackFile path: -// the second block's recursive call re-reads the nested stack file the -// first block's pass already rewrote. -func buildSharedNestedStackFixture(t *testing.T) string { - t.Helper() - - root := helpers.TmpDirWOSymlinks(t) - - write := func(rel, body string) { - full := filepath.Join(root, rel) - require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) - require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) - } - - write("stacks/parent/terragrunt.stack.hcl", `stack "alpha" { - source = "../..//stacks/nested" - - update_source_with_cas = true - - path = "alpha" -} - -stack "beta" { - source = "../..//stacks/nested" - - update_source_with_cas = true - - path = "beta" -} -`) - write("stacks/nested/terragrunt.stack.hcl", `unit "service" { - source = "../..//units/shared" - - update_source_with_cas = true - - path = "service" -} -`) - write("units/shared/terragrunt.hcl", `terraform { - source = "../..//modules/vpc" - - update_source_with_cas = true -} -`) - write("modules/vpc/main.tf", `resource "aws_vpc" "main" { - cidr_block = "10.0.0.0/16" -} -`) - - return root -} - -// snapshotTree reads every regular file under root and returns a sha256 of the -// (relpath, mode, contents) triples in walk order. Used to prove a run didn't -// mutate the source tree, including file permissions. -func snapshotTree(t *testing.T, root string) string { - t.Helper() - - h := sha256.New() - - err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - if info.IsDir() || !info.Mode().IsRegular() { - return nil - } - - rel, err := filepath.Rel(root, path) - if err != nil { - return err - } - - body, err := os.ReadFile(path) - if err != nil { - return err - } - - h.Write([]byte(rel)) - h.Write([]byte{0}) - h.Write([]byte(info.Mode().String())) - h.Write([]byte{0}) - h.Write(body) - h.Write([]byte{0}) - - return nil - }) - require.NoError(t, err) - - return hex.EncodeToString(h.Sum(nil)) -} - func TestProcessStackComponent_LocalSource_RewritesStackSources(t *testing.T) { t.Parallel() @@ -220,9 +29,12 @@ func TestProcessStackComponent_LocalSource_RewritesStackSources(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := root + "//stacks/my-stack" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -250,9 +62,12 @@ func TestProcessStackComponent_LocalSource_RewritesUnitSources(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := root + "//stacks/my-stack" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -271,130 +86,52 @@ func TestProcessStackComponent_LocalSource_RewritesUnitSources(t *testing.T) { assert.NotContains(t, contentStr, "modules/vpc", "module path should not appear in the rewritten source") } -// TestProcessStackComponent_LocalSource_SharedUnitTemplate covers issue -// #6141: two unit blocks pointing at the same unit-template directory must -// both rewrite cleanly to identical cas:: refs. Before the fix the second -// block's pass over the shared terragrunt.hcl re-read the already-rewritten -// file and treated "cas::sha256:..." as a relative path, failing the whole -// stack. -func TestProcessStackComponent_LocalSource_SharedUnitTemplate(t *testing.T) { +func TestProcessStackComponent_LocalSource_DoesNotMutateInput(t *testing.T) { t.Parallel() - root := buildSharedTemplateFixture(t) + root := buildLocalStackFixture(t) l := logger.CreateLogger() + before := snapshotTree(t, root) + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) - source := root + "//stacks/my-stack" - - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") - require.NoError(t, err, "shared unit template across two blocks must not fail CAS processing") - - defer result.Cleanup() - - content, err := os.ReadFile(filepath.Join(result.ContentDir, "terragrunt.stack.hcl")) + v, err := cas.OSVenv() require.NoError(t, err) - blocks, err := cas.ReadStackBlocks(content) + source := root + "//stacks/my-stack" + + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) - sources := map[string]string{} - for _, b := range blocks { - sources[b.Name] = b.Source - } + result.Cleanup() - require.Contains(t, sources, "first") - require.Contains(t, sources, "second") - assert.True(t, strings.HasPrefix(sources["first"], "cas::sha256:"), "first unit must be rewritten to cas:: ref") - assert.True(t, strings.HasPrefix(sources["second"], "cas::sha256:"), "second unit must be rewritten to cas:: ref") - assert.Equal(t, sources["first"], sources["second"], - "two blocks sharing one template must resolve to the same synthetic tree") + after := snapshotTree(t, root) + assert.Equal(t, before, after, "processing must not mutate the local source tree") } -// TestProcessStackComponent_LocalSource_SharedNestedStack is the stack-block -// analogue of TestProcessStackComponent_LocalSource_SharedUnitTemplate. Two -// stack blocks pointing at the same nested-stack directory must both rewrite -// to the same cas:: ref. Before the fix, the second block's recursive pass -// over the nested stack file re-read the already-rewritten file and tried to -// resolve its cas:: block sources as relative paths. -func TestProcessStackComponent_LocalSource_SharedNestedStack(t *testing.T) { +func TestProcessStackComponent_LocalSource_DeterministicOutput(t *testing.T) { t.Parallel() - root := buildSharedNestedStackFixture(t) + root := buildLocalStackFixture(t) l := logger.CreateLogger() - storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") - c, err := cas.New(cas.WithStorePath(storePath)) + v, err := cas.OSVenv() require.NoError(t, err) - source := root + "//stacks/parent" + readStackFile := func() string { + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(t, err) - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") - require.NoError(t, err, "shared nested stack across two blocks must not fail CAS processing") + source := root + "//stacks/my-stack" - defer result.Cleanup() + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") + require.NoError(t, err) - content, err := os.ReadFile(filepath.Join(result.ContentDir, "terragrunt.stack.hcl")) - require.NoError(t, err) - - blocks, err := cas.ReadStackBlocks(content) - require.NoError(t, err) - - sources := map[string]string{} - for _, b := range blocks { - sources[b.Name] = b.Source - } - - require.Contains(t, sources, "alpha") - require.Contains(t, sources, "beta") - assert.True(t, strings.HasPrefix(sources["alpha"], "cas::sha256:"), "alpha stack must be rewritten to cas:: ref") - assert.True(t, strings.HasPrefix(sources["beta"], "cas::sha256:"), "beta stack must be rewritten to cas:: ref") - assert.Equal(t, sources["alpha"], sources["beta"], - "two stack blocks sharing one nested-stack template must resolve to the same synthetic tree") -} - -func TestProcessStackComponent_LocalSource_DoesNotMutateInput(t *testing.T) { - t.Parallel() - - root := buildLocalStackFixture(t) - l := logger.CreateLogger() - - before := snapshotTree(t, root) - - storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") - c, err := cas.New(cas.WithStorePath(storePath)) - require.NoError(t, err) - - source := root + "//stacks/my-stack" - - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") - require.NoError(t, err) - - result.Cleanup() - - after := snapshotTree(t, root) - assert.Equal(t, before, after, "processing must not mutate the local source tree") -} - -func TestProcessStackComponent_LocalSource_DeterministicOutput(t *testing.T) { - t.Parallel() - - root := buildLocalStackFixture(t) - l := logger.CreateLogger() - - readStackFile := func() string { - storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") - c, err := cas.New(cas.WithStorePath(storePath)) - require.NoError(t, err) - - source := root + "//stacks/my-stack" - - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") - require.NoError(t, err) - - defer result.Cleanup() + defer result.Cleanup() content, err := os.ReadFile(filepath.Join(result.ContentDir, "terragrunt.stack.hcl")) require.NoError(t, err) @@ -412,7 +149,7 @@ func TestProcessStackComponent_LocalSource_ContentAddressedCacheKey(t *testing.T t.Parallel() // Two fixtures with the same relative layout but different module contents - // must yield different synthetic tree hashes — otherwise one source would + // must yield different synthetic tree hashes; otherwise one source would // poison the cache for the other. rootA := buildLocalStackFixture(t) rootB := buildLocalStackFixture(t) @@ -427,6 +164,9 @@ func TestProcessStackComponent_LocalSource_ContentAddressedCacheKey(t *testing.T l := logger.CreateLogger() + v, err := cas.OSVenv() + require.NoError(t, err) + runAndExtractServiceRef := func(root string) string { storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") c, err := cas.New(cas.WithStorePath(storePath)) @@ -434,7 +174,7 @@ func TestProcessStackComponent_LocalSource_ContentAddressedCacheKey(t *testing.T source := root + "//stacks/my-stack" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -472,13 +212,13 @@ func TestProcessStackComponent_LocalSource_ContentAddressedCacheKey(t *testing.T func TestProcessStackComponent_LocalSource_NonExistentPath(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) l := logger.CreateLogger() - // Absolute path that does not exist — must not be misinterpreted as a URL. + // Absolute path that does not exist must not be misinterpreted as a URL. source := filepath.Join(helpers.TmpDirWOSymlinks(t), "does-not-exist") - _, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + _, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.Error(t, err, "non-existent local path must fail") require.ErrorIs(t, err, fs.ErrNotExist, "error must be a local file-not-found error") } @@ -486,7 +226,7 @@ func TestProcessStackComponent_LocalSource_NonExistentPath(t *testing.T) { func TestProcessStackComponent_LocalSource_RegularFileRejected(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) l := logger.CreateLogger() // A regular file is not a valid component source. The local flow rejects @@ -497,20 +237,20 @@ func TestProcessStackComponent_LocalSource_RegularFileRejected(t *testing.T) { filePath := filepath.Join(tmp, "a-file") require.NoError(t, os.WriteFile(filePath, []byte("x"), 0o644)) - _, err := c.ProcessStackComponent(t.Context(), l, filePath, "stack") + _, err := c.ProcessStackComponent(t.Context(), l, v, filePath, "stack") require.Error(t, err, "a regular file must not be accepted as a component source") } func TestProcessStackComponent_LocalSource_MissingSubdir(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) l := logger.CreateLogger() root := buildLocalStackFixture(t) source := root + "//stacks/does-not-exist" - _, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + _, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.Error(t, err, "missing subdir inside a local source must fail") assert.Contains(t, err.Error(), "does-not-exist") } @@ -518,7 +258,7 @@ func TestProcessStackComponent_LocalSource_MissingSubdir(t *testing.T) { func TestProcessStackComponent_LocalSource_SymlinkEscapesRepo(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) l := logger.CreateLogger() tmp := helpers.TmpDirWOSymlinks(t) @@ -536,7 +276,7 @@ func TestProcessStackComponent_LocalSource_SymlinkEscapesRepo(t *testing.T) { )) require.NoError(t, os.Symlink(outside, filepath.Join(root, "escape"))) - _, err := c.ProcessStackComponent(t.Context(), l, root+"//stacks/my-stack", "stack") + _, err := c.ProcessStackComponent(t.Context(), l, v, root+"//stacks/my-stack", "stack") require.Error(t, err, "symlink pointing outside the source root must be rejected") require.ErrorIs(t, err, cas.ErrSourceEscapesRepo) } @@ -547,10 +287,10 @@ func TestProcessStackComponent_LocalSource_SymlinkEscapesRepo(t *testing.T) { func TestProcessStackComponent_EmptySourceFails(t *testing.T) { t.Parallel() - c := newCAS(t) + c, v := newCAS(t) l := logger.CreateLogger() - _, err := c.ProcessStackComponent(t.Context(), l, "", "stack") + _, err := c.ProcessStackComponent(t.Context(), l, v, "", "stack") require.Error(t, err, "empty source must be rejected") } @@ -568,9 +308,12 @@ func TestProcessStackComponent_GitForcerRoutesRemote(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := "git::" + repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err, "git:: forcer must route through the remote flow and succeed against the test server") defer result.Cleanup() @@ -581,7 +324,7 @@ func TestProcessStackComponent_GitForcerRoutesRemote(t *testing.T) { // TestProcessStackComponent_SSHShorthandRoutesRemote confirms that SSH // shorthand (git@host:path) is treated as remote, not local. The test runs // inside a synctest bubble so the context deadline fires on the synthetic -// clock the moment every bubbled goroutine is idle — no real-time wait, no +// clock the moment every bubbled goroutine is idle, with no real-time wait and no // dependency on DNS or network behavior. All we care about here is which // branch of the dispatcher the source was routed through; we assert the // error originated in the remote pipeline, not the local one. @@ -589,13 +332,13 @@ func TestProcessStackComponent_SSHShorthandRoutesRemote(t *testing.T) { t.Parallel() synctest.Test(t, func(t *testing.T) { - c := newCAS(t) + c, v := newCAS(t) l := logger.CreateLogger() ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond) defer cancel() - _, err := c.ProcessStackComponent(ctx, l, "git@unreachable.invalid:owner/repo.git", "stack") + _, err := c.ProcessStackComponent(ctx, l, v, "git@unreachable.invalid:owner/repo.git", "stack") require.Error(t, err, "SSH shorthand must route through the remote flow and fail there") assert.NotContains(t, err.Error(), "local source", "error must come from the remote pipeline") }) @@ -612,9 +355,12 @@ func TestProcessStackComponent_LocalSource_MaterializeSynthTree(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := root + "//stacks/my-stack" - result, err := c.ProcessStackComponent(ctx, l, source, "stack") + result, err := c.ProcessStackComponent(ctx, l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -642,7 +388,288 @@ func TestProcessStackComponent_LocalSource_MaterializeSynthTree(t *testing.T) { require.NoError(t, err) destDir := helpers.TmpDirWOSymlinks(t) - require.NoError(t, c.MaterializeTree(ctx, l, hash, destDir)) + require.NoError(t, c.MaterializeTree(ctx, l, v, hash, destDir)) assert.FileExists(t, filepath.Join(destDir, "terragrunt.hcl")) } + +// buildLocalStackFixture lays out a directory tree on disk that mirrors the +// structure used by the remote stack tests so we can exercise the same +// processing pipeline against a local source. The returned path is the +// repo-root; callers append "//stacks/my-stack" to target the stack. +func buildLocalStackFixture(t *testing.T) string { + t.Helper() + + root := helpers.TmpDirWOSymlinks(t) + + write := func(rel, body string) { + full := filepath.Join(root, rel) + require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) + require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) + } + + write("stacks/my-stack/terragrunt.stack.hcl", `unit "service" { + source = "../..//units/my-service" + + update_source_with_cas = true + + path = "service" +} + +unit "plain" { + source = "../../units/plain-service" + path = "plain" +} +`) + write("units/my-service/terragrunt.hcl", `terraform { + source = "../..//modules/vpc" + + update_source_with_cas = true +} +`) + write("units/plain-service/terragrunt.hcl", `terraform { + source = "../../modules/vpc" +} +`) + write("modules/vpc/main.tf", `resource "aws_vpc" "main" { + cidr_block = "10.0.0.0/16" +} +`) + write("modules/vpc/variables.tf", `variable "name" { + type = string +} +`) + + return root +} + +// buildSharedTemplateFixture lays out a stack with two unit blocks that point +// at the same unit-template directory. Reproduces issue #6141: the first +// block's pass over the shared terragrunt.hcl rewrites terraform.source to a +// cas:: ref, and the second block's pass over the same file used to treat +// that ref as a relative path and abort CAS processing. +func buildSharedTemplateFixture(t *testing.T) string { + t.Helper() + + root := helpers.TmpDirWOSymlinks(t) + + write := func(rel, body string) { + full := filepath.Join(root, rel) + require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) + require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) + } + + write("stacks/my-stack/terragrunt.stack.hcl", `unit "first" { + source = "../..//units/shared" + + update_source_with_cas = true + + path = "first" +} + +unit "second" { + source = "../..//units/shared" + + update_source_with_cas = true + + path = "second" +} +`) + write("units/shared/terragrunt.hcl", `terraform { + source = "../..//modules/vpc" + + update_source_with_cas = true +} +`) + write("modules/vpc/main.tf", `resource "aws_vpc" "main" { + cidr_block = "10.0.0.0/16" +} +`) + + return root +} + +// buildSharedNestedStackFixture lays out a top-level stack with two stack +// blocks that point at the same nested-stack directory. The same shared- +// template failure mode applies through the recursive processStackFile path: +// the second block's recursive call re-reads the nested stack file the +// first block's pass already rewrote. +func buildSharedNestedStackFixture(t *testing.T) string { + t.Helper() + + root := helpers.TmpDirWOSymlinks(t) + + write := func(rel, body string) { + full := filepath.Join(root, rel) + require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) + require.NoError(t, os.WriteFile(full, []byte(body), 0o644)) + } + + write("stacks/parent/terragrunt.stack.hcl", `stack "alpha" { + source = "../..//stacks/nested" + + update_source_with_cas = true + + path = "alpha" +} + +stack "beta" { + source = "../..//stacks/nested" + + update_source_with_cas = true + + path = "beta" +} +`) + write("stacks/nested/terragrunt.stack.hcl", `unit "service" { + source = "../..//units/shared" + + update_source_with_cas = true + + path = "service" +} +`) + write("units/shared/terragrunt.hcl", `terraform { + source = "../..//modules/vpc" + + update_source_with_cas = true +} +`) + write("modules/vpc/main.tf", `resource "aws_vpc" "main" { + cidr_block = "10.0.0.0/16" +} +`) + + return root +} + +// snapshotTree reads every regular file under root and returns a sha256 of the +// (relpath, mode, contents) triples in walk order. Used to prove a run didn't +// mutate the source tree, including file permissions. +func snapshotTree(t *testing.T, root string) string { + t.Helper() + + h := sha256.New() + + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() || !info.Mode().IsRegular() { + return nil + } + + rel, err := filepath.Rel(root, path) + if err != nil { + return err + } + + body, err := os.ReadFile(path) + if err != nil { + return err + } + + h.Write([]byte(rel)) + h.Write([]byte{0}) + h.Write([]byte(info.Mode().String())) + h.Write([]byte{0}) + h.Write(body) + h.Write([]byte{0}) + + return nil + }) + require.NoError(t, err) + + return hex.EncodeToString(h.Sum(nil)) +} + +// TestProcessStackComponent_LocalSource_SharedUnitTemplate covers issue +// #6141: two unit blocks pointing at the same unit-template directory must +// both rewrite cleanly to identical cas:: refs. Before the fix the second +// block's pass over the shared terragrunt.hcl re-read the already-rewritten +// file and treated "cas::sha256:..." as a relative path, failing the whole +// stack. +func TestProcessStackComponent_LocalSource_SharedUnitTemplate(t *testing.T) { + t.Parallel() + + root := buildSharedTemplateFixture(t) + l := logger.CreateLogger() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + source := root + "//stacks/my-stack" + + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") + require.NoError(t, err, "shared unit template across two blocks must not fail CAS processing") + + defer result.Cleanup() + + content, err := os.ReadFile(filepath.Join(result.ContentDir, "terragrunt.stack.hcl")) + require.NoError(t, err) + + blocks, err := cas.ReadStackBlocks(content) + require.NoError(t, err) + + sources := map[string]string{} + for _, b := range blocks { + sources[b.Name] = b.Source + } + + require.Contains(t, sources, "first") + require.Contains(t, sources, "second") + assert.True(t, strings.HasPrefix(sources["first"], "cas::sha256:"), "first unit must be rewritten to cas:: ref") + assert.True(t, strings.HasPrefix(sources["second"], "cas::sha256:"), "second unit must be rewritten to cas:: ref") + assert.Equal(t, sources["first"], sources["second"], + "two blocks sharing one template must resolve to the same synthetic tree") +} + +// TestProcessStackComponent_LocalSource_SharedNestedStack is the stack-block +// analogue of TestProcessStackComponent_LocalSource_SharedUnitTemplate. Two +// stack blocks pointing at the same nested-stack directory must both rewrite +// to the same cas:: ref. Before the fix, the second block's recursive pass +// over the nested stack file re-read the already-rewritten file and tried to +// resolve its cas:: block sources as relative paths. +func TestProcessStackComponent_LocalSource_SharedNestedStack(t *testing.T) { + t.Parallel() + + root := buildSharedNestedStackFixture(t) + l := logger.CreateLogger() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := cas.OSVenv() + require.NoError(t, err) + + source := root + "//stacks/parent" + + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") + require.NoError(t, err, "shared nested stack across two blocks must not fail CAS processing") + + defer result.Cleanup() + + content, err := os.ReadFile(filepath.Join(result.ContentDir, "terragrunt.stack.hcl")) + require.NoError(t, err) + + blocks, err := cas.ReadStackBlocks(content) + require.NoError(t, err) + + sources := map[string]string{} + for _, b := range blocks { + sources[b.Name] = b.Source + } + + require.Contains(t, sources, "alpha") + require.Contains(t, sources, "beta") + assert.True(t, strings.HasPrefix(sources["alpha"], "cas::sha256:"), "alpha stack must be rewritten to cas:: ref") + assert.True(t, strings.HasPrefix(sources["beta"], "cas::sha256:"), "beta stack must be rewritten to cas:: ref") + assert.Equal(t, sources["alpha"], sources["beta"], + "two stack blocks sharing one nested-stack template must resolve to the same synthetic tree") +} diff --git a/internal/cas/stacks_test.go b/internal/cas/stacks_test.go index 4001be4d92..2b2a98155c 100644 --- a/internal/cas/stacks_test.go +++ b/internal/cas/stacks_test.go @@ -153,10 +153,13 @@ func TestProcessStackComponent_RewritesStackSources(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + // Source mimics what a stack generates: //?ref= source := repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -189,9 +192,12 @@ func TestProcessStackComponent_RewritesUnitSources(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -233,9 +239,12 @@ func TestProcessStackComponent_CreatesSyntheticTrees(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(ctx, l, source, "stack") + result, err := c.ProcessStackComponent(ctx, l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -268,11 +277,11 @@ func TestProcessStackComponent_CreatesSyntheticTrees(t *testing.T) { // The synthetic tree should be stored in the synth store synthStore := cas.NewStore(filepath.Join(storePath, "synth", "trees")) - assert.False(t, synthStore.NeedsWrite(hash), "synthetic tree should exist in synth store") + assert.False(t, synthStore.NeedsWrite(v, hash), "synthetic tree should exist in synth store") // Verify the tree can be read and contains entries synthContent := cas.NewContent(synthStore) - treeData, err := synthContent.Read(hash) + treeData, err := synthContent.Read(v, hash) require.NoError(t, err) assert.NotEmpty(t, treeData, "synthetic tree data should not be empty") } @@ -283,6 +292,9 @@ func TestProcessStackComponent_DeterministicOutput(t *testing.T) { repoURL := startStackTestServer(t) l := logger.CreateLogger() + v, err := cas.OSVenv() + require.NoError(t, err) + readStackFile := func() string { storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) @@ -290,7 +302,7 @@ func TestProcessStackComponent_DeterministicOutput(t *testing.T) { source := repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -305,7 +317,7 @@ func TestProcessStackComponent_DeterministicOutput(t *testing.T) { first := readStackFile() second := readStackFile() - // Both runs should produce identical output — the CAS hashes are + // Both runs should produce identical output. The CAS hashes are // deterministic based on ref + path, so regeneration must not produce diffs. assert.Equal(t, first, second, "processing the same source twice should produce identical output") } @@ -321,9 +333,12 @@ func TestProcessStackComponent_MaterializeSynthTree(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(ctx, l, source, "stack") + result, err := c.ProcessStackComponent(ctx, l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -353,7 +368,7 @@ func TestProcessStackComponent_MaterializeSynthTree(t *testing.T) { // Materialize the synthetic tree to a new directory destDir := helpers.TmpDirWOSymlinks(t) - err = c.MaterializeTree(ctx, l, hash, destDir) + err = c.MaterializeTree(ctx, l, v, hash, destDir) require.NoError(t, err) // The materialized tree should contain the unit's terragrunt.hcl @@ -370,9 +385,12 @@ func TestProcessStackComponent_InvalidRefFails(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := repoURL + "//stacks/my-stack?ref=nonexistent-tag" - _, err = c.ProcessStackComponent(t.Context(), l, source, "stack") + _, err = c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.Error(t, err, "should fail when ref does not exist") } @@ -386,9 +404,12 @@ func TestProcessStackComponent_InvalidSubdirFails(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := repoURL + "//nonexistent/path?ref=main" - _, err = c.ProcessStackComponent(t.Context(), l, source, "stack") + _, err = c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.Error(t, err, "should fail when subdir does not exist") } @@ -403,9 +424,12 @@ func TestProcessStackComponent_BlobsStoredInCAS(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(ctx, l, source, "stack") + result, err := c.ProcessStackComponent(ctx, l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -436,9 +460,12 @@ func TestProcessStackComponent_AcceptsExplicitGitPrefix(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + source := "git::" + repoURL + "//stacks/my-stack?ref=main" - result, err := c.ProcessStackComponent(t.Context(), l, source, "stack") + result, err := c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.NoError(t, err) defer result.Cleanup() @@ -455,11 +482,14 @@ func TestProcessStackComponent_ShorthandSourceReachesClone(t *testing.T) { c, err := cas.New(cas.WithStorePath(storePath), cas.WithCloneDepth(-1)) require.NoError(t, err) + v, err := cas.OSVenv() + require.NoError(t, err) + // Bogus org so the network call fails fast. The error shape proves the // shorthand was rewritten and reached `git ls-remote`. source := "github.com/gruntwork-io-this-org-does-not-exist/repo?ref=main" - _, err = c.ProcessStackComponent(t.Context(), l, source, "stack") + _, err = c.ProcessStackComponent(t.Context(), l, v, source, "stack") require.Error(t, err) require.ErrorIs(t, err, git.ErrCommandSpawn, "failure must come from a spawned git command") diff --git a/internal/cas/store.go b/internal/cas/store.go index 68df120abc..b00323cc06 100644 --- a/internal/cas/store.go +++ b/internal/cas/store.go @@ -6,70 +6,53 @@ import ( "github.com/gruntwork-io/terragrunt/internal/vfs" ) -// Store manages the store directory and filesystem locks to prevent concurrent writes +// Store manages the store directory and filesystem locks to prevent concurrent writes. type Store struct { - fs vfs.FS path string } -// NewStore creates a new Store instance with the OS filesystem. +// NewStore creates a new Store rooted at path. func NewStore(path string) *Store { - return &Store{ - path: path, - fs: vfs.NewOSFS(), - } -} - -// WithFS sets the filesystem for file operations and returns the Store for method chaining. -func (s *Store) WithFS(fs vfs.FS) *Store { - s.fs = fs - return s + return &Store{path: path} } -// FS returns the configured filesystem. -func (s *Store) FS() vfs.FS { - return s.fs -} - -// Path returns the current store path +// Path returns the current store path. func (s *Store) Path() string { return s.path } -// NeedsWrite checks if a given hash needs to be stored -func (s *Store) NeedsWrite(hash string) bool { +// NeedsWrite checks if a given hash needs to be stored. +func (s *Store) NeedsWrite(v Venv, hash string) bool { partitionDir := filepath.Join(s.path, hash[:2]) path := filepath.Join(partitionDir, hash) - return !s.hasContent(path) + return !s.hasContent(v, path) } -// AcquireLock acquires a filesystem lock for the given hash -// Returns the lock that should be unlocked when done -func (s *Store) AcquireLock(hash string) (vfs.Unlocker, error) { +// AcquireLock acquires a filesystem lock for the given hash. +// Returns the lock that should be unlocked when done. +func (s *Store) AcquireLock(v Venv, hash string) (vfs.Unlocker, error) { partitionDir := filepath.Join(s.path, hash[:2]) lockPath := filepath.Join(partitionDir, hash+".lock") - // Ensure the partition directory exists - if err := s.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err := v.FS.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return nil, err } - return vfs.Lock(s.fs, lockPath) + return vfs.Lock(v.FS, lockPath) } -// TryAcquireLock attempts to acquire a filesystem lock for the given hash without blocking -// Returns the lock and true if successful, nil and false if the lock is already held -func (s *Store) TryAcquireLock(hash string) (vfs.Unlocker, bool, error) { +// TryAcquireLock attempts to acquire a filesystem lock for the given hash without blocking. +// Returns the lock and true if successful, nil and false if the lock is already held. +func (s *Store) TryAcquireLock(v Venv, hash string) (vfs.Unlocker, bool, error) { partitionDir := filepath.Join(s.path, hash[:2]) lockPath := filepath.Join(partitionDir, hash+".lock") - // Ensure the partition directory exists - if err := s.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err := v.FS.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return nil, false, err } - return vfs.TryLock(s.fs, lockPath) + return vfs.TryLock(v.FS, lockPath) } // EnsureWithWait tries to acquire a lock for the given hash, and if another process @@ -77,48 +60,40 @@ func (s *Store) TryAcquireLock(hash string) (vfs.Unlocker, bool, error) { // This is an optimization for read operations that avoids duplicate writes. // // Returns: -// - needsWrite: true if content doesn't exist and caller should write it -// - lock: the acquired lock (nil if needsWrite is false) -// - error: any error that occurred -func (s *Store) EnsureWithWait(hash string) (needsWrite bool, lock vfs.Unlocker, err error) { - // Fast path: check if content already exists +// - needsWrite: true if content doesn't exist and caller should write it +// - lock: the acquired lock (nil if needsWrite is false) +// - error: any error that occurred +func (s *Store) EnsureWithWait(v Venv, hash string) (needsWrite bool, lock vfs.Unlocker, err error) { partitionDir := filepath.Join(s.path, hash[:2]) path := filepath.Join(partitionDir, hash) - if s.hasContent(path) { + if s.hasContent(v, path) { return false, nil, nil } - // Try to acquire lock without blocking - tryLock, acquired, err := s.TryAcquireLock(hash) + tryLock, acquired, err := s.TryAcquireLock(v, hash) if err != nil { return false, nil, err } if acquired { - // We got the lock immediately, check if we still need to write - // (another process might have completed while we were trying) - if !s.NeedsWrite(hash) { - // Content appeared while we were acquiring lock, no write needed + if !s.NeedsWrite(v, hash) { if err = tryLock.Unlock(); err != nil { return false, nil, err } return false, nil, nil } - // We have the lock and content doesn't exist, caller should write + return true, tryLock, nil } - // Lock is held by another process, wait for it to complete - waitLock, err := s.AcquireLock(hash) + waitLock, err := s.AcquireLock(v, hash) if err != nil { return false, nil, err } - // Now we have the lock, check if the other process wrote the content - if !s.NeedsWrite(hash) { - // Content was written by the other process, no write needed + if !s.NeedsWrite(v, hash) { if err := waitLock.Unlock(); err != nil { return false, nil, err } @@ -126,12 +101,11 @@ func (s *Store) EnsureWithWait(hash string) (needsWrite bool, lock vfs.Unlocker, return false, nil, nil } - // Content still doesn't exist, caller should write it return true, waitLock, nil } -func (s *Store) hasContent(path string) bool { - _, err := s.fs.Stat(path) +func (s *Store) hasContent(v Venv, path string) bool { + _, err := v.FS.Stat(path) return err == nil } diff --git a/internal/cas/store_test.go b/internal/cas/store_test.go index d6b6c2dd70..28a17c3a55 100644 --- a/internal/cas/store_test.go +++ b/internal/cas/store_test.go @@ -19,10 +19,9 @@ func TestStore(t *testing.T) { t.Run("custom path", func(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() customPath := "/custom-store" - store := cas.NewStore(customPath).WithFS(memFs) + store := cas.NewStore(customPath) assert.Equal(t, customPath, store.Path()) }) } @@ -30,19 +29,19 @@ func TestStore(t *testing.T) { func TestStore_NeedsWrite(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() + v := newMemVenv(t) storePath := defaultStorePath - store := cas.NewStore(storePath).WithFS(memFs) + store := cas.NewStore(storePath) // Create a fake content file testHash := "abcdef123456" // Create partition directory partitionDir := filepath.Join(store.Path(), testHash[:2]) - err := memFs.MkdirAll(partitionDir, 0755) + err := v.FS.MkdirAll(partitionDir, 0755) require.NoError(t, err, "Failed to create partition directory") testPath := filepath.Join(partitionDir, testHash) - err = vfs.WriteFile(memFs, testPath, []byte("test"), 0644) + err = vfs.WriteFile(v.FS, testPath, []byte("test"), 0644) require.NoError(t, err, "Failed to create test file") tests := []struct { @@ -65,7 +64,7 @@ func TestStore_NeedsWrite(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tt.want, store.NeedsWrite(tt.hash)) + assert.Equal(t, tt.want, store.NeedsWrite(v, tt.hash)) }) } } @@ -73,19 +72,19 @@ func TestStore_NeedsWrite(t *testing.T) { func TestStore_AcquireLock(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() + v := newMemVenv(t) storePath := defaultStorePath - store := cas.NewStore(storePath).WithFS(memFs) + store := cas.NewStore(storePath) testHash := "abcdef1234567890abcdef1234567890abcdef12" // Test successful lock acquisition - lock, err := store.AcquireLock(testHash) + lock, err := store.AcquireLock(v, testHash) require.NoError(t, err) assert.NotNil(t, lock) // Verify partition directory was created partitionDir := filepath.Join(storePath, testHash[:2]) - _, err = memFs.Stat(partitionDir) + _, err = v.FS.Stat(partitionDir) require.NoError(t, err) // Clean up @@ -96,19 +95,19 @@ func TestStore_AcquireLock(t *testing.T) { func TestStore_TryAcquireLock(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() + v := newMemVenv(t) storePath := defaultStorePath - store := cas.NewStore(storePath).WithFS(memFs) + store := cas.NewStore(storePath) testHash := "abcdef1234567890abcdef1234567890abcdef12" // Test successful lock acquisition - lock1, acquired, err := store.TryAcquireLock(testHash) + lock1, acquired, err := store.TryAcquireLock(v, testHash) require.NoError(t, err) assert.True(t, acquired) assert.NotNil(t, lock1) // Test lock contention - should fail to acquire - lock2, acquired, err := store.TryAcquireLock(testHash) + lock2, acquired, err := store.TryAcquireLock(v, testHash) require.NoError(t, err) assert.False(t, acquired) assert.Nil(t, lock2) @@ -118,7 +117,7 @@ func TestStore_TryAcquireLock(t *testing.T) { require.NoError(t, err) // Now should be able to acquire again - lock3, acquired, err := store.TryAcquireLock(testHash) + lock3, acquired, err := store.TryAcquireLock(v, testHash) require.NoError(t, err) assert.True(t, acquired) assert.NotNil(t, lock3) @@ -131,9 +130,9 @@ func TestStore_TryAcquireLock(t *testing.T) { func TestStore_LockConcurrency(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() + v := newMemVenv(t) storePath := defaultStorePath - store := cas.NewStore(storePath).WithFS(memFs) + store := cas.NewStore(storePath) testHash := "abcdef1234567890abcdef1234567890abcdef12" // Test that multiple goroutines can't acquire the same lock @@ -142,7 +141,7 @@ func TestStore_LockConcurrency(t *testing.T) { // First goroutine acquires lock and holds it briefly go func() { - lock, err := store.AcquireLock(testHash) + lock, err := store.AcquireLock(v, testHash) assert.NoError(t, err) acquired <- true @@ -161,7 +160,7 @@ func TestStore_LockConcurrency(t *testing.T) { // Should block until first lock is released start := time.Now() - lock, err := store.AcquireLock(testHash) + lock, err := store.AcquireLock(v, testHash) elapsed := time.Since(start) assert.NoError(t, err) @@ -181,9 +180,9 @@ func TestStore_LockConcurrency(t *testing.T) { func TestStore_EnsureWithWait(t *testing.T) { t.Parallel() - memFs := vfs.NewMemMapFS() + v := newMemVenv(t) storePath := defaultStorePath - store := cas.NewStore(storePath).WithFS(memFs) + store := cas.NewStore(storePath) testHash := "abcdef1234567890abcdef1234567890abcdef12" t.Run("content already exists", func(t *testing.T) { @@ -191,15 +190,15 @@ func TestStore_EnsureWithWait(t *testing.T) { // Create the content manually partitionDir := filepath.Join(storePath, testHash[:2]) - err := memFs.MkdirAll(partitionDir, 0755) + err := v.FS.MkdirAll(partitionDir, 0755) require.NoError(t, err) contentPath := filepath.Join(partitionDir, testHash) - err = vfs.WriteFile(memFs, contentPath, []byte("existing content"), 0644) + err = vfs.WriteFile(v.FS, contentPath, []byte("existing content"), 0644) require.NoError(t, err) // EnsureWithWait should return false (no write needed) - needsWrite, lock, err := store.EnsureWithWait(testHash) + needsWrite, lock, err := store.EnsureWithWait(v, testHash) require.NoError(t, err) assert.False(t, needsWrite) assert.Nil(t, lock) @@ -211,7 +210,7 @@ func TestStore_EnsureWithWait(t *testing.T) { testHashNew := "fedcba0987654321fedcba0987654321fedcba09" // EnsureWithWait should return true (write needed) and provide lock - needsWrite, lock, err := store.EnsureWithWait(testHashNew) + needsWrite, lock, err := store.EnsureWithWait(v, testHashNew) require.NoError(t, err) assert.True(t, needsWrite) assert.NotNil(t, lock) diff --git a/internal/cas/testserver_test.go b/internal/cas/testserver_test.go index 58abfb9d7b..3c15567651 100644 --- a/internal/cas/testserver_test.go +++ b/internal/cas/testserver_test.go @@ -46,9 +46,9 @@ func startTestServer(t *testing.T) string { // // The repo layout is: // -// stacks/my-stack/terragrunt.stack.hcl — stack file with update_source_with_cas -// units/my-service/terragrunt.hcl — unit file with update_source_with_cas -// modules/vpc/main.tf — plain Terraform module +// stacks/my-stack/terragrunt.stack.hcl stack file with update_source_with_cas +// units/my-service/terragrunt.hcl unit file with update_source_with_cas +// modules/vpc/main.tf plain Terraform module func startStackTestServer(t *testing.T) string { t.Helper() @@ -81,7 +81,7 @@ unit "plain" { `) require.NoError(t, srv.CommitFile("units/my-service/terragrunt.hcl", unitHCL, "add unit file")) - // Plain unit (no CAS flag) — should remain unchanged after processing. + // Plain unit (no CAS flag) should remain unchanged after processing. plainUnitHCL := []byte(`terraform { source = "../../modules/vpc" } diff --git a/internal/cas/tree.go b/internal/cas/tree.go index 380795683e..16d049c671 100644 --- a/internal/cas/tree.go +++ b/internal/cas/tree.go @@ -16,10 +16,9 @@ import ( // unixPermMask isolates the user/group/other rwx bits from a git tree mode. const unixPermMask = os.FileMode(0o777) -// Git tree entry mode constants. Git stores the entry type in the high bits of -// a six-digit octal mode; gitTypeMask isolates them so a symlink blob (120000) -// can be distinguished from a regular blob (100644 / 100755) at materialization -// time. +// Git stores the entry type in the high bits of a six-digit octal mode; +// gitTypeMask isolates them so a symlink blob (120000) can be distinguished +// from a regular blob (100644 / 100755) at materialization time. const ( gitTypeMask = uint64(0o170000) gitTypeSymlink = uint64(0o120000) @@ -43,6 +42,7 @@ func WithForceCopy() LinkTreeOption { // blobStore is used to resolve blob entries, treeStore is used to resolve subtree entries. func LinkTree( ctx context.Context, + v Venv, blobStore *Store, treeStore *Store, t *git.Tree, @@ -54,7 +54,7 @@ func LinkTree( opt(&o) } - return linkTree(ctx, blobStore, treeStore, t, targetDir, targetDir, &o) + return linkTree(ctx, v, blobStore, treeStore, t, targetDir, targetDir, &o) } // linkTree is the recursive implementation behind LinkTree. rootDir is the @@ -64,6 +64,7 @@ func LinkTree( // resolve outside the original tree even when the link sits in a subdirectory. func linkTree( ctx context.Context, + v Venv, blobStore *Store, treeStore *Store, t *git.Tree, @@ -102,10 +103,9 @@ func linkTree( parentDirPath := filepath.Dir(dirPath) delete(dirsToCreate, parentDirPath) - // Create work items based on entry type. Git stores symlinks as blobs - // whose content is the link target; the entry mode (120000) is the - // only signal that distinguishes them from regular files, so dispatch - // on the mode here instead of treating every blob as a file to copy. + // Git encodes a symlink as a blob whose body is the link target; the + // entry mode (120000) is the only signal that distinguishes it from a + // regular file, so dispatch on the mode rather than the type. switch entry.Type { case "blob": itemType := "link" @@ -129,15 +129,12 @@ func linkTree( } } - fs := blobStore.FS() - for dirPath := range dirsToCreate { - if err := fs.MkdirAll(dirPath, DefaultDirPerms); err != nil { + if err := v.FS.MkdirAll(dirPath, DefaultDirPerms); err != nil { return fmt.Errorf("mkdir %s: %w", dirPath, err) } } - // Use errgroup for concurrent processing g, ctx := errgroup.WithContext(ctx) // Use half the available CPUs (at least 1) to avoid saturating I/O during tree materialization. @@ -145,17 +142,16 @@ func linkTree( maxWorkers := max(1, runtime.GOMAXPROCS(0)/scalingFactor) g.SetLimit(maxWorkers) - // Process work items concurrently for _, work := range workItems { g.Go(func() error { switch work.itemType { case "link": - err := blobContent.Link(ctx, work.entry.Hash, work.path, gitFilePerm(work.entry.Mode), linkOpts...) + err := blobContent.Link(ctx, v, work.entry.Hash, work.path, gitFilePerm(work.entry.Mode), linkOpts...) if err != nil { return fmt.Errorf("link blob %s: %w", work.path, err) } case "symlink": - target, err := blobContent.Read(work.entry.Hash) + target, err := blobContent.Read(v, work.entry.Hash) if err != nil { return fmt.Errorf("read symlink blob %s: %w", work.entry.Hash, err) } @@ -164,11 +160,11 @@ func linkTree( return err } - if err := vfs.Symlink(fs, string(target), work.path); err != nil { + if err := vfs.Symlink(v.FS, string(target), work.path); err != nil { return fmt.Errorf("symlink %s -> %s: %w", work.path, string(target), err) } case "subtree": - treeData, err := treeContent.Read(work.entry.Hash) + treeData, err := treeContent.Read(v, work.entry.Hash) if err != nil { return fmt.Errorf("read tree %s: %w", work.entry.Hash, err) } @@ -178,7 +174,7 @@ func linkTree( return fmt.Errorf("parse tree %s: %w", work.entry.Hash, err) } - err = linkTree(ctx, blobStore, treeStore, subTree, rootDir, work.path, o) + err = linkTree(ctx, v, blobStore, treeStore, subTree, rootDir, work.path, o) if err != nil { return fmt.Errorf("link subtree %s: %w", work.path, err) } @@ -188,7 +184,6 @@ func linkTree( }) } - // Wait for all goroutines to complete and return first error if any return g.Wait() } @@ -209,10 +204,10 @@ func gitFilePerm(mode string) os.FileMode { return os.FileMode(n) & unixPermMask } -// gitEntryIsSymlink reports whether the given git tree entry mode is the -// symlink type (120000). Git tree modes encode the entry type in their high -// bits, so the permission-only view used by gitFilePerm cannot distinguish a -// symlink blob from a regular blob. +// gitEntryIsSymlink reports whether mode encodes the git symlink type +// (120000). The high bits of a six-digit octal mode carry the entry type; +// permission-only inspection cannot distinguish a symlink blob from a regular +// blob. func gitEntryIsSymlink(mode string) bool { if mode == "" { return false diff --git a/internal/cas/tree_test.go b/internal/cas/tree_test.go index 31adfca337..e7dc73bd24 100644 --- a/internal/cas/tree_test.go +++ b/internal/cas/tree_test.go @@ -138,13 +138,19 @@ invalid format`), } } +// TestLinkTreeSymlinks pins the materialize-time symlink contract: a 120000 +// entry whose blob holds the link target string surfaces as a real symbolic +// link in the destination, and absolute or dot-dot targets that climb above +// the root are refused. func TestLinkTreeSymlinks(t *testing.T) { t.Parallel() + l := logger.CreateLogger() + tests := []struct { - wantLinks map[string]string // path -> expected target - wantBlobs map[string][]byte // path -> expected file content for non-symlink entries - storeTargets map[string]string // hash -> target string for entries the test should not realize as symlinks + wantLinks map[string]string + wantBlobs map[string][]byte + storeTargets map[string]string name string treeData []byte wantErr bool @@ -179,11 +185,10 @@ func TestLinkTreeSymlinks(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - l := logger.CreateLogger() + v := newMemVenv(t) + require.NoError(t, v.FS.MkdirAll("/store", 0o755)) - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + store := cas.NewStore("/store") content := cas.NewContent(store) tree, err := git.ParseTree(tt.treeData, "test-repo") @@ -197,18 +202,18 @@ func TestLinkTreeSymlinks(t *testing.T) { target = tt.wantLinks[entry.Path] } - require.NoError(t, content.Store(l, entry.Hash, []byte(target))) + require.NoError(t, content.Store(l, v, entry.Hash, []byte(target))) default: if data, ok := tt.wantBlobs[entry.Path]; ok { - require.NoError(t, content.Store(l, entry.Hash, data)) + require.NoError(t, content.Store(l, v, entry.Hash, data)) } } } targetDir := "/target" - require.NoError(t, memFs.MkdirAll(targetDir, 0755)) + require.NoError(t, v.FS.MkdirAll(targetDir, 0o755)) - err = cas.LinkTree(t.Context(), store, store, tree, targetDir) + err = cas.LinkTree(t.Context(), v, store, store, tree, targetDir) if tt.wantErr { require.Error(t, err) return @@ -219,18 +224,18 @@ func TestLinkTreeSymlinks(t *testing.T) { for path, wantTarget := range tt.wantLinks { full := filepath.Join(targetDir, path) - info, err := vfs.Lstat(memFs, full) + info, err := vfs.Lstat(v.FS, full) require.NoError(t, err) assert.NotZero(t, info.Mode()&os.ModeSymlink, "%s is not a symlink (mode=%s)", full, info.Mode()) - got, err := vfs.Readlink(memFs, full) + got, err := vfs.Readlink(v.FS, full) require.NoError(t, err) assert.Equal(t, wantTarget, got) } for path, wantContent := range tt.wantBlobs { full := filepath.Join(targetDir, path) - got, err := vfs.ReadFile(memFs, full) + got, err := vfs.ReadFile(v.FS, full) require.NoError(t, err) assert.Equal(t, wantContent, got) } @@ -241,9 +246,11 @@ func TestLinkTreeSymlinks(t *testing.T) { func TestLinkTree(t *testing.T) { t.Parallel() + l := logger.CreateLogger() + tests := []struct { name string - setupStore func(t *testing.T) (*cas.Store, vfs.FS, string) + setupStore func(t *testing.T, v cas.Venv) (*cas.Store, string) treeData []byte wantFiles []struct { path string @@ -255,27 +262,27 @@ func TestLinkTree(t *testing.T) { }{ { name: "basic tree with files and directories", - setupStore: func(t *testing.T) (*cas.Store, vfs.FS, string) { + setupStore: func(t *testing.T, v cas.Venv) (*cas.Store, string) { t.Helper() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") content := cas.NewContent(store) // Create test content testData := []byte("test content") testHash := "a1b2c3d4" - err := content.Store(nil, testHash, testData) + err := content.Store(l, v, testHash, testData) require.NoError(t, err) // Create and store the src directory tree data srcTreeData := `100644 blob a1b2c3d4 README.md` srcTreeHash := "i9j0k1l2" - err = content.Store(nil, srcTreeHash, []byte(srcTreeData)) + err = content.Store(l, v, srcTreeHash, []byte(srcTreeData)) require.NoError(t, err) - return store, memFs, testHash + return store, testHash }, treeData: []byte(`100644 blob a1b2c3d4 README.md 100755 blob a1b2c3d4 scripts/test.sh @@ -312,14 +319,14 @@ func TestLinkTree(t *testing.T) { }, { name: "empty tree", - setupStore: func(t *testing.T) (*cas.Store, vfs.FS, string) { + setupStore: func(t *testing.T, v cas.Venv) (*cas.Store, string) { t.Helper() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) - return store, memFs, "" + store := cas.NewStore("/store") + + return store, "" }, treeData: []byte(""), wantFiles: []struct { @@ -331,14 +338,14 @@ func TestLinkTree(t *testing.T) { }, { name: "tree with missing content", - setupStore: func(t *testing.T) (*cas.Store, vfs.FS, string) { + setupStore: func(t *testing.T, v cas.Venv) (*cas.Store, string) { t.Helper() - memFs := vfs.NewMemMapFS() - require.NoError(t, memFs.MkdirAll("/store", 0755)) - store := cas.NewStore("/store").WithFS(memFs) + require.NoError(t, v.FS.MkdirAll("/store", 0755)) + + store := cas.NewStore("/store") - return store, memFs, "" + return store, "" }, treeData: []byte(`100644 blob missing123 README.md`), wantErr: true, @@ -349,8 +356,10 @@ func TestLinkTree(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + v := newMemVenv(t) + // Setup store - store, memFs, _ := tt.setupStore(t) + store, _ := tt.setupStore(t, v) // Parse the tree tree, err := git.ParseTree(tt.treeData, "test-repo") @@ -358,10 +367,10 @@ func TestLinkTree(t *testing.T) { // Create target directory targetDir := "/target" - require.NoError(t, memFs.MkdirAll(targetDir, 0755)) + require.NoError(t, v.FS.MkdirAll(targetDir, 0755)) // Link the tree - err = cas.LinkTree(t.Context(), store, store, tree, targetDir) + err = cas.LinkTree(t.Context(), v, store, store, tree, targetDir) if tt.wantErr { require.Error(t, err) return @@ -374,19 +383,19 @@ func TestLinkTree(t *testing.T) { path := filepath.Join(targetDir, want.path) // Check if file/directory exists - info, err := memFs.Stat(path) + info, err := v.FS.Stat(path) require.NoError(t, err) assert.Equal(t, want.isDir, info.IsDir()) if !want.isDir { // Check file content - data, err := vfs.ReadFile(memFs, path) + data, err := vfs.ReadFile(v.FS, path) require.NoError(t, err) assert.Equal(t, want.content, data) // Verify content matches store by reading from both locations storePath := filepath.Join(store.Path(), want.hash[:2], want.hash) - storeData, err := vfs.ReadFile(memFs, storePath) + storeData, err := vfs.ReadFile(v.FS, storePath) require.NoError(t, err) assert.Equal(t, storeData, data) } diff --git a/internal/cas/venv.go b/internal/cas/venv.go new file mode 100644 index 0000000000..fc25a45eaa --- /dev/null +++ b/internal/cas/venv.go @@ -0,0 +1,67 @@ +package cas + +import ( + "errors" + + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/internal/vexec" + "github.com/gruntwork-io/terragrunt/internal/vfs" +) + +// ErrVenvFSUnset is the panic value [Venv.RequireFS] raises when a +// function declares it needs v.FS and the caller hands in a Venv with +// FS == nil. Production callers build Venv through [OSVenv], so the +// panic surfaces a test misconfiguration rather than a runtime +// condition. +var ErrVenvFSUnset = errors.New("cas.Venv.FS is required but unset") + +// ErrVenvGitUnset is the panic value [Venv.RequireGit] raises when a +// function declares it needs v.Git and the caller hands in a Venv with +// Git == nil. Production callers build Venv through [OSVenv], so the +// panic surfaces a test misconfiguration rather than a runtime +// condition. +var ErrVenvGitUnset = errors.New("cas.Venv.Git is required but unset") + +// Venv bundles the virtualized dependencies CAS operations need so callers +// pass both per call rather than CAS holding them as struct fields. Either +// field can be a stub in tests. +// +// Functions document which handles they touch and call [Venv.RequireFS] +// or [Venv.RequireGit] at entry so a missing handle panics at the +// offending call site instead of inside an unrelated stack frame. +type Venv struct { + // FS is the filesystem CAS reads and writes through. + FS vfs.FS + // Git shells out to the git binary. + Git *git.GitRunner +} + +// OSVenv builds the production [Venv]: the real OS filesystem and a git +// runner backed by [vexec.NewOSExec]. Returns an error if the git binary +// is not on PATH. +func OSVenv() (Venv, error) { + runner, err := git.NewGitRunner(vexec.NewOSExec()) + if err != nil { + return Venv{}, err + } + + return Venv{FS: vfs.NewOSFS(), Git: runner}, nil +} + +// RequireFS panics with [ErrVenvFSUnset] when v.FS is nil. Functions +// that touch the filesystem call this as their first statement so the +// contract sits next to the signature. +func (v Venv) RequireFS() { + if v.FS == nil { + panic(ErrVenvFSUnset) + } +} + +// RequireGit panics with [ErrVenvGitUnset] when v.Git is nil. +// Functions that shell out to git call this as their first statement so +// the contract sits next to the signature. +func (v Venv) RequireGit() { + if v.Git == nil { + panic(ErrVenvGitUnset) + } +} diff --git a/internal/cas/venv_test.go b/internal/cas/venv_test.go new file mode 100644 index 0000000000..143441d625 --- /dev/null +++ b/internal/cas/venv_test.go @@ -0,0 +1,42 @@ +package cas_test + +import ( + "testing" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/internal/vexec" + "github.com/gruntwork-io/terragrunt/internal/vfs" + "github.com/stretchr/testify/require" +) + +// TestVenvRequireFS pins the FS contract: the zero Venv panics with the +// sentinel, a populated Venv passes. +func TestVenvRequireFS(t *testing.T) { + t.Parallel() + + require.PanicsWithValue(t, cas.ErrVenvFSUnset, func() { + cas.Venv{}.RequireFS() + }) + + require.NotPanics(t, func() { + cas.Venv{FS: vfs.NewOSFS()}.RequireFS() + }) +} + +// TestVenvRequireGit pins the Git contract. A Venv with FS but no Git +// must still panic; only a populated Git satisfies the check. +func TestVenvRequireGit(t *testing.T) { + t.Parallel() + + runner, err := git.NewGitRunner(vexec.NewOSExec()) + require.NoError(t, err) + + require.PanicsWithValue(t, cas.ErrVenvGitUnset, func() { + cas.Venv{FS: vfs.NewOSFS()}.RequireGit() + }) + + require.NotPanics(t, func() { + cas.Venv{FS: vfs.NewOSFS(), Git: runner}.RequireGit() + }) +} diff --git a/internal/getter/casgetter.go b/internal/getter/casgetter.go index 654741eeb2..1ea3b03ae9 100644 --- a/internal/getter/casgetter.go +++ b/internal/getter/casgetter.go @@ -15,23 +15,68 @@ import ( // ErrDirectoryNotFound is returned when CASGetter cannot stat a local source. var ErrDirectoryNotFound = errors.New("directory not found") -// CASGetter is the go-getter implementation that routes git/file sources -// through Terragrunt's content-addressable store. +// SchemeGit is the forced-getter marker for git sources. +const SchemeGit = "git" + +// CASGetter is the go-getter implementation that routes git, local, and +// configured non-git sources through Terragrunt's content-addressable store. type CASGetter struct { CAS *cas.CAS Logger log.Logger Opts *cas.CloneOptions + Venv cas.Venv + fetchers map[string]getter.Getter + resolvers map[string]cas.SourceResolver Detectors []Detector } -// NewCASGetter constructs a CASGetter wired with the standard detector chain. -// A nil opts is treated as a zero-value CloneOptions so Get can rely on g.Opts. -func NewCASGetter(l log.Logger, c *cas.CAS, opts *cas.CloneOptions) *CASGetter { +// CASGetterOption mutates a CASGetter at construction time. +type CASGetterOption func(*CASGetter) + +// WithGenericFetchers registers a scheme→getter map for non-git +// sources. Schemes not in the map fall through to whichever bare +// go-getter the outer client registers next. +func WithGenericFetchers(m map[string]getter.Getter) CASGetterOption { + return func(g *CASGetter) { g.fetchers = m } +} + +// WithGenericResolvers registers a scheme→resolver map for probing +// non-git sources. Schemes not in the map go through the fetch path +// without a probe (download then content-hash). +func WithGenericResolvers(m map[string]cas.SourceResolver) CASGetterOption { + return func(g *CASGetter) { g.resolvers = m } +} + +// WithDefaultGenericDispatch is the shorthand for the canonical pairing of +// [WithGenericFetchers]([DefaultGenericFetchers]) and [WithGenericResolvers] +// ([DefaultSourceResolvers]). fetcherOpts are forwarded so HTTP auth headers +// still reach the fetcher. +func WithDefaultGenericDispatch(fetcherOpts ...GenericFetcherOption) CASGetterOption { + return func(g *CASGetter) { + g.fetchers = DefaultGenericFetchers(fetcherOpts...) + g.resolvers = DefaultSourceResolvers() + } +} + +// NewCASGetter constructs a CASGetter with the standard detector chain for +// git and file canonicalization. Pass [WithDefaultGenericDispatch] (or +// [WithGenericFetchers] + [WithGenericResolvers]) to enable the non-git +// dispatch path. +// +// Requires v.FS and v.Git: the file-path branch in Detect stats through +// v.FS, and the git-path branch in Get clones through v.Git. Panics +// with [cas.ErrVenvFSUnset] or [cas.ErrVenvGitUnset] respectively when +// either is nil. Production callers build v through [cas.OSVenv], which +// always supplies both. +func NewCASGetter(l log.Logger, c *cas.CAS, v cas.Venv, opts *cas.CloneOptions, options ...CASGetterOption) *CASGetter { + v.RequireFS() + v.RequireGit() + if opts == nil { opts = &cas.CloneOptions{} } - return &CASGetter{ + g := &CASGetter{ Detectors: []Detector{ new(GitHubDetector), new(GitDetector), @@ -42,73 +87,34 @@ func NewCASGetter(l log.Logger, c *cas.CAS, opts *cas.CloneOptions) *CASGetter { CAS: c, Logger: l, Opts: opts, + Venv: v, + } + + for _, opt := range options { + opt(g) } + + return g } -// Get clones (or copies, for local sources) the source into the CAS store and -// links it into req.Dst. +// Get clones (or copies) the source into the CAS store and links it into +// req.Dst. Behavior is selected by req.Forced (the scheme). func (g *CASGetter) Get(ctx context.Context, req *getter.Request) error { if req.Copy { - // Local directory: persist to CAS and link. + // Local directory. var linkOpts []cas.LinkTreeOption if g.Opts.Mutable { linkOpts = append(linkOpts, cas.WithForceCopy()) } - return g.CAS.StoreLocalDirectory(ctx, g.Logger, req.Src, req.Dst, linkOpts...) + return g.CAS.StoreLocalDirectory(ctx, g.Logger, g.Venv, req.Src, req.Dst, linkOpts...) } - ref := "" - - u := req.URL() - - q := u.Query() - if len(q) > 0 { - ref = q.Get("ref") - q.Del("ref") - - u.RawQuery = q.Encode() + if g.isGenericScheme(req.Forced) { + return g.getGeneric(ctx, req) } - // Copy so concurrent Get calls against the same getter don't race on - // Branch/Dir mutation. - opts := *g.Opts - opts.Branch = ref - opts.Dir = req.Dst - - return g.CAS.Clone(ctx, g.Logger, &opts, GitCloneURL(u.String())) -} - -// GitCloneURL turns a v2-detected URL string into a clone target the -// underlying git client accepts. -// -// Two normalizations are needed: -// -// 1. Strip a leading "git::". The v2 outer client only splits the forced -// prefix into req.Forced when the source carried it on entry; when -// CASGetter.Detect runs its own detector chain (e.g. for github -// shorthand or git@host:path SCP), the v2 GitDetector reattaches -// "git::" to its result, and req.URL().String() preserves it. Passing -// it through to git makes git look up the missing "git-remote-git" -// helper. -// 2. Convert "ssh://git@host/path" into the SCP-style "git@host:path" -// git expects for SSH cloning. URLs that carry an explicit port -// (e.g. "ssh://git@host:2222/path") keep the URL form because git's -// SCP shorthand has no syntax for a port. -func GitCloneURL(urlStr string) string { - urlStr = strings.TrimPrefix(urlStr, "git::") - - if !strings.HasPrefix(urlStr, "ssh://") { - return urlStr - } - - if u, err := url.Parse(urlStr); err == nil && u.Port() != "" { - return urlStr - } - - after := strings.TrimPrefix(urlStr, "ssh://") - - return strings.Replace(after, "/", ":", 1) + return g.getGit(ctx, req) } // GetFile is not supported for the CAS getter. @@ -121,16 +127,26 @@ func (g *CASGetter) Mode(_ context.Context, _ *url.URL) (getter.Mode, error) { return getter.ModeDir, nil } -// Detect canonicalizes the source via the detector chain. For local sources -// it sets req.Copy=true so Get takes the StoreLocalDirectory path. +// Detect canonicalizes the source via the detector chain. Local +// sources get req.Copy = true so Get takes the StoreLocalDirectory +// path. Non-git schemes covered by Fetchers get archive=false appended +// to the URL so the outer client does not pre-decompress before +// invoking Get. func (g *CASGetter) Detect(req *getter.Request) (bool, error) { - if req.Forced == "git" { + if req.Forced == SchemeGit { return true, nil } if after, ok := strings.CutPrefix(req.Src, "git::"); ok { req.Src = after - req.Forced = "git" + req.Forced = SchemeGit + + return true, nil + } + + if scheme, src, ok := g.matchGenericScheme(req); ok { + req.Forced = scheme + req.Src = appendDisableArchive(src) return true, nil } @@ -146,7 +162,12 @@ func (g *CASGetter) Detect(req *getter.Request) (bool, error) { } if _, isFileDetector := detector.(*getter.FileDetector); isFileDetector { - info, statErr := g.CAS.FS().Stat(src) + // Repeats the NewCASGetter check so a caller that hand- + // assembles a CASGetter and skips the constructor still + // gets the typed panic instead of a runtime nil-deref. + g.Venv.RequireFS() + + info, statErr := g.Venv.FS.Stat(src) if statErr != nil { return false, fmt.Errorf("%w: %s", ErrDirectoryNotFound, src) } @@ -166,3 +187,244 @@ func (g *CASGetter) Detect(req *getter.Request) (bool, error) { return false, nil } + +// GitCloneURL turns a v2-detected URL string into a clone target the +// underlying git client accepts. +// +// Two normalizations are needed: +// +// 1. Strip a leading "git::". The v2 outer client only splits the +// forced prefix into req.Forced when the source carried it on +// entry; when CASGetter.Detect runs its own detector chain (e.g. +// for github shorthand or git@host:path SCP), the v2 GitDetector +// reattaches "git::" to its result, and req.URL().String() +// preserves it. Passing it through to git makes git look up the +// missing "git-remote-git" helper. +// 2. Convert "ssh://git@host/path" into the SCP-style +// "git@host:path" git expects for SSH cloning. URLs that carry an +// explicit port (e.g. "ssh://git@host:2222/path") keep the URL +// form because git's SCP shorthand has no syntax for a port. +func GitCloneURL(urlStr string) string { + urlStr = strings.TrimPrefix(urlStr, "git::") + + if !strings.HasPrefix(urlStr, "ssh://") { + return urlStr + } + + if u, err := url.Parse(urlStr); err == nil && u.Port() != "" { + return urlStr + } + + after := strings.TrimPrefix(urlStr, "ssh://") + + return strings.Replace(after, "/", ":", 1) +} + +// matchGenericScheme reports whether req should route through the non-git +// generic path and returns the scheme plus the (possibly canonicalized) +// source URL. req.Forced (set by the outer client when it stripped a +// "::" prefix) wins; otherwise the URL scheme is consulted. +// +// URL-scheme claiming is restricted to http and https. The bare go-getter +// v2 protocol getters for s3, gcs, hg, smb reject `://...` URLs +// (they expect canonical HTTPS forms or the forced-prefix syntax), so +// claiming those schemes here would set up a doomed inner fetch on every +// cache miss. +// +// HTTPS URLs against AWS S3 hosts are an exception: virtual-host forms +// (`.s3.amazonaws.com/`) would route through the HTTPS +// fetcher and bypass S3 auth, so the matcher rewrites them to the path- +// style form the bare s3 getter accepts and claims the s3 scheme. +func (g *CASGetter) matchGenericScheme(req *getter.Request) (string, string, bool) { + if g.fetchers == nil { + return "", req.Src, false + } + + if scheme, ok := g.lookupFetcher(strings.ToLower(req.Forced)); ok { + src := req.Src + + // A forced s3 prefix with an AWS virtual-host URL still needs + // the rewrite, since the bare s3 getter rejects virtual-host + // hosts regardless of how the scheme was claimed. + if scheme == SchemeS3 { + if u, perr := url.Parse(req.Src); perr == nil { + if canonical, cok := canonicalAWSS3HTTPSURL(u); cok { + src = canonical + } + } + } + + return scheme, src, true + } + + u, err := url.Parse(req.Src) + if err != nil || u.Scheme == "" { + return "", req.Src, false + } + + scheme := strings.ToLower(u.Scheme) + switch scheme { + case SchemeHTTP, SchemeHTTPS: + if canonical, ok := canonicalAWSS3HTTPSURL(u); ok { + if _, fok := g.fetchers[SchemeS3]; fok { + return SchemeS3, canonical, true + } + } + + if _, ok := g.fetchers[scheme]; ok { + return scheme, req.Src, true + } + } + + return "", req.Src, false +} + +// isGenericScheme reports whether the forced scheme corresponds to a +// fetcher registered for the generic (non-git) dispatch path. +func (g *CASGetter) isGenericScheme(forced string) bool { + if g.fetchers == nil { + return false + } + + _, ok := g.lookupFetcher(strings.ToLower(forced)) + + return ok +} + +// lookupFetcher resolves scheme (or an alias of it) to a fetcher entry +// and returns the registry key on a hit. "gs" maps to "gcs"; all other +// inputs are taken as the registry key directly. +func (g *CASGetter) lookupFetcher(scheme string) (string, bool) { + if scheme == "gs" { + scheme = SchemeGCS + } + + if _, ok := g.fetchers[scheme]; ok { + return scheme, true + } + + return "", false +} + +// appendDisableArchive adds archive=false to the URL query, preserving +// any existing value. The marker tells the outer v2 client to skip its +// archive-extension pre-decompression so req.Dst reaches Get pointing +// at the original destination instead of a temporary archive path. +func appendDisableArchive(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + + q := u.Query() + if q.Get("archive") != "" { + return rawURL + } + + q.Set("archive", "false") + u.RawQuery = q.Encode() + + return u.String() +} + +// stripDisableArchive removes the archive=false marker before handing the +// URL to the inner client so archive extension detection runs there. +func stripDisableArchive(u *url.URL) string { + if u == nil { + return "" + } + + clone := *u + q := clone.Query() + + if q.Get("archive") == "false" { + q.Del("archive") + clone.RawQuery = q.Encode() + } + + return clone.String() +} + +// getGit clones via [cas.CAS.Clone] after lifting ?ref= out of the URL +// into [cas.CloneOptions.Branch]. +func (g *CASGetter) getGit(ctx context.Context, req *getter.Request) error { + ref := "" + + u := req.URL() + + q := u.Query() + if len(q) > 0 { + ref = q.Get("ref") + q.Del("ref") + + u.RawQuery = q.Encode() + } + + return g.CAS.Clone(ctx, g.Logger, g.Venv, GitCloneURL(u.String()), + cas.WithDir(req.Dst), + cas.WithBranch(ref), + cas.WithDepth(g.Opts.Depth), + cas.WithMutable(g.Opts.Mutable), + cas.WithIncludedGitFiles(g.Opts.IncludedGitFiles)) +} + +// getGeneric routes a non-git source through CAS. The archive=false +// marker Detect injected gets stripped before passing the URL to the +// inner getter.Client so archive extraction runs there. +func (g *CASGetter) getGeneric(ctx context.Context, req *getter.Request) error { + scheme, ok := g.lookupFetcher(strings.ToLower(req.Forced)) + if !ok { + return fmt.Errorf("CASGetter: no fetcher registered for scheme %q", strings.ToLower(req.Forced)) + } + + bare := g.fetchers[scheme] + + innerURL := stripDisableArchive(req.URL()) + + opts := *g.Opts + opts.Dir = req.Dst + + return g.CAS.FetchSource(ctx, g.Logger, g.Venv, &opts, cas.SourceRequest{ + Scheme: scheme, + URL: innerURL, + Resolver: g.resolvers[scheme], + Fetch: g.buildInnerFetch(bare, scheme, innerURL), + }) +} + +// buildInnerFetch returns a SourceFetcher that downloads urlStr into a +// fresh temp directory through a single-getter inner [getter.Client] and +// ingests the result via [cas.CAS.IngestDirectory]. The inner client uses +// the default decompressor map so `.tar.gz`/`.zip` URLs extract before +// ingest. +// +// scheme is set on the inner request's Forced field so the bare +// scheme-specific getter still claims the request. The bare go-getter v2 +// s3 and gcs getters reject `http://`/`gs://` URLs unless Forced matches +// their validScheme; without this the inner client falls through with a +// generic "error downloading". +func (g *CASGetter) buildInnerFetch(bare getter.Getter, scheme, urlStr string) cas.SourceFetcher { + return func(ctx context.Context, l log.Logger, v cas.Venv, suggestedKey string) (string, error) { + tempDir, cleanup, err := g.CAS.MakeFetchTempDir(l, v) + if err != nil { + return "", err + } + + defer cleanup() + + inner := &getter.Client{ + Getters: []getter.Getter{bare}, + } + + if _, err := inner.Get(ctx, &getter.Request{ + Src: urlStr, + Dst: tempDir, + Forced: scheme, + GetMode: getter.ModeAny, + }); err != nil { + return "", err + } + + return g.CAS.IngestDirectory(l, v, tempDir, suggestedKey) + } +} diff --git a/internal/getter/casgetter_detect_test.go b/internal/getter/casgetter_detect_test.go new file mode 100644 index 0000000000..987a248712 --- /dev/null +++ b/internal/getter/casgetter_detect_test.go @@ -0,0 +1,436 @@ +package getter_test + +import ( + "net/url" + "path/filepath" + "testing" + + tgcas "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/internal/vfs" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + gogetter "github.com/hashicorp/go-getter/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCASGetterDetect_GitForcedPrefix(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + tests := []struct { + name string + src string + reqForced string + want string + }{ + { + name: "git:: forced prefix is claimed and Src is stripped", + src: "git::https://example.com/repo.git", + want: getter.SchemeGit, + }, + { + name: "Forced field set to git is claimed without inspecting Src", + src: "https://example.com/repo.git", + reqForced: getter.SchemeGit, + want: getter.SchemeGit, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := &gogetter.Request{Src: tt.src, Forced: tt.reqForced} + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, tt.want, req.Forced) + }) + } +} + +func TestCASGetterDetect_GenericForcedPrefixes(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + tests := []struct { + forced string + src string + }{ + {forced: getter.SchemeS3, src: "s3.amazonaws.com/bucket/key.tgz"}, + {forced: getter.SchemeGCS, src: "www.googleapis.com/storage/v1/bucket/key.tgz"}, + {forced: getter.SchemeHTTP, src: "example.com/mod.tar.gz"}, + {forced: getter.SchemeHTTPS, src: "example.com/mod.tar.gz"}, + {forced: getter.SchemeHg, src: "example.com/repo"}, + {forced: getter.SchemeSMB, src: "example.com/share/path"}, + } + + for _, tt := range tests { + t.Run("forced "+tt.forced, func(t *testing.T) { + t.Parallel() + + req := &gogetter.Request{Src: tt.src, Forced: tt.forced} + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, tt.forced, req.Forced) + // Detect appends archive=false so the outer v2 client + // skips its pre-decompression step. + u, parseErr := url.Parse(req.Src) + require.NoError(t, parseErr) + assert.Equal(t, "false", u.Query().Get("archive")) + }) + } +} + +// TestCASGetterDetect_SchemeDetectionByURL pins URL-scheme claiming. +// Only http and https URLs are claimed by URL scheme alone; s3, gcs, +// hg, and smb sources route through the bare go-getter v2 protocol +// getters and reject the `://...` form, so claiming those +// schemes by URL would set up a doomed inner fetch on every cache +// miss. Routing those sources through CAS requires the explicit +// forced-prefix form (`s3::https://...`, `gcs::https://...`). +func TestCASGetterDetect_SchemeDetectionByURL(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + claimed := []struct { + name string + src string + forced string + }{ + {name: "http URL", src: "http://example.com/mod.tar.gz", forced: getter.SchemeHTTP}, + {name: "https URL", src: "https://example.com/mod.tar.gz", forced: getter.SchemeHTTPS}, + } + + for _, tt := range claimed { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := &gogetter.Request{Src: tt.src} + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok, "detector should claim %s", tt.src) + assert.Equal(t, tt.forced, req.Forced) + }) + } + + unclaimed := []struct { + name string + src string + }{ + {name: "s3 URL", src: "s3://bucket/key.tgz"}, + {name: "gs URL", src: "gs://bucket/key.tgz"}, + } + + for _, tt := range unclaimed { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := &gogetter.Request{Src: tt.src} + + ok, _ := g.Detect(req) + assert.False(t, ok, "URL-scheme claim must not match %s; require :: forced prefix instead", tt.src) + }) + } +} + +// TestCASGetterDetect_AWSS3HTTPSRoutesToS3Fetcher pins that an https URL +// against an AWS S3 endpoint claims the s3 scheme (so the inner fetch +// uses S3 auth) and is rewritten to the path-style form the bare s3 +// getter accepts. Without this, virtual-host URLs would route through +// the plain HTTPS fetcher and silently fail for private buckets. +func TestCASGetterDetect_AWSS3HTTPSRoutesToS3Fetcher(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + tests := []struct { + name string + src string + wantSrc string + }{ + { + name: "global virtual-host rewritten to global path-style", + src: "https://my-bucket.s3.amazonaws.com/path.zip", + wantSrc: "https://s3.amazonaws.com/my-bucket/path.zip", + }, + { + name: "regional virtual-host rewritten to regional path-style", + src: "https://my-bucket.s3-us-west-2.amazonaws.com/path.zip", + wantSrc: "https://s3-us-west-2.amazonaws.com/my-bucket/path.zip", + }, + { + name: "global path-style claimed unchanged", + src: "https://s3.amazonaws.com/my-bucket/path.zip", + wantSrc: "https://s3.amazonaws.com/my-bucket/path.zip", + }, + { + name: "regional path-style claimed unchanged", + src: "https://s3-us-west-2.amazonaws.com/my-bucket/path.zip", + wantSrc: "https://s3-us-west-2.amazonaws.com/my-bucket/path.zip", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := &gogetter.Request{Src: tt.src} + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok, "detector should claim %s", tt.src) + assert.Equal(t, getter.SchemeS3, req.Forced, "AWS S3 host must route through the s3 fetcher") + + u, parseErr := url.Parse(req.Src) + require.NoError(t, parseErr) + + q := u.Query() + assert.Equal(t, "false", q.Get("archive"), "Detect must append archive=false") + + q.Del("archive") + u.RawQuery = q.Encode() + + assert.Equal(t, tt.wantSrc, u.String(), "Src must be canonicalized to the bare s3 getter's accepted form") + }) + } +} + +// TestCASGetterDetect_AWSS3HTTPSPreservesQueryAndVersion pins that the +// ?version selector and other query parameters survive the rewrite, so +// versioned S3 objects keep resolving to the same VersionId after +// canonicalization. +func TestCASGetterDetect_AWSS3HTTPSPreservesQueryAndVersion(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + req := &gogetter.Request{Src: "https://my-bucket.s3.amazonaws.com/path.zip?version=abc123"} + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, getter.SchemeS3, req.Forced) + + u, parseErr := url.Parse(req.Src) + require.NoError(t, parseErr) + assert.Equal(t, "abc123", u.Query().Get("version")) + assert.Equal(t, "/my-bucket/path.zip", u.Path) + assert.Equal(t, "s3.amazonaws.com", u.Host) +} + +// TestCASGetterDetect_NonS3AmazonAWSHostFallsThroughToHTTPS pins that +// non-S3 amazonaws.com hosts (iam, sts, ec2, ...) stay on the HTTPS +// fetcher rather than being misrouted through s3. canonicalAWSS3HTTPSURL +// is the gate. +func TestCASGetterDetect_NonS3AmazonAWSHostFallsThroughToHTTPS(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + tests := []string{ + "https://iam.amazonaws.com/bucket/key.tgz", + "https://sts.amazonaws.com/bucket/key.tgz", + "https://ec2.amazonaws.com/bucket/key.tgz", + } + + for _, src := range tests { + t.Run(src, func(t *testing.T) { + t.Parallel() + + req := &gogetter.Request{Src: src} + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, getter.SchemeHTTPS, req.Forced, "non-S3 amazonaws.com hosts must route through HTTPS, not s3") + }) + } +} + +// TestCASGetterDetect_S3ForcedPrefixCanonicalizesVHost pins that +// `s3::https://.s3.amazonaws.com/` is rewritten to the +// path-style form before being handed to the bare s3 getter, which +// rejects virtual-host hosts. Without this rewrite, the forced-prefix +// form would set up a doomed inner fetch on every cache miss. +func TestCASGetterDetect_S3ForcedPrefixCanonicalizesVHost(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + req := &gogetter.Request{ + Src: "https://my-bucket.s3.amazonaws.com/path.zip", + Forced: getter.SchemeS3, + } + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok) + assert.Equal(t, getter.SchemeS3, req.Forced) + + u, parseErr := url.Parse(req.Src) + require.NoError(t, parseErr) + assert.Equal(t, "/my-bucket/path.zip", u.Path) + assert.Equal(t, "s3.amazonaws.com", u.Host) +} + +// TestCASGetterDetect_ForcedPrefixNormalizesAlias pins that `gs::` +// forced inputs route through the gcs fetcher entry. Without this, +// `gs::` users would silently miss the GCS dispatch path. +func TestCASGetterDetect_ForcedPrefixNormalizesAlias(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + req := &gogetter.Request{ + Src: "https://www.googleapis.com/storage/v1/bucket/mod.tgz", + Forced: "gs", + } + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok, "gs:: forced prefix must claim through the gcs fetcher") + assert.Equal(t, getter.SchemeGCS, req.Forced, "Forced must be normalized to the registry key") +} + +func TestCASGetterDetect_SchemeNotInRegistryFallsThrough(t *testing.T) { + t.Parallel() + + // A scheme that CASGetter does not handle, with no fetcher + // registered for it, must not be claimed by the generic-scheme + // matcher. (A higher-priority getter, TFR for instance, wins the + // outer registry race in this case.) + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + // No generic dispatch wired: only the git+file paths are active. + g := getter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}) + + // An s3:// URL would be claimed if generic dispatch were on; + // without it, the generic-scheme matcher in Detect must return + // false (the FileDetector then runs and reports a stat error, + // but the assertion here only cares that the generic-scheme + // path did not silently match). + req := &gogetter.Request{Src: "s3://bucket/key.tgz"} + + ok, _ := g.Detect(req) + assert.False(t, ok, "without WithGenericFetchers, s3:// must not be claimed") +} + +// TestNewCASGetter_PanicsOnNilVenvFS pins the constructor-time rejection +// of a Venv missing FS. The misconfiguration surfaces at the offending +// NewCASGetter call rather than at first Detect. +func TestNewCASGetter_PanicsOnNilVenvFS(t *testing.T) { + t.Parallel() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + require.PanicsWithValue(t, tgcas.ErrVenvFSUnset, func() { + getter.NewCASGetter(logger.CreateLogger(), c, tgcas.Venv{}, &tgcas.CloneOptions{}) + }) +} + +// TestNewCASGetter_PanicsOnNilVenvGit pins the constructor-time +// rejection of a Venv with FS set but Git missing. CASGetter routes +// through git for any git source, so a missing runner would otherwise +// nil-deref deep inside the clone path. +func TestNewCASGetter_PanicsOnNilVenvGit(t *testing.T) { + t.Parallel() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v := tgcas.Venv{FS: vfs.NewOSFS()} + + require.PanicsWithValue(t, tgcas.ErrVenvGitUnset, func() { + getter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}) + }) +} + +// TestCASGetterDetect_PanicsOnNilVenvFS pins the in-Detect repeat of +// the constructor check. Only reachable when a caller hand-assembles +// CASGetter and skips NewCASGetter. +func TestCASGetterDetect_PanicsOnNilVenvFS(t *testing.T) { + t.Parallel() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + g := &getter.CASGetter{ + CAS: c, + Logger: logger.CreateLogger(), + Opts: &tgcas.CloneOptions{}, + Venv: tgcas.Venv{}, + Detectors: []getter.Detector{new(getter.FileDetector)}, + } + + require.PanicsWithValue(t, tgcas.ErrVenvFSUnset, func() { + _, _ = g.Detect(&gogetter.Request{Src: "./some/local/path", Pwd: t.TempDir()}) + }) +} + +func TestCASGetterDetect_PreservesExistingArchiveQueryValue(t *testing.T) { + t.Parallel() + + g := newCASGetterForDetect(t) + + // If the URL already carries archive=true, Detect must not + // overwrite it. Same for archive=false (which is what Detect + // would have added anyway). + tests := []struct { + name string + src string + want string + }{ + {name: "preserve archive=true", src: "https://example.com/mod.tar.gz?archive=true", want: "true"}, + {name: "preserve archive=false", src: "https://example.com/mod.tar.gz?archive=false", want: "false"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := &gogetter.Request{Src: tt.src} + + ok, err := g.Detect(req) + require.NoError(t, err) + require.True(t, ok) + + u, parseErr := url.Parse(req.Src) + require.NoError(t, parseErr) + assert.Equal(t, tt.want, u.Query().Get("archive")) + }) + } +} + +// newCASGetterForDetect returns a CASGetter with the default generic +// dispatch wiring so Detect's scheme-matching path is fully exercised. +func newCASGetterForDetect(t *testing.T) *getter.CASGetter { + t.Helper() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + return getter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}, getter.WithDefaultGenericDispatch()) +} diff --git a/internal/getter/casgetter_forced_test.go b/internal/getter/casgetter_forced_test.go new file mode 100644 index 0000000000..028079a47a --- /dev/null +++ b/internal/getter/casgetter_forced_test.go @@ -0,0 +1,149 @@ +package getter_test + +import ( + "context" + "net/url" + "path/filepath" + "sync/atomic" + "testing" + + tgcas "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + gogetter "github.com/hashicorp/go-getter/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCASGetter_ForcedThreadedToInnerClient pins the wiring fix that +// would otherwise let the inner getter.Client's bare scheme-specific +// getter (Detect rejects URLs whose scheme doesn't match its +// validScheme) silently refuse to claim the request, surfacing as a +// generic "error downloading" multierror wrap. +// +// Both the bare go-getter v2 s3.Getter and gcs.Getter implement +// Detect as: forced != "" → validScheme(forced) → claim; URL.Scheme +// → validScheme(scheme) → claim; otherwise reject. For an `http://` +// or `gs://` URL with no forced field set, Detect rejects, no getter +// claims, and the client wraps with "error downloading". The fix is +// to propagate the scheme as `Forced` on the inner request. +func TestCASGetter_ForcedThreadedToInnerClient(t *testing.T) { + t.Parallel() + + const scheme = "fakescheme" + + stub := &forcedRequiredGetter{scheme: scheme} + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + g := getter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}, + getter.WithGenericFetchers(map[string]gogetter.Getter{scheme: stub}), + // No resolver registered → forces the fetch path. + ) + + client := &gogetter.Client{Getters: []gogetter.Getter{g}} + + // The outer client.Get error is incidental to this test: the + // stub's no-op Get returns nil but CAS.FetchSource then walks + // the empty temp dir and ingests an empty tree, which may or + // may not fail depending on the storeFetchedContent path. Log + // whatever comes back so the assertion below is the contract. + if _, err := client.Get(t.Context(), &gogetter.Request{ + Src: scheme + "::http://example.com/anything.tar.gz", + Dst: filepath.Join(t.TempDir(), "out"), + GetMode: gogetter.ModeAny, + }); err != nil { + t.Logf("client.Get returned %v (incidental for this test)", err) + } + + // The stub records each Detect call; if Forced wasn't + // propagated to the inner request, the stub's Detect returns + // false because the URL scheme is `http` (not `fakescheme`) and + // the inner client falls through to "error downloading". The + // assertion below pins that Forced reached the stub. + assert.Positive(t, stub.detectCalls.Load(), + "inner getter.Client should have invoked stub.Detect") + assert.Positive(t, stub.forcedCalls.Load(), + "stub.Detect should have seen req.Forced == %q at least once "+ + "(if not, CASGetter is failing to thread the scheme through)", scheme) +} + +// TestCASGetter_GetCanonicalizesForcedAlias pins that CASGetter.Get +// resolves an alias forced scheme (gs) to its registry key (gcs) +// before looking up the fetcher. Without canonicalization, a caller +// that reaches Get with req.Forced == "gs" passes the +// isGenericScheme gate (which uses lookupFetcher) and then trips a +// "no fetcher registered" failure inside getGeneric on the raw map +// lookup. +// +// Reaching the stub at all proves the alias resolved: the stub is +// only registered under "gcs", so an unresolved "gs" lookup would +// return the fetcher-not-registered error before any Detect call. +func TestCASGetter_GetCanonicalizesForcedAlias(t *testing.T) { + t.Parallel() + + stub := &forcedRequiredGetter{scheme: getter.SchemeGCS} + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + g := getter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}, + getter.WithGenericFetchers(map[string]gogetter.Getter{getter.SchemeGCS: stub}), + ) + + // Bypass the outer client so Detect does not normalize Forced; + // Get must do the canonicalization itself. + if err := g.Get(t.Context(), &gogetter.Request{ + Src: "https://www.googleapis.com/storage/v1/bucket/mod.tgz", + Dst: filepath.Join(t.TempDir(), "out"), + Forced: "gs", + GetMode: gogetter.ModeAny, + }); err != nil { + t.Logf("Get returned %v (incidental for this test)", err) + } + + assert.Positive(t, stub.forcedCalls.Load(), + "gcs fetcher stub should have been reached through the gs alias") +} + +// forcedRequiredGetter is a bare-getter stub that only claims a +// request when req.Forced matches its scheme. Mirrors the bare s3 and +// gcs getters' Detect behavior. +type forcedRequiredGetter struct { + scheme string + detectCalls atomic.Int32 + forcedCalls atomic.Int32 +} + +func (g *forcedRequiredGetter) Detect(req *gogetter.Request) (bool, error) { + g.detectCalls.Add(1) + + if req.Forced == g.scheme { + g.forcedCalls.Add(1) + return true, nil + } + + return false, nil +} + +func (g *forcedRequiredGetter) Get(_ context.Context, _ *gogetter.Request) error { + return nil +} + +func (g *forcedRequiredGetter) GetFile(_ context.Context, _ *gogetter.Request) error { + return nil +} + +func (g *forcedRequiredGetter) Mode(_ context.Context, _ *url.URL) (gogetter.Mode, error) { + return gogetter.ModeDir, nil +} diff --git a/internal/getter/casgetter_generic_test.go b/internal/getter/casgetter_generic_test.go new file mode 100644 index 0000000000..9a35707c53 --- /dev/null +++ b/internal/getter/casgetter_generic_test.go @@ -0,0 +1,197 @@ +package getter_test + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "net/http" + "net/http/httptest" + "path/filepath" + "sync/atomic" + "testing" + + tgcas "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + gogetter "github.com/hashicorp/go-getter/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// tarballHandler serves the supplied tar.gz bytes with a stable ETag and +// counts GET vs HEAD requests so a test can assert the second run skips +// the GET. +type tarballHandler struct { + etag string + body []byte + heads atomic.Int32 + gets atomic.Int32 +} + +func (h *tarballHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("ETag", `"`+h.etag+`"`) + w.Header().Set("Content-Type", "application/gzip") + + switch r.Method { + case http.MethodHead: + h.heads.Add(1) + w.WriteHeader(http.StatusOK) + case http.MethodGet: + h.gets.Add(1) + w.WriteHeader(http.StatusOK) + + if _, err := w.Write(h.body); err != nil { + panic(err) + } + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +// makeTarGz packs files into an in-memory tar.gz tree. +func makeTarGz(t *testing.T, files map[string]string) []byte { + t.Helper() + + var buf bytes.Buffer + + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + for name, body := range files { + hdr := &tar.Header{ + Name: name, + Mode: 0o644, + Size: int64(len(body)), + } + require.NoError(t, tw.WriteHeader(hdr)) + + _, err := tw.Write([]byte(body)) + require.NoError(t, err) + } + + require.NoError(t, tw.Close()) + require.NoError(t, gzw.Close()) + + return buf.Bytes() +} + +func TestCASGetter_HTTPArchiveCachesSecondRun(t *testing.T) { + t.Parallel() + + body := makeTarGz(t, map[string]string{ + "main.tf": `resource "null_resource" "a" {}`, + "sub/x.tf": "variable \"x\" {}\n", + "README.md": "hello", + }) + + h := &tarballHandler{body: body, etag: "stable-etag"} + + srv := httptest.NewServer(h) + defer srv.Close() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + + g := getter.NewCASGetter(l, c, v, &tgcas.CloneOptions{}, getter.WithDefaultGenericDispatch()) + + client := &gogetter.Client{Getters: []gogetter.Getter{g}} + + src := srv.URL + "/mod.tar.gz" + + runOnce := func(t *testing.T) string { + t.Helper() + + dst := filepath.Join(t.TempDir(), "out") + + _, err := client.Get(t.Context(), &gogetter.Request{ + Src: src, + Dst: dst, + GetMode: gogetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(dst, "main.tf")) + require.FileExists(t, filepath.Join(dst, "sub", "x.tf")) + + return dst + } + + runOnce(t) + + firstGets := h.gets.Load() + firstHeads := h.heads.Load() + + assert.Equal(t, int32(1), firstGets, "first run must download the archive once") + assert.GreaterOrEqual(t, firstHeads, int32(1), "first run must probe via HEAD before downloading") + + runOnce(t) + + assert.Equal(t, firstGets, h.gets.Load(), + "second run must hit the CAS via the probe and skip the archive GET") + assert.Greater(t, h.heads.Load(), firstHeads, + "second run must still probe to confirm the cached version is current") +} + +// TestCASGetter_HTTPMissingETagFallsBackToContentHash exercises the no-probe +// path: a server without ETag/Last-Modified causes CAS to download every +// run, but blob storage still dedupes across runs. +func TestCASGetter_HTTPMissingETagFallsBackToContentHash(t *testing.T) { + t.Parallel() + + body := makeTarGz(t, map[string]string{"main.tf": "ok"}) + + var gets atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + + if r.Method == http.MethodGet { + gets.Add(1) + } + + w.WriteHeader(http.StatusOK) + + if r.Method == http.MethodGet { + if _, err := w.Write(body); err != nil { + panic(err) + } + } + })) + defer srv.Close() + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + l := logger.CreateLogger() + + g := getter.NewCASGetter(l, c, v, &tgcas.CloneOptions{}, getter.WithDefaultGenericDispatch()) + + client := &gogetter.Client{Getters: []gogetter.Getter{g}} + + src := srv.URL + "/mod.tar.gz" + + for range 2 { + dst := t.TempDir() + + _, err := client.Get(t.Context(), &gogetter.Request{ + Src: src, + Dst: filepath.Join(dst, "out"), + GetMode: gogetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(dst, "out", "main.tf")) + } + + assert.Equal(t, int32(2), gets.Load(), + "with no ETag, every run downloads; cache deduplication happens only at the blob level") +} diff --git a/internal/getter/casprotocol.go b/internal/getter/casprotocol.go index 9f44609dc5..56a5f499b9 100644 --- a/internal/getter/casprotocol.go +++ b/internal/getter/casprotocol.go @@ -16,14 +16,23 @@ import ( type CASProtocolGetter struct { CAS *cas.CAS Logger log.Logger + Venv cas.Venv Mutable bool } // NewCASProtocolGetter creates a new CASProtocolGetter. -func NewCASProtocolGetter(l log.Logger, c *cas.CAS) *CASProtocolGetter { +// +// Requires v.FS: Get dispatches to [cas.CAS.MaterializeTree], which +// reads and links through v.FS. v.Git is not consulted because +// materialization is a pure FS operation. Panics with +// [cas.ErrVenvFSUnset] when v.FS is nil. +func NewCASProtocolGetter(l log.Logger, c *cas.CAS, v cas.Venv) *CASProtocolGetter { + v.RequireFS() + return &CASProtocolGetter{ CAS: c, Logger: l, + Venv: v, } } @@ -41,7 +50,7 @@ func (g *CASProtocolGetter) Get(ctx context.Context, req *getter.Request) error linkOpts = append(linkOpts, cas.WithForceCopy()) } - return g.CAS.MaterializeTree(ctx, g.Logger, hash, req.Dst, linkOpts...) + return g.CAS.MaterializeTree(ctx, g.Logger, g.Venv, hash, req.Dst, linkOpts...) } // GetFile is not supported for the CAS protocol getter. diff --git a/internal/getter/defaults.go b/internal/getter/defaults.go index cff57cfb83..c187bce53b 100644 --- a/internal/getter/defaults.go +++ b/internal/getter/defaults.go @@ -8,31 +8,75 @@ import ( getter "github.com/hashicorp/go-getter/v2" ) -// newHTTPGetter constructs an HttpGetter with Netrc enabled (matching -// Terragrunt's historic behavior under v1's UpdateGetters customization) -// and an optional set of extra headers. Pass nil for `extra` to get the -// default getter; pass a non-nil header set to inject auth (used by -// WithHTTPAuth and WithHTTPSAuth for GitHub release downloads). -// -// XTerraformGet is left enabled (the default) so X-Terraform-Get redirects -// continue to work. -func newHTTPGetter(extra http.Header) *getter.HttpGetter { - return &getter.HttpGetter{Netrc: true, Header: extra} +// Registry keys for the non-git fetcher and resolver maps. They match +// the lowercased scheme strings CASGetter.Detect produces. Exported so +// callers can extend or replace specific entries in +// DefaultGenericFetchers and DefaultSourceResolvers. +const ( + SchemeS3 = "s3" + SchemeGCS = "gcs" + SchemeHTTP = "http" + SchemeHTTPS = "https" + SchemeHg = "hg" + SchemeSMB = "smb" +) + +// GenericFetcherOption configures DefaultGenericFetchers. +type GenericFetcherOption func(*genericFetcherConfig) + +type genericFetcherConfig struct { + httpExtra http.Header + httpsExtra http.Header +} + +// WithHTTPExtraHeaders attaches header to the bare http getter so +// auth headers reach the wire on a CAS miss. Intended for tests that +// talk to net/http/httptest; production callers want +// WithHTTPSExtraHeaders. +func WithHTTPExtraHeaders(header http.Header) GenericFetcherOption { + return func(c *genericFetcherConfig) { c.httpExtra = header } +} + +// WithHTTPSExtraHeaders attaches header to the bare https getter. +func WithHTTPSExtraHeaders(header http.Header) GenericFetcherOption { + return func(c *genericFetcherConfig) { c.httpsExtra = header } +} + +// DefaultGenericFetchers returns the per-scheme bare getters CASGetter +// uses on a cache miss. Exported so callers that build dedicated +// CAS-only clients (the CAS-experiment path in +// runner/run/download_source.go) share the fetcher set NewClient uses. +func DefaultGenericFetchers(opts ...GenericFetcherOption) map[string]getter.Getter { + var cfg genericFetcherConfig + for _, opt := range opts { + opt(&cfg) + } + + return map[string]getter.Getter{ + SchemeS3: new(s3.Getter), + SchemeGCS: new(gcs.Getter), + SchemeHTTP: &HTTPSchemeGetter{Inner: newHTTPGetter(cfg.httpExtra), Scheme: SchemeHTTP}, + SchemeHTTPS: &HTTPSchemeGetter{Inner: newHTTPGetter(cfg.httpsExtra), Scheme: SchemeHTTPS}, + SchemeHg: new(getter.HgGetter), + SchemeSMB: new(getter.SmbClientGetter), + } } -// buildGetters realizes the option set into the ordered slice of Getters -// the client will iterate. The order matters for v2's first-match detection: +// buildGetters realizes the option set into the ordered Getter slice +// the client iterates. v2 uses first-match detection so order matters: // // 1. tfr (Terraform Registry): must precede git so tfr:// wins forced // detection. -// 2. CAS protocol getter: when CAS is enabled it resolves `cas::` source -// references produced by `update_source_with_cas` stack rewrites. -// 3. CAS git wrapper: when CAS is enabled it intercepts git URLs ahead of -// the bare GitGetter so plain `git::` sources route through CAS. -// 4. Optional caller-prepended getters (tests). +// 2. CAS protocol getter (when CAS is enabled): resolves `cas::` +// source references produced by `update_source_with_cas` stack +// rewrites. +// 3. CAS getter (when CAS is enabled): intercepts git, s3, gcs, +// http(s), hg, and smb sources ahead of the bare protocol getters. +// 4. Caller-prepended getters (tests). // 5. The default protocol set: git, hg, smb, http(s), s3, gcs, file. // -// File goes last so it doesn't claim sources that other detectors recognize. +// file goes last so it does not claim sources another detector +// recognizes. func buildGetters(b *builder) []Getter { var ( out []Getter @@ -49,17 +93,35 @@ func buildGetters(b *builder) []Getter { gitGetter = NewGitGetter() - httpGetter = &HTTPSchemeGetter{Inner: newHTTPGetter(b.httpExtraHeader), Scheme: "http"} - httpsGetter = &HTTPSchemeGetter{Inner: newHTTPGetter(b.httpsExtraHeader), Scheme: "https"} + httpGetter = &HTTPSchemeGetter{Inner: newHTTPGetter(b.httpExtraHeader), Scheme: SchemeHTTP} + httpsGetter = &HTTPSchemeGetter{Inner: newHTTPGetter(b.httpsExtraHeader), Scheme: SchemeHTTPS} + + hgGetter := new(getter.HgGetter) + smbClientGetter := new(getter.SmbClientGetter) + smbMountGetter := new(getter.SmbMountGetter) + s3Getter := new(s3.Getter) + gcsGetter := new(gcs.Getter) if b.tfRegistry != nil { out = append(out, b.tfRegistry) } if b.casStore != nil { + fetchers := map[string]getter.Getter{ + SchemeS3: s3Getter, + SchemeGCS: gcsGetter, + SchemeHTTP: httpGetter, + SchemeHTTPS: httpsGetter, + SchemeHg: hgGetter, + SchemeSMB: smbClientGetter, + } + out = append(out, - NewCASProtocolGetter(b.logger, b.casStore), - NewCASGetter(b.logger, b.casStore, b.casCloneOpts), + NewCASProtocolGetter(b.logger, b.casStore, b.casVenv), + NewCASGetter(b.logger, b.casStore, b.casVenv, b.casCloneOpts, + WithGenericFetchers(fetchers), + WithGenericResolvers(DefaultSourceResolvers()), + ), ) } @@ -67,15 +129,27 @@ func buildGetters(b *builder) []Getter { out = append(out, gitGetter, - new(getter.HgGetter), - new(getter.SmbClientGetter), - new(getter.SmbMountGetter), + hgGetter, + smbClientGetter, + smbMountGetter, httpGetter, httpsGetter, - new(s3.Getter), - new(gcs.Getter), + s3Getter, + gcsGetter, fileGetter, ) return out } + +// newHTTPGetter constructs an HttpGetter with Netrc enabled (matching +// Terragrunt's historic behavior under v1's UpdateGetters customization) +// and an optional set of extra headers. Pass nil for `extra` to get the +// default getter; pass a non-nil header set to inject auth (used by +// WithHTTPAuth and WithHTTPSAuth for GitHub release downloads). +// +// XTerraformGet is left enabled (the default) so X-Terraform-Get +// redirects continue to work. +func newHTTPGetter(extra http.Header) *getter.HttpGetter { + return &getter.HttpGetter{Netrc: true, Header: extra} +} diff --git a/internal/getter/http.go b/internal/getter/http.go index 6db616916c..3a1cf3bd9e 100644 --- a/internal/getter/http.go +++ b/internal/getter/http.go @@ -8,14 +8,11 @@ import ( getter "github.com/hashicorp/go-getter/v2" ) -// HTTPSchemeGetter wraps an [getter.HttpGetter] so its Detect only matches a -// specific scheme. Two of these (one for "http", one for "https") are -// registered by [buildGetters] so the per-scheme auth headers configured via -// [WithHTTPAuth] and [WithHTTPSAuth] route to the correct slot. -// -// Without the wrapper the upstream HttpGetter.Detect matches both http and -// https schemes, so the first registered instance wins for both and the -// second slot's auth headers never reach the wire. +// HTTPSchemeGetter wraps a [getter.HttpGetter] so its Detect only matches +// one scheme. The upstream HttpGetter.Detect claims both http and https, +// so registering two HttpGetters for per-scheme auth would have the first +// shadow the second; the wrapper is what makes [WithHTTPAuth] and +// [WithHTTPSAuth] route to their intended slots. type HTTPSchemeGetter struct { Inner *getter.HttpGetter Scheme string @@ -36,12 +33,9 @@ func (g *HTTPSchemeGetter) Mode(ctx context.Context, u *url.URL) (getter.Mode, e return g.Inner.Mode(ctx, u) } -// Detect returns true only when the request's scheme (or forced-getter -// prefix) matches the configured scheme. -// -// The prefix check is case-sensitive. URLs reach this point through -// Terragrunt's detector chain in canonical lowercase form, so the -// case-sensitive check is intentional. +// Detect claims only requests whose scheme (or forced-getter prefix) +// equals [HTTPSchemeGetter.Scheme]. URLs are canonical lowercase by the +// time they reach this getter, so the comparison is case-sensitive. func (g *HTTPSchemeGetter) Detect(req *getter.Request) (bool, error) { if req.Forced != "" { return req.Forced == g.Scheme, nil diff --git a/internal/getter/options.go b/internal/getter/options.go index a522b44b20..e32e64adca 100644 --- a/internal/getter/options.go +++ b/internal/getter/options.go @@ -29,10 +29,12 @@ func WithTFRegistry(g *RegistryGetter) Option { } // WithCAS registers CASGetter, which intercepts git/file sources and routes -// them through Terragrunt's content-addressable storage. -func WithCAS(c *cas.CAS, cloneOpts *cas.CloneOptions) Option { +// them through Terragrunt's content-addressable storage. v supplies the +// filesystem and git runner used by every CAS operation. +func WithCAS(c *cas.CAS, v cas.Venv, cloneOpts *cas.CloneOptions) Option { return func(b *builder) { b.casStore = c + b.casVenv = v b.casCloneOpts = cloneOpts } } @@ -84,6 +86,7 @@ type builder struct { tfRegistry *RegistryGetter casStore *cas.CAS casCloneOpts *cas.CloneOptions + casVenv cas.Venv httpExtraHeader http.Header httpsExtraHeader http.Header decompressors map[string]Decompressor diff --git a/internal/getter/options_test.go b/internal/getter/options_test.go index 7c315c64a9..cb45958154 100644 --- a/internal/getter/options_test.go +++ b/internal/getter/options_test.go @@ -25,7 +25,10 @@ func TestWithCASRegistersCASGetter(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(helpers.TmpDirWOSymlinks(t), "store"))) require.NoError(t, err) - client := getter.NewClient(getter.WithCAS(c, &cas.CloneOptions{})) + v, err := cas.OSVenv() + require.NoError(t, err) + + client := getter.NewClient(getter.WithCAS(c, v, &cas.CloneOptions{})) assert.True(t, hasGetter[*getter.CASGetter](client.Getters), "WithCAS must register CASGetter") assert.True(t, hasGetter[*getter.CASProtocolGetter](client.Getters), "WithCAS must register CASProtocolGetter") @@ -41,7 +44,10 @@ func TestWithCASRoutesCASProtocolURLs(t *testing.T) { c, err := cas.New(cas.WithStorePath(filepath.Join(helpers.TmpDirWOSymlinks(t), "store"))) require.NoError(t, err) - client := getter.NewClient(getter.WithCAS(c, &cas.CloneOptions{})) + v, err := cas.OSVenv() + require.NoError(t, err) + + client := getter.NewClient(getter.WithCAS(c, v, &cas.CloneOptions{})) req := &getter.Request{Src: "cas::sha1:0000000000000000000000000000000000000000"} diff --git a/internal/getter/resolver.go b/internal/getter/resolver.go new file mode 100644 index 0000000000..38cf1bb3cb --- /dev/null +++ b/internal/getter/resolver.go @@ -0,0 +1,23 @@ +package getter + +import ( + "github.com/gruntwork-io/terragrunt/internal/cas" +) + +// SourceResolver is re-exported so callers configuring CASGetter only +// need to import internal/getter. +type SourceResolver = cas.SourceResolver + +// DefaultSourceResolvers returns the per-scheme resolvers CASGetter dispatches +// through. SMB has no cheap probe so smb:// sources fall through to the +// no-resolver path in [cas.CAS.FetchSource] (download then content-hash); git +// is handled separately by [cas.CAS.Clone]. +func DefaultSourceResolvers() map[string]SourceResolver { + return map[string]SourceResolver{ + "http": NewHTTPResolver(), + "https": NewHTTPSResolver(), + "s3": NewS3Resolver(), + "gcs": NewGCSResolver(), + "hg": NewHgResolver(), + } +} diff --git a/internal/getter/resolver_gcs.go b/internal/getter/resolver_gcs.go new file mode 100644 index 0000000000..93000a4bc9 --- /dev/null +++ b/internal/getter/resolver_gcs.go @@ -0,0 +1,206 @@ +package getter + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + "cloud.google.com/go/storage" + + "github.com/gruntwork-io/terragrunt/internal/cas" +) + +// gcsResolverTimeout caps the Attrs call so a slow remote can't stall CAS +// dispatch. +const gcsResolverTimeout = 10 * time.Second + +// gcsCanonicalPathSegments is the segment count produced by splitting +// `storage///` on "/" with limit 4. +const gcsCanonicalPathSegments = 4 + +// ErrGCSMissingBucket is returned when a gs:// URL has no host segment. +var ErrGCSMissingBucket = errors.New("missing bucket in GCS URL") + +// ErrGCSMissingObject is returned when a GCS URL names a bucket but no object. +var ErrGCSMissingObject = errors.New("missing object in GCS URL") + +// ErrGCSUnrecognizedURL is returned when an http(s) URL does not match the +// canonical /storage/// shape. +var ErrGCSUnrecognizedURL = errors.New("not a recognized GCS URL") + +// ErrGCSUnsupportedScheme is returned when the URL scheme is neither gs nor http(s). +var ErrGCSUnsupportedScheme = errors.New("unsupported GCS URL scheme") + +// GCSObject is the subset of *storage.ObjectHandle a resolver uses. +type GCSObject interface { + Attrs(ctx context.Context) (*storage.ObjectAttrs, error) +} + +// GCSClient is the subset of *storage.Client a resolver uses. +type GCSClient interface { + Object(bucket, object string) GCSObject + Close() error +} + +// GCSResolver is a [cas.SourceResolver] for objects in Google Cloud +// Storage. +type GCSResolver struct { + // NewClient builds a GCS client per request. Nil means + // [storage.NewClient] with the ambient application default + // credentials. + NewClient func(ctx context.Context) (GCSClient, error) +} + +// NewGCSResolver returns a resolver wired to the ambient ADC. +func NewGCSResolver() *GCSResolver { return &GCSResolver{} } + +// Scheme returns "gcs". +func (r *GCSResolver) Scheme() string { return "gcs" } + +// Probe reads object metadata via ObjectHandle.Attrs and returns a +// content-addressed cache key from MD5 (when present) or CRC32C +// (always populated by GCS). Errors surface as +// [cas.ErrNoVersionMetadata]. +func (r *GCSResolver) Probe(ctx context.Context, rawURL string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("parse GCS URL %s: %w", rawURL, err) + } + + bucket, object, err := parseGCSURL(u) + if err != nil { + return "", fmt.Errorf("parse GCS URL %s: %w", rawURL, err) + } + + ctx, cancel := context.WithTimeout(ctx, gcsResolverTimeout) + defer cancel() + + client, err := r.client(ctx) + if err != nil { + return "", cas.ErrNoVersionMetadata + } + + key, probeErr := r.pickGCSCacheKeyFromAttrs(ctx, client, bucket, object) + + // Close errors only surface on the success path. On probe + // failure the primary error already explains the outcome, and + // joining a close error would mask the sentinel + // (ErrNoVersionMetadata) callers test for. + closeErr := client.Close() + + if probeErr != nil { + return "", probeErr + } + + if closeErr != nil { + return "", fmt.Errorf("close GCS client: %w", closeErr) + } + + return key, nil +} + +// pickGCSCacheKeyFromAttrs fetches object metadata through client and +// returns the cascade-derived cache key. ErrNoVersionMetadata signals +// either a failed Attrs call or a nil attrs payload. +func (r *GCSResolver) pickGCSCacheKeyFromAttrs( + ctx context.Context, + client GCSClient, + bucket, object string, +) (string, error) { + attrs, err := client.Object(bucket, object).Attrs(ctx) + if err != nil { + return "", cas.ErrNoVersionMetadata + } + + return pickGCSCacheKey(attrs) +} + +func (r *GCSResolver) client(ctx context.Context) (GCSClient, error) { + if r.NewClient != nil { + return r.NewClient(ctx) + } + + c, err := storage.NewClient(ctx) + if err != nil { + return nil, err + } + + return &storageClientAdapter{c: c}, nil +} + +// pickGCSCacheKey walks the cascade MD5 → CRC32C and returns the +// cache key for the first match. CRC32C participates whenever MD5 is +// absent; its value zero is a real checksum (the empty object is the +// canonical example, but other byte sequences also collide on 0), not +// a "missing" sentinel. Older code gated CRC32C on `!= 0` and +// silently downgraded zero-CRC32C content to URL-scoped opaque keys. +// +// GCS populates CRC32C for every object the SDK returns Attrs for, so +// the cascade does not need an ETag fallback in practice. A nil attrs +// payload (only reachable from a fake test client or an SDK regression) +// surfaces as ErrNoVersionMetadata so the caller falls back to content +// hashing. +func pickGCSCacheKey(attrs *storage.ObjectAttrs) (string, error) { + if attrs == nil { + return "", cas.ErrNoVersionMetadata + } + + if len(attrs.MD5) > 0 { + return cas.ContentKey("md5", hex.EncodeToString(attrs.MD5)), nil + } + + return cas.ContentKey("crc32c", strconv.FormatUint(uint64(attrs.CRC32C), 16)), nil +} + +// parseGCSURL extracts bucket and object from either canonical form. +// Accepts `https://www.googleapis.com/storage/v1//` and +// `gs:///`. URLs that name a bucket but no object are +// rejected at parse time so callers do not pay a doomed Attrs round +// trip for an SDK-shaped not-found error. +func parseGCSURL(u *url.URL) (bucket, object string, err error) { + switch strings.ToLower(u.Scheme) { + case "gs": + object = strings.TrimPrefix(u.Path, "/") + bucket = u.Host + + if bucket == "" { + return "", "", fmt.Errorf("%w: %q", ErrGCSMissingBucket, u.String()) + } + + if object == "" { + return "", "", fmt.Errorf("%w: %q", ErrGCSMissingObject, u.String()) + } + + return bucket, object, nil + case "http", "https": + // Canonical: /storage/// + parts := strings.SplitN(strings.TrimPrefix(u.Path, "/"), "/", gcsCanonicalPathSegments) + if len(parts) < gcsCanonicalPathSegments || parts[0] != "storage" { + return "", "", fmt.Errorf("%w: %q", ErrGCSUnrecognizedURL, u.String()) + } + + if parts[3] == "" { + return "", "", fmt.Errorf("%w: %q", ErrGCSMissingObject, u.String()) + } + + return parts[2], parts[3], nil + } + + return "", "", fmt.Errorf("%w: %q", ErrGCSUnsupportedScheme, u.Scheme) +} + +// storageClientAdapter narrows *storage.Client to the GCSClient interface. +type storageClientAdapter struct { + c *storage.Client +} + +func (a *storageClientAdapter) Object(bucket, object string) GCSObject { + return a.c.Bucket(bucket).Object(object) +} + +func (a *storageClientAdapter) Close() error { return a.c.Close() } diff --git a/internal/getter/resolver_gcs_test.go b/internal/getter/resolver_gcs_test.go new file mode 100644 index 0000000000..a5d28d5b98 --- /dev/null +++ b/internal/getter/resolver_gcs_test.go @@ -0,0 +1,241 @@ +package getter_test + +import ( + "context" + "encoding/hex" + "errors" + "testing" + + "cloud.google.com/go/storage" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGCSResolver_PrefersMD5(t *testing.T) { + t.Parallel() + + md5 := []byte{0xde, 0xad, 0xbe, 0xef, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b} + client := &fakeGCSClient{objects: map[string]*fakeGCSObject{ + "path/to/key.tgz": {attrs: &storage.ObjectAttrs{ + MD5: md5, + CRC32C: 0xdeadbeef, + Etag: `"some-etag"`, + }}, + }} + + r := newGCSResolverWith(client) + + got, err := r.Probe(t.Context(), "gs://bucket/path/to/key.tgz") + require.NoError(t, err) + assert.Equal(t, cas.ContentKey("md5", hex.EncodeToString(md5)), got) +} + +func TestGCSResolver_FallsThroughToCRC32CWhenMD5Absent(t *testing.T) { + t.Parallel() + + client := &fakeGCSClient{objects: map[string]*fakeGCSObject{ + "composite.tgz": {attrs: &storage.ObjectAttrs{ + // MD5 nil, common on composite objects. + CRC32C: 0xdeadbeef, + Etag: `"some-etag"`, + }}, + }} + + r := newGCSResolverWith(client) + + got, err := r.Probe(t.Context(), "gs://bucket/composite.tgz") + require.NoError(t, err) + assert.Equal(t, cas.ContentKey("crc32c", "deadbeef"), got) +} + +// TestGCSResolver_PrefersCRC32CEvenWhenZero pins that CRC32C +// participates in the cascade even when the checksum value is +// literally 0. Some legitimate content (the empty object is the +// canonical example, but other byte sequences also hash to 0) +// produces a zero CRC32C, and treating that as "absent" silently +// downgrades the cache key to the opaque ETag fallback, losing +// cross-URL dedupe for content-addressable objects. +func TestGCSResolver_PrefersCRC32CEvenWhenZero(t *testing.T) { + t.Parallel() + + client := &fakeGCSClient{objects: map[string]*fakeGCSObject{ + "zero-crc.tgz": {attrs: &storage.ObjectAttrs{ + // MD5 absent so the cascade falls to CRC32C; the legal + // value 0 must not be treated as "no signal". + CRC32C: 0, + Etag: `"some-etag"`, + }}, + }} + + r := newGCSResolverWith(client) + + got, err := r.Probe(t.Context(), "gs://bucket/zero-crc.tgz") + require.NoError(t, err) + assert.Equal(t, cas.ContentKey("crc32c", "0"), got, + "CRC32C=0 is a real checksum and must be preferred over the opaque ETag") +} + +func TestGCSResolver_AttrsFailureSurfacesErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + client := &fakeGCSClient{objects: map[string]*fakeGCSObject{ + "err.tgz": {err: errors.New("transient GCS error")}, + }} + + r := newGCSResolverWith(client) + + _, err := r.Probe(t.Context(), "gs://bucket/err.tgz") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// TestGCSResolver_NilAttrsReturnsErrNoVersionMetadata covers the +// only path to ErrNoVersionMetadata after the cascade lost its +// ETag/empty-attrs fallback: an SDK that returns (nil, nil) from +// Attrs. The fake client below stands in for that regression mode; +// real GCS always populates CRC32C, so the empty-attrs path is +// unreachable from production. +func TestGCSResolver_NilAttrsReturnsErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + client := &fakeGCSClient{objects: map[string]*fakeGCSObject{ + "empty.tgz": {attrs: nil}, + }} + + r := newGCSResolverWith(client) + + _, err := r.Probe(t.Context(), "gs://bucket/empty.tgz") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +func TestGCSResolver_AcceptsCanonicalAndShortURLs(t *testing.T) { + t.Parallel() + + md5 := []byte{0xde, 0xad, 0xbe, 0xef, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b} + client := &fakeGCSClient{objects: map[string]*fakeGCSObject{ + "path/to/key.tgz": {attrs: &storage.ObjectAttrs{MD5: md5}}, + }} + + r := newGCSResolverWith(client) + + short, err := r.Probe(t.Context(), "gs://bucket/path/to/key.tgz") + require.NoError(t, err) + + canonical, err := r.Probe(t.Context(), "https://www.googleapis.com/storage/v1/bucket/path/to/key.tgz") + require.NoError(t, err) + + // Both forms resolve to the same object metadata, so the + // content-addressed cache key is identical. + assert.Equal(t, short, canonical) +} + +func TestGCSResolver_RejectsUnknownURLShape(t *testing.T) { + t.Parallel() + + r := newGCSResolverWith(&fakeGCSClient{}) + + _, err := r.Probe(t.Context(), "ftp://bucket/key.tgz") + require.ErrorIs(t, err, getter.ErrGCSUnsupportedScheme) +} + +// TestGCSResolver_RejectsEmptyObject pins that parseGCSURL rejects +// URLs that name a bucket but no object. Without this guard the +// resolver passes object="" to ObjectHandle.Attrs, which fails +// downstream with an SDK-shaped error that surfaces as +// ErrNoVersionMetadata after a wasted round-trip. Fail at parse time +// instead so the caller sees a meaningful URL-shape error. +func TestGCSResolver_RejectsEmptyObject(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "gs bucket with no path", url: "gs://bucket"}, + {name: "gs bucket with trailing slash only", url: "gs://bucket/"}, + {name: "canonical with no object segment", url: "https://www.googleapis.com/storage/v1/bucket/"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := newGCSResolverWith(&fakeGCSClient{objects: map[string]*fakeGCSObject{ + "": {attrs: nil, err: errors.New("Attrs must not be called for an empty-object URL")}, + }}) + + _, err := r.Probe(t.Context(), tt.url) + require.ErrorIs(t, err, getter.ErrGCSMissingObject, + "parseGCSURL must reject %q with no object", tt.url) + require.NotErrorIs(t, err, cas.ErrNoVersionMetadata, + "rejection must come from parseGCSURL, not from an Attrs call on an empty object name") + }) + } +} + +// TestGCSResolver_RejectsEmptyBucket pins that parseGCSURL rejects a +// gs:// URL with no host. Without this guard the resolver passes +// bucket="" downstream and pays a doomed Attrs round trip. +func TestGCSResolver_RejectsEmptyBucket(t *testing.T) { + t.Parallel() + + r := newGCSResolverWith(&fakeGCSClient{}) + + _, err := r.Probe(t.Context(), "gs:///key.tgz") + require.ErrorIs(t, err, getter.ErrGCSMissingBucket) + require.NotErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// TestGCSResolver_RejectsUnrecognizedCanonicalPath pins that an +// http(s) URL outside the canonical `/storage///` +// shape is rejected at parse time. Mirrors the S3 resolver's +// parse-time rejection of unsupported URL forms. +func TestGCSResolver_RejectsUnrecognizedCanonicalPath(t *testing.T) { + t.Parallel() + + r := newGCSResolverWith(&fakeGCSClient{}) + + _, err := r.Probe(t.Context(), "https://www.googleapis.com/not-storage/v1/bucket/key.tgz") + require.ErrorIs(t, err, getter.ErrGCSUnrecognizedURL) + require.NotErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// fakeGCSObject returns canned object attributes. +type fakeGCSObject struct { + attrs *storage.ObjectAttrs + err error +} + +func (o *fakeGCSObject) Attrs(_ context.Context) (*storage.ObjectAttrs, error) { + if o.err != nil { + return nil, o.err + } + + return o.attrs, nil +} + +// fakeGCSClient routes Object(bucket, name) to a per-name fake. +type fakeGCSClient struct { + objects map[string]*fakeGCSObject +} + +func (c *fakeGCSClient) Object(_, name string) getter.GCSObject { + if o, ok := c.objects[name]; ok { + return o + } + + return &fakeGCSObject{err: errors.New("no such object")} +} + +func (c *fakeGCSClient) Close() error { return nil } + +func newGCSResolverWith(client *fakeGCSClient) *getter.GCSResolver { + r := getter.NewGCSResolver() + r.NewClient = func(_ context.Context) (getter.GCSClient, error) { + return client, nil + } + + return r +} diff --git a/internal/getter/resolver_hg.go b/internal/getter/resolver_hg.go new file mode 100644 index 0000000000..9632c5816a --- /dev/null +++ b/internal/getter/resolver_hg.go @@ -0,0 +1,117 @@ +package getter + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/vexec" +) + +// ErrInvalidHgRev is returned when the `?rev=` query parameter contains +// a character that would corrupt the argv passed to `hg identify` (NUL, +// newline, or carriage return). Shell metacharacters like `;` and `|` +// are not in this set: [vexec.Exec.Command] does not run through a +// shell, so they reach hg as part of the rev argument and hg rejects +// them on its own. +var ErrInvalidHgRev = errors.New("invalid hg rev") + +// hgResolverTimeout caps `hg identify` so a slow remote can't stall CAS. +const hgResolverTimeout = 10 * time.Second + +// HgResolver is a [cas.SourceResolver] for Mercurial sources. +type HgResolver struct { + // Exec runs the hg binary. Required; [NewHgResolver] wires + // [vexec.NewOSExec]. Tests substitute an in-memory backend. + Exec vexec.Exec + // HgBinary overrides the binary name resolved via [vexec.Exec.LookPath]. + // Empty means "hg". + HgBinary string +} + +// NewHgResolver returns a resolver bound to the real OS-backed exec +// and the ambient `hg` binary on PATH. +func NewHgResolver() *HgResolver { return &HgResolver{Exec: vexec.NewOSExec()} } + +// Scheme returns "hg". +func (r *HgResolver) Scheme() string { return "hg" } + +// Probe runs `hg identify --template '{node}\n'` against rawURL and +// returns the 40-char node hash as a content-addressed cache key. The +// ref comes from the URL's `rev` query parameter; absent or empty +// means "tip". Missing binary, timeout, or unreachable remote produce +// [cas.ErrNoVersionMetadata]. +// +// `--template '{node}'` is used instead of `--id` because `--id` +// returns the abbreviated 12-char short hash, which is not +// collision-safe for cache keying. +func (r *HgResolver) Probe(ctx context.Context, rawURL string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("parse hg URL %s: %w", rawURL, err) + } + + rev := u.Query().Get("rev") + if err := validateHgRev(rev); err != nil { + return "", fmt.Errorf("parse hg URL %s: %w", rawURL, err) + } + + cleaned := *u + q := cleaned.Query() + + q.Del("rev") + cleaned.RawQuery = q.Encode() + + bin := r.HgBinary + if bin == "" { + bin = "hg" + } + + if _, err := r.Exec.LookPath(bin); err != nil { + return "", cas.ErrNoVersionMetadata + } + + ctx, cancel := context.WithTimeout(ctx, hgResolverTimeout) + defer cancel() + + // --rev= and the -- terminator keep a `-`-prefixed value + // from being reparsed by hg's option parser. + args := []string{"identify", "--template", "{node}\n"} + if rev != "" { + args = append(args, "--rev="+rev) + } + + args = append(args, "--", cleaned.String()) + + out, err := r.Exec.Command(ctx, bin, args...).Output() + if err != nil { + return "", cas.ErrNoVersionMetadata + } + + node := strings.TrimSpace(string(out)) + if node == "" { + return "", cas.ErrNoVersionMetadata + } + + return cas.ContentKey("hg-node", node), nil +} + +// validateHgRev rejects rev values that would corrupt the argv handed to +// `hg identify`. NUL, newline, and carriage return are the only +// characters guarded here: they break argument boundaries inside the +// child process (NUL truncates C strings; newlines split log lines and +// some hg parsers). Other special characters reach hg literally because +// [vexec.Exec.Command] does not invoke a shell. +func validateHgRev(rev string) error { + for _, r := range rev { + if r == 0 || r == '\n' || r == '\r' { + return fmt.Errorf("%w: contains control character", ErrInvalidHgRev) + } + } + + return nil +} diff --git a/internal/getter/resolver_hg_test.go b/internal/getter/resolver_hg_test.go new file mode 100644 index 0000000000..c47c228c2f --- /dev/null +++ b/internal/getter/resolver_hg_test.go @@ -0,0 +1,241 @@ +package getter_test + +import ( + "context" + "errors" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/internal/vexec" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// errHgNotFound is the LookPath error the missing-binary tests inject. +var errHgNotFound = errors.New("hg not found") + +func TestHgResolver_MissingBinaryReturnsErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + e := vexec.NewMemExec( + hgHandler(vexec.Result{}), + vexec.WithLookPath(func(string) (string, error) { return "", errHgNotFound }), + ) + + r := &getter.HgResolver{Exec: e} + + _, err := r.Probe(t.Context(), "https://example.com/repo") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// TestHgResolver_BinaryFailureReturnsErrNoVersionMetadata feeds the +// resolver an Exec whose handler returns a non-zero exit code, so the +// resolver swallows the failure as ErrNoVersionMetadata. +func TestHgResolver_BinaryFailureReturnsErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + e := vexec.NewMemExec(hgHandler(vexec.Result{ExitCode: 1})) + + r := &getter.HgResolver{Exec: e} + + _, err := r.Probe(t.Context(), "https://example.com/repo") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// TestHgResolver_ParsesNodeFromStubOutput verifies the resolver picks +// up the hex hash from a successful command and wraps it in a +// content-addressed cache key. The stub returns the full 40-char node +// hash the resolver must request (the abbreviated 12-char form has +// ~280M values and is not collision-safe for use as a cache key). +func TestHgResolver_ParsesNodeFromStubOutput(t *testing.T) { + t.Parallel() + + const fullNode = "abcdef0123456789abcdef0123456789abcdef01" + + e := vexec.NewMemExec(hgHandler(vexec.Result{Stdout: []byte(fullNode + "\n")})) + + r := &getter.HgResolver{Exec: e} + + got, err := r.Probe(t.Context(), "https://example.com/repo?rev=tip") + require.NoError(t, err) + assert.Equal(t, cas.ContentKey("hg-node", fullNode), got) +} + +func TestHgResolver_EmptyOutputReturnsErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + e := vexec.NewMemExec(hgHandler(vexec.Result{Stdout: []byte("\n")})) + + r := &getter.HgResolver{Exec: e} + + _, err := r.Probe(t.Context(), "https://example.com/repo") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// TestHgResolver_PassesRevAsArg pins the argv shape. `--template +// '{node}\n'` is needed for the full 40-char node hash (the cache +// key); `--id` returns the 12-char short form and is not +// collision-safe. `--rev=` and the `--` URL terminator block a +// `-`-prefixed value from being reparsed as an hg flag. +func TestHgResolver_PassesRevAsArg(t *testing.T) { + t.Parallel() + + var gotArgs []string + + handler := func(_ context.Context, inv vexec.Invocation) vexec.Result { + gotArgs = inv.Args + return vexec.Result{Stdout: []byte("abcdef0123456789abcdef0123456789abcdef01\n")} + } + + r := &getter.HgResolver{Exec: vexec.NewMemExec(handler)} + + _, err := r.Probe(t.Context(), "https://example.com/repo?rev=feature-x") + require.NoError(t, err) + + assert.Equal(t, + []string{"identify", "--template", "{node}\n", "--rev=feature-x", "--", "https://example.com/repo"}, + gotArgs, + ) +} + +// TestHgResolver_FlagLikeRevStaysBoundToOption pins that a +// `-`-prefixed rev value stays inside the --rev argv element instead +// of appearing as its own flag-shaped element. +func TestHgResolver_FlagLikeRevStaysBoundToOption(t *testing.T) { + t.Parallel() + + var gotArgs []string + + handler := func(_ context.Context, inv vexec.Invocation) vexec.Result { + gotArgs = inv.Args + return vexec.Result{Stdout: []byte("abcdef0123456789abcdef0123456789abcdef01\n")} + } + + r := &getter.HgResolver{Exec: vexec.NewMemExec(handler)} + + _, err := r.Probe(t.Context(), "https://example.com/repo?rev=--debugger") + require.NoError(t, err) + + assert.Contains(t, gotArgs, "--rev=--debugger", + "flag-like rev value must stay bound to --rev in a single argv element") + assert.NotContains(t, gotArgs, "--debugger", + "flag-like rev value must not appear as its own argv element") +} + +// TestHgResolver_RejectsRevWithControlCharacters pins that NUL, +// newline, and carriage return in the rev are rejected before any hg +// invocation. Other shell metacharacters (`;`, `|`, ...) pass through +// because [vexec.Exec.Command] does not run through a shell. +func TestHgResolver_RejectsRevWithControlCharacters(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + rev string + }{ + {name: "null byte", rev: "tip\x00rest"}, + {name: "newline", rev: "tip\nrest"}, + {name: "carriage return", rev: "tip\rrest"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var commandRan bool + + handler := func(_ context.Context, _ vexec.Invocation) vexec.Result { + commandRan = true + return vexec.Result{Stdout: []byte("abcdef0123456789abcdef0123456789abcdef01\n")} + } + + r := &getter.HgResolver{Exec: vexec.NewMemExec(handler)} + + rawURL := "https://example.com/repo?rev=" + url.QueryEscape(tt.rev) + + _, err := r.Probe(t.Context(), rawURL) + require.ErrorIs(t, err, getter.ErrInvalidHgRev) + assert.False(t, commandRan, "hg must not be invoked when rev is invalid") + }) + } +} + +// TestHgResolver_AcceptsRevWithShellMetacharacters pins that shell +// metacharacters reach hg literally instead of being rejected by the +// resolver. The exec layer does not invoke a shell, so `;` is just part +// of the rev string and hg rejects it on its own merits. +func TestHgResolver_AcceptsRevWithShellMetacharacters(t *testing.T) { + t.Parallel() + + var gotArgs []string + + handler := func(_ context.Context, inv vexec.Invocation) vexec.Result { + gotArgs = inv.Args + return vexec.Result{Stdout: []byte("abcdef0123456789abcdef0123456789abcdef01\n")} + } + + r := &getter.HgResolver{Exec: vexec.NewMemExec(handler)} + + _, err := r.Probe(t.Context(), "https://example.com/repo?rev="+url.QueryEscape("tip ; echo pwned")) + require.NoError(t, err) + assert.Contains(t, gotArgs, "--rev=tip ; echo pwned", + "shell metacharacters must reach hg as part of a single argv element") +} + +// TestHgResolver_AgainstRealHg verifies the resolver against the +// actual hg binary when it is installed. It uses a freshly-initialized +// repository on disk so the test does not reach the network. The +// assertion pins the resolver's key against a ContentKey derived +// from the full 40-char node hash; this regresses if the resolver +// reverts to `--id`'s 12-char short form. +func TestHgResolver_AgainstRealHg(t *testing.T) { + t.Parallel() + + if _, err := exec.LookPath("hg"); err != nil { + t.Skip("hg binary not installed on this host") + } + + repoDir := t.TempDir() + + hg := func(args ...string) { + cmd := exec.CommandContext(t.Context(), "hg", args...) + cmd.Dir = repoDir + + out, err := cmd.CombinedOutput() + require.NoErrorf(t, err, "hg %v failed: %s", args, string(out)) + } + + hg("init", ".") + require.NoError(t, os.WriteFile(filepath.Join(repoDir, "main.tf"), []byte("hello\n"), 0o644)) + hg("--config", "ui.username=test ", "commit", "-A", "-m", "initial") + + // Independently query the full 40-char node hash so the assertion + // reflects what the resolver should be folding into the key. + nodeCmd := exec.CommandContext(t.Context(), "hg", "identify", "--template", "{node}\n", repoDir) + out, err := nodeCmd.Output() + require.NoError(t, err) + + fullNode := strings.TrimSpace(string(out)) + require.Len(t, fullNode, 40, "hg must emit a 40-char node hash with --template '{node}'") + + r := getter.NewHgResolver() + + got, err := r.Probe(t.Context(), repoDir) + require.NoError(t, err) + assert.Equal(t, cas.ContentKey("hg-node", fullNode), got) +} + +// hgHandler returns a vexec.Handler that always produces the given +// Result, regardless of invocation arguments. Used to pin +// stdout/exit-code on a per-test basis. +func hgHandler(r vexec.Result) vexec.Handler { + return func(_ context.Context, _ vexec.Invocation) vexec.Result { + return r + } +} diff --git a/internal/getter/resolver_http.go b/internal/getter/resolver_http.go new file mode 100644 index 0000000000..84cb4590f3 --- /dev/null +++ b/internal/getter/resolver_http.go @@ -0,0 +1,165 @@ +package getter + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gruntwork-io/terragrunt/internal/cas" +) + +// httpResolverTimeout caps the HEAD request so a slow remote can't stall +// CAS dispatch. On timeout the resolver returns ErrNoVersionMetadata +// and CAS falls back to content hashing. +const httpResolverTimeout = 10 * time.Second + +// HTTPResolver is a [cas.SourceResolver] for HTTP and HTTPS URLs. +type HTTPResolver struct { + // Client overrides the http.Client used for the HEAD probe. + // Nil means a copy of http.DefaultClient with httpResolverTimeout. + Client *http.Client + // scheme is what Scheme() reports; set by [NewHTTPResolver] and + // [NewHTTPSResolver]. + scheme string +} + +// NewHTTPResolver returns a resolver for the http scheme. +func NewHTTPResolver() *HTTPResolver { return &HTTPResolver{scheme: "http"} } + +// NewHTTPSResolver returns a resolver for the https scheme. The same +// type handles both; separate constructors keep the [SourceResolver] +// Scheme() contract honest for each instance. +func NewHTTPSResolver() *HTTPResolver { return &HTTPResolver{scheme: "https"} } + +// Scheme returns the URL scheme this resolver handles ("http" or +// "https"). +func (r *HTTPResolver) Scheme() string { + if r.scheme == "" { + return "http" + } + + return r.scheme +} + +// Probe HEADs rawURL and returns a URL-scoped opaque cache key derived +// from the ETag (preferred) or Last-Modified header. +// +// ETag is treated as opaque even when the server claims it is a strong +// content hash: there is no portable way to distinguish content hashes +// from server-assigned tokens. Network errors and non-2xx responses +// surface as [cas.ErrNoVersionMetadata]. +func (r *HTTPResolver) Probe(ctx context.Context, rawURL string) (string, error) { + client := r.Client + if client == nil { + c := *http.DefaultClient + c.Timeout = httpResolverTimeout + client = &c + } + + // The outer client strips these before invoking the HTTP getter, + // so probing with them attached would split cache entries that + // resolve to the same fetched bytes. + probeURL := stripHTTPMagicParams(rawURL) + + req, err := http.NewRequestWithContext(ctx, http.MethodHead, probeURL, http.NoBody) + if err != nil { + return "", fmt.Errorf("build HEAD request for %s: %w", rawURL, err) + } + + resp, err := client.Do(req) + if err != nil { + return "", cas.ErrNoVersionMetadata + } + + key, probeErr := r.pickHTTPCacheKey(probeURL, resp) + + // Body close errors only surface on the success path; on a + // probe failure the primary error already explains the outcome + // and joining would mask the ErrNoVersionMetadata sentinel + // callers test for. + closeErr := resp.Body.Close() + + if probeErr != nil { + return "", probeErr + } + + if closeErr != nil { + return "", fmt.Errorf("close HTTP response body for %s: %w", rawURL, closeErr) + } + + return key, nil +} + +// pickHTTPCacheKey reads the cache-key-bearing headers from resp and +// returns the OpaqueKey for the strongest available signal. +func (r *HTTPResolver) pickHTTPCacheKey(rawURL string, resp *http.Response) (string, error) { + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return "", cas.ErrNoVersionMetadata + } + + scheme := r.Scheme() + if u, parseErr := url.Parse(rawURL); parseErr == nil && u.Scheme != "" { + scheme = strings.ToLower(u.Scheme) + } + + if etag := strings.TrimSpace(resp.Header.Get("ETag")); etag != "" { + if normalized := normalizeETag(etag); normalized != "" { + return cas.OpaqueKey(scheme, rawURL, normalized), nil + } + } + + if lm := strings.TrimSpace(resp.Header.Get("Last-Modified")); lm != "" { + return cas.OpaqueKey(scheme, rawURL, lm), nil + } + + return "", cas.ErrNoVersionMetadata +} + +// normalizeETag strips the weak-validator W/ prefix and the surrounding +// quotes so the same bytes served with either form produce the same +// cache key. +func normalizeETag(etag string) string { + etag = strings.TrimPrefix(etag, "W/") + etag = strings.TrimPrefix(etag, "w/") + etag = strings.TrimPrefix(etag, "\"") + etag = strings.TrimSuffix(etag, "\"") + + return etag +} + +// httpMagicParams are the query keys the go-getter v2 outer client +// consumes itself rather than forwarding to the HTTP getter. +var httpMagicParams = []string{"archive", "checksum", "filename"} + +// stripHTTPMagicParams returns rawURL with [httpMagicParams] removed. +// Unparsable inputs are returned unchanged so the HEAD request +// surfaces the same error a fetch would. +func stripHTTPMagicParams(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + + q := u.Query() + + changed := false + + for _, k := range httpMagicParams { + if q.Has(k) { + q.Del(k) + + changed = true + } + } + + if !changed { + return rawURL + } + + u.RawQuery = q.Encode() + + return u.String() +} diff --git a/internal/getter/resolver_http_test.go b/internal/getter/resolver_http_test.go new file mode 100644 index 0000000000..e9667577da --- /dev/null +++ b/internal/getter/resolver_http_test.go @@ -0,0 +1,187 @@ +package getter_test + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHTTPResolver_PrefersStrongETag(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodHead, r.Method) + w.Header().Set("ETag", `"abc123"`) + w.Header().Set("Last-Modified", "Mon, 01 Jan 2024 00:00:00 GMT") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + url := srv.URL + "/mod.tgz" + + key, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Equal(t, cas.OpaqueKey("http", url, "abc123"), key) +} + +func TestHTTPResolver_StripsWeakETagPrefix(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("ETag", `W/"weak-tag"`) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + url := srv.URL + "/mod.tgz" + + key, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Equal(t, cas.OpaqueKey("http", url, "weak-tag"), key) +} + +func TestHTTPResolver_FallsBackToLastModified(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Last-Modified", "Mon, 01 Jan 2024 00:00:00 GMT") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + url := srv.URL + "/mod.tgz" + + key, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Equal(t, cas.OpaqueKey("http", url, "Mon, 01 Jan 2024 00:00:00 GMT"), key) +} + +func TestHTTPResolver_ReturnsErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + + _, err := r.Probe(t.Context(), srv.URL+"/mod.tgz") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +func TestHTTPResolver_ReturnsErrOnNon2xx(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + + _, err := r.Probe(t.Context(), srv.URL+"/mod.tgz") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// TestHTTPResolver_LowercaseWeakETag pins that both `W/` and `w/` +// weak-validator prefixes normalize to the same key, since some +// servers emit the lowercase form despite the RFC specifying upper. +func TestHTTPResolver_LowercaseWeakETag(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("ETag", `w/"weak-tag"`) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + url := srv.URL + "/mod.tgz" + + key, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Equal(t, cas.OpaqueKey("http", url, "weak-tag"), key) +} + +// TestHTTPResolver_SchemeReportsRegisteredScheme pins the +// SourceResolver contract: the resolver registered under "https" in +// DefaultSourceResolvers must report "https" from Scheme(), not "http". +func TestHTTPResolver_SchemeReportsRegisteredScheme(t *testing.T) { + t.Parallel() + + assert.Equal(t, "http", getter.NewHTTPResolver().Scheme()) + assert.Equal(t, "https", getter.NewHTTPSResolver().Scheme()) +} + +// TestHTTPResolver_StripsOuterClientMagicParams pins that probes for +// the same URL with and without the outer-client magic params share +// a cache key, and that the params do not reach the wire on HEAD. +func TestHTTPResolver_StripsOuterClientMagicParams(t *testing.T) { + t.Parallel() + + var seenQueries []string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenQueries = append(seenQueries, r.URL.RawQuery) + + w.Header().Set("ETag", `"abc123"`) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + + plain := srv.URL + "/mod.tgz" + plainKey, err := r.Probe(t.Context(), plain) + require.NoError(t, err) + + withMagic := srv.URL + "/mod.tgz?archive=zip&checksum=sha256:deadbeef&filename=override.tgz" + magicKey, err := r.Probe(t.Context(), withMagic) + require.NoError(t, err) + + assert.Equal(t, plainKey, magicKey, + "magic params must not split the cache key; the outer client strips them before fetch") + + require.Len(t, seenQueries, 2) + assert.Empty(t, seenQueries[0], "first probe has no query") + assert.Empty(t, seenQueries[1], "magic params must be stripped before the HEAD request") +} + +// TestHTTPResolver_PreservesNonMagicQueryParams pins that stripping +// is scoped to the magic-param allowlist; caller-supplied params +// (auth tokens, server-honored selectors) must reach HEAD. +func TestHTTPResolver_PreservesNonMagicQueryParams(t *testing.T) { + t.Parallel() + + var seenQueries []string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenQueries = append(seenQueries, r.URL.RawQuery) + + w.Header().Set("ETag", `"abc123"`) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r := getter.NewHTTPResolver() + + withCustom := srv.URL + "/mod.tgz?token=secret&v=2" + _, err := r.Probe(t.Context(), withCustom) + require.NoError(t, err) + + require.Len(t, seenQueries, 1) + q, parseErr := url.ParseQuery(seenQueries[0]) + require.NoError(t, parseErr) + assert.Equal(t, "secret", q.Get("token")) + assert.Equal(t, "2", q.Get("v")) +} diff --git a/internal/getter/resolver_s3.go b/internal/getter/resolver_s3.go new file mode 100644 index 0000000000..945a979615 --- /dev/null +++ b/internal/getter/resolver_s3.go @@ -0,0 +1,342 @@ +package getter + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + + "github.com/gruntwork-io/terragrunt/internal/cas" +) + +// ErrS3UnrecognizedURL is returned when an amazonaws.com URL does not match +// a supported S3 path-style or legacy virtual-host shape. +var ErrS3UnrecognizedURL = errors.New("not a recognized S3 URL") + +// ErrS3ModernPathStyleUnsupported is returned for `s3..amazonaws.com` +// URLs. The upstream go-getter/s3 v2 Getter rejects them, so the resolver +// rejects them too to keep probe success aligned with fetch success. +var ErrS3ModernPathStyleUnsupported = errors.New("modern path-style S3 URL not supported (use s3-.amazonaws.com instead)") + +// ErrS3CompatibleUnrecognizedURL is returned when a non-amazonaws.com URL +// does not have the host// path shape required for S3-compatible +// services. +var ErrS3CompatibleUnrecognizedURL = errors.New("not a recognized S3-compatible URL") + +// s3ResolverTimeout caps the HeadObject call so a slow remote can't stall +// CAS dispatch. +const s3ResolverTimeout = 10 * time.Second + +// Host-part counts for AWS S3 URL forms. +// Path style: `.amazonaws.com`. +// Virtual-host style: `..amazonaws.com`. +const ( + s3HostPartsPathStyle = 3 + s3HostPartsVHostStyle = 4 + // s3URLPathSegments is the count produced by splitting `/bucket/key` + // on "/" with limit 3: ["", "bucket", "key"]. Used as a validation + // gate before indexing. + s3URLPathSegments = 3 +) + +// S3API is the subset of *s3.Client a resolver depends on. Exported +// so tests can inject a fake. +type S3API interface { + HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) +} + +// S3Resolver is a [cas.SourceResolver] for objects in Amazon S3 and +// S3-compatible services. +// +// Supported URL forms (constrained by the upstream go-getter/s3/v2 +// Getter, whose parseUrl enforces a 3-part `amazonaws.com` hostname): +// +// https://s3.amazonaws.com// (global path-style) +// https://s3-.amazonaws.com// (legacy regional path-style) +// https:////?region= (S3-compatible service) +// +// Modern virtual-host URLs (`.s3..amazonaws.com`, +// 5-part) and modern path-style URLs (`s3..amazonaws.com`, +// 4-part) are rejected by both the bare getter and this resolver. Use +// the legacy regional form above. +type S3Resolver struct { + // NewClient builds an S3 client per request. Nil means the resolver + // uses the AWS SDK default config (env, profile, IMDS) with a + // region derived from the URL. + NewClient func(ctx context.Context, region string) (S3API, error) +} + +// NewS3Resolver returns a resolver wired to the default AWS SDK config. +func NewS3Resolver() *S3Resolver { return &S3Resolver{} } + +// Scheme returns "s3". +func (r *S3Resolver) Scheme() string { return "s3" } + +// Probe runs HeadObject with ChecksumMode=ENABLED and returns a +// cache key from the strongest available token. The cascade prefers +// content-addressed checksums (cross-URL dedupe) over the opaque ETag +// (URL-scoped): +// +// x-amz-checksum-sha256 +// x-amz-checksum-crc64nvme +// x-amz-checksum-sha1 +// x-amz-checksum-crc32c +// x-amz-checksum-crc32 +// ETag +// +// The ETag stays opaque even for single-part objects: multipart ETag +// `-` is not a content hash. Network or AWS errors surface as +// [cas.ErrNoVersionMetadata]. +func (r *S3Resolver) Probe(ctx context.Context, rawURL string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("parse S3 URL %s: %w", rawURL, err) + } + + target, err := parseS3URL(u) + if err != nil { + return "", fmt.Errorf("parse S3 URL %s: %w", rawURL, err) + } + + ctx, cancel := context.WithTimeout(ctx, s3ResolverTimeout) + defer cancel() + + client, err := r.client(ctx, target.Region) + if err != nil { + return "", cas.ErrNoVersionMetadata + } + + input := &s3.HeadObjectInput{ + Bucket: aws.String(target.Bucket), + Key: aws.String(target.Key), + ChecksumMode: types.ChecksumModeEnabled, + } + + // The bare S3 getter forwards ?version= as GetObject's VersionId, + // so HeadObject must too. Without this, the probe describes the + // current version while the fetch downloads a different one. + if target.Version != "" { + input.VersionId = aws.String(target.Version) + } + + out, err := client.HeadObject(ctx, input) + if err != nil { + return "", cas.ErrNoVersionMetadata + } + + return pickS3CacheKey(rawURL, out) +} + +// client returns the S3 client for region, using the AWS SDK config +// chain when r.NewClient is unset. +func (r *S3Resolver) client(ctx context.Context, region string) (S3API, error) { + if r.NewClient != nil { + return r.NewClient(ctx, region) + } + + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return nil, err + } + + return s3.NewFromConfig(cfg), nil +} + +// pickS3CacheKey walks the checksum cascade and returns the cache key +// for the first match. ErrNoVersionMetadata signals an empty head with +// no checksum and no ETag. +func pickS3CacheKey(rawURL string, head *s3.HeadObjectOutput) (string, error) { + if head == nil { + return "", cas.ErrNoVersionMetadata + } + + if v := strPtr(head.ChecksumSHA256); v != "" { + return cas.ContentKey("sha256", v), nil + } + + if v := strPtr(head.ChecksumCRC64NVME); v != "" { + return cas.ContentKey("crc64nvme", v), nil + } + + if v := strPtr(head.ChecksumSHA1); v != "" { + return cas.ContentKey("sha1", v), nil + } + + if v := strPtr(head.ChecksumCRC32C); v != "" { + return cas.ContentKey("crc32c", v), nil + } + + if v := strPtr(head.ChecksumCRC32); v != "" { + return cas.ContentKey("crc32", v), nil + } + + if etag := strings.TrimSpace(strPtr(head.ETag)); etag != "" { + if normalized := normalizeETag(etag); normalized != "" { + return cas.OpaqueKey("s3", rawURL, normalized), nil + } + } + + return "", cas.ErrNoVersionMetadata +} + +// s3Target is the parsed form of an S3 URL: AWS region, bucket, object +// key, and the optional ?version= selector for versioned objects. +type s3Target struct { + Region string + Bucket string + Key string + Version string +} + +// parseS3URL extracts an [s3Target] from an S3 URL in any of the forms +// go-getter accepts. Returns an error if the URL is unrecognizable. +func parseS3URL(u *url.URL) (s3Target, error) { + version := u.Query().Get("version") + + if strings.Contains(u.Host, "amazonaws.com") { + hostParts := strings.Split(u.Host, ".") + switch len(hostParts) { + case s3HostPartsPathStyle: + // Path-style: .amazonaws.com//. + // hostParts[0] must identify S3 (exactly "s3" or the + // "s3-" legacy regional prefix); otherwise the + // host belongs to a different AWS service (e.g. + // iam.amazonaws.com) and we must not silently parse it + // as path-style S3 with a bogus region. + region, ok := s3RegionFromHostLabel(hostParts[0]) + if !ok { + return s3Target{}, fmt.Errorf("%w: %q", ErrS3UnrecognizedURL, u.String()) + } + + pathParts := strings.SplitN(u.Path, "/", s3URLPathSegments) + if len(pathParts) != s3URLPathSegments { + return s3Target{}, fmt.Errorf("%w: %q", ErrS3UnrecognizedURL, u.String()) + } + + return s3Target{Region: region, Bucket: pathParts[1], Key: pathParts[2], Version: version}, nil + case s3HostPartsVHostStyle: + // hostParts[0] == "s3" is the modern path-style + // (`s3..amazonaws.com`), which the upstream + // go-getter/s3 v2 Getter rejects. Reject at probe time too + // so the failure mode matches the fetcher's rather than + // silently misparsing bucket="s3". + if hostParts[0] == "s3" { + return s3Target{}, fmt.Errorf("%w: %q", ErrS3ModernPathStyleUnsupported, u.String()) + } + + // Legacy virtual-host style: .s3[-].amazonaws.com/. + // hostParts[1] must identify S3 the same way as the + // path-style case, otherwise the host belongs to a + // non-S3 service (e.g. bucket.iam.amazonaws.com). + region, ok := s3RegionFromHostLabel(hostParts[1]) + if !ok { + return s3Target{}, fmt.Errorf("%w: %q", ErrS3UnrecognizedURL, u.String()) + } + + return s3Target{ + Region: region, + Bucket: hostParts[0], + Key: strings.TrimPrefix(u.Path, "/"), + Version: version, + }, nil + } + + return s3Target{}, fmt.Errorf("%w: %q", ErrS3UnrecognizedURL, u.String()) + } + + // S3-compatible service: host//?region= + pathParts := strings.SplitN(u.Path, "/", s3URLPathSegments) + if len(pathParts) != s3URLPathSegments { + return s3Target{}, fmt.Errorf("%w: %q", ErrS3CompatibleUnrecognizedURL, u.String()) + } + + region := u.Query().Get("region") + if region == "" { + region = "us-east-1" + } + + return s3Target{Region: region, Bucket: pathParts[1], Key: pathParts[2], Version: version}, nil +} + +// s3RegionFromHostLabel parses an S3-identifying host label and +// returns the AWS region it encodes. The exact label "s3" is the +// global path-style endpoint and maps to us-east-1. A label of the +// form "s3-" maps to that region. Any other label is rejected +// (ok = false) so non-S3 amazonaws.com hosts (iam, sts, ec2, ...) do +// not silently parse as S3 with a bogus region. +func s3RegionFromHostLabel(label string) (region string, ok bool) { + if label == "s3" { + return "us-east-1", true + } + + if region, ok := strings.CutPrefix(label, "s3-"); ok && region != "" { + return region, true + } + + return "", false +} + +// strPtr safely dereferences a *string. +func strPtr(p *string) string { + if p == nil { + return "" + } + + return *p +} + +// canonicalAWSS3HTTPSURL returns the path-style HTTPS URL for an AWS S3 +// URL in any supported form, preserving the user's query string. ok is +// false when u is not http/https against an amazonaws.com host, or when +// the host matches a form the bare go-getter v2 s3 getter rejects (modern +// virtual-host and modern path-style). +// +// The rewrite exists because the bare s3 getter's parseUrl only accepts +// path-style hosts (`s3.amazonaws.com`, `s3-.amazonaws.com`), so +// routing a virtual-host URL to it without canonicalization would set up +// a doomed inner fetch on every cache miss. +func canonicalAWSS3HTTPSURL(u *url.URL) (string, bool) { + if u == nil { + return "", false + } + + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return "", false + } + + if !strings.Contains(u.Host, "amazonaws.com") { + return "", false + } + + target, err := parseS3URL(u) + if err != nil { + return "", false + } + + canonical := *u + canonical.Scheme = "https" + canonical.Host = s3HostLabelForRegion(target.Region) + ".amazonaws.com" + canonical.Path = "/" + target.Bucket + "/" + target.Key + + return canonical.String(), true +} + +// s3HostLabelForRegion returns the path-style host label for an AWS +// region. us-east-1 maps to the global "s3" label so probes against +// region-unspecified URLs stay on the global endpoint instead of +// silently shifting to us-east-1's regional form. +func s3HostLabelForRegion(region string) string { + if region == "us-east-1" { + return "s3" + } + + return "s3-" + region +} diff --git a/internal/getter/resolver_s3_test.go b/internal/getter/resolver_s3_test.go new file mode 100644 index 0000000000..a1bba3aba6 --- /dev/null +++ b/internal/getter/resolver_s3_test.go @@ -0,0 +1,318 @@ +package getter_test + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + + "github.com/gruntwork-io/terragrunt/internal/cas" + "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestS3Resolver_PrefersSHA256(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&fakeS3Head{out: &s3.HeadObjectOutput{ + ChecksumSHA256: aws.String("sha256-token"), + ChecksumSHA1: aws.String("sha1-token"), + ETag: aws.String(`"etag-token"`), + }}) + + got, err := r.Probe(t.Context(), "https://s3-us-east-1.amazonaws.com/bucket/key.tgz") + require.NoError(t, err) + assert.Equal(t, cas.ContentKey("sha256", "sha256-token"), got) +} + +func TestS3Resolver_FallsThroughChecksumCascade(t *testing.T) { + t.Parallel() + + // Each entry isolates a single checksum so a future reorder of + // the cascade fails noisily instead of slipping past the + // "strongest-and-weakest only" sentinels. + tests := []struct { + name string + head *s3.HeadObjectOutput + want string + }{ + { + name: "CRC64NVME only", + head: &s3.HeadObjectOutput{ + ChecksumCRC64NVME: aws.String("crc64nvme-token"), + ETag: aws.String(`"etag-token"`), + }, + want: cas.ContentKey("crc64nvme", "crc64nvme-token"), + }, + { + name: "SHA1 only", + head: &s3.HeadObjectOutput{ + ChecksumSHA1: aws.String("sha1-token"), + ETag: aws.String(`"etag-token"`), + }, + want: cas.ContentKey("sha1", "sha1-token"), + }, + { + name: "CRC32C only", + head: &s3.HeadObjectOutput{ + ChecksumCRC32C: aws.String("crc32c-token"), + ETag: aws.String(`"etag-token"`), + }, + want: cas.ContentKey("crc32c", "crc32c-token"), + }, + { + name: "CRC32 only", + head: &s3.HeadObjectOutput{ + ChecksumCRC32: aws.String("crc32-token"), + ETag: aws.String(`"etag-token"`), + }, + want: cas.ContentKey("crc32", "crc32-token"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&fakeS3Head{out: tt.head}) + + got, err := r.Probe(t.Context(), "https://s3-us-east-1.amazonaws.com/bucket/key.tgz") + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestS3Resolver_FallsBackToOpaqueETag(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&fakeS3Head{out: &s3.HeadObjectOutput{ + ETag: aws.String(`"etag-token"`), + }}) + + url := "https://s3-us-east-1.amazonaws.com/bucket/key.tgz" + got, err := r.Probe(t.Context(), url) + require.NoError(t, err) + assert.Equal(t, cas.OpaqueKey("s3", url, "etag-token"), got) +} + +func TestS3Resolver_MultipartETagStaysOpaque(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&fakeS3Head{out: &s3.HeadObjectOutput{ + ETag: aws.String(`"d41d8cd98f00b204e9800998ecf8427e-3"`), + }}) + + url := "https://s3-us-east-1.amazonaws.com/bucket/key.tgz" + got, err := r.Probe(t.Context(), url) + require.NoError(t, err) + // Multipart ETag is treated opaquely, scoped to URL. + assert.Equal(t, cas.OpaqueKey("s3", url, "d41d8cd98f00b204e9800998ecf8427e-3"), got) +} + +func TestS3Resolver_HeadFailureSurfacesErrNoVersionMetadata(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&fakeS3Head{err: errors.New("transient AWS error")}) + + _, err := r.Probe(t.Context(), "https://s3-us-east-1.amazonaws.com/bucket/key.tgz") + require.ErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +// TestS3Resolver_RejectsModernURLForms pins the upstream +// go-getter/s3/v2 limitation: modern virtual-host URLs +// (`.s3..amazonaws.com`) and modern path-style URLs +// (`s3..amazonaws.com`) are rejected by the bare getter's +// parseUrl. The resolver tracks the bare getter's behavior so probe +// success aligns with fetch success. The fake fails the test if +// HeadObject is reached, since rejection has to happen at parse time +// rather than silently downgrade through a doomed HeadObject call. +func TestS3Resolver_RejectsModernURLForms(t *testing.T) { + t.Parallel() + + tests := []struct { + wantErr error + name string + url string + }{ + { + wantErr: getter.ErrS3UnrecognizedURL, + name: "modern virtual-host style with 5 host parts", + url: "https://bucket.s3.us-west-2.amazonaws.com/modules/example.tar.gz", + }, + { + wantErr: getter.ErrS3ModernPathStyleUnsupported, + name: "modern path-style with 4 host parts", + url: "https://s3.us-west-2.amazonaws.com/bucket/modules/example.tar.gz", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&assertingS3Head{t: t}) + + _, err := r.Probe(t.Context(), tt.url) + require.ErrorIs(t, err, tt.wantErr, + "parseS3URL must reject %s; upstream go-getter/s3/v2 also rejects it", tt.url) + require.NotErrorIs(t, err, cas.ErrNoVersionMetadata, + "rejection must come from parseS3URL, not from an empty HeadObject result") + }) + } +} + +type fakeS3Head struct { + out *s3.HeadObjectOutput + err error + gotInput *s3.HeadObjectInput +} + +func (f *fakeS3Head) HeadObject(_ context.Context, in *s3.HeadObjectInput, _ ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { + f.gotInput = in + + if f.err != nil { + return nil, f.err + } + + return f.out, nil +} + +// TestS3Resolver_VersionedURLForwardsVersionIDToHeadObject pins +// probe/fetch alignment for versioned S3 objects. The upstream +// go-getter/s3/v2 Getter passes ?version= as VersionId on GetObject; +// without the same forwarding on HeadObject the probe describes the +// current version while the fetch downloads a different one, and the +// cache key derived from the probe's checksum no longer matches the +// fetched bytes. +func TestS3Resolver_VersionedURLForwardsVersionIDToHeadObject(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + { + name: "aws path-style with version", + url: "https://s3-us-east-1.amazonaws.com/bucket/key.tgz?version=abc123", + }, + { + name: "aws virtual-host style with version", + url: "https://bucket.s3-us-west-2.amazonaws.com/key.tgz?version=abc123", + }, + { + name: "s3-compatible with version", + url: "https://minio.example.com/bucket/key.tgz?version=abc123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + head := &fakeS3Head{out: &s3.HeadObjectOutput{ + ChecksumSHA256: aws.String("sha256-token"), + }} + r := newS3ResolverWith(head) + + _, err := r.Probe(t.Context(), tt.url) + require.NoError(t, err) + require.NotNil(t, head.gotInput) + require.NotNil(t, head.gotInput.VersionId, + "HeadObject must receive VersionId so the probe targets the same version the fetcher downloads") + assert.Equal(t, "abc123", aws.ToString(head.gotInput.VersionId)) + }) + } +} + +// TestS3Resolver_UnversionedURLOmitsVersionID pins that an absent +// ?version= leaves VersionId nil rather than passing an empty string, +// which would be a malformed HeadObject input. +func TestS3Resolver_UnversionedURLOmitsVersionID(t *testing.T) { + t.Parallel() + + head := &fakeS3Head{out: &s3.HeadObjectOutput{ + ChecksumSHA256: aws.String("sha256-token"), + }} + r := newS3ResolverWith(head) + + _, err := r.Probe(t.Context(), "https://s3-us-east-1.amazonaws.com/bucket/key.tgz") + require.NoError(t, err) + require.NotNil(t, head.gotInput) + assert.Nil(t, head.gotInput.VersionId, + "unversioned URL must leave VersionId nil so HeadObject targets the current version") +} + +// assertingS3Head fails the test if HeadObject is reached. Pin +// parse-time rejection: parseS3URL must filter the URL before any +// network call. +type assertingS3Head struct { + t *testing.T +} + +func (f *assertingS3Head) HeadObject(_ context.Context, _ *s3.HeadObjectInput, _ ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { + f.t.Fatalf("HeadObject must not be reached for an unsupported S3 URL form") + return nil, nil +} + +// TestS3Resolver_RejectsNonS3AmazonawsHosts pins that parseS3URL +// only claims hostnames whose first label identifies S3 +// (`s3.amazonaws.com` and `s3-.amazonaws.com`). Any other +// 3-part `*.amazonaws.com` host belongs to a different AWS service +// and must be rejected at parse time, not silently parsed with a +// bogus region. +// +// Before this fix the URL below parsed as path-style with +// region="iam", bucket="bucket", key="key", which caused a wasted +// HeadObject call against a non-S3 endpoint on every probe. +func TestS3Resolver_RejectsNonS3AmazonawsHosts(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + }{ + {name: "iam endpoint", url: "https://iam.amazonaws.com/bucket/key.tgz"}, + {name: "sts endpoint", url: "https://sts.amazonaws.com/bucket/key.tgz"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&assertingS3Head{t: t}) + + _, err := r.Probe(t.Context(), tt.url) + require.ErrorIs(t, err, getter.ErrS3UnrecognizedURL, + "parseS3URL must reject non-S3 amazonaws.com host %q", tt.url) + require.NotErrorIs(t, err, cas.ErrNoVersionMetadata, + "rejection must come from parseS3URL, not from a HeadObject call against the wrong service") + }) + } +} + +// TestS3Resolver_RejectsS3CompatibleURLWithoutKey pins parse-time +// rejection of an S3-compatible host that names a bucket but no key. +// Failing at parse time keeps the resolver from issuing a doomed +// HeadObject against a non-AWS endpoint. +func TestS3Resolver_RejectsS3CompatibleURLWithoutKey(t *testing.T) { + t.Parallel() + + r := newS3ResolverWith(&assertingS3Head{t: t}) + + _, err := r.Probe(t.Context(), "https://minio.example.com/bucket") + require.ErrorIs(t, err, getter.ErrS3CompatibleUnrecognizedURL) + require.NotErrorIs(t, err, cas.ErrNoVersionMetadata) +} + +func newS3ResolverWith(head getter.S3API) *getter.S3Resolver { + r := getter.NewS3Resolver() + r.NewClient = func(_ context.Context, _ string) (getter.S3API, error) { + return head, nil + } + + return r +} diff --git a/internal/getter/types_test.go b/internal/getter/types_test.go index e9125323a4..5390c657d6 100644 --- a/internal/getter/types_test.go +++ b/internal/getter/types_test.go @@ -7,6 +7,7 @@ import ( "github.com/gruntwork-io/terragrunt/internal/cas" "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/test/helpers/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -68,7 +69,7 @@ func TestRegistryGetterDetect(t *testing.T) { func TestCASProtocolGetterMode(t *testing.T) { t.Parallel() - g := getter.NewCASProtocolGetter(logger.CreateLogger(), nil) + g := getter.NewCASProtocolGetter(logger.CreateLogger(), nil, cas.Venv{FS: vfs.NewOSFS()}) mode, err := g.Mode(t.Context(), &url.URL{Scheme: "cas"}) require.NoError(t, err) @@ -80,7 +81,7 @@ func TestCASProtocolGetterMode(t *testing.T) { func TestCASProtocolGetterGetFile(t *testing.T) { t.Parallel() - g := getter.NewCASProtocolGetter(logger.CreateLogger(), nil) + g := getter.NewCASProtocolGetter(logger.CreateLogger(), nil, cas.Venv{FS: vfs.NewOSFS()}) err := g.GetFile(t.Context(), &getter.Request{}) require.ErrorIs(t, err, cas.ErrGetFileNotSupported) diff --git a/internal/git/git.go b/internal/git/git.go index fc47b05ec7..8e4e9636ec 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -74,6 +74,12 @@ func NewGitRunner(e vexec.Exec) (*GitRunner, error) { }, nil } +// ExtractRepoName extracts the repository name from a git URL +func ExtractRepoName(repo string) string { + name := filepath.Base(repo) + return strings.TrimSuffix(name, ".git") +} + // WithWorkDir returns a new GitRunner with the specified working directory func (g *GitRunner) WithWorkDir(workDir string) *GitRunner { if g == nil { @@ -143,27 +149,6 @@ func (g *GitRunner) GetRepoRoot(ctx context.Context) (string, error) { return root, nil } -// runRepoRoot performs the uncached `git rev-parse --show-toplevel`. Use -// GetRepoRoot for the memoized entry point. -func (g *GitRunner) runRepoRoot(ctx context.Context) (string, error) { - cmd := g.prepareCommand(ctx, "rev-parse", "--show-toplevel") - - var stdout, stderr bytes.Buffer - - cmd.SetStdout(&stdout) - cmd.SetStderr(&stderr) - - if err := cmd.Run(); err != nil { - return "", &WrappedError{ - Op: "git_rev_parse", - Context: stderr.String(), - Err: errors.Join(ErrCommandSpawn, err), - } - } - - return strings.TrimSpace(stdout.String()), nil -} - // LsRemoteResult represents the output of git ls-remote type LsRemoteResult struct { Hash string @@ -470,12 +455,6 @@ func (g *GitRunner) CreateTempDir() (string, func() error, error) { return tempDir, cleanup, nil } -// ExtractRepoName extracts the repository name from a git URL -func ExtractRepoName(repo string) string { - name := filepath.Base(repo) - return strings.TrimSuffix(name, ".git") -} - // LsTreeRecursive runs git ls-tree -r and returns all blobs recursively // This eliminates the need for multiple separate ls-tree calls on subtrees func (g *GitRunner) LsTreeRecursive(ctx context.Context, ref string) (*Tree, error) { @@ -886,6 +865,27 @@ func (g *GitRunner) ObjectFormat(ctx context.Context) (string, error) { return strings.TrimSpace(stdout.String()), nil } +// runRepoRoot performs the uncached `git rev-parse --show-toplevel`. Use +// GetRepoRoot for the memoized entry point. +func (g *GitRunner) runRepoRoot(ctx context.Context) (string, error) { + cmd := g.prepareCommand(ctx, "rev-parse", "--show-toplevel") + + var stdout, stderr bytes.Buffer + + cmd.SetStdout(&stdout) + cmd.SetStderr(&stderr) + + if err := cmd.Run(); err != nil { + return "", &WrappedError{ + Op: "git_rev_parse", + Context: stderr.String(), + Err: errors.Join(ErrCommandSpawn, err), + } + } + + return strings.TrimSpace(stdout.String()), nil +} + func (g *GitRunner) prepareCommand(ctx context.Context, name string, args ...string) vexec.Cmd { cmd := g.exec.Command(ctx, g.GitPath, append([]string{name}, args...)...) cmd.SetCancel(func() error { diff --git a/internal/runner/run/download_source.go b/internal/runner/run/download_source.go index 24645d69d0..17983a9983 100644 --- a/internal/runner/run/download_source.go +++ b/internal/runner/run/download_source.go @@ -379,13 +379,19 @@ func tryCASDownload(ctx context.Context, l log.Logger, src *tf.Source, opts *Opt return false, nil } + venv, err := cas.OSVenv() + if err != nil { + l.Warnf("Failed to initialize CAS environment: %v. Falling back to standard getter.", err) + return false, nil + } + cloneOpts := cas.CloneOptions{ Dir: src.DownloadDir, IncludedGitFiles: []string{"HEAD", "config"}, Mutable: mutable, } - casProtocol := getter.NewCASProtocolGetter(l, c) + casProtocol := getter.NewCASProtocolGetter(l, c, venv) casProtocol.Mutable = mutable // CAS-only client: CASProtocolGetter handles cas::sha1: sources @@ -395,7 +401,7 @@ func tryCASDownload(ctx context.Context, l log.Logger, src *tf.Source, opts *Opt client := &getter.Client{ Getters: []getter.Getter{ casProtocol, - getter.NewCASGetter(l, c, &cloneOpts), + getter.NewCASGetter(l, c, venv, &cloneOpts, getter.WithDefaultGenericDispatch()), }, } diff --git a/internal/services/catalog/module/repo.go b/internal/services/catalog/module/repo.go index b09dea0cf1..9ea07c3145 100644 --- a/internal/services/catalog/module/repo.go +++ b/internal/services/catalog/module/repo.go @@ -379,12 +379,17 @@ func (repo *Repo) performClone(ctx context.Context, l log.Logger, fsys vfs.FS, o return err } + venv, err := cas.OSVenv() + if err != nil { + return err + } + cloneOpts := cas.CloneOptions{ Dir: repo.path, IncludedGitFiles: includedGitFiles, } - clientOpts = append(clientOpts, getter.WithCAS(c, &cloneOpts)) + clientOpts = append(clientOpts, getter.WithCAS(c, venv, &cloneOpts)) } client := getter.NewClient(clientOpts...) diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index 20fd56188c..ecc32658ca 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -14,7 +14,6 @@ import ( "path/filepath" "runtime" "slices" - "sort" "strings" "sync" "syscall" @@ -330,7 +329,7 @@ func WalkDirParallel(fsys FS, root string, fn fs.WalkDirFunc, opts ...WalkDirPar // but means that for very large directories WalkDir can be inefficient. // WalkDir does not follow symbolic links. // -// Adapted from spf13/afero#571 — replace with afero.WalkDir once merged. +// Adapted from spf13/afero#571; replace with afero.WalkDir once merged. func WalkDir(fsys FS, root string, fn fs.WalkDirFunc) error { info, err := lstatIfPossible(fsys, root) if err != nil { @@ -777,7 +776,7 @@ func (z *ZipDecompressor) extractRegularFile( } // FileInfoDirEntry wraps os.FileInfo to implement fs.DirEntry. -// Adapted from spf13/afero#571 — replace with afero equivalent once merged. +// Adapted from spf13/afero#571; replace with afero equivalent once merged. type FileInfoDirEntry struct { FileInfo os.FileInfo } @@ -1072,7 +1071,7 @@ func ReadDirEntries(fsys FS, dirname string) ([]fs.DirEntry, error) { return nil, err } - sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) + slices.SortFunc(entries, func(a, b fs.DirEntry) int { return strings.Compare(a.Name(), b.Name()) }) return entries, nil } @@ -1088,7 +1087,7 @@ func ReadDirEntries(fsys FS, dirname string) ([]fs.DirEntry, error) { entries[i] = FileInfoDirEntry{FileInfo: info} } - sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) + slices.SortFunc(entries, func(a, b fs.DirEntry) int { return strings.Compare(a.Name(), b.Name()) }) return entries, nil } @@ -1122,10 +1121,10 @@ func sanitizeZipPath(dst, name string) (string, error) { } // ValidateSymlinkTarget reports whether a symbolic link whose path is linkPath -// and whose stored target is target would resolve inside dst. Both absolute and -// dot-dot targets that climb above dst are rejected so callers can safely -// materialize symlinks from untrusted sources (zip archives, git trees) without -// letting them escape the destination directory. +// and whose stored target is target would resolve inside dst. Absolute targets +// and dot-dot targets that climb above dst are rejected so callers can safely +// materialize symlinks from untrusted sources (zip archives, fetched tarballs, +// git trees) without letting them escape the destination directory. func ValidateSymlinkTarget(dst, linkPath, target string) error { // Resolve the target relative to the link's directory absTarget := target diff --git a/pkg/config/stack.go b/pkg/config/stack.go index 5ce26f33a5..fcac1cfd36 100644 --- a/pkg/config/stack.go +++ b/pkg/config/stack.go @@ -175,21 +175,9 @@ func GenerateStackFile(ctx context.Context, l log.Logger, pctx *ParsingContext, return err } - var casInstance *cas.CAS - - if casEnabled { - if err := cas.ValidateCASCloneDepth(pctx.CASCloneDepth); err != nil { - return err - } - - c, casErr := cas.New(cas.WithCloneDepth(pctx.CASCloneDepth)) - if casErr != nil { - l.Warnf("Failed to initialize CAS for stack generation: %v. CAS features disabled.", casErr) - - casEnabled = false - } else { - casInstance = c - } + cs, err := setupCAS(l, casEnabled, pctx.CASCloneDepth) + if err != nil { + return err } genOpts := generateOpts{ @@ -203,8 +191,9 @@ func GenerateStackFile(ctx context.Context, l log.Logger, pctx *ParsingContext, targetDir: stackTargetDir, autoIncludes: autoIncludes, stackSrcBytes: stackSrcBytes, - casEnabled: casEnabled, - casInstance: casInstance, + casEnabled: cs.Enabled, + casInstance: cs.Instance, + casVenv: cs.Venv, strictControls: pctx.StrictControls, stackDepsEnabled: stackDepsEnabled, } @@ -252,10 +241,51 @@ func validateUpdateSourceWithCAS(stackFile *StackConfig, stackFilePath string, c return nil } +// casSetup is the result of setupCAS: the CAS instance and Venv that +// stack/unit generation threads through every CAS call, plus the +// Enabled flag callers gate CAS features on. Enabled is false either +// because casEnabled started false or because construction failed and +// the warning was already logged. +type casSetup struct { + Instance *cas.CAS + Venv cas.Venv + Enabled bool +} + +// setupCAS prepares the CAS bundle for stack generation. A non-nil +// error is reserved for user-facing misconfiguration (invalid clone +// depth); transient setup failures log a warning and return an +// Enabled=false bundle so the caller falls through to the standard +// getter. +func setupCAS(l log.Logger, enabled bool, cloneDepth int) (casSetup, error) { + if !enabled { + return casSetup{}, nil + } + + if err := cas.ValidateCASCloneDepth(cloneDepth); err != nil { + return casSetup{}, err + } + + c, err := cas.New(cas.WithCloneDepth(cloneDepth)) + if err != nil { + l.Warnf("Failed to initialize CAS for stack generation: %v. CAS features disabled.", err) + return casSetup{}, nil + } + + v, err := cas.OSVenv() + if err != nil { + l.Warnf("Failed to initialize CAS environment: %v. CAS features disabled.", err) + return casSetup{}, nil + } + + return casSetup{Instance: c, Venv: v, Enabled: true}, nil +} + // generateOpts holds the subset of options needed for stack/unit generation. type generateOpts struct { autoIncludes map[string]*inthclparse.AutoIncludeResolved casInstance *cas.CAS + casVenv cas.Venv sourceMap map[string]string strictControls strict.Controls rootWorkingDir string @@ -533,7 +563,7 @@ func fetchComponentSource( matOpts = append(matOpts, cas.WithForceCopy()) } - if err := opts.casInstance.MaterializeTree(ctx, l, hash, dest, matOpts...); err != nil { + if err := opts.casInstance.MaterializeTree(ctx, l, opts.casVenv, hash, dest, matOpts...); err != nil { return errors.Errorf("Failed to materialize CAS tree for %s %s: %w", kindStr, cmp.name, err) } @@ -581,7 +611,7 @@ func fetchViaCAS( ) error { resolvedSource := resolveLocalCASSource(l, sourceDir, source) - result, err := opts.casInstance.ProcessStackComponent(ctx, l, resolvedSource, kindStr) + result, err := opts.casInstance.ProcessStackComponent(ctx, l, opts.casVenv, resolvedSource, kindStr) if err != nil { return err } diff --git a/test/cas_archive_test.go b/test/cas_archive_test.go new file mode 100644 index 0000000000..2f27f4ad21 --- /dev/null +++ b/test/cas_archive_test.go @@ -0,0 +1,62 @@ +//go:build docker || aws || gcp + +package test_test + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "slices" + "testing" + + "github.com/stretchr/testify/require" +) + +// makeModuleArchive packs a minimal terragrunt-friendly module into a +// gzipped tarball. Used by the CAS-over-S3 and CAS-over-GCS +// integration tests so go-getter's archive detection recognizes the +// `.tar.gz` URL and extracts on download. The fixed content lets +// every variant test assert against the same materialized layout. +func makeModuleArchive(t *testing.T) []byte { + t.Helper() + + files := map[string]string{ + "main.tf": `resource "null_resource" "test" {}`, + "vars.tf": `variable "x" { type = string }`, + "README": "module readme", + "sub/x.tf": "# nested file", + } + + var buf bytes.Buffer + + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + + // Sort entry names so the resulting archive bytes are stable across + // runs. Map iteration order otherwise breaks any caller that pins + // behavior on the archive checksum. + names := make([]string, 0, len(files)) + for name := range files { + names = append(names, name) + } + + slices.Sort(names) + + for _, name := range names { + body := files[name] + + require.NoError(t, tw.WriteHeader(&tar.Header{ + Name: name, + Mode: 0o644, + Size: int64(len(body)), + })) + + _, err := tw.Write([]byte(body)) + require.NoError(t, err) + } + + require.NoError(t, tw.Close()) + require.NoError(t, gz.Close()) + + return buf.Bytes() +} diff --git a/test/helpers/package.go b/test/helpers/package.go index 9aa09d9beb..5e1d8296ae 100644 --- a/test/helpers/package.go +++ b/test/helpers/package.go @@ -81,8 +81,29 @@ const ( caKeyBits = 4096 semverPartsLen = 3 + + // cleanupTimeout caps the runtime of a cleanup helper invoked + // through CleanupContext. Two minutes is generous enough for + // S3/GCS/DynamoDB teardown including object listing and per-object + // deletion, short enough to bound a hung helper. + cleanupTimeout = 2 * time.Minute ) +// CleanupContext returns a context detached from t.Context()'s +// cancellation signal so cleanup helpers run correctly when invoked +// from t.Cleanup callbacks. By the time t.Cleanup fires, t.Context() +// is already canceled and any SDK call against it returns immediately +// with "context canceled", silently leaving cloud resources behind. +// +// The returned context inherits values from t.Context() (so SDK +// instrumentation propagates) but does not inherit cancellation; a +// fresh cleanupTimeout bounds the helper. Callers must call the +// cancel function when done. +func CleanupContext(t *testing.T) (context.Context, context.CancelFunc) { + t.Helper() + return context.WithTimeout(context.WithoutCancel(t.Context()), cleanupTimeout) +} + type TerraformOutput struct { Type any `json:"Type"` Value any `json:"Value"` @@ -255,10 +276,13 @@ func DeleteS3Bucket(t *testing.T, awsRegion string, bucketName string, opts ...o client := CreateS3ClientForTest(t, awsRegion, opts...) + ctx, cancel := CleanupContext(t) + defer cancel() + t.Logf("Deleting test s3 bucket %s", bucketName) // First check if bucket exists - _, err := client.HeadBucket(t.Context(), &s3.HeadBucketInput{Bucket: aws.String(bucketName)}) + _, err := client.HeadBucket(ctx, &s3.HeadBucketInput{Bucket: aws.String(bucketName)}) if err != nil { if isAWSResourceNotFoundError(err) { t.Logf("S3 bucket %s does not exist, cleanup already complete", bucketName) @@ -268,9 +292,9 @@ func DeleteS3Bucket(t *testing.T, awsRegion string, bucketName string, opts ...o t.Logf("Error checking if S3 bucket %s exists: %v", bucketName, err) } - cleanS3Bucket(t, client, bucketName) + cleanS3Bucket(t, ctx, client, bucketName) - if _, err := client.DeleteBucket(t.Context(), &s3.DeleteBucketInput{Bucket: aws.String(bucketName)}); err != nil { + if _, err := client.DeleteBucket(ctx, &s3.DeleteBucketInput{Bucket: aws.String(bucketName)}); err != nil { if isAWSResourceNotFoundError(err) { t.Logf("S3 bucket %s was already deleted", bucketName) return nil @@ -283,15 +307,15 @@ func DeleteS3Bucket(t *testing.T, awsRegion string, bucketName string, opts ...o // Sleep for a little bit first to give the bucket a chance to be ready. time.Sleep(1 * time.Second) - cleanS3Bucket(t, client, bucketName) + cleanS3Bucket(t, ctx, client, bucketName) - if _, err = client.DeleteBucket(t.Context(), &s3.DeleteBucketInput{Bucket: aws.String(bucketName)}); err != nil { + if _, err = client.DeleteBucket(ctx, &s3.DeleteBucketInput{Bucket: aws.String(bucketName)}); err != nil { if isAWSResourceNotFoundError(err) { t.Logf("S3 bucket %s was already deleted", bucketName) return nil } - t.Logf("Failed to delete S3 bucket %s: %v", bucketName, err) + t.Errorf("Failed to delete S3 bucket %s: %v", bucketName, err) return err } @@ -302,7 +326,7 @@ func DeleteS3Bucket(t *testing.T, awsRegion string, bucketName string, opts ...o return nil } -func cleanS3Bucket(t *testing.T, client *s3.Client, bucketName string) { +func cleanS3Bucket(t *testing.T, ctx context.Context, client *s3.Client, bucketName string) { t.Helper() versionsInput := &s3.ListObjectVersionsInput{ @@ -310,7 +334,7 @@ func cleanS3Bucket(t *testing.T, client *s3.Client, bucketName string) { } for { - out, err := client.ListObjectVersions(t.Context(), versionsInput) + out, err := client.ListObjectVersions(ctx, versionsInput) if err != nil { if isAWSResourceNotFoundError(err) { t.Logf("S3 bucket %s does not exist, skipping cleanup", bucketName) @@ -342,7 +366,7 @@ func cleanS3Bucket(t *testing.T, client *s3.Client, bucketName string) { }, } - _, err := client.DeleteObjects(t.Context(), deleteInput) + _, err := client.DeleteObjects(ctx, deleteInput) if err != nil { if isAWSResourceNotFoundError(err) { t.Logf("S3 bucket %s was deleted during cleanup", bucketName) @@ -370,7 +394,7 @@ func cleanS3Bucket(t *testing.T, client *s3.Client, bucketName string) { }, } - _, err := client.DeleteObjects(t.Context(), deleteInput) + _, err := client.DeleteObjects(ctx, deleteInput) if err != nil { if isAWSResourceNotFoundError(err) { t.Logf("S3 bucket %s was deleted during cleanup", bucketName) diff --git a/test/integration_aws_test.go b/test/integration_aws_test.go index 34967a3c62..be361b88fb 100644 --- a/test/integration_aws_test.go +++ b/test/integration_aws_test.go @@ -2386,7 +2386,8 @@ func cleanupTableForTest(t *testing.T, tableName string, awsRegion string) { t.Logf("Deleting test DynamoDB table %s", tableName) - ctx := t.Context() + ctx, cancel := helpers.CleanupContext(t) + defer cancel() _, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{TableName: aws.String(tableName)}) if err != nil { diff --git a/test/integration_cas_gcs_test.go b/test/integration_cas_gcs_test.go new file mode 100644 index 0000000000..0878c47b69 --- /dev/null +++ b/test/integration_cas_gcs_test.go @@ -0,0 +1,93 @@ +//go:build gcp + +package test_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "cloud.google.com/go/storage" + + tgcas "github.com/gruntwork-io/terragrunt/internal/cas" + tggetter "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + "github.com/stretchr/testify/require" +) + +// TestGcpCASGCSMD5Probe exercises CASGetter end-to-end against a +// real GCS bucket. The MD5 metadata GCS records for single-chunk +// uploads drives the content-addressed cache key; a second +// CASGetter request materializes from CAS without re-downloading. +func TestGcpCASGCSMD5Probe(t *testing.T) { + t.Parallel() + + project := os.Getenv("GOOGLE_CLOUD_PROJECT") + if project == "" { + t.Skip("GOOGLE_CLOUD_PROJECT not set; skipping real-GCP test") + } + + bucket := "terragrunt-cas-test-" + strings.ToLower(helpers.UniqueID()) + object := "modules/example.tar.gz" + + createGCSBucket(t, project, terraformRemoteStateGcpRegion, bucket) + t.Cleanup(func() { deleteGCSBucket(t, bucket) }) + + uploadGCSObjectForCAS(t, bucket, object, makeModuleArchive(t)) + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + g := tggetter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}, tggetter.WithDefaultGenericDispatch()) + client := &tggetter.Client{Getters: []tggetter.Getter{g}} + + // The bare v2 gcs.Getter's parseURL only recognizes + // googleapis.com-hosted URLs; gs:// URLs land an empty bucket. + src := "gcs::https://www.googleapis.com/storage/v1/" + bucket + "/" + object + + first := filepath.Join(helpers.TmpDirWOSymlinks(t), "first") + _, err = client.Get(t.Context(), &tggetter.Request{ + Src: src, + Dst: first, + GetMode: tggetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(first, "main.tf")) + + second := filepath.Join(helpers.TmpDirWOSymlinks(t), "second") + _, err = client.Get(t.Context(), &tggetter.Request{ + Src: src, + Dst: second, + GetMode: tggetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(second, "main.tf")) +} + +// uploadGCSObjectForCAS writes body to bucket/object using the ambient +// application default credentials, the same path the resolver and the +// go-getter GCS fetcher take. +func uploadGCSObjectForCAS(t *testing.T, bucket, object string, body []byte) { + t.Helper() + + c, err := storage.NewClient(t.Context()) + require.NoError(t, err) + + t.Cleanup(func() { + if err := c.Close(); err != nil { + t.Logf("close GCS client: %v", err) + } + }) + + w := c.Bucket(bucket).Object(object).NewWriter(t.Context()) + + _, err = w.Write(body) + require.NoError(t, err) + require.NoError(t, w.Close()) +} diff --git a/test/integration_cas_rustfs_test.go b/test/integration_cas_rustfs_test.go new file mode 100644 index 0000000000..c8bda66bba --- /dev/null +++ b/test/integration_cas_rustfs_test.go @@ -0,0 +1,198 @@ +//go:build docker + +package test_test + +import ( + "bytes" + "context" + "net/url" + "path/filepath" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + + tgcas "github.com/gruntwork-io/terragrunt/internal/cas" + tggetter "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + "github.com/stretchr/testify/require" +) + +// TestCAS_S3_RustFS_ProbeAvoidsRedownload verifies the S3 → CAS path +// against an in-Docker RustFS instance: a second CASGetter request for +// the same object skips the download when HeadObject reports the same +// version metadata. +func TestCAS_S3_RustFS_ProbeAvoidsRedownload(t *testing.T) { //nolint: paralleltest + endpoint := setupRustFSForCAS(t) + + bucket := "cas-test-" + strings.ToLower(helpers.UniqueID()) + key := "modules/example.tar.gz" + + s3Client := newRustFSClient(t, endpoint) + createRustFSBucket(t, s3Client, bucket) + uploadRustFSObject(t, s3Client, bucket, key, makeModuleArchive(t)) + + // CASGetter wired with the default generic dispatch (S3 fetcher + // + S3 resolver). The S3 client's endpoint and credentials come + // from the standard AWS_* env vars; RustFS speaks plain HTTP and + // path-style URLs, so we set the resolver's NewClient hook to + // match. + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + resolvers := tggetter.DefaultSourceResolvers() + resolvers[tggetter.SchemeS3] = newRustFSS3Resolver(t, endpoint) + + g := tggetter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}, + tggetter.WithGenericFetchers(tggetter.DefaultGenericFetchers()), + tggetter.WithGenericResolvers(resolvers), + ) + + client := &tggetter.Client{Getters: []tggetter.Getter{g}} + + // Embed credentials in the URL query so the bare go-getter + // s3.Getter takes its endpoint-override branch. Without + // aws_access_key_id/secret in the query it assumes amazonaws.com + // and ignores the RustFS endpoint. The resolver's NewClient + // hook configures BaseEndpoint separately for its HeadObject + // probe. + src := rustfsSourceURL(t, endpoint, bucket, key) + + first := filepath.Join(helpers.TmpDirWOSymlinks(t), "first") + _, err = client.Get(t.Context(), &tggetter.Request{ + Src: src, + Dst: first, + GetMode: tggetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(first, "main.tf")) + + second := filepath.Join(helpers.TmpDirWOSymlinks(t), "second") + _, err = client.Get(t.Context(), &tggetter.Request{ + Src: src, + Dst: second, + GetMode: tggetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(second, "main.tf")) +} + +// setupRustFSForCAS spins up the same RustFS container the existing +// integration tests use and exports the AWS_* env vars the SDK config +// chain reads. +func setupRustFSForCAS(t *testing.T) string { + t.Helper() + + _, addr := helpers.RunContainer(t, + "rustfs/rustfs:1.0.0-beta.2@sha256:6bd08dc511cebe0a4b5c35c266db465c7eb92cf3df4321c69967be66fe4cb395", + 9000, + testcontainers.WithCmd("/data"), + testcontainers.WithWaitStrategy(wait.ForLog("Starting:")), + ) + + t.Setenv("AWS_ACCESS_KEY_ID", "rustfsadmin") + t.Setenv("AWS_SECRET_ACCESS_KEY", "rustfsadmin") + t.Setenv("AWS_DEFAULT_REGION", "us-east-1") + + return addr +} + +func newRustFSClient(t *testing.T, endpoint string) *s3.Client { + t.Helper() + + c, err := newRustFSClientFor(t.Context(), endpoint) + require.NoError(t, err) + + return c +} + +// newRustFSClientFor builds an S3 client wired to the RustFS endpoint +// using ctx for credential / IMDS lookups. Used by the resolver's +// NewClient hook so the caller's context propagates into SDK config +// load, matching how a production resolver would honor request +// cancellation. +func newRustFSClientFor(ctx context.Context, endpoint string) (*s3.Client, error) { + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion("us-east-1"), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("rustfsadmin", "rustfsadmin", "")), + ) + if err != nil { + return nil, err + } + + return s3.NewFromConfig(cfg, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + }), nil +} + +// newRustFSS3Resolver returns an S3Resolver that talks to RustFS via +// the SDK's path-style endpoint, mirroring how a real S3Resolver would +// be configured for an S3-compatible object store. +func newRustFSS3Resolver(t *testing.T, endpoint string) *tggetter.S3Resolver { + t.Helper() + + r := tggetter.NewS3Resolver() + r.NewClient = func(ctx context.Context, _ string) (tggetter.S3API, error) { + return newRustFSClientFor(ctx, endpoint) + } + + return r +} + +func createRustFSBucket(t *testing.T, c *s3.Client, bucket string) { + t.Helper() + + _, err := c.CreateBucket(t.Context(), &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + require.NoError(t, err) +} + +func uploadRustFSObject(t *testing.T, c *s3.Client, bucket, key string, body []byte) { + t.Helper() + + _, err := c.PutObject(t.Context(), &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: bytes.NewReader(body), + ChecksumAlgorithm: types.ChecksumAlgorithmSha256, + }) + require.NoError(t, err) +} + +// rustfsSourceURL builds the `s3:::////` +// go-getter source string with RustFS credentials embedded in the +// query. The bare s3 getter parses the access-key query params and +// uses that as a signal to override its endpoint to u.Host, point at +// the testcontainer instead of real AWS. +func rustfsSourceURL(t *testing.T, endpoint, bucket, key string) string { + t.Helper() + + u, err := url.Parse(endpoint) + require.NoError(t, err) + + q := url.Values{} + q.Set("aws_access_key_id", "rustfsadmin") + q.Set("aws_access_key_secret", "rustfsadmin") + + out := url.URL{ + Scheme: u.Scheme, + Host: u.Host, + Path: "/" + bucket + "/" + key, + RawQuery: q.Encode(), + } + + return "s3::" + out.String() +} diff --git a/test/integration_cas_s3_test.go b/test/integration_cas_s3_test.go new file mode 100644 index 0000000000..dcf6d4b439 --- /dev/null +++ b/test/integration_cas_s3_test.go @@ -0,0 +1,95 @@ +//go:build aws + +package test_test + +import ( + "bytes" + "path/filepath" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + + tgcas "github.com/gruntwork-io/terragrunt/internal/cas" + tggetter "github.com/gruntwork-io/terragrunt/internal/getter" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/test/helpers/logger" + "github.com/stretchr/testify/require" +) + +// TestAwsCASS3ChecksumProbe exercises CASGetter end-to-end against a +// real S3 bucket. The PutObject sets a SHA-256 checksum so the +// resolver's preferred content-addressed path runs; on a second +// CASGetter request CAS materializes from the local store without +// re-downloading the archive. +func TestAwsCASS3ChecksumProbe(t *testing.T) { + t.Parallel() + + region := helpers.TerraformRemoteStateS3Region + bucket := "terragrunt-cas-test-" + strings.ToLower(helpers.UniqueID()) + key := "modules/example.tar.gz" + + s3Client := helpers.CreateS3ClientForTest(t, region) + + _, err := s3Client.CreateBucket(t.Context(), &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + CreateBucketConfiguration: &types.CreateBucketConfiguration{ + LocationConstraint: types.BucketLocationConstraint(region), + }, + }) + require.NoError(t, err) + + t.Cleanup(func() { + if err := helpers.DeleteS3Bucket(t, region, bucket); err != nil { + t.Logf("delete bucket %s: %v", bucket, err) + } + }) + + body := makeModuleArchive(t) + _, err = s3Client.PutObject(t.Context(), &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: bytes.NewReader(body), + ChecksumAlgorithm: types.ChecksumAlgorithmSha256, + }) + require.NoError(t, err) + + storePath := filepath.Join(helpers.TmpDirWOSymlinks(t), "store") + c, err := tgcas.New(tgcas.WithStorePath(storePath)) + require.NoError(t, err) + + v, err := tgcas.OSVenv() + require.NoError(t, err) + + g := tggetter.NewCASGetter(logger.CreateLogger(), c, v, &tgcas.CloneOptions{}, tggetter.WithDefaultGenericDispatch()) + client := &tggetter.Client{Getters: []tggetter.Getter{g}} + + // Legacy regional path-style URL: the bare go-getter s3.Getter's + // parseUrl only handles 3-part hostnames. Modern virtual-host + // URLs (`bucket.s3.region.amazonaws.com`, 5 parts) and modern + // path-style URLs (`s3.region.amazonaws.com`, 4 parts) both fail + // the bare getter's len(hostParts) != 3 check. `s3-region.amazonaws.com` + // is the form both the bare getter and our S3Resolver's parseS3URL + // recognize, so the test URL stays compatible with both. + src := "s3::https://s3-" + region + ".amazonaws.com/" + bucket + "/" + key + + first := filepath.Join(helpers.TmpDirWOSymlinks(t), "first") + _, err = client.Get(t.Context(), &tggetter.Request{ + Src: src, + Dst: first, + GetMode: tggetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(first, "main.tf")) + + second := filepath.Join(helpers.TmpDirWOSymlinks(t), "second") + _, err = client.Get(t.Context(), &tggetter.Request{ + Src: src, + Dst: second, + GetMode: tggetter.ModeAny, + }) + require.NoError(t, err) + require.FileExists(t, filepath.Join(second, "main.tf")) +} diff --git a/test/integration_gcp_test.go b/test/integration_gcp_test.go index 47bd71baef..5b522231b8 100644 --- a/test/integration_gcp_test.go +++ b/test/integration_gcp_test.go @@ -598,7 +598,8 @@ func createGCSBucket(t *testing.T, projectID string, location string, bucketName func deleteGCSBucket(t *testing.T, bucketName string) { t.Helper() - ctx := t.Context() + ctx, cancel := helpers.CleanupContext(t) + defer cancel() extGCSCfg := &gcsbackend.ExtendedRemoteStateConfigGCS{} @@ -623,20 +624,32 @@ func deleteGCSBucket(t *testing.T, bucketName string) { break } + // Tests that exercise the "no bootstrap" path never create the bucket, + // so cleanup hitting a missing bucket is expected and not a test failure. + if errors.Is(err, storage.ErrBucketNotExist) { + t.Logf("GCS bucket %s does not exist; skipping cleanup", bucketName) + return + } + if err != nil { - t.Logf("Failed to list objects and versions in GCS bucket %s: %v", bucketName, err) + t.Errorf("Failed to list objects and versions in GCS bucket %s: %v", bucketName, err) return } // purge the object version if err := bucket.Object(objectAttrs.Name).Generation(objectAttrs.Generation).Delete(ctx); err != nil { - t.Logf("Failed to delete GCS bucket object %s: %v", objectAttrs.Name, err) + t.Errorf("Failed to delete GCS bucket object %s: %v", objectAttrs.Name, err) return } } // remote empty bucket if err := bucket.Delete(ctx); err != nil { - t.Fatalf("Failed to delete GCS bucket %s: %v", bucketName, err) + if errors.Is(err, storage.ErrBucketNotExist) { + t.Logf("GCS bucket %s does not exist; skipping cleanup", bucketName) + return + } + + t.Errorf("Failed to delete GCS bucket %s: %v", bucketName, err) } }