Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit c6fc827

Browse files
User-facing Python API
This commit reimplements the user-facing Python API and address numerous issues: 1. significant overhead on the python side, due in particular to caching with python string keys based on input tensor sizes; 2. significant overhead on the tensor allocation side because of multiple transformations on the ATen to TensorInfo path + range inferencing every time; 3. implicit behaviors everywhere which gave a distorted view of how to use TC properly (see issue #196). This is also contrary to PEP-20; 4. 1 cache file per TC and size combination which makes it very cumbersome to use caching to files; 5. every python side API call being implemented in terms of `TCFunction.apply(self.cu, tc_info, kwargs, *inputs)` which puts the PyTorch python autograd abstraction on the critical path to everything. This is not portable or even future-proof since with PyTorch 0.5 we want to integrate TC at the C++ level after measuring 100+us for just getting into a torch.backward; 6. benchmark overhead of the abstraction which was quite more expensive than the CUDA kernels themselves in low latency mode; 7. `tc.define` abstraction only allows 1 TC function for `forward` and 1 TC function for `backward` and has a very clunky interface with a function to map tensors to the inputs of `backward` After design discussions with @apaszke it became quickly clear that using python language features was a significantly better way to go. This rewrite also kills overhead everywhere and removes as much implicit behavior as possible. A little implicit behavior is still remaining in the form of how MappingOptions are generated in TC class function calls on the first compilation (i.e. similar behavior to cudnn benchmark except here we JIT-compile). To make that behavior as explicit as possible we require the user to pass a generating function which returns a MappingOptions object. We provide 3 such functions: 1. always naive options for debug mode 2. load from cache or fail for using with a compilation cache file 3. autotune backed by cache with the same API as the `tc.autotune` function In the future we can add more behaviors depending on user feedback. For now it seems safer to let the user use whatever preferred behavior.
1 parent 38e8429 commit c6fc827

22 files changed

+1772
-1517
lines changed

.jenkins/build.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,8 @@ conda install -y -c nicolasvasilache caffe2
6767
WITH_CAFFE2=ON CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda CLANG_PREFIX=$(${CONDA_PREFIX}/bin/llvm-config --prefix) BUILD_TYPE=Release ./build.sh
6868

6969
python setup.py install
70-
./test_python/run_test.sh
7170

72-
for f in $(find ./python/examples -name "*.py"); do
71+
for f in $(find ./python/ -name "*.py"); do
7372
python $f -v
7473
done
7574

docs/source/coding_conventions.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Filter non-rectangular regions with data-dependencies
5959
-----------------------------------------------------
6060

6161
TC semantics are restricted to (hyper-)rectangular iteration spaces.
62-
This is a hard requirement to ensure range inference is non-ambiguous (see inference_).
62+
This is a hard requirement to ensure range inference is non-ambiguous (see :ref:`inference`).
6363
To simulate non-rectangular iteration spaces, one can use the following:
6464

6565
.. code::

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
# ones.
3838
extensions = [
3939
'sphinx.ext.autodoc',
40+
'sphinx_autodoc_typehints',
4041
'sphinx.ext.doctest',
4142
'sphinx.ext.intersphinx',
4243
'sphinx.ext.todo',
Lines changed: 48 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,161 +1,51 @@
11
Autograd with TC
22
================
33

4-
We provide the TC integration with PyTorch `autograd` so that it is easy to write
5-
a training layer with TC and be able to run backwards as well if the layer is part
6-
of a network. We do not support double backwards right now. In order to write a
7-
training layer with TC, you need to follow the steps below:
8-
9-
1. Define your TC language that has two definitions: one for the forward layer and the other for the backward layer and pass it to :code:`tc.define` call. In addition, also pass :code:`training=True` and the name of the backward TC :code:`backward`.
10-
11-
2. Create the Input Variables and Parameters. For example, weights should be marked as Parameters and the inputs tensors as Variables.
12-
13-
3. Run the layer and get the output of forward pass.
14-
15-
4. To see that the backward call works fine, you can call backward on the outputs.
16-
17-
Let's see one example to demonstrate the steps above:
18-
19-
Examples
20-
--------
21-
22-
.. code-block:: python
23-
24-
import tensor_comprehensions as tc
25-
import torch
26-
from torch.autograd import Variable
27-
from torch.nn.parameter import Parameter
28-
CONV_LANG = """
29-
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{
30-
O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw)
31-
}}
32-
def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O) -> (d_I, d_W1) {{
33-
d_I(n, c, h, w) +=! d_O( n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw)
34-
d_W1(m, c, kh, kw) +=! d_O(r_n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w)
35-
}}
36-
"""
37-
N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1
38-
convolution = tc.define(CONV_LANG, training=True, name="convolution", backward="convolution_grad", constants={"sh":sH, "sw":sW})
39-
I = Variable(torch.randn(N, C, H, W).cuda(), requires_grad=True)
40-
W = Parameter(torch.randn(O, C, kH, kW).cuda())
41-
out = convolution(I, W)
42-
out[0].sum().backward()
43-
44-
.. note::
45-
46-
Please note the usage of :code:`.cuda()` i.e. tensor data is declared as the CUDA
47-
type. Applying :code:`Variable` on the tensor data essentially allows the layer to be
48-
part of computations graph and if :code:`Variable(torch.rand(), requires_grad=True).cuda()`
49-
is done, then the grad will be available for the `Variable.cuda()` and not the actual `Variable/Tensor`.
50-
51-
52-
Specifying CudaMappingOptions
53-
--------------------------
54-
55-
We highly recommend passing the mapping options when running the kernel.
56-
See :ref:`must_pass_options` for more details. When running the training layer,
57-
you can pass the options for forward and backward layer separately or you can
58-
pass the same options for them. In case you want to pass different options for
59-
them, the example for that would be:
60-
61-
.. code-block:: python
62-
63-
import tensor_comprehensions as tc
64-
import torch
65-
from torch.autograd import Variable
66-
from torch.nn.parameter import Parameter
67-
CONV_LANG = """
68-
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{
69-
O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw)
70-
}}
71-
def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O) -> (d_I, d_W1) {{
72-
d_I(n, c, h, w) +=! d_O( n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw)
73-
d_W1(m, c, kh, kw) +=! d_O(r_n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w)
74-
}}
75-
"""
76-
N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1
77-
convolution = tc.define(CONV_LANG, training=True, name="convolution", backward="convolution_grad", constants={"sh":sH, "sw":sW})
78-
I = Variable(torch.randn(N, C, H, W).cuda(), requires_grad=True)
79-
W = Parameter(torch.randn(O, C, kH, kW).cuda())
80-
out = convolution(I, W, options=[tc.CudaMappingOptions("conv"), tc.CudaMappingOptions("group_conv")])
81-
out[0].sum().backward()
82-
83-
In order to obtain options via autotuning for backward and forward layer, keep reading further.
84-
85-
86-
Autotuning training layer
87-
-------------------------
88-
89-
You can autotune a training layer easily. The forward and backward layers will
90-
be tuned separately in order to ensure maximal performance. Please read :ref:`pytorch_autotune_layers`
91-
for how to set autotuner parameters. We will see how to autotune a training
92-
layer, save cache and run the layer with help of examples:
93-
94-
You can either cache to default options or to a file (also see :ref:`autotuner_cache_choices`).
95-
Let's see how to cache options to file when we tune a training layer.
96-
97-
.. code-block:: python
98-
99-
import tensor_comprehensions as tc
100-
import torch
101-
CONV_LANG = """
102-
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{
103-
O(n, m, h, w) +=! I(n, r_c, {sh} * h + r_kh, {sw} * w + r_kw) * W1(m, r_c, r_kh, r_kw)
104-
}}
105-
def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O) -> (d_I, d_W1) {{
106-
d_I(n, c, h, w) +=! d_O( n, r_m, {sh} * h - r_kh, {sw} * w - r_kw) * W1(r_m, c, r_kh, r_kw)
107-
d_W1(m, c, kh, kw) +=! d_O(r_n, m, {sh} * r_h - kh, {sw} * r_w - kw) * I(r_n, c, r_h, r_w)
108-
}}
109-
"""
110-
N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1
111-
convolution = tc.define(CONV_LANG, training=True, name="convolution", backward="convolution_grad", constants={"sh":sH, "sw":sW})
112-
I, W1 = torch.randn(N, C, H, W).cuda(), torch.randn(O, C, kH, kW).cuda()
113-
convolution.autotune(I, W, cache="convolution_train.tc")
114-
out = convolution(I, W)
115-
out[0].sum().backward()
116-
117-
You will find a cache file created: :code:`convolution_train.options` has
118-
options for the forward layer and :code:`convolution_train_backward.options` file
119-
has options for the grad layer.
120-
121-
Reordering grad outputs
122-
-----------------------
123-
124-
In the backward pass, TC uses the list of input tensors in the forward pass and appends
125-
the output tensors list to it. This is treated as the input to the backward TC definition.
126-
However, sometimes, the forward layer TC might have some temporary variable for which we don't
127-
need gradient in the backward TC. In such cases, users can use :code:`reorder_function`. See
128-
the example below for how to use it:
129-
130-
.. code-block:: python
131-
132-
import tensor_comprehensions as tc
133-
import torch
134-
LANG = """
135-
def convolution(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (tmp, O) {
136-
tmp(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
137-
O(n, m, h, w) = tmp(n, m, h, w) + B(m)
138-
}
139-
def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B, float(N, M, H, W) d_O)
140-
-> (d_I, d_W1, d_B) {
141-
d_I(n, c, h, w) +=! d_O( n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw)
142-
d_W1(m, c, kh, kw) +=! d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
143-
d_B(m) +=! d_O(n, m, h, w)
144-
}
145-
"""
146-
147-
# since the forward layer produces two outputs, one is temporary which is
148-
# not needed in the forward pass, we can reorder the grad_outputs as we want.
149-
# So, here we return the output grad that we actually use in backwards TC.
150-
def reorder():
151-
def reorder_function(grad_outputs):
152-
return [grad_outputs[1]]
153-
return reorder_function
154-
155-
N, C, H, W, M, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1
156-
convolution = tc.define(LANG, training=True, name="convolution", backward="convolution_grad")
157-
I = Variable(torch.randn(N, C, H, W).cuda(), requires_grad=True)
158-
W = Parameter(torch.randn(M, C, kH, kW).cuda())
159-
B = Parameter(torch.randn(M).cuda())
160-
out = convolution(I, W, B, reorder_function=reorder())
161-
out[0].sum().backward()
4+
To create a :code:`torch.autograd` function backed by TC one can just use the
5+
:func:`make_autograd` helper function:
6+
7+
.. code-block:: python
8+
9+
conv = """
10+
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
11+
O(n, m, h, w) +=!
12+
I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
13+
}
14+
def convolution_igrad(float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
15+
-> (d_I)
16+
{
17+
d_I(n, c, h, w) +=!
18+
d_O( n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw)
19+
}
20+
def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
21+
{
22+
d_W1(m, c, kh, kw) +=!
23+
d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
24+
}
25+
"""
26+
27+
N, C, H, W, O, kH, kW = 32, 4, 56, 56, 16, 1, 1
28+
T = tc.define(
29+
conv,
30+
tc.make_autotuned_options_factory(
31+
starting_options='naive',
32+
tuner_config=tuner_config))
33+
I, W = (
34+
torch.randn(N, C, H, W, device='cuda', requires_grad=True),
35+
torch.randn(O, C, kH, kW, device='cuda', requires_grad=True))
36+
37+
def convolution_backward(I, W, d_O):
38+
d_I = T.convolution_igrad(W, d_O)
39+
d_O = T.convolution_wgrad(I, d_O)
40+
return (d_I, d_O)
41+
42+
convolution_function = tc.make_autograd(
43+
T.convolution, convolution_backward)
44+
45+
# First occurrence triggers tuning
46+
out = convolution_function(I, W)
47+
out.sum().backward()
48+
49+
# Subsequent occurrences do not
50+
out = convolution_function(I, W)
51+
out.sum().backward()

0 commit comments

Comments
 (0)