Skip to content

Commit 127bd5a

Browse files
drisspgpytorchmergebot
authored andcommitted
Add sparsity (pytorch#148513)
Pull Request resolved: pytorch#148513 Approved by: https://github.com/danielvegamyhre
1 parent b4430c3 commit 127bd5a

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

benchmarks/transformer/sdpa.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def asdict(self):
8080

8181

8282
def calculate_tflops(
83-
config: ExperimentConfig, time_us: float, is_backward: bool = False
83+
config: ExperimentConfig,
84+
time_us: float,
85+
is_backward: bool = False,
86+
sparsity: float = 0.0,
8487
) -> float:
8588
"""
8689
Calculate TFLOPS for scaled dot product attention.
@@ -89,6 +92,7 @@ def calculate_tflops(
8992
- config: The experiment configuration
9093
- time_us: The execution time in microseconds
9194
- is_backward: Whether to calculate for backward pass (includes gradient computation)
95+
- sparsity: Sparsity factor between 0.0 and 1.0, where 0.0 means no sparsity and 1.0 means fully sparse
9296
9397
Returns:
9498
- TFLOPS value
@@ -99,6 +103,9 @@ def calculate_tflops(
99103
N = config.kv_seq_len
100104
D = config.head_dim
101105

106+
# Calculate density factor (1.0 - sparsity)
107+
density = 1.0 - sparsity
108+
102109
# Forward pass FLOPs
103110
qk_flops = (
104111
M * N * D * 2
@@ -110,6 +117,9 @@ def calculate_tflops(
110117

111118
total_flops = B * H * (qk_flops + softmax_flops + av_flops)
112119

120+
# Apply density factor to account for sparsity
121+
total_flops *= density
122+
113123
# For backward pass flash uses 2.5x more flops will use this
114124
if is_backward:
115125
total_flops *= 2.5
@@ -168,8 +178,11 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
168178
)
169179

170180
# Calculate TFLOPS for forward and backward passes
171-
forward_tflops = calculate_tflops(config, forward_time)
172-
backward_tflops = calculate_tflops(config, backward_time, is_backward=True)
181+
sparsity = 0.5 if is_causal else 0.0
182+
forward_tflops = calculate_tflops(config, forward_time, sparsity=sparsity)
183+
backward_tflops = calculate_tflops(
184+
config, backward_time, is_backward=True, sparsity=sparsity
185+
)
173186

174187
return ExperimentResults(
175188
forward_time=forward_time,

0 commit comments

Comments
 (0)