@@ -80,7 +80,10 @@ def asdict(self):
80
80
81
81
82
82
def calculate_tflops (
83
- config : ExperimentConfig , time_us : float , is_backward : bool = False
83
+ config : ExperimentConfig ,
84
+ time_us : float ,
85
+ is_backward : bool = False ,
86
+ sparsity : float = 0.0 ,
84
87
) -> float :
85
88
"""
86
89
Calculate TFLOPS for scaled dot product attention.
@@ -89,6 +92,7 @@ def calculate_tflops(
89
92
- config: The experiment configuration
90
93
- time_us: The execution time in microseconds
91
94
- is_backward: Whether to calculate for backward pass (includes gradient computation)
95
+ - sparsity: Sparsity factor between 0.0 and 1.0, where 0.0 means no sparsity and 1.0 means fully sparse
92
96
93
97
Returns:
94
98
- TFLOPS value
@@ -99,6 +103,9 @@ def calculate_tflops(
99
103
N = config .kv_seq_len
100
104
D = config .head_dim
101
105
106
+ # Calculate density factor (1.0 - sparsity)
107
+ density = 1.0 - sparsity
108
+
102
109
# Forward pass FLOPs
103
110
qk_flops = (
104
111
M * N * D * 2
@@ -110,6 +117,9 @@ def calculate_tflops(
110
117
111
118
total_flops = B * H * (qk_flops + softmax_flops + av_flops )
112
119
120
+ # Apply density factor to account for sparsity
121
+ total_flops *= density
122
+
113
123
# For backward pass flash uses 2.5x more flops will use this
114
124
if is_backward :
115
125
total_flops *= 2.5
@@ -168,8 +178,11 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
168
178
)
169
179
170
180
# Calculate TFLOPS for forward and backward passes
171
- forward_tflops = calculate_tflops (config , forward_time )
172
- backward_tflops = calculate_tflops (config , backward_time , is_backward = True )
181
+ sparsity = 0.5 if is_causal else 0.0
182
+ forward_tflops = calculate_tflops (config , forward_time , sparsity = sparsity )
183
+ backward_tflops = calculate_tflops (
184
+ config , backward_time , is_backward = True , sparsity = sparsity
185
+ )
173
186
174
187
return ExperimentResults (
175
188
forward_time = forward_time ,
0 commit comments