Skip to content

Commit 665f188

Browse files
authored
Merge pull request #86 from apoorvkh/support-type-checking
Support type checking
2 parents 42fa349 + 3ad386a commit 665f188

File tree

18 files changed

+401
-448
lines changed

18 files changed

+401
-448
lines changed

README.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Requires: Linux (+ SSH & shared filesystem if using multiple machines)
2828
Dummy distributed training function:
2929

3030
```python
31+
from __future__ import annotations
3132
import os
3233
import torch
3334
import torch.nn as nn
@@ -59,26 +60,24 @@ Launching training with `torchrunx`:
5960
```python
6061
import torchrunx
6162

62-
results = torchrunx.launch(
63-
func = train,
64-
kwargs = dict(
65-
model = nn.Linear(10, 10),
66-
num_steps = 10
67-
),
68-
#
63+
results = torchrunx.Launcher(
6964
hostnames = ["localhost", "second_machine"],
7065
workers_per_host = 2
66+
).run(
67+
train,
68+
model = nn.Linear(10, 10),
69+
num_steps = 10
7170
)
7271

7372
trained_model: nn.Module = results.rank(0)
7473
torch.save(trained_model.state_dict(), "output/model.pth")
7574
```
7675

7776
**See examples where we fine-tune LLMs (e.g. GPT-2 on WikiText) using:**
78-
- [Accelerate](https://torchrun.xyz/examples/accelerate.html)
79-
- [HF Transformers](https://torchrun.xyz/examples/transformers.html)
77+
- [Transformers](https://torchrun.xyz/examples/transformers.html)
8078
- [DeepSpeed](https://torchrun.xyz/examples/deepspeed.html)
8179
- [PyTorch Lightning](https://torchrun.xyz/examples/lightning.html)
80+
- [Accelerate](https://torchrun.xyz/examples/accelerate.html)
8281

8382
**Refer to our [API](https://torchrun.xyz/api.html) and [Advanced Usage Guide](https://torchrun.xyz/advanced.html) for many more capabilities!**
8483

@@ -118,4 +117,4 @@ torch.save(trained_model.state_dict(), "output/model.pth")
118117
> - Automatic detection of SLURM environments.
119118
> - Start multi-node training from Python notebooks!
120119
121-
**On our [roadmap](https://github.com/apoorvkh/torchrunx/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement): higher-order parallelism, support for debuggers, fuller typing, and more!**
120+
**On our [roadmap](https://github.com/apoorvkh/torchrunx/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement): higher-order parallelism, support for debuggers, and more!**

docs/conf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
"sphinx_toolbox.github",
2525
]
2626

27+
maximum_signature_line_length = 90
2728
autodoc_member_order = "bysource"
28-
autodoc_typehints = "description"
29-
autodoc_typehints_description_target = "documented"
3029

31-
intersphinx_mapping = {'python': ('https://docs.python.org/3', None)}
30+
intersphinx_mapping = {
31+
'python': ('https://docs.python.org/3.9', None),
32+
}
3233

3334
from docs.linkcode_github import generate_linkcode_resolve_fn
3435
linkcode_resolve = generate_linkcode_resolve_fn(project, github_username, github_repository)

docs/source/api.md

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,6 @@
11
# API
22

33
```{eval-rst}
4-
.. autofunction:: torchrunx.launch(func, args, kwargs, ...)
5-
```
6-
7-
We provide the {obj}`torchrunx.Launcher` class as an alias to {obj}`torchrunx.launch`.
8-
9-
```{eval-rst}
10-
.. autoclass:: torchrunx.Launcher
11-
:members:
12-
```
13-
14-
## Results
15-
16-
```{eval-rst}
17-
.. autoclass:: torchrunx.LaunchResult
4+
.. automodule:: torchrunx
185
:members:
196
```
20-
21-
## Exceptions
22-
23-
```{eval-rst}
24-
.. autoexception:: torchrunx.AgentFailedError
25-
```
26-
27-
```{eval-rst}
28-
.. autoexception:: torchrunx.WorkerFailedError
29-
```

docs/source/features/cli.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# CLI Integration
22

3-
We can use {mod}`torchrunx.Launcher` to populate arguments from the CLI (e.g. with [tyro](https://brentyi.github.io/tyro/)):
3+
We can automatically populate {mod}`torchrunx.Launcher` arguments using most CLI tools (those that generate interfaces from Data Classes, e.g. [tyro](https://brentyi.github.io/tyro/)):
44

55
```python
6-
import torchrunx as trx
6+
import torchrunx
77
import tyro
88

99
def distributed_function():
10-
pass
10+
...
1111

1212
if __name__ == "__main__":
13-
launcher = tyro.cli(trx.Launcher)
13+
launcher = tyro.cli(torchrunx.Launcher)
1414
launcher.run(distributed_function)
1515
```
1616

pyproject.toml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ authors = [
1111
]
1212
description = "Automatically initialize distributed PyTorch environments"
1313
readme = "README.md"
14-
license = {file = "LICENSE"}
14+
license = { file = "LICENSE" }
1515
urls = { Repository = "https://github.com/apoorvkh/torchrunx.git", Documentation = "https://torchrun.xyz" }
1616
requires-python = ">=3.9"
1717
dependencies = [
@@ -21,12 +21,17 @@ dependencies = [
2121
# torch.distributed depends on numpy
2222
# torch<=2.2 needs numpy<2
2323
"numpy>=1.20",
24+
"typing-extensions>=4.9.0",
2425
]
2526
[dependency-groups]
2627
dev = ["ruff==0.9.5", "pyright[nodejs]==1.1.393", "pytest==8.3.4"]
2728
test-extras = ["submitit", "transformers"]
28-
docs = ["sphinx==7.4.7", "furo==2024.8.6", "myst-parser==3.0.1", "sphinx-toolbox==3.8.2"]
29-
29+
docs = [
30+
"sphinx==7.4.7",
31+
"furo==2024.8.6",
32+
"myst-parser==3.0.1",
33+
"sphinx-toolbox==3.8.2",
34+
]
3035

3136
[tool.ruff]
3237
include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
@@ -36,6 +41,8 @@ src = ["src", "tests"]
3641
[tool.ruff.lint]
3742
select = ["ALL"]
3843
ignore = [
44+
"TC003", # no type checking blocks for stdlib
45+
"D104", # package docstrings
3946
"ANN401", # self / cls / Any annotations
4047
"BLE001", # blind exceptions
4148
"TD", # todo syntax

src/torchrunx/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
"""API for our torchrunx library."""
2-
3-
from .launcher import Launcher, LaunchResult, launch
1+
from .launcher import DEFAULT_ENV_VARS_FOR_COPY, Launcher, LaunchResult
42
from .utils.errors import AgentFailedError, WorkerFailedError
53

6-
__all__ = [
7-
"AgentFailedError",
8-
"LaunchResult",
4+
__all__ = [ # noqa: RUF022
5+
"DEFAULT_ENV_VARS_FOR_COPY",
96
"Launcher",
7+
"LaunchResult",
8+
"AgentFailedError",
109
"WorkerFailedError",
11-
"launch",
1210
]
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
"""Utilities for integrations with other libraries."""

0 commit comments

Comments
 (0)