-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Add a standalone tutorial for integrating custom op using sycl for Intel GPU #3470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,274 @@ | ||||||
.. _cpp-custom-ops-tutorial-sycl: | ||||||
|
||||||
Custom SYCL Operators | ||||||
===================== | ||||||
|
||||||
.. grid:: 2 | ||||||
|
||||||
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn | ||||||
:class-card: card-prerequisites | ||||||
|
||||||
* How to integrate custom operators written in SYCL with PyTorch | ||||||
|
||||||
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites | ||||||
:class-card: card-prerequisites | ||||||
|
||||||
* PyTorch 2.8 or later | ||||||
* Basic understanding of SYCL programming | ||||||
|
||||||
.. note:: | ||||||
|
||||||
``SYCL`` serves as the backend programming language for Intel GPUs (device label ``xpu``). For configuration details, see: | ||||||
`Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html>`_. The Intel Compiler, which comes bundled with Intel Deep Learning Essentials, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial. | ||||||
|
||||||
PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc). | ||||||
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the | ||||||
blessed path to authoring a custom operator written in SYCL. Tutorials for C++ and CUDA operators are available in the :ref:`cpp-custom-ops-tutorial`. | ||||||
|
||||||
Follow the structure to create a custom SYCL operator: | ||||||
|
||||||
.. code-block:: text | ||||||
|
||||||
sycl_example/ | ||||||
├── setup.py | ||||||
├── sycl_extension | ||||||
│ ├── __init__.py | ||||||
│ ├── muladd.sycl | ||||||
│ └── ops.py | ||||||
└── test_sycl_extension.py | ||||||
|
||||||
Setting up the Build System | ||||||
--------------------------- | ||||||
|
||||||
If you need to compile **SYCL** code (for example, ``.sycl`` files), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension>`_. | ||||||
The setup process is very similar to C++/CUDA, except the compilation arguments need to be adjusted for SYCL. | ||||||
|
||||||
Using ``sycl_extension`` is as simple as writing the following ``setup.py``: | ||||||
|
||||||
.. code-block:: python | ||||||
|
||||||
import os | ||||||
import torch | ||||||
import glob | ||||||
from setuptools import find_packages, setup | ||||||
from torch.utils.cpp_extension import SyclExtension, BuildExtension | ||||||
|
||||||
library_name = "sycl_extension" | ||||||
py_limited_api = True | ||||||
extra_compile_args = { | ||||||
"cxx": ["-O3", | ||||||
"-fdiagnostics-color=always", | ||||||
"-DPy_LIMITED_API=0x03090000"], | ||||||
"sycl": ["-O3" ] | ||||||
} | ||||||
|
||||||
assert(torch.xpu.is_available()), "XPU is not available, please check your environment" | ||||||
# Source files collection | ||||||
this_dir = os.path.dirname(os.path.curdir) | ||||||
extensions_dir = os.path.join(this_dir, library_name) | ||||||
sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl"))) | ||||||
# Construct extension | ||||||
ext_modules = [ | ||||||
SyclExtension( | ||||||
f"{library_name}._C", | ||||||
sources, | ||||||
extra_compile_args=extra_compile_args, | ||||||
py_limited_api=py_limited_api, | ||||||
) | ||||||
] | ||||||
setup( | ||||||
name=library_name, | ||||||
packages=find_packages(), | ||||||
ext_modules=ext_modules, | ||||||
install_requires=["torch"], | ||||||
description="Simple Example of PyTorch Sycl extensions", | ||||||
cmdclass={"build_ext": BuildExtension}, | ||||||
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, | ||||||
) | ||||||
|
||||||
|
||||||
Defining the custom op and adding backend implementations | ||||||
--------------------------------------------------------- | ||||||
First, let's write a Sycl function that computes ``mymuladd``: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
let's be consistent in capitalization |
||||||
|
||||||
In order to use this from PyTorch’s Python frontend, we need to register it | ||||||
as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically | ||||||
bind the operator to Python. | ||||||
|
||||||
|
||||||
If you also have a SYCL implementation of ``myaddmul``, you can also register it | ||||||
in a separate ``TORCH_LIBRARY_IMPL`` block: | ||||||
|
||||||
.. code-block:: cpp | ||||||
|
||||||
#include <c10/xpu/XPUStream.h> | ||||||
#include <sycl/sycl.hpp> | ||||||
#include <ATen/Operators.h> | ||||||
#include <torch/all.h> | ||||||
#include <torch/library.h> | ||||||
|
||||||
namespace sycl_extension { | ||||||
// MulAdd Kernel: result = a * b + c | ||||||
static void muladd_kernel( | ||||||
int numel, const float* a, const float* b, float c, float* result, | ||||||
const sycl::nd_item<1>& item) { | ||||||
int idx = item.get_global_id(0); | ||||||
if (idx < numel) { | ||||||
result[idx] = a[idx] * b[idx] + c; | ||||||
} | ||||||
} | ||||||
|
||||||
class MulAddKernelFunctor { | ||||||
public: | ||||||
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result) | ||||||
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {} | ||||||
void operator()(const sycl::nd_item<1>& item) const { | ||||||
muladd_kernel(numel, a, b, c, result, item); | ||||||
} | ||||||
|
||||||
private: | ||||||
int numel; | ||||||
const float* a; | ||||||
const float* b; | ||||||
float c; | ||||||
float* result; | ||||||
}; | ||||||
|
||||||
at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) { | ||||||
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape"); | ||||||
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor"); | ||||||
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor"); | ||||||
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor"); | ||||||
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor"); | ||||||
|
||||||
at::Tensor a_contig = a.contiguous(); | ||||||
at::Tensor b_contig = b.contiguous(); | ||||||
at::Tensor result = at::empty_like(a_contig); | ||||||
|
||||||
const float* a_ptr = a_contig.data_ptr<float>(); | ||||||
const float* b_ptr = b_contig.data_ptr<float>(); | ||||||
float* res_ptr = result.data_ptr<float>(); | ||||||
int numel = a_contig.numel(); | ||||||
|
||||||
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); | ||||||
constexpr int threads = 256; | ||||||
int blocks = (numel + threads - 1) / threads; | ||||||
|
||||||
queue.submit([&](sycl::handler& cgh) { | ||||||
cgh.parallel_for<MulAddKernelFunctor>( | ||||||
sycl::nd_range<1>(blocks * threads, threads), | ||||||
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr) | ||||||
); | ||||||
}); | ||||||
|
||||||
return result; | ||||||
} | ||||||
// Defines the operators | ||||||
TORCH_LIBRARY(sycl_extension, m) { | ||||||
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); | ||||||
} | ||||||
|
||||||
// ================================================== | ||||||
// Register Sycl Implementations to Torch Library | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
// ================================================== | ||||||
TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) { | ||||||
m.impl("mymuladd", &mymuladd_xpu); | ||||||
} | ||||||
|
||||||
} // namespace sycl_extension | ||||||
|
||||||
|
||||||
|
||||||
Create a Python Interface | ||||||
------------------------- | ||||||
|
||||||
Create a Python interface for our operator in the ``sycl_extension/ops.py`` file: | ||||||
|
||||||
.. code-block:: python | ||||||
|
||||||
import torch | ||||||
from torch import Tensor | ||||||
__all__ = ["mymuladd"] | ||||||
|
||||||
def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor: | ||||||
"""Performs a * b + c in an efficient fused kernel""" | ||||||
return torch.ops.sycl_extension.mymuladd.default(a, b, c) | ||||||
|
||||||
Initialize Package | ||||||
------------------ | ||||||
|
||||||
Create ``sycl_extension/__init__.py`` file to make the package importable: | ||||||
|
||||||
.. code-block:: python | ||||||
|
||||||
import ctypes | ||||||
from pathlib import Path | ||||||
|
||||||
import torch | ||||||
|
||||||
current_dir = Path(__file__).parent.parent | ||||||
build_dir = current_dir / "build" | ||||||
so_files = list(build_dir.glob("**/*.so")) | ||||||
|
||||||
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" | ||||||
|
||||||
with torch._ops.dl_open_guard(): | ||||||
loaded_lib = ctypes.CDLL(so_files[0]) | ||||||
|
||||||
from . import ops | ||||||
|
||||||
__all__ = [ | ||||||
"loaded_lib", | ||||||
"ops", | ||||||
] | ||||||
|
||||||
Testing sycl extension operator | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
------------------- | ||||||
|
||||||
Use simple test to verify that the operator works correctly. | ||||||
|
||||||
.. code-block:: python | ||||||
|
||||||
import torch | ||||||
from torch.testing._internal.common_utils import TestCase | ||||||
import unittest | ||||||
import sycl_extension | ||||||
|
||||||
def reference_muladd(a, b, c): | ||||||
return a * b + c | ||||||
|
||||||
class TestMyMulAdd(TestCase): | ||||||
def sample_inputs(self, device, *, requires_grad=False): | ||||||
def make_tensor(*size): | ||||||
return torch.randn(size, device=device, requires_grad=requires_grad) | ||||||
|
||||||
def make_nondiff_tensor(*size): | ||||||
return torch.randn(size, device=device, requires_grad=False) | ||||||
|
||||||
return [ | ||||||
[make_tensor(3), make_tensor(3), 1], | ||||||
[make_tensor(20), make_tensor(20), 3.14], | ||||||
[make_tensor(20), make_nondiff_tensor(20), -123], | ||||||
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3], | ||||||
] | ||||||
|
||||||
def _test_correctness(self, device): | ||||||
samples = self.sample_inputs(device) | ||||||
for args in samples: | ||||||
result = sycl_extension.ops.mymuladd(*args) | ||||||
expected = reference_muladd(*args) | ||||||
torch.testing.assert_close(result, expected) | ||||||
|
||||||
@unittest.skipIf(not torch.xpu.is_available(), "requires Intel GPU") | ||||||
def test_correctness_xpu(self): | ||||||
self._test_correctness("xpu") | ||||||
|
||||||
if __name__ == "__main__": | ||||||
unittest.main() | ||||||
|
||||||
This test checks the correctness of the custom operator by comparing its output against a reference implementation. | ||||||
|
||||||
Conclusion | ||||||
---------- | ||||||
|
||||||
In this tutorial, we demonstrated how to implement and compile custom SYCL operators for PyTorch. We specifically showcased an inference operation ``muladd``. For adding backward support or enabling torch.compile compatibility, please refer to :ref:`cpp-custom-ops-tutorial`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a personal preference, but I try to avoid using words like "simple," "easy," etc, in docs - just to make sure it's inclusive and folks don't feel frustrated if something doesn't work..