diff --git a/mojo_opset/backends/ttx/kernels/npu/__init__.py b/mojo_opset/backends/ttx/kernels/npu/__init__.py index bd5c5108..bb487f87 100755 --- a/mojo_opset/backends/ttx/kernels/npu/__init__.py +++ b/mojo_opset/backends/ttx/kernels/npu/__init__.py @@ -49,6 +49,7 @@ from .swa import swa_paged_prefill_impl from .swiglu import swiglu_bwd_impl from .swiglu import swiglu_fwd_impl +from .int8_gemm import int8_gemm_dequant_impl, prepare_b_impl # triton-dist based comm kernels (requires triton_dist + shmem packages) allgather_gemm_impl = None @@ -131,4 +132,6 @@ "allgather_gemm_impl", "gemm_allreduce_impl", "gemm_reduce_scatter_impl", + "int8_gemm_dequant_impl", + "prepare_b_impl", ] diff --git a/mojo_opset/backends/ttx/kernels/npu/int8_gemm.py b/mojo_opset/backends/ttx/kernels/npu/int8_gemm.py index 88ae0e1a..81e35439 100644 --- a/mojo_opset/backends/ttx/kernels/npu/int8_gemm.py +++ b/mojo_opset/backends/ttx/kernels/npu/int8_gemm.py @@ -110,7 +110,7 @@ def _pad_to(x, mult): return ((x + mult - 1) // mult) * mult -def prepare_b(b: torch.Tensor) -> torch.Tensor: +def prepare_b_impl(b: torch.Tensor) -> torch.Tensor: """Transpose B to (N, K) row-major and pad to block boundaries. For inference: weight B is fixed, call once and reuse.