diff --git a/Makefile b/Makefile index 37e7b03..ff89b74 100644 --- a/Makefile +++ b/Makefile @@ -55,4 +55,4 @@ seed-cleanup: build-dashboard: rm -rf alphatrion/static/* - cd dashboard && npm install && npm run build \ No newline at end of file + cd dashboard && npm install && npm run build diff --git a/alphatrion/__init__.py b/alphatrion/__init__.py index 597fa4a..9320a36 100644 --- a/alphatrion/__init__.py +++ b/alphatrion/__init__.py @@ -1,18 +1,23 @@ -from alphatrion.experiment.craft_exp import CraftExperiment +from alphatrion.experiment.experiment import ( + CheckpointConfig, + Experiment, + ExperimentConfig, + MonitorMode, +) from alphatrion.log.log import log_artifact, log_metrics, log_params from alphatrion.metadata.sql_models import Status +from alphatrion.project.project import Project from alphatrion.runtime.runtime import init from alphatrion.tracing.tracing import task, workflow -from alphatrion.trial.trial import CheckpointConfig, MonitorMode, Trial, TrialConfig __all__ = [ "init", "log_artifact", "log_params", "log_metrics", - "CraftExperiment", - "Trial", - "TrialConfig", + "Project", + "Experiment", + "ExperimentConfig", "CheckpointConfig", "MonitorMode", "Status", diff --git a/alphatrion/artifact/artifact.py b/alphatrion/artifact/artifact.py index 34a40af..1964a41 100644 --- a/alphatrion/artifact/artifact.py +++ b/alphatrion/artifact/artifact.py @@ -8,8 +8,8 @@ class Artifact: - def __init__(self, project_id: str, insecure: bool = False): - self._project_id = project_id + def __init__(self, team_id: str, insecure: bool = False): + self._team_id = team_id self._url = os.environ.get(consts.ARTIFACT_REGISTRY_URL) self._url = self._url.replace("https://", "").replace("http://", "") self._client = oras.client.OrasClient( @@ -48,7 +48,7 @@ def push( raise ValueError("No files to push.") url = self._url if self._url.endswith("/") else f"{self._url}/" - path = f"{self._project_id}/{repo_name}:{version}" + path = f"{self._team_id}/{repo_name}:{version}" target = f"{url}{path}" try: @@ -61,7 +61,7 @@ def push( # TODO: should we store it in the metadb instead? def list_versions(self, repo_name: str) -> list[str]: url = self._url if self._url.endswith("/") else f"{self._url}/" - target = f"{url}{self._project_id}/{repo_name}" + target = f"{url}{self._team_id}/{repo_name}" try: tags = self._client.get_tags(target) return tags @@ -70,7 +70,7 @@ def list_versions(self, repo_name: str) -> list[str]: def delete(self, repo_name: str, versions: str | list[str]): url = self._url if self._url.endswith("/") else f"{self._url}/" - target = f"{url}{self._project_id}/{repo_name}" + target = f"{url}{self._team_id}/{repo_name}" try: self._client.delete_tags(target, tags=versions) diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py deleted file mode 100644 index db1aa2b..0000000 --- a/alphatrion/experiment/base.py +++ /dev/null @@ -1,134 +0,0 @@ -import uuid -from abc import ABC, abstractmethod -from dataclasses import dataclass - -from pydantic import BaseModel, Field - -from alphatrion.runtime.runtime import global_runtime -from alphatrion.trial import trial -from alphatrion.utils import context - - -class ExperimentConfig(BaseModel): - """Configuration for a Experiment.""" - - max_execution_seconds: int = Field( - default=-1, - description="Maximum execution seconds for the experiment. \ - Once exceeded, the experiment and all its trials will be cancelled. \ - Default is -1 (no limit).", - ) - - -@dataclass -class Experiment(ABC): - """ - Base Experiment class. One instance one experiment, multiple trials. - """ - - __slots__ = ("_runtime", "_id", "_trials", "_config") - - @classmethod - @abstractmethod - def setup( - cls, - name: str, - description: str | None = None, - meta: dict | None = None, - config: ExperimentConfig | None = None, - ) -> "Experiment": - """Return a new experiment.""" - ... - - def __init__(self, config: ExperimentConfig | None = None): - self._runtime = global_runtime() - # All trials in this experiment, key is trial_id, value is Trial instance. - self._trials: dict[int, trial.Trial] = {} - self._config = config or ExperimentConfig() - self._runtime.current_exp = self - - self._context = context.Context( - cancel_func=self._stop, - timeout=self._config.max_execution_seconds - if self._config.max_execution_seconds > 0 - else None, - ) - - @property - def id(self): - return self._id - - def get_trial(self, id: int) -> trial.Trial | None: - return self._trials.get(id) - - async def __aenter__(self): - if self._id is None: - raise RuntimeError("Experiment is not set. Did you call start()?") - - exp = self._get() - if exp is None: - raise RuntimeError(f"Experiment {self._id} not found in the database.") - - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - self.done() - - # done() is safe to call multiple times. - def done(self): - self._cancel() - - def _cancel(self): - return self._context.cancel() - - def _stop(self): - for t in list(self._trials.values()): - t.done() - self._trials = dict() - # Set to None at the end of the experiment because - # it will be used in trial.done(). - self._runtime.current_exp = None - - def register_trial(self, id: uuid.UUID, instance: trial.Trial): - self._trials[id] = instance - - def unregister_trial(self, id: uuid.UUID): - self._trials.pop(id, None) - - def _create( - self, - name: str, - description: str | None = None, - meta: dict | None = None, - ): - """ - :param name: the name of the experiment. - :param description: the description of the experiment - :param meta: the metadata of the experiment - - :return: the experiment ID - """ - - self._id = self._runtime._metadb.create_exp( - name=name, - description=description, - project_id=self._runtime._project_id, - meta=meta, - ) - return self._id - - def _get(self): - return self._runtime._metadb.get_exp(exp_id=self._id) - - def _get_by_name(self, name: str, project_id: str): - return self._runtime._metadb.get_exp_by_name(name=name, project_id=project_id) - - def delete(self): - exp = self._get() - if exp is None: - return - - self._runtime._metadb.delete_exp(exp_id=self._exp.id) - # TODO: Should we make this optional as a parameter? - tags = self._runtime._artifact.list_versions(repo_name=str(self._exp.id)) - self._runtime._artifact.delete(experiment_name=exp.name, versions=tags) diff --git a/alphatrion/experiment/craft_exp.py b/alphatrion/experiment/craft_exp.py deleted file mode 100644 index 4bc4fa0..0000000 --- a/alphatrion/experiment/craft_exp.py +++ /dev/null @@ -1,69 +0,0 @@ -from alphatrion.experiment.base import Experiment, ExperimentConfig -from alphatrion.trial.trial import Trial, TrialConfig - - -class CraftExperiment(Experiment): - """ - Craft experiment implementation. - - This experiment class offers methods to manage the experiment lifecycle flexibly. - Opposite to other experiment classes, you need to call all these methods yourself. - """ - - def __init__(self, config: ExperimentConfig | None = None): - super().__init__(config=config) - - @classmethod - def setup( - cls, - name: str, - description: str | None = None, - meta: dict | None = None, - config: ExperimentConfig | None = None, - ) -> "CraftExperiment": - """ - Setup the experiment. If the name already exists in the same project, - it will refer to the existing experiment instead of creating a new one. - """ - - exp = CraftExperiment(config) - exp_obj = exp._get_by_name(name=name, project_id=exp._runtime._project_id) - - # If experiment with the same name exists in the project, use it. - if exp_obj: - exp._id = exp_obj.uuid - else: - exp._create( - name=name, - description=description, - meta=meta, - ) - - return exp - - def start_trial( - self, - name: str, - description: str | None = None, - meta: dict | None = None, - params: dict | None = None, - config: TrialConfig | None = None, - ) -> Trial: - """ - start_trial starts a new trial in this experiment. - You need to call trial.cancel() to stop the trial for proper cleanup, - unless it's a timeout trial. - Or you can use 'async with exp.start_trial(...) as trial', which will - automatically stop the trial at the end of the context. - - :params description: the description of the trial - :params meta: the metadata of the trial - :params config: the configuration of the trial - - :return: the Trial instance - """ - - trial = Trial(exp_id=self._id, config=config) - trial._start(name=name, description=description, meta=meta, params=params) - self.register_trial(id=trial.id, instance=trial) - return trial diff --git a/alphatrion/trial/trial.py b/alphatrion/experiment/experiment.py similarity index 79% rename from alphatrion/trial/trial.py rename to alphatrion/experiment/experiment.py index ab9c92f..b9b7e15 100644 --- a/alphatrion/trial/trial.py +++ b/alphatrion/experiment/experiment.py @@ -12,7 +12,7 @@ from alphatrion.utils import context # Used in log/log.py to log params/metrics -current_trial_id = contextvars.ContextVar("current_trial_id", default=None) +current_exp_id = contextvars.ContextVar("current_exp_id", default=None) class CheckpointConfig(BaseModel): @@ -56,24 +56,24 @@ class MonitorMode(enum.Enum): MIN = "min" -class TrialConfig(BaseModel): - """Configuration for a Trial.""" +class ExperimentConfig(BaseModel): + """Configuration for a Experiment.""" max_execution_seconds: int = Field( default=-1, - description="Maximum execution seconds for the trial. \ - Trial timeout will override experiment timeout if both are set. \ + description="Maximum execution seconds for the Experiment. \ + Experiment timeout will override project timeout if both are set. \ Default is -1 (no limit).", ) early_stopping_runs: int = Field( default=-1, description="Number of runs with no improvement \ - after which the trial will be stopped. Default is -1 (no early stopping). \ + after which the Experiment will be stopped. Default is -1 (no early stopping). \ Count each time when calling log_metrics with the monitored metric.", ) - max_runs_per_trial: int = Field( + max_runs_per_experiment: int = Field( default=-1, - description="Maximum number of runs for each trial. \ + description="Maximum number of runs for each Experiment. \ Default is -1 (no limit). Count by the finished runs.", ) monitor_metric: str | None = Field( @@ -90,11 +90,11 @@ class TrialConfig(BaseModel): ) target_metric_value: float | None = Field( default=None, - description="If specified, the trial will stop when \ + description="If specified, the Experiment will stop when \ the monitored metric reaches this target value. \ - If monitor_mode is 'max', the trial will stop when \ + If monitor_mode is 'max', the Experiment will stop when \ the metric >= target_metric_value. If monitor_mode is 'min', \ - the trial will stop when the metric <= target_metric_value. \ + the Experiment will stop when the metric <= target_metric_value. \ Default is None (no target).", ) checkpoint: CheckpointConfig = Field( @@ -122,15 +122,15 @@ def metric_must_be_valid(self): return self -class Trial: +class Experiment: __slots__ = ( "_id", - "_exp_id", + "_proj_id", "_config", "_runtime", "_context", "_token", - # _meta stores the runtime meta information of the trial. + # _meta stores the runtime meta information of the experiment. # * best_metrics: dict of best metric values, used for checkpointing and # early stopping. When the workload(e.g. Pod) restarts, the meta info # will be lost and start from scratch. Then once some features like @@ -142,15 +142,15 @@ class Trial: "_runs", # Only work when early_stopping_runs > 0 "_early_stopping_counter", - # Only work when max_runs_per_trial > 0 + # Only work when max_runs_per_experiment > 0 "_total_runs_counter", - # Whether the trial is ended with error. + # Whether the Experiment is ended with error. "_err", ) - def __init__(self, exp_id: int, config: TrialConfig | None = None): - self._exp_id = exp_id - self._config = config or TrialConfig() + def __init__(self, proj_id: uuid.UUID, config: ExperimentConfig | None = None): + self._proj_id = proj_id + self._config = config or ExperimentConfig() self._runtime = global_runtime() self._construct_meta() self._runs = dict[uuid.UUID, Run]() @@ -164,7 +164,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): self.done() if self._token: - current_trial_id.reset(self._token) + current_exp_id.reset(self._token) @property def id(self) -> uuid.UUID: @@ -173,7 +173,8 @@ def id(self) -> uuid.UUID: def _construct_meta(self): self._meta = dict() - # TODO: if restart from existing trial, load the best_metrics from database. + # TODO: if restart from existing experiment, + # load the best_metrics from database. if self._config.monitor_mode == MonitorMode.MAX: self._meta["best_metrics"] = {self._config.monitor_metric: float("-inf")} elif self._config.monitor_mode == MonitorMode.MIN: @@ -181,7 +182,7 @@ def _construct_meta(self): else: raise ValueError(f"Invalid monitor_mode: {self._config.monitor_mode}") - def config(self) -> TrialConfig: + def config(self) -> ExperimentConfig: return self._config def should_checkpoint_on_best(self, metric_key: str, metric_value: float) -> bool: @@ -275,15 +276,16 @@ def _timeout(self) -> int | None: # Make sure you have termination condition, either by timeout or by calling cancel() # Before we have logic like once all the tasks are done, we'll call the cancel() - # automatically, however, this is unpredictable because some tasks may be waiting - # for external events, so we leave it to the user to decide when to stop the trial. + # automatically, however, this is unpredictable because some tasks may wait for + # external events, so we leave it to the user to decide when to stop the experiment. async def wait(self): await self._context.wait() def is_done(self) -> bool: return self._context.cancelled() - # If the name is same in the same experiment, it will refer to the existing trial. + # If the name is same in the same experiment, + # it will refer to the existing experiment. def _start( self, name: str, @@ -291,16 +293,16 @@ def _start( meta: dict | None = None, params: dict | None = None, ): - trial_obj = self._runtime._metadb.get_trial_by_name( - trial_name=name, experiment_id=self._exp_id + exp_obj = self._runtime._metadb.get_exp_by_name( + name=name, project_id=self._proj_id ) - # FIXME: what if the existing trial is completed, will lead to confusion? - if trial_obj: - self._id = trial_obj.uuid + # FIXME: what if the existing Experiment is completed, will lead to confusion? + if exp_obj: + self._id = exp_obj.uuid else: - self._id = self._runtime._metadb.create_trial( - project_id=self._runtime._project_id, - experiment_id=self._exp_id, + self._id = self._runtime._metadb.create_experiment( + team_id=self._runtime._team_id, + project_id=self._proj_id, name=name, description=description, meta=meta, @@ -313,13 +315,13 @@ def _start( timeout=self._timeout(), ) - # We don't reset the trial id context var here, because - # each trial runs in its own context. - self._token = current_trial_id.set(self._id) + # We don't reset the Experiment id context var here, because + # each experiment runs in its own context. + self._token = current_exp_id.set(self._id) # done function should be called manually as a pair of start - # FIXME: watch for system signals to cancel the trial gracefully, - # or it could lead to trial not being marked as completed. + # FIXME: watch for system signals to cancel the Experiment gracefully, + # or it could lead to experiment not being marked as completed. def done(self): self._cancel() @@ -331,35 +333,35 @@ def _cancel(self): self._context.cancel() def _stop(self): - trial = self._runtime._metadb.get_trial(trial_id=self._id) - if trial is not None and trial.status not in FINISHED_STATUS: + exp = self._runtime._metadb.get_experiment(experiment_id=self._id) + if exp is not None and exp.status not in FINISHED_STATUS: duration = ( - datetime.now(UTC) - trial.created_at.replace(tzinfo=UTC) + datetime.now(UTC) - exp.created_at.replace(tzinfo=UTC) ).total_seconds() status = Status.COMPLETED if self._err: status = Status.FAILED - self._runtime._metadb.update_trial( - trial_id=self._id, status=status, duration=duration + self._runtime._metadb.update_experiment( + experiment_id=self._id, status=status, duration=duration ) - self._runtime.current_exp.unregister_trial(self._id) + self._runtime.current_proj.unregister_exp(self._id) for run in self._runs.values(): run.cancel() self._runs.clear() def _get_obj(self): - return self._runtime.metadb.get_trial(trial_id=self._id) + return self._runtime._metadb.get_experiment(experiment_id=self._id) def start_run(self, call_func: callable) -> Run: - """Start a new run for the trial. + """Start a new run for the Experiment. :param call_func: a callable function that returns a coroutine. It must be a async and lambda function. :return: the Run instance.""" - run = Run(trial_id=self._id) + run = Run(exp_id=self.id) run.start(call_func) self._runs[run.id] = run @@ -376,7 +378,7 @@ def _post_run(self, run: Run): run.done() if ( - self._config.max_runs_per_trial > 0 - and self._total_runs_counter >= self._config.max_runs_per_trial + self._config.max_runs_per_experiment > 0 + and self._total_runs_counter >= self._config.max_runs_per_experiment ): self.done() diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index bedfb0f..54e2d5f 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -1,8 +1,8 @@ from collections.abc import Callable +from alphatrion.experiment.experiment import current_exp_id from alphatrion.run.run import current_run_id from alphatrion.runtime.runtime import global_runtime -from alphatrion.trial.trial import current_trial_id from alphatrion.utils import time as utime ARTIFACT_PATH = "artifact_path" @@ -34,14 +34,12 @@ async def log_artifact( return log_artifact_in_sync( paths=paths, version=version, - pre_save_hook=pre_save_hook, ) def log_artifact_in_sync( paths: str | list[str], version: str = "latest", - pre_save_hook: Callable | None = None, ) -> str: """ Log artifacts (files) to the artifact registry (synchronous version). @@ -68,34 +66,33 @@ def log_artifact_in_sync( "Set ENABLE_ARTIFACT_STORAGE=true in the environment variables." ) - # We use experiment ID as the repo name rather than the experiment name, - # because experiment name is not unique - exp = runtime.current_exp - if exp is None: - raise RuntimeError("No running experiment found in the current context.") - - return runtime._artifact.push(repo_name=str(exp.id), paths=paths, version=version) + # We use project ID as the repo name rather than the project name, + # because project name is not unique + proj = runtime.current_proj + if proj is None: + raise RuntimeError("No running project found in the current context.") + return runtime._artifact.push(repo_name=str(proj.id), paths=paths, version=version) # log_params is used to save a set of parameters, which is a dict of key-value pairs. -# should be called after starting a trial. +# should be called after starting a Experiment. async def log_params(params: dict): - trial_id = current_trial_id.get() - if trial_id is None: - raise RuntimeError("log_params must be called inside a Trial.") + exp_id = current_exp_id.get() + if exp_id is None: + raise RuntimeError("log_params must be called inside a Experiment.") runtime = global_runtime() # TODO: should we upload to the artifact as well? - # current_trial_id is protect by contextvar, so it's safe to use in async - runtime._metadb.update_trial( - trial_id=trial_id, + # current_exp_id is protect by contextvar, so it's safe to use in async + runtime._metadb.update_experiment( + experiment_id=exp_id, params=params, ) # log_metrics is used to log a set of metrics at once, # metric key must be string, value must be float. -# If save_on_best is enabled in the trial config, and the metric is the best metric -# so far, the trial will checkpoint the current data. +# If save_on_best is enabled in the experiment config, and the metric is the best metric +# so far, the experiment will checkpoint the current data. # # Note: log_metrics can only be called inside a Run, because it needs a run_id. async def log_metrics(metrics: dict[str, float]): @@ -104,15 +101,15 @@ async def log_metrics(metrics: dict[str, float]): raise RuntimeError("log_metrics must be called inside a Run.") runtime = global_runtime() - exp = runtime.current_exp + proj = runtime.current_proj - trial_id = current_trial_id.get() - if trial_id is None: - raise RuntimeError("log_metrics must be called inside a Trial.") + exp_id = current_exp_id.get() + if exp_id is None: + raise RuntimeError("log_metrics must be called inside a Experiment.") - trial = exp.get_trial(id=trial_id) - if trial is None: - raise RuntimeError(f"Trial {trial_id} not found in the database.") + exp = proj.get_experiment(id=exp_id) + if exp is None: + raise RuntimeError(f"Experiment {exp_id} not found in the database.") # track if any metric is the best metric should_checkpoint = False @@ -122,28 +119,28 @@ async def log_metrics(metrics: dict[str, float]): runtime._metadb.create_metric( key=key, value=value, - project_id=runtime._project_id, - experiment_id=exp.id, - trial_id=trial_id, + team_id=runtime._team_id, + project_id=proj.id, + experiment_id=exp_id, run_id=run_id, ) # TODO: should we save the checkpoint path for the best metric? # Always call the should_checkpoint_on_best first because # it also updates the best metric. - should_checkpoint |= trial.should_checkpoint_on_best( + should_checkpoint |= exp.should_checkpoint_on_best( metric_key=key, metric_value=value ) - should_early_stop |= trial.should_early_stop(metric_key=key, metric_value=value) - should_stop_on_target |= trial.should_stop_on_target_metric( + should_early_stop |= exp.should_early_stop(metric_key=key, metric_value=value) + should_stop_on_target |= exp.should_stop_on_target_metric( metric_key=key, metric_value=value ) if should_checkpoint: address = await log_artifact( - paths=trial.config().checkpoint.path, + paths=exp.config().checkpoint.path, version=utime.now_2_hash(), - pre_save_hook=trial.config().checkpoint.pre_save_hook, + pre_save_hook=exp.config().checkpoint.pre_save_hook, ) runtime._metadb.update_run( run_id=run_id, @@ -151,4 +148,4 @@ async def log_metrics(metrics: dict[str, float]): ) if should_early_stop or should_stop_on_target: - trial.done() + exp.done() diff --git a/alphatrion/metadata/base.py b/alphatrion/metadata/base.py index 6142fb8..d53e0bc 100644 --- a/alphatrion/metadata/base.py +++ b/alphatrion/metadata/base.py @@ -1,51 +1,57 @@ import uuid from abc import ABC, abstractmethod -from alphatrion.metadata.sql_models import Model, Trial +from alphatrion.metadata.sql_models import Experiment, Model class MetaStore(ABC): """Base class for all metadata storage backends.""" @abstractmethod - def get_project(self, project_id: uuid.UUID): + def create_team( + self, name: str, description: str | None = None, meta: dict | None = None + ) -> uuid.UUID: + raise NotImplementedError("Subclasses must implement this method.") + + @abstractmethod + def get_team(self, team_id: uuid.UUID): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def create_exp( + def create_project( self, name: str, - project_id: uuid.UUID, + team_id: uuid.UUID, description: str | None = None, meta: dict | None = None, ) -> int: raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def delete_exp(self, exp_id: uuid.UUID): + def delete_project(self, project_id: uuid.UUID): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def update_exp(self, exp_id: uuid.UUID, **kwargs): + def update_project(self, project_id: uuid.UUID, **kwargs): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def get_exp(self, exp_id: uuid.UUID): + def get_project(self, project_id: uuid.UUID): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def get_exp_by_name(self, name: str, project_id: uuid.UUID): + def get_proj_by_name(self, name: str, team_id: uuid.UUID): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def list_exps(self, project_id: uuid.UUID, page: int, page_size: int): + def list_projects(self, team_id: uuid.UUID, page: int, page_size: int): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod def create_model( self, name: str, - project_id: uuid.UUID, + team_id: uuid.UUID, version: str = "latest", description: str | None = None, meta: dict | None = None, @@ -61,7 +67,7 @@ def get_model(self, model_id: uuid.UUID) -> Model | None: raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def list_models(self, project_id: uuid.UUID, page: int, page_size: int): + def list_models(self, team_id: uuid.UUID, page: int, page_size: int): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod @@ -69,10 +75,10 @@ def delete_model(self, model_id: uuid.UUID): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def create_trial( + def create_experiment( self, + team_id: uuid.UUID, project_id: uuid.UUID, - experiment_id: uuid.UUID, name: str, description: str | None = None, meta: dict | None = None, @@ -81,22 +87,22 @@ def create_trial( raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def get_trial(self, trial_id: uuid.UUID) -> Trial | None: + def get_experiment(self, exp_id: uuid.UUID) -> Experiment | None: raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def get_trial_by_name(self, name: str, experiment_id: uuid.UUID) -> Trial | None: + def get_exp_by_name(self, name: str, project_id: uuid.UUID) -> Experiment | None: raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def update_trial(self, trial_id: uuid.UUID, **kwargs): + def update_experiment(self, experiment_id: uuid.UUID, **kwargs): raise NotImplementedError("Subclasses must implement this method.") @abstractmethod def create_run( self, + team_id: uuid.UUID, project_id: uuid.UUID, - trial_id: uuid.UUID, experiment_id: uuid.UUID, meta: dict | None = None, ) -> int: @@ -105,9 +111,9 @@ def create_run( @abstractmethod def create_metric( self, + team_id: uuid.UUID, project_id: uuid.UUID, experiment_id: uuid.UUID, - trial_id: uuid.UUID, run_id: uuid.UUID, key: str, value: float, @@ -115,5 +121,5 @@ def create_metric( raise NotImplementedError("Subclasses must implement this method.") @abstractmethod - def list_metrics_by_trial_id(self, trial_id: uuid.UUID) -> list[dict]: + def list_metrics_by_experiment_id(self, experiment_id: uuid.UUID) -> list[dict]: raise NotImplementedError("Subclasses must implement this method.") diff --git a/alphatrion/metadata/sql.py b/alphatrion/metadata/sql.py index ad7d7f4..00d7cbd 100644 --- a/alphatrion/metadata/sql.py +++ b/alphatrion/metadata/sql.py @@ -12,7 +12,7 @@ Project, Run, Status, - Trial, + Team, ) @@ -26,137 +26,135 @@ def __init__(self, db_url: str, init_tables: bool = False): # Mostly used in tests. Base.metadata.create_all(self._engine) - def create_project( + def create_team( self, name: str, description: str | None = None, meta: dict | None = None ) -> uuid.UUID: session = self._session() - new_project = Project( + new_team = Team( name=name, description=description, meta=meta, ) - session.add(new_project) + session.add(new_team) session.commit() - project_id = new_project.uuid + team_id = new_team.uuid session.close() - return project_id + return team_id - def get_project(self, project_id: uuid.UUID) -> Project | None: + def get_team(self, team_id: uuid.UUID) -> Team | None: session = self._session() - project = ( - session.query(Project) - .filter(Project.uuid == project_id, Project.is_del == 0) - .first() + team = ( + session.query(Team).filter(Team.uuid == team_id, Team.is_del == 0).first() ) session.close() - return project + return team - def list_projects(self, page: int, page_size: int) -> list[Project]: + def list_teams(self, page: int, page_size: int) -> list[Team]: session = self._session() - projects = ( - session.query(Project) - .filter(Project.is_del == 0) + teams = ( + session.query(Team) + .filter(Team.is_del == 0) .offset(page * page_size) .limit(page_size) .all() ) session.close() - return projects + return teams - def create_exp( + def create_project( self, name: str, - project_id: uuid.UUID, + team_id: uuid.UUID, description: str | None = None, meta: dict | None = None, ) -> uuid.UUID: session = self._session() - new_exp = Experiment( + new_proj = Project( name=name, - project_id=project_id, + team_id=team_id, description=description, meta=meta, ) - session.add(new_exp) + session.add(new_proj) session.commit() - exp_id = new_exp.uuid + exp_id = new_proj.uuid session.close() return exp_id - # Soft delete the experiment now. In the future, we may implement hard delete. - def delete_exp(self, exp_id: uuid.UUID): + # Soft delete the project now. + def delete_project(self, project_id: uuid.UUID): session = self._session() - exp = ( - session.query(Experiment) - .filter(Experiment.uuid == exp_id, Experiment.is_del == 0) + proj = ( + session.query(Project) + .filter(Project.uuid == project_id, Project.is_del == 0) .first() ) - if exp: - exp.is_del = 1 + if proj: + proj.is_del = 1 session.commit() session.close() # We don't support append-only update, the complete fields should be provided. - def update_exp(self, exp_id: uuid.UUID, **kwargs) -> None: + def update_project(self, project_id: uuid.UUID, **kwargs) -> None: session = self._session() - exp = ( - session.query(Experiment) - .filter(Experiment.uuid == exp_id, Experiment.is_del == 0) + proj = ( + session.query(Project) + .filter(Project.uuid == project_id, Project.is_del == 0) .first() ) - if exp: + if proj: for key, value in kwargs.items(): - setattr(exp, key, value) + setattr(proj, key, value) session.commit() session.close() - # get_exp will ignore the deleted experiments. - def get_exp(self, exp_id: uuid.UUID) -> Experiment | None: + # get function will ignore the deleted ones. + def get_project(self, project_id: uuid.UUID) -> Project | None: session = self._session() - exp = ( - session.query(Experiment) - .filter(Experiment.uuid == exp_id, Experiment.is_del == 0) + proj = ( + session.query(Project) + .filter(Project.uuid == project_id, Project.is_del == 0) .first() ) session.close() - return exp + return proj - def get_exp_by_name(self, name: str, project_id: uuid.UUID) -> Experiment | None: + def get_proj_by_name(self, name: str, team_id: uuid.UUID) -> Project | None: session = self._session() - exp = ( - session.query(Experiment) + proj = ( + session.query(Project) .filter( - Experiment.name == name, - Experiment.project_id == project_id, - Experiment.is_del == 0, + Project.name == name, + Project.team_id == team_id, + Project.is_del == 0, ) .first() ) session.close() - return exp + return proj - # paginate the experiments in case of too many experiments. - def list_exps( - self, project_id: uuid.UUID, page: int = 0, page_size: int = 10 - ) -> list[Experiment]: + # paginate the projects in case of too many projects. + def list_projects( + self, team_id: uuid.UUID, page: int = 0, page_size: int = 10 + ) -> list[Project]: session = self._session() - exps = ( - session.query(Experiment) - .filter(Experiment.project_id == project_id, Experiment.is_del == 0) + projects = ( + session.query(Project) + .filter(Project.team_id == team_id, Project.is_del == 0) .offset(page * page_size) .limit(page_size) .all() ) session.close() - return exps + return projects def create_model( self, name: str, - project_id: uuid.UUID, + team_id: uuid.UUID, version: str = "latest", description: str | None = None, meta: dict | None = None, @@ -164,7 +162,7 @@ def create_model( session = self._session() new_model = Model( name=name, - project_id=project_id, + team_id=team_id, version=version, description=description, meta=meta, @@ -223,106 +221,104 @@ def delete_model(self, model_id: uuid.UUID): session.commit() session.close() - def create_trial( + def create_experiment( self, name: str, + team_id: uuid.UUID, project_id: uuid.UUID, - experiment_id: uuid.UUID, description: str | None = None, meta: dict | None = None, params: dict | None = None, status: Status = Status.PENDING, ) -> uuid.UUID: session = self._session() - new_trial = Trial( + new_exp = Experiment( + team_id=team_id, project_id=project_id, - experiment_id=experiment_id, name=name, description=description, meta=meta, params=params, status=status, ) - session.add(new_trial) + session.add(new_exp) session.commit() - trial_id = new_trial.uuid + exp_id = new_exp.uuid session.close() - return trial_id + return exp_id - def get_trial(self, trial_id: uuid.UUID) -> Trial | None: + def get_experiment(self, experiment_id: uuid.UUID) -> Experiment | None: session = self._session() - trial = ( - session.query(Trial) - .filter(Trial.uuid == trial_id, Trial.is_del == 0) + exp = ( + session.query(Experiment) + .filter(Experiment.uuid == experiment_id, Experiment.is_del == 0) .first() ) session.close() - return trial + return exp - # TODO: should we use join to get the trial by experiment name? - def get_trial_by_name( - self, trial_name: str, experiment_id: uuid.UUID - ) -> Trial | None: - # make sure the experiment exists - exp = self.get_exp(experiment_id) - if exp is None: + # Different project may have the same experiment name. + def get_exp_by_name(self, name: str, project_id: uuid.UUID) -> Experiment | None: + # make sure the project exists + proj = self.get_project(project_id) + if proj is None: return None session = self._session() trial = ( - session.query(Trial) + session.query(Experiment) .filter( - Trial.name == trial_name, - Trial.experiment_id == experiment_id, - Trial.is_del == 0, + Experiment.name == name, + Experiment.project_id == project_id, + Experiment.is_del == 0, ) .first() ) session.close() return trial - def list_trials_by_experiment_id( - self, experiment_id: uuid.UUID, page: int = 0, page_size: int = 10 - ) -> list[Trial]: + def list_exps_by_project_id( + self, project_id: uuid.UUID, page: int = 0, page_size: int = 10 + ) -> list[Experiment]: session = self._session() - trials = ( - session.query(Trial) - .filter(Trial.experiment_id == experiment_id, Trial.is_del == 0) + exps = ( + session.query(Experiment) + .filter(Experiment.project_id == project_id, Experiment.is_del == 0) .offset(page * page_size) .limit(page_size) .all() ) session.close() - return trials + return exps - def update_trial(self, trial_id: uuid.UUID, **kwargs) -> None: + def update_experiment(self, experiment_id: uuid.UUID, **kwargs) -> None: session = self._session() - trial = ( - session.query(Trial) - .filter(Trial.uuid == trial_id, Trial.is_del == 0) + exp = ( + session.query(Experiment) + .filter(Experiment.uuid == experiment_id, Experiment.is_del == 0) .first() ) - if trial: + if exp: for key, value in kwargs.items(): - setattr(trial, key, value) + setattr(exp, key, value) session.commit() session.close() def create_run( self, + team_id: uuid.UUID, project_id: uuid.UUID, experiment_id: uuid.UUID, - trial_id: uuid.UUID, meta: dict | None = None, status: Status = Status.PENDING, ) -> uuid.UUID: session = self._session() new_run = Run( project_id=project_id, + team_id=team_id, experiment_id=experiment_id, - trial_id=trial_id, meta=meta, status=status, ) @@ -349,13 +345,13 @@ def get_run(self, run_id: uuid.UUID) -> Run | None: session.close() return run - def list_runs_by_trial_id( - self, trial_id: uuid.UUID, page: int = 0, page_size: int = 10 + def list_runs_by_exp_id( + self, exp_id: uuid.UUID, page: int = 0, page_size: int = 10 ) -> list[Run]: session = self._session() runs = ( session.query(Run) - .filter(Run.trial_id == trial_id, Run.is_del == 0) + .filter(Run.experiment_id == exp_id, Run.is_del == 0) .offset(page * page_size) .limit(page_size) .all() @@ -365,18 +361,18 @@ def list_runs_by_trial_id( def create_metric( self, + team_id: uuid.UUID, project_id: uuid.UUID, experiment_id: uuid.UUID, - trial_id: uuid.UUID, run_id: uuid.UUID, key: str, value: float, ) -> uuid.UUID: session = self._session() new_metric = Metric( + team_id=team_id, project_id=project_id, experiment_id=experiment_id, - trial_id=trial_id, run_id=run_id, key=key, value=value, @@ -387,13 +383,13 @@ def create_metric( session.close() return new_metric_id - def list_metrics_by_trial_id( - self, trial_id: uuid.UUID, page: int = 0, page_size: int = 10 + def list_metrics_by_experiment_id( + self, experiment_id: uuid.UUID, page: int = 0, page_size: int = 10 ) -> list[Metric]: session = self._session() metrics = ( session.query(Metric) - .filter(Metric.trial_id == trial_id) + .filter(Metric.experiment_id == experiment_id) .offset(page * page_size) .limit(page_size) .all() diff --git a/alphatrion/metadata/sql_models.py b/alphatrion/metadata/sql_models.py index 7ea7736..7110945 100644 --- a/alphatrion/metadata/sql_models.py +++ b/alphatrion/metadata/sql_models.py @@ -30,13 +30,13 @@ class Status(enum.IntEnum): FINISHED_STATUS = [Status.COMPLETED, Status.FAILED, Status.CANCELLED] -class Project(Base): - __tablename__ = "projects" +class Team(Base): + __tablename__ = "teams" uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = Column(String, nullable=False) description = Column(String, nullable=True) - meta = Column(JSON, nullable=True, comment="Additional metadata for the project") + meta = Column(JSON, nullable=True, comment="Additional metadata for the team") created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) updated_at = Column( @@ -47,26 +47,15 @@ class Project(Base): is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") -class ExperimentType(enum.IntEnum): - UNKNOWN = 0 - CRAFT_EXPERIMENT = 1 - - -# Define the Experiment model for SQLAlchemy -class Experiment(Base): - __tablename__ = "experiments" +# Define the Project model for SQLAlchemy +class Project(Base): + __tablename__ = "projects" uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = Column(String, nullable=False) description = Column(String, nullable=True) - project_id = Column(UUID(as_uuid=True), nullable=False) - meta = Column(JSON, nullable=True, comment="Additional metadata for the experiment") - kind = Column( - Integer, - default=ExperimentType.CRAFT_EXPERIMENT, - nullable=False, - comment="Type of the experiment", - ) + team_id = Column(UUID(as_uuid=True), nullable=False) + meta = Column(JSON, nullable=True, comment="Additional metadata for the project") created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) updated_at = Column( @@ -77,16 +66,27 @@ class Experiment(Base): is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") -class Trial(Base): - __tablename__ = "trials" +class ExperimentType(enum.IntEnum): + UNKNOWN = 0 + CRAFT_EXPERIMENT = 1 + + +class Experiment(Base): + __tablename__ = "experiments" uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + team_id = Column(UUID(as_uuid=True), nullable=False) project_id = Column(UUID(as_uuid=True), nullable=False) - experiment_id = Column(UUID(as_uuid=True), nullable=False) name = Column(String, nullable=False) description = Column(String, nullable=True) meta = Column(JSON, nullable=True, comment="Additional metadata for the trial") params = Column(JSON, nullable=True, comment="Parameters for the experiment") + kind = Column( + Integer, + default=ExperimentType.CRAFT_EXPERIMENT, + nullable=False, + comment="Type of the experiment", + ) duration = Column(Float, default=0.0, comment="Duration of the trial in seconds") status = Column( Integer, @@ -110,9 +110,9 @@ class Run(Base): __tablename__ = "runs" uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + team_id = Column(UUID(as_uuid=True), nullable=False) project_id = Column(UUID(as_uuid=True), nullable=False) experiment_id = Column(UUID(as_uuid=True), nullable=False) - trial_id = Column(UUID(as_uuid=True), nullable=False) meta = Column(JSON, nullable=True, comment="Additional metadata for the run") status = Column( Integer, @@ -138,7 +138,7 @@ class Model(Base): uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = Column(String, nullable=False, unique=True) description = Column(String, nullable=True) - project_id = Column(UUID(as_uuid=True), nullable=False) + team_id = Column(UUID(as_uuid=True), nullable=False) version = Column(String, nullable=False) meta = Column(JSON, nullable=True, comment="Additional metadata for the model") @@ -151,15 +151,15 @@ class Model(Base): is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") -# TODO: key, project_id, experiment_id, trial_id, run_id should be unique together +# TODO: key, team_id, project_id, experiment_id, run_id should be unique together class Metric(Base): __tablename__ = "metrics" uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) key = Column(String, nullable=False) value = Column(Float, nullable=False) + team_id = Column(UUID(as_uuid=True), nullable=False) project_id = Column(UUID(as_uuid=True), nullable=False) experiment_id = Column(UUID(as_uuid=True), nullable=False) - trial_id = Column(UUID(as_uuid=True), nullable=False) run_id = Column(UUID(as_uuid=True), nullable=False) created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) diff --git a/alphatrion/model/model.py b/alphatrion/model/model.py index 0fc618d..ed03b78 100644 --- a/alphatrion/model/model.py +++ b/alphatrion/model/model.py @@ -8,13 +8,13 @@ def __init__(self, runtime: Runtime): def create( self, name: str, - project_id: str, + team_id: str, description: str | None = None, meta: dict | None = None, ): return self._runtime._metadb.create_model( name=name, - project_id=project_id, + team_id=team_id, description=description, meta=meta, ) diff --git a/alphatrion/trial/__init__.py b/alphatrion/project/__init__.py similarity index 100% rename from alphatrion/trial/__init__.py rename to alphatrion/project/__init__.py diff --git a/alphatrion/project/project.py b/alphatrion/project/project.py new file mode 100644 index 0000000..0aa81a5 --- /dev/null +++ b/alphatrion/project/project.py @@ -0,0 +1,175 @@ +import uuid + +from pydantic import BaseModel, Field + +from alphatrion.experiment import experiment +from alphatrion.runtime.runtime import global_runtime +from alphatrion.utils import context + + +class ProjectConfig(BaseModel): + """Configuration for a Project.""" + + max_execution_seconds: int = Field( + default=-1, + description="Maximum execution seconds for the project. \ + Once exceeded, the project and all its experiments will be cancelled. \ + Default is -1 (no limit).", + ) + + +class Project: + """ + Project represents a collection of experiments. + """ + + def __init__(self, config: ProjectConfig | None = None): + self._runtime = global_runtime() + # All experiments in this project, + # key is experiment_id, value is Experiment instance. + self._experiments: dict[int, experiment.Experiment] = {} + self._config = config or ProjectConfig() + self._runtime.current_proj = self + + self._context = context.Context( + cancel_func=self._stop, + timeout=self._config.max_execution_seconds + if self._config.max_execution_seconds > 0 + else None, + ) + + @property + def id(self): + return self._id + + def get_experiment(self, id: int) -> experiment.Experiment | None: + return self._experiments.get(id) + + async def __aenter__(self): + if self._id is None: + raise RuntimeError("Project is not set. Did you call start()?") + + project = self._get() + if project is None: + raise RuntimeError(f"Project {self._id} not found in the database.") + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.done() + + # done() is safe to call multiple times. + def done(self): + self._cancel() + + def _cancel(self): + return self._context.cancel() + + def _stop(self): + for t in list(self._experiments.values()): + t.done() + self._experiments = dict() + # Set to None at the end of the project because + # it will be used in experiment.done(). + self._runtime.current_proj = None + + def register_exp(self, id: uuid.UUID, instance: experiment.Experiment): + self._experiments[id] = instance + + def unregister_exp(self, id: uuid.UUID): + self._experiments.pop(id, None) + + def _create( + self, + name: str, + description: str | None = None, + meta: dict | None = None, + ): + """ + :param name: the name of the project. + :param description: the description of the project + :param meta: the metadata of the project + + :return: the project ID + """ + + self._id = self._runtime._metadb.create_project( + name=name, + description=description, + team_id=self._runtime._team_id, + meta=meta, + ) + return self._id + + def _get(self): + return self._runtime._metadb.get_project(project_id=self._id) + + def _get_by_name(self, name: str): + return self._runtime._metadb.get_proj_by_name( + name=name, team_id=self._runtime._team_id + ) + + def delete(self): + exp = self._get() + if exp is None: + return + + self._runtime._metadb.delete_project(project_id=self._id) + # TODO: Should we make this optional as a parameter? + tags = self._runtime._artifact.list_versions(repo_name=str(self._id)) + self._runtime._artifact.delete(repo_name=str(self._id), versions=tags) + + @classmethod + def setup( + cls, + name: str, + description: str | None = None, + meta: dict | None = None, + config: experiment.ExperimentConfig | None = None, + ) -> "Project": + """ + Setup the experiment. If the name already exists in the same project, + it will refer to the existing experiment instead of creating a new one. + """ + + proj = Project(config) + proj_obj = proj._get_by_name(name=name) + + # If project with the same name exists in the project, use it. + if proj_obj: + proj._id = proj_obj.uuid + else: + proj._create( + name=name, + description=description, + meta=meta, + ) + + return proj + + def start_experiment( + self, + name: str, + description: str | None = None, + meta: dict | None = None, + params: dict | None = None, + config: experiment.ExperimentConfig | None = None, + ) -> experiment.Experiment: + """ + start_experiment starts a new experiment in this project. + You need to call experiment.cancel() to stop the experiment for proper cleanup, + unless it's a timeout experiment. + Or you can use 'async with exp.start_experiment(...) as experiment', which will + automatically stop the experiment at the end of the context. + + :params description: the description of the experiment + :params meta: the metadata of the experiment + :params config: the configuration of the Experiment + + :return: the Experiment instance + """ + + exp = experiment.Experiment(proj_id=self._id, config=config) + exp._start(name=name, description=description, meta=meta, params=params) + self.register_exp(id=exp.id, instance=exp) + return exp diff --git a/alphatrion/run/run.py b/alphatrion/run/run.py index c350666..a400140 100644 --- a/alphatrion/run/run.py +++ b/alphatrion/run/run.py @@ -9,11 +9,11 @@ class Run: - __slots__ = ("_id", "_task", "_runtime", "_trial_id") + __slots__ = ("_id", "_task", "_runtime", "_exp_id") - def __init__(self, trial_id: uuid.UUID): + def __init__(self, exp_id: uuid.UUID): self._runtime = global_runtime() - self._trial_id = trial_id + self._exp_id = exp_id @property def id(self) -> uuid.UUID: @@ -24,9 +24,9 @@ def _get_obj(self): def start(self, call_func: callable) -> None: self._id = self._runtime._metadb.create_run( - project_id=self._runtime._project_id, - experiment_id=self._runtime.current_exp.id, - trial_id=self._trial_id, + team_id=self._runtime._team_id, + project_id=self._runtime.current_proj.id, + experiment_id=self._exp_id, status=Status.RUNNING, ) @@ -34,7 +34,7 @@ def start(self, call_func: callable) -> None: token = current_run_id.set(self.id) try: # The created task will also inherit the current context, - # including the current_trial_id, current_run_id context var. + # including the current_exp_id, current_run_id context var. self._task = asyncio.create_task(call_func()) finally: current_run_id.reset(token) diff --git a/alphatrion/runtime/runtime.py b/alphatrion/runtime/runtime.py index e0664d8..d1dbbab 100644 --- a/alphatrion/runtime/runtime.py +++ b/alphatrion/runtime/runtime.py @@ -10,7 +10,7 @@ def init( - project_id: uuid.UUID, + team_id: uuid.UUID, artifact_insecure: bool = False, init_tables: bool = False, ): @@ -24,7 +24,7 @@ def init( """ global __RUNTIME__ __RUNTIME__ = Runtime( - project_id=project_id, + team_id=team_id, artifact_insecure=artifact_insecure, init_tables=init_tables, ) @@ -37,37 +37,35 @@ def global_runtime(): # Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc. -# Stateful information will also be stored here, e.g., current running experiment ID. +# Stateful information will also be stored here, e.g., current running Project. class Runtime: - __slots__ = ("_project_id", "_metadb", "_artifact", "__current_exp") + __slots__ = ("_team_id", "_metadb", "_artifact", "__current_proj") def __init__( self, - project_id: str, + team_id: uuid.UUID, artifact_insecure: bool = False, init_tables: bool = False, ): - self._project_id = project_id + self._team_id = team_id self._metadb = SQLStore( os.getenv(consts.METADATA_DB_URL), init_tables=init_tables ) if self.artifact_storage_enabled(): - self._artifact = Artifact( - project_id=self._project_id, insecure=artifact_insecure - ) + self._artifact = Artifact(team_id=self._team_id, insecure=artifact_insecure) def artifact_storage_enabled(self) -> bool: return os.getenv(consts.ENABLE_ARTIFACT_STORAGE, "true").lower() == "true" - # current_exp is the current running experiment. + # current_proj is the current running Project. @property - def current_exp(self): - return self.__current_exp + def current_proj(self): + return self.__current_proj - @current_exp.setter - def current_exp(self, value) -> None: - self.__current_exp = value + @current_proj.setter + def current_proj(self, value) -> None: + self.__current_proj = value @property def metadb(self) -> SQLStore: diff --git a/alphatrion/server/graphql/resolvers.py b/alphatrion/server/graphql/resolvers.py index 0bc66e2..f212126 100644 --- a/alphatrion/server/graphql/resolvers.py +++ b/alphatrion/server/graphql/resolvers.py @@ -11,39 +11,76 @@ Metric, Project, Run, - Trial, + Team, ) class GraphQLResolvers: @staticmethod - def list_projects(page: int = 0, page_size: int = 10) -> list[Project]: + def list_teams(page: int = 0, page_size: int = 10) -> list[Team]: metadb = runtime.graphql_runtime().metadb - projects = metadb.list_projects(page=page, page_size=page_size) + teams = metadb.list_teams(page=page, page_size=page_size) + return [ + Team( + id=t.uuid, + name=t.name, + description=t.description, + meta=t.meta, + created_at=t.created_at, + updated_at=t.updated_at, + ) + for t in teams + ] + + @staticmethod + def get_team(id: str) -> Team | None: + metadb = runtime.graphql_runtime().metadb + team = metadb.get_team(team_id=uuid.UUID(id)) + if team: + return Team( + id=team.uuid, + name=team.name, + description=team.description, + meta=team.meta, + created_at=team.created_at, + updated_at=team.updated_at, + ) + return None + + @staticmethod + def list_projects( + team_id: str, page: int = 0, page_size: int = 10 + ) -> list[Project]: + metadb = runtime.graphql_runtime().metadb + projects = metadb.list_projects( + team_id=uuid.UUID(team_id), page=page, page_size=page_size + ) return [ Project( - id=p.uuid, - name=p.name, - description=p.description, - meta=p.meta, - created_at=p.created_at, - updated_at=p.updated_at, + id=proj.uuid, + team_id=proj.team_id, + name=proj.name, + description=proj.description, + meta=proj.meta, + created_at=proj.created_at, + updated_at=proj.updated_at, ) - for p in projects + for proj in projects ] @staticmethod def get_project(id: str) -> Project | None: metadb = runtime.graphql_runtime().metadb - project = metadb.get_project(project_id=uuid.UUID(id)) - if project: + proj = metadb.get_project(project_id=uuid.UUID(id)) + if proj: return Project( - id=project.uuid, - name=project.name, - description=project.description, - meta=project.meta, - created_at=project.created_at, - updated_at=project.updated_at, + id=proj.uuid, + team_id=proj.team_id, + name=proj.name, + description=proj.description, + meta=proj.meta, + created_at=proj.created_at, + updated_at=proj.updated_at, ) return None @@ -52,34 +89,42 @@ def list_experiments( project_id: str, page: int = 0, page_size: int = 10 ) -> list[Experiment]: metadb = runtime.graphql_runtime().metadb - exps = metadb.list_exps( + exps = metadb.list_exps_by_project_id( project_id=uuid.UUID(project_id), page=page, page_size=page_size ) return [ Experiment( - id=exp.uuid, - project_id=exp.project_id, - name=exp.name, - description=exp.description, - meta=exp.meta, - kind=GraphQLExperimentTypeEnum[GraphQLExperimentType(exp.kind).name], - created_at=exp.created_at, - updated_at=exp.updated_at, + id=e.uuid, + team_id=e.team_id, + project_id=e.project_id, + name=e.name, + description=e.description, + meta=e.meta, + params=e.params, + duration=e.duration, + status=GraphQLStatusEnum[Status(e.status).name], + kind=GraphQLExperimentTypeEnum[GraphQLExperimentType(e.kind).name], + created_at=e.created_at, + updated_at=e.updated_at, ) - for exp in exps + for e in exps ] @staticmethod def get_experiment(id: str) -> Experiment | None: metadb = runtime.graphql_runtime().metadb - exp = metadb.get_exp(exp_id=uuid.UUID(id)) + exp = metadb.get_experiment(experiment_id=uuid.UUID(id)) if exp: return Experiment( id=exp.uuid, + team_id=exp.team_id, project_id=exp.project_id, name=exp.name, description=exp.description, meta=exp.meta, + params=exp.params, + duration=exp.duration, + status=GraphQLStatusEnum[Status(exp.status).name], kind=GraphQLExperimentTypeEnum[GraphQLExperimentType(exp.kind).name], created_at=exp.created_at, updated_at=exp.updated_at, @@ -87,60 +132,15 @@ def get_experiment(id: str) -> Experiment | None: return None @staticmethod - def list_trials( - experiment_id: str, page: int = 0, page_size: int = 10 - ) -> list[Trial]: - metadb = runtime.graphql_runtime().metadb - trials = metadb.list_trials_by_experiment_id( - experiment_id=uuid.UUID(experiment_id), page=page, page_size=page_size - ) - return [ - Trial( - id=t.uuid, - experiment_id=t.experiment_id, - project_id=t.project_id, - name=t.name, - description=t.description, - meta=t.meta, - params=t.params, - duration=t.duration, - status=GraphQLStatusEnum[Status(t.status).name], - created_at=t.created_at, - updated_at=t.updated_at, - ) - for t in trials - ] - - @staticmethod - def get_trial(id: str) -> Trial | None: - metadb = runtime.graphql_runtime().metadb - trial = metadb.get_trial(trial_id=uuid.UUID(id)) - if trial: - return Trial( - id=trial.uuid, - experiment_id=trial.experiment_id, - project_id=trial.project_id, - name=trial.name, - description=trial.description, - meta=trial.meta, - params=trial.params, - duration=trial.duration, - status=GraphQLStatusEnum[Status(trial.status).name], - created_at=trial.created_at, - updated_at=trial.updated_at, - ) - return None - - @staticmethod - def list_runs(trial_id: str, page: int = 0, page_size: int = 10) -> list[Run]: + def list_runs(experiment_id: str, page: int = 0, page_size: int = 10) -> list[Run]: metadb = runtime.graphql_runtime().metadb - runs = metadb.list_runs_by_trial_id( - trial_id=uuid.UUID(trial_id), page=page, page_size=page_size + runs = metadb.list_runs_by_exp_id( + exp_id=uuid.UUID(experiment_id), page=page, page_size=page_size ) return [ Run( id=r.uuid, - trial_id=r.trial_id, + team_id=r.team_id, project_id=r.project_id, experiment_id=r.experiment_id, meta=r.meta, @@ -157,7 +157,7 @@ def get_run(id: str) -> Run | None: if run: return Run( id=run.uuid, - trial_id=run.trial_id, + team_id=run.team_id, project_id=run.project_id, experiment_id=run.experiment_id, meta=run.meta, @@ -167,21 +167,21 @@ def get_run(id: str) -> Run | None: return None @staticmethod - def list_trial_metrics( - trial_id: str, page: int = 0, page_size: int = 10 + def list_exp_metrics( + experiment_id: str, page: int = 0, page_size: int = 10 ) -> list[Metric]: metadb = runtime.graphql_runtime().metadb - metrics = metadb.list_metrics_by_trial_id( - trial_id=uuid.UUID(trial_id), page=page, page_size=page_size + metrics = metadb.list_metrics_by_experiment_id( + experiment_id=uuid.UUID(experiment_id), page=page, page_size=page_size ) return [ Metric( id=m.uuid, key=m.key, value=m.value, + team_id=m.team_id, project_id=m.project_id, experiment_id=m.experiment_id, - trial_id=m.trial_id, run_id=m.run_id, created_at=m.created_at, ) diff --git a/alphatrion/server/graphql/schema.py b/alphatrion/server/graphql/schema.py index 3118778..2b9b897 100644 --- a/alphatrion/server/graphql/schema.py +++ b/alphatrion/server/graphql/schema.py @@ -1,20 +1,30 @@ import strawberry from alphatrion.server.graphql.resolvers import GraphQLResolvers -from alphatrion.server.graphql.types import Experiment, Metric, Project, Run, Trial +from alphatrion.server.graphql.types import Experiment, Metric, Project, Run, Team @strawberry.type class Query: - projects: list[Project] = strawberry.field(resolver=GraphQLResolvers.list_projects) - project: Project | None = strawberry.field(resolver=GraphQLResolvers.get_project) + teams: list[Team] = strawberry.field(resolver=GraphQLResolvers.list_teams) + team: Team | None = strawberry.field(resolver=GraphQLResolvers.get_team) @strawberry.field - def experiments( + def projects( self, - project_id: str, + team_id: str, page: int = 0, page_size: int = 10, + ) -> list[Project]: + return GraphQLResolvers.list_projects( + team_id=team_id, page=page, page_size=page_size + ) + + project: Project | None = strawberry.field(resolver=GraphQLResolvers.get_project) + + @strawberry.field + def experiments( + self, project_id: str, page: int = 0, page_size: int = 10 ) -> list[Experiment]: return GraphQLResolvers.list_experiments( project_id=project_id, page=page, page_size=page_size @@ -25,25 +35,15 @@ def experiments( ) @strawberry.field - def trials( - self, experiment_id: str, page: int = 0, page_size: int = 10 - ) -> list[Trial]: - return GraphQLResolvers.list_trials( - experiment_id=experiment_id, page=page, page_size=page_size - ) - - trial: Trial | None = strawberry.field(resolver=GraphQLResolvers.get_trial) - - @strawberry.field - def runs(self, trial_id: str, page: int = 0, page_size: int = 10) -> list[Run]: + def runs(self, experiment_id: str, page: int = 0, page_size: int = 10) -> list[Run]: return GraphQLResolvers.list_runs( - trial_id=trial_id, page=page, page_size=page_size + experiment_id=experiment_id, page=page, page_size=page_size ) run: Run | None = strawberry.field(resolver=GraphQLResolvers.get_run) trial_metrics: list[Metric] = strawberry.field( - resolver=GraphQLResolvers.list_trial_metrics + resolver=GraphQLResolvers.list_exp_metrics ) diff --git a/alphatrion/server/graphql/types.py b/alphatrion/server/graphql/types.py index 67c85c6..5226e6e 100644 --- a/alphatrion/server/graphql/types.py +++ b/alphatrion/server/graphql/types.py @@ -6,7 +6,7 @@ @strawberry.type -class Project: +class Team: id: strawberry.ID name: str | None description: str | None @@ -15,20 +15,11 @@ class Project: updated_at: datetime -class GraphQLExperimentType(Enum): - UNKNOWN = 0 - CRAFT_EXPERIMENT = 1 - - -GraphQLExperimentTypeEnum = strawberry.enum(GraphQLExperimentType) - - @strawberry.type -class Experiment: +class Project: id: strawberry.ID - project_id: strawberry.ID | None + team_id: strawberry.ID name: str | None - kind: GraphQLExperimentTypeEnum description: str | None meta: JSON | None created_at: datetime @@ -47,13 +38,22 @@ class GraphQLStatus(Enum): GraphQLStatusEnum = strawberry.enum(GraphQLStatus) +class GraphQLExperimentType(Enum): + UNKNOWN = 0 + CRAFT_EXPERIMENT = 1 + + +GraphQLExperimentTypeEnum = strawberry.enum(GraphQLExperimentType) + + @strawberry.type -class Trial: +class Experiment: id: strawberry.ID - experiment_id: strawberry.ID + team_id: strawberry.ID project_id: strawberry.ID name: str description: str | None + kind: GraphQLExperimentTypeEnum meta: JSON | None params: JSON | None duration: float @@ -65,7 +65,7 @@ class Trial: @strawberry.type class Run: id: strawberry.ID - trial_id: strawberry.ID + team_id: strawberry.ID project_id: strawberry.ID experiment_id: strawberry.ID meta: JSON | None @@ -78,8 +78,8 @@ class Metric: id: strawberry.ID key: str | None value: float | None + team_id: strawberry.ID project_id: strawberry.ID experiment_id: strawberry.ID - trial_id: strawberry.ID run_id: strawberry.ID created_at: datetime diff --git a/hack/seed.py b/hack/seed.py index b1548d4..a23b585 100644 --- a/hack/seed.py +++ b/hack/seed.py @@ -20,7 +20,7 @@ Project, Run, Status, - Trial, + Team, ) load_dotenv() @@ -47,8 +47,8 @@ def make_json_serializable(obj): return obj -def generate_project() -> Project: - return Project( +def generate_team() -> Team: + return Team( uuid=uuid.uuid4(), name=fake.bs().title(), description=fake.catch_phrase(), @@ -58,22 +58,22 @@ def generate_project() -> Project: ) -def generate_experiment(projects: list[Project]) -> Experiment: - return Experiment( +def generate_project(teams: list[Team]) -> Project: + return Project( name=fake.bs().title(), description=fake.catch_phrase(), meta=make_json_serializable( fake.pydict(nb_elements=3, variable_nb_elements=True) ), - project_id=random.choice(projects).uuid, + team_id=random.choice(teams).uuid, ) -def generate_trial(exps: list[Experiment]) -> Trial: - exp = random.choice(exps) - return Trial( - project_id=exp.project_id, - experiment_id=exp.uuid, +def generate_experiment(projects: list[Project]) -> Experiment: + proj = random.choice(projects) + return Experiment( + team_id=proj.team_id, + project_id=proj.uuid, name=fake.bs().title(), description=fake.catch_phrase(), meta=make_json_serializable( @@ -87,12 +87,12 @@ def generate_trial(exps: list[Experiment]) -> Trial: ) -def generate_run(trials: list[Trial]) -> Run: - trial = random.choice(trials) +def generate_run(exps: list[Experiment]) -> Run: + exp = random.choice(exps) return Run( - project_id=trial.project_id, - experiment_id=trial.experiment_id, - trial_id=trial.uuid, + team_id=exp.team_id, + project_id=exp.project_id, + experiment_id=exp.uuid, meta=make_json_serializable( fake.pydict(nb_elements=2, variable_nb_elements=True) ), @@ -103,9 +103,9 @@ def generate_run(trials: list[Trial]) -> Run: def generate_metric(runs: list[Run]) -> Metric: run = random.choice(runs) return Metric( + team_id=run.team_id, project_id=run.project_id, experiment_id=run.experiment_id, - trial_id=run.trial_id, run_id=run.uuid, key=random.choice(["accuracy", "loss", "precision", "fitness"]), value=random.uniform(0, 1), @@ -113,39 +113,37 @@ def generate_metric(runs: list[Run]) -> Metric: def seed_all( - num_projects: int, - num_exps_per_project: int, - num_trials_per_exp: int, - num_runs_per_trial: int, + num_teams: int, + num_projs_per_team: int, + num_exps_per_proj: int, + num_runs_per_exp: int, num_metrics_per_run: int, ): Base.metadata.create_all(bind=engine) print("🌱 generating seeds ...") - projects = [generate_project() for _ in range(num_projects)] - session.add_all(projects) + teams = [generate_team() for _ in range(num_teams)] + session.add_all(teams) session.commit() - experiments = [ - generate_experiment(projects) - for _ in range(num_exps_per_project) - for _ in range(len(projects)) + projs = [ + generate_project(teams) + for _ in range(num_projs_per_team) + for _ in range(len(teams)) ] - session.add_all(experiments) + session.add_all(projs) session.commit() - trials = [ - generate_trial(experiments) - for _ in range(num_trials_per_exp) - for _ in range(len(experiments)) + exps = [ + generate_experiment(projs) + for _ in range(num_exps_per_proj) + for _ in range(len(projs)) ] - session.add_all(trials) + session.add_all(exps) session.commit() runs = [ - generate_run(trials) - for _ in range(num_runs_per_trial) - for _ in range(len(trials)) + generate_run(exps) for _ in range(num_runs_per_exp) for _ in range(len(exps)) ] session.add_all(runs) session.commit() @@ -165,7 +163,7 @@ def cleanup(): print("🧹 cleaning up seeded data ...") session.query(Metric).delete() session.query(Run).delete() - session.query(Trial).delete() + session.query(Team).delete() session.query(Experiment).delete() session.query(Project).delete() session.commit() @@ -183,10 +181,10 @@ def cleanup(): cleanup() elif action == "seed": seed_all( - num_projects=3, - num_exps_per_project=10, - num_trials_per_exp=10, - num_runs_per_trial=20, + num_teams=3, + num_projs_per_team=10, + num_exps_per_proj=10, + num_runs_per_exp=20, num_metrics_per_run=30, ) else: diff --git a/migrations/versions/03410247c6b7_remove_step_from_trial.py b/migrations/versions/03410247c6b7_remove_step_from_trial.py deleted file mode 100644 index 50e2923..0000000 --- a/migrations/versions/03410247c6b7_remove_step_from_trial.py +++ /dev/null @@ -1,32 +0,0 @@ -"""remove step from trial - -Revision ID: 03410247c6b7 -Revises: c89fc8504699 -Create Date: 2026-01-16 15:17:18.057002 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '03410247c6b7' -down_revision: Union[str, Sequence[str], None] = 'c89fc8504699' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('metrics', 'step') - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.add_column('metrics', sa.Column('step', sa.INTEGER(), autoincrement=False, nullable=False)) - # ### end Alembic commands ### diff --git a/migrations/versions/c89fc8504699_add_project_meta.py b/migrations/versions/c89fc8504699_add_project_meta.py deleted file mode 100644 index 8ce347b..0000000 --- a/migrations/versions/c89fc8504699_add_project_meta.py +++ /dev/null @@ -1,32 +0,0 @@ -"""add project meta - -Revision ID: c89fc8504699 -Revises: 648be86800d3 -Create Date: 2025-12-01 15:40:10.591556 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'c89fc8504699' -down_revision: Union[str, Sequence[str], None] = '648be86800d3' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.add_column('projects', sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the project')) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('projects', 'meta') - # ### end Alembic commands ### diff --git a/migrations/versions/648be86800d3_init.py b/migrations/versions/c98aa69beda7_init_schema.py similarity index 89% rename from migrations/versions/648be86800d3_init.py rename to migrations/versions/c98aa69beda7_init_schema.py index 353d62a..3c11fad 100644 --- a/migrations/versions/648be86800d3_init.py +++ b/migrations/versions/c98aa69beda7_init_schema.py @@ -1,8 +1,8 @@ -"""init +"""init schema -Revision ID: 648be86800d3 +Revision ID: c98aa69beda7 Revises: -Create Date: 2025-11-27 18:12:04.027865 +Create Date: 2026-01-29 01:10:13.441947 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. -revision: str = '648be86800d3' +revision: str = 'c98aa69beda7' down_revision: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -23,11 +23,15 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table('experiments', sa.Column('uuid', sa.UUID(), nullable=False), + sa.Column('team_id', sa.UUID(), nullable=False), + sa.Column('project_id', sa.UUID(), nullable=False), sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.String(), nullable=True), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the experiment'), + sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the trial'), + sa.Column('params', sa.JSON(), nullable=True, comment='Parameters for the experiment'), sa.Column('kind', sa.Integer(), nullable=False, comment='Type of the experiment'), + sa.Column('duration', sa.Float(), nullable=True, comment='Duration of the trial in seconds'), + sa.Column('status', sa.Integer(), nullable=False, comment='Status of the trial, 0: UNKNOWN, 1: PENDING, 2: RUNNING, 9: COMPLETED, 10: CANCELLED, 11: FAILED'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), sa.Column('is_del', sa.Integer(), nullable=True, comment='0 for not deleted, 1 for deleted'), @@ -37,11 +41,10 @@ def upgrade() -> None: sa.Column('uuid', sa.UUID(), nullable=False), sa.Column('key', sa.String(), nullable=False), sa.Column('value', sa.Float(), nullable=False), + sa.Column('team_id', sa.UUID(), nullable=False), sa.Column('project_id', sa.UUID(), nullable=False), sa.Column('experiment_id', sa.UUID(), nullable=False), - sa.Column('trial_id', sa.UUID(), nullable=False), sa.Column('run_id', sa.UUID(), nullable=False), - sa.Column('step', sa.Integer(), nullable=False), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), sa.PrimaryKeyConstraint('uuid') ) @@ -49,7 +52,7 @@ def upgrade() -> None: sa.Column('uuid', sa.UUID(), nullable=False), sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.String(), nullable=True), - sa.Column('project_id', sa.UUID(), nullable=False), + sa.Column('team_id', sa.UUID(), nullable=False), sa.Column('version', sa.String(), nullable=False), sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the model'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), @@ -62,6 +65,8 @@ def upgrade() -> None: sa.Column('uuid', sa.UUID(), nullable=False), sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.String(), nullable=True), + sa.Column('team_id', sa.UUID(), nullable=False), + sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the project'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), sa.Column('is_del', sa.Integer(), nullable=True, comment='0 for not deleted, 1 for deleted'), @@ -69,9 +74,9 @@ def upgrade() -> None: ) op.create_table('runs', sa.Column('uuid', sa.UUID(), nullable=False), + sa.Column('team_id', sa.UUID(), nullable=False), sa.Column('project_id', sa.UUID(), nullable=False), sa.Column('experiment_id', sa.UUID(), nullable=False), - sa.Column('trial_id', sa.UUID(), nullable=False), sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the run'), sa.Column('status', sa.Integer(), nullable=False, comment='Status of the run, 0: UNKNOWN, 1: PENDING, 2: RUNNING, 9: COMPLETED, 10: CANCELLED, 11: FAILED'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), @@ -79,16 +84,11 @@ def upgrade() -> None: sa.Column('is_del', sa.Integer(), nullable=True, comment='0 for not deleted, 1 for deleted'), sa.PrimaryKeyConstraint('uuid') ) - op.create_table('trials', + op.create_table('teams', sa.Column('uuid', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('experiment_id', sa.UUID(), nullable=False), sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.String(), nullable=True), - sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the trial'), - sa.Column('params', sa.JSON(), nullable=True, comment='Parameters for the experiment'), - sa.Column('duration', sa.Float(), nullable=True, comment='Duration of the trial in seconds'), - sa.Column('status', sa.Integer(), nullable=False, comment='Status of the trial, 0: UNKNOWN, 1: PENDING, 2: RUNNING, 9: COMPLETED, 10: CANCELLED, 11: FAILED'), + sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the team'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), sa.Column('is_del', sa.Integer(), nullable=True, comment='0 for not deleted, 1 for deleted'), @@ -100,7 +100,7 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('trials') + op.drop_table('teams') op.drop_table('runs') op.drop_table('projects') op.drop_table('models') diff --git a/tests/integration/server/test_graphql_query.py b/tests/integration/server/test_graphql_query.py index a6b4713..9543b9b 100644 --- a/tests/integration/server/test_graphql_query.py +++ b/tests/integration/server/test_graphql_query.py @@ -1,3 +1,5 @@ +# ruff: noqa: E501 + # test query from graphql endpoint import uuid @@ -7,14 +9,14 @@ from alphatrion.server.graphql.schema import schema -def test_query_single_project(): +def test_query_single_team(): init(init_tables=True) metadb = graphql_runtime().metadb - id = metadb.create_project(name="Test Project", description="A project for testing") + id = metadb.create_team(name="Test Team", description="A team for testing") query = f""" query {{ - project(id: "{id}") {{ + team(id: "{id}") {{ id name description @@ -29,23 +31,24 @@ def test_query_single_project(): variable_values={}, ) assert response.errors is None - assert response.data["project"]["id"] == str(id) - assert response.data["project"]["name"] == "Test Project" + assert response.data["team"]["id"] == str(id) + assert response.data["team"]["name"] == "Test Team" -def test_query_projects(): +def test_query_teams(): init(init_tables=True) + metadb = graphql_runtime().metadb - _ = metadb.create_project( - name="Test Project1", description="A project for testing", meta={"foo": "bar"} + _ = metadb.create_team( + name="Test Team1", description="A team for testing", meta={"foo": "bar"} ) - _ = metadb.create_project( - name="Test Project2", description="A project for testing", meta={"baz": 123} + _ = metadb.create_team( + name="Test Team2", description="Another team for testing", meta={"baz": 123} ) query = """ query { - projects { + teams { id name description @@ -60,28 +63,28 @@ def test_query_projects(): variable_values={}, ) assert response.errors is None - assert len(response.data["projects"]) >= 2 + assert len(response.data["teams"]) >= 2 -def test_query_single_experiment(): +def test_query_single_project(): init(init_tables=True) - project_id = uuid.uuid4() + + team_id = uuid.uuid4() metadb = graphql_runtime().metadb - id = metadb.create_exp( - name="Test Experiment", - description="A experiment for testing", - project_id=project_id, + id = metadb.create_project( + name="Test Project", + description="A project for testing", + team_id=team_id, ) query = f""" query {{ - experiment(id: "{id}") {{ + project(id: "{id}") {{ id - projectId + teamId name description meta - kind createdAt updatedAt }} @@ -92,39 +95,38 @@ def test_query_single_experiment(): variable_values={}, ) assert response.errors is None - assert response.data["experiment"]["id"] == str(id) - assert response.data["experiment"]["name"] == "Test Experiment" + assert response.data["project"]["id"] == str(id) + assert response.data["project"]["name"] == "Test Project" -def test_query_experiments(): +def test_query_projects(): init(init_tables=True) - project_id = uuid.uuid4() + team_id = uuid.uuid4() metadb = graphql_runtime().metadb - _ = metadb.create_exp( - name="Test Experiment1", - description="A experiment for testing", - project_id=project_id, + _ = metadb.create_project( + name="Test Project1", + description="A project for testing", + team_id=team_id, ) - _ = metadb.create_exp( - name="Test Experiment2", - description="A experiment for testing", - project_id=project_id, + _ = metadb.create_project( + name="Test Project2", + description="A project for testing", + team_id=team_id, ) - _ = metadb.create_exp( - name="Test Experiment2", - description="A experiment for testing", - project_id=uuid.uuid4(), + _ = metadb.create_project( + name="Test Project2", + description="A project for testing", + team_id=uuid.uuid4(), ) query = f""" query {{ - experiments(projectId: "{project_id}", page: 0, pageSize: 10) {{ + projects(teamId: "{team_id}", page: 0, pageSize: 10) {{ id - projectId + teamId name description meta - kind createdAt updatedAt }} @@ -135,33 +137,34 @@ def test_query_experiments(): variable_values={}, ) assert response.errors is None - assert len(response.data["experiments"]) == 2 + assert len(response.data["projects"]) == 2 -def test_query_single_trial(): +def test_query_single_exp(): init(init_tables=True) + team_id = uuid.uuid4() project_id = uuid.uuid4() - experiment_id = uuid.uuid4() metadb = graphql_runtime().metadb - trial_id = metadb.create_trial( - name="Test Trial", + exp_id = metadb.create_experiment( + name="Test Experiment", + team_id=team_id, project_id=project_id, - experiment_id=experiment_id, status=Status.RUNNING, meta={}, ) query = f""" query {{ - trial(id: "{trial_id}") {{ + experiment(id: "{exp_id}") {{ id + teamId projectId - experimentId meta params duration status + kind createdAt updatedAt }} @@ -172,38 +175,38 @@ def test_query_single_trial(): variable_values={}, ) assert response.errors is None - assert "trial" in response.data - assert response.data["trial"]["id"] == str(trial_id) - assert response.data["trial"]["experimentId"] == str(experiment_id) - assert response.data["trial"]["projectId"] == str(project_id) + assert "experiment" in response.data + assert response.data["experiment"]["id"] == str(exp_id) + assert response.data["experiment"]["projectId"] == str(project_id) -def test_query_trials(): +def test_query_experiments(): init(init_tables=True) + team_id = uuid.uuid4() project_id = uuid.uuid4() - experiment_id = uuid.uuid4() metadb = graphql_runtime().metadb - _ = metadb.create_trial( - name="Test Trial1", - experiment_id=experiment_id, + _ = metadb.create_experiment( + name="Test Experiment1", + team_id=team_id, project_id=project_id, ) - _ = metadb.create_trial( - name="Test Trial2", - experiment_id=experiment_id, + _ = metadb.create_experiment( + name="Test Experiment2", + team_id=team_id, project_id=project_id, ) query = f""" query {{ - trials(experimentId: "{experiment_id}", page: 0, pageSize: 10) {{ + experiments(projectId: "{project_id}", page: 0, pageSize: 10) {{ id + teamId projectId - experimentId name description params duration + kind status createdAt updatedAt @@ -215,26 +218,26 @@ def test_query_trials(): variable_values={}, ) assert response.errors is None - assert len(response.data["trials"]) == 2 + assert len(response.data["experiments"]) == 2 def test_query_single_run(): init(init_tables=True) + team_id = uuid.uuid4() project_id = uuid.uuid4() - trial_id = uuid.uuid4() exp_id = uuid.uuid4() metadb = graphql_runtime().metadb run_id = metadb.create_run( + team_id=team_id, project_id=project_id, experiment_id=exp_id, - trial_id=trial_id, ) response = schema.execute_sync( f""" query {{ run(id: "{run_id}") {{ id - trialId + teamId projectId experimentId meta @@ -247,35 +250,35 @@ def test_query_single_run(): ) assert response.errors is None assert response.data["run"]["id"] == str(run_id) + assert response.data["run"]["teamId"] == str(team_id) assert response.data["run"]["projectId"] == str(project_id) assert response.data["run"]["experimentId"] == str(exp_id) - assert response.data["run"]["trialId"] == str(trial_id) def test_query_runs(): init(init_tables=True) + team_id = uuid.uuid4() project_id = uuid.uuid4() exp_id = uuid.uuid4() - trial_id = uuid.uuid4() metadb = graphql_runtime().metadb _ = metadb.create_run( + team_id=team_id, project_id=project_id, experiment_id=exp_id, - trial_id=trial_id, ) _ = metadb.create_run( + team_id=team_id, project_id=project_id, experiment_id=exp_id, - trial_id=trial_id, ) query = f""" query {{ - runs(trialId: "{trial_id}", page: 0, pageSize: 10) {{ + runs(experimentId: "{exp_id}", page: 0, pageSize: 10) {{ id - trialId - experimentId + teamId projectId + experimentId meta status createdAt @@ -292,36 +295,36 @@ def test_query_runs(): def test_query_trial_metrics(): init(init_tables=True) + team_id = uuid.uuid4() project_id = uuid.uuid4() experiment_id = uuid.uuid4() - trial_id = uuid.uuid4() metadb = graphql_runtime().metadb _ = metadb.create_metric( + team_id=team_id, project_id=project_id, experiment_id=experiment_id, - trial_id=trial_id, run_id=uuid.uuid4(), key="accuracy", value=0.95, ) _ = metadb.create_metric( + team_id=team_id, project_id=project_id, experiment_id=experiment_id, - trial_id=trial_id, run_id=uuid.uuid4(), key="accuracy", value=0.95, ) query = f""" query {{ - trialMetrics(trialId: "{trial_id}") {{ + trialMetrics(experimentId: "{experiment_id}") {{ id key value + teamId projectId experimentId - trialId runId createdAt }} @@ -334,6 +337,6 @@ def test_query_trial_metrics(): assert response.errors is None assert len(response.data["trialMetrics"]) == 2 for metric in response.data["trialMetrics"]: + assert metric["teamId"] == str(team_id) assert metric["projectId"] == str(project_id) assert metric["experimentId"] == str(experiment_id) - assert metric["trialId"] == str(trial_id) diff --git a/tests/integration/test_artifact.py b/tests/integration/test_artifact.py index 6088922..604a21b 100644 --- a/tests/integration/test_artifact.py +++ b/tests/integration/test_artifact.py @@ -11,14 +11,14 @@ @pytest.fixture def artifact(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) artifact = global_runtime()._artifact yield artifact def test_push_with_files(artifact): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) @@ -41,7 +41,7 @@ def test_push_with_files(artifact): def test_push_with_folder(artifact): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) diff --git a/tests/integration/test_craft_experiment.py b/tests/integration/test_craft_experiment.py index 783bc26..1150887 100644 --- a/tests/integration/test_craft_experiment.py +++ b/tests/integration/test_craft_experiment.py @@ -7,33 +7,33 @@ @pytest.mark.asyncio -async def test_integration_craft_experiment(): - trial_id = None +async def test_integration_project(): + exp_id = None async def fake_work(duration: int): await asyncio.sleep(duration) print("duration done:", duration) - async with alpha.CraftExperiment.setup( - name="integration_test_exp", - description="Integration test for CraftExperiment", - meta={"test_case": "integration_craft_experiment"}, - ) as exp: - async with exp.start_trial( - name="integration_test_trial", - description="Trial for integration test", - meta={"trial_case": "integration_craft_trial"}, - config=alpha.TrialConfig(max_runs_per_trial=2), - ) as trial: - trial_id = trial.id - - trial.start_run(lambda: fake_work(1)) - trial.start_run(lambda: fake_work(2)) - trial.start_run(lambda: fake_work(4)) - trial.start_run(lambda: fake_work(5)) - trial.start_run(lambda: fake_work(6)) - - await trial.wait() + async with alpha.Project.setup( + name="integration_test_project", + description="Integration test for Project", + meta={"test_case": "integration_project"}, + ) as proj: + async with proj.start_experiment( + name="integration_test_experiment", + description="Experiment for integration test", + meta={"experiment_case": "integration_project_experiment"}, + config=alpha.ExperimentConfig(max_runs_per_experiment=2), + ) as exp: + exp_id = exp.id + + exp.start_run(lambda: fake_work(1)) + exp.start_run(lambda: fake_work(2)) + exp.start_run(lambda: fake_work(4)) + exp.start_run(lambda: fake_work(5)) + exp.start_run(lambda: fake_work(6)) + + await exp.wait() runtime = global_runtime() @@ -41,7 +41,7 @@ async def fake_work(duration: int): # Or the result below will always be right. await asyncio.sleep(1) - runs = runtime.metadb.list_runs_by_trial_id(trial_id=trial_id) + runs = runtime.metadb.list_runs_by_exp_id(exp_id=exp_id) assert len(runs) == 5 completed_runs = [run for run in runs if run.status == alpha.Status.COMPLETED] assert len(completed_runs) == 2 diff --git a/tests/integration/test_log.py b/tests/integration/test_log.py index 13e6e1a..9166c2d 100644 --- a/tests/integration/test_log.py +++ b/tests/integration/test_log.py @@ -8,24 +8,24 @@ import pytest import alphatrion as alpha +from alphatrion.experiment.experiment import current_exp_id from alphatrion.log.log import ARTIFACT_PATH from alphatrion.metadata.sql_models import Status -from alphatrion.trial.trial import current_trial_id @pytest.mark.asyncio async def test_log_artifact(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - async with alpha.CraftExperiment.setup( - name="log_artifact_exp", + async with alpha.Project.setup( + name="log_artifact_project", description="Context manager test", meta={"key": "value"}, - ) as exp: - trial = exp.start_trial(name="first-trial") + ) as proj: + exp = proj.start_experiment(name="first-exp") - exp_obj = exp._runtime._metadb.get_exp(exp_id=exp._id) - assert exp_obj is not None + proj_obj = proj._runtime._metadb.get_project(project_id=proj._id) + assert proj_obj is not None with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) @@ -35,7 +35,7 @@ async def test_log_artifact(): f.write("This is file1.") await alpha.log_artifact(paths="file1.txt", version="v1") - versions = exp._runtime._artifact.list_versions(exp_obj.uuid) + versions = exp._runtime._artifact.list_versions(proj_obj.uuid) assert "v1" in versions with open("file1.txt", "w") as f: @@ -43,76 +43,80 @@ async def test_log_artifact(): # push folder instead await alpha.log_artifact(paths=["file1.txt"], version="v2") - versions = exp._runtime._artifact.list_versions(exp_obj.uuid) + versions = exp._runtime._artifact.list_versions(proj_obj.uuid) assert "v2" in versions exp._runtime._artifact.delete( - repo_name=exp_obj.uuid, + repo_name=proj_obj.uuid, versions=["v1", "v2"], ) - versions = exp._runtime._artifact.list_versions(exp_obj.uuid) + versions = exp._runtime._artifact.list_versions(proj_obj.uuid) assert len(versions) == 0 - trial.done() + exp.done() - got_exp = exp._runtime._metadb.get_exp(exp_id=exp._id) - assert got_exp is not None - assert got_exp.name == "log_artifact_exp" + got_proj = proj._runtime._metadb.get_project(project_id=proj._id) + assert got_proj is not None + assert got_proj.name == "log_artifact_project" - got_trial = exp._runtime._metadb.get_trial(trial_id=trial._id) - assert got_trial is not None - assert got_trial.name == "first-trial" - assert got_trial.status == Status.COMPLETED + got_exp = proj._runtime._metadb.get_experiment(experiment_id=exp.id) + assert got_exp is not None + assert got_exp.name == "first-exp" + assert got_exp.status == Status.COMPLETED @pytest.mark.asyncio async def test_log_params(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - async with alpha.CraftExperiment.setup(name="log_params_exp") as exp: - trial = exp.start_trial(name="first-trial", params={"param1": 0.1}) + async with alpha.Project.setup(name="log_params_proj") as proj: + exp = proj.start_experiment(name="first-exp", params={"param1": 0.1}) - new_trial = exp._runtime._metadb.get_trial(trial_id=trial.id) - assert new_trial is not None - assert new_trial.params == {"param1": 0.1} + new_exp = proj._runtime._metadb.get_experiment(experiment_id=exp.id) + assert new_exp is not None + assert new_exp.params == {"param1": 0.1} params = {"param1": 0.2} await alpha.log_params(params=params) - new_trial = exp._runtime._metadb.get_trial(trial_id=trial.id) - assert new_trial is not None - assert new_trial.params == {"param1": 0.2} - assert new_trial.status == Status.RUNNING - assert current_trial_id.get() == trial.id + new_exp = exp._runtime._metadb.get_experiment(experiment_id=exp.id) + assert new_exp is not None + assert new_exp.params == {"param1": 0.2} + assert new_exp.status == Status.RUNNING + assert current_exp_id.get() == exp.id - trial.done() + exp.done() - trial = exp.start_trial(name="second-trial", params={"param1": 0.1}) - assert current_trial_id.get() == trial.id - trial.done() + exp = proj.start_experiment(name="second-exp", params={"param1": 0.1}) + assert current_exp_id.get() == exp.id + exp.done() @pytest.mark.asyncio async def test_log_metrics(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async def log_metric(metrics: dict): await alpha.log_metrics(metrics) - async with alpha.CraftExperiment.setup(name="log_metrics_exp") as exp: - trial = exp.start_trial(name="first-trial", params={"param1": 0.1}) + async with alpha.Project.setup(name="log_metrics_exp") as proj: + exp = proj.start_experiment(name="first-exp", params={"param1": 0.1}) - new_trial = exp._runtime._metadb.get_trial(trial_id=trial._id) - assert new_trial is not None - assert new_trial.params == {"param1": 0.1} + new_exp = exp._runtime._metadb.get_experiment(experiment_id=exp.id) + assert new_exp is not None + assert new_exp.params == {"param1": 0.1} - metrics = exp._runtime._metadb.list_metrics_by_trial_id(trial_id=trial._id) + metrics = proj._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) assert len(metrics) == 0 - run = trial.start_run(lambda: log_metric({"accuracy": 0.95, "loss": 0.1})) + run = exp.start_run(lambda: log_metric({"accuracy": 0.95, "loss": 0.1})) await run.wait() - metrics = exp._runtime._metadb.list_metrics_by_trial_id(trial_id=trial._id) + metrics = exp._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) assert len(metrics) == 2 assert metrics[0].key == "accuracy" assert metrics[0].value == 0.95 @@ -122,10 +126,12 @@ async def log_metric(metrics: dict): assert run_id_1 is not None assert metrics[0].run_id == metrics[1].run_id - run = trial.start_run(lambda: log_metric({"accuracy": 0.96})) + run = exp.start_run(lambda: log_metric({"accuracy": 0.96})) await run.wait() - metrics = exp._runtime._metadb.list_metrics_by_trial_id(trial_id=trial._id) + metrics = proj._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) assert len(metrics) == 3 assert metrics[2].key == "accuracy" assert metrics[2].value == 0.96 @@ -133,13 +139,13 @@ async def log_metric(metrics: dict): assert run_id_2 is not None assert run_id_2 != run_id_1 - trial.done() + exp.done() @pytest.mark.asyncio async def test_log_metrics_with_save_on_max(): - project_id = uuid.uuid4() - alpha.init(project_id=project_id, artifact_insecure=True, init_tables=True) + team_id = uuid.uuid4() + alpha.init(team_id=team_id, artifact_insecure=True, init_tables=True) async def log_metric(value: float): await alpha.log_metrics({"accuracy": value}) @@ -150,11 +156,11 @@ def find_unused_version(used_versions, all_versions): return v return None - async with alpha.CraftExperiment.setup( + async with alpha.Project.setup( name="log_metrics_with_save_on_max", description="Context manager test", meta={"key": "value"}, - ) as exp: + ) as proj: with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) file = "file.txt" @@ -163,9 +169,9 @@ def pre_save_hook(): with open(file, "a") as f: f.write("This is pre_save_hook modified file.\n") - trial = exp.start_trial( - name="trial-with-save_on_best", - config=alpha.TrialConfig( + exp = proj.start_experiment( + name="exp-with-save_on_best", + config=alpha.ExperimentConfig( checkpoint=alpha.CheckpointConfig( enabled=True, path=tmpdir, @@ -181,19 +187,19 @@ def pre_save_hook(): with open(file, "w") as f: f.write("This is file.\n") - run = trial.start_run(lambda: log_metric(0.90)) + run = exp.start_run(lambda: log_metric(0.90)) await run.wait() # We need this because the returned version is unordered. used_version = [] - versions = exp._runtime._artifact.list_versions(exp.id) + versions = proj._runtime._artifact.list_versions(proj.id) assert len(versions) == 1 run_obj = run._get_obj() fixed_version = versions[0] used_version.append(fixed_version) assert ( - run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version + run_obj.meta[ARTIFACT_PATH] == f"{team_id}/{proj.id}:" + fixed_version ) with open(file) as f: assert len(f.readlines()) == 2 @@ -201,25 +207,25 @@ def pre_save_hook(): # To avoid the same timestamp hash, we wait for 1 second time.sleep(1) - run = trial.start_run(lambda: log_metric(0.78)) + run = exp.start_run(lambda: log_metric(0.78)) await run.wait() - versions = exp._runtime._artifact.list_versions(exp.id) + versions = proj._runtime._artifact.list_versions(proj.id) assert len(versions) == 1 time.sleep(1) - run = trial.start_run(lambda: log_metric(0.91)) + run = exp.start_run(lambda: log_metric(0.91)) await run.wait() - versions = exp._runtime._artifact.list_versions(exp.id) + versions = proj._runtime._artifact.list_versions(proj.id) assert len(versions) == 2 fixed_version = find_unused_version(used_version, versions) used_version.append(fixed_version) run_obj = run._get_obj() assert ( - run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version + run_obj.meta[ARTIFACT_PATH] == f"{team_id}/{proj.id}:" + fixed_version ) with open(file) as f: @@ -227,42 +233,42 @@ def pre_save_hook(): time.sleep(1) - run = trial.start_run(lambda: log_metric(0.98)) + run = exp.start_run(lambda: log_metric(0.98)) await run.wait() - versions = exp._runtime._artifact.list_versions(exp.id) + versions = proj._runtime._artifact.list_versions(proj.id) assert len(versions) == 3 run_obj = run._get_obj() fixed_version = find_unused_version(used_version, versions) used_version.append(fixed_version) assert ( - run_obj.meta[ARTIFACT_PATH] == f"{project_id}/{exp.id}:" + fixed_version + run_obj.meta[ARTIFACT_PATH] == f"{team_id}/{proj.id}:" + fixed_version ) with open(file) as f: assert len(f.readlines()) == 4 - trial.done() + exp.done() @pytest.mark.asyncio async def test_log_metrics_with_save_on_min(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async def log_metric(value: float): await alpha.log_metrics({"accuracy": value}) - async with alpha.CraftExperiment.setup( + async with alpha.Project.setup( name="log_metrics_with_save_on_min", description="Context manager test", meta={"key": "value"}, - ) as exp: + ) as proj: with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) - trial = exp.start_trial( - name="trial-with-save_on_best", - config=alpha.TrialConfig( + exp = proj.start_experiment( + name="exp-with-save_on_best", + config=alpha.ExperimentConfig( checkpoint=alpha.CheckpointConfig( enabled=True, path=tmpdir, @@ -277,42 +283,42 @@ async def log_metric(value: float): with open(file1, "w") as f: f.write("This is file1.") - run = trial.start_run(lambda: log_metric(0.30)) + run = exp.start_run(lambda: log_metric(0.30)) await run.wait() - versions = exp._runtime._artifact.list_versions(exp.id) + versions = proj._runtime._artifact.list_versions(proj.id) assert len(versions) == 1 # To avoid the same timestamp hash, we wait for 1 second time.sleep(1) - run = trial.start_run(lambda: log_metric(0.58)) + run = exp.start_run(lambda: log_metric(0.58)) await run.wait() - versions = exp._runtime._artifact.list_versions(exp.id) + versions = proj._runtime._artifact.list_versions(proj.id) assert len(versions) == 1 time.sleep(1) - run = trial.start_run(lambda: log_metric(0.21)) + run = exp.start_run(lambda: log_metric(0.21)) await run.wait() - versions = exp._runtime._artifact.list_versions(exp.id) + versions = exp._runtime._artifact.list_versions(proj.id) assert len(versions) == 2 time.sleep(1) - task = trial.start_run(lambda: log_metric(0.18)) + task = exp.start_run(lambda: log_metric(0.18)) await task.wait() - versions = exp._runtime._artifact.list_versions(exp.id) + versions = proj._runtime._artifact.list_versions(str(proj.id)) assert len(versions) == 3 - trial.done() + exp.done() @pytest.mark.asyncio async def test_log_metrics_with_early_stopping(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -321,38 +327,40 @@ async def fake_sleep(value: float): await asyncio.sleep(100) await alpha.log_metrics({"accuracy": value}) - async with alpha.CraftExperiment.setup( - name="log_metrics_with_early_stopping" - ) as exp: - async with exp.start_trial( - name="trial-with-early-stopping", - config=alpha.TrialConfig( + async with alpha.Project.setup(name="log_metrics_with_early_stopping") as proj: + async with proj.start_experiment( + name="exp-with-early-stopping", + config=alpha.ExperimentConfig( monitor_metric="accuracy", early_stopping_runs=2, ), - ) as trial: - trial.start_run(lambda: fake_work(0.5)) - trial.start_run(lambda: fake_work(0.6)) - trial.start_run(lambda: fake_work(0.2)) - trial.start_run(lambda: fake_work(0.7)) - trial.start_run(lambda: fake_sleep(0.2)) + ) as exp: + exp.start_run(lambda: fake_work(0.5)) + exp.start_run(lambda: fake_work(0.6)) + exp.start_run(lambda: fake_work(0.2)) + exp.start_run(lambda: fake_work(0.7)) + exp.start_run(lambda: fake_sleep(0.2)) # The first run that is worse than 0.6 - trial.start_run(lambda: fake_work(0.4)) + exp.start_run(lambda: fake_work(0.4)) # The second run that is worse than 0.6, should trigger early stopping - trial.start_run(lambda: fake_work(0.1)) - trial.start_run(lambda: fake_work(0.2)) + exp.start_run(lambda: fake_work(0.1)) + exp.start_run(lambda: fake_work(0.2)) # trigger early stopping - await trial.wait() + await exp.wait() assert ( - len(trial._runtime._metadb.list_metrics_by_trial_id(trial_id=trial.id)) + len( + exp._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) + ) == 6 ) @pytest.mark.asyncio async def test_log_metrics_with_early_stopping_never_triggered(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -361,26 +369,30 @@ async def fake_sleep(value: float): await asyncio.sleep(value) await alpha.log_metrics({"accuracy": value}) - async with alpha.CraftExperiment.setup( + async with alpha.Project.setup( name="log_metrics_with_both_early_stopping_and_timeout" - ) as exp: - async with exp.start_trial( - name="trial-with-early-stopping", - config=alpha.TrialConfig( + ) as proj: + async with proj.start_experiment( + name="exp-with-early-stopping", + config=alpha.ExperimentConfig( monitor_metric="accuracy", early_stopping_runs=3, max_execution_seconds=3, ), - ) as trial: + ) as exp: start_time = datetime.now() - trial.start_run(lambda: fake_work(1)) - trial.start_run(lambda: fake_work(2)) - trial.start_run(lambda: fake_sleep(2)) + exp.start_run(lambda: fake_work(1)) + exp.start_run(lambda: fake_work(2)) + exp.start_run(lambda: fake_sleep(2)) # running in parallel. - await trial.wait() + await exp.wait() assert ( - len(trial._runtime._metadb.list_metrics_by_trial_id(trial_id=trial.id)) + len( + exp._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) + ) == 3 ) assert datetime.now() - start_time >= timedelta(seconds=3) @@ -388,34 +400,36 @@ async def fake_sleep(value: float): @pytest.mark.asyncio async def test_log_metrics_with_max_run_number(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) - async with alpha.CraftExperiment.setup( - name="log_metrics_with_max_run_number" - ) as exp: - async with exp.start_trial( - name="trial-with-max-run-number", - config=alpha.TrialConfig( + async with alpha.Project.setup(name="log_metrics_with_max_run_number") as proj: + async with proj.start_experiment( + name="exp-with-max-run-number", + config=alpha.ExperimentConfig( monitor_metric="accuracy", - max_runs_per_trial=5, + max_runs_per_experiment=5, ), - ) as trial: - while not trial.is_done(): - run = trial.start_run(lambda: fake_work(1)) + ) as exp: + while not exp.is_done(): + run = exp.start_run(lambda: fake_work(1)) await run.wait() assert ( - len(trial._runtime._metadb.list_metrics_by_trial_id(trial_id=trial.id)) + len( + exp._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) + ) == 5 ) @pytest.mark.asyncio async def test_log_metrics_with_max_target_meet(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -424,31 +438,33 @@ async def fake_sleep(value: float): await asyncio.sleep(10) await alpha.log_metrics({"accuracy": value}) - async with alpha.CraftExperiment.setup( - name="log_metrics_with_max_target_meet" - ) as exp: - async with exp.start_trial( - name="trial-with-max-target-meet", - config=alpha.TrialConfig( + async with alpha.Project.setup(name="log_metrics_with_max_target_meet") as proj: + async with proj.start_experiment( + name="exp-with-max-target-meet", + config=alpha.ExperimentConfig( monitor_metric="accuracy", target_metric_value=0.9, ), - ) as trial: - trial.start_run(lambda: fake_work(0.5)) - trial.start_run(lambda: fake_work(0.3)) - trial.start_run(lambda: fake_sleep(0.4)) - trial.start_run(lambda: fake_work(0.9)) - await trial.wait() + ) as exp: + exp.start_run(lambda: fake_work(0.5)) + exp.start_run(lambda: fake_work(0.3)) + exp.start_run(lambda: fake_sleep(0.4)) + exp.start_run(lambda: fake_work(0.9)) + await exp.wait() assert ( - len(trial._runtime._metadb.list_metrics_by_trial_id(trial_id=trial.id)) + len( + exp._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) + ) == 3 ) @pytest.mark.asyncio async def test_log_metrics_with_min_target_meet(): - alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -457,24 +473,26 @@ async def fake_sleep(value: float): await asyncio.sleep(3) await alpha.log_metrics({"accuracy": value}) - async with alpha.CraftExperiment.setup( - name="log_metrics_with_min_target_meet" - ) as exp: - async with exp.start_trial( - name="trial-with-min-target-meet", - config=alpha.TrialConfig( + async with alpha.Project.setup(name="log_metrics_with_min_target_meet") as proj: + async with proj.start_experiment( + name="exp-with-min-target-meet", + config=alpha.ExperimentConfig( monitor_metric="accuracy", target_metric_value=0.2, monitor_mode=alpha.MonitorMode.MIN, ), - ) as trial: - trial.start_run(lambda: fake_work(0.5)) - trial.start_run(lambda: fake_work(0.3)) - trial.start_run(lambda: fake_sleep(0.4)) - trial.start_run(lambda: fake_work(0.2)) - await trial.wait() + ) as exp: + exp.start_run(lambda: fake_work(0.5)) + exp.start_run(lambda: fake_work(0.3)) + exp.start_run(lambda: fake_sleep(0.4)) + exp.start_run(lambda: fake_work(0.2)) + await exp.wait() assert ( - len(trial._runtime._metadb.list_metrics_by_trial_id(trial_id=trial.id)) + len( + exp._runtime._metadb.list_metrics_by_experiment_id( + experiment_id=exp.id + ) + ) == 3 ) diff --git a/tests/integration/test_tracing.py b/tests/integration/test_tracing.py index 7f99f10..872ec8d 100644 --- a/tests/integration/test_tracing.py +++ b/tests/integration/test_tracing.py @@ -52,9 +52,9 @@ async def joke_workflow(): @pytest.mark.asyncio async def test_workflow(): - async with alpha.CraftExperiment.setup("demo_joke_workflow") as exp: - async with exp.start_trial("demo_joke_trial") as trial: - task = trial.start_run(lambda: joke_workflow()) - await task.wait() + async with alpha.Project.setup("demo_joke_workflow") as proj: + async with proj.start_experiment("demo_joke_experiment") as exp: + run = exp.start_run(lambda: joke_workflow()) + await run.wait() - assert exp.get_trial(trial.id) is None + assert proj.get_experiment(exp.id) is None diff --git a/tests/unit/artifact/test_artifact.py b/tests/unit/artifact/test_artifact.py index a4353ac..791f363 100644 --- a/tests/unit/artifact/test_artifact.py +++ b/tests/unit/artifact/test_artifact.py @@ -10,7 +10,7 @@ @pytest.fixture def artifact(): - init(project_id=uuid.uuid4(), artifact_insecure=True) + init(team_id=uuid.uuid4(), artifact_insecure=True) artifact = global_runtime()._artifact yield artifact diff --git a/tests/unit/experiment/test_craft_exp.py b/tests/unit/experiment/test_craft_exp.py deleted file mode 100644 index 9a76871..0000000 --- a/tests/unit/experiment/test_craft_exp.py +++ /dev/null @@ -1,313 +0,0 @@ -import asyncio -import random -import uuid -from datetime import datetime, timedelta - -import pytest - -from alphatrion.experiment.craft_exp import CraftExperiment, ExperimentConfig -from alphatrion.metadata.sql_models import Status -from alphatrion.runtime.runtime import global_runtime, init -from alphatrion.trial.trial import Trial, TrialConfig, current_trial_id - - -@pytest.mark.asyncio -async def test_craft_experiment(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - ) as exp: - exp1 = exp._get() - assert exp1 is not None - assert exp1.name == "context_exp" - assert exp1.description == "Context manager test" - - trial = exp.start_trial(name="first-trial") - trial_obj = trial._get_obj() - assert trial_obj is not None - assert trial_obj.name == "first-trial" - - trial.done() - - trial_obj = trial._get_obj() - assert trial_obj.duration is not None - assert trial_obj.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_craft_experiment_with_done(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - trial_id = None - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - ) as exp: - trial = exp.start_trial(name="first-trial") - trial_id = trial.id - - # exit the exp context, trial should be done automatically - trial_obj = global_runtime()._metadb.get_trial(trial_id=trial_id) - assert trial_obj.duration is not None - assert trial_obj.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_craft_experiment_with_done_with_err(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - trial_id = None - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - ) as exp: - trial = exp.start_trial(name="first-trial") - trial_id = trial.id - trial.done_with_err() - - # exit the exp context, trial should be done automatically - trial_obj = global_runtime()._metadb.get_trial(trial_id=trial_id) - assert trial_obj.duration is not None - assert trial_obj.status == Status.FAILED - - -@pytest.mark.asyncio -async def test_craft_experiment_with_no_context(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async def fake_work(trial: Trial): - await asyncio.sleep(3) - trial.done() - - exp = CraftExperiment.setup(name="no_context_exp") - async with exp.start_trial(name="first-trial") as trial: - trial.start_run(lambda: fake_work(trial)) - await trial.wait() - - trial_obj = trial._get_obj() - assert trial_obj.duration is not None - assert trial_obj.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_create_experiment_with_trial(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - trial_id = None - async with CraftExperiment.setup(name="context_exp") as exp: - async with exp.start_trial(name="first-trial") as trial: - trial_obj = trial._get_obj() - assert trial_obj is not None - assert trial_obj.name == "first-trial" - trial_id = current_trial_id.get() - - trial_obj = exp._runtime._metadb.get_trial(trial_id=trial_id) - assert trial_obj.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_create_experiment_with_trial_wait(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async def fake_work(trial: Trial): - await asyncio.sleep(3) - trial.done() - - trial_id = None - async with CraftExperiment.setup(name="context_exp") as exp: - async with exp.start_trial(name="first-trial") as trial: - trial_id = current_trial_id.get() - start_time = datetime.now() - - asyncio.create_task(fake_work(trial)) - assert datetime.now() - start_time <= timedelta(seconds=1) - - await trial.wait() - assert datetime.now() - start_time >= timedelta(seconds=3) - - trial_obj = exp._runtime._metadb.get_trial(trial_id=trial_id) - assert trial_obj.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_create_experiment_with_run(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async def fake_work(cancel_func: callable, trial_id: uuid.UUID): - assert current_trial_id.get() == trial_id - await asyncio.sleep(3) - cancel_func() - - async with ( - CraftExperiment.setup(name="context_exp") as exp, - exp.start_trial(name="first-trial") as trial, - ): - start_time = datetime.now() - - trial.start_run(lambda: fake_work(trial.done, trial.id)) - assert len(trial._runs) == 1 - - trial.start_run(lambda: fake_work(trial.done, trial.id)) - assert len(trial._runs) == 2 - - await trial.wait() - assert datetime.now() - start_time >= timedelta(seconds=3) - assert len(trial._runs) == 0 - - -@pytest.mark.asyncio -async def test_create_experiment_with_run_cancelled(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async def fake_work(timeout: int): - await asyncio.sleep(timeout) - - async with ( - CraftExperiment.setup(name="context_exp") as exp, - exp.start_trial( - name="first-trial", config=TrialConfig(max_execution_seconds=2) - ) as trial, - ): - run_0 = trial.start_run(lambda: fake_work(1)) - run_1 = trial.start_run(lambda: fake_work(4)) - run_2 = trial.start_run(lambda: fake_work(5)) - run_3 = trial.start_run(lambda: fake_work(6)) - # At this point, 4 runs are started. - assert len(trial._runs) == 4 - await trial.wait() - - run_0_obj = run_0._get_obj() - assert run_0_obj.status == Status.COMPLETED - run_1_obj = run_1._get_obj() - assert run_1_obj.status == Status.CANCELLED - run_2_obj = run_2._get_obj() - assert run_2_obj.status == Status.CANCELLED - run_3_obj = run_3._get_obj() - assert run_3_obj.status == Status.CANCELLED - - -@pytest.mark.asyncio -async def test_craft_experiment_with_max_execution_seconds(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - ) as exp: - trial = exp.start_trial( - name="first-trial", config=TrialConfig(max_execution_seconds=2) - ) - await trial.wait() - assert trial.is_done() - - trial = trial._get_obj() - assert trial.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_craft_experiment_with_multi_trials_in_parallel(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async def fake_work(): - exp = global_runtime().current_exp - - duration = random.randint(1, 5) - trial = exp.start_trial( - name="first-trial", config=TrialConfig(max_execution_seconds=duration) - ) - # double check current trial id. - assert trial.id == current_trial_id.get() - - await trial.wait() - assert trial.is_done() - # we don't reset the current trial id. - assert trial.id == current_trial_id.get() - - trial = trial._get_obj() - assert trial.status == Status.COMPLETED - - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - ): - await asyncio.gather( - fake_work(), - fake_work(), - fake_work(), - ) - print("All trials finished.") - - -@pytest.mark.asyncio -async def test_craft_experiment_with_config(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - config=ExperimentConfig(max_execution_seconds=2), - ) as exp: - trial = exp.start_trial(name="first-trial") - await trial.wait() - assert trial.is_done() - - trial = trial._get_obj() - assert trial.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_craft_experiment_with_hierarchy_timeout(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - config=ExperimentConfig(max_execution_seconds=2), - ) as exp: - start_time = datetime.now() - trial = exp.start_trial( - name="first-trial", config=TrialConfig(max_execution_seconds=5) - ) - await trial.wait() - assert trial.is_done() - - assert (datetime.now() - start_time).total_seconds() >= 2 - assert (datetime.now() - start_time).total_seconds() < 5 - - trial = trial._get_obj() - assert trial.status == Status.COMPLETED - - -@pytest.mark.asyncio -async def test_craft_experiment_with_hierarchy_timeout_2(): - init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) - - start_time = datetime.now() - - async with CraftExperiment.setup( - name="context_exp", - description="Context manager test", - meta={"key": "value"}, - config=ExperimentConfig(max_execution_seconds=5), - ) as exp: - trial = exp.start_trial( - name="first-trial", config=TrialConfig(max_execution_seconds=2) - ) - await trial.wait() - assert trial.is_done() - - assert (datetime.now() - start_time).total_seconds() >= 2 - - trial = trial._get_obj() - assert trial.status == Status.COMPLETED - - assert (datetime.now() - start_time).total_seconds() < 5 diff --git a/tests/unit/trial/test_trial.py b/tests/unit/experiment/test_experimant.py similarity index 75% rename from tests/unit/trial/test_trial.py rename to tests/unit/experiment/test_experimant.py index 2344d95..45e51c8 100644 --- a/tests/unit/trial/test_trial.py +++ b/tests/unit/experiment/test_experimant.py @@ -7,57 +7,64 @@ import faker import pytest -from alphatrion.trial.trial import CheckpointConfig, Trial, TrialConfig +from alphatrion.experiment.experiment import ( + CheckpointConfig, + Experiment, + ExperimentConfig, +) +from alphatrion.runtime.runtime import init -class TestTrial(unittest.IsolatedAsyncioTestCase): +class TestExperiment(unittest.IsolatedAsyncioTestCase): @pytest.mark.asyncio async def test_timeout(self): test_cases = [ { "name": "No timeout", - "config": TrialConfig(), + "config": ExperimentConfig(), "created": False, "expected": None, }, { "name": "Positive timeout", - "config": TrialConfig(max_execution_seconds=10), + "config": ExperimentConfig(max_execution_seconds=10), "created": False, "expected": 10, }, { "name": "Zero timeout", - "config": TrialConfig(max_execution_seconds=0), + "config": ExperimentConfig(max_execution_seconds=0), "created": False, "expected": 0, }, { "name": "Negative timeout", - "config": TrialConfig(max_execution_seconds=-5), + "config": ExperimentConfig(max_execution_seconds=-5), "created": False, "expected": None, }, { "name": "With started_at, positive timeout", - "config": TrialConfig(max_execution_seconds=5), + "config": ExperimentConfig(max_execution_seconds=5), "created": True, "expected": 3, }, ] + init(team_id=uuid.uuid4(), init_tables=True) + for case in test_cases: with self.subTest(name=case["name"]): - trial = Trial(exp_id=uuid.uuid4(), config=case["config"]) - trial._start(name=faker.Faker().word()) + exp = Experiment(proj_id=uuid.uuid4(), config=case["config"]) + exp._start(name=faker.Faker().word()) if case["created"]: time.sleep(2) # simulate elapsed time self.assertEqual( - trial._timeout(), case["config"].max_execution_seconds - 2 + exp._timeout(), case["config"].max_execution_seconds - 2 ) else: - self.assertEqual(trial._timeout(), case["expected"]) + self.assertEqual(exp._timeout(), case["expected"]) def test_config(self): test_cases = [ @@ -96,13 +103,15 @@ def test_config(self): }, ] + init(team_id=uuid.uuid4(), init_tables=True) + for case in test_cases: with self.subTest(name=case["name"]): if case["error"]: with self.assertRaises(ValueError): - Trial( - exp_id=1, - config=TrialConfig( + Experiment( + proj_id=uuid.uuid4(), + config=ExperimentConfig( monitor_metric=case["config"].get( "monitor_metric", None ), @@ -117,9 +126,9 @@ def test_config(self): ), ) else: - _ = Trial( - exp_id=1, - config=TrialConfig( + _ = Experiment( + proj_id=uuid.uuid4(), + config=ExperimentConfig( monitor_metric=case["config"].get("monitor_metric", None), checkpoint=CheckpointConfig( save_on_best=case["config"].get( diff --git a/tests/unit/metadata/test_sql.py b/tests/unit/metadata/test_sql.py index c8c226b..d8a0373 100644 --- a/tests/unit/metadata/test_sql.py +++ b/tests/unit/metadata/test_sql.py @@ -12,96 +12,103 @@ def db(): yield db -def test_create_exp(db): - project_id = uuid.uuid4() - id = db.create_exp("test_exp", project_id, "test description", {"key": "value"}) - exp = db.get_exp(id) - assert exp is not None - assert exp.name == "test_exp" - assert exp.project_id == project_id - assert exp.description == "test description" - assert exp.meta == {"key": "value"} - assert exp.uuid is not None - - -def test_delete_exp(db): - id = db.create_exp("test_exp", uuid.uuid4(), "test description", {"key": "value"}) - db.delete_exp(id) - exp = db.get_exp(id) - assert exp is None +def test_create_project(db): + team_id = uuid.uuid4() + id = db.create_project("test_proj", team_id, "test description", {"key": "value"}) + proj = db.get_project(id) + assert proj is not None + assert proj.name == "test_proj" + assert proj.team_id == team_id + assert proj.description == "test description" + assert proj.meta == {"key": "value"} + assert proj.uuid is not None + + +def test_delete_project(db): + id = db.create_project( + "test_proj", uuid.uuid4(), "test description", {"key": "value"} + ) + db.delete_project(id) + proj = db.get_project(id) + assert proj is None -def test_update_exp(db): - id = db.create_exp("test_exp", uuid.uuid4(), "test description", {"key": "value"}) - db.update_exp(id, name="new_name") - exp = db.get_exp(id) - assert exp.name == "new_name" +def test_update_project(db): + id = db.create_project( + "test_proj", uuid.uuid4(), "test description", {"key": "value"} + ) + db.update_project(id, name="new_name") + proj = db.get_project(id) + assert proj.name == "new_name" -def test_list_exps(db): - project_id1 = uuid.uuid4() - project_id2 = uuid.uuid4() - db.create_exp("exp1", project_id1, None, None) - db.create_exp("exp2", project_id1, None, None) - db.create_exp("exp3", project_id2, None, None) +def test_list_projects(db): + team_id1 = uuid.uuid4() + team_id2 = uuid.uuid4() + db.create_project("proj1", team_id1, None, None) + db.create_project("proj2", team_id1, None, None) + db.create_project("proj3", team_id2, None, None) - exps = db.list_exps(project_id1, 0, 10) - assert len(exps) == 2 + projs = db.list_projects(team_id1, 0, 10) + assert len(projs) == 2 - exps = db.list_exps(project_id2, 0, 10) - assert len(exps) == 1 + projs = db.list_projects(team_id2, 0, 10) + assert len(projs) == 1 - exps = db.list_exps(uuid.uuid4(), 0, 10) - assert len(exps) == 0 + projs = db.list_projects(uuid.uuid4(), 0, 10) + assert len(projs) == 0 -def test_create_trial(db): - project_id = uuid.uuid4() - exp_id = db.create_exp("test_exp", project_id, "test description") - trial_id = db.create_trial( - experiment_id=exp_id, - project_id=project_id, - name="test-trial", +def test_create_experiment(db): + team_id = uuid.uuid4() + proj_id = db.create_project( + "test_proj", team_id, "test description", {"key": "value"} + ) + exp_id = db.create_experiment( + team_id=team_id, + project_id=proj_id, + name="test-exp", params={"lr": 0.01}, ) - trial = db.get_trial(trial_id) - assert trial is not None - assert trial.experiment_id == exp_id - assert trial.name == "test-trial" - assert trial.status == Status.PENDING - assert trial.meta is None - assert trial.params == {"lr": 0.01} - - -def test_update_trial(db): - project_id = uuid.uuid4() - exp_id = db.create_exp("test_exp", project_id, "test description") - trial_id = db.create_trial( - experiment_id=exp_id, project_id=project_id, name="test-trial" + exp = db.get_experiment(exp_id) + assert exp is not None + assert exp.project_id == proj_id + assert exp.name == "test-exp" + assert exp.status == Status.PENDING + assert exp.meta is None + assert exp.params == {"lr": 0.01} + + +def test_update_experiment(db): + team_id = uuid.uuid4() + proj_id = db.create_project("test_proj", team_id, "test description") + exp_id = db.create_experiment( + team_id=team_id, + project_id=proj_id, + name="test_exp", + description="test description", ) - trial = db.get_trial(trial_id) - assert trial.status == Status.PENDING - assert trial.meta is None + exp = db.get_experiment(exp_id) + assert exp.status == Status.PENDING + assert exp.meta is None - db.update_trial(trial_id, status=Status.RUNNING, meta={"note": "started"}) - trial = db.get_trial(trial_id) - assert trial.status == Status.RUNNING - assert trial.meta == {"note": "started"} + db.update_experiment(exp_id, status=Status.RUNNING, meta={"note": "started"}) + exp = db.get_experiment(exp_id) + assert exp.status == Status.RUNNING + assert exp.meta == {"note": "started"} def test_create_metric(db): - project_id = uuid.uuid4() - exp_id = db.create_exp("test_exp", project_id, "test description") - trial_id = db.create_trial( - experiment_id=exp_id, project_id=project_id, name="test-trial" - ) - run_id = db.create_run( - trial_id=trial_id, project_id=project_id, experiment_id=exp_id + team_id = uuid.uuid4() + proj_id = db.create_project( + "test_proj", team_id, "test description", {"key": "value"} ) - db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.95) - db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.85) + exp_id = db.create_experiment(team_id=team_id, project_id=proj_id, name="test-exp") + run_id = db.create_run(team_id=team_id, project_id=proj_id, experiment_id=exp_id) + db.create_metric(team_id, proj_id, exp_id, run_id, "accuracy", 0.95) + db.create_metric(team_id, proj_id, exp_id, run_id, "accuracy", 0.85) - metrics = db.list_metrics_by_trial_id(trial_id) + metrics = db.list_metrics_by_experiment_id(exp_id) assert len(metrics) == 2 assert metrics[0].key == "accuracy" assert metrics[0].value == 0.95 diff --git a/tests/unit/model/test_model.py b/tests/unit/model/test_model.py index 885bc36..fdd3315 100644 --- a/tests/unit/model/test_model.py +++ b/tests/unit/model/test_model.py @@ -3,21 +3,20 @@ import pytest from alphatrion.model.model import Model -from alphatrion.runtime.runtime import Runtime +from alphatrion.runtime.runtime import global_runtime, init @pytest.fixture def model(): - runtime = Runtime(project_id="test_project", init_tables=True) + init(team_id=uuid.uuid4(), init_tables=True) + runtime = global_runtime() model = Model(runtime=runtime) yield model def test_model(model): - project_id = uuid.uuid4() - id = model.create( - "test_model", project_id, "A test model", {"tags": {"foo": "bar"}} - ) + team_id = uuid.uuid4() + id = model.create("test_model", team_id, "A test model", {"tags": {"foo": "bar"}}) model1 = model.get(id) assert model1 is not None assert model1.name == "test_model" diff --git a/tests/unit/project/test_project.py b/tests/unit/project/test_project.py new file mode 100644 index 0000000..3a807fe --- /dev/null +++ b/tests/unit/project/test_project.py @@ -0,0 +1,320 @@ +import asyncio +import random +import uuid +from datetime import datetime, timedelta + +import pytest + +from alphatrion.experiment import experiment +from alphatrion.metadata.sql_models import Status +from alphatrion.project.project import Project, ProjectConfig +from alphatrion.runtime.runtime import global_runtime, init + + +@pytest.mark.asyncio +async def test_project(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + ) as proj: + proj1 = proj._get() + assert proj1 is not None + assert proj1.name == "context_proj" + assert proj1.description == "Context manager test" + + exp = proj.start_experiment(name="first-experiment") + exp_obj = exp._get_obj() + assert exp_obj is not None + assert exp_obj.name == "first-experiment" + + exp.done() + + exp_obj = exp._get_obj() + assert exp_obj.duration is not None + assert exp_obj.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_project_with_done(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + exp_id = None + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + ) as proj: + exp = proj.start_experiment(name="first-experiment") + exp_id = exp.id + + # exit the exp context, trial should be done automatically + exp_obj = global_runtime()._metadb.get_experiment(experiment_id=exp_id) + assert exp_obj.duration is not None + assert exp_obj.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_project_with_done_with_err(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + exp_id = None + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + ) as proj: + exp = proj.start_experiment(name="first-experiment") + exp_id = exp.id + exp.done_with_err() + + # exit the proj context, trial should be done automatically + exp_obj = global_runtime()._metadb.get_experiment(experiment_id=exp_id) + assert exp_obj.duration is not None + assert exp_obj.status == Status.FAILED + + +@pytest.mark.asyncio +async def test_project_with_no_context(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async def fake_work(exp: experiment.Experiment): + await asyncio.sleep(3) + exp.done() + + proj = Project.setup(name="no_context_proj") + async with proj.start_experiment(name="first-trial") as exp: + exp.start_run(lambda: fake_work(exp)) + await exp.wait() + + exp_obj = exp._get_obj() + assert exp_obj.duration is not None + assert exp_obj.status == Status.COMPLETED + + exp.done() + + +@pytest.mark.asyncio +async def test_project_with_exp(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + exp_id = None + async with Project.setup(name="context_proj") as proj: + async with proj.start_experiment(name="first-exp") as exp: + exp_obj = exp._get_obj() + assert exp_obj is not None + assert exp_obj.name == "first-exp" + exp_id = experiment.current_exp_id.get() + + exp_obj = proj._runtime._metadb.get_experiment(experiment_id=exp_id) + assert exp_obj.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_create_project_with_exp_wait(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async def fake_work(exp: experiment.Experiment): + await asyncio.sleep(3) + exp.done() + + exp_id = None + async with Project.setup(name="context_proj") as proj: + async with proj.start_experiment(name="first-experiment") as exp: + exp_id = experiment.current_exp_id.get() + start_time = datetime.now() + + asyncio.create_task(fake_work(exp)) + assert datetime.now() - start_time <= timedelta(seconds=1) + + await exp.wait() + assert datetime.now() - start_time >= timedelta(seconds=3) + + exp_obj = exp._runtime._metadb.get_experiment(experiment_id=exp_id) + assert exp_obj.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_create_project_with_run(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async def fake_work(cancel_func: callable, exp_id: uuid.UUID): + assert experiment.current_exp_id.get() == exp_id + await asyncio.sleep(3) + cancel_func() + + async with ( + Project.setup(name="context_proj") as proj, + proj.start_experiment(name="first-experiment") as exp, + ): + start_time = datetime.now() + + exp.start_run(lambda: fake_work(exp.done, exp.id)) + assert len(exp._runs) == 1 + + exp.start_run(lambda: fake_work(exp.done, exp.id)) + assert len(exp._runs) == 2 + + await exp.wait() + assert datetime.now() - start_time >= timedelta(seconds=3) + assert len(exp._runs) == 0 + + +@pytest.mark.asyncio +async def test_create_project_with_run_cancelled(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async def fake_work(timeout: int): + await asyncio.sleep(timeout) + + async with ( + Project.setup(name="context_proj") as proj, + proj.start_experiment( + name="first-experiment", + config=experiment.ExperimentConfig(max_execution_seconds=2), + ) as exp, + ): + run_0 = exp.start_run(lambda: fake_work(1)) + run_1 = exp.start_run(lambda: fake_work(4)) + run_2 = exp.start_run(lambda: fake_work(5)) + run_3 = exp.start_run(lambda: fake_work(6)) + # At this point, 4 runs are started. + assert len(exp._runs) == 4 + await exp.wait() + + run_0_obj = run_0._get_obj() + assert run_0_obj.status == Status.COMPLETED + run_1_obj = run_1._get_obj() + assert run_1_obj.status == Status.CANCELLED + run_2_obj = run_2._get_obj() + assert run_2_obj.status == Status.CANCELLED + run_3_obj = run_3._get_obj() + assert run_3_obj.status == Status.CANCELLED + + +@pytest.mark.asyncio +async def test_create_project_with_max_execution_seconds(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + ) as proj: + exp = proj.start_experiment( + name="first-experiment", + config=experiment.ExperimentConfig(max_execution_seconds=2), + ) + await exp.wait() + assert exp.is_done() + + exp = exp._get_obj() + assert exp.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_project_with_multi_trials_in_parallel(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async def fake_work(): + proj = global_runtime().current_proj + + duration = random.randint(1, 5) + exp = proj.start_experiment( + name="first-experiment", + config=experiment.ExperimentConfig(max_execution_seconds=duration), + ) + # double check current trial id. + assert exp.id == experiment.current_exp_id.get() + + await exp.wait() + assert exp.is_done() + # we don't reset the current trial id. + assert exp.id == experiment.current_exp_id.get() + + exp = exp._get_obj() + assert exp.status == Status.COMPLETED + + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + ): + await asyncio.gather( + fake_work(), + fake_work(), + fake_work(), + ) + print("All trials finished.") + + +@pytest.mark.asyncio +async def test_project_with_config(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + config=ProjectConfig(max_execution_seconds=2), + ) as proj: + exp = proj.start_experiment(name="first-experiment") + await exp.wait() + assert exp.is_done() + + exp_obj = exp._get_obj() + assert exp_obj.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_project_with_hierarchy_timeout(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + config=ProjectConfig(max_execution_seconds=2), + ) as proj: + start_time = datetime.now() + exp = proj.start_experiment( + name="first-experiment", + config=experiment.ExperimentConfig(max_execution_seconds=5), + ) + await exp.wait() + assert exp.is_done() + + assert (datetime.now() - start_time).total_seconds() >= 2 + assert (datetime.now() - start_time).total_seconds() < 5 + + exp_obj = exp._get_obj() + assert exp_obj.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_project_with_hierarchy_timeout_2(): + init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + start_time = datetime.now() + + async with Project.setup( + name="context_proj", + description="Context manager test", + meta={"key": "value"}, + config=ProjectConfig(max_execution_seconds=5), + ) as proj: + exp = proj.start_experiment( + name="first-experiment", + config=experiment.ExperimentConfig(max_execution_seconds=2), + ) + await exp.wait() + assert exp.is_done() + + assert (datetime.now() - start_time).total_seconds() >= 2 + + exp_obj = exp._get_obj() + assert exp_obj.status == Status.COMPLETED + + assert (datetime.now() - start_time).total_seconds() < 5