Skip to content

Commit

Permalink
replace importing the wf with ast parsing
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Kamat <[email protected]>
  • Loading branch information
ayushkamat committed Feb 3, 2025
1 parent 5493f7e commit a354e6a
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 45 deletions.
126 changes: 126 additions & 0 deletions src/latch_cli/centromere/ast_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import ast
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent
from typing import Literal, Optional

import click


@dataclass
class FlyteObject:
type: Literal["task", "workflow"]
name: str
dockerfile: Optional[Path] = None


def is_task_decorator(decorator_name: str) -> bool:
return decorator_name in {
# og
"small_task",
"medium_task",
"large_task",
# og gpu
"small_gpu_task",
"large_gpu_task",
# custom
"custom_task",
"custom_memory_optimized_task",
# nf
"nextflow_runtime_task",
# l40s gpu
"g6e_xlarge_task",
"g6e_2xlarge_task",
"g6e_4xlarge_task",
"g6e_8xlarge_task",
"g6e_12xlarge_task",
"g6e_16xlarge_task",
"g6e_24xlarge_task",
# v100 gpu
"v100_x1_task",
"v100_x4_task",
"v100_x8_task",
}


class Visitor(ast.NodeVisitor):
def __init__(self, file: Path):
self.file = file
self.flyte_objects: list[FlyteObject] = []

def visit_FunctionDef(self, node: ast.FunctionDef):
if len(node.decorator_list) == 0:
return self.generic_visit(node)

for decorator in node.decorator_list:
if isinstance(decorator, ast.Name):
if decorator.id == "workflow":
self.flyte_objects.append(FlyteObject("workflow", node.name))
elif is_task_decorator(decorator.id):
self.flyte_objects.append(FlyteObject("task", node.name))

elif isinstance(decorator, ast.Call):
func = decorator.func
assert isinstance(func, ast.Name)

if not is_task_decorator(func.id) and func.id != "workflow":
continue

if func.id == "workflow":
self.flyte_objects.append(FlyteObject("workflow", node.name))
continue

dockerfile: Optional[Path] = None
for kw in decorator.keywords:
if kw.arg != "dockerfile":
continue

try:
dockerfile = Path(ast.literal_eval(kw.value)).resolve()
except ValueError as e:
click.secho(
dedent(f"""\
There was an issue parsing the `dockerfile` argument for task `{node.name}` in {self.file}.
Note that values passed to `dockerfile` must be string literals.
"""),
fg="red",
)

raise click.exceptions.Exit(1) from e

if not dockerfile.exists():
click.secho(
f"The `dockerfile` value ({dockerfile}) for task `{node.name}` in {self.file} does not exist.",
fg="red",
)

raise click.exceptions.Exit(1)

self.flyte_objects.append(FlyteObject("task", node.name, dockerfile))

return self.generic_visit(node)


def get_flyte_objects(file: Path) -> list[FlyteObject]:
res = []
if file.is_dir():
for child in file.iterdir():
res.extend(get_flyte_objects(child))

return res

assert file.is_file()
if file.suffix != ".py":
return res

v = Visitor(file.resolve())

try:
parsed = ast.parse(file.read_text(), filename=file)
except SyntaxError as e:
click.secho(f"There is a syntax error in {file}: {e}", fg="red")
raise click.exceptions.Exit(1) from e

v.visit(parsed)

return v.flyte_objects
52 changes: 27 additions & 25 deletions src/latch_cli/centromere/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@
import paramiko
import paramiko.util
from docker.transport import SSHHTTPAdapter
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteEntities
from flytekit.core.workflow import PythonFunctionWorkflow

import latch_cli.tinyrequests as tinyrequests
from latch.utils import account_id_from_token, current_workspace, retrieve_or_login
from latch_cli.centromere.ast_parsing import get_flyte_objects
from latch_cli.centromere.utils import (
RemoteConnInfo,
_construct_dkr_client,
_construct_ssh_client,
_import_flyte_objects,
)
from latch_cli.constants import docker_image_name_illegal_pat, latch_constants
from latch_cli.docker_utils import get_default_dockerfile
Expand Down Expand Up @@ -176,8 +173,8 @@ def __init__(

if self.workflow_type == WorkflowType.latchbiosdk:
try:
_import_flyte_objects([self.pkg_root], module_name=self.wf_module)
except ModuleNotFoundError:
flyte_objects = get_flyte_objects(self.pkg_root / self.wf_module)
except ModuleNotFoundError as e:
click.secho(
dedent(
f"""
Expand All @@ -189,14 +186,23 @@ def __init__(
),
fg="red",
)
raise click.exceptions.Exit(1)
raise click.exceptions.Exit(1) from e

wf_name: Optional[str] = None

name_path = pkg_root / latch_constants.pkg_workflow_name
if name_path.exists():
wf_name = name_path.read_text().strip()

if wf_name is None:
for obj in flyte_objects:
if obj.type != "workflow":
continue

for entity in FlyteEntities.entities:
if isinstance(entity, PythonFunctionWorkflow):
self.workflow_name = entity.name
wf_name = obj.name
break

if not hasattr(self, "workflow_name"):
if wf_name is None:
click.secho(
dedent("""\
Unable to locate workflow code. If you are a registering a Snakemake project, make sure to pass the Snakefile path with the --snakefile flag.
Expand All @@ -205,21 +211,17 @@ def __init__(
)
raise click.exceptions.Exit(1)

name_path = pkg_root / latch_constants.pkg_workflow_name
if name_path.exists():
self.workflow_name = name_path.read_text().strip()
self.workflow_name = wf_name

for entity in FlyteEntities.entities:
if isinstance(entity, PythonTask):
if (
hasattr(entity, "dockerfile_path")
and entity.dockerfile_path is not None
):
self.container_map[entity.name] = _Container(
dockerfile=entity.dockerfile_path,
image_name=self.task_image_name(entity.name),
pkg_dir=entity.dockerfile_path.parent,
)
for obj in flyte_objects:
if obj.type != "task" or obj.dockerfile is None:
continue

self.container_map[obj.name] = _Container(
dockerfile=obj.dockerfile,
image_name=self.task_image_name(obj.name),
pkg_dir=obj.dockerfile.parent,
)

elif self.workflow_type == WorkflowType.snakemake:
assert snakefile is not None
Expand Down
29 changes: 10 additions & 19 deletions src/latch_cli/centromere/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import builtins
import contextlib
import functools
import os
import random
import string
Expand All @@ -12,10 +11,8 @@
from typing import Callable, Iterator, List, Optional, TypeVar

import docker
import docker.errors
import paramiko
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.tools import module_loader
from typing_extensions import ParamSpec

from latch_cli.constants import latch_constants
Expand All @@ -42,6 +39,10 @@ def _add_sys_paths(paths: List[Path]) -> Iterator[None]:


def _import_flyte_objects(paths: List[Path], module_name: str = "wf"):
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.tools import module_loader

with _add_sys_paths(paths):

class FakeModule(ModuleType):
Expand Down Expand Up @@ -76,20 +77,15 @@ def __new__(*args, **kwargs):
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
try:
return real_import(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level,
name, globals=globals, locals=locals, fromlist=fromlist, level=level
)
except (ModuleNotFoundError, AttributeError) as e:
except (ModuleNotFoundError, AttributeError):
return FakeModule(name)

# Temporary ctx tells lytekit to skip local execution when
# inspecting objects
fap = FileAccessProvider(
local_sandbox_dir=tempfile.mkdtemp(prefix="foo"),
raw_output_prefix="bar",
local_sandbox_dir=tempfile.mkdtemp(prefix="foo"), raw_output_prefix="bar"
)
tmp_context = FlyteContext(fap, inspect_objects_only=True)

Expand Down Expand Up @@ -201,9 +197,7 @@ def _construct_ssh_client(
raise ConnectionError("unable to create connection to jump host")

sock = gateway_transport.open_channel(
kind="direct-tcpip",
dest_addr=(remote_conn_info.ip, 22),
src_addr=("", 0),
kind="direct-tcpip", dest_addr=(remote_conn_info.ip, 22), src_addr=("", 0)
)
else:
sock = None
Expand All @@ -214,10 +208,7 @@ def _construct_ssh_client(
ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
ssh.connect(
remote_conn_info.ip,
username=remote_conn_info.username,
sock=sock,
pkey=pkey,
remote_conn_info.ip, username=remote_conn_info.username, sock=sock, pkey=pkey
)

transport = ssh.get_transport()
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a354e6a

Please sign in to comment.