Skip to content

Commit 1577ee6

Browse files
authored
component_integration_tests: use compute_world instead of dummy_app (#710)
1 parent cc6b386 commit 1577ee6

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ hydra-core
1313
ipython
1414
kfp==1.8.9
1515
mlflow-skinny
16-
moto==4.1.3
16+
moto==4.1.6
1717
pyre-extensions
1818
pyre-check
1919
pytest

scripts/component_integration_tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ def _mock_aws_batch() -> None:
166166
This sets up a mock AWS batch backend that uses Docker to execute the jobs
167167
locally.
168168
"""
169+
# setup the docker network so DNS works correctly
170+
from torchx.schedulers.docker_scheduler import ensure_network, NETWORK
171+
172+
ensure_network()
173+
os.environ.setdefault("MOTO_DOCKER_NETWORK_NAME", NETWORK)
174+
169175
from moto import mock_batch, mock_ec2, mock_ecs, mock_iam, mock_logs
170176

171177
mock_batch().__enter__()

torchx/components/integration_tests/component_provider.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@ def tearDown(self) -> None:
3737
class DDPComponentProvider(ComponentProvider):
3838
def get_app_def(self) -> AppDef:
3939
return dist_components.ddp(
40-
script="torchx/components/integration_tests/test/dummy_app.py",
40+
m="torchx.examples.apps.compute_world_size.main",
4141
name="ddp-trainer",
4242
image=self._image,
4343
cpu=1,
4444
j="2x2",
4545
max_retries=3,
46+
env={
47+
"LOGLEVEL": "INFO",
48+
},
4649
)
4750

4851

torchx/schedulers/docker_scheduler.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444

4545
if TYPE_CHECKING:
46+
from docker import DockerClient
4647
from docker.models.containers import Container
4748

4849
log: logging.Logger = logging.getLogger(__name__)
@@ -94,6 +95,30 @@ def has_docker() -> bool:
9495
return False
9596

9697

98+
def ensure_network(client: Optional["DockerClient"] = None) -> None:
99+
"""
100+
This creates the torchx docker network. Multi-process safe.
101+
"""
102+
import filelock
103+
from docker.errors import APIError
104+
105+
if client is None:
106+
import docker
107+
108+
client = docker.from_env()
109+
110+
lock_path = os.path.join(tempfile.gettempdir(), "torchx_docker_network_lock")
111+
112+
# Docker networks.create check_duplicate has a race condition so we need
113+
# to do client side locking to ensure only one network is created.
114+
with filelock.FileLock(lock_path, timeout=10):
115+
try:
116+
client.networks.create(name=NETWORK, driver="bridge", check_duplicate=True)
117+
except APIError as e:
118+
if "already exists" not in str(e):
119+
raise
120+
121+
97122
class DockerOpts(TypedDict, total=False):
98123
copy_env: Optional[List[str]]
99124

@@ -145,24 +170,6 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
145170
def __init__(self, session_name: str) -> None:
146171
super().__init__("docker", session_name)
147172

148-
def _ensure_network(self) -> None:
149-
import filelock
150-
from docker.errors import APIError
151-
152-
client = self._docker_client
153-
lock_path = os.path.join(tempfile.gettempdir(), "torchx_docker_network_lock")
154-
155-
# Docker networks.create check_duplicate has a race condition so we need
156-
# to do client side locking to ensure only one network is created.
157-
with filelock.FileLock(lock_path, timeout=10):
158-
try:
159-
client.networks.create(
160-
name=NETWORK, driver="bridge", check_duplicate=True
161-
)
162-
except APIError as e:
163-
if "already exists" not in str(e):
164-
raise
165-
166173
def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:
167174
client = self._docker_client
168175

@@ -180,7 +187,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:
180187
except Exception as e:
181188
log.warning(f"failed to pull image {image}, falling back to local: {e}")
182189

183-
self._ensure_network()
190+
ensure_network(self._docker_client)
184191

185192
for container in req.containers:
186193
client.containers.run(

0 commit comments

Comments
 (0)