|
| 1 | +# PyTorch Compatibility Shim System (Experimental) |
| 2 | + |
| 3 | +This document describes the experimental compatibility shim system that allows TorchTitan to run on both PyTorch nightly and stable releases (e.g., PyTorch 2.8.0). |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +The shim system is implemented in `torchtitan/experiments/compat/compat.py` and automatically patches missing PyTorch APIs when the package is imported. This allows developers using stable PyTorch releases to use TorchTitan without requiring PyTorch nightly. |
| 8 | + |
| 9 | +## How It Works |
| 10 | + |
| 11 | +The compatibility system uses two approaches: |
| 12 | + |
| 13 | +### 1. Import Hook for Missing Modules |
| 14 | +For completely missing modules (like `torch.distributed.checkpoint._consolidate_hf_safetensors`), a custom meta path finder intercepts imports and provides shim modules with stub implementations. |
| 15 | + |
| 16 | +### 2. Runtime Patching for Missing Classes |
| 17 | +For existing modules that are missing specific classes (like `DefaultStager` in `torch.distributed.checkpoint.staging`), the shim system directly adds the missing classes to the existing module at import time. |
| 18 | + |
| 19 | +## Automatic Activation |
| 20 | + |
| 21 | +The shim system is automatically activated when you import the `torchtitan` package: |
| 22 | + |
| 23 | +```python |
| 24 | +import torchtitan # Shims are installed automatically |
| 25 | +``` |
| 26 | + |
| 27 | +This happens in `torchtitan/__init__.py`, which imports `torchtitan.experiments.compat` before anything else. |
| 28 | + |
| 29 | +## Currently Shimmed APIs |
| 30 | + |
| 31 | +### 1. Checkpoint Consolidation (`torch.distributed.checkpoint._consolidate_hf_safetensors`) |
| 32 | +- `consolidate_safetensor_files` - Raises NotImplementedError |
| 33 | +- `consolidate_safetensors_files_on_every_rank` - Raises NotImplementedError |
| 34 | + |
| 35 | +**Note:** HuggingFace checkpoint export requires PyTorch nightly. |
| 36 | + |
| 37 | +### 2. Checkpoint Staging (`torch.distributed.checkpoint.staging`) |
| 38 | +- `StagingOptions` - Simple placeholder for staging configuration |
| 39 | +- `DefaultStager` - Falls back to `BlockingAsyncStager` if available |
| 40 | + |
| 41 | +### 3. Pipeline Schedules (`torch.distributed.pipelining.schedules`) |
| 42 | +- `ScheduleDualPipeV` - Raises NotImplementedError if instantiated |
| 43 | + |
| 44 | +**Note:** Use a different pipeline schedule if you hit this error. |
| 45 | + |
| 46 | +### 4. Flex Attention (`torch.nn.attention.flex_attention`) |
| 47 | +- `AuxOutput` - NamedTuple for auxiliary flex_attention outputs |
| 48 | + |
| 49 | +### 5. Checkpoint Wrapper (`torch.distributed.algorithms._checkpoint.checkpoint_wrapper`) |
| 50 | +- Wraps `checkpoint_wrapper` function to filter out the `early_stop` parameter which is not available in PyTorch 2.8.0 |
| 51 | +- The `early_stop` parameter is silently ignored in stable PyTorch |
| 52 | + |
| 53 | +## Adding New Shims |
| 54 | + |
| 55 | +If you encounter a new missing API when using stable PyTorch, you can add a shim by: |
| 56 | + |
| 57 | +1. **For missing modules:** Add a factory function to `torchtitan/experiments/compat/compat.py` and register it with `register_shim()` |
| 58 | + |
| 59 | +```python |
| 60 | +def _shim_new_module(): |
| 61 | + module = ModuleType('torch.some.missing.module') |
| 62 | + # Add functions/classes to the module |
| 63 | + return module |
| 64 | + |
| 65 | +# In install_shims(): |
| 66 | +register_shim('torch.some.missing.module', _shim_new_module) |
| 67 | +``` |
| 68 | + |
| 69 | +2. **For missing classes in existing modules:** Add a function that patches the existing module |
| 70 | + |
| 71 | +```python |
| 72 | +def _shim_existing_module(): |
| 73 | + from torch.some import existing_module |
| 74 | + |
| 75 | + class MissingClass: |
| 76 | + # Implementation or stub |
| 77 | + pass |
| 78 | + |
| 79 | + existing_module.MissingClass = MissingClass |
| 80 | + return existing_module |
| 81 | + |
| 82 | +# In install_shims(): |
| 83 | +_shim_existing_module() |
| 84 | +``` |
| 85 | + |
| 86 | +## Testing |
| 87 | + |
| 88 | +To verify the shim system works: |
| 89 | + |
| 90 | +```bash |
| 91 | +# Should succeed with PyTorch 2.8.0 |
| 92 | +python -c "import torchtitan; print('Shims loaded successfully')" |
| 93 | + |
| 94 | +# Try importing a shimmed module |
| 95 | +python -c "from torch.distributed.checkpoint._consolidate_hf_safetensors import consolidate_safetensors_files_on_every_rank" |
| 96 | + |
| 97 | +# Run the test suite |
| 98 | +python -m torchtitan.experiments.compat.test_compat |
| 99 | +``` |
| 100 | + |
| 101 | +## Known Limitations |
| 102 | + |
| 103 | +1. **HuggingFace Checkpoint Export:** Not supported in stable PyTorch. Set `checkpoint.last_save_in_hf = false` in your config. |
| 104 | + |
| 105 | +2. **ScheduleDualPipeV:** Not available in stable PyTorch. Use a different pipeline schedule. |
| 106 | + |
| 107 | +3. **Async Checkpoint Staging:** Limited functionality with the shim. Some advanced features may not work. |
| 108 | + |
| 109 | +## Version Compatibility |
| 110 | + |
| 111 | +- **PyTorch Nightly:** All features work natively, shims are harmless |
| 112 | +- **PyTorch 2.8.0:** Tested and working with limitations noted above |
| 113 | +- **Older versions:** May require additional shims |
| 114 | + |
| 115 | +## Philosophy |
| 116 | + |
| 117 | +The shim system follows these principles: |
| 118 | + |
| 119 | +1. **Simple and Transparent:** Easy to understand and extend |
| 120 | +2. **Fail-Fast:** Unsupported features raise clear errors explaining limitations |
| 121 | +3. **Non-Intrusive:** Works automatically without code changes |
| 122 | +4. **Compatible:** Harmless when used with PyTorch nightly |
| 123 | + |
| 124 | +## Troubleshooting |
| 125 | + |
| 126 | +If you encounter an import error: |
| 127 | + |
| 128 | +1. Check if it's a PyTorch API that's missing in your version |
| 129 | +2. Add a shim following the patterns in `torchtitan/experiments/compat/compat.py` |
| 130 | +3. Test that both stable and nightly PyTorch work with your shim |
| 131 | + |
| 132 | +For feature limitations, the error messages will guide you to either: |
| 133 | +- Upgrade to PyTorch nightly |
| 134 | +- Use an alternative feature |
| 135 | +- Disable the feature in your configuration |
| 136 | + |
| 137 | +## Experimental Status |
| 138 | + |
| 139 | +This compatibility system is experimental and may change in future releases. It is designed to help users who cannot use PyTorch nightly for various reasons (e.g., stability requirements, deployment constraints). |
0 commit comments