Skip to content

Commit 857b83d

Browse files
fix gemv test on avx512bf16 cpu (#1956)
1 parent c59334e commit 857b83d

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ markers = [
101101
"deprecated: mark test as covering a deprecated feature",
102102
"slow: mark test as slow",
103103
]
104+
testpaths = ["tests"]
104105

105106
[tool.ruff]
106107
src = [

tests/test_functional.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,14 +807,20 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind):
807807
compress_statistics=double_quant,
808808
quant_storage=quant_storage,
809809
)
810+
811+
# dequant+F.linear reference path.
812+
C1 = torch.nn.functional.linear(A, F.dequantize_4bit(qB, state).to(dtype))
813+
814+
# original matmul reference path.
810815
C3 = torch.matmul(A, B.t())
816+
811817
# CPU requires convert weight packed for gemv
812818
if device == "cpu" and F.has_avx512bf16():
813819
qB, state = F._convert_weight_packed_for_cpu(qB, state)
814820
qB = qB.t()
821+
822+
# GEMV test
815823
C2 = F.gemv_4bit(A, qB.t(), state=state)
816-
# dequant+F.linear reference path
817-
C1 = torch.nn.functional.linear(A, F.dequantize_4bit(qB, state).to(dtype))
818824

819825
err1 = (C1 - C2).abs().float()
820826
err2 = (C3 - C2).abs().float()

0 commit comments

Comments
 (0)