diff --git a/alphatrion/artifact/artifact.py b/alphatrion/artifact/artifact.py index eaf3b32..34a40af 100644 --- a/alphatrion/artifact/artifact.py +++ b/alphatrion/artifact/artifact.py @@ -21,7 +21,7 @@ def push( repo_name: str, paths: str | list[str], version: str = "latest", - ): + ) -> str: """ Push files or all files in a folder to the artifact registry. You can specify either files or folder, but not both. @@ -48,13 +48,16 @@ def push( raise ValueError("No files to push.") url = self._url if self._url.endswith("/") else f"{self._url}/" - target = f"{url}{self._project_id}/{repo_name}:{version}" + path = f"{self._project_id}/{repo_name}:{version}" + target = f"{url}{path}" try: self._client.push(target, files=files_to_push) except Exception as e: raise RuntimeError("Failed to push artifacts") from e + return path + # TODO: should we store it in the metadb instead? def list_versions(self, repo_name: str) -> list[str]: url = self._url if self._url.endswith("/") else f"{self._url}/" diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index 7b0c831..cde584a 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -3,11 +3,13 @@ from alphatrion.trial.trial import current_trial_id from alphatrion.utils import time as utime +ARTIFACT_PATH = "artifact_path" + async def log_artifact( paths: str | list[str], version: str = "latest", -): +) -> str: """ Log artifacts (files) to the artifact registry. @@ -20,16 +22,16 @@ async def log_artifact( :param version: the version (tag) to log the files """ - return log_artifact_sync( + return log_artifact_in_sync( paths=paths, version=version, ) -def log_artifact_sync( +def log_artifact_in_sync( paths: str | list[str], version: str = "latest", -): +) -> str: """ Log artifacts (files) to the artifact registry (synchronous version). @@ -61,7 +63,7 @@ def log_artifact_sync( if exp is None: raise RuntimeError("No running experiment found in the current context.") - runtime._artifact.push(repo_name=str(exp.id), paths=paths, version=version) + return runtime._artifact.push(repo_name=str(exp.id), paths=paths, version=version) # log_params is used to save a set of parameters, which is a dict of key-value pairs. @@ -128,10 +130,14 @@ async def log_metrics(metrics: dict[str, float]): # TODO: we should save the artifact path in the metrics as well. if should_checkpoint: - await log_artifact( + address = await log_artifact( paths=trial.config().checkpoint.path, version=utime.now_2_hash(), ) + runtime._metadb.update_run( + run_id=run_id, + meta={ARTIFACT_PATH: address}, + ) if should_early_stop or should_stop_on_target: trial.done() diff --git a/alphatrion/metadata/sql.py b/alphatrion/metadata/sql.py index a07f4ae..ad7d7f4 100644 --- a/alphatrion/metadata/sql.py +++ b/alphatrion/metadata/sql.py @@ -338,6 +338,7 @@ def update_run(self, run_id: uuid.UUID, **kwargs) -> None: run = session.query(Run).filter(Run.uuid == run_id, Run.is_del == 0).first() if run: for key, value in kwargs.items(): + # TODO: meta update should be merged instead of replaced setattr(run, key, value) session.commit() session.close() diff --git a/alphatrion/utils/time.py b/alphatrion/utils/time.py index e0f11c1..90792fb 100644 --- a/alphatrion/utils/time.py +++ b/alphatrion/utils/time.py @@ -24,6 +24,4 @@ def humanize_time(timestamp: str) -> str: def now_2_hash() -> str: timestamp = str(int(datetime.now(UTC).timestamp())) unique_hash = hashlib.sha1(timestamp.encode()).hexdigest()[:7] - - print("Generated hash from timestamp:", timestamp, "->", unique_hash) return unique_hash diff --git a/tests/integration/test_log.py b/tests/integration/test_log.py index 18c0dda..a997b99 100644 --- a/tests/integration/test_log.py +++ b/tests/integration/test_log.py @@ -8,6 +8,7 @@ import pytest import alphatrion as alpha +from alphatrion.log.log import ARTIFACT_PATH from alphatrion.metadata.sql_models import Status from alphatrion.trial.trial import current_trial_id @@ -137,11 +138,18 @@ async def log_metric(metrics: dict): @pytest.mark.asyncio async def test_log_metrics_with_save_on_max(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + project_id = uuid.uuid4() + alpha.init(project_id=project_id, artifact_insecure=True, init_tables=True) async def log_metric(value: float): await alpha.log_metrics({"accuracy": value}) + def find_unused_version(used_versions, all_versions): + for v in all_versions: + if v not in used_versions: + return v + return None + async with alpha.CraftExperiment.setup( name="log_metrics_with_save_on_max", description="Context manager test", @@ -171,8 +179,17 @@ async def log_metric(value: float): run = trial.start_run(lambda: log_metric(0.90)) await run.wait() + # We need this because the returned version is unordered. + used_version = [] + versions = exp._runtime._artifact.list_versions(exp.id) assert len(versions) == 1 + run_obj = run._get_obj() + fixed_version = versions[0] + used_version.append(fixed_version) + assert ( + run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version + ) # To avoid the same timestamp hash, we wait for 1 second time.sleep(1) @@ -191,6 +208,13 @@ async def log_metric(value: float): versions = exp._runtime._artifact.list_versions(exp.id) assert len(versions) == 2 + fixed_version = find_unused_version(used_version, versions) + used_version.append(fixed_version) + run_obj = run._get_obj() + assert ( + run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version + ) + time.sleep(1) run = trial.start_run(lambda: log_metric(0.98)) @@ -198,6 +222,13 @@ async def log_metric(value: float): versions = exp._runtime._artifact.list_versions(exp.id) assert len(versions) == 3 + run_obj = run._get_obj() + + fixed_version = find_unused_version(used_version, versions) + used_version.append(fixed_version) + assert ( + run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version + ) trial.done() diff --git a/tests/unit/experiment/test_craft_exp.py b/tests/unit/experiment/test_craft_exp.py index 2ef8f79..9a76871 100644 --- a/tests/unit/experiment/test_craft_exp.py +++ b/tests/unit/experiment/test_craft_exp.py @@ -286,6 +286,7 @@ async def test_craft_experiment_with_hierarchy_timeout(): trial = trial._get_obj() assert trial.status == Status.COMPLETED + @pytest.mark.asyncio async def test_craft_experiment_with_hierarchy_timeout_2(): init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)