Skip to content

Commit a78ac40

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Add _weight_int4pack_mm to the C shim fallback list (pytorch#151059)
Summary: As title Pull Request resolved: pytorch#151059 Approved by: https://github.com/yushangdi
1 parent 12281f9 commit a78ac40

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
skipCUDAIf,
3434
)
3535
from torch.testing._internal.common_quantization import (
36+
_group_quantize_tensor,
3637
skip_if_no_torchvision,
3738
skipIfNoFBGEMM,
3839
)
@@ -42,6 +43,7 @@
4243
IS_FBCODE,
4344
IS_MACOS,
4445
IS_WINDOWS,
46+
parametrize,
4547
skipIfRocm,
4648
skipIfXpu,
4749
TEST_WITH_ROCM,
@@ -4936,6 +4938,42 @@ def forward(self, x, y):
49364938
)
49374939
self.check_model(Model(), example_inputs)
49384940

4941+
@skipIfXpu(
4942+
msg="aten::convert_weight_to_int4pack is not currently implemented for XPU"
4943+
)
4944+
@parametrize("m", [32])
4945+
@parametrize("n", [64])
4946+
@parametrize("q_group", [32, 64])
4947+
@parametrize("num_groups", [1, 2])
4948+
def test__weight_int4pack_mm(self, m, n, q_group, num_groups):
4949+
if self.device != GPU_TYPE:
4950+
raise unittest.SkipTest("requires GPU")
4951+
4952+
class Model(torch.nn.Module):
4953+
def __init__(self, weight, scale_and_zeros) -> None:
4954+
super().__init__()
4955+
self.weight = weight
4956+
self.scale_and_zeros = scale_and_zeros
4957+
4958+
def forward(self, a):
4959+
return torch._weight_int4pack_mm(
4960+
a, self.weight, q_group, self.scale_and_zeros
4961+
)
4962+
4963+
def convert_weight_to_int4pack(b):
4964+
b_int32, b_scales_and_zeros = _group_quantize_tensor(
4965+
b, n_bit=4, q_group_size=q_group
4966+
)
4967+
b_int4pack = torch._convert_weight_to_int4pack(b_int32, innerKTiles=2)
4968+
return b_int4pack, b_scales_and_zeros
4969+
4970+
k = q_group * num_groups
4971+
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
4972+
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
4973+
b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b)
4974+
model = Model(b_int4pack, b_scales_and_zeros_f32)
4975+
self.check_model(model, (a,))
4976+
49394977
def test_assert_tensor_meta(self):
49404978
class Module(torch.nn.Module):
49414979
def forward(self, x):

torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTe
4949
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTensorHandle input_gates, AtenTensorHandle hidden_gates, AtenTensorHandle cx, AtenTensorHandle* input_bias, AtenTensorHandle* hidden_bias, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
5050
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
5151
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
52+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
5253
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
5354
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
5455
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

torchgen/aoti/fallback_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"aten._thnn_fused_lstm_cell.default",
5252
"aten._to_sparse.default",
5353
"aten._trilinear.default",
54+
"aten._weight_int4pack_mm.default",
5455
"aten._weight_int8pack_mm.default",
5556
"aten.adaptive_max_pool2d_backward.default",
5657
"aten.adaptive_max_pool2d.default",

0 commit comments

Comments
 (0)