Skip to content

Commit cf73c8b

Browse files
aivanoufacebook-github-bot
authored andcommitted
Rename torch_dist_role script_args to args, script_envs to env (#64)
Summary: Pull Request resolved: #64 Rename torch_dist_role ``script_args`` to ``args``, ``script_envs`` to ``env`` Reviewed By: kiukchung Differential Revision: D29152204 fbshipit-source-id: a0a234d7b10c86dc073504c8633b692524df8cb8
1 parent 8513edf commit cf73c8b

File tree

6 files changed

+31
-31
lines changed

6 files changed

+31
-31
lines changed

torchx/cli/test/cmd_describe_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def get_test_app(self) -> AppDef:
2121
"elastic_trainer",
2222
image="trainer_fbpkg",
2323
entrypoint="trainer.par",
24-
script_args=["--arg1", "foo"],
24+
args=["--arg1", "foo"],
2525
resource=resource,
2626
num_replicas=2,
2727
nnodes="2:3",

torchx/components/base/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def torch_dist_role(
3232
entrypoint: str,
3333
resource: Union[str, Resource] = NULL_RESOURCE,
3434
base_image: Optional[str] = None,
35-
script_args: Optional[List[str]] = None,
36-
script_envs: Optional[Dict[str, str]] = None,
35+
args: Optional[List[str]] = None,
36+
env: Optional[Dict[str, str]] = None,
3737
num_replicas: int = 1,
3838
max_retries: int = 0,
3939
port_map: Optional[Dict[str, int]] = None,
@@ -65,8 +65,8 @@ def torch_dist_role(
6565
entrypoint: Script or binary to launch
6666
resource: Resource specs that define the container properties. Predefined resources
6767
are supported as str arguments.
68-
script_args: Arguments to the script
69-
script_envs: Env. variables to the worker
68+
args: Arguments to the script
69+
env: Env. variables to the worker
7070
num_replicas: Number of replicas
7171
max_retries: Number of retries
7272
retry_policy: ``torchx.specs.api.RetryPolicy``
@@ -90,8 +90,8 @@ def torch_dist_role(
9090
entrypoint,
9191
resource,
9292
base_image,
93-
script_args,
94-
script_envs,
93+
args,
94+
env,
9595
num_replicas,
9696
max_retries,
9797
port_map or {},

torchx/components/base/roles.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def create_torch_dist_role(
1717
entrypoint: str,
1818
resource: Resource = NULL_RESOURCE,
1919
base_image: Optional[str] = None,
20-
script_args: Optional[List[str]] = None,
21-
script_envs: Optional[Dict[str, str]] = None,
20+
args: Optional[List[str]] = None,
21+
env: Optional[Dict[str, str]] = None,
2222
num_replicas: int = 1,
2323
max_retries: int = 0,
2424
port_map: Dict[str, int] = field(default_factory=dict),
@@ -54,7 +54,7 @@ def create_torch_dist_role(
5454
... image="<NONE>",
5555
... resource=NULL_RESOURCE,
5656
... entrypoint="my_train_script.py",
57-
... script_args=["--script_arg", "foo", "--another_arg", "bar"],
57+
... args=["--script_arg", "foo", "--another_arg", "bar"],
5858
... num_replicas=4, max_retries=1,
5959
... nproc_per_node=8, nnodes="2:4", max_restarts=3)
6060
... # effectively runs:
@@ -72,8 +72,8 @@ def create_torch_dist_role(
7272
entrypoint: User binary or python script that will be launched.
7373
resource: Resource that is requested by scheduler
7474
base_image: Optional base image, if schedulers support image overlay
75-
script_args: User provided arguments
76-
script_envs: Env. variables that will be set on worker process that runs entrypoint
75+
args: User provided arguments
76+
env: Env. variables that will be set on worker process that runs entrypoint
7777
num_replicas: Number of role replicas to run
7878
max_retries: Max number of retries
7979
port_map: Port mapping for the role
@@ -84,11 +84,11 @@ def create_torch_dist_role(
8484
Role object that launches user entrypoint via the torchelastic as proxy
8585
8686
"""
87-
script_args = script_args or []
88-
script_envs = script_envs or {}
87+
args = args or []
88+
env = env or {}
8989

9090
entrypoint_override = "python"
91-
args: List[str] = ["-m", "torch.distributed.launch"]
91+
torch_run_args: List[str] = ["-m", "torch.distributed.launch"]
9292

9393
launch_kwargs.setdefault("rdzv_backend", "etcd")
9494
launch_kwargs.setdefault("rdzv_id", macros.app_id)
@@ -98,14 +98,14 @@ def create_torch_dist_role(
9898
if isinstance(val, bool):
9999
# treat boolean kwarg as a flag
100100
if val:
101-
args += [f"--{arg}"]
101+
torch_run_args += [f"--{arg}"]
102102
else:
103-
args += [f"--{arg}", str(val)]
103+
torch_run_args += [f"--{arg}", str(val)]
104104
if not os.path.isabs(entrypoint) and not entrypoint.startswith(macros.img_root):
105105
# make entrypoint relative to {img_root} ONLY if it is not an absolute path
106106
entrypoint = os.path.join(macros.img_root, entrypoint)
107107

108-
args += [entrypoint, *script_args]
108+
args = [*torch_run_args, entrypoint, *args]
109109
return (
110110
Role(
111111
name,
@@ -114,7 +114,7 @@ def create_torch_dist_role(
114114
resource=resource,
115115
port_map=port_map,
116116
)
117-
.runs(entrypoint_override, *args, **script_envs)
117+
.runs(entrypoint_override, *args, **env)
118118
.replicas(num_replicas)
119119
.with_retry_policy(retry_policy, max_retries)
120120
)

torchx/components/base/test/lib_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def test_torch_dist_role_default(self) -> None:
3636
entrypoint="test_entry.py",
3737
base_image="test_base_image",
3838
resource=Resource(1, 1, 10),
39-
script_args=["arg1", "arg2"],
40-
script_envs={"FOO": "BAR"},
39+
args=["arg1", "arg2"],
40+
env={"FOO": "BAR"},
4141
nnodes=2,
4242
)
4343

torchx/components/base/test/roles_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def test_build_create_torch_dist_role(self) -> None:
2727
"elastic_trainer",
2828
image="test_image",
2929
entrypoint="/bin/echo",
30-
script_args=["hello", "world"],
31-
script_envs={"ENV_VAR_1": "FOOBAR"},
30+
args=["hello", "world"],
31+
env={"ENV_VAR_1": "FOOBAR"},
3232
port_map={"foo": 8080},
3333
nnodes="2:4",
3434
max_restarts=3,
@@ -65,7 +65,7 @@ def test_build_create_torch_dist_role_override_rdzv_params(self) -> None:
6565
"test_role",
6666
image="torch_image",
6767
entrypoint="user_script.py",
68-
script_args=["--script_arg", "foo"],
68+
args=["--script_arg", "foo"],
6969
nnodes="2:4",
7070
rdzv_backend="etcd",
7171
rdzv_id="foobar",
@@ -148,7 +148,7 @@ def test_json_serialization_factory(self) -> None:
148148
image="user_image",
149149
entrypoint="user_script.py",
150150
resource=resource,
151-
script_args=["--script_arg", "foo"],
151+
args=["--script_arg", "foo"],
152152
port_map={"tensorboard": 8080},
153153
nnodes="2:4",
154154
rdzv_backend="etcd",

torchx/specs/api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -772,20 +772,20 @@ def _create_args_parser(
772772
)
773773

774774
for param_name, parameter in parameters.items():
775-
script_args: Dict[str, Any] = {
775+
args: Dict[str, Any] = {
776776
"help": args_desc[param_name],
777777
"type": get_argparse_param_type(parameter),
778778
}
779779
if parameter.default != inspect.Parameter.empty:
780-
script_args["default"] = parameter.default
780+
args["default"] = parameter.default
781781
if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL:
782-
script_args["nargs"] = argparse.REMAINDER
782+
args["nargs"] = argparse.REMAINDER
783783
arg_name = param_name
784784
else:
785785
arg_name = f"--{param_name}"
786-
if "default" not in script_args:
787-
script_args["required"] = True
788-
script_parser.add_argument(arg_name, **script_args)
786+
if "default" not in args:
787+
args["required"] = True
788+
script_parser.add_argument(arg_name, **args)
789789
return script_parser
790790

791791

0 commit comments

Comments
 (0)