Skip to content

Hqq support #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
@@ -156,7 +156,8 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
size_m=a.shape[0],
size_n=w_ref.shape[1],
size_k=w_ref.shape[0],
is_k_full=True))))
is_k_full=True,
is_zp_float=False))))

# machete
timers.append(
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
297 changes: 220 additions & 77 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions csrc/quantization/gptq_marlin/marlin_dtypes.cuh
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ class ScalarType<half> {
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
using FragZPF = Vec<half2, 1>;

static __device__ float inline num2float(const half x) {
return __half2float(x);
@@ -53,6 +54,7 @@ class ScalarType<nv_bfloat16> {
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
using FragZPF = Vec<nv_bfloat162, 1>;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
@@ -197,7 +197,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file

// gptq_marlin repack from GPTQ.
86 changes: 85 additions & 1 deletion tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import scalar_types

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
@@ -226,7 +227,7 @@ def test_gptq_marlin_gemm(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)

output = ops.gptq_marlin_gemm(
@@ -244,6 +245,7 @@ def test_gptq_marlin_gemm(
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

@@ -441,6 +443,7 @@ def test_awq_marlin_gemm(
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

@@ -451,6 +454,87 @@ def test_awq_marlin_gemm(
assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", [64])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
k_chunk,
n_chunk,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

quant_type = scalar_types.uint4

a_input = rand_data((size_m, size_k))
dev = a_input.device

b_weight = torch.randint(0,
10, (size_n, size_k),
dtype=torch.uint8,
device=dev)
scale = rand_data((size_n, size_k // group_size))
zero = rand_data((size_n, size_k // group_size))

gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
4).to(dev)
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
group_size).to(dev)

g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)

workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

output = ops.gptq_marlin_gemm(
a_input,
marlin_w_q,
marlin_s,
marlin_zp,
g_idx,
g_idx_sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[0],
a_input.shape[1],
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=True,
)

b_flat = b_weight.reshape(-1, group_size)
zp_flat = zero.reshape(-1, 1)
s_flat = scale.reshape(-1, 1)
dequant = (b_flat - zp_flat) * s_flat

output_ref = torch.matmul(a_input,
dequant.reshape(b_weight.shape).transpose(1, 0))

torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)

assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
8 changes: 5 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -335,7 +335,8 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@register_fake("_C::ggml_dequantize")
@@ -578,11 +579,12 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
has_zp, use_fp32_reduce, is_zp_float)


# fp8 marlin
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
@@ -48,6 +49,7 @@
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
Loading