@@ -202,6 +202,100 @@ Caveat: tabulate module is needed, so you might need pip install it first.
202
202
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
203
203
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
204
204
205
+ An End-to-End Example
206
+ ------------------------------------
207
+ To demonstrate the use of Flight Recorder, we will use a small program where we induce mismatched collectives.
208
+ In this example, ``rank0 `` is programmed to do an additional collective.
209
+ The Flight Recorder dump files are saved to the ``/tmp `` directory.
210
+ For demonstration purposes, we named this program ``crash.py ``.
211
+
212
+ .. note ::
213
+ Please note that this is a simplified example. In real-world scenarios, the process would involve more
214
+ complexities.
215
+
216
+ .. code :: python
217
+ :caption: A crashing example
218
+
219
+ import torch
220
+ import torch.distributed as dist
221
+ import os
222
+ from datetime import timedelta
223
+
224
+ local_rank = int (os.environ[" LOCAL_RANK" ])
225
+ world_size = int (os.environ[" WORLD_SIZE" ])
226
+ assert world_size <= 8 , " world size must be less than or equal to 8"
227
+ os.environ[" TORCH_NCCL_DEBUG_INFO_TEMP_FILE" ] = " /tmp/trace_"
228
+ os.environ[" TORCH_NCCL_DUMP_ON_TIMEOUT" ] = " 1"
229
+ os.environ[" TORCH_NCCL_TRACE_BUFFER_SIZE" ] = " 2000"
230
+ device = torch.device(f " cuda: { local_rank} " )
231
+ print (f " { local_rank= } { world_size= } master addr: { os.environ[' MASTER_ADDR' ]} master port: { os.environ[' MASTER_PORT' ]} { device= } " )
232
+
233
+ # Initialize the process group with a small timeout so that jobs fail quickly
234
+ dist.init_process_group(" nccl" , world_size = world_size, rank = local_rank, timeout = timedelta(seconds = 1 ))
235
+
236
+ a = torch.full((3 , 4 ), float (local_rank), device = device)
237
+ # Write some collectives to populate Flight Recorder data
238
+ for i in range (2 ):
239
+ print (f " calling allreduce on { local_rank= } " )
240
+ f = dist.all_reduce(a)
241
+
242
+ # rank0 is doing an additional collective
243
+ if local_rank == 0 :
244
+ print (" rank0 is doing an allreduce on tensor b, but other ranks forgot" )
245
+ b = torch.full((4 ,5 ), float (local_rank), device = device)
246
+ f = dist.all_reduce(b)
247
+
248
+ for i in range (2 ):
249
+ print (f " calling allreduce on { local_rank= } " )
250
+ f = dist.all_reduce(a)
251
+
252
+ torch.cuda.synchronize(device = device)
253
+ print (f " { local_rank= } exiting " )
254
+
255
+
256
+ To run this program, use ``torchrun ``:
257
+
258
+
259
+ .. code :: python
260
+
261
+ torchrun -- nnodes= 1 -- nproc_per_node= 2 crash.py
262
+
263
+ You should see two files in the ``/tmp `` directory:
264
+
265
+ .. code :: bash
266
+
267
+ $ls /tmp/trace*
268
+ # Expected output
269
+ /tmp/trace_0 /tmp/trace_1
270
+
271
+ Finally, to analyze these two files, we use the ``torchfrtrace `` command:
272
+
273
+ .. code :: bash
274
+
275
+ torchfrtrace --prefix " trace_" /tmp/
276
+
277
+ The output from the trace command is meant to be human-readable. It includes information about the
278
+ set of collectives that caused a failure.
279
+ The output for the command above is shown below.
280
+ We can clearly see that rank 1 did not join the "all_reduce" collective.
281
+
282
+ .. code-block :: bash
283
+ $torchfrtrace --prefix " trace_" /tmp/
284
+ Not all ranks joining collective 5 at entry 4
285
+ group info: 0:default_pg
286
+ collective: nccl:all_reduce
287
+ missing ranks: {1}
288
+ input sizes: [[3, 4]]
289
+ output sizes: [[3, 4]]
290
+ expected ranks: 2
291
+ collective state: scheduled
292
+ collective stack trace:
293
+ all_reduce at /home/cpio/local/pytorch/torch/distributed/distributed_c10d.py:2696
294
+ wrapper at /home/cpio/local/pytorch/torch/distributed/c10d_logger.py:83
295
+ < module> at /home/cpio/test/crash.py:44
296
+
297
+
298
+
205
299
Conclusion
206
300
----------
207
301
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.
0 commit comments