Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions nemo_run/core/execution/skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
try:
import sky
import sky.task as skyt
from sky.utils import status_lib
from sky import backends
from sky.utils import status_lib

_SKYPILOT_AVAILABLE = True
except ImportError:
Expand All @@ -62,7 +62,8 @@ class SkypilotExecutor(Executor):
gpus="A10G",
gpus_per_node=devices,
container_image="nvcr.io/nvidia/nemo:dev",
cloud="kubernetes",
infra="k8s/my-context",
network_tier="best",
cluster_name="nemo_tester",
file_mounts={
"nemo_run.whl": "nemo_run.whl"
Expand Down Expand Up @@ -105,6 +106,8 @@ class SkypilotExecutor(Executor):
idle_minutes_to_autostop: Optional[int] = None
torchrun_nproc_per_node: Optional[int] = None
cluster_config_overrides: Optional[dict[str, Any]] = None
infra: Optional[str] = None
network_tier: Optional[str] = None
packager: Packager = field(default_factory=lambda: GitArchivePackager()) # type: ignore # noqa: F821

def __post_init__(self):
Expand All @@ -114,6 +117,13 @@ def __post_init__(self):
assert isinstance(self.packager, GitArchivePackager), (
"Only GitArchivePackager is currently supported for SkypilotExecutor."
)
if self.infra is not None:
assert self.cloud is None, "Cannot specify both `infra` and `cloud` parameters."
assert self.region is None, "Cannot specify both `infra` and `region` parameters."
assert self.zone is None, "Cannot specify both `infra` and `zone` parameters."
logger.info(
"`cloud` is deprecated and will be removed in a future version. Use `infra` instead."
)

@classmethod
def parse_app(cls: Type["SkypilotExecutor"], app_id: str) -> tuple[str, str, int]:
Expand Down Expand Up @@ -173,6 +183,8 @@ def parse_attr(attr: str):
"memory",
"instance_type",
"use_spot",
"infra",
"network_tier",
"image_id",
"disk_size",
"disk_tier",
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ lepton = "nemo_run.run.torchx_backend.schedulers.lepton:create_scheduler"

[project.optional-dependencies]
skypilot = [
"skypilot[kubernetes]>=0.9.2",
"skypilot[kubernetes]>=0.10.0",
]
skypilot-all = [
"skypilot[all]>=0.9.2",
"skypilot[all]>=0.10.0",
]
ray = [
"kubernetes"
Expand Down
40 changes: 40 additions & 0 deletions test/core/execution/test_skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class MockClusterNotUpError(Exception):
"sky.core": sky_core_mock,
"sky.skylet.job_lib": job_lib_mock,
"sky.utils.common_utils": common_utils_mock,
"sky.resources": MagicMock(),
}

# Also mock the sky_exceptions module with our mock exception
Expand Down Expand Up @@ -141,6 +142,33 @@ def test_init_non_git_packager(self, mock_skypilot_imports):
packager=non_git_packager,
)

def test_init_with_infra_and_cloud_fails(self, mock_skypilot_imports):
with pytest.raises(
AssertionError, match="Cannot specify both `infra` and `cloud` parameters."
):
SkypilotExecutor(
infra="my-infra",
cloud="aws",
)

def test_init_with_infra_and_region_fails(self, mock_skypilot_imports):
with pytest.raises(
AssertionError, match="Cannot specify both `infra` and `region` parameters."
):
SkypilotExecutor(
infra="my-infra",
region="us-west-2",
)

def test_init_with_infra_and_zone_fails(self, mock_skypilot_imports):
with pytest.raises(
AssertionError, match="Cannot specify both `infra` and `zone` parameters."
):
SkypilotExecutor(
infra="my-infra",
zone="us-west-2a",
)

def test_parse_app(self, mock_skypilot_imports):
app_id = "app___cluster-name___task-name___123"
cluster, task, job_id = SkypilotExecutor.parse_app(app_id)
Expand Down Expand Up @@ -228,6 +256,18 @@ def test_to_resources_with_none_string(self, mock_resources, mock_skypilot_impor
assert config["cloud"] is None
assert config["any_of"][1]["region"] is None

@patch("sky.resources.Resources")
def test_to_resources_with_infra_and_network_tier(self, mock_resources, mock_skypilot_imports):
executor = SkypilotExecutor(infra="k8s/my-context", network_tier="best")

executor.to_resources()

mock_resources.from_yaml_config.assert_called_once()

config = mock_resources.from_yaml_config.call_args[0][0]
assert config["infra"] == "k8s/my-context"
assert config["network_tier"] == "best"

@patch("sky.core.status")
@patch("sky.core.queue")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
Expand Down
Loading
Loading