Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 54 additions & 30 deletions nemo_run/core/packaging/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import uuid
from dataclasses import dataclass
from pathlib import Path
import re

from invoke.context import Context

Expand Down Expand Up @@ -72,6 +73,38 @@ class GitArchivePackager(Packager):
check_uncommitted_changes: bool = False
check_untracked_files: bool = False

def _concatenate_tar_files(
self, ctx: Context, output_file: str, files_to_concatenate: list[str]
):
"""Concatenate multiple uncompressed tar files into a single tar archive.

The list should include ALL fragments to merge (base + additions).
Creates/overwrites `output_file`.
"""
if not files_to_concatenate:
raise ValueError("files_to_concatenate must not be empty")

# Quote paths for shell safety
quoted_files = [shlex.quote(f) for f in files_to_concatenate]
quoted_output_file = shlex.quote(output_file)

if os.uname().sysname == "Linux":
# Start from the first archive then append the rest, to avoid self-append issues
first_file, *rest_files = quoted_files
ctx.run(f"cp {first_file} {quoted_output_file}")
if rest_files:
ctx.run(f"tar Af {quoted_output_file} {' '.join(rest_files)}")
# Remove all input fragments
ctx.run(f"rm {' '.join(quoted_files)}")
else:
# Extract all fragments and repack once (faster than iterative extract/append)
temp_dir = f"temp_extract_{uuid.uuid4()}"
ctx.run(f"mkdir -p {temp_dir}")
for file in quoted_files:
ctx.run(f"tar xf {file} -C {temp_dir}")
ctx.run(f"tar cf {quoted_output_file} -C {temp_dir} .")
ctx.run(f"rm -r {temp_dir} {' '.join(quoted_files)}")

def package(self, path: Path, job_dir: str, name: str) -> str:
output_file = os.path.join(job_dir, f"{name}.tar.gz")
if os.path.exists(output_file):
Expand Down Expand Up @@ -113,20 +146,11 @@ def package(self, path: Path, job_dir: str, name: str) -> str:
)

ctx = Context()
# we first add git files into an uncompressed archive
# then we add submodule files into that archive
# then we add an extra files from pattern to that archive
# finally we compress it (cannot compress right away, since adding files is not possible)
git_archive_cmd = (
f"git archive --format=tar --output={output_file}.tmp {self.ref}:{git_sub_path}"
)
if os.uname().sysname == "Linux":
tar_submodule_cmd = f"tar Af {output_file}.tmp $sha1.tmp && rm $sha1.tmp"
else:
tar_submodule_cmd = f"cat $sha1.tmp >> {output_file}.tmp && rm $sha1.tmp"

git_submodule_cmd = f"""git submodule foreach --recursive \
'git archive --format=tar --prefix=$sm_path/ --output=$sha1.tmp HEAD && {tar_submodule_cmd}'"""
# Build the base uncompressed archive, then separately generate all additional fragments.
# Finally, concatenate all fragments in one pass for performance and portability.
base_tar_tmp = f"{output_file}.tmp.base"
git_archive_cmd = f"git archive --format=tar --output={shlex.quote(base_tar_tmp)} {self.ref}:{git_sub_path}"
git_submodule_cmd = "git submodule foreach --recursive 'git archive --format=tar --prefix=$sm_path/ --output=$sha1.tmp HEAD'"

with ctx.cd(git_base_path):
ctx.run(git_archive_cmd)
Expand All @@ -143,6 +167,16 @@ def package(self, path: Path, job_dir: str, name: str) -> str:
"include_pattern and include_pattern_relative_path should have the same length"
)

# Collect submodule tar fragments (named as <40-hex-sha1>.tmp) if any
submodule_tmp_files: list[str] = []
if self.include_submodules:
for dirpath, _dirnames, filenames in os.walk(git_base_path):
for filename in filenames:
if re.fullmatch(r"[0-9a-f]{40}\.tmp", filename):
submodule_tmp_files.append(os.path.join(dirpath, filename))

# Generate additional fragments from include patterns and collect their paths
additional_tmp_files: list[str] = []
for include_pattern, include_pattern_relative_path in zip(
self.include_pattern, self.include_pattern_relative_path
):
Expand All @@ -158,26 +192,16 @@ def package(self, path: Path, job_dir: str, name: str) -> str:
include_pattern, include_pattern_relative_path
)
pattern_tar_file_name = os.path.join(git_base_path, pattern_tar_file_name)
include_pattern_cmd = (
f"find {relative_include_pattern} -type f | tar -cf {pattern_tar_file_name} -T -"
)
include_pattern_cmd = f"find {relative_include_pattern} -type f | tar -cf {shlex.quote(pattern_tar_file_name)} -T -"

with ctx.cd(include_pattern_relative_path):
ctx.run(include_pattern_cmd)
additional_tmp_files.append(pattern_tar_file_name)

with ctx.cd(git_base_path):
if os.uname().sysname == "Linux":
# On Linux, directly concatenate tar files
ctx.run(f"tar Af {output_file}.tmp {pattern_tar_file_name}")
ctx.run(f"rm {pattern_tar_file_name}")
else:
# Extract and repack approach for other platforms
temp_dir = f"temp_extract_{pattern_file_id}"
ctx.run(f"mkdir -p {temp_dir}")
ctx.run(f"tar xf {output_file}.tmp -C {temp_dir}")
ctx.run(f"tar xf {pattern_tar_file_name} -C {temp_dir}")
ctx.run(f"tar cf {output_file}.tmp -C {temp_dir} .")
ctx.run(f"rm -rf {temp_dir} {pattern_tar_file_name}")
# Concatenate all fragments in one pass into {output_file}.tmp
fragments_to_merge: list[str] = [base_tar_tmp] + submodule_tmp_files + additional_tmp_files
with ctx.cd(git_base_path):
self._concatenate_tar_files(ctx, f"{output_file}.tmp", fragments_to_merge)

gzip_cmd = f"gzip -c {output_file}.tmp > {output_file}"
rm_cmd = f"rm {output_file}.tmp"
Expand Down
94 changes: 94 additions & 0 deletions test/core/packaging/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import os
import shlex
import subprocess
import tarfile
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch

import invoke
Expand Down Expand Up @@ -418,3 +420,95 @@
),
)
assert len(os.listdir(os.path.join(job_dir, "extracted_output", "submodule"))) == 0


def _make_uncompressed_tar_from_dir(src_dir: Path, tar_path: Path):
# Create an uncompressed tar at tar_path from the contents of src_dir
# with files at the root of the archive
with tarfile.open(tar_path, mode="w") as tf:
for entry in sorted(src_dir.iterdir()):
tf.add(entry, arcname=entry.name)


@patch("nemo_run.core.packaging.git.Context", MockContext)
def test_concatenate_tar_files_non_linux_integration(tmp_path, monkeypatch):
# Force non-Linux path (extract+repack)
monkeypatch.setattr(os, "uname", lambda: SimpleNamespace(sysname="Darwin"))

# Prepare two small tar fragments
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
(dir_a / "fileA.txt").write_text("A")
(dir_b / "fileB.txt").write_text("B")

tar_a = tmp_path / "a.tar"
tar_b = tmp_path / "b.tar"
_make_uncompressed_tar_from_dir(dir_a, tar_a)
_make_uncompressed_tar_from_dir(dir_b, tar_b)

out_tar = tmp_path / "out.tar"
packager = GitArchivePackager()
ctx = MockContext()
packager._concatenate_tar_files(ctx, str(out_tar), [str(tar_a), str(tar_b)])

# Inputs removed
assert not tar_a.exists() and not tar_b.exists()

# Output contains both files at root
assert out_tar.exists()
with tarfile.open(out_tar, mode="r") as tf:
names = sorted(m.name for m in tf.getmembers() if m.isfile())
assert names == ["./fileA.txt", "./fileB.txt"]


def test_concatenate_tar_files_linux_emits_expected_commands(monkeypatch, tmp_path):
# Simulate Linux branch; use a dummy Context that records commands instead of executing
monkeypatch.setattr(os, "uname", lambda: SimpleNamespace(sysname="Linux"))

class DummyContext:
def __init__(self):
self.commands: list[str] = []

def run(self, cmd: str, **_kwargs):
self.commands.append(cmd)

# Support ctx.cd(...) context manager API
def cd(self, _path: Path):
class _CD:
def __enter__(self_nonlocal):
return self

def __exit__(self_nonlocal, exc_type, exc, tb):
return False

return _CD()

# Fake inputs (do not need to exist since we don't execute)
tar1 = str(tmp_path / "one.tar")
tar2 = str(tmp_path / "two.tar")
tar3 = str(tmp_path / "three.tar")
out_tar = str(tmp_path / "out.tar")

ctx = DummyContext()
packager = GitArchivePackager()
packager._concatenate_tar_files(ctx, out_tar, [tar1, tar2, tar3])

# Expected sequence: cp first -> tar Af rest -> rm all inputs
assert len(ctx.commands) == 3
assert ctx.commands[0] == f"cp {shlex.quote(tar1)} {shlex.quote(out_tar)}"
assert (
ctx.commands[1] == f"tar Af {shlex.quote(out_tar)} {shlex.quote(tar2)} {shlex.quote(tar3)}"
)
assert ctx.commands[2] == f"rm {shlex.quote(tar1)} {shlex.quote(tar2)} {shlex.quote(tar3)}"


@patch("nemo_run.core.packaging.git.Context", MockContext)
def test_include_pattern_length_mismatch_raises(packager, temp_repo):
# Mismatch between include_pattern and include_pattern_relative_path should raise
packager.include_pattern = ["extra"]
packager.include_pattern_relative_path = ["/tmp", "/also/tmp"]
with tempfile.TemporaryDirectory() as job_dir:
with pytest.raises(ValueError, match="same length"):
packager.package(Path(temp_repo), job_dir, "mismatch")
Loading