Skip to content

Commit

Permalink
add support for stages
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jul 4, 2024
1 parent b59a683 commit ed4eef0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 25 deletions.
4 changes: 4 additions & 0 deletions dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def load_from_pipeline(stage, data, typ="outs"):
Output.PARAM_PERSIST,
Output.PARAM_REMOTE,
Output.PARAM_PUSH,
Output.PARAM_PULL,
*ANNOTATION_FIELDS,
],
)
Expand Down Expand Up @@ -875,6 +876,9 @@ def dumpd(self, **kwargs): # noqa: C901, PLR0912
if not self.can_push:
ret[self.PARAM_PUSH] = self.can_push

if not self.pull:
ret[self.PARAM_PULL] = self.pull

if with_files:
obj = self.obj or self.get_obj()
if obj:
Expand Down
1 change: 1 addition & 0 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"checkpoint": bool,
Output.PARAM_REMOTE: str,
Output.PARAM_PUSH: bool,
Output.PARAM_PULL: bool,
}
}

Expand Down
3 changes: 3 additions & 0 deletions dvc/stage/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PARAM_DESC = Annotation.PARAM_DESC
PARAM_REMOTE = Output.PARAM_REMOTE
PARAM_PUSH = Output.PARAM_PUSH
PARAM_PULL = Output.PARAM_PULL

DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE

Expand All @@ -51,6 +52,8 @@ def _get_flags(out):
yield PARAM_REMOTE, out.remote
if not out.can_push:
yield PARAM_PUSH, False
if not out.pull:
yield PARAM_PULL, False


def _serialize_out(out):
Expand Down
10 changes: 5 additions & 5 deletions tests/func/test_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,11 +746,11 @@ def test_checkout_dir_compat(tmp_dir, dvc):
assert (tmp_dir / "data").read_text() == {"foo": "foo"}


def test_checkout_pull_option(tmp_dir, dvc, copy_script):
tmp_dir.dvc_gen({"explicit_1": "x", "explicit_2": "y", "always": "z"})
for name in ["explicit_1", "explicit_2"]:
with (tmp_dir / f"{name}.dvc").modify() as d:
d["outs"][0]["pull"] = False
def test_checkout_for_files_with_explicit_pull_option_set(tmp_dir, dvc, copy_script):
stages = tmp_dir.dvc_gen({"explicit_1": "x", "explicit_2": "y", "always": "z"})
for stage in stages[:-1]:
stage.outs[0].pull = False
stage.dump()

remove(tmp_dir / "explicit_1")
remove(tmp_dir / "explicit_2")
Expand Down
88 changes: 68 additions & 20 deletions tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,28 +680,76 @@ def test_pull_granular_excluding_import_that_cannot_be_pulled(
dvc.pull(imp_stage.addressing)


def test_fetch_pull_option(tmp_dir, dvc, local_remote):
file_config = {"explicit_1": "x", "explicit_2": "y", "always": "z"}
tmp_dir.dvc_gen(file_config)
for name in ["explicit_1", "explicit_2"]:
with (tmp_dir / f"{name}.dvc").modify() as d:
d["outs"][0]["pull"] = False
def test_fetch_for_files_with_explicit_pull_set(tmp_dir, dvc, local_remote):
stages = tmp_dir.dvc_gen({"explicit_1": "x", "explicit_2": "y", "always": "z"})
for stage in stages[:-1]:
stage.outs[0].pull = False
stage.dump()

dvc.push()
dvc.cache.local.clear() # purge cache

oids = {stage.outs[0].def_path: stage.outs[0].hash_info.value for stage in stages}

assert dvc.fetch() == 1
assert set(dvc.cache.local.all()) == {oids["always"]}

assert dvc.fetch("explicit_1") == 1
assert set(dvc.cache.local.all()) == {oids["always"], oids["explicit_1"]}


def test_pull_for_files_with_explicit_pull_set(tmp_dir, dvc, local_remote):
stages = tmp_dir.dvc_gen({"explicit_1": "x", "explicit_2": "y", "always": "z"})
for stage in stages[:-1]:
stage.outs[0].pull = False
stage.dump()

dvc.push()
oids_by_name = {
name: (tmp_dir / f"{name}.dvc").parse()["outs"][0]["md5"]
for name in file_config
dvc.cache.local.clear() # purge cache
remove("explicit_1")
remove("explicit_2")
remove("always")

assert dvc.pull() == {
"added": ["always"],
"deleted": [],
"fetched": 1,
"modified": [],
}
all_oids = set(oids_by_name.values())
assert set(dvc.cache.local.oids_exist(all_oids)) == all_oids

# purge cache
dvc.cache.local.clear()
assert set(dvc.cache.local.oids_exist(all_oids)) == set()
assert dvc.pull("explicit_1") == {
"added": ["explicit_1"],
"deleted": [],
"fetched": 1,
"modified": [],
}

dvc.fetch()
expected = {oids_by_name[name] for name in ["always"]}
assert set(dvc.cache.local.oids_exist(all_oids)) == expected

dvc.fetch("explicit_1")
expected = {oids_by_name[name] for name in ["always", "explicit_1"]}
assert set(dvc.cache.local.oids_exist(all_oids)) == expected
def test_pull_for_stage_outputs_with_explicit_pull_set(tmp_dir, dvc, local_remote):
stage1 = dvc.stage.add(name="always", outs=["always"], cmd="echo always > always")
stage2 = dvc.stage.add(
name="explicit", outs=["explicit"], cmd="echo explicit > explicit"
)
stage2.outs[0].pull = False
stage2.dump()

assert set(dvc.reproduce()) == {stage1, stage2}
dvc.push()

dvc.cache.local.clear() # purge cache
remove("explicit")
remove("always")

assert dvc.pull() == {
"added": ["always"],
"deleted": [],
"fetched": 1,
"modified": [],
}

assert dvc.pull("explicit") == {
"added": ["explicit"],
"deleted": [],
"fetched": 1,
"modified": [],
}

0 comments on commit ed4eef0

Please sign in to comment.