Skip to content

Commit 3ac3fc9

Browse files
committed
[MLU] add bilinear and bilinear_grad
1 parent 6057ef4 commit 3ac3fc9

File tree

1 file changed

+352
-0
lines changed

1 file changed

+352
-0
lines changed
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "kernels/funcs/elementwise_utils.h"
16+
#include "kernels/funcs/mlu_baseop.h"
17+
#include "kernels/funcs/mlu_funcs.h"
18+
#include "kernels/funcs/reduce_op.h"
19+
#include "paddle/phi/kernels/funcs/slice_utils.h"
20+
21+
namespace custom_kernel {
22+
23+
template <typename T, typename Context>
24+
void SetTensorValueKernel(const Context& dev_ctx,
25+
const phi::DenseTensor& x,
26+
const phi::DenseTensor& value,
27+
const phi::IntArray& starts,
28+
const phi::IntArray& ends,
29+
const phi::IntArray& steps,
30+
const std::vector<int64_t>& axes,
31+
const std::vector<int64_t>& decrease_axes,
32+
const std::vector<int64_t>& none_axes,
33+
phi::DenseTensor* out);
34+
35+
template <typename T, typename Context>
36+
void StridedSliceRawKernel(const Context& dev_ctx,
37+
const phi::DenseTensor& x,
38+
const std::vector<int>& axes,
39+
const phi::IntArray& starts,
40+
const phi::IntArray& ends,
41+
const phi::IntArray& strides,
42+
const std::vector<int>& infer_flags,
43+
const std::vector<int>& decrease_axis,
44+
phi::DenseTensor* out);
45+
46+
template <typename T, typename Context>
47+
void BilinearKernel(const Context& dev_ctx,
48+
const phi::DenseTensor& x,
49+
const phi::DenseTensor& y,
50+
const phi::DenseTensor& weight,
51+
const paddle::optional<phi::DenseTensor>& bias,
52+
phi::DenseTensor* out) {
53+
dev_ctx.template Alloc<T>(out);
54+
55+
auto batch_size = x.dims()[0];
56+
auto weight_dims = weight.dims();
57+
int out_dim = weight_dims[0];
58+
auto x_dim = weight_dims[1];
59+
auto y_dim = weight_dims[2];
60+
61+
// Create the intermediate variable to calculate the result of
62+
// Input(X) multiplied by Input(Weight_i), the formula is:
63+
// left_mul = X Weight_i.
64+
Tensor left_mul;
65+
left_mul.Resize(phi::make_ddim({batch_size, y_dim}));
66+
dev_ctx.template Alloc<T>(&left_mul);
67+
68+
MLUCnnlTensorDesc x_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
69+
MLUCnnlTensorDesc y_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
70+
MLUCnnlTensorDesc weight_desc(weight, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
71+
MLUCnnlTensorDesc left_mul_desc(
72+
left_mul, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
73+
74+
phi::DenseTensor output_mat_slice;
75+
output_mat_slice.Resize(phi::make_ddim({batch_size}));
76+
77+
phi::DenseTensor out_temp;
78+
out_temp.Resize(out->dims());
79+
dev_ctx.template Alloc<T>(&out_temp);
80+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), &out_temp);
81+
82+
for (int64_t i = 0; i < out_dim; ++i) {
83+
phi::DenseTensor weight_slice;
84+
weight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
85+
dev_ctx.template Alloc<T>(&weight_slice);
86+
MLUCnnlTensorDesc weight_slice_desc(weight_slice);
87+
88+
phi::DenseTensor matmul_out;
89+
matmul_out.Resize(phi::make_ddim({batch_size, y_dim}));
90+
dev_ctx.template Alloc<T>(&matmul_out);
91+
MLUCnnlTensorDesc matmul_out_desc(matmul_out);
92+
int64_t next_i = i + 1;
93+
int64_t value = 1;
94+
const phi::IntArray& starts_indices = {i};
95+
const phi::IntArray& ends_indices = {next_i};
96+
const phi::IntArray& strides_indices = {value};
97+
std::vector<int> infer_flags(1);
98+
std::vector<int> decrease_axis;
99+
std::vector<int> axes = {0};
100+
custom_kernel::StridedSliceRawKernel<T, Context>(dev_ctx,
101+
weight,
102+
axes,
103+
starts_indices,
104+
ends_indices,
105+
strides_indices,
106+
infer_flags,
107+
decrease_axis,
108+
&weight_slice);
109+
110+
MLUCnnl::Matmul(dev_ctx,
111+
false,
112+
false,
113+
x_desc.get(),
114+
GetBasePtr(&x),
115+
weight_slice_desc.get(),
116+
GetBasePtr(&weight_slice),
117+
left_mul_desc.get(),
118+
GetBasePtr(&left_mul));
119+
120+
int axis = -1;
121+
MLUOpTensorKernel<T>(
122+
dev_ctx, left_mul, y, axis, CNNL_OP_TENSOR_MUL, &matmul_out);
123+
124+
phi::DenseTensor sum_out;
125+
sum_out.Resize({batch_size});
126+
const std::vector<int64_t>& dims = {1};
127+
MLUReduceOp<T>(dev_ctx,
128+
matmul_out,
129+
dims,
130+
false,
131+
/*keep_dim*/ false,
132+
/*reduce_all*/ "reduce_sum",
133+
&sum_out);
134+
135+
std::vector<int64_t> sum_axes = {1};
136+
std::vector<int64_t> decrease_axes;
137+
std::vector<int64_t> none_axes;
138+
custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
139+
*&out_temp,
140+
sum_out,
141+
starts_indices,
142+
ends_indices,
143+
strides_indices,
144+
sum_axes,
145+
decrease_axes,
146+
none_axes,
147+
&output_mat_slice);
148+
}
149+
150+
if (bias.get_ptr()) {
151+
phi::DenseTensor new_bias;
152+
new_bias = bias.get();
153+
int axis = -1;
154+
MLUOpTensorKernel<T>(
155+
dev_ctx, out_temp, new_bias, axis, CNNL_OP_TENSOR_ADD, out);
156+
} else {
157+
TensorCopy(dev_ctx, out_temp, false, out);
158+
}
159+
}
160+
161+
template <typename T, typename Context>
162+
void BilinearGradKernel(const Context& dev_ctx,
163+
const phi::DenseTensor& x,
164+
const phi::DenseTensor& y,
165+
const phi::DenseTensor& weight,
166+
const phi::DenseTensor& dout,
167+
phi::DenseTensor* dx,
168+
phi::DenseTensor* dy,
169+
phi::DenseTensor* dweight,
170+
phi::DenseTensor* dbias) {
171+
auto batch_size = x.dims()[0];
172+
auto weight_dims = weight.dims();
173+
int out_dim = weight_dims[0];
174+
auto x_dim = weight_dims[1];
175+
auto y_dim = weight_dims[2];
176+
177+
// Create the intermediate variable to calculate the Output(Y@Grad).
178+
phi::DenseTensor x_scale;
179+
x_scale.Resize(phi::make_ddim({batch_size, x_dim}));
180+
dev_ctx.template Alloc<T>(&x_scale);
181+
182+
// Create the intermediate variable to calculate the Output(X@Grad).
183+
phi::DenseTensor y_scale;
184+
y_scale.Resize(phi::make_ddim({batch_size, y_dim}));
185+
dev_ctx.template Alloc<T>(&y_scale);
186+
187+
if (dx) {
188+
dev_ctx.template Alloc<T>(dx);
189+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), dx);
190+
}
191+
if (dy) {
192+
dev_ctx.template Alloc<T>(dy);
193+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), dy);
194+
}
195+
if (dweight) {
196+
dev_ctx.template Alloc<T>(dweight);
197+
FillMLUTensorWithHostValue(dev_ctx, static_cast<T>(0.0f), dweight);
198+
}
199+
200+
if (dx || dy || dweight) {
201+
phi::DenseTensor dx_temp;
202+
dx_temp.Resize(dx->dims());
203+
dev_ctx.template Alloc<T>(&dx_temp);
204+
MLUCnnlTensorDesc dx_temp_desc(dx_temp);
205+
206+
phi::DenseTensor dy_temp;
207+
dy_temp.Resize(dy->dims());
208+
dev_ctx.template Alloc<T>(&dy_temp);
209+
MLUCnnlTensorDesc dy_temp_desc(dy_temp);
210+
211+
phi::DenseTensor dweight_temp;
212+
dweight_temp.Resize(phi::make_ddim({x_dim, y_dim}));
213+
dev_ctx.template Alloc<T>(&dweight_temp);
214+
MLUCnnlTensorDesc dweight_temp_desc(dweight_temp);
215+
216+
for (int64_t i = 0; i < out_dim; ++i) {
217+
phi::DenseTensor weight_slice;
218+
weight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
219+
dev_ctx.template Alloc<T>(&weight_slice);
220+
int64_t next_i = i + 1;
221+
int64_t value = 1;
222+
const phi::IntArray& starts_indices = {i};
223+
const phi::IntArray& ends_indices = {next_i};
224+
const phi::IntArray& strides_indices = {value};
225+
std::vector<int> infer_flags(1);
226+
std::vector<int> decrease_axis;
227+
std::vector<int> axes = {0};
228+
custom_kernel::StridedSliceRawKernel<T, Context>(dev_ctx,
229+
weight,
230+
axes,
231+
starts_indices,
232+
ends_indices,
233+
strides_indices,
234+
infer_flags,
235+
decrease_axis,
236+
&weight_slice);
237+
weight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
238+
MLUCnnlTensorDesc weight_slice_desc(weight_slice);
239+
MLUCnnlTensorDesc x_scale_desc(x_scale);
240+
MLUCnnlTensorDesc y_scale_desc(y_scale);
241+
MLUCnnlTensorDesc dx_desc(*dx);
242+
MLUCnnlTensorDesc dy_desc(*dy);
243+
MLUCnnlTensorDesc y_desc(y);
244+
245+
// dout[:, i]
246+
std::vector<int> dout_axes = {1};
247+
std::vector<int> decrease_axes;
248+
phi::DenseTensor dout_mat_slice;
249+
dout_mat_slice.Resize(phi::make_ddim({batch_size}));
250+
custom_kernel::StridedSliceRawKernel<T, Context>(dev_ctx,
251+
dout,
252+
dout_axes,
253+
starts_indices,
254+
ends_indices,
255+
strides_indices,
256+
infer_flags,
257+
decrease_axis,
258+
&dout_mat_slice);
259+
if (dx) {
260+
int axis = -1;
261+
dout_mat_slice.Resize({batch_size, 1});
262+
MLUCnnlTensorDesc dout_mat_slice_desc(dout_mat_slice);
263+
MLUOpTensorKernel<T>(
264+
dev_ctx, dout_mat_slice, y, axis, CNNL_OP_TENSOR_MUL, &y_scale);
265+
MLUCnnl::Matmul(dev_ctx,
266+
false,
267+
true,
268+
y_scale_desc.get(),
269+
GetBasePtr(&y_scale),
270+
weight_slice_desc.get(),
271+
GetBasePtr(&weight_slice),
272+
dx_temp_desc.get(),
273+
GetBasePtr(&dx_temp));
274+
MLUOpTensorKernel<T>(
275+
dev_ctx, dx_temp, *dx, axis, CNNL_OP_TENSOR_ADD, dx);
276+
}
277+
if (dy || dweight) {
278+
int axis = -1;
279+
dout_mat_slice.Resize({batch_size, 1});
280+
MLUCnnlTensorDesc dout_mat_slice_desc(dout_mat_slice);
281+
MLUOpTensorKernel<T>(
282+
dev_ctx, dout_mat_slice, x, axis, CNNL_OP_TENSOR_MUL, &x_scale);
283+
if (dy) {
284+
MLUCnnl::Matmul(dev_ctx,
285+
false,
286+
false,
287+
x_scale_desc.get(),
288+
GetBasePtr(&x_scale),
289+
weight_slice_desc.get(),
290+
GetBasePtr(&weight_slice),
291+
dy_temp_desc.get(),
292+
GetBasePtr(&dy_temp));
293+
MLUOpTensorKernel<T>(
294+
dev_ctx, dy_temp, *dy, axis, CNNL_OP_TENSOR_ADD, dy);
295+
}
296+
if (dweight) {
297+
MLUCnnl::Matmul(dev_ctx,
298+
true,
299+
false,
300+
x_scale_desc.get(),
301+
GetBasePtr(&x_scale),
302+
y_desc.get(),
303+
GetBasePtr(&y),
304+
dweight_temp_desc.get(),
305+
GetBasePtr(&dweight_temp));
306+
307+
std::vector<int64_t> dweight_axes = {0};
308+
std::vector<int64_t> decrease_axes;
309+
std::vector<int64_t> none_axes;
310+
phi::DenseTensor dweight_slice;
311+
dweight_slice.Resize(phi::make_ddim({x_dim, y_dim}));
312+
dev_ctx.template Alloc<T>(&dweight_slice);
313+
MLUCnnlTensorDesc dweight_slice_desc(dweight_slice);
314+
custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
315+
*dweight,
316+
dweight_temp,
317+
starts_indices,
318+
ends_indices,
319+
strides_indices,
320+
dweight_axes,
321+
decrease_axes,
322+
none_axes,
323+
&dweight_slice);
324+
}
325+
}
326+
}
327+
// calculate the gradient of Input(Bias).
328+
if (dbias) {
329+
dev_ctx.template Alloc<T>(dbias);
330+
const std::vector<int64_t>& dims = {0};
331+
MLUReduceOp<T>(dev_ctx,
332+
dout,
333+
dims,
334+
false, /*keep_dim*/
335+
false, /*reduce_all*/
336+
"reduce_sum",
337+
dbias);
338+
}
339+
}
340+
}
341+
342+
} // namespace custom_kernel
343+
344+
PD_REGISTER_PLUGIN_KERNEL(
345+
bilinear, mlu, ALL_LAYOUT, custom_kernel::BilinearKernel, float, double) {}
346+
347+
PD_REGISTER_PLUGIN_KERNEL(bilinear_grad,
348+
mlu,
349+
ALL_LAYOUT,
350+
custom_kernel::BilinearGradKernel,
351+
float,
352+
double) {}

0 commit comments

Comments
 (0)