From 60f2bc237695eccdbbc21d1630fa137c40ea1d05 Mon Sep 17 00:00:00 2001 From: Alexander Zhipa Date: Sun, 29 Jun 2025 22:32:20 -0400 Subject: [PATCH] feat: add custom named_resources support (#1085) --- docs/source/advanced.rst | 27 +++++++++++++++++++++++ torchx/specs/__init__.py | 10 ++++++--- torchx/specs/test/named_resources_test.py | 17 ++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 61615ee18..e803f460f 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -153,6 +153,33 @@ resource can then be used in the following manner: test_app("gpu_x2") +Alternatively, you can define custom named resources in a Python module and point +to it using the ``TORCHX_CUSTOM_NAMED_RESOURCES`` environment variable: + +.. code-block:: python + + # my_resources.py + from torchx.specs import Resource + + def gpu_x8_efa() -> Resource: + return Resource(cpu=100, gpu=8, memMB=819200, devices={"vpc.amazonaws.com/efa": 1}) + + def cpu_x32() -> Resource: + return Resource(cpu=32, gpu=0, memMB=131072) + + NAMED_RESOURCES = { + "gpu_x8_efa": gpu_x8_efa, + "cpu_x32": cpu_x32, + } + +Then set the environment variable: + +.. code-block:: bash + + export TORCHX_CUSTOM_NAMED_RESOURCES=my_resources + +This allows you to use your custom resources without creating a package with entry points. + Registering Custom Components ------------------------------- diff --git a/torchx/specs/__init__.py b/torchx/specs/__init__.py index c31eb9365..b7ecb207d 100644 --- a/torchx/specs/__init__.py +++ b/torchx/specs/__init__.py @@ -12,6 +12,8 @@ scheduler or pipeline adapter. """ import difflib + +import os from typing import Callable, Dict, Mapping, Optional from torchx.specs.api import ( @@ -63,8 +65,10 @@ GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr( "torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={} ) -FB_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr( - "torchx.specs.fb.named_resources", "NAMED_RESOURCES", default={} +CUSTOM_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr( + os.environ.get("TORCHX_CUSTOM_NAMED_RESOURCES", "torchx.specs.fb.named_resources"), + "NAMED_RESOURCES", + default={}, ) @@ -75,7 +79,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]: for name, resource in { **GENERIC_NAMED_RESOURCES, **AWS_NAMED_RESOURCES, - **FB_NAMED_RESOURCES, + **CUSTOM_NAMED_RESOURCES, **resource_methods, }.items(): materialized_resources[name] = resource diff --git a/torchx/specs/test/named_resources_test.py b/torchx/specs/test/named_resources_test.py index f03705b8d..2632b416e 100644 --- a/torchx/specs/test/named_resources_test.py +++ b/torchx/specs/test/named_resources_test.py @@ -8,6 +8,7 @@ # pyre-strict +import os import unittest from unittest.mock import MagicMock, patch @@ -47,3 +48,19 @@ def test_named_resources_library(self, mock_named_resources: MagicMock) -> None: def test_null_and_missing_named_resources(self) -> None: self.assertEqual(named_resources["NULL"], NULL_RESOURCE) self.assertEqual(named_resources["MISSING"], NULL_RESOURCE) + + def test_custom_named_resources_env_var(self) -> None: + import sys + + mock_module = type(sys)("test_module") + mock_module.NAMED_RESOURCES = {"test_resource": mock_resource} + + with patch.dict(sys.modules, {"test_module": mock_module}): + with patch( + "torchx.specs.CUSTOM_NAMED_RESOURCES", mock_module.NAMED_RESOURCES + ): + import torchx.specs + + factories = torchx.specs._load_named_resources() + + self.assertIn("test_resource", factories)