Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,18 @@ ServerPWD: alphatr1on
Below is a simple example with two approaches demonstrating how to create an experiment and log performance metrics.

```python
import alphatrion as alpha
import uuid
from alphatrion import init, log_metrics, Project, CraftExperiment

# Better to use a fixed UUID to identify your team.
alpha.init(team_id=uuid.uuid4(), artifact_insecure=True)
# Better to use a fixed UUID for the team and user in real scenarios.
init(team_id=uuid.uuid4(), user_id=uuid.uuid4(), artifact_insecure=True)

async def log():
# Run your code here then log metrics.
await alpha.log_metrics({"accuracy": 0.95})
await log_metrics({"accuracy": 0.95})

async with alpha.Project.setup(name="my_project"):
async with alpha.CraftExperiment.start(name="my_experiment") as exp:
async with Project.setup(name="my_project"):
async with CraftExperiment.start(name="my_experiment") as exp:
task = exp.run(lambda: log())
await task.wait()
```
Expand Down
3 changes: 2 additions & 1 deletion alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ def _start(
self._id = exp_obj.uuid
else:
self._id = self._runtime._metadb.create_experiment(
name=name,
team_id=self._runtime._team_id,
user_id=self._runtime._user_id,
project_id=proj.id,
name=name,
description=description,
meta=meta,
params=params,
Expand Down
25 changes: 24 additions & 1 deletion alphatrion/metadata/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from abc import ABC, abstractmethod

from alphatrion.metadata.sql_models import Experiment, Model
from alphatrion.metadata.sql_models import Experiment, Model, User


class MetaStore(ABC):
Expand All @@ -17,11 +17,32 @@ def create_team(
def get_team(self, team_id: uuid.UUID):
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
def create_user(
self,
username: str,
email: str,
team_id: uuid.UUID,
meta: dict | None = None,
) -> uuid.UUID:
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
def get_user(self, user_id: uuid.UUID) -> User | None:
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
def list_users(
self, team_id: uuid.UUID, page: int = 0, page_size: int = 10
) -> list[User]:
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
def create_project(
self,
name: str,
team_id: uuid.UUID,
user_id: uuid.UUID,
description: str | None = None,
meta: dict | None = None,
) -> int:
Expand Down Expand Up @@ -78,6 +99,7 @@ def delete_model(self, model_id: uuid.UUID):
def create_experiment(
self,
team_id: uuid.UUID,
user_id: uuid.UUID,
project_id: uuid.UUID,
name: str,
description: str | None = None,
Expand All @@ -102,6 +124,7 @@ def update_experiment(self, experiment_id: uuid.UUID, **kwargs):
def create_run(
self,
team_id: uuid.UUID,
user_id: uuid.UUID,
project_id: uuid.UUID,
experiment_id: uuid.UUID,
meta: dict | None = None,
Expand Down
81 changes: 81 additions & 0 deletions alphatrion/metadata/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Run,
Status,
Team,
User,
)


Expand All @@ -26,6 +27,8 @@ def __init__(self, db_url: str, init_tables: bool = False):
# Mostly used in tests.
Base.metadata.create_all(self._engine)

# ---------- Team APIs ----------

def create_team(
self, name: str, description: str | None = None, meta: dict | None = None
) -> uuid.UUID:
Expand Down Expand Up @@ -62,17 +65,82 @@ def list_teams(self, page: int, page_size: int) -> list[Team]:
session.close()
return teams

def get_team_by_user_id(self, user_id: uuid.UUID) -> Team | None:
session = self._session()
user = (
session.query(User).filter(User.uuid == user_id, User.is_del == 0).first()
)
if not user:
session.close()
return None
team = (
session.query(Team)
.filter(Team.uuid == user.team_id, Team.is_del == 0)
.first()
)
session.close()
return team

# ---------- User APIs ----------

def create_user(
self,
username: str,
email: str,
team_id: uuid.UUID,
meta: dict | None = None,
) -> uuid.UUID:
session = self._session()
new_user = User(
username=username,
team_id=team_id,
email=email,
meta=meta,
)
session.add(new_user)
session.commit()
user_id = new_user.uuid
session.close()

return user_id

def get_user(self, user_id: uuid.UUID) -> User | None:
session = self._session()
user = (
session.query(User).filter(User.uuid == user_id, User.is_del == 0).first()
)
session.close()
return user

def list_users(
self, team_id: uuid.UUID, page: int = 0, page_size: int = 10
) -> list[User]:
session = self._session()
users = (
session.query(User)
.filter(User.team_id == team_id, User.is_del == 0)
.offset(page * page_size)
.limit(page_size)
.all()
)
session.close()
return users

# ---------- Project APIs ----------

def create_project(
self,
name: str,
team_id: uuid.UUID,
user_id: uuid.UUID,
description: str | None = None,
meta: dict | None = None,
) -> uuid.UUID:
session = self._session()
new_proj = Project(
name=name,
team_id=team_id,
creator_id=user_id,
description=description,
meta=meta,
)
Expand Down Expand Up @@ -151,6 +219,8 @@ def list_projects(
session.close()
return projects

# ---------- Model APIs ----------

def create_model(
self,
name: str,
Expand Down Expand Up @@ -221,10 +291,13 @@ def delete_model(self, model_id: uuid.UUID):
session.commit()
session.close()

# ---------- Experiment APIs ----------

def create_experiment(
self,
name: str,
team_id: uuid.UUID,
user_id: uuid.UUID,
project_id: uuid.UUID,
description: str | None = None,
meta: dict | None = None,
Expand All @@ -234,6 +307,7 @@ def create_experiment(
session = self._session()
new_exp = Experiment(
team_id=team_id,
user_id=user_id,
project_id=project_id,
name=name,
description=description,
Expand Down Expand Up @@ -306,18 +380,23 @@ def update_experiment(self, experiment_id: uuid.UUID, **kwargs) -> None:
session.commit()
session.close()

# ---------- Run APIs ----------

def create_run(
self,
team_id: uuid.UUID,
user_id: uuid.UUID,
project_id: uuid.UUID,
experiment_id: uuid.UUID,
meta: dict | None = None,
status: Status = Status.PENDING,
) -> uuid.UUID:
session = self._session()

new_run = Run(
project_id=project_id,
team_id=team_id,
user_id=user_id,
experiment_id=experiment_id,
meta=meta,
status=status,
Expand Down Expand Up @@ -359,6 +438,8 @@ def list_runs_by_exp_id(
session.close()
return runs

# ---------- Metric APIs ----------

def create_metric(
self,
team_id: uuid.UUID,
Expand Down
21 changes: 21 additions & 0 deletions alphatrion/metadata/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ class Team(Base):
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")


class User(Base):
__tablename__ = "users"

uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
username = Column(String, nullable=False, unique=True)
email = Column(String, nullable=False, unique=True)
team_id = Column(UUID(as_uuid=True), nullable=False)
meta = Column(JSON, nullable=True, comment="Additional metadata for the user")

created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")


# Define the Project model for SQLAlchemy
class Project(Base):
__tablename__ = "projects"
Expand All @@ -55,6 +73,7 @@ class Project(Base):
name = Column(String, nullable=False)
description = Column(String, nullable=True)
team_id = Column(UUID(as_uuid=True), nullable=False)
creator_id = Column(UUID(as_uuid=True), nullable=True)
meta = Column(JSON, nullable=True, comment="Additional metadata for the project")

created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
Expand All @@ -77,6 +96,7 @@ class Experiment(Base):
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
team_id = Column(UUID(as_uuid=True), nullable=False)
project_id = Column(UUID(as_uuid=True), nullable=False)
user_id = Column(UUID(as_uuid=True), nullable=True)
name = Column(String, nullable=False)
description = Column(String, nullable=True)
meta = Column(JSON, nullable=True, comment="Additional metadata for the trial")
Expand Down Expand Up @@ -113,6 +133,7 @@ class Run(Base):
team_id = Column(UUID(as_uuid=True), nullable=False)
project_id = Column(UUID(as_uuid=True), nullable=False)
experiment_id = Column(UUID(as_uuid=True), nullable=False)
user_id = Column(UUID(as_uuid=True), nullable=True)
meta = Column(JSON, nullable=True, comment="Additional metadata for the run")
status = Column(
Integer,
Expand Down
1 change: 1 addition & 0 deletions alphatrion/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _create(
name=name,
description=description,
team_id=self._runtime._team_id,
user_id=self._runtime._user_id,
meta=meta,
)
return self._id
Expand Down
3 changes: 2 additions & 1 deletion alphatrion/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def _get_obj(self):

def start(self, call_func: callable) -> None:
self._id = self._runtime._metadb.create_run(
team_id=self._runtime._team_id,
team_id=self._runtime.team_id,
user_id=self._runtime.user_id,
project_id=self._runtime.current_proj.id,
experiment_id=self._exp_id,
status=Status.RUNNING,
Expand Down
17 changes: 15 additions & 2 deletions alphatrion/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

def init(
team_id: uuid.UUID,
user_id: uuid.UUID,
artifact_insecure: bool = False,
init_tables: bool = False,
):
Expand All @@ -25,6 +26,7 @@ def init(
global __RUNTIME__
__RUNTIME__ = Runtime(
team_id=team_id,
user_id=user_id,
artifact_insecure=artifact_insecure,
init_tables=init_tables,
)
Expand All @@ -39,19 +41,22 @@ def global_runtime():
# Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc.
# Stateful information will also be stored here, e.g., current running Project.
class Runtime:
__slots__ = ("_team_id", "_metadb", "_artifact", "__current_proj")
__slots__ = ("_user_id", "_team_id", "_metadb", "_artifact", "__current_proj")

def __init__(
self,
team_id: uuid.UUID,
user_id: uuid.UUID,
artifact_insecure: bool = False,
init_tables: bool = False,
):
self._team_id = team_id
self._metadb = SQLStore(
os.getenv(consts.METADATA_DB_URL), init_tables=init_tables
)

self._user_id = user_id
self._team_id = team_id

if self.artifact_storage_enabled():
self._artifact = Artifact(team_id=self._team_id, insecure=artifact_insecure)

Expand All @@ -70,3 +75,11 @@ def current_proj(self, value) -> None:
@property
def metadb(self) -> SQLStore:
return self._metadb

@property
def user_id(self) -> uuid.UUID:
return self._user_id

@property
def team_id(self) -> uuid.UUID:
return self._team_id
Loading
Loading