diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index 88850a2..7b0c831 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -105,7 +105,6 @@ async def log_metrics(metrics: dict[str, float]): should_checkpoint = False should_early_stop = False should_stop_on_target = False - step = trial.increment_step() for key, value in metrics.items(): runtime._metadb.create_metric( key=key, @@ -114,7 +113,6 @@ async def log_metrics(metrics: dict[str, float]): experiment_id=exp.id, trial_id=trial_id, run_id=run_id, - step=step, ) # TODO: should we save the checkpoint path for the best metric? @@ -128,6 +126,7 @@ async def log_metrics(metrics: dict[str, float]): metric_key=key, metric_value=value ) + # TODO: we should save the artifact path in the metrics as well. if should_checkpoint: await log_artifact( paths=trial.config().checkpoint.path, diff --git a/alphatrion/metadata/base.py b/alphatrion/metadata/base.py index 23d952b..6142fb8 100644 --- a/alphatrion/metadata/base.py +++ b/alphatrion/metadata/base.py @@ -111,7 +111,6 @@ def create_metric( run_id: uuid.UUID, key: str, value: float, - step: int | None = None, ) -> int: raise NotImplementedError("Subclasses must implement this method.") diff --git a/alphatrion/metadata/sql.py b/alphatrion/metadata/sql.py index a7e8af8..a07f4ae 100644 --- a/alphatrion/metadata/sql.py +++ b/alphatrion/metadata/sql.py @@ -370,7 +370,6 @@ def create_metric( run_id: uuid.UUID, key: str, value: float, - step: int, ) -> uuid.UUID: session = self._session() new_metric = Metric( @@ -380,7 +379,6 @@ def create_metric( run_id=run_id, key=key, value=value, - step=step, ) session.add(new_metric) session.commit() diff --git a/alphatrion/metadata/sql_models.py b/alphatrion/metadata/sql_models.py index 49f6e24..7ea7736 100644 --- a/alphatrion/metadata/sql_models.py +++ b/alphatrion/metadata/sql_models.py @@ -151,6 +151,7 @@ class Model(Base): is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") +# TODO: key, project_id, experiment_id, trial_id, run_id should be unique together class Metric(Base): __tablename__ = "metrics" @@ -161,5 +162,4 @@ class Metric(Base): experiment_id = Column(UUID(as_uuid=True), nullable=False) trial_id = Column(UUID(as_uuid=True), nullable=False) run_id = Column(UUID(as_uuid=True), nullable=False) - step = Column(Integer, nullable=False, default=0) created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) diff --git a/alphatrion/server/graphql/resolvers.py b/alphatrion/server/graphql/resolvers.py index 45d3015..0bc66e2 100644 --- a/alphatrion/server/graphql/resolvers.py +++ b/alphatrion/server/graphql/resolvers.py @@ -183,7 +183,6 @@ def list_trial_metrics( experiment_id=m.experiment_id, trial_id=m.trial_id, run_id=m.run_id, - step=m.step, created_at=m.created_at, ) for m in metrics diff --git a/alphatrion/server/graphql/types.py b/alphatrion/server/graphql/types.py index 6777016..67c85c6 100644 --- a/alphatrion/server/graphql/types.py +++ b/alphatrion/server/graphql/types.py @@ -82,5 +82,4 @@ class Metric: experiment_id: strawberry.ID trial_id: strawberry.ID run_id: strawberry.ID - step: int created_at: datetime diff --git a/alphatrion/trial/trial.py b/alphatrion/trial/trial.py index 32573ea..9ea27fd 100644 --- a/alphatrion/trial/trial.py +++ b/alphatrion/trial/trial.py @@ -122,8 +122,6 @@ class Trial: "_exp_id", "_config", "_runtime", - # step is used to track the round, e.g. the step in metric logging. - "_step", "_context", "_token", # _meta stores the runtime meta information of the trial. @@ -148,7 +146,6 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None): self._exp_id = exp_id self._config = config or TrialConfig() self._runtime = global_runtime() - self._step = 0 self._construct_meta() self._runs = dict[uuid.UUID, Run]() self._early_stopping_counter = 0 @@ -350,10 +347,6 @@ def _stop(self): def _get_obj(self): return self._runtime.metadb.get_trial(trial_id=self._id) - def increment_step(self) -> int: - self._step += 1 - return self._step - def start_run(self, call_func: callable) -> Run: """Start a new run for the trial. :param call_func: a callable function that returns a coroutine. diff --git a/dashboard/src/components/trials/trial-detail.tsx b/dashboard/src/components/trials/trial-detail.tsx index 695f365..2729593 100644 --- a/dashboard/src/components/trials/trial-detail.tsx +++ b/dashboard/src/components/trials/trial-detail.tsx @@ -38,7 +38,6 @@ function MetricsChart({ metrics }: { metrics: Metric[] }) { metricsByKey[key].push(m); }); - Object.values(metricsByKey).forEach((arr) => arr.sort((a, b) => a.step - b.step)); const keys = Object.keys(metricsByKey); return ( diff --git a/dashboard/src/graphql/queries.ts b/dashboard/src/graphql/queries.ts index feae479..ea3b08d 100644 --- a/dashboard/src/graphql/queries.ts +++ b/dashboard/src/graphql/queries.ts @@ -139,7 +139,6 @@ export const LIST_TRIAL_METRICS = ` experimentId trialId runId - step createdAt } } diff --git a/dashboard/src/types/index.ts b/dashboard/src/types/index.ts index d006578..37d26e6 100644 --- a/dashboard/src/types/index.ts +++ b/dashboard/src/types/index.ts @@ -58,7 +58,6 @@ export interface Metric { experimentId: string; trialId: string; runId: string; - step: number; createdAt: string; } diff --git a/hack/seed.py b/hack/seed.py index 7310ddd..b1548d4 100644 --- a/hack/seed.py +++ b/hack/seed.py @@ -109,7 +109,6 @@ def generate_metric(runs: list[Run]) -> Metric: run_id=run.uuid, key=random.choice(["accuracy", "loss", "precision", "fitness"]), value=random.uniform(0, 1), - step=random.randint(1, 1000), ) diff --git a/migrations/versions/03410247c6b7_remove_step_from_trial.py b/migrations/versions/03410247c6b7_remove_step_from_trial.py new file mode 100644 index 0000000..50e2923 --- /dev/null +++ b/migrations/versions/03410247c6b7_remove_step_from_trial.py @@ -0,0 +1,32 @@ +"""remove step from trial + +Revision ID: 03410247c6b7 +Revises: c89fc8504699 +Create Date: 2026-01-16 15:17:18.057002 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '03410247c6b7' +down_revision: Union[str, Sequence[str], None] = 'c89fc8504699' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('metrics', 'step') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('metrics', sa.Column('step', sa.INTEGER(), autoincrement=False, nullable=False)) + # ### end Alembic commands ### diff --git a/tests/integration/server/test_graphql_query.py b/tests/integration/server/test_graphql_query.py index 7a78e62..a6b4713 100644 --- a/tests/integration/server/test_graphql_query.py +++ b/tests/integration/server/test_graphql_query.py @@ -304,7 +304,6 @@ def test_query_trial_metrics(): run_id=uuid.uuid4(), key="accuracy", value=0.95, - step=0, ) _ = metadb.create_metric( project_id=project_id, @@ -313,7 +312,6 @@ def test_query_trial_metrics(): run_id=uuid.uuid4(), key="accuracy", value=0.95, - step=1, ) query = f""" query {{ @@ -325,7 +323,6 @@ def test_query_trial_metrics(): experimentId trialId runId - step createdAt }} }} diff --git a/tests/integration/test_log.py b/tests/integration/test_log.py index 559a498..18c0dda 100644 --- a/tests/integration/test_log.py +++ b/tests/integration/test_log.py @@ -115,10 +115,8 @@ async def log_metric(metrics: dict): assert len(metrics) == 2 assert metrics[0].key == "accuracy" assert metrics[0].value == 0.95 - assert metrics[0].step == 1 assert metrics[1].key == "loss" assert metrics[1].value == 0.1 - assert metrics[1].step == 1 run_id_1 = metrics[0].run_id assert run_id_1 is not None assert metrics[0].run_id == metrics[1].run_id @@ -130,7 +128,6 @@ async def log_metric(metrics: dict): assert len(metrics) == 3 assert metrics[2].key == "accuracy" assert metrics[2].value == 0.96 - assert metrics[2].step == 2 run_id_2 = metrics[2].run_id assert run_id_2 is not None assert run_id_2 != run_id_1 diff --git a/tests/unit/metadata/test_sql.py b/tests/unit/metadata/test_sql.py index 36d9c90..c8c226b 100644 --- a/tests/unit/metadata/test_sql.py +++ b/tests/unit/metadata/test_sql.py @@ -98,8 +98,8 @@ def test_create_metric(db): run_id = db.create_run( trial_id=trial_id, project_id=project_id, experiment_id=exp_id ) - db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.95, 1) - db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.85, 2) + db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.95) + db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.85) metrics = db.list_metrics_by_trial_id(trial_id) assert len(metrics) == 2