diff --git a/README.md b/README.md index ccd48a0..a181e4c 100644 --- a/README.md +++ b/README.md @@ -59,18 +59,18 @@ ServerPWD: alphatr1on Below is a simple example with two approaches demonstrating how to create an experiment and log performance metrics. ```python -import alphatrion as alpha import uuid +from alphatrion import init, log_metrics, Project, CraftExperiment -# Better to use a fixed UUID to identify your team. -alpha.init(team_id=uuid.uuid4(), artifact_insecure=True) +# Better to use a fixed UUID for the team and user in real scenarios. +init(team_id=uuid.uuid4(), user_id=uuid.uuid4(), artifact_insecure=True) async def log(): # Run your code here then log metrics. - await alpha.log_metrics({"accuracy": 0.95}) + await log_metrics({"accuracy": 0.95}) -async with alpha.Project.setup(name="my_project"): - async with alpha.CraftExperiment.start(name="my_experiment") as exp: +async with Project.setup(name="my_project"): + async with CraftExperiment.start(name="my_experiment") as exp: task = exp.run(lambda: log()) await task.wait() ``` diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index d6ca969..af2c1ce 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -184,9 +184,10 @@ def _start( self._id = exp_obj.uuid else: self._id = self._runtime._metadb.create_experiment( + name=name, team_id=self._runtime._team_id, + user_id=self._runtime._user_id, project_id=proj.id, - name=name, description=description, meta=meta, params=params, diff --git a/alphatrion/metadata/base.py b/alphatrion/metadata/base.py index d53e0bc..6a3e163 100644 --- a/alphatrion/metadata/base.py +++ b/alphatrion/metadata/base.py @@ -1,7 +1,7 @@ import uuid from abc import ABC, abstractmethod -from alphatrion.metadata.sql_models import Experiment, Model +from alphatrion.metadata.sql_models import Experiment, Model, User class MetaStore(ABC): @@ -17,11 +17,32 @@ def create_team( def get_team(self, team_id: uuid.UUID): raise NotImplementedError("Subclasses must implement this method.") + @abstractmethod + def create_user( + self, + username: str, + email: str, + team_id: uuid.UUID, + meta: dict | None = None, + ) -> uuid.UUID: + raise NotImplementedError("Subclasses must implement this method.") + + @abstractmethod + def get_user(self, user_id: uuid.UUID) -> User | None: + raise NotImplementedError("Subclasses must implement this method.") + + @abstractmethod + def list_users( + self, team_id: uuid.UUID, page: int = 0, page_size: int = 10 + ) -> list[User]: + raise NotImplementedError("Subclasses must implement this method.") + @abstractmethod def create_project( self, name: str, team_id: uuid.UUID, + user_id: uuid.UUID, description: str | None = None, meta: dict | None = None, ) -> int: @@ -78,6 +99,7 @@ def delete_model(self, model_id: uuid.UUID): def create_experiment( self, team_id: uuid.UUID, + user_id: uuid.UUID, project_id: uuid.UUID, name: str, description: str | None = None, @@ -102,6 +124,7 @@ def update_experiment(self, experiment_id: uuid.UUID, **kwargs): def create_run( self, team_id: uuid.UUID, + user_id: uuid.UUID, project_id: uuid.UUID, experiment_id: uuid.UUID, meta: dict | None = None, diff --git a/alphatrion/metadata/sql.py b/alphatrion/metadata/sql.py index 00d7cbd..d9b9aef 100644 --- a/alphatrion/metadata/sql.py +++ b/alphatrion/metadata/sql.py @@ -13,6 +13,7 @@ Run, Status, Team, + User, ) @@ -26,6 +27,8 @@ def __init__(self, db_url: str, init_tables: bool = False): # Mostly used in tests. Base.metadata.create_all(self._engine) + # ---------- Team APIs ---------- + def create_team( self, name: str, description: str | None = None, meta: dict | None = None ) -> uuid.UUID: @@ -62,10 +65,74 @@ def list_teams(self, page: int, page_size: int) -> list[Team]: session.close() return teams + def get_team_by_user_id(self, user_id: uuid.UUID) -> Team | None: + session = self._session() + user = ( + session.query(User).filter(User.uuid == user_id, User.is_del == 0).first() + ) + if not user: + session.close() + return None + team = ( + session.query(Team) + .filter(Team.uuid == user.team_id, Team.is_del == 0) + .first() + ) + session.close() + return team + + # ---------- User APIs ---------- + + def create_user( + self, + username: str, + email: str, + team_id: uuid.UUID, + meta: dict | None = None, + ) -> uuid.UUID: + session = self._session() + new_user = User( + username=username, + team_id=team_id, + email=email, + meta=meta, + ) + session.add(new_user) + session.commit() + user_id = new_user.uuid + session.close() + + return user_id + + def get_user(self, user_id: uuid.UUID) -> User | None: + session = self._session() + user = ( + session.query(User).filter(User.uuid == user_id, User.is_del == 0).first() + ) + session.close() + return user + + def list_users( + self, team_id: uuid.UUID, page: int = 0, page_size: int = 10 + ) -> list[User]: + session = self._session() + users = ( + session.query(User) + .filter(User.team_id == team_id, User.is_del == 0) + .offset(page * page_size) + .limit(page_size) + .all() + ) + session.close() + return users + + # ---------- Project APIs ---------- + def create_project( self, name: str, team_id: uuid.UUID, + user_id: uuid.UUID, description: str | None = None, meta: dict | None = None, ) -> uuid.UUID: @@ -73,6 +140,7 @@ def create_project( new_proj = Project( name=name, team_id=team_id, + creator_id=user_id, description=description, meta=meta, ) @@ -151,6 +219,8 @@ def list_projects( session.close() return projects + # ---------- Model APIs ---------- + def create_model( self, name: str, @@ -221,10 +291,13 @@ def delete_model(self, model_id: uuid.UUID): session.commit() session.close() + # ---------- Experiment APIs ---------- + def create_experiment( self, name: str, team_id: uuid.UUID, + user_id: uuid.UUID, project_id: uuid.UUID, description: str | None = None, meta: dict | None = None, @@ -234,6 +307,7 @@ def create_experiment( session = self._session() new_exp = Experiment( team_id=team_id, + user_id=user_id, project_id=project_id, name=name, description=description, @@ -306,18 +380,23 @@ def update_experiment(self, experiment_id: uuid.UUID, **kwargs) -> None: session.commit() session.close() + # ---------- Run APIs ---------- + def create_run( self, team_id: uuid.UUID, + user_id: uuid.UUID, project_id: uuid.UUID, experiment_id: uuid.UUID, meta: dict | None = None, status: Status = Status.PENDING, ) -> uuid.UUID: session = self._session() + new_run = Run( project_id=project_id, team_id=team_id, + user_id=user_id, experiment_id=experiment_id, meta=meta, status=status, @@ -359,6 +438,8 @@ def list_runs_by_exp_id( session.close() return runs + # ---------- Metric APIs ---------- + def create_metric( self, team_id: uuid.UUID, diff --git a/alphatrion/metadata/sql_models.py b/alphatrion/metadata/sql_models.py index 7110945..367b098 100644 --- a/alphatrion/metadata/sql_models.py +++ b/alphatrion/metadata/sql_models.py @@ -47,6 +47,24 @@ class Team(Base): is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") +class User(Base): + __tablename__ = "users" + + uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + username = Column(String, nullable=False, unique=True) + email = Column(String, nullable=False, unique=True) + team_id = Column(UUID(as_uuid=True), nullable=False) + meta = Column(JSON, nullable=True, comment="Additional metadata for the user") + + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + ) + is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") + + # Define the Project model for SQLAlchemy class Project(Base): __tablename__ = "projects" @@ -55,6 +73,7 @@ class Project(Base): name = Column(String, nullable=False) description = Column(String, nullable=True) team_id = Column(UUID(as_uuid=True), nullable=False) + creator_id = Column(UUID(as_uuid=True), nullable=True) meta = Column(JSON, nullable=True, comment="Additional metadata for the project") created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) @@ -77,6 +96,7 @@ class Experiment(Base): uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) team_id = Column(UUID(as_uuid=True), nullable=False) project_id = Column(UUID(as_uuid=True), nullable=False) + user_id = Column(UUID(as_uuid=True), nullable=True) name = Column(String, nullable=False) description = Column(String, nullable=True) meta = Column(JSON, nullable=True, comment="Additional metadata for the trial") @@ -113,6 +133,7 @@ class Run(Base): team_id = Column(UUID(as_uuid=True), nullable=False) project_id = Column(UUID(as_uuid=True), nullable=False) experiment_id = Column(UUID(as_uuid=True), nullable=False) + user_id = Column(UUID(as_uuid=True), nullable=True) meta = Column(JSON, nullable=True, comment="Additional metadata for the run") status = Column( Integer, diff --git a/alphatrion/project/project.py b/alphatrion/project/project.py index 750d354..f781f3c 100644 --- a/alphatrion/project/project.py +++ b/alphatrion/project/project.py @@ -126,6 +126,7 @@ def _create( name=name, description=description, team_id=self._runtime._team_id, + user_id=self._runtime._user_id, meta=meta, ) return self._id diff --git a/alphatrion/run/run.py b/alphatrion/run/run.py index a400140..2e8c092 100644 --- a/alphatrion/run/run.py +++ b/alphatrion/run/run.py @@ -24,7 +24,8 @@ def _get_obj(self): def start(self, call_func: callable) -> None: self._id = self._runtime._metadb.create_run( - team_id=self._runtime._team_id, + team_id=self._runtime.team_id, + user_id=self._runtime.user_id, project_id=self._runtime.current_proj.id, experiment_id=self._exp_id, status=Status.RUNNING, diff --git a/alphatrion/runtime/runtime.py b/alphatrion/runtime/runtime.py index d1dbbab..e694d13 100644 --- a/alphatrion/runtime/runtime.py +++ b/alphatrion/runtime/runtime.py @@ -11,6 +11,7 @@ def init( team_id: uuid.UUID, + user_id: uuid.UUID, artifact_insecure: bool = False, init_tables: bool = False, ): @@ -25,6 +26,7 @@ def init( global __RUNTIME__ __RUNTIME__ = Runtime( team_id=team_id, + user_id=user_id, artifact_insecure=artifact_insecure, init_tables=init_tables, ) @@ -39,19 +41,22 @@ def global_runtime(): # Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc. # Stateful information will also be stored here, e.g., current running Project. class Runtime: - __slots__ = ("_team_id", "_metadb", "_artifact", "__current_proj") + __slots__ = ("_user_id", "_team_id", "_metadb", "_artifact", "__current_proj") def __init__( self, team_id: uuid.UUID, + user_id: uuid.UUID, artifact_insecure: bool = False, init_tables: bool = False, ): - self._team_id = team_id self._metadb = SQLStore( os.getenv(consts.METADATA_DB_URL), init_tables=init_tables ) + self._user_id = user_id + self._team_id = team_id + if self.artifact_storage_enabled(): self._artifact = Artifact(team_id=self._team_id, insecure=artifact_insecure) @@ -70,3 +75,11 @@ def current_proj(self, value) -> None: @property def metadb(self) -> SQLStore: return self._metadb + + @property + def user_id(self) -> uuid.UUID: + return self._user_id + + @property + def team_id(self) -> uuid.UUID: + return self._team_id diff --git a/alphatrion/server/graphql/resolvers.py b/alphatrion/server/graphql/resolvers.py index f212126..d05604f 100644 --- a/alphatrion/server/graphql/resolvers.py +++ b/alphatrion/server/graphql/resolvers.py @@ -12,6 +12,7 @@ Project, Run, Team, + User, ) @@ -47,6 +48,22 @@ def get_team(id: str) -> Team | None: ) return None + @staticmethod + def get_user(id: str) -> User | None: + metadb = runtime.graphql_runtime().metadb + user = metadb.get_user(user_id=uuid.UUID(id)) + if user: + return User( + id=user.uuid, + username=user.username, + email=user.email, + team_id=user.team_id, + meta=user.meta, + created_at=user.created_at, + updated_at=user.updated_at, + ) + return None + @staticmethod def list_projects( team_id: str, page: int = 0, page_size: int = 10 @@ -59,6 +76,7 @@ def list_projects( Project( id=proj.uuid, team_id=proj.team_id, + creator_id=proj.creator_id, name=proj.name, description=proj.description, meta=proj.meta, @@ -76,6 +94,7 @@ def get_project(id: str) -> Project | None: return Project( id=proj.uuid, team_id=proj.team_id, + creator_id=proj.creator_id, name=proj.name, description=proj.description, meta=proj.meta, @@ -96,6 +115,7 @@ def list_experiments( Experiment( id=e.uuid, team_id=e.team_id, + user_id=e.user_id, project_id=e.project_id, name=e.name, description=e.description, @@ -118,6 +138,7 @@ def get_experiment(id: str) -> Experiment | None: return Experiment( id=exp.uuid, team_id=exp.team_id, + user_id=exp.user_id, project_id=exp.project_id, name=exp.name, description=exp.description, @@ -141,6 +162,7 @@ def list_runs(experiment_id: str, page: int = 0, page_size: int = 10) -> list[Ru Run( id=r.uuid, team_id=r.team_id, + user_id=r.user_id, project_id=r.project_id, experiment_id=r.experiment_id, meta=r.meta, @@ -158,6 +180,7 @@ def get_run(id: str) -> Run | None: return Run( id=run.uuid, team_id=run.team_id, + user_id=run.user_id, project_id=run.project_id, experiment_id=run.experiment_id, meta=run.meta, diff --git a/alphatrion/server/graphql/schema.py b/alphatrion/server/graphql/schema.py index 2b9b897..d55cdf0 100644 --- a/alphatrion/server/graphql/schema.py +++ b/alphatrion/server/graphql/schema.py @@ -1,7 +1,7 @@ import strawberry from alphatrion.server.graphql.resolvers import GraphQLResolvers -from alphatrion.server.graphql.types import Experiment, Metric, Project, Run, Team +from alphatrion.server.graphql.types import Experiment, Metric, Project, Run, Team, User @strawberry.type @@ -9,6 +9,8 @@ class Query: teams: list[Team] = strawberry.field(resolver=GraphQLResolvers.list_teams) team: Team | None = strawberry.field(resolver=GraphQLResolvers.get_team) + user: User | None = strawberry.field(resolver=GraphQLResolvers.get_user) + @strawberry.field def projects( self, diff --git a/alphatrion/server/graphql/types.py b/alphatrion/server/graphql/types.py index 5226e6e..35fd7ac 100644 --- a/alphatrion/server/graphql/types.py +++ b/alphatrion/server/graphql/types.py @@ -15,10 +15,22 @@ class Team: updated_at: datetime +@strawberry.type +class User: + id: strawberry.ID + username: str + email: str + team_id: strawberry.ID + meta: JSON | None + created_at: datetime + updated_at: datetime + + @strawberry.type class Project: id: strawberry.ID team_id: strawberry.ID + creator_id: strawberry.ID name: str | None description: str | None meta: JSON | None @@ -50,6 +62,7 @@ class GraphQLExperimentType(Enum): class Experiment: id: strawberry.ID team_id: strawberry.ID + user_id: strawberry.ID project_id: strawberry.ID name: str description: str | None @@ -66,6 +79,7 @@ class Experiment: class Run: id: strawberry.ID team_id: strawberry.ID + user_id: strawberry.ID project_id: strawberry.ID experiment_id: strawberry.ID meta: JSON | None diff --git a/hack/seed.py b/hack/seed.py index a23b585..40cf638 100644 --- a/hack/seed.py +++ b/hack/seed.py @@ -21,6 +21,7 @@ Run, Status, Team, + User, ) load_dotenv() @@ -58,21 +59,40 @@ def generate_team() -> Team: ) -def generate_project(teams: list[Team]) -> Project: +def generate_user(teams: list[Team]) -> User: + return User( + uuid=uuid.uuid4(), + username=fake.user_name(), + email=fake.email(), + team_id=random.choice(teams).uuid, + meta=make_json_serializable( + fake.pydict(nb_elements=3, variable_nb_elements=True) + ), + ) + + +def generate_project(users: list[User]) -> Project: + user = random.choice(users) + team = ( + session.query(Team).filter(Team.uuid == user.team_id, Team.is_del == 0).first() + ) return Project( name=fake.bs().title(), description=fake.catch_phrase(), meta=make_json_serializable( fake.pydict(nb_elements=3, variable_nb_elements=True) ), - team_id=random.choice(teams).uuid, + creator_id=user.uuid, + team_id=team.uuid, ) def generate_experiment(projects: list[Project]) -> Experiment: proj = random.choice(projects) + user_id = proj.creator_id return Experiment( team_id=proj.team_id, + user_id=user_id, project_id=proj.uuid, name=fake.bs().title(), description=fake.catch_phrase(), @@ -89,8 +109,10 @@ def generate_experiment(projects: list[Project]) -> Experiment: def generate_run(exps: list[Experiment]) -> Run: exp = random.choice(exps) + user_id = exp.user_id return Run( team_id=exp.team_id, + user_id=user_id, project_id=exp.project_id, experiment_id=exp.uuid, meta=make_json_serializable( @@ -114,6 +136,7 @@ def generate_metric(runs: list[Run]) -> Metric: def seed_all( num_teams: int, + num_users: int, num_projs_per_team: int, num_exps_per_proj: int, num_runs_per_exp: int, @@ -122,12 +145,17 @@ def seed_all( Base.metadata.create_all(bind=engine) print("🌱 generating seeds ...") + teams = [generate_team() for _ in range(num_teams)] session.add_all(teams) session.commit() + users = [generate_user(teams) for _ in range(num_users)] + session.add_all(users) + session.commit() + projs = [ - generate_project(teams) + generate_project(users) for _ in range(num_projs_per_team) for _ in range(len(teams)) ] @@ -182,6 +210,7 @@ def cleanup(): elif action == "seed": seed_all( num_teams=3, + num_users=15, num_projs_per_team=10, num_exps_per_proj=10, num_runs_per_exp=20, diff --git a/migrations/versions/c98aa69beda7_init_schema.py b/migrations/versions/3c290addfa4c_init_schema.py similarity index 84% rename from migrations/versions/c98aa69beda7_init_schema.py rename to migrations/versions/3c290addfa4c_init_schema.py index 3c11fad..9da1350 100644 --- a/migrations/versions/c98aa69beda7_init_schema.py +++ b/migrations/versions/3c290addfa4c_init_schema.py @@ -1,8 +1,8 @@ """init schema -Revision ID: c98aa69beda7 +Revision ID: 3c290addfa4c Revises: -Create Date: 2026-01-29 01:10:13.441947 +Create Date: 2026-01-31 11:57:13.880349 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. -revision: str = 'c98aa69beda7' +revision: str = '3c290addfa4c' down_revision: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -25,6 +25,7 @@ def upgrade() -> None: sa.Column('uuid', sa.UUID(), nullable=False), sa.Column('team_id', sa.UUID(), nullable=False), sa.Column('project_id', sa.UUID(), nullable=False), + sa.Column('user_id', sa.UUID(), nullable=True), sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.String(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the trial'), @@ -66,6 +67,7 @@ def upgrade() -> None: sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.String(), nullable=True), sa.Column('team_id', sa.UUID(), nullable=False), + sa.Column('creator_id', sa.UUID(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the project'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), @@ -77,6 +79,7 @@ def upgrade() -> None: sa.Column('team_id', sa.UUID(), nullable=False), sa.Column('project_id', sa.UUID(), nullable=False), sa.Column('experiment_id', sa.UUID(), nullable=False), + sa.Column('user_id', sa.UUID(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the run'), sa.Column('status', sa.Integer(), nullable=False, comment='Status of the run, 0: UNKNOWN, 1: PENDING, 2: RUNNING, 9: COMPLETED, 10: CANCELLED, 11: FAILED'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), @@ -94,12 +97,26 @@ def upgrade() -> None: sa.Column('is_del', sa.Integer(), nullable=True, comment='0 for not deleted, 1 for deleted'), sa.PrimaryKeyConstraint('uuid') ) + op.create_table('users', + sa.Column('uuid', sa.UUID(), nullable=False), + sa.Column('username', sa.String(), nullable=False), + sa.Column('email', sa.String(), nullable=False), + sa.Column('team_id', sa.UUID(), nullable=False), + sa.Column('meta', sa.JSON(), nullable=True, comment='Additional metadata for the user'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('is_del', sa.Integer(), nullable=True, comment='0 for not deleted, 1 for deleted'), + sa.PrimaryKeyConstraint('uuid'), + sa.UniqueConstraint('email'), + sa.UniqueConstraint('username') + ) # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('users') op.drop_table('teams') op.drop_table('runs') op.drop_table('projects') diff --git a/tests/integration/server/test_graphql_query.py b/tests/integration/server/test_graphql_query.py index 9543b9b..99576d9 100644 --- a/tests/integration/server/test_graphql_query.py +++ b/tests/integration/server/test_graphql_query.py @@ -65,16 +65,52 @@ def test_query_teams(): assert response.errors is None assert len(response.data["teams"]) >= 2 +def test_query_user(): + init(init_tables=True) + + metadb = graphql_runtime().metadb + team_id = metadb.create_team( + name="Test Team", description="A team for testing", meta={"foo": "bar"} + ) + + user_id = metadb.create_user( + username="tester", email="tester@inftyai.com", team_id=team_id, meta={"foo": "bar"} + ) + + query = f""" + query {{ + user(id: "{user_id}") {{ + id + username + email + meta + teamId + createdAt + updatedAt + }} + }} + """ + response = schema.execute_sync( + query, + variable_values={}, + ) + assert response.errors is None + assert response.data["user"]["username"] == "tester" + assert response.data["user"]["email"] == "tester@inftyai.com" + assert response.data["user"]["teamId"] == str(team_id) + assert response.data["user"]["meta"] == {"foo": "bar"} def test_query_single_project(): init(init_tables=True) team_id = uuid.uuid4() + user_id = uuid.uuid4() metadb = graphql_runtime().metadb id = metadb.create_project( name="Test Project", description="A project for testing", team_id=team_id, + user_id=user_id, ) query = f""" @@ -102,21 +138,26 @@ def test_query_single_project(): def test_query_projects(): init(init_tables=True) team_id = uuid.uuid4() + user_id = uuid.uuid4() metadb = graphql_runtime().metadb + _ = metadb.create_project( name="Test Project1", description="A project for testing", team_id=team_id, + user_id=user_id, ) _ = metadb.create_project( name="Test Project2", description="A project for testing", team_id=team_id, + user_id=user_id, ) _ = metadb.create_project( - name="Test Project2", + name="Test Project3", description="A project for testing", team_id=uuid.uuid4(), + user_id=user_id, ) query = f""" @@ -143,12 +184,14 @@ def test_query_projects(): def test_query_single_exp(): init(init_tables=True) team_id = uuid.uuid4() + user_id = uuid.uuid4() project_id = uuid.uuid4() metadb = graphql_runtime().metadb exp_id = metadb.create_experiment( name="Test Experiment", team_id=team_id, + user_id=user_id, project_id=project_id, status=Status.RUNNING, meta={}, @@ -184,15 +227,18 @@ def test_query_experiments(): init(init_tables=True) team_id = uuid.uuid4() project_id = uuid.uuid4() + user_id = uuid.uuid4() metadb = graphql_runtime().metadb _ = metadb.create_experiment( name="Test Experiment1", team_id=team_id, + user_id=user_id, project_id=project_id, ) _ = metadb.create_experiment( name="Test Experiment2", team_id=team_id, + user_id=user_id, project_id=project_id, ) @@ -224,11 +270,13 @@ def test_query_experiments(): def test_query_single_run(): init(init_tables=True) team_id = uuid.uuid4() + user_id = uuid.uuid4() project_id = uuid.uuid4() exp_id = uuid.uuid4() metadb = graphql_runtime().metadb run_id = metadb.create_run( team_id=team_id, + user_id=user_id, project_id=project_id, experiment_id=exp_id, ) @@ -258,16 +306,19 @@ def test_query_single_run(): def test_query_runs(): init(init_tables=True) team_id = uuid.uuid4() + user_id = uuid.uuid4() project_id = uuid.uuid4() exp_id = uuid.uuid4() metadb = graphql_runtime().metadb _ = metadb.create_run( team_id=team_id, + user_id=user_id, project_id=project_id, experiment_id=exp_id, ) _ = metadb.create_run( team_id=team_id, + user_id=user_id, project_id=project_id, experiment_id=exp_id, ) diff --git a/tests/integration/test_artifact.py b/tests/integration/test_artifact.py index 604a21b..f72ba36 100644 --- a/tests/integration/test_artifact.py +++ b/tests/integration/test_artifact.py @@ -11,14 +11,24 @@ @pytest.fixture def artifact(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) artifact = global_runtime()._artifact yield artifact def test_push_with_files(artifact): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) @@ -41,7 +51,12 @@ def test_push_with_files(artifact): def test_push_with_folder(artifact): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) diff --git a/tests/integration/test_log.py b/tests/integration/test_log.py index 0230a28..5037bd6 100644 --- a/tests/integration/test_log.py +++ b/tests/integration/test_log.py @@ -15,7 +15,12 @@ @pytest.mark.asyncio async def test_log_artifact(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async with alpha.Project.setup( name="log_artifact_project", @@ -67,7 +72,12 @@ async def test_log_artifact(): @pytest.mark.asyncio async def test_log_params(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async with alpha.Project.setup(name="log_params_proj") as proj: exp = alpha.CraftExperiment.start(name="first-exp", params={"param1": 0.1}) @@ -94,7 +104,12 @@ async def test_log_params(): @pytest.mark.asyncio async def test_log_metrics(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def log_metric(metrics: dict): await alpha.log_metrics(metrics) @@ -145,7 +160,10 @@ async def log_metric(metrics: dict): @pytest.mark.asyncio async def test_log_metrics_with_save_on_max(): team_id = uuid.uuid4() - alpha.init(team_id=team_id, artifact_insecure=True, init_tables=True) + user_id = uuid.uuid4() + alpha.init( + team_id=team_id, user_id=user_id, artifact_insecure=True, init_tables=True + ) async def log_metric(value: float): await alpha.log_metrics({"accuracy": value}) @@ -253,7 +271,12 @@ def pre_save_hook(): @pytest.mark.asyncio async def test_log_metrics_with_save_on_min(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def log_metric(value: float): await alpha.log_metrics({"accuracy": value}) @@ -318,7 +341,12 @@ async def log_metric(value: float): @pytest.mark.asyncio async def test_log_metrics_with_early_stopping(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -360,7 +388,12 @@ async def fake_sleep(value: float): @pytest.mark.asyncio async def test_log_metrics_with_early_stopping_never_triggered(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -400,7 +433,12 @@ async def fake_sleep(value: float): @pytest.mark.asyncio async def test_log_metrics_with_max_run_number(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -429,7 +467,12 @@ async def fake_work(value: float): @pytest.mark.asyncio async def test_log_metrics_with_max_target_meet(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) @@ -464,7 +507,12 @@ async def fake_sleep(value: float): @pytest.mark.asyncio async def test_log_metrics_with_min_target_meet(): - alpha.init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + alpha.init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(value: float): await alpha.log_metrics({"accuracy": value}) diff --git a/tests/unit/artifact/test_artifact.py b/tests/unit/artifact/test_artifact.py index 791f363..03f5fc9 100644 --- a/tests/unit/artifact/test_artifact.py +++ b/tests/unit/artifact/test_artifact.py @@ -10,7 +10,7 @@ @pytest.fixture def artifact(): - init(team_id=uuid.uuid4(), artifact_insecure=True) + init(team_id=uuid.uuid4(), user_id=uuid.uuid4(), artifact_insecure=True) artifact = global_runtime()._artifact yield artifact diff --git a/tests/unit/experiment/test_experimant.py b/tests/unit/experiment/test_experimant.py index 81c3db6..9131757 100644 --- a/tests/unit/experiment/test_experimant.py +++ b/tests/unit/experiment/test_experimant.py @@ -52,16 +52,17 @@ async def test_timeout(self): }, ] - init(team_id=uuid.uuid4(), init_tables=True) + init(team_id=uuid.uuid4(), user_id=uuid.uuid4(), init_tables=True) for case in test_cases: with self.subTest(name=case["name"]): - proj = Project.setup( name=faker.Faker().word(), description="Test Project", ) - exp = CraftExperiment.start(name=faker.Faker().word(), config=case["config"]) + exp = CraftExperiment.start( + name=faker.Faker().word(), config=case["config"] + ) if case["created"]: time.sleep(2) # simulate elapsed time @@ -110,7 +111,7 @@ def test_config(self): }, ] - init(team_id=uuid.uuid4(), init_tables=True) + init(team_id=uuid.uuid4(), user_id=uuid.uuid4(), init_tables=True) for case in test_cases: with self.subTest(name=case["name"]): diff --git a/tests/unit/metadata/test_sql.py b/tests/unit/metadata/test_sql.py index d8a0373..669a8cb 100644 --- a/tests/unit/metadata/test_sql.py +++ b/tests/unit/metadata/test_sql.py @@ -14,11 +14,15 @@ def db(): def test_create_project(db): team_id = uuid.uuid4() - id = db.create_project("test_proj", team_id, "test description", {"key": "value"}) + user_id = uuid.uuid4() + id = db.create_project( + "test_proj", team_id, user_id, "test description", {"key": "value"} + ) proj = db.get_project(id) assert proj is not None assert proj.name == "test_proj" assert proj.team_id == team_id + assert proj.creator_id == user_id assert proj.description == "test description" assert proj.meta == {"key": "value"} assert proj.uuid is not None @@ -26,7 +30,7 @@ def test_create_project(db): def test_delete_project(db): id = db.create_project( - "test_proj", uuid.uuid4(), "test description", {"key": "value"} + "test_proj", uuid.uuid4(), uuid.uuid4(), "test description", {"key": "value"} ) db.delete_project(id) proj = db.get_project(id) @@ -35,7 +39,7 @@ def test_delete_project(db): def test_update_project(db): id = db.create_project( - "test_proj", uuid.uuid4(), "test description", {"key": "value"} + "test_proj", uuid.uuid4(), uuid.uuid4(), "test description", {"key": "value"} ) db.update_project(id, name="new_name") proj = db.get_project(id) @@ -45,9 +49,10 @@ def test_update_project(db): def test_list_projects(db): team_id1 = uuid.uuid4() team_id2 = uuid.uuid4() - db.create_project("proj1", team_id1, None, None) - db.create_project("proj2", team_id1, None, None) - db.create_project("proj3", team_id2, None, None) + user_id = uuid.uuid4() + db.create_project("proj1", team_id1, user_id, None, None) + db.create_project("proj2", team_id1, user_id, None, None) + db.create_project("proj3", team_id2, user_id, None, None) projs = db.list_projects(team_id1, 0, 10) assert len(projs) == 2 @@ -61,11 +66,13 @@ def test_list_projects(db): def test_create_experiment(db): team_id = uuid.uuid4() + user_id = uuid.uuid4() proj_id = db.create_project( - "test_proj", team_id, "test description", {"key": "value"} + "test_proj", team_id, user_id, "test description", {"key": "value"} ) exp_id = db.create_experiment( team_id=team_id, + user_id=user_id, project_id=proj_id, name="test-exp", params={"lr": 0.01}, @@ -81,9 +88,11 @@ def test_create_experiment(db): def test_update_experiment(db): team_id = uuid.uuid4() - proj_id = db.create_project("test_proj", team_id, "test description") + user_id = uuid.uuid4() + proj_id = db.create_project("test_proj", team_id, user_id, "test description") exp_id = db.create_experiment( team_id=team_id, + user_id=user_id, project_id=proj_id, name="test_exp", description="test description", @@ -100,11 +109,16 @@ def test_update_experiment(db): def test_create_metric(db): team_id = uuid.uuid4() + user_id = uuid.uuid4() proj_id = db.create_project( - "test_proj", team_id, "test description", {"key": "value"} + "test_proj", team_id, user_id, "test description", {"key": "value"} + ) + exp_id = db.create_experiment( + team_id=team_id, user_id=user_id, project_id=proj_id, name="test-exp" + ) + run_id = db.create_run( + team_id=team_id, user_id=user_id, project_id=proj_id, experiment_id=exp_id ) - exp_id = db.create_experiment(team_id=team_id, project_id=proj_id, name="test-exp") - run_id = db.create_run(team_id=team_id, project_id=proj_id, experiment_id=exp_id) db.create_metric(team_id, proj_id, exp_id, run_id, "accuracy", 0.95) db.create_metric(team_id, proj_id, exp_id, run_id, "accuracy", 0.85) diff --git a/tests/unit/model/test_model.py b/tests/unit/model/test_model.py index fdd3315..f196837 100644 --- a/tests/unit/model/test_model.py +++ b/tests/unit/model/test_model.py @@ -8,7 +8,7 @@ @pytest.fixture def model(): - init(team_id=uuid.uuid4(), init_tables=True) + init(team_id=uuid.uuid4(), user_id=uuid.uuid4(), init_tables=True) runtime = global_runtime() model = Model(runtime=runtime) yield model diff --git a/tests/unit/project/test_project.py b/tests/unit/project/test_project.py index a1f243a..49d0cb4 100644 --- a/tests/unit/project/test_project.py +++ b/tests/unit/project/test_project.py @@ -14,7 +14,12 @@ @pytest.mark.asyncio async def test_project(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async with Project.setup( name="context_proj", @@ -40,7 +45,12 @@ async def test_project(): @pytest.mark.asyncio async def test_project_with_done(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) exp_id = None async with Project.setup( @@ -59,7 +69,12 @@ async def test_project_with_done(): @pytest.mark.asyncio async def test_project_with_done_with_err(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) exp_id = None async with Project.setup( @@ -79,7 +94,12 @@ async def test_project_with_done_with_err(): @pytest.mark.asyncio async def test_project_with_no_context(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(exp: experiment.Experiment): await asyncio.sleep(3) @@ -99,7 +119,12 @@ async def fake_work(exp: experiment.Experiment): @pytest.mark.asyncio async def test_project_with_exp(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) exp_id = None async with Project.setup(name="context_proj") as proj: @@ -115,7 +140,12 @@ async def test_project_with_exp(): @pytest.mark.asyncio async def test_create_project_with_exp_wait(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(exp: experiment.Experiment): await asyncio.sleep(3) @@ -139,7 +169,12 @@ async def fake_work(exp: experiment.Experiment): @pytest.mark.asyncio async def test_create_project_with_run(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(cancel_func: callable, exp_id: uuid.UUID): assert experiment.current_exp_id.get() == exp_id @@ -165,7 +200,12 @@ async def fake_work(cancel_func: callable, exp_id: uuid.UUID): @pytest.mark.asyncio async def test_create_project_with_run_cancelled(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(timeout: int): await asyncio.sleep(timeout) @@ -197,7 +237,12 @@ async def fake_work(timeout: int): @pytest.mark.asyncio async def test_create_project_with_max_execution_seconds(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async with Project.setup( name="context_proj", @@ -217,7 +262,12 @@ async def test_create_project_with_max_execution_seconds(): @pytest.mark.asyncio async def test_project_with_multi_trials_in_parallel(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async def fake_work(): duration = random.randint(1, 5) @@ -251,7 +301,12 @@ async def fake_work(): @pytest.mark.asyncio async def test_project_with_config(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async with Project.setup( name="context_proj", @@ -269,7 +324,12 @@ async def test_project_with_config(): @pytest.mark.asyncio async def test_project_with_hierarchy_timeout(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) async with Project.setup( name="context_proj", @@ -294,7 +354,12 @@ async def test_project_with_hierarchy_timeout(): @pytest.mark.asyncio async def test_project_with_hierarchy_timeout_2(): - init(team_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + init( + team_id=uuid.uuid4(), + user_id=uuid.uuid4(), + artifact_insecure=True, + init_tables=True, + ) start_time = datetime.now()