forked from NVIDIA/TransformerEngine
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathrun_tp_overlap_ut.sh
More file actions
62 lines (44 loc) · 3.8 KB
/
Copy pathrun_tp_overlap_ut.sh
File metadata and controls
62 lines (44 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/bin/bash
TP_SIZE=8
WARMUP_ITERS=10
TIMING_ITERS=20
# default
#BATCH_SIZE=2
#SEQ_LENGTH=512
#NUM_HEADS=12
#HEAD_SIZE=64
# set engine of musaMemcpyAsync, 1 for DMA, 2 for TDM, 3 for CE
# export MUSA_MEMCPY_PATH=2
# llama3 70B
BATCH_SIZE=1
SEQ_LENGTH=8192
NUM_HEADS=64
HEAD_SIZE=128
# llama3 405B
# BATCH_SIZE=1
# SEQ_LENGTH=8192
# NUM_HEADS=128
# HEAD_SIZE=128
# bulk overlap rs
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type rs --bulk-overlap --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --check-numerics"
# bulk overlap rs over ce
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type rs --bulk-overlap --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --use-ce --check-numerics"
# bulk overlap ag
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type ag --bulk-overlap --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --check-numerics"
# bulk overlap ag over ce
# export MUSA_MEMCPY_PATH=3 # set engine of musaMemcpyAsync is CE for bulk ag
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type ag --bulk-overlap --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --use-ce --check-numerics"
# pipline overlap rs
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type rs --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --check-numerics"
# pipline overlap rs over ce
cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type rs --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --use-ce --check-numerics"
# ring_exchange overlap rs
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type rs --p2p --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --check-numerics"
# ring_exchange overlap rs over ce
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type rs --p2p --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --use-ce --check-numerics"
# ring_exchange overlap ag
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type ag --p2p --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --check-numerics"
# ring_exchange overlap ag over ce
# cmd="torchrun --nproc-per-node=$TP_SIZE tests/pytorch/distributed/run_gemm_with_overlap.py --comm-type ag --p2p --verbose --dtype bf16 --batch-size $BATCH_SIZE --seq-length $SEQ_LENGTH --num-heads $NUM_HEADS --head-dim $HEAD_SIZE --warmup-iters $WARMUP_ITERS --timing-iters $TIMING_ITERS --use-ce --check-numerics"
echo $cmd
eval $cmd