Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ resource can then be used in the following manner:

test_app("gpu_x2")

Alternatively, you can define custom named resources in a Python module and point
to it using the ``TORCHX_CUSTOM_NAMED_RESOURCES`` environment variable:

.. code-block:: python

# my_resources.py
from torchx.specs import Resource

def gpu_x8_efa() -> Resource:
return Resource(cpu=100, gpu=8, memMB=819200, devices={"vpc.amazonaws.com/efa": 1})

def cpu_x32() -> Resource:
return Resource(cpu=32, gpu=0, memMB=131072)

NAMED_RESOURCES = {
"gpu_x8_efa": gpu_x8_efa,
"cpu_x32": cpu_x32,
}

Then set the environment variable:

.. code-block:: bash

export TORCHX_CUSTOM_NAMED_RESOURCES=my_resources

This allows you to use your custom resources without creating a package with entry points.


Registering Custom Components
-------------------------------
Expand Down
10 changes: 7 additions & 3 deletions torchx/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
scheduler or pipeline adapter.
"""
import difflib

import os
from typing import Callable, Dict, Mapping, Optional

from torchx.specs.api import (
Expand Down Expand Up @@ -63,8 +65,10 @@
GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
"torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
)
FB_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
"torchx.specs.fb.named_resources", "NAMED_RESOURCES", default={}
CUSTOM_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
os.environ.get("TORCHX_CUSTOM_NAMED_RESOURCES", "torchx.specs.fb.named_resources"),
"NAMED_RESOURCES",
default={},
)


Expand All @@ -75,7 +79,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
for name, resource in {
**GENERIC_NAMED_RESOURCES,
**AWS_NAMED_RESOURCES,
**FB_NAMED_RESOURCES,
**CUSTOM_NAMED_RESOURCES,
**resource_methods,
}.items():
materialized_resources[name] = resource
Expand Down
17 changes: 17 additions & 0 deletions torchx/specs/test/named_resources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict


import os
import unittest
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -47,3 +48,19 @@ def test_named_resources_library(self, mock_named_resources: MagicMock) -> None:
def test_null_and_missing_named_resources(self) -> None:
self.assertEqual(named_resources["NULL"], NULL_RESOURCE)
self.assertEqual(named_resources["MISSING"], NULL_RESOURCE)

def test_custom_named_resources_env_var(self) -> None:
import sys

mock_module = type(sys)("test_module")
mock_module.NAMED_RESOURCES = {"test_resource": mock_resource}

with patch.dict(sys.modules, {"test_module": mock_module}):
with patch(
"torchx.specs.CUSTOM_NAMED_RESOURCES", mock_module.NAMED_RESOURCES
):
import torchx.specs

factories = torchx.specs._load_named_resources()

self.assertIn("test_resource", factories)