Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions alphatrion/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def push(
repo_name: str,
paths: str | list[str],
version: str = "latest",
):
) -> str:
"""
Push files or all files in a folder to the artifact registry.
You can specify either files or folder, but not both.
Expand All @@ -48,13 +48,16 @@ def push(
raise ValueError("No files to push.")

url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._project_id}/{repo_name}:{version}"
path = f"{self._project_id}/{repo_name}:{version}"
target = f"{url}{path}"

try:
self._client.push(target, files=files_to_push)
except Exception as e:
raise RuntimeError("Failed to push artifacts") from e

return path

# TODO: should we store it in the metadb instead?
def list_versions(self, repo_name: str) -> list[str]:
url = self._url if self._url.endswith("/") else f"{self._url}/"
Expand Down
18 changes: 12 additions & 6 deletions alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from alphatrion.trial.trial import current_trial_id
from alphatrion.utils import time as utime

ARTIFACT_PATH = "artifact_path"


async def log_artifact(
paths: str | list[str],
version: str = "latest",
):
) -> str:
"""
Log artifacts (files) to the artifact registry.

Expand All @@ -20,16 +22,16 @@ async def log_artifact(
:param version: the version (tag) to log the files
"""

return log_artifact_sync(
return log_artifact_in_sync(
paths=paths,
version=version,
)


def log_artifact_sync(
def log_artifact_in_sync(
paths: str | list[str],
version: str = "latest",
):
) -> str:
"""
Log artifacts (files) to the artifact registry (synchronous version).

Expand Down Expand Up @@ -61,7 +63,7 @@ def log_artifact_sync(
if exp is None:
raise RuntimeError("No running experiment found in the current context.")

runtime._artifact.push(repo_name=str(exp.id), paths=paths, version=version)
return runtime._artifact.push(repo_name=str(exp.id), paths=paths, version=version)


# log_params is used to save a set of parameters, which is a dict of key-value pairs.
Expand Down Expand Up @@ -128,10 +130,14 @@ async def log_metrics(metrics: dict[str, float]):

# TODO: we should save the artifact path in the metrics as well.
if should_checkpoint:
await log_artifact(
address = await log_artifact(
paths=trial.config().checkpoint.path,
version=utime.now_2_hash(),
)
runtime._metadb.update_run(
run_id=run_id,
meta={ARTIFACT_PATH: address},
)

if should_early_stop or should_stop_on_target:
trial.done()
1 change: 1 addition & 0 deletions alphatrion/metadata/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def update_run(self, run_id: uuid.UUID, **kwargs) -> None:
run = session.query(Run).filter(Run.uuid == run_id, Run.is_del == 0).first()
if run:
for key, value in kwargs.items():
# TODO: meta update should be merged instead of replaced
setattr(run, key, value)
session.commit()
session.close()
Expand Down
2 changes: 0 additions & 2 deletions alphatrion/utils/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,4 @@ def humanize_time(timestamp: str) -> str:
def now_2_hash() -> str:
timestamp = str(int(datetime.now(UTC).timestamp()))
unique_hash = hashlib.sha1(timestamp.encode()).hexdigest()[:7]

print("Generated hash from timestamp:", timestamp, "->", unique_hash)
return unique_hash
33 changes: 32 additions & 1 deletion tests/integration/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

import alphatrion as alpha
from alphatrion.log.log import ARTIFACT_PATH
from alphatrion.metadata.sql_models import Status
from alphatrion.trial.trial import current_trial_id

Expand Down Expand Up @@ -137,11 +138,18 @@ async def log_metric(metrics: dict):

@pytest.mark.asyncio
async def test_log_metrics_with_save_on_max():
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
project_id = uuid.uuid4()
alpha.init(project_id=project_id, artifact_insecure=True, init_tables=True)

async def log_metric(value: float):
await alpha.log_metrics({"accuracy": value})

def find_unused_version(used_versions, all_versions):
for v in all_versions:
if v not in used_versions:
return v
return None

async with alpha.CraftExperiment.setup(
name="log_metrics_with_save_on_max",
description="Context manager test",
Expand Down Expand Up @@ -171,8 +179,17 @@ async def log_metric(value: float):
run = trial.start_run(lambda: log_metric(0.90))
await run.wait()

# We need this because the returned version is unordered.
used_version = []

versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 1
run_obj = run._get_obj()
fixed_version = versions[0]
used_version.append(fixed_version)
assert (
run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version
)

# To avoid the same timestamp hash, we wait for 1 second
time.sleep(1)
Expand All @@ -191,13 +208,27 @@ async def log_metric(value: float):
versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 2

fixed_version = find_unused_version(used_version, versions)
used_version.append(fixed_version)
run_obj = run._get_obj()
assert (
run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version
)

time.sleep(1)

run = trial.start_run(lambda: log_metric(0.98))
await run.wait()

versions = exp._runtime._artifact.list_versions(exp.id)
assert len(versions) == 3
run_obj = run._get_obj()

fixed_version = find_unused_version(used_version, versions)
used_version.append(fixed_version)
assert (
run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version
)

trial.done()

Expand Down
1 change: 1 addition & 0 deletions tests/unit/experiment/test_craft_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ async def test_craft_experiment_with_hierarchy_timeout():
trial = trial._get_obj()
assert trial.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_craft_experiment_with_hierarchy_timeout_2():
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
Expand Down
Loading