Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ async def log_metrics(metrics: dict[str, float]):
should_checkpoint = False
should_early_stop = False
should_stop_on_target = False
step = trial.increment_step()
for key, value in metrics.items():
runtime._metadb.create_metric(
key=key,
Expand All @@ -114,7 +113,6 @@ async def log_metrics(metrics: dict[str, float]):
experiment_id=exp.id,
trial_id=trial_id,
run_id=run_id,
step=step,
)

# TODO: should we save the checkpoint path for the best metric?
Expand All @@ -128,6 +126,7 @@ async def log_metrics(metrics: dict[str, float]):
metric_key=key, metric_value=value
)

# TODO: we should save the artifact path in the metrics as well.
if should_checkpoint:
await log_artifact(
paths=trial.config().checkpoint.path,
Expand Down
1 change: 0 additions & 1 deletion alphatrion/metadata/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def create_metric(
run_id: uuid.UUID,
key: str,
value: float,
step: int | None = None,
) -> int:
raise NotImplementedError("Subclasses must implement this method.")

Expand Down
2 changes: 0 additions & 2 deletions alphatrion/metadata/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def create_metric(
run_id: uuid.UUID,
key: str,
value: float,
step: int,
) -> uuid.UUID:
session = self._session()
new_metric = Metric(
Expand All @@ -380,7 +379,6 @@ def create_metric(
run_id=run_id,
key=key,
value=value,
step=step,
)
session.add(new_metric)
session.commit()
Expand Down
2 changes: 1 addition & 1 deletion alphatrion/metadata/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class Model(Base):
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")


# TODO: key, project_id, experiment_id, trial_id, run_id should be unique together
class Metric(Base):
__tablename__ = "metrics"

Expand All @@ -161,5 +162,4 @@ class Metric(Base):
experiment_id = Column(UUID(as_uuid=True), nullable=False)
trial_id = Column(UUID(as_uuid=True), nullable=False)
run_id = Column(UUID(as_uuid=True), nullable=False)
step = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
1 change: 0 additions & 1 deletion alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def list_trial_metrics(
experiment_id=m.experiment_id,
trial_id=m.trial_id,
run_id=m.run_id,
step=m.step,
created_at=m.created_at,
)
for m in metrics
Expand Down
1 change: 0 additions & 1 deletion alphatrion/server/graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,4 @@ class Metric:
experiment_id: strawberry.ID
trial_id: strawberry.ID
run_id: strawberry.ID
step: int
created_at: datetime
7 changes: 0 additions & 7 deletions alphatrion/trial/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ class Trial:
"_exp_id",
"_config",
"_runtime",
# step is used to track the round, e.g. the step in metric logging.
"_step",
"_context",
"_token",
# _meta stores the runtime meta information of the trial.
Expand All @@ -148,7 +146,6 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
self._exp_id = exp_id
self._config = config or TrialConfig()
self._runtime = global_runtime()
self._step = 0
self._construct_meta()
self._runs = dict[uuid.UUID, Run]()
self._early_stopping_counter = 0
Expand Down Expand Up @@ -350,10 +347,6 @@ def _stop(self):
def _get_obj(self):
return self._runtime.metadb.get_trial(trial_id=self._id)

def increment_step(self) -> int:
self._step += 1
return self._step

def start_run(self, call_func: callable) -> Run:
"""Start a new run for the trial.
:param call_func: a callable function that returns a coroutine.
Expand Down
1 change: 0 additions & 1 deletion dashboard/src/components/trials/trial-detail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ function MetricsChart({ metrics }: { metrics: Metric[] }) {
metricsByKey[key].push(m);
});

Object.values(metricsByKey).forEach((arr) => arr.sort((a, b) => a.step - b.step));
const keys = Object.keys(metricsByKey);

return (
Expand Down
1 change: 0 additions & 1 deletion dashboard/src/graphql/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ export const LIST_TRIAL_METRICS = `
experimentId
trialId
runId
step
createdAt
}
}
Expand Down
1 change: 0 additions & 1 deletion dashboard/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ export interface Metric {
experimentId: string;
trialId: string;
runId: string;
step: number;
createdAt: string;
}

Expand Down
1 change: 0 additions & 1 deletion hack/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def generate_metric(runs: list[Run]) -> Metric:
run_id=run.uuid,
key=random.choice(["accuracy", "loss", "precision", "fitness"]),
value=random.uniform(0, 1),
step=random.randint(1, 1000),
)


Expand Down
32 changes: 32 additions & 0 deletions migrations/versions/03410247c6b7_remove_step_from_trial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""remove step from trial

Revision ID: 03410247c6b7
Revises: c89fc8504699
Create Date: 2026-01-16 15:17:18.057002

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '03410247c6b7'
down_revision: Union[str, Sequence[str], None] = 'c89fc8504699'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('metrics', 'step')
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('metrics', sa.Column('step', sa.INTEGER(), autoincrement=False, nullable=False))
# ### end Alembic commands ###
3 changes: 0 additions & 3 deletions tests/integration/server/test_graphql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def test_query_trial_metrics():
run_id=uuid.uuid4(),
key="accuracy",
value=0.95,
step=0,
)
_ = metadb.create_metric(
project_id=project_id,
Expand All @@ -313,7 +312,6 @@ def test_query_trial_metrics():
run_id=uuid.uuid4(),
key="accuracy",
value=0.95,
step=1,
)
query = f"""
query {{
Expand All @@ -325,7 +323,6 @@ def test_query_trial_metrics():
experimentId
trialId
runId
step
createdAt
}}
}}
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,8 @@ async def log_metric(metrics: dict):
assert len(metrics) == 2
assert metrics[0].key == "accuracy"
assert metrics[0].value == 0.95
assert metrics[0].step == 1
assert metrics[1].key == "loss"
assert metrics[1].value == 0.1
assert metrics[1].step == 1
run_id_1 = metrics[0].run_id
assert run_id_1 is not None
assert metrics[0].run_id == metrics[1].run_id
Expand All @@ -130,7 +128,6 @@ async def log_metric(metrics: dict):
assert len(metrics) == 3
assert metrics[2].key == "accuracy"
assert metrics[2].value == 0.96
assert metrics[2].step == 2
run_id_2 = metrics[2].run_id
assert run_id_2 is not None
assert run_id_2 != run_id_1
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/metadata/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_create_metric(db):
run_id = db.create_run(
trial_id=trial_id, project_id=project_id, experiment_id=exp_id
)
db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.95, 1)
db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.85, 2)
db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.95)
db.create_metric(project_id, exp_id, trial_id, run_id, "accuracy", 0.85)

metrics = db.list_metrics_by_trial_id(trial_id)
assert len(metrics) == 2
Expand Down
Loading