Skip to content

Commit

Permalink
cuda kernel fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Jan 29, 2025
1 parent a001467 commit 27c35e2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
9 changes: 8 additions & 1 deletion configs/rwkv/24B-pp-scaling.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 8,
"pipe_parallel_size": 16,
"model_parallel_size": 1,

"num_layers": 352,
Expand Down Expand Up @@ -80,6 +80,13 @@
"hidden_dropout": 0,
"attention_dropout": 0,

"comms_logger": {
"enabled": true,
"verbose": true,
"prof_all": true,
"debug": true
},

# precision settings
"bf16": {
"bf16": true,
Expand Down
10 changes: 3 additions & 7 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,21 +404,17 @@ def __init__(self, neox_args, init_method, layer_number):
if neox_args.hidden_dropout > 0:
self.drop1 = nn.Dropout(p=neox_args.hidden_dropout)

if layer_number == 0:
if layer_number % neox_args.pipe_parallel_size == 0:
if not self.neox_args.rwkv_fla:
global wkv_cuda
"""
Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not.
"""
wkv_cuda = load(
name="wkv6",
#sources=[
# "megatron/model/rwkv/v6/cuda/wkv6_op.cpp",
# f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu",
#],
sources=[
"megatron/model/rwkv/v6/hip/wkv6_op.cpp",
f"megatron/model/rwkv/v6/hip/wkv6_hip.hip",
"megatron/model/rwkv/v6/cuda/wkv6_op.cpp",
f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
Expand Down

0 comments on commit 27c35e2

Please sign in to comment.