Skip to content

Commit f736935

Browse files
committed
Move specialized gemm kernels and Triton allocator into NVIDIA backend; add iluvatar backend support
1 parent 14d1aa1 commit f736935

14 files changed

Lines changed: 5421 additions & 2425 deletions

File tree

benchmark/test_gemm_perf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def gems_sgemm_wrapper(
107107
alpha_ptr,
108108
beta_ptr,
109109
):
110-
flag_blas.ops.sgemm(
110+
flag_blas.sgemm(
111111
transa,
112112
transb,
113113
m,
@@ -197,7 +197,7 @@ def gems_hgemm_wrapper(
197197
alpha_ptr,
198198
beta_ptr,
199199
):
200-
flag_blas.ops.hgemm(
200+
flag_blas.hgemm(
201201
transa,
202202
transb,
203203
m,
@@ -287,7 +287,7 @@ def gems_bfgemm_wrapper(
287287
alpha_ptr,
288288
beta_ptr,
289289
):
290-
flag_blas.ops.bfgemm(
290+
flag_blas.bfgemm(
291291
transa,
292292
transb,
293293
m,
@@ -372,7 +372,7 @@ def gems_fp8gemm_wrapper(
372372
alpha_ptr,
373373
beta_ptr,
374374
):
375-
flag_blas.ops.fp8gemm(
375+
flag_blas.fp8gemm(
376376
transa,
377377
transb,
378378
m,

src/flag_blas/__init__.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,13 @@
22
flag_blas - BLAS operations implemented with Triton
33
"""
44

5-
import torch
6-
import triton
7-
from packaging import version
8-
9-
10-
def _alloc_fn(size, alignment, stream):
11-
return torch.empty(size, device="cuda", dtype=torch.int8)
5+
import warnings
126

13-
14-
triton.set_allocator(_alloc_fn)
7+
import torch
158

169
from flag_blas import runtime
17-
from flag_blas import testing
18-
from flag_blas.ops import *
10+
from flag_blas import testing # noqa: F401
11+
from flag_blas.ops import * # noqa: F401,F403
1912
from flag_blas.config import aten_patch_list, resolve_user_setting
2013
from flag_blas.runtime.register import Register
2114

0 commit comments

Comments
 (0)