diff --git a/config.yml b/config.yml index f06ab9f2..3624ef3e 100644 --- a/config.yml +++ b/config.yml @@ -48,12 +48,6 @@ globus: uuid: 9032dd3a-e841-4687-a163-2720da731b5b name: alcf_home832 - nersc_test: - root_path: /global/cfs/cdirs/als/data_mover/share/dabramov - uri: nersc.gov - uuid: d40248e6-d874-4f7b-badd-2c06c16f1a58 - name: nersc_test - nersc_alsdev: root_path: /global/homes/a/alsdev/test_directory/ uri: nersc.gov @@ -72,6 +66,24 @@ globus: uuid: d40248e6-d874-4f7b-badd-2c06c16f1a58 name: nersc832_alsdev_scratch + nersc832_alsdev_pscratch_raw: + root_path: /pscratch/sd/a/alsdev/8.3.2/raw + uri: nersc.gov + uuid: d40248e6-d874-4f7b-badd-2c06c16f1a58 + name: nersc832_alsdev_pscratch_raw + + nersc832_alsdev_pscratch_scratch: + root_path: /pscratch/sd/a/alsdev/8.3.2/scratch + uri: nersc.gov + uuid: d40248e6-d874-4f7b-badd-2c06c16f1a58 + name: nersc832_alsdev_pscratch_scratch + + nersc832_alsdev_recon_scripts: + root_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_reconstruction_scripts + uri: nersc.gov + uuid: d40248e6-d874-4f7b-badd-2c06c16f1a58 + name: nersc832_alsdev_recon_scripts + nersc832: root_path: /global/cfs/cdirs/als/data_mover/8.3.2 uri: nersc.gov @@ -95,6 +107,14 @@ globus: client_id: ${GLOBUS_CLIENT_ID} client_secret: ${GLOBUS_CLIENT_SECRET} +harbor_images832: + recon_image: tomorecon_nersc_mpi_hdf5@sha256:cc098a2cfb6b1632ea872a202c66cb7566908da066fd8f8c123b92fa95c2a43c + multires_image: tomorecon_nersc_mpi_hdf5@sha256:cc098a2cfb6b1632ea872a202c66cb7566908da066fd8f8c123b92fa95c2a43c + +ghcr_images832: + recon_image: ghcr.io/als-computing/microct:master + multires_image: ghcr.io/als-computing/microct:master + prefect: deployments: - type_spec: new_file_832 diff --git a/create_deployments_832_nersc.sh b/create_deployments_832_nersc.sh new file mode 100755 index 00000000..29886aa2 --- /dev/null +++ b/create_deployments_832_nersc.sh @@ -0,0 +1,20 @@ +export $(grep -v '^#' .env | xargs) + +# create 'nersc_flow_pool' +prefect work-pool create 'nersc_flow_pool' +prefect work-pool create 'nersc_prune_pool' + +# nersc_flow_pool + # in docker-compose.yaml: + # command: prefect agent start --pool "nersc_flow_pool" +prefect deployment build ./orchestration/flows/bl832/nersc.py:nersc_recon_flow -n nersc_recon_flow -p nersc_flow_pool -q nersc_recon_flow_queue +prefect deployment apply nersc_recon_flow-deployment.yaml + +# nersc_prune_pool + # in docker-compose.yaml: + # command: prefect agent start --pool "nersc_prune_pool" +prefect deployment build ./orchestration/flows/bl832/prune.py:prune_nersc832_alsdev_pscratch_raw -n prune_nersc832_alsdev_pscratch_raw -p nersc_prune_pool -q prune_nersc832_pscratch_queue +prefect deployment apply prune_nersc832_alsdev_pscratch_raw-deployment.yaml + +prefect deployment build ./orchestration/flows/bl832/prune.py:prune_nersc832_alsdev_pscratch_scratch -n prune_nersc832_alsdev_pscratch_scratch -p nersc_prune_pool -q prune_nersc832_pscratch_queue +prefect deployment apply prune_nersc832_alsdev_pscratch_scratch-deployment.yaml diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py new file mode 100644 index 00000000..a1809686 --- /dev/null +++ b/orchestration/_tests/test_sfapi_flow.py @@ -0,0 +1,286 @@ +# orchestration/_tests/test_sfapi_flow.py + +from pathlib import Path +import pytest +from unittest.mock import MagicMock, patch, mock_open +from uuid import uuid4 + +from prefect.blocks.system import Secret +from prefect.testing.utilities import prefect_test_harness + + +@pytest.fixture(autouse=True, scope="session") +def prefect_test_fixture(): + """ + A pytest fixture that automatically sets up and tears down the Prefect test harness + for the entire test session. It creates and saves test secrets and configurations + required for Globus integration. + + Yields: + None + """ + with prefect_test_harness(): + globus_client_id = Secret(value=str(uuid4())) + globus_client_id.save(name="globus-client-id") + globus_client_secret = Secret(value=str(uuid4())) + globus_client_secret.save(name="globus-client-secret") + + yield + + +# ---------------------------- +# Tests for create_sfapi_client +# ---------------------------- + + +def test_create_sfapi_client_success(): + """ + Test successful creation of the SFAPI client. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + # Mock data for client_id and client_secret files + mock_client_id = 'value' + mock_client_secret = '{"key": "value"}' + + # Create separate mock_open instances for each file + mock_open_client_id = mock_open(read_data=mock_client_id) + mock_open_client_secret = mock_open(read_data=mock_client_secret) + + with patch("orchestration.flows.bl832.nersc.os.getenv") as mock_getenv, \ + patch("orchestration.flows.bl832.nersc.os.path.isfile") as mock_isfile, \ + patch("builtins.open", side_effect=[ + mock_open_client_id.return_value, + mock_open_client_secret.return_value + ]), \ + patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key") as mock_import_key, \ + patch("orchestration.flows.bl832.nersc.Client") as MockClient: + + # Mock environment variables + mock_getenv.side_effect = lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret" + }.get(x, None) + + # Mock file existence + mock_isfile.return_value = True + + # Mock JsonWebKey.import_key to return a mock secret + mock_import_key.return_value = "mock_secret" + + # Create the client + client = NERSCTomographyHPCController.create_sfapi_client() + + # Assert that Client was instantiated with 'value' and 'mock_secret' + MockClient.assert_called_once_with("value", "mock_secret") + + # Assert that the returned client is the mocked client + assert client == MockClient.return_value, "Client should be the mocked sfapi_client.Client instance" + + +def test_create_sfapi_client_missing_paths(): + """ + Test creation of the SFAPI client with missing credential paths. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + with patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None): + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController.create_sfapi_client() + + +def test_create_sfapi_client_missing_files(): + """ + Test creation of the SFAPI client with missing credential files. + """ + with ( + # Mock environment variables + patch( + "orchestration.flows.bl832.nersc.os.getenv", + side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret" + }.get(x, None) + ), + + # Mock file existence to simulate missing files + patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) + ): + # Import the module after applying patches to ensure mocks are in place + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + # Expect a FileNotFoundError due to missing credential files + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController.create_sfapi_client() + +# ---------------------------- +# Fixture for Mocking SFAPI Client +# ---------------------------- + + +@pytest.fixture +def mock_sfapi_client(): + """ + Mock the sfapi_client.Client class with necessary methods. + """ + with patch("orchestration.flows.bl832.nersc.Client") as MockClient: + mock_client_instance = MockClient.return_value + + # Mock the user method + mock_user = MagicMock() + mock_user.name = "testuser" + mock_client_instance.user.return_value = mock_user + + # Mock the compute method to return a mocked compute object + mock_compute = MagicMock() + mock_job = MagicMock() + mock_job.jobid = "12345" + mock_job.state = "COMPLETED" + mock_compute.submit_job.return_value = mock_job + mock_client_instance.compute.return_value = mock_compute + + yield mock_client_instance + + +# ---------------------------- +# Fixture for Mocking Config832 +# ---------------------------- + +@pytest.fixture +def mock_config832(): + """ + Mock the Config832 class to provide necessary configurations. + """ + with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: + mock_config = MockConfig.return_value + mock_config.harbor_images832 = { + "recon_image": "mock_recon_image", + "multires_image": "mock_multires_image", + } + mock_config.apps = {"als_transfer": "some_config"} + yield mock_config + + +# ---------------------------- +# Tests for NERSCTomographyHPCController +# ---------------------------- + +def test_reconstruct_success(mock_sfapi_client, mock_config832): + """ + Test successful reconstruction job submission. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + file_path = "path/to/file.h5" + + with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): + result = controller.reconstruct(file_path=file_path) + + # Verify that compute was called with Machine.perlmutter + mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + + # Verify that submit_job was called once + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + + # Verify that complete was called on the job + mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + + # Assert that the method returns True + assert result is True, "reconstruct should return True on successful job completion." + + +def test_reconstruct_submission_failure(mock_sfapi_client, mock_config832): + """ + Test reconstruction job submission failure. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + file_path = "path/to/file.h5" + + # Simulate submission failure + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") + + with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): + result = controller.reconstruct(file_path=file_path) + + # Assert that the method returns False + assert result is False, "reconstruct should return False on submission failure." + + +def test_build_multi_resolution_success(mock_sfapi_client, mock_config832): + """ + Test successful multi-resolution job submission. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + file_path = "path/to/file.h5" + + with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): + result = controller.build_multi_resolution(file_path=file_path) + + # Verify that compute was called with Machine.perlmutter + mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + + # Verify that submit_job was called once + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + + # Verify that complete was called on the job + mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + + # Assert that the method returns True + assert result is True, "build_multi_resolution should return True on successful job completion." + + +def test_build_multi_resolution_submission_failure(mock_sfapi_client, mock_config832): + """ + Test multi-resolution job submission failure. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + file_path = "path/to/file.h5" + + # Simulate submission failure + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") + + with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): + result = controller.build_multi_resolution(file_path=file_path) + + # Assert that the method returns False + assert result is False, "build_multi_resolution should return False on submission failure." + + +def test_job_submission(mock_sfapi_client): + """ + Test job submission and status updates. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=MagicMock()) + file_path = "path/to/file.h5" + + # Mock Path to extract file and folder names + with patch.object(Path, 'parent', new_callable=MagicMock) as mock_parent, \ + patch.object(Path, 'stem', new_callable=MagicMock) as mock_stem: + mock_parent.name = "to" + mock_stem.return_value = "file" + + with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): + controller.reconstruct(file_path=file_path) + + # Verify that compute was called with Machine.perlmutter + mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + + # Verify that submit_job was called once + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + + # Verify the returned job has the expected attributes + submitted_job = mock_sfapi_client.compute.return_value.submit_job.return_value + assert submitted_job.jobid == "12345", "Job ID should match the mock job ID." + assert submitted_job.state == "COMPLETED", "Job state should be COMPLETED." diff --git a/orchestration/_tests/test_transfer_controller.py b/orchestration/_tests/test_transfer_controller.py new file mode 100644 index 00000000..15b07f54 --- /dev/null +++ b/orchestration/_tests/test_transfer_controller.py @@ -0,0 +1,306 @@ +# test_transfer_controller.py + +import pytest +from pytest_mock import MockFixture +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import globus_sdk +from prefect.blocks.system import Secret +from prefect.testing.utilities import prefect_test_harness + +from .test_globus import MockTransferClient + + +@pytest.fixture(autouse=True, scope="session") +def prefect_test_fixture(): + """ + A pytest fixture that automatically sets up and tears down the Prefect test harness + for the entire test session. It creates and saves test secrets and configurations + required for Globus integration. + + Yields: + None + """ + with prefect_test_harness(): + # Create ephemeral secrets in the local Prefect test database + globus_client_id = Secret(value=str(uuid4())) + globus_client_id.save(name="globus-client-id") + + globus_client_secret = Secret(value=str(uuid4())) + globus_client_secret.save(name="globus-client-secret") + + yield + + +@pytest.fixture(scope="session") +def transfer_controller_module(): + """ + Defer importing orchestration.transfer_controller until after + the prefect_test_fixture is loaded. This prevents Prefect from + trying to load secrets at import time. + """ + from orchestration.transfer_controller import ( + FileSystemEndpoint, + GlobusTransferController, + SimpleTransferController, + get_transfer_controller, + CopyMethod, + ) + return { + "FileSystemEndpoint": FileSystemEndpoint, + "GlobusTransferController": GlobusTransferController, + "SimpleTransferController": SimpleTransferController, + "get_transfer_controller": get_transfer_controller, + "CopyMethod": CopyMethod, + } + + +class MockEndpoint: + def __init__(self, root_path, uuid_value=None): + self.root_path = root_path + self.uuid = uuid_value or str(uuid4()) + self.uri = f"mock_endpoint_uri_{self.uuid}" + + +@pytest.fixture +def mock_config832(): + """ + Mock the Config832 class to provide necessary configurations. + """ + with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: + mock_config = MockConfig.return_value + mock_config.endpoints = { + "alcf832_raw": MockEndpoint("/alcf832_raw"), + } + mock_config.tc = MockTransferClient() + yield mock_config + + +@pytest.fixture +def mock_globus_endpoint(): + """ + A pytest fixture that returns a mocked GlobusEndpoint. + If your orchestration.globus.transfer also loads secrets at import, + you may need to similarly defer that import behind another fixture. + """ + from orchestration.globus.transfer import GlobusEndpoint + endpoint = GlobusEndpoint( + name="mock_globus_endpoint", + root_path="/mock_globus_root/", + uuid="mock_endpoint_id", + uri="mock_endpoint_uri" + ) + return endpoint + + +@pytest.fixture +def mock_file_system_endpoint(transfer_controller_module): + """ + A pytest fixture that returns a FileSystemEndpoint instance. + """ + FileSystemEndpoint = transfer_controller_module["FileSystemEndpoint"] + endpoint = FileSystemEndpoint( + name="mock_filesystem_endpoint", + root_path="/mock_fs_root" + ) + return endpoint + + +class MockSecret: + value = str(uuid4()) + + +# -------------------------------------------------------------------------- +# Tests for get_transfer_controller +# -------------------------------------------------------------------------- + +def test_get_transfer_controller_globus(mock_config832, transfer_controller_module): + CopyMethod = transfer_controller_module["CopyMethod"] + get_transfer_controller = transfer_controller_module["get_transfer_controller"] + GlobusTransferController = transfer_controller_module["GlobusTransferController"] + + controller = get_transfer_controller(CopyMethod.GLOBUS, mock_config832) + assert isinstance(controller, GlobusTransferController), ( + "Expected GlobusTransferController for GLOBUS transfer type." + ) + + +def test_get_transfer_controller_simple(mock_config832, transfer_controller_module): + CopyMethod = transfer_controller_module["CopyMethod"] + get_transfer_controller = transfer_controller_module["get_transfer_controller"] + SimpleTransferController = transfer_controller_module["SimpleTransferController"] + + controller = get_transfer_controller(CopyMethod.SIMPLE, mock_config832) + assert isinstance(controller, SimpleTransferController), ( + "Expected SimpleTransferController for SIMPLE transfer type." + ) + + +def test_get_transfer_controller_invalid(mock_config832, transfer_controller_module): + get_transfer_controller = transfer_controller_module["get_transfer_controller"] + with pytest.raises(ValueError, match="Invalid transfer type"): + get_transfer_controller("invalid_type", mock_config832) + + +# -------------------------------------------------------------------------- +# Tests for GlobusTransferController +# -------------------------------------------------------------------------- + +def test_globus_transfer_controller_copy_success( + mock_config832, mock_globus_endpoint, mocker: MockFixture, transfer_controller_module +): + """ + Test a successful copy() operation using GlobusTransferController. + We mock start_transfer to return True. + """ + GlobusTransferController = transfer_controller_module["GlobusTransferController"] + MockSecretClass = MockSecret + + # Patch any Secret.load calls to avoid real Prefect Cloud calls + mocker.patch('prefect.blocks.system.Secret.load', return_value=MockSecretClass()) + + with patch("orchestration.transfer_controller.start_transfer", return_value=True) as mock_start_transfer: + controller = GlobusTransferController(mock_config832) + result = controller.copy( + file_path="some_dir/test_file.txt", + source=mock_globus_endpoint, + destination=mock_globus_endpoint, + ) + + assert result is True, "Expected True when transfer completes successfully." + mock_start_transfer.assert_called_once() + + # Verify arguments passed to start_transfer + _, called_kwargs = mock_start_transfer.call_args + assert called_kwargs["source_endpoint"] == mock_globus_endpoint + assert called_kwargs["dest_endpoint"] == mock_globus_endpoint + assert "max_wait_seconds" in called_kwargs + + +def test_globus_transfer_controller_copy_failure( + mock_config832, mock_globus_endpoint, mocker: MockFixture, transfer_controller_module +): + """ + Test a failing copy() operation using GlobusTransferController. + We mock start_transfer to return False, indicating a transfer failure. + """ + GlobusTransferController = transfer_controller_module["GlobusTransferController"] + MockSecretClass = MockSecret + + mocker.patch('prefect.blocks.system.Secret.load', return_value=MockSecretClass()) + + with patch("orchestration.transfer_controller.start_transfer", return_value=False) as mock_start_transfer: + controller = GlobusTransferController(mock_config832) + result = controller.copy( + file_path="some_dir/test_file.txt", + source=mock_globus_endpoint, + destination=mock_globus_endpoint, + ) + assert result is False, "Expected False when transfer fails." + mock_start_transfer.assert_called_once() + + +def test_globus_transfer_controller_copy_exception( + mock_config832, mock_globus_endpoint, transfer_controller_module +): + """ + Test copy() operation that raises a TransferAPIError exception. + """ + GlobusTransferController = transfer_controller_module["GlobusTransferController"] + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.reason = "Bad Request" + + with patch( + "orchestration.transfer_controller.start_transfer", + side_effect=globus_sdk.services.transfer.errors.TransferAPIError(mock_response, "Mocked Error") + ) as mock_start_transfer: + controller = GlobusTransferController(mock_config832) + result = controller.copy( + file_path="some_dir/test_file.txt", + source=mock_globus_endpoint, + destination=mock_globus_endpoint, + ) + assert result is False, "Expected False when TransferAPIError is raised." + mock_start_transfer.assert_called_once() + + +# -------------------------------------------------------------------------- +# Tests for SimpleTransferController +# -------------------------------------------------------------------------- + +def test_simple_transfer_controller_no_file_path( + mock_config832, mock_file_system_endpoint, transfer_controller_module +): + SimpleTransferController = transfer_controller_module["SimpleTransferController"] + controller = SimpleTransferController(mock_config832) + result = controller.copy( + file_path="", + source=mock_file_system_endpoint, + destination=mock_file_system_endpoint, + ) + assert result is False, "Expected False when no file_path is provided." + + +def test_simple_transfer_controller_no_source_or_destination(mock_config832, transfer_controller_module): + SimpleTransferController = transfer_controller_module["SimpleTransferController"] + controller = SimpleTransferController(mock_config832) + result = controller.copy( + file_path="test.txt", + source=None, + destination=None, + ) + assert result is False, "Expected False when either source or destination is None." + + +def test_simple_transfer_controller_copy_success( + mock_config832, mock_file_system_endpoint, transfer_controller_module +): + SimpleTransferController = transfer_controller_module["SimpleTransferController"] + with patch("os.system", return_value=0) as mock_os_system: + controller = SimpleTransferController(mock_config832) + result = controller.copy( + file_path="some_dir/test_file.txt", + source=mock_file_system_endpoint, + destination=mock_file_system_endpoint, + ) + + assert result is True, "Expected True when os.system returns 0." + mock_os_system.assert_called_once() + command_called = mock_os_system.call_args[0][0] + assert "cp -r" in command_called, "Expected cp command in os.system call." + + +def test_simple_transfer_controller_copy_failure( + mock_config832, mock_file_system_endpoint, transfer_controller_module +): + SimpleTransferController = transfer_controller_module["SimpleTransferController"] + with patch("os.system", return_value=1) as mock_os_system: + controller = SimpleTransferController(mock_config832) + result = controller.copy( + file_path="some_dir/test_file.txt", + source=mock_file_system_endpoint, + destination=mock_file_system_endpoint, + ) + + assert result is False, "Expected False when os.system returns non-zero." + mock_os_system.assert_called_once() + command_called = mock_os_system.call_args[0][0] + assert "cp -r" in command_called, "Expected cp command in os.system call." + + +def test_simple_transfer_controller_copy_exception( + mock_config832, mock_file_system_endpoint, transfer_controller_module +): + SimpleTransferController = transfer_controller_module["SimpleTransferController"] + with patch("os.system", side_effect=Exception("Mocked cp error")) as mock_os_system: + controller = SimpleTransferController(mock_config832) + result = controller.copy( + file_path="some_dir/test_file.txt", + source=mock_file_system_endpoint, + destination=mock_file_system_endpoint, + ) + + assert result is False, "Expected False when an exception is raised during copy." + mock_os_system.assert_called_once() diff --git a/orchestration/flows/bl832/alcf.py b/orchestration/flows/bl832/alcf.py index 3836b8ef..73490015 100644 --- a/orchestration/flows/bl832/alcf.py +++ b/orchestration/flows/bl832/alcf.py @@ -3,6 +3,7 @@ import os from pathlib import Path import time +# from typing import Optional from globus_compute_sdk import Client, Executor from globus_compute_sdk.serialize import CombinedCode @@ -12,10 +13,44 @@ from prefect.blocks.system import JSON, Secret from orchestration.flows.bl832.config import Config832 +from orchestration.flows.bl832.job_controller import TomographyHPCController from orchestration.globus.transfer import GlobusEndpoint, start_transfer from orchestration.prefect import schedule_prefect_flow +class ALCFTomographyHPCController(TomographyHPCController): + """ + Implementation of TomographyHPCController for ALCF. + Methods here leverage Globus Compute for processing tasks. + + TODO: Refactor ALCF reconstruction flow into this class. + + Args: + TomographyHPCController (ABC): Abstract class for tomography HPC controllers. + """ + + def __init__(self) -> None: + pass + + def reconstruct( + self, + file_path: str = "", + ) -> bool: + + # uses Globus Compute to reconstruct the tomography + # TODO: Refactor ALCF reconstruction code into this class. + + pass + + def build_multi_resolution( + self, + file_path: str = "", + ) -> bool: + # uses Globus Compute to build multi-resolution tomography + # TODO: Refactor ALCF multi-res zarr code into this class. + pass + + @task(name="transfer_data_to_alcf") def transfer_data_to_alcf( file_path: str, diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 57de5f46..ff19a9c3 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -14,10 +14,13 @@ def __init__(self) -> None: self.data832_raw = self.endpoints["data832_raw"] self.data832_scratch = self.endpoints["data832_scratch"] self.nersc832 = self.endpoints["nersc832"] - self.nersc_test = self.endpoints["nersc_test"] self.nersc_alsdev = self.endpoints["nersc_alsdev"] self.nersc832_alsdev_raw = self.endpoints["nersc832_alsdev_raw"] self.nersc832_alsdev_scratch = self.endpoints["nersc832_alsdev_scratch"] + self.nersc832_alsdev_pscratch_raw = self.endpoints["nersc832_alsdev_pscratch_raw"] + self.nersc832_alsdev_pscratch_scratch = self.endpoints["nersc832_alsdev_pscratch_scratch"] + self.nersc832_alsdev_recon_scripts = self.endpoints["nersc832_alsdev_recon_scripts"] self.alcf832_raw = self.endpoints["alcf832_raw"] self.alcf832_scratch = self.endpoints["alcf832_scratch"] self.scicat = config["scicat"] + self.ghcr_images832 = config["ghcr_images832"] diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py new file mode 100644 index 00000000..53af1145 --- /dev/null +++ b/orchestration/flows/bl832/job_controller.py @@ -0,0 +1,116 @@ +from abc import ABC, abstractmethod +from dotenv import load_dotenv +from enum import Enum +import logging + +from orchestration.flows.bl832.config import Config832 + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +load_dotenv() + + +class TomographyHPCController(ABC): + """ + Abstract class for tomography HPC controllers. + Provides interface methods for reconstruction and building multi-resolution datasets. + + Args: + ABC: Abstract Base Class + """ + def __init__( + self, + config: Config832 + ) -> None: + self.config = config + + @abstractmethod + def reconstruct( + self, + file_path: str = "", + ) -> bool: + """Perform tomography reconstruction + + :param file_path: Path to the file to reconstruct. + :return: True if successful, False otherwise. + """ + pass + + @abstractmethod + def build_multi_resolution( + self, + file_path: str = "", + ) -> bool: + """Generate multi-resolution version of reconstructed tomography + + :param file_path: Path to the file for which to build multi-resolution data. + :return: True if successful, False otherwise. + """ + pass + + +class HPC(Enum): + """ + Enum representing different HPC environments. + Use enum names as strings to identify HPC sites, ensuring a standard set of values. + + Members: + ALCF: Argonne Leadership Computing Facility + NERSC: National Energy Research Scientific Computing Center + """ + ALCF = "ALCF" + NERSC = "NERSC" + OLCF = "OLCF" + + +def get_controller( + hpc_type: HPC, + config: Config832 +) -> TomographyHPCController: + """ + Factory function that returns an HPC controller instance for the given HPC environment. + + :param hpc_type: A string identifying the HPC environment (e.g., 'ALCF', 'NERSC'). + :return: An instance of a TomographyHPCController subclass corresponding to the given HPC environment. + :raises ValueError: If an invalid or unsupported HPC type is specified. + """ + if not isinstance(hpc_type, HPC): + raise ValueError(f"Invalid HPC type provided: {hpc_type}") + + if not config: + raise ValueError("Config object is required.") + + if hpc_type == HPC.ALCF: + from orchestration.flows.bl832.alcf import ALCFTomographyHPCController + return ALCFTomographyHPCController() + elif hpc_type == HPC.NERSC: + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + return NERSCTomographyHPCController( + client=NERSCTomographyHPCController.create_sfapi_client(), + config=config + ) + elif hpc_type == HPC.OLCF: + # TODO: Implement OLCF controller + pass + else: + raise ValueError(f"Unsupported HPC type: {hpc_type}") + + +def do_it_all() -> None: + controller = get_controller("ALCF") + controller.reconstruct() + controller.build_multi_resolution() + + file_path = "" + controller = get_controller("NERSC") + controller.reconstruct( + file_path=file_path, + ) + controller.build_multi_resolution( + file_path=file_path, + ) + + +if __name__ == "__main__": + do_it_all() + logger.info("Done.") diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py new file mode 100644 index 00000000..95e8aa31 --- /dev/null +++ b/orchestration/flows/bl832/nersc.py @@ -0,0 +1,495 @@ +import datetime +from dotenv import load_dotenv +import json +import logging +import os +from pathlib import Path +import re +import time + +from authlib.jose import JsonWebKey +from prefect import flow +from prefect.blocks.system import JSON +from sfapi_client import Client +from sfapi_client.compute import Machine + +from orchestration.flows.bl832.config import Config832 +from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController +from orchestration.transfer_controller import get_transfer_controller, CopyMethod +from orchestration.prefect import schedule_prefect_flow + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +load_dotenv() + + +class NERSCTomographyHPCController(TomographyHPCController): + """ + Implementation for a NERSC-based tomography HPC controller. + + Submits reconstruction and multi-resolution jobs to NERSC via SFAPI. + """ + + def __init__( + self, + client: Client, + config: Config832 + ) -> None: + super().__init__(config) + self.client = client + + @staticmethod + def create_sfapi_client() -> Client: + """Create and return an NERSC client instance""" + + # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! + # Otherwise, the key will not have the necessary permissions to access the data. + client_id_path = os.getenv("PATH_NERSC_CLIENT_ID") + client_secret_path = os.getenv("PATH_NERSC_PRI_KEY") + + if not client_id_path or not client_secret_path: + logger.error("NERSC credentials paths are missing.") + raise ValueError("Missing NERSC credentials paths.") + if not os.path.isfile(client_id_path) or not os.path.isfile(client_secret_path): + logger.error("NERSC credential files are missing.") + raise FileNotFoundError("NERSC credential files are missing.") + + client_id = None + client_secret = None + with open(client_id_path, "r") as f: + client_id = f.read() + + with open(client_secret_path, "r") as f: + client_secret = JsonWebKey.import_key(json.loads(f.read())) + + try: + client = Client(client_id, client_secret) + logger.info("NERSC client created successfully.") + return client + except Exception as e: + logger.error(f"Failed to create NERSC client: {e}") + raise e + + def reconstruct( + self, + file_path: str = "", + ) -> bool: + """ + Use NERSC for tomography reconstruction + """ + logger.info("Starting NERSC reconstruction process.") + + user = self.client.user() + + raw_path = self.config.nersc832_alsdev_raw.root_path + logger.info(f"{raw_path=}") + + recon_image = self.config.ghcr_images832["recon_image"] + logger.info(f"{recon_image=}") + + recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path + logger.info(f"{recon_scripts_dir=}") + + scratch_path = self.config.nersc832_alsdev_scratch.root_path + logger.info(f"{scratch_path=}") + + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + logger.info(f"{pscratch_path=}") + + path = Path(file_path) + folder_name = path.parent.name + if not folder_name: + folder_name = "" + + file_name = f"{path.stem}.h5" + + logger.info(f"File name: {file_name}") + logger.info(f"Folder name: {folder_name}") + + # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately + # Note: If q=debug, there is no minimum time limit + # However, if q=preempt, there is a minimum time limit of 2 hours. Otherwise the job won't run. + # The realtime queue can only be used for select accounts (e.g. ALS) + job_script = f"""#!/bin/bash +#SBATCH -q realtime +#SBATCH -A als +#SBATCH -C cpu +#SBATCH --job-name=tomo_recon_{folder_name}_{file_name} +#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err +#SBATCH -N 1 +#SBATCH --ntasks-per-node 1 +#SBATCH --cpus-per-task 64 +#SBATCH --time=0:15:00 +#SBATCH --exclusive + +date +echo "Creating directory {pscratch_path}/8.3.2/raw/{folder_name}" +mkdir -p {pscratch_path}/8.3.2/raw/{folder_name} +mkdir -p {pscratch_path}/8.3.2/scratch/{folder_name} + +echo "Copying file {raw_path}/{folder_name}/{file_name} to {pscratch_path}/8.3.2/raw/{folder_name}/" +cp {raw_path}/{folder_name}/{file_name} {pscratch_path}/8.3.2/raw/{folder_name} +if [ $? -ne 0 ]; then + echo "Failed to copy data to pscratch." + exit 1 +fi + +chmod -R 2775 {pscratch_path}/8.3.2 + +echo "Verifying copied files..." +ls -l {pscratch_path}/8.3.2/raw/{folder_name}/ + +echo "Running reconstruction container..." +srun podman-hpc run \ +--volume {recon_scripts_dir}/sfapi_reconstruction.py:/alsuser/sfapi_reconstruction.py \ +--volume {pscratch_path}/8.3.2:/alsdata \ +--volume {pscratch_path}/8.3.2:/alsuser/ \ +{recon_image} \ +bash -c "python sfapi_reconstruction.py {file_name} {folder_name}" +date +""" + + try: + logger.info("Submitting reconstruction job script to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() # Wait until the job completes + logger.info("Reconstruction job completed successfully.") + return True + + except Exception as e: + logger.info(f"Error during job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.perlmutter.job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Reconstruction job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + # Unknown error: cannot recover + return False + + def build_multi_resolution( + self, + file_path: str = "", + ) -> bool: + """Use NERSC to make multiresolution version of tomography results.""" + + logger.info("Starting NERSC multiresolution process.") + + user = self.client.user() + + multires_image = self.config.ghcr_images832["multires_image"] + logger.info(f"{multires_image=}") + + recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path + logger.info(f"{recon_scripts_dir=}") + + scratch_path = self.config.nersc832_alsdev_scratch.root_path + logger.info(f"{scratch_path=}") + + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + logger.info(f"{pscratch_path=}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + recon_path = f"scratch/{folder_name}/rec{file_name}/" + logger.info(f"{recon_path=}") + + raw_path = f"raw/{folder_name}/{file_name}.h5" + logger.info(f"{raw_path=}") + + # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately + job_script = f"""#!/bin/bash +#SBATCH -q realtime +#SBATCH -A als +#SBATCH -C cpu +#SBATCH --job-name=tomo_multires_{folder_name}_{file_name} +#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err +#SBATCH -N 1 +#SBATCH --ntasks-per-node 1 +#SBATCH --cpus-per-task 64 +#SBATCH --time=0:15:00 +#SBATCH --exclusive + +date + +echo "Running multires container..." +srun podman-hpc run \ +--volume {recon_scripts_dir}/tiff_to_zarr.py:/alsuser/tiff_to_zarr.py \ +--volume {pscratch_path}/8.3.2:/alsdata \ +--volume {pscratch_path}/8.3.2:/alsuser/ \ +{multires_image} \ +bash -c "python tiff_to_zarr.py {recon_path} --raw_file {raw_path}" + +date +""" + try: + logger.info("Submitting Tiff to Zarr job script to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() # Wait until the job completes + logger.info("Reconstruction job completed successfully.") + + return True + + except Exception as e: + logger.warning(f"Error during job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.perlmutter.job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Reconstruction job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False + + +def schedule_pruning( + config: Config832, + raw_file_path: str, + tiff_file_path: str, + zarr_file_path: str +) -> bool: + # data832/scratch : 14 days + # nersc/pscratch : 1 day + # nersc832/scratch : never? + + pruning_config = JSON.load("pruning-config").value + data832_delay = datetime.timedelta(days=pruning_config["delete_data832_files_after_days"]) + nersc832_delay = datetime.timedelta(days=pruning_config["delete_nersc832_files_after_days"]) + + # data832_delay, nersc832_delay = datetime.timedelta(minutes=1), datetime.timedelta(minutes=1) + + # Delete tiffs from data832_scratch + logger.info(f"Deleting tiffs from data832_scratch: {tiff_file_path=}") + try: + source_endpoint = config.data832_scratch + check_endpoint = config.nersc832_alsdev_scratch + location = "data832_scratch" + + flow_name = f"delete {location}: {Path(tiff_file_path).name}" + schedule_prefect_flow( + deployment_name=f"prune_{location}/prune_{location}", + flow_run_name=flow_name, + parameters={ + "relative_path": tiff_file_path, + "source_endpoint": source_endpoint, + "check_endpoint": check_endpoint + }, + duration_from_now=data832_delay + ) + except Exception as e: + logger.error(f"Failed to schedule prune task: {e}") + + # Delete zarr from data832_scratch + logger.info(f"Deleting zarr from data832_scratch: {zarr_file_path=}") + try: + source_endpoint = config.data832_scratch + check_endpoint = config.nersc832_alsdev_scratch + location = "data832_scratch" + + flow_name = f"delete {location}: {Path(zarr_file_path).name}" + schedule_prefect_flow( + deployment_name=f"prune_{location}/prune_{location}", + flow_run_name=flow_name, + parameters={ + "relative_path": zarr_file_path, + "source_endpoint": source_endpoint, + "check_endpoint": check_endpoint + }, + duration_from_now=data832_delay + ) + except Exception as e: + logger.error(f"Failed to schedule prune task: {e}") + + # Delete from nersc832_pscratch/raw + logger.info(f"Deleting raw from nersc832_alsdev_pscratch_raw: {raw_file_path=}") + try: + source_endpoint = config.nersc832_alsdev_pscratch_raw + check_endpoint = None + location = "nersc832_alsdev_pscratch_raw" + + flow_name = f"delete {location}: {Path(raw_file_path).name}" + schedule_prefect_flow( + deployment_name=f"prune_{location}/prune_{location}", + flow_run_name=flow_name, + parameters={ + "relative_path": raw_file_path, + "source_endpoint": source_endpoint, + "check_endpoint": check_endpoint + }, + duration_from_now=nersc832_delay + ) + except Exception as e: + logger.error(f"Failed to schedule prune task: {e}") + + # Delete tiffs from from nersc832_pscratch/scratch + logger.info(f"Deleting tiffs from nersc832_alsdev_pscratch_scratch: {tiff_file_path=}") + try: + source_endpoint = config.nersc832_alsdev_pscratch_scratch + check_endpoint = None + location = "nersc832_alsdev_pscratch_scratch" + + flow_name = f"delete {location}: {Path(tiff_file_path).name}" + schedule_prefect_flow( + deployment_name=f"prune_{location}/prune_{location}", + flow_run_name=flow_name, + parameters={ + "relative_path": tiff_file_path, + "source_endpoint": source_endpoint, + "check_endpoint": check_endpoint + }, + duration_from_now=nersc832_delay + ) + except Exception as e: + logger.error(f"Failed to schedule prune task: {e}") + + # Delete zarr from from nersc832_pscratch/scratch + logger.info(f"Deleting zarr from nersc832_alsdev_pscratch_scratch: {zarr_file_path=}") + try: + source_endpoint = config.nersc832_alsdev_pscratch_scratch + check_endpoint = None + location = "nersc832_alsdev_pscratch_scratch" + + flow_name = f"delete {location}: {Path(zarr_file_path).name}" + schedule_prefect_flow( + deployment_name=f"prune_{location}/prune_{location}", + flow_run_name=flow_name, + parameters={ + "relative_path": zarr_file_path, + "source_endpoint": source_endpoint, + "check_endpoint": check_endpoint + }, + duration_from_now=nersc832_delay + ) + except Exception as e: + logger.error(f"Failed to schedule prune task: {e}") + + +@flow(name="nersc_recon_flow") +def nersc_recon_flow( + file_path: str, + config: Config832, +) -> bool: + """ + Perform tomography reconstruction on NERSC. + + :param file_path: Path to the file to reconstruct. + """ + + logger.info(f"Starting NERSC reconstruction flow for {file_path=}") + controller = get_controller( + hpc_type=HPC.NERSC, + config=config + ) + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + + # Transfer reconstructed data + logger.info("Preparing transfer.") + transfer_controller = get_transfer_controller( + transfer_type=CopyMethod.GLOBUS, + config=config + ) + + logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") + transfer_controller.copy( + file_path=tiff_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch + ) + + transfer_controller.copy( + file_path=zarr_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch + ) + + logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") + transfer_controller.copy( + file_path=tiff_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + + transfer_controller.copy( + file_path=zarr_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + + logger.info("Scheduling pruning tasks.") + schedule_pruning( + config=config, + raw_file_path=file_path, + tiff_file_path=tiff_file_path, + zarr_file_path=zarr_file_path + ) + + # TODO: Ingest into SciCat + if nersc_reconstruction_success and nersc_multi_res_success: + return True + else: + return False + + +if __name__ == "__main__": + nersc_recon_flow( + file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + config=Config832() + ) diff --git a/orchestration/flows/bl832/olcf.py b/orchestration/flows/bl832/olcf.py new file mode 100644 index 00000000..6fdc4419 --- /dev/null +++ b/orchestration/flows/bl832/olcf.py @@ -0,0 +1,28 @@ +from orchestration.flows.bl832.job_controller import TomographyHPCController + + +class OLCFTomographyHPCController(TomographyHPCController): + """ + Implementation of TomographyHPCController for OLCF. + + Args: + TomographyHPCController (ABC): Abstract class for tomography HPC controllers. + """ + + def __init__(self) -> None: + pass + + def reconstruct( + self, + file_path: str = "", + ) -> bool: + # TODO: Implement tomography reconstruction at OLCF + # https://docs.olcf.ornl.gov/ace_testbed/defiant_quick_start_guide.html#running-jobs + pass + + def build_multi_resolution( + self, + file_path: str = "", + ) -> bool: + # TODO: Implement building multi-resolution datasets at OLCF + pass diff --git a/orchestration/flows/bl832/prune.py b/orchestration/flows/bl832/prune.py index 44b87be3..1de05085 100644 --- a/orchestration/flows/bl832/prune.py +++ b/orchestration/flows/bl832/prune.py @@ -143,5 +143,33 @@ def prune_nersc832_alsdev_scratch( config=config) +@flow(name="prune_nersc832_alsdev_pscratch_raw") +def prune_nersc832_alsdev_pscratch_raw( + relative_path: str, + source_endpoint: GlobusEndpoint, + check_endpoint: Union[GlobusEndpoint, None] = None, + config=None, +): + prune_files( + relative_path=relative_path, + source_endpoint=source_endpoint, + check_endpoint=check_endpoint, + config=config) + + +@flow(name="prune_nersc832_alsdev_pscratch_scratch") +def prune_nersc832_alsdev_pscratch_scratch( + relative_path: str, + source_endpoint: GlobusEndpoint, + check_endpoint: Union[GlobusEndpoint, None] = None, + config=None, +): + prune_files( + relative_path=relative_path, + source_endpoint=source_endpoint, + check_endpoint=check_endpoint, + config=config) + + if __name__ == "__main__": prune_nersc832_alsdev_scratch("BLS-00564_dyparkinson/") diff --git a/orchestration/globus/transfer.py b/orchestration/globus/transfer.py index 812d83da..f6947065 100644 --- a/orchestration/globus/transfer.py +++ b/orchestration/globus/transfer.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from datetime import datetime, timezone, timedelta from dateutil import parser +import json import logging import os from pathlib import Path @@ -43,6 +44,13 @@ def full_path(self, path_suffix: str): path = Path(self.root_path) / path_suffix return str(path) + def to_dict(self) -> dict: + return self.__dict__ + + @classmethod + def from_dict(cls, data: dict) -> 'GlobusEndpoint': + return cls(**data) + @dataclass class GlobusApp: @@ -301,3 +309,23 @@ def prune_one_safe( task_wait(tranfer_client, delete_id) logger.info(f"file deleted from: {source_endpoint.uri}") + + +if __name__ == "__main__": + from orchestration.flows.bl832.config import Config832 + + # test globus endpoint serialization/deserialization + config = Config832() + # Example serialization + source = config.alcf832_raw + logger.info(source) + + serialized = json.dumps(source.to_dict()) + logger.info(serialized) + + # Example deserialization + data = json.loads(serialized) + source_deserialized = GlobusEndpoint.from_dict(data) + logger.info(source_deserialized) + + assert source == source_deserialized diff --git a/orchestration/nersc.py b/orchestration/nersc.py index 777b7524..a77d33a6 100644 --- a/orchestration/nersc.py +++ b/orchestration/nersc.py @@ -1,29 +1,59 @@ +''' +DEPRECATION WARNING: NerscClient is deprecated and will be removed when we refactor the ptychography code +''' +import functools import json import logging -from pathlib import Path -import time +# from pathlib import Path +# import time +import warnings -from authlib.integrations.requests_client import OAuth2Session -from authlib.oauth2.rfc7523 import PrivateKeyJWT +# from authlib.integrations.requests_client import OAuth2Session +# from authlib.oauth2.rfc7523 import PrivateKeyJWT from authlib.jose import JsonWebKey from sfapi_client import Client -from sfapi_client._sync.client import SFAPI_BASE_URL, SFAPI_TOKEN_URL +# from sfapi_client._sync.client import SFAPI_BASE_URL, SFAPI_TOKEN_URL from sfapi_client.compute import Machine # Temporary patch till the sfapi_client is updated from sfapi_client.jobs import JobSacct -from sfapi_client.compute import Compute +# from sfapi_client.compute import Compute JobSacct.model_rebuild() +def deprecated_method(message: str): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + warnings.warn( + message="NerscClient() is deprecated and will be removed in a future version. Use the official NERSC " + "sfapi_client module instead: https://nersc.github.io/sfapi_client/", + category=DeprecationWarning, + stacklevel=2 + ) + return func(*args, **kwargs) + return wrapper + return decorator + + class NerscClient(Client): + ''' + DEPRECATION WARNING: NerscClient is deprecated and will be removed when we refactor the ptychography code + ''' def __init__( self, path_client_id, path_priv_key, logger=None, ): + warnings.warn( + "NerscClient() is deprecated and will be removed in a future version. " + "Use the official NERSC sfapi_client module instead: https://nersc.github.io/sfapi_client/", + DeprecationWarning, + stacklevel=2, # Shows warning at caller level + ) + self.path_client_id = path_client_id self.path_private_key = path_priv_key @@ -34,9 +64,8 @@ def __init__( # Reading the client_id and private key from the files self.client_id = None self.pri_key = None - #self.session = None + # self.session = None self.init_client_info() - super().__init__(self.client_id, self.pri_key) @@ -54,45 +83,54 @@ def __init__( self.has_ran = False self.perlmutter = self.compute(Machine.perlmutter) + @deprecated_method() def get_client_id(self): with open(self.path_client_id, "r") as f: self.client_id = f.read() + @deprecated_method() def get_private_key(self): with open(self.path_private_key, "r") as f: self.pri_key = JsonWebKey.import_key(json.loads(f.read())) + @deprecated_method() def get_machine_status(self): return self.perlmutter.status + @deprecated_method() def init_client_info( self ): self.get_client_id() self.get_private_key() - + + @deprecated_method() def init_directory_paths(self): self.home_path = f"/global/homes/{self.user().name[0]}/{self.user().name}" self.scratch_path = f"/pscratch/sd/{self.user().name[0]}/{self.user().name}" + @deprecated_method() def request_job_status(self): self.job = self.perlmutter.job(jobid=self.jobid) + @deprecated_method() def update_job_id(self): if self.job is None: - self.logger.info(f"No job found") + self.logger.info("No job found") else: self.jobid = self.job.jobid + @deprecated_method() def update_job_state(self): self.request_job_status() self.job_state = self.job.state - + if self.job_state == "RUNNING": - self.has_ran = True + self.has_ran = True elif self.job_state == "COMPLETE": self.logger.info(f"Job {self.jobid} with COMPLETE status") + @deprecated_method() def submit_job(self, job_script): self.task = None self.job = None @@ -104,6 +142,5 @@ def submit_job(self, job_script): self.logger.info(f"Submitting job with script: {job_script}") self.job = self.perlmutter.submit_job(job_script) self.update_job_id() - #self.update_job_state() + # self.update_job_state() self.logger.info(f"Submitted job id: {self.jobid}") - diff --git a/orchestration/transfer_controller.py b/orchestration/transfer_controller.py new file mode 100644 index 00000000..973b796d --- /dev/null +++ b/orchestration/transfer_controller.py @@ -0,0 +1,291 @@ +from abc import ABC, abstractmethod +from dotenv import load_dotenv +from enum import Enum +import logging +import os +import time +from typing import Generic, TypeVar + +import globus_sdk + +from orchestration.flows.bl832.config import Config832 +from orchestration.globus.transfer import GlobusEndpoint, start_transfer + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +load_dotenv() + + +class TransferEndpoint(ABC): + """ + Abstract base class for endpoints. + """ + def __init__( + self, + name: str, + root_path: str + ) -> None: + self.name = name + self.root_path = root_path + + def name(self) -> str: + """ + A human-readable or reference name for the endpoint. + """ + return self.name + + def root_path(self) -> str: + """ + Root path or base directory for this endpoint. + """ + return self.root_path + + +class FileSystemEndpoint(TransferEndpoint): + """ + A file system endpoint. + + Args: + TransferEndpoint: Abstract class for endpoints. + """ + def __init__( + self, + name: str, + root_path: str + ) -> None: + super().__init__(name, root_path) + + def full_path( + self, + path_suffix: str + ) -> str: + """ + Constructs the full path by appending the path_suffix to the root_path. + + Args: + path_suffix (str): The relative path to append. + + Returns: + str: The full absolute path. + """ + if path_suffix.startswith("/"): + path_suffix = path_suffix[1:] + return f"{self.root_path.rstrip('/')}/{path_suffix}" + + +Endpoint = TypeVar("Endpoint", bound=TransferEndpoint) + + +class TransferController(Generic[Endpoint], ABC): + """ + Abstract class for transferring data. + + Args: + ABC: Abstract Base Class + """ + def __init__( + self, + config: Config832 + ) -> None: + self.config = config + + @abstractmethod + def copy( + self, + file_path: str = None, + source: Endpoint = None, + destination: Endpoint = None, + ) -> bool: + """ + Copy a file from a source endpoint to a destination endpoint. + + Args: + file_path (str): The path of the file to copy. + source (Endpoint): The source endpoint. + destination (Endpoint): The destination endpoint. + + Returns: + bool: True if the transfer was successful, False otherwise. + """ + pass + + +class GlobusTransferController(TransferController[GlobusEndpoint]): + def __init__( + self, + config: Config832 + ) -> None: + super().__init__(config) + """ + Use Globus Transfer to move data between endpoints. + + Args: + TransferController: Abstract class for transferring data. + """ + def copy( + self, + file_path: str = None, + source: GlobusEndpoint = None, + destination: GlobusEndpoint = None, + ) -> bool: + """ + Copy a file from a source endpoint to a destination endpoint. + + Args: + file_path (str): The path of the file to copy. + source (GlobusEndpoint): The source endpoint. + destination (GlobusEndpoint): The destination endpoint. + transfer_client (TransferClient): The Globus transfer client. + """ + + logger.info(f"Transferring {file_path} from {source.name} to {destination.name}") + + if file_path[0] == "/": + file_path = file_path[1:] + + source_path = os.path.join(source.root_path, file_path) + dest_path = os.path.join(destination.root_path, file_path) + logger.info(f"Transferring {source_path} to {dest_path}") + # Start the timer + start_time = time.time() + success = False + try: + success = start_transfer( + transfer_client=self.config.tc, + source_endpoint=source, + source_path=source_path, + dest_endpoint=destination, + dest_path=dest_path, + max_wait_seconds=600, + logger=logger, + ) + if success: + logger.info("Transfer completed successfully.") + else: + logger.error("Transfer failed.") + return success + except globus_sdk.services.transfer.errors.TransferAPIError as e: + logger.error(f"Failed to submit transfer: {e}") + return success + finally: + # Stop the timer and calculate the duration + elapsed_time = time.time() - start_time + logger.info(f"Transfer process took {elapsed_time:.2f} seconds.") + return success + + +class SimpleTransferController(TransferController[FileSystemEndpoint]): + def __init__(self, config: Config832) -> None: + super().__init__(config) + """ + Use a simple 'cp' command to move data within the same system. + + Args: + TransferController: Abstract class for transferring data. + """ + + def copy( + self, + file_path: str = "", + source: FileSystemEndpoint = None, + destination: FileSystemEndpoint = None, + ) -> bool: + """ + Copy a file from a source endpoint to a destination endpoint using the 'cp' command. + + Args: + file_path (str): The path of the file to copy. + source (FileSystemEndpoint): The source endpoint. + destination (FileSystemEndpoint): The destination endpoint. + + Returns: + bool: True if the transfer was successful, False otherwise. + """ + if not file_path: + logger.error("No file_path provided.") + return False + if not source or not destination: + logger.error("Source or destination endpoint not provided.") + return False + + logger.info(f"Transferring {file_path} from {source.name} to {destination.name}") + + if file_path.startswith("/"): + file_path = file_path[1:] + + source_path = os.path.join(source.root_path, file_path) + dest_path = os.path.join(destination.root_path, file_path) + logger.info(f"Transferring {source_path} to {dest_path}") + + # Start the timer + start_time = time.time() + + try: + result = os.system(f"cp -r '{source_path}' '{dest_path}'") + if result == 0: + logger.info("Transfer completed successfully.") + return True + else: + logger.error(f"Transfer failed with exit code {result}.") + return False + except Exception as e: + logger.error(f"Transfer failed: {e}") + return False + finally: + # Stop the timer and calculate the duration + elapsed_time = time.time() - start_time + logger.info(f"Transfer process took {elapsed_time:.2f} seconds.") + + +class CopyMethod(Enum): + """ + Enum representing different transfer methods. + Use enum names as strings to identify transfer methods, ensuring a standard set of values. + """ + GLOBUS = "globus" + SIMPLE = "simple" + + +def get_transfer_controller( + transfer_type: CopyMethod, + config: Config832 +) -> TransferController: + """ + Get the appropriate transfer controller based on the transfer type. + + Args: + transfer_type (str): The type of transfer to perform. + config (Config832): The configuration object. + + Returns: + TransferController: The transfer controller object. + """ + if transfer_type == CopyMethod.GLOBUS: + return GlobusTransferController(config) + elif transfer_type == CopyMethod.SIMPLE: + return SimpleTransferController(config) + else: + raise ValueError(f"Invalid transfer type: {transfer_type}") + + +if __name__ == "__main__": + config = Config832() + transfer_type = CopyMethod.GLOBUS + globus_transfer_controller = get_transfer_controller(transfer_type, config) + globus_transfer_controller.copy( + file_path="dabramov/test.txt", + source=config.alcf832_raw, + destination=config.alcf832_scratch + ) + + simple_transfer_controller = get_transfer_controller(CopyMethod.SIMPLE, config) + success = simple_transfer_controller.copy( + file_path="test.rtf", + source=FileSystemEndpoint("source", "/Users/david/Documents/copy_test/test_source/"), + destination=FileSystemEndpoint("destination", "/Users/david/Documents/copy_test/test_destination/") + ) + + if success: + logger.info("Simple transfer succeeded.") + else: + logger.error("Simple transfer failed.") diff --git a/requirements.txt b/requirements.txt index 88c2dd61..7252202b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ globus-sdk>=3.0 h5py httpx>=0.22.0 -numpy +numpy==1.23.2 pillow python-dotenv prefect==2.19.5 diff --git a/scripts/cancel_sfapi_job.py b/scripts/cancel_sfapi_job.py new file mode 100644 index 00000000..53dec051 --- /dev/null +++ b/scripts/cancel_sfapi_job.py @@ -0,0 +1,38 @@ +from dotenv import load_dotenv +import json +import logging +import os + +from authlib.jose import JsonWebKey +from sfapi_client import Client +from sfapi_client.compute import Machine + + +load_dotenv() +logger = logging.getLogger(__name__) + +client_id_path = os.getenv("PATH_NERSC_CLIENT_ID") +client_secret_path = os.getenv("PATH_NERSC_PRI_KEY") + +if not client_id_path or not client_secret_path: + logger.error("NERSC credentials paths are missing.") + raise ValueError("Missing NERSC credentials paths.") +if not os.path.isfile(client_id_path) or not os.path.isfile(client_secret_path): + logger.error("NERSC credential files are missing.") + raise FileNotFoundError("NERSC credential files are missing.") + +client_id = None +client_secret = None +with open(client_id_path, "r") as f: + client_id = f.read() + +with open(client_secret_path, "r") as f: + client_secret = JsonWebKey.import_key(json.loads(f.read())) + +with Client(client_id, client_secret) as client: + perlmutter = client.compute(Machine.perlmutter) + # job = perlmutter.submit_job(job_path) + jobs = perlmutter.jobs(user="dabramov") + for job in jobs: + logger.info(f"Cancelling job: {job.jobid}") + job.cancel()