Skip to content

Commit

Permalink
Enable triton on XPU devices (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor-Krivov authored Jan 20, 2025
1 parent 9affc3e commit d2537f1
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
- Your ROCm version must be compatible with Triton.
- Intel CPU and Intel GPU:
- Your torch and intel_extension_for_pytorch package version should at least 2.4 for optimized performance.
- Alternatively, you can rely on triton kernels for GPU, then you'll need to install [intel-xpu-backend-for-triton](https://github.com/intel/intel-xpu-backend-for-triton) along with compatible torch and transformers. Easiest way is to use [pre-built wheels](https://github.com/intel/intel-xpu-backend-for-triton?tab=readme-ov-file#install-pytorch-and-triton-from-nightly-wheels).

### Install from PyPi

Expand Down
5 changes: 3 additions & 2 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
exclude_layers_to_not_quantize,
try_import,
)
from awq.utils.utils import get_best_device, ipex_available
from awq.utils.utils import get_best_device, ipex_available, triton_available
from transformers import (
AutoConfig,
PreTrainedModel,
Expand Down Expand Up @@ -499,7 +499,8 @@ def from_quantized(
)

best_device = get_best_device()
use_ipex = use_ipex or best_device in ["cpu", "xpu:0"]
if best_device == "cpu" or (best_device == "xpu:0" and not triton_available):
use_ipex = True
if use_ipex and not ipex_available:
raise ImportError(
"Please install intel_extension_for_pytorch with "
Expand Down
5 changes: 2 additions & 3 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
try:
from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton

# covers both CUDA and ROCm
if torch.cuda.is_available():
TRITON_AVAILABLE = True
# covers CUDA, ROCm and XPU. If we can import triton, then we can use it.
TRITON_AVAILABLE = True

except ImportError:
TRITON_AVAILABLE = False
Expand Down
10 changes: 8 additions & 2 deletions awq/modules/triton/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

def get_same_device_cm(t):
if t.device.type == 'xpu':
return torch.xpu.device(t.device.index)
else:
return torch.cuda.device(t.device.index)


@triton.jit
def awq_dequantize_kernel(
Expand Down Expand Up @@ -280,7 +286,7 @@ def awq_dequantize_triton(
triton.cdiv(X, META["BLOCK_SIZE_X"]),
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
)
with torch.cuda.device(qweight.device.index):
with get_same_device_cm(qweight):
awq_dequantize_kernel[grid](
qweight,
scales,
Expand Down Expand Up @@ -333,7 +339,7 @@ def awq_gemm_triton(

# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
with torch.cuda.device(qweight.device.index):
with get_same_device_cm(qweight):
awq_gemm_kernel[grid](
input,
qweight,
Expand Down
6 changes: 6 additions & 0 deletions awq/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@


ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None
try:
import triton as tl
triton_available = True
except ImportError:
triton_available = False



def get_module_by_name_suffix(model, module_name: str):
Expand Down

0 comments on commit d2537f1

Please sign in to comment.