Skip to content

Commit 3ba3a46

Browse files
c-p-i-osvekars
andauthored
[doc] add small example to flight recorder tutorial (#3163)
* [doc] add small example to flight recorder tutorial --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 540bd0c commit 3ba3a46

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

prototype_source/flight_recorder_tutorial.rst

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,100 @@ Caveat: tabulate module is needed, so you might need pip install it first.
202202
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
203203
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
204204
205+
An End-to-End Example
206+
------------------------------------
207+
To demonstrate the use of Flight Recorder, we will use a small program where we induce mismatched collectives.
208+
In this example, ``rank0`` is programmed to do an additional collective.
209+
The Flight Recorder dump files are saved to the ``/tmp`` directory.
210+
For demonstration purposes, we named this program ``crash.py``.
211+
212+
.. note::
213+
Please note that this is a simplified example. In real-world scenarios, the process would involve more
214+
complexities.
215+
216+
.. code:: python
217+
:caption: A crashing example
218+
219+
import torch
220+
import torch.distributed as dist
221+
import os
222+
from datetime import timedelta
223+
224+
local_rank = int(os.environ["LOCAL_RANK"])
225+
world_size = int(os.environ["WORLD_SIZE"])
226+
assert world_size <= 8, "world size must be less than or equal to 8"
227+
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_"
228+
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
229+
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
230+
device = torch.device(f"cuda:{local_rank}")
231+
print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}")
232+
233+
# Initialize the process group with a small timeout so that jobs fail quickly
234+
dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1))
235+
236+
a = torch.full((3, 4), float(local_rank), device=device)
237+
# Write some collectives to populate Flight Recorder data
238+
for i in range(2):
239+
print(f"calling allreduce on {local_rank=}")
240+
f = dist.all_reduce(a)
241+
242+
# rank0 is doing an additional collective
243+
if local_rank == 0:
244+
print("rank0 is doing an allreduce on tensor b, but other ranks forgot")
245+
b = torch.full((4,5), float(local_rank), device=device)
246+
f = dist.all_reduce(b)
247+
248+
for i in range(2):
249+
print(f"calling allreduce on {local_rank=}")
250+
f = dist.all_reduce(a)
251+
252+
torch.cuda.synchronize(device=device)
253+
print(f"{local_rank=} exiting")
254+
255+
256+
To run this program, use ``torchrun``:
257+
258+
259+
.. code:: python
260+
261+
torchrun --nnodes=1 --nproc_per_node=2 crash.py
262+
263+
You should see two files in the ``/tmp`` directory:
264+
265+
.. code:: bash
266+
267+
$ls /tmp/trace*
268+
# Expected output
269+
/tmp/trace_0 /tmp/trace_1
270+
271+
Finally, to analyze these two files, we use the ``torchfrtrace`` command:
272+
273+
.. code:: bash
274+
275+
torchfrtrace --prefix "trace_" /tmp/
276+
277+
The output from the trace command is meant to be human-readable. It includes information about the
278+
set of collectives that caused a failure.
279+
The output for the command above is shown below.
280+
We can clearly see that rank 1 did not join the "all_reduce" collective.
281+
282+
.. code-block:: bash
283+
$torchfrtrace --prefix "trace_" /tmp/
284+
Not all ranks joining collective 5 at entry 4
285+
group info: 0:default_pg
286+
collective: nccl:all_reduce
287+
missing ranks: {1}
288+
input sizes: [[3, 4]]
289+
output sizes: [[3, 4]]
290+
expected ranks: 2
291+
collective state: scheduled
292+
collective stack trace:
293+
all_reduce at /home/cpio/local/pytorch/torch/distributed/distributed_c10d.py:2696
294+
wrapper at /home/cpio/local/pytorch/torch/distributed/c10d_logger.py:83
295+
<module> at /home/cpio/test/crash.py:44
296+
297+
298+
205299
Conclusion
206300
----------
207301
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.

0 commit comments

Comments
 (0)