Skip to content

Commit 91bdcc0

Browse files
committed
add shims for pytorch stable
1 parent a8899e4 commit 91bdcc0

File tree

4 files changed

+516
-0
lines changed

4 files changed

+516
-0
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# PyTorch Compatibility Shim System (Experimental)
2+
3+
This document describes the experimental compatibility shim system that allows TorchTitan to run on both PyTorch nightly and stable releases (e.g., PyTorch 2.8.0).
4+
5+
## Overview
6+
7+
The shim system is implemented in `torchtitan/experiments/compat/compat.py` and automatically patches missing PyTorch APIs when the package is imported. This allows developers using stable PyTorch releases to use TorchTitan without requiring PyTorch nightly.
8+
9+
## How It Works
10+
11+
The compatibility system uses two approaches:
12+
13+
### 1. Import Hook for Missing Modules
14+
For completely missing modules (like `torch.distributed.checkpoint._consolidate_hf_safetensors`), a custom meta path finder intercepts imports and provides shim modules with stub implementations.
15+
16+
### 2. Runtime Patching for Missing Classes
17+
For existing modules that are missing specific classes (like `DefaultStager` in `torch.distributed.checkpoint.staging`), the shim system directly adds the missing classes to the existing module at import time.
18+
19+
## Automatic Activation
20+
21+
The shim system is automatically activated when you import the `torchtitan` package:
22+
23+
```python
24+
import torchtitan # Shims are installed automatically
25+
```
26+
27+
This happens in `torchtitan/__init__.py`, which imports `torchtitan.experiments.compat` before anything else.
28+
29+
## Currently Shimmed APIs
30+
31+
### 1. Checkpoint Consolidation (`torch.distributed.checkpoint._consolidate_hf_safetensors`)
32+
- `consolidate_safetensor_files` - Raises NotImplementedError
33+
- `consolidate_safetensors_files_on_every_rank` - Raises NotImplementedError
34+
35+
**Note:** HuggingFace checkpoint export requires PyTorch nightly.
36+
37+
### 2. Checkpoint Staging (`torch.distributed.checkpoint.staging`)
38+
- `StagingOptions` - Simple placeholder for staging configuration
39+
- `DefaultStager` - Falls back to `BlockingAsyncStager` if available
40+
41+
### 3. Pipeline Schedules (`torch.distributed.pipelining.schedules`)
42+
- `ScheduleDualPipeV` - Raises NotImplementedError if instantiated
43+
44+
**Note:** Use a different pipeline schedule if you hit this error.
45+
46+
### 4. Flex Attention (`torch.nn.attention.flex_attention`)
47+
- `AuxOutput` - NamedTuple for auxiliary flex_attention outputs
48+
49+
### 5. Checkpoint Wrapper (`torch.distributed.algorithms._checkpoint.checkpoint_wrapper`)
50+
- Wraps `checkpoint_wrapper` function to filter out the `early_stop` parameter which is not available in PyTorch 2.8.0
51+
- The `early_stop` parameter is silently ignored in stable PyTorch
52+
53+
## Adding New Shims
54+
55+
If you encounter a new missing API when using stable PyTorch, you can add a shim by:
56+
57+
1. **For missing modules:** Add a factory function to `torchtitan/experiments/compat/compat.py` and register it with `register_shim()`
58+
59+
```python
60+
def _shim_new_module():
61+
module = ModuleType('torch.some.missing.module')
62+
# Add functions/classes to the module
63+
return module
64+
65+
# In install_shims():
66+
register_shim('torch.some.missing.module', _shim_new_module)
67+
```
68+
69+
2. **For missing classes in existing modules:** Add a function that patches the existing module
70+
71+
```python
72+
def _shim_existing_module():
73+
from torch.some import existing_module
74+
75+
class MissingClass:
76+
# Implementation or stub
77+
pass
78+
79+
existing_module.MissingClass = MissingClass
80+
return existing_module
81+
82+
# In install_shims():
83+
_shim_existing_module()
84+
```
85+
86+
## Testing
87+
88+
To verify the shim system works:
89+
90+
```bash
91+
# Should succeed with PyTorch 2.8.0
92+
python -c "import torchtitan; print('Shims loaded successfully')"
93+
94+
# Try importing a shimmed module
95+
python -c "from torch.distributed.checkpoint._consolidate_hf_safetensors import consolidate_safetensors_files_on_every_rank"
96+
97+
# Run the test suite
98+
python -m torchtitan.experiments.compat.test_compat
99+
```
100+
101+
## Known Limitations
102+
103+
1. **HuggingFace Checkpoint Export:** Not supported in stable PyTorch. Set `checkpoint.last_save_in_hf = false` in your config.
104+
105+
2. **ScheduleDualPipeV:** Not available in stable PyTorch. Use a different pipeline schedule.
106+
107+
3. **Async Checkpoint Staging:** Limited functionality with the shim. Some advanced features may not work.
108+
109+
## Version Compatibility
110+
111+
- **PyTorch Nightly:** All features work natively, shims are harmless
112+
- **PyTorch 2.8.0:** Tested and working with limitations noted above
113+
- **Older versions:** May require additional shims
114+
115+
## Philosophy
116+
117+
The shim system follows these principles:
118+
119+
1. **Simple and Transparent:** Easy to understand and extend
120+
2. **Fail-Fast:** Unsupported features raise clear errors explaining limitations
121+
3. **Non-Intrusive:** Works automatically without code changes
122+
4. **Compatible:** Harmless when used with PyTorch nightly
123+
124+
## Troubleshooting
125+
126+
If you encounter an import error:
127+
128+
1. Check if it's a PyTorch API that's missing in your version
129+
2. Add a shim following the patterns in `torchtitan/experiments/compat/compat.py`
130+
3. Test that both stable and nightly PyTorch work with your shim
131+
132+
For feature limitations, the error messages will guide you to either:
133+
- Upgrade to PyTorch nightly
134+
- Use an alternative feature
135+
- Disable the feature in your configuration
136+
137+
## Experimental Status
138+
139+
This compatibility system is experimental and may change in future releases. It is designed to help users who cannot use PyTorch nightly for various reasons (e.g., stability requirements, deployment constraints).
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
PyTorch compatibility shims for non-nightly versions.
9+
10+
This experimental module provides compatibility between PyTorch nightly and stable releases
11+
by shimming missing modules and functions.
12+
13+
Usage:
14+
import torchtitan.experiments.compat # noqa: F401
15+
16+
The shims are automatically installed when this module is imported.
17+
"""
18+
19+
# Import compat to auto-install shims
20+
from . import compat # noqa: F401

0 commit comments

Comments
 (0)