diff --git a/README.md b/README.md index e760c53b..99bb3786 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up - Your ROCm version must be compatible with Triton. - Intel CPU and Intel GPU: - Your torch and intel_extension_for_pytorch package version should at least 2.4 for optimized performance. + - Alternatively, you can rely on triton kernels for GPU, then you'll need to install [intel-xpu-backend-for-triton](https://github.com/intel/intel-xpu-backend-for-triton) along with compatible torch and transformers. Easiest way is to use [pre-built wheels](https://github.com/intel/intel-xpu-backend-for-triton?tab=readme-ov-file#install-pytorch-and-triton-from-nightly-wheels). ### Install from PyPi diff --git a/awq/models/base.py b/awq/models/base.py index f879c921..e63e903e 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -29,7 +29,7 @@ exclude_layers_to_not_quantize, try_import, ) -from awq.utils.utils import get_best_device, ipex_available +from awq.utils.utils import get_best_device, ipex_available, triton_available from transformers import ( AutoConfig, PreTrainedModel, @@ -499,7 +499,8 @@ def from_quantized( ) best_device = get_best_device() - use_ipex = use_ipex or best_device in ["cpu", "xpu:0"] + if best_device == "cpu" or (best_device == "xpu:0" and not triton_available): + use_ipex = True if use_ipex and not ipex_available: raise ImportError( "Please install intel_extension_for_pytorch with " diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index 7ee89cc8..4d014992 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -14,9 +14,8 @@ try: from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton - # covers both CUDA and ROCm - if torch.cuda.is_available(): - TRITON_AVAILABLE = True + # covers CUDA, ROCm and XPU. If we can import triton, then we can use it. + TRITON_AVAILABLE = True except ImportError: TRITON_AVAILABLE = False diff --git a/awq/modules/triton/gemm.py b/awq/modules/triton/gemm.py index 5657ca2c..cea82dff 100644 --- a/awq/modules/triton/gemm.py +++ b/awq/modules/triton/gemm.py @@ -20,6 +20,12 @@ AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +def get_same_device_cm(t): + if t.device.type == 'xpu': + return torch.xpu.device(t.device.index) + else: + return torch.cuda.device(t.device.index) + @triton.jit def awq_dequantize_kernel( @@ -280,7 +286,7 @@ def awq_dequantize_triton( triton.cdiv(X, META["BLOCK_SIZE_X"]), triton.cdiv(Y, META["BLOCK_SIZE_Y"]), ) - with torch.cuda.device(qweight.device.index): + with get_same_device_cm(qweight): awq_dequantize_kernel[grid]( qweight, scales, @@ -333,7 +339,7 @@ def awq_gemm_triton( # A = input, B = qweight, C = result # A = M x K, B = K x N, C = M x N - with torch.cuda.device(qweight.device.index): + with get_same_device_cm(qweight): awq_gemm_kernel[grid]( input, qweight, diff --git a/awq/utils/utils.py b/awq/utils/utils.py index 738019a3..1dddfeb2 100644 --- a/awq/utils/utils.py +++ b/awq/utils/utils.py @@ -5,6 +5,12 @@ ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None +try: + import triton as tl + triton_available = True +except ImportError: + triton_available = False + def get_module_by_name_suffix(model, module_name: str):