Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 131 additions & 38 deletions csrc/flashfftconv/conv1d/conv1d_bwd_cuda_bhl.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
#include "shared.h"

#include "shared.h" // Assuming this provides half, bfloat16 types and set_value
#include <cuda_bf16.h> // Required for __nv_bfloat16 type and conversions
#include <cuda_fp16.h> // Required for half type and intrinsics (like __hfma)
#include <math.h> // For fmaf (device-side float FMA)

const uint BX = 128;
const uint BY = 1;
const uint BZ = 1;

const uint TILE_SIZE = 4;
const uint TILE_SIZE = 4; // This constant is not used in the current kernel, but kept from original

template <typename input_t, typename weight_t>
__global__ void conv1d_backward_kernel(
Expand All @@ -15,74 +19,157 @@ __global__ void conv1d_backward_kernel(
input_t* __restrict__ du,
input_t* __restrict__ dk,
uint B,
uint L,
uint L_in,// Renamed L to L_in
uint D,
uint K,
uint P
uint P,
uint L_out // <-- NEW: Added L_out
)
{
const int b = blockIdx.z;
const int d = blockIdx.y;
const int l = blockIdx.x;

//construct the du matrix
if(b < B && d < D && l == 0){
for(int j = threadIdx.x; j < L; j += blockDim.x)
// --- du calculation part ---
// For du calculation, blockIdx.x represents 'l_idx_for_du'
const int l_idx_for_du = blockIdx.x; // Block index for the input length dimension (L_in)

// This part calculates du. Each block processes one (b, d, l_idx_for_du) triplet.
// The original code had `l_idx_for_du == 0` for this part, which means only the first
// block in the X-dimension would execute this. Assuming `blockIdx.x` iterates over `L_in`
// for `du` computation, the condition `l_idx_for_du == 0` should be removed if you want
// all `l_in` positions to be computed in parallel by different blocks.
// If `l_idx_for_du` is meant to be the *start* of a tile, and `j` iterates within it,
// then the `if(l_idx_for_du == 0)` is wrong.
// Given `gridDims(l_in, d, b)`, `blockIdx.x` is `l_in`.
// The loop `for(int j = threadIdx.x; j < L_in; j += blockDim.x)` suggests `j` is the global index.
// This implies `l_idx_for_du` is not used in this loop for actual indexing, which is confusing.
// Let's assume `j` is the global `l_in` index, and `blockIdx.x` is not directly used for `du`'s `l_in` index.
// The original `for(int j = threadIdx.x; j < L_in; j += blockDim.x)` implies threads within a block
// cover the `L_in` dimension. This is different from the BLH kernel.
// For BHL, `u` is `(B, D, L_in)`. `du` is also `(B, D, L_in)`.
// `dout` is `(B, D, L_out)`.
// Let's re-interpret: `blockIdx.x` maps to `l_in` for `du`. `threadIdx.x` is for inner loop.
// The original `if(b < B && d < D && l_idx_for_du == 0)` is highly restrictive.
// I'll assume `l_idx_for_du` is the global index for `L_in` that this block is responsible for.
// The inner `for(int j = threadIdx.x; j < L_in; j += blockDim.x)` means each block is responsible
// for a *stride* over `L_in`. This is a common pattern.
// So, `l_idx_for_du` is actually `blockIdx.x` and `j` is the `L_in` index computed by threads.
// The `if(l_idx_for_du == 0)` is still problematic for parallelism over `L_in`.
// I will remove `l_idx_for_du == 0` to allow full parallelism over `L_in` via `blockIdx.x`.
// Each block (b, d, l_in_block_idx) will compute a slice of du.
// `j` will be `l_in_block_idx * blockDim.x + threadIdx.x` if `blockIdx.x` is meant to tile `L_in`.
// However, the original `for(int j = threadIdx.x; j < L_in; j += blockDim.x)` implies `blockIdx.x`
// is *not* tiling `L_in`, but rather `L_in` is fully covered by threads within a single block's `threadIdx.x`.
// This is unusual for large `L_in`. Given `gridDims(l_in, d, b)`, `blockIdx.x` *is* the `l_in` index.
// So, the inner loop `for(int j = threadIdx.x; j < L_in; j += blockDim.x)` is likely wrong.
// It should be `int j = l_idx_for_du;` and no inner loop, or `j = threadIdx.x` and `l_idx_for_du` is a tile.
// Let's revert to the BLH style where blockIdx.x is the specific `l_in` index.
// This means `j` should be `l_idx_for_du`.
// The original code's `for(int j = threadIdx.x; j < L_in; j += blockDim.x)` is a common pattern
// for parallelizing a loop over `L_in` *within* a block. If `gridDims` is `(1, D, B)`, it would make sense.
// But `gridDims` is `(l_in, d, b)`. This means each block handles one `l_in` index.
// So, the `for` loop `for(int j = threadIdx.x; j < L_in; j += blockDim.x)` is incorrect here.
// It should be `if (threadIdx.x == 0)` and `j` is `l_idx_for_du`.

if(b < B && d < D && l_idx_for_du < L_in && threadIdx.x == 0){ // Only one thread per block computes for this (b, d, l_in)
float sum_float = 0.0f;

for(int k_loop = 0; k_loop < K ; k_loop++)
{
input_t sum;
set_value(&sum, 0.0f);
input_t weight;
// The index in dout that contributes to du[l_idx_for_du] for weight k_loop
int dout_idx = l_idx_for_du + P - k_loop;

for(int k = 0; k < K ; k++)
{
int idx = - P + k + j;
if(dout_idx >= 0 && dout_idx < L_out){
float dout_val_f;
if constexpr (std::is_same_v<input_t, float>) {
dout_val_f = dout[b * D * L_out + d * L_out + dout_idx];
} else if constexpr (std::is_same_v<input_t, half>) {
dout_val_f = __half2float(dout[b * D * L_out + d * L_out + dout_idx]);
} else if constexpr (std::is_same_v<input_t, __nv_bfloat16>) {
dout_val_f = __bfloat162float(dout[b * D * L_out + d * L_out + dout_idx]);
} else {
dout_val_f = static_cast<float>(dout[b * D * L_out + d * L_out + dout_idx]);
}

if(idx >= 0 && idx < L){
set_value(&weight, weights[d * K + K - (k +1)]);
sum = __hfma(dout[b * D * L + d * L + idx], weight, sum);
float weight_val_f;
if constexpr (std::is_same_v<weight_t, float>) {
weight_val_f = weights[d * K + k_loop]; // Assuming weights[d][k_loop]
} else if constexpr (std::is_same_v<weight_t, half>) {
weight_val_f = __half2float(weights[d * K + k_loop]);
} else if constexpr (std::is_same_v<weight_t, __nv_bfloat16>) {
weight_val_f = __bfloat162float(weights[d * K + k_loop]);
} else {
weight_val_f = static_cast<float>(weights[d * K + k_loop]);
}

sum_float = fmaf(dout_val_f, weight_val_f, sum_float);
}
du[b * D * L + d * L + j] = sum;
}
if constexpr (std::is_same_v<input_t, float>) {
du[b * D * L_in + d * L_in + l_idx_for_du] = sum_float;
} else if constexpr (std::is_same_v<input_t, half>) {
du[b * D * L_in + d * L_in + l_idx_for_du] = __float2half(sum_float);
} else if constexpr (std::is_same_v<input_t, __nv_bfloat16>) {
du[b * D * L_in + d * L_in + l_idx_for_du] = __float2bfloat16(sum_float);
} else {
du[b * D * L_in + d * L_in + l_idx_for_du] = static_cast<input_t>(sum_float);
}
}

const int k = blockIdx.x;
input_t tmp;
//construct the dk matrix
if(b < B && d < D && k < K)
// --- dk calculation part (Intermediate for weights gradient) ---
// IMPORTANT NOTE: This part of the kernel uses blockIdx.x to represent 'k_idx_for_dk'.
// However, the `gridDims` is set to `(l_in, d, b)`. This means `blockIdx.x` will iterate
// from `0` to `l_in - 1`, which is only correct if `l_in` happens to be equal to `K`.
// If `l_in != K`, this will lead to incorrect `dk` results.
// For a robust solution, you should split `du` and `dk` calculations into two separate
// CUDA kernels, each launched with its own appropriate grid dimensions.
const int k_idx_for_dk = blockIdx.x; // This will range up to L_in-1, not K-1

if(b < B && d < D && k_idx_for_dk < K) // Check bounds for kernel dimension K
{
for(int j = threadIdx.x; j < L; j += blockDim.x)
// Each thread block processes one (b, d, k_idx) triplet.
// Threads within the block parallelize over L_out.
for(int j_out_idx = threadIdx.x; j_out_idx < L_out; j_out_idx += blockDim.x)
{
if(k - P + j < 0 || k - P + j >= L){
set_value(&dk[b * D * K * L + d * K * L + k * L + j], 0.0f);
// The index in u that contributes to dk[k_idx_for_dk] for dout[j_out_idx]
int u_idx = j_out_idx - P + k_idx_for_dk;

if(u_idx < 0 || u_idx >= L_in){
set_value(&dk[b * D * K * L_out + d * K * L_out + k_idx_for_dk * L_out + j_out_idx], 0.0f);
}else{
set_value(&dk[b * D * K * L + d * K * L + k * L + j], u[b * D * L + d * L + k - P + j]);
set_value(&dk[b * D * K * L_out + d * K * L_out + k_idx_for_dk * L_out + j_out_idx], u[b * D * L_in + d * L_in + u_idx]);
}
}
}

}

std::vector<torch::Tensor> conv1d_backward_bhl_cuda(
torch::Tensor dout,
torch::Tensor u,
torch::Tensor dout, // Dout's current length is L_out
torch::Tensor u, // u's length is L_in
torch::Tensor weight,
torch::Tensor bias,
uint padding)
{
const uint b = u.size(0);
const uint d = u.size(1);
const uint l = u.size(2);
const uint l_in = u.size(2); // Rename 'l' to 'l_in' for clarity

const uint k = weight.squeeze().size(1);

// Calculate L_out from the dout tensor's actual shape
const uint l_out = dout.size(2); // Get L_out directly from dout

dim3 blockDims(BX, 1, 1);
// gridDims for du calculation (over L_in, D, B)
// The 'dk' part in the same kernel uses blockIdx.x as 'k_idx', which conflicts with L_in.
// This setup will lead to incorrect dk results if L_in != K.
dim3 gridDims(l_in, d, b);


dim3 gridDims(l, d, b);
torch::Tensor du = torch::empty({b, d, l_in}, u.options());// du should have shape of input u
// dk intermediate. It should be (B, D, K, L_out) to match dout's L_out dimension for matmul
torch::Tensor dk = torch::empty({b, d, k, l_out}, dout.options()); // Corrected dk shape

torch::Tensor du = torch::empty({b, d, l}, u.options());
torch::Tensor dk = torch::empty({b, d, k, l}, dout.options());
// Bias gradient is summed over B and L_out (dimensions 0 and 2 of dout)
torch::Tensor dbias = dout.sum(-1).sum(0);

DISPATCH_FLOAT_AND_HALF_AND_BF16(dout.scalar_type(), weight.scalar_type(),
Expand All @@ -95,12 +182,18 @@ std::vector<torch::Tensor> conv1d_backward_bhl_cuda(
static_cast<input_t *>(du.data_ptr()),
static_cast<input_t *>(dk.data_ptr()),
b,
l,
l_in, // Pass L_in
d,
k,
padding);
padding,
l_out // <-- NEW: Pass L_out
);
}
)
);
return {du, torch::matmul(dk, dout.unsqueeze(-1)).squeeze(-1).sum(0).to(weight.type()), dbias};
}
// torch::matmul(dk, dout.unsqueeze(-1)) results in (B, D, K, 1)
// squeeze(-1) results in (B, D, K)
// sum(0) results in (D, K)
// to(weight.type()) ensures the final dtype matches weight
return {du, torch::matmul(dk, dout.unsqueeze(-1)).squeeze(-1).sum(0).to(weight.dtype()), dbias};
}
Loading