Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to require explicit fetch/checkout #10451

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def loadd_from(stage, d_list):
remote = d.pop(Output.PARAM_REMOTE, None)
annot = {field: d.pop(field, None) for field in ANNOTATION_FIELDS}
files = d.pop(Output.PARAM_FILES, None)
pull = d.pop(Output.PARAM_PULL, True)
push = d.pop(Output.PARAM_PUSH, True)
hash_name = d.pop(Output.PARAM_HASH, None)
fs_config = d.pop(Output.PARAM_FS_CONFIG, None)
Expand All @@ -109,6 +110,7 @@ def loadd_from(stage, d_list):
remote=remote,
**annot,
files=files,
pull=pull,
push=push,
hash_name=hash_name,
fs_config=fs_config,
Expand Down Expand Up @@ -189,6 +191,7 @@ def load_from_pipeline(stage, data, typ="outs"):
Output.PARAM_PERSIST,
Output.PARAM_REMOTE,
Output.PARAM_PUSH,
Output.PARAM_PULL,
*ANNOTATION_FIELDS,
],
)
Expand Down Expand Up @@ -296,6 +299,7 @@ class Output:
PARAM_PLOT_HEADER = "header"
PARAM_PERSIST = "persist"
PARAM_REMOTE = "remote"
PARAM_PULL = "pull"
PARAM_PUSH = "push"
PARAM_CLOUD = "cloud"
PARAM_HASH = "hash"
Expand Down Expand Up @@ -323,6 +327,7 @@ def __init__( # noqa: PLR0913
repo=None,
fs_config=None,
files: Optional[list[dict[str, Any]]] = None,
pull: bool = True,
push: bool = True,
hash_name: Optional[str] = DEFAULT_ALGORITHM,
):
Expand Down Expand Up @@ -387,6 +392,7 @@ def __init__( # noqa: PLR0913
self.metric = False if self.IS_DEPENDENCY else metric
self.plot = False if self.IS_DEPENDENCY else plot
self.persist = persist
self.pull = pull
self.can_push = push

self.fs_path = self._parse_path(self.fs, fs_path)
Expand Down Expand Up @@ -870,6 +876,9 @@ def dumpd(self, **kwargs): # noqa: C901, PLR0912
if not self.can_push:
ret[self.PARAM_PUSH] = self.can_push

if not self.pull:
ret[self.PARAM_PULL] = self.pull

if with_files:
obj = self.obj or self.get_obj()
if obj:
Expand Down Expand Up @@ -1501,6 +1510,7 @@ def _merge_dir_version_meta(self, other: "Output"):
**ANNOTATION_SCHEMA,
Output.PARAM_CACHE: bool,
Output.PARAM_REMOTE: str,
Output.PARAM_PULL: bool,
Output.PARAM_PUSH: bool,
Output.PARAM_FILES: [DIR_FILES_SCHEMA],
Output.PARAM_FS_CONFIG: dict,
Expand Down
26 changes: 17 additions & 9 deletions dvc/repo/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def onerror(target, exc):
raise

view = self.index.targets_view(
targets, recursive=recursive, with_deps=with_deps, onerror=onerror
targets,
recursive=recursive,
with_deps=with_deps,
onerror=onerror,
)

with ui.progress(unit="entry", desc="Building workspace index", leave=True) as pb:
Expand All @@ -145,16 +148,16 @@ def onerror(target, exc):
_check_can_delete(diff.files_delete, new, self.root_dir, self.fs)

failed = set()
out_paths = [out.fs_path for out in view.outs if out.use_cache and out.is_in_repo]
outs = [out for out in view.outs if out.use_cache and out.is_in_repo]

def checkout_onerror(src_path, dest_path, _exc):
logger.debug(
"failed to create '%s' from '%s'", dest_path, src_path, exc_info=True
)

for out_path in out_paths:
if self.fs.isin_or_eq(dest_path, out_path):
failed.add(out_path)
for out in outs:
if self.fs.isin_or_eq(dest_path, out.fs_path):
failed.add(out)

with ui.progress(unit="file", desc="Applying changes", leave=True) as pb:
apply(
Expand All @@ -171,16 +174,21 @@ def checkout_onerror(src_path, dest_path, _exc):
out_changes = _build_out_changes(view, diff.changes)

typ_map = {ADD: "added", DELETE: "deleted", MODIFY: "modified"}
failed_paths = {out.fs_path for out in failed}
for key, typ in out_changes.items():
out_path = self.fs.join(self.root_dir, *key)

if out_path in failed:
if out_path in failed_paths:
self.fs.remove(out_path, recursive=True)
else:
self.state.save_link(out_path, self.fs)
stats[typ_map[typ]].append(_fspath_dir(out_path))

if failed and not allow_missing:
raise CheckoutError([relpath(out_path) for out_path in failed], stats)

unexpected_failure = [
out.fs_path for out in failed if out.pull or (targets and any(targets))
]
if unexpected_failure and not allow_missing:
raise CheckoutError(
[relpath(out_path) for out_path in unexpected_failure], stats
)
return stats
2 changes: 2 additions & 0 deletions dvc/repo/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def stage_filter(stage: "Stage") -> bool:
def outs_filter(out: "Output") -> bool:
if push and not out.can_push:
return False
if not push and not out.pull and not (targets and any(targets)):
return False
return not (remote and out.remote and remote != out.remote)

for rev in repo.brancher(
Expand Down
1 change: 1 addition & 0 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"checkpoint": bool,
Output.PARAM_REMOTE: str,
Output.PARAM_PUSH: bool,
Output.PARAM_PULL: bool,
}
}

Expand Down
3 changes: 3 additions & 0 deletions dvc/stage/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PARAM_DESC = Annotation.PARAM_DESC
PARAM_REMOTE = Output.PARAM_REMOTE
PARAM_PUSH = Output.PARAM_PUSH
PARAM_PULL = Output.PARAM_PULL

DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE

Expand All @@ -51,6 +52,8 @@ def _get_flags(out):
yield PARAM_REMOTE, out.remote
if not out.can_push:
yield PARAM_PUSH, False
if not out.pull:
yield PARAM_PULL, False


def _serialize_out(out):
Expand Down
26 changes: 26 additions & 0 deletions tests/func/test_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,29 @@ def test_checkout_dir_compat(tmp_dir, dvc):
remove("data")
dvc.checkout()
assert (tmp_dir / "data").read_text() == {"foo": "foo"}


def test_checkout_for_files_with_explicit_pull_option_set(tmp_dir, dvc, copy_script):
stages = tmp_dir.dvc_gen({"explicit_1": "x", "explicit_2": "y", "always": "z"})
for stage in stages[:-1]:
stage.outs[0].pull = False
stage.dump()

remove(tmp_dir / "explicit_1")
remove(tmp_dir / "explicit_2")
remove(tmp_dir / "always")

# ensure missing pull=False file does not cause an error
explicit2_oid = (tmp_dir / "explicit_2.dvc").parse()["outs"][0]["md5"]
dvc.cache.local.delete(explicit2_oid)

dvc.checkout(force=True)
# pull=False, but present in cache
assert (tmp_dir / "explicit_1").read_text() == "x"
# pull=False, not in cache
assert not (tmp_dir / "explicit_2").exists()
# pull=True
assert (tmp_dir / "always").read_text() == "z"

with pytest.raises(CheckoutError):
dvc.checkout(targets="explicit_2.dvc", force=True)
75 changes: 75 additions & 0 deletions tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,3 +678,78 @@ def test_pull_granular_excluding_import_that_cannot_be_pulled(
dvc.pull()
with pytest.raises(CloneError, match="SCM error"):
dvc.pull(imp_stage.addressing)


def test_fetch_for_files_with_explicit_pull_set(tmp_dir, dvc, local_remote):
stages = tmp_dir.dvc_gen({"explicit_1": "x", "explicit_2": "y", "always": "z"})
for stage in stages[:-1]:
stage.outs[0].pull = False
stage.dump()

dvc.push()
dvc.cache.local.clear() # purge cache

oids = {stage.outs[0].def_path: stage.outs[0].hash_info.value for stage in stages}

assert dvc.fetch() == 1
assert set(dvc.cache.local.all()) == {oids["always"]}

assert dvc.fetch("explicit_1") == 1
assert set(dvc.cache.local.all()) == {oids["always"], oids["explicit_1"]}


def test_pull_for_files_with_explicit_pull_set(tmp_dir, dvc, local_remote):
stages = tmp_dir.dvc_gen({"explicit_1": "x", "explicit_2": "y", "always": "z"})
for stage in stages[:-1]:
stage.outs[0].pull = False
stage.dump()

dvc.push()
dvc.cache.local.clear() # purge cache
remove("explicit_1")
remove("explicit_2")
remove("always")

assert dvc.pull() == {
"added": ["always"],
"deleted": [],
"fetched": 1,
"modified": [],
}

assert dvc.pull("explicit_1") == {
"added": ["explicit_1"],
"deleted": [],
"fetched": 1,
"modified": [],
}


def test_pull_for_stage_outputs_with_explicit_pull_set(tmp_dir, dvc, local_remote):
stage1 = dvc.stage.add(name="always", outs=["always"], cmd="echo always > always")
stage2 = dvc.stage.add(
name="explicit", outs=["explicit"], cmd="echo explicit > explicit"
)
stage2.outs[0].pull = False
stage2.dump()

assert set(dvc.reproduce()) == {stage1, stage2}
dvc.push()

dvc.cache.local.clear() # purge cache
remove("explicit")
remove("always")

assert dvc.pull() == {
"added": ["always"],
"deleted": [],
"fetched": 1,
"modified": [],
}

assert dvc.pull("explicit") == {
"added": ["explicit"],
"deleted": [],
"fetched": 1,
"modified": [],
}
Loading