diff --git a/nemo_run/core/execution/skypilot.py b/nemo_run/core/execution/skypilot.py index d3823c43..84fe602f 100644 --- a/nemo_run/core/execution/skypilot.py +++ b/nemo_run/core/execution/skypilot.py @@ -65,7 +65,22 @@ class SkypilotExecutor(Executor): network_tier="best", cluster_name="nemo_tester", file_mounts={ - "nemo_run.whl": "nemo_run.whl" + "nemo_run.whl": "nemo_run.whl", + "/workspace/code": "/local/path/to/code", + }, + storage_mounts={ + "/workspace/outputs": { + "name": "my-training-outputs", + "store": "gcs", # or "s3", "azure", etc. + "mode": "MOUNT", + "persistent": True, + }, + "/workspace/checkpoints": { + "name": "model-checkpoints", + "store": "s3", + "mode": "MOUNT", + "persistent": True, + } }, setup=\"\"\" conda deactivate @@ -99,6 +114,7 @@ class SkypilotExecutor(Executor): disk_tier: Optional[Union[str, list[str]]] = None ports: Optional[tuple[str]] = None file_mounts: Optional[dict[str, str]] = None + storage_mounts: Optional[dict[str, dict[str, Any]]] = None # Can be str or dict configs cluster_name: Optional[str] = None setup: Optional[str] = None autodown: bool = False @@ -371,9 +387,22 @@ def to_task( envs=self.env_vars, num_nodes=self.num_nodes, ) + # Handle regular file mounts file_mounts = self.file_mounts or {} file_mounts["/nemo_run"] = self.job_dir task.set_file_mounts(file_mounts) + + # Handle storage mounts separately + if self.storage_mounts: + from sky.data import Storage + + storage_objects = {} + for mount_path, config in self.storage_mounts.items(): + # Create Storage object from config dict + storage_obj = Storage.from_yaml_config(config) + storage_objects[mount_path] = storage_obj + task.set_storage_mounts(storage_objects) + task.set_resources(self.to_resources()) if env_vars: diff --git a/test/core/execution/test_skypilot.py b/test/core/execution/test_skypilot.py index fe975049..5d35c39c 100644 --- a/test/core/execution/test_skypilot.py +++ b/test/core/execution/test_skypilot.py @@ -561,3 +561,160 @@ def test_to_task(self, mock_task, mock_skypilot_imports, executor): # Verify the returned task is our mock assert result == mock_task_instance + + @patch("sky.task.Task") + def test_to_task_with_storage_mounts(self, mock_task, mock_skypilot_imports): + # Create a mock task instance + mock_task_instance = MagicMock() + mock_task.return_value = mock_task_instance + mock_task_instance.set_file_mounts = MagicMock() + mock_task_instance.set_storage_mounts = MagicMock() + mock_task_instance.set_resources = MagicMock() + + # Mock sky.data.Storage + mock_storage_class = MagicMock() + mock_storage_obj = MagicMock() + mock_storage_class.from_yaml_config.return_value = mock_storage_obj + + executor = SkypilotExecutor( + container_image="test:latest", + storage_mounts={ + "/workspace/outputs": { + "name": "my-outputs", + "store": "gcs", + "mode": "MOUNT", + "persistent": True, + } + }, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + executor.job_dir = tmp_dir + + with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources: + mock_to_resources.return_value = MagicMock() + + with patch("sky.data.Storage", mock_storage_class): + executor.to_task("test_task") + + # Verify Storage.from_yaml_config was called with the config + mock_storage_class.from_yaml_config.assert_called_once_with( + { + "name": "my-outputs", + "store": "gcs", + "mode": "MOUNT", + "persistent": True, + } + ) + + # Verify set_storage_mounts was called with Storage objects + mock_task_instance.set_storage_mounts.assert_called_once() + storage_mounts_call = mock_task_instance.set_storage_mounts.call_args[0][0] + assert "/workspace/outputs" in storage_mounts_call + assert storage_mounts_call["/workspace/outputs"] == mock_storage_obj + + @patch("sky.task.Task") + def test_to_task_with_both_file_and_storage_mounts(self, mock_task, mock_skypilot_imports): + # Create a mock task instance + mock_task_instance = MagicMock() + mock_task.return_value = mock_task_instance + mock_task_instance.set_file_mounts = MagicMock() + mock_task_instance.set_storage_mounts = MagicMock() + mock_task_instance.set_resources = MagicMock() + + # Mock sky.data.Storage + mock_storage_class = MagicMock() + mock_storage_obj = MagicMock() + mock_storage_class.from_yaml_config.return_value = mock_storage_obj + + executor = SkypilotExecutor( + container_image="test:latest", + file_mounts={ + "/workspace/code": "/local/path/to/code", + }, + storage_mounts={ + "/workspace/outputs": { + "name": "my-outputs", + "store": "s3", + "mode": "MOUNT", + }, + "/workspace/checkpoints": { + "name": "my-checkpoints", + "store": "gcs", + "mode": "MOUNT_CACHED", + }, + }, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + executor.job_dir = tmp_dir + + with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources: + mock_to_resources.return_value = MagicMock() + + with patch("sky.data.Storage", mock_storage_class): + executor.to_task("test_task") + + # Verify file_mounts includes both user files and nemo_run + file_mounts_call = mock_task_instance.set_file_mounts.call_args[0][0] + assert "/workspace/code" in file_mounts_call + assert file_mounts_call["/workspace/code"] == "/local/path/to/code" + assert "/nemo_run" in file_mounts_call + assert file_mounts_call["/nemo_run"] == tmp_dir + + # Verify Storage.from_yaml_config was called for both storage mounts + assert mock_storage_class.from_yaml_config.call_count == 2 + + # Verify set_storage_mounts was called with both Storage objects + mock_task_instance.set_storage_mounts.assert_called_once() + storage_mounts_call = mock_task_instance.set_storage_mounts.call_args[0][0] + assert "/workspace/outputs" in storage_mounts_call + assert "/workspace/checkpoints" in storage_mounts_call + assert len(storage_mounts_call) == 2 + + @patch("sky.task.Task") + def test_to_task_without_storage_mounts(self, mock_task, mock_skypilot_imports): + # Test that set_storage_mounts is not called when storage_mounts is None + mock_task_instance = MagicMock() + mock_task.return_value = mock_task_instance + mock_task_instance.set_file_mounts = MagicMock() + mock_task_instance.set_storage_mounts = MagicMock() + mock_task_instance.set_resources = MagicMock() + + executor = SkypilotExecutor( + container_image="test:latest", + file_mounts={"/workspace/code": "/local/path"}, + storage_mounts=None, # Explicitly set to None + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + executor.job_dir = tmp_dir + + with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources: + mock_to_resources.return_value = MagicMock() + + executor.to_task("test_task") + + # Verify set_storage_mounts was NOT called + mock_task_instance.set_storage_mounts.assert_not_called() + + # Verify file_mounts still works + mock_task_instance.set_file_mounts.assert_called_once() + + def test_init_with_storage_mounts(self, mock_skypilot_imports): + # Test initialization with storage_mounts parameter + executor = SkypilotExecutor( + container_image="test:latest", + storage_mounts={ + "/workspace/data": { + "name": "training-data", + "store": "s3", + "mode": "MOUNT", + } + }, + ) + + assert executor.storage_mounts is not None + assert "/workspace/data" in executor.storage_mounts + assert executor.storage_mounts["/workspace/data"]["name"] == "training-data" + assert executor.storage_mounts["/workspace/data"]["store"] == "s3"