Skip to content

Commit 31478cc

Browse files
feat: 添加寒武纪平台where/expand/conv算子
1 parent 385ce12 commit 31478cc

File tree

14 files changed

+775
-23
lines changed

14 files changed

+775
-23
lines changed

src/04kernel/src/collectors/conv.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "kernel/collectors/conv.h"
2+
#include "../kernels/conv/cnnl_kernel.hh"
23
#include "../kernels/conv/cudnn_kernel.hh"
34

45
namespace refactor::kernel {
@@ -23,6 +24,11 @@ namespace refactor::kernel {
2324
ans.emplace_back(std::move(ptr));
2425
}
2526
break;
27+
case decltype(_target)::Mlu:
28+
if (auto ptr = ConvCnnl::build(poolAttrs, x, w, b, y); ptr) {
29+
ans.emplace_back(std::move(ptr));
30+
}
31+
break;
2632
default:
2733
UNREACHABLEX(void, "Unknown target");
2834
}

src/04kernel/src/collectors/where.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#include "kernel/collectors/where.h"
2+
#include "../kernels/where/cnnl_kernel.hh"
23
#include "../kernels/where/cpu_kernel.hh"
34
#include "../kernels/where/where_cuda.hh"
45

56
namespace refactor::kernel {
67

78
std::vector<KernelBox>
8-
WhereCollector::filter(TensorRefs inputs, TensorRefs) const {
9+
WhereCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
910
std::vector<KernelBox> ans;
1011
switch (_target) {
1112
case decltype(_target)::Cpu:
@@ -18,6 +19,11 @@ namespace refactor::kernel {
1819
ans.emplace_back(std::move(ptr));
1920
}
2021
break;
22+
case decltype(_target)::Mlu:
23+
if (auto ptr = WhereCnnl::build(inputs, outputs); ptr) {
24+
ans.emplace_back(std::move(ptr));
25+
}
26+
break;
2127
default:
2228
UNREACHABLEX(void, "Unknown target");
2329
}

src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ namespace refactor::kernel {
107107
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut));
108108

109109
auto handle = res.fetchOrStore<CnnlContext>()->handle;
110-
auto xTransSize = cnnlGetTensorElementNum(d->inDescTrans) * sizeof(info.dtX);
110+
auto xTransSize = cnnlGetTensorElementNum(d->inDescTrans) * info.dtX.size();
111111
size_t workspaceSize;
112112
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->inDesc, d->NCHW2NHWC, &workspaceSize));
113-
size_t totalWorkspaceSize = xTransSize + workspaceSize;
113+
size_t totalWorkspaceSize = xTransSize * 2 + workspaceSize;
114114

115115
res.fetchOrStore<CnnlContext>();
116116
auto routine = [d = std::move(d),
@@ -129,7 +129,7 @@ namespace refactor::kernel {
129129

130130
void *xTrans = workspace;
131131
void *yTrans = xTrans + xTransSize;
132-
void *cursor = yTrans + workspaceSize;
132+
void *cursor = yTrans + xTransSize;
133133

134134
// transpose NCHW input to NHWC
135135
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->inDesc, x,
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#include "cnnl_kernel.hh"
2+
3+
#ifdef USE_BANG
4+
#include "../../utilities/bang/cnnl_context.hh"
5+
#include "../../utilities/bang/cnnl_functions.h"
6+
#include "../expand/cnnl_kernel.hh"
7+
#include "hardware/functions.h"
8+
#endif
9+
10+
namespace refactor::kernel {
11+
using K = ConvCnnl;
12+
13+
K::ConvCnnl(decltype(info) info_) noexcept
14+
: Kernel(), info(std::move(info_)) {}
15+
16+
auto K::build(PoolAttributes const &poolAttributes,
17+
Tensor const &x,
18+
Tensor const &w,
19+
std::optional<std::reference_wrapper<Tensor const>> b,
20+
Tensor const &y) -> KernelBox {
21+
static const std::unordered_set<decltype(DataType::internal)>
22+
SET{DataType::FP16, DataType::BF16, DataType::F32, DataType::F64, DataType::I8};
23+
#ifndef USE_BANG
24+
return nullptr;
25+
#endif
26+
27+
auto dt = x.dataType;
28+
if (!SET.contains(dt) || w.dataType != dt || y.dataType != dt) {
29+
return nullptr;
30+
}
31+
32+
std::optional<ExpandInfoCnnl> biasExpand = std::nullopt;
33+
if (b) {
34+
ASSERT(b->get().shape[0] == y.shape[1], "");
35+
std::vector<dim_t> input(y.rank(), 1);
36+
input[1] = y.shape[1];
37+
biasExpand.emplace(ExpandInfoCnnl(
38+
b->get().dataType,
39+
slice(input.data(), input.size()),
40+
slice(y.shape.data(), y.rank())));
41+
}
42+
43+
// group is not supported
44+
if (w.rank() != 4 || poolAttributes.rank() != 2) {
45+
return nullptr;
46+
}
47+
auto d = poolAttributes.dilations(),
48+
p = poolAttributes.pads(),
49+
s = poolAttributes.strides();
50+
return std::make_unique<K>(decltype(info){
51+
dt,
52+
{
53+
static_cast<int>(x.shape[0]),
54+
static_cast<int>(x.shape[1]),
55+
static_cast<int>(x.shape[2]),
56+
static_cast<int>(x.shape[3]),
57+
},
58+
{
59+
static_cast<int>(w.shape[0]),
60+
static_cast<int>(w.shape[1]),
61+
static_cast<int>(w.shape[2]),
62+
static_cast<int>(w.shape[3]),
63+
},
64+
{
65+
static_cast<int>(y.shape[0]),
66+
static_cast<int>(y.shape[1]),
67+
static_cast<int>(y.shape[2]),
68+
static_cast<int>(y.shape[3]),
69+
},
70+
{d[0], d[1]},
71+
{p[0], p[1], p[2], p[3]},
72+
{s[0], s[1]},
73+
std::move(biasExpand),
74+
});
75+
}
76+
77+
auto K::typeId() noexcept -> size_t {
78+
static uint8_t ID = 1;
79+
return reinterpret_cast<size_t>(&ID);
80+
}
81+
82+
auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
83+
auto K::description() const noexcept -> std::string_view {
84+
return "Performing conv using CNNL";
85+
}
86+
87+
#ifdef USE_BANG
88+
89+
auto ConvCnnl::lower(Resources &res) const -> RoutineWorkspace {
90+
using namespace cnnl;
91+
using namespace runtime;
92+
93+
// RAII for closure
94+
struct Descriptors {
95+
cnnlTensorDescriptor_t x, y, w;
96+
cnnlTensorDescriptor_t xTrans, yTrans, wTrans;
97+
cnnlTransposeDescriptor_t NCHW2NHWC, NHWC2NCHW;
98+
cnnlConvolutionDescriptor_t conv;
99+
cnnlConvolutionForwardAlgo_t algo;
100+
// std::optional<ExtraPadding> extraPadding;
101+
std::optional<Routine> biasExpand;
102+
bool f32;
103+
104+
Descriptors(decltype(f32) f32_)
105+
:// extraPadding(std::nullopt),
106+
biasExpand(std::nullopt),
107+
f32(f32_) {
108+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&x));
109+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&y));
110+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&w));
111+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&xTrans));
112+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&yTrans));
113+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&wTrans));
114+
CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NCHW2NHWC));
115+
CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NHWC2NCHW));
116+
CNNL_ASSERT(cnnlCreateConvolutionDescriptor(&conv));
117+
}
118+
~Descriptors() noexcept(false) {
119+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(x));
120+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(y));
121+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(w));
122+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(xTrans));
123+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(yTrans));
124+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(wTrans));
125+
CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NCHW2NHWC));
126+
CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NHWC2NCHW));
127+
CNNL_ASSERT(cnnlDestroyConvolutionDescriptor(conv));
128+
}
129+
130+
Descriptors(const Descriptors &) = delete;
131+
Descriptors(Descriptors &&) = delete;
132+
};
133+
auto d = std::make_shared<Descriptors>(info.dt != DataType::F64);
134+
// d->extraPadding = ExtraPadding::build(info.dt, info.xShape, info.pad);
135+
if (info.biasExpand) {
136+
d->biasExpand = ExpandCnnl(*info.biasExpand).lower(res).routine;
137+
}
138+
int xs[]{
139+
info.xShape[0],
140+
info.xShape[1],
141+
info.xShape[2] + std::abs(info.pad[0] - info.pad[2]),
142+
info.xShape[3] + std::abs(info.pad[1] - info.pad[3]),
143+
};
144+
145+
auto NHWC = [](const int shape[]) -> std::vector<int> {
146+
return {
147+
shape[0], shape[2], shape[3], shape[1]};
148+
};
149+
150+
std::vector<int> xsNHWC = NHWC(xs);
151+
std::vector<int> wsNHWC = NHWC(info.wShape);
152+
std::vector<int> ysNHWC = NHWC(info.yShape);
153+
154+
setCnnlTensor(d->x, info.dt, slice(xs, 4));
155+
setCnnlTensor(d->y, info.dt, slice(info.yShape, 4));
156+
setCnnlTensor(d->w, info.dt, slice(info.wShape, 4));
157+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->xTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, xsNHWC.data()));
158+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->yTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, ysNHWC.data()));
159+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->wTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, wsNHWC.data()));
160+
161+
auto xTransSize = cnnlGetTensorElementNum(d->xTrans) * info.dt.size();
162+
auto yTransSize = cnnlGetTensorElementNum(d->yTrans) * info.dt.size();
163+
auto wTransSize = cnnlGetTensorElementNum(d->wTrans) * info.dt.size();
164+
165+
int permuteIn[4] = {0, 2, 3, 1};
166+
int permuteOut[4] = {0, 3, 1, 2};
167+
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NCHW2NHWC, 4, permuteIn));
168+
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut));
169+
170+
size_t xWorkspaceSize, yWorkspaceSize, wWorkspaceSize, convWorkspaceSize;
171+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
172+
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->x, d->NCHW2NHWC, &xWorkspaceSize));
173+
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->w, d->NCHW2NHWC, &wWorkspaceSize));
174+
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->yTrans, d->NHWC2NCHW, &yWorkspaceSize));
175+
176+
// clang-format off
177+
auto computation = info.dt == DataType::F64 ? DataType::F64
178+
: info.dt == DataType::I8 ? DataType::I32
179+
: DataType::F32;
180+
// clang-format on
181+
auto group = xs[1] / info.wShape[1];
182+
CNNL_ASSERT(cnnlSetConvolutionDescriptor(d->conv, 4, info.pad, info.stride, info.dilation, group, cnnlDataTypeConvert(computation)));
183+
CNNL_ASSERT(cnnlGetConvolutionForwardAlgorithm(
184+
handle, d->conv, d->xTrans, d->wTrans, d->yTrans,
185+
CNNL_CONVOLUTION_FWD_FASTEST, &d->algo));
186+
187+
CNNL_ASSERT(cnnlGetConvolutionForwardWorkspaceSize(
188+
handle, d->xTrans, d->wTrans, d->yTrans, NULL,
189+
d->conv, d->algo, &convWorkspaceSize));
190+
191+
// if (d->extraPadding) {
192+
// workspaceSize = hardware::alignBytes(workspaceSize, 256);
193+
// }
194+
195+
size_t workspaceSize = xTransSize + yTransSize + wTransSize + std::max({xWorkspaceSize, wWorkspaceSize, yWorkspaceSize, convWorkspaceSize});
196+
197+
res.fetchOrStore<CnnlContext>();
198+
auto routine = [d, xTransSize, yTransSize, wTransSize,
199+
xWorkspaceSize, wWorkspaceSize,
200+
yWorkspaceSize, convWorkspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
201+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
202+
void const *x = inputs[0], *w = inputs[1];
203+
void *y = outputs[0];
204+
// if (auto f = d->extraPadding; f) {
205+
// x = (*f)(x, reinterpret_cast<uint8_t *>(workspace) + workspaceSize);
206+
// }
207+
// if (auto f = d->biasExpand; f) {
208+
// (*f)(res, workspace, inputs + 2, outputs);
209+
// }
210+
211+
void *xTrans = workspace;
212+
void *wTrans = xTrans + xTransSize;
213+
void *yTrans = wTrans + wTransSize;
214+
void *opWorkspace = yTrans + yTransSize;
215+
216+
// transpose NCHW input to NHWC
217+
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->x, x,
218+
d->xTrans, xTrans, opWorkspace, xWorkspaceSize));
219+
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->w, w,
220+
d->wTrans, wTrans, opWorkspace, wWorkspaceSize));
221+
222+
// build alpha/beta for double
223+
auto a = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1),
224+
b = d->f32
225+
? factor<fp32_t>(d->biasExpand ? 1 : 0)
226+
: factor<fp64_t>(d->biasExpand ? 1 : 0);
227+
CNNL_ASSERT(cnnlConvolutionForward(
228+
handle,
229+
d->conv, d->algo, &a,
230+
d->xTrans, xTrans, d->wTrans, wTrans,
231+
NULL, NULL, opWorkspace, convWorkspaceSize,
232+
&b, d->yTrans, yTrans));
233+
234+
// transpose NHWC intermediates to NCHW
235+
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NHWC2NCHW, d->yTrans, yTrans,
236+
d->y, y, opWorkspace, yWorkspaceSize));
237+
};
238+
return {std::move(routine), workspaceSize};
239+
}
240+
241+
#endif
242+
243+
}// namespace refactor::kernel
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef KERNEL_CONV_CNNL_KERNEL_HH
2+
#define KERNEL_CONV_CNNL_KERNEL_HH
3+
4+
#include "../../kernels/expand/cnnl_kernel.hh"
5+
#include "kernel/attributes/pool_attributes.h"
6+
#include "kernel/kernel.h"
7+
#include <optional>
8+
9+
namespace refactor::kernel {
10+
11+
/// @brief Use `cnnlConvolutionForward`.
12+
/// It only supports 4D tensors.
13+
struct ConvCnnl final : public Kernel {
14+
struct {
15+
DataType dt;
16+
int xShape[4],
17+
wShape[4],
18+
yShape[4],
19+
dilation[2],
20+
pad[4],
21+
stride[2];
22+
std::optional<ExpandInfoCnnl> biasExpand;
23+
} info;
24+
25+
explicit ConvCnnl(decltype(info)) noexcept;
26+
27+
static KernelBox build(PoolAttributes const &,
28+
Tensor const &,
29+
Tensor const &,
30+
std::optional<std::reference_wrapper<Tensor const>>,
31+
Tensor const &);
32+
static size_t typeId() noexcept;
33+
34+
size_t kernelTypeId() const noexcept final;
35+
std::string_view description() const noexcept final;
36+
#ifdef USE_BANG
37+
RoutineWorkspace lower(Resources &) const final;
38+
#endif
39+
};
40+
41+
}// namespace refactor::kernel
42+
43+
#endif// KERNEL_CONV_CNNL_KERNEL_HH

0 commit comments

Comments
 (0)