Skip to content

Xpu ut/dtypes #2797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,26 @@
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.testing.utils import skip_if_no_cuda, skip_if_no_gemlite, skip_if_rocm
from torchao.testing.utils import skip_if_no_gemlite, skip_if_rocm, skip_if_xpu
from torchao.utils import (
check_cpu_version,
check_xpu_version,
is_fbcode,
is_ROCM,
is_sm_at_least_89,
is_sm_at_least_90,
auto_detect_device,
)

is_cusparselt_available = (
hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available()
)

_DEVICE = auto_detect_device()


def get_quantization_functions(
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
do_sparse: bool, do_int4: bool, device: str =_DEVICE, int4_zp_int: bool = False
):
base_functions = [
int8_weight_only(),
Expand Down Expand Up @@ -113,9 +116,9 @@ class TestAffineQuantized(TestCase):
["xpu"] if torch.xpu.is_available() else []
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

def test_tensor_core_layout_transpose(self):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
Expand Down Expand Up @@ -177,7 +180,7 @@ def _apply(module, config_or_subclass_inserter):
ql = _apply(linear, apply_quant)
ql.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

def test_register_new_dispatch(self):
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes.affine_quantized_tensor_ops import (
Expand Down Expand Up @@ -214,10 +217,10 @@ def apply_uint6_weight_only_quant(linear):
)
return linear

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
apply_uint6_weight_only_quant(linear)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE)
with self.assertRaisesRegex(
AssertionError, "dispatching to my impl for uint6 weight only quant"
):
Expand All @@ -240,13 +243,14 @@ def test_print_quantized_module(self):
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
)
@skip_if_xpu("XPU enablement is in Progress")
def test_test_copy__apply(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)

if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
Expand All @@ -257,20 +261,20 @@ def test_test_copy__apply(self, apply_quant):
ql = apply_quant(linear)
ql2 = apply_quant(linear2)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE)
output = ql(example_input)
ql2.weight.copy_(ql.weight)
ql2.bias = ql.bias
output2 = ql2(example_input)
self.assertEqual(output, output2)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
)
def test_copy__mismatch_metadata(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device=_DEVICE)

if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
Expand Down Expand Up @@ -344,10 +348,10 @@ def test_alias(self, device, dtype):
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
_ = dummy.weight[...]

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
@skip_if_rocm("ROCm enablement in progress")
@skip_if_xpu("xpu enablement in progress")
def test_slice_int4wo(self, device, dtype):
# in_feature not divisible by 1024
# out_feature not divisible by 8
Expand All @@ -358,9 +362,7 @@ def test_slice_int4wo(self, device, dtype):
_ = dummy.weight.narrow(0, 0, 64)
_ = dummy.weight.narrow(1, 0, 128)

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("dtype", [torch.float16, torch.bfloat16])
@skip_if_no_cuda()
@skip_if_no_gemlite()
def test_slice_gemlite(self, device, dtype):
# in_feature not divisible by 1024
Expand Down Expand Up @@ -441,7 +443,7 @@ def dequant(input_layer, in_features, orig_shape):
)
self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0)

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_matmul(self, device, dtype):
x = torch.randn(53, 2048)
Expand All @@ -458,14 +460,14 @@ def test_matmul(self, device, dtype):
# make sure it runs
torch.matmul(x, w.t())

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
@skip_if_rocm("ROCm enablement in progress")
@skip_if_xpu("XPU enablement in progress")
def test_slice_and_copy_int4wo(self, device, dtype):
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE)
)
quantize_(l, Int4WeightOnlyConfig())
param = l.weight
Expand All @@ -482,7 +484,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
assert param.data.dequantize()[0][0] == 0

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
dummy_l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
quantize_(dummy_l, Int4WeightOnlyConfig())
quantized = dummy_l.weight
quantized = quantized.narrow(0, 0, 512)
Expand All @@ -492,9 +494,8 @@ def test_slice_and_copy_int4wo(self, device, dtype):
# making sure param.data is updated
assert param.data.dequantize()[0][0] != 0

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
@skip_if_rocm("ROCm enablement in progress")
def test_mm_int4wo(self, device, dtype):
weight = torch.randn(512, 1024).to(device).to(dtype)
Expand Down
Loading