Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimitze DepthToSpace for mode = DCR, blocksize = 4 and width = 8x #23399

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 132 additions & 12 deletions onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,121 @@
return Status::OK();
}

template <typename T>
static Status DepthSpaceDCRBlock4Width8XOpCpuImpl(OpKernelContext* ctx,
const Tensor& input, Tensor& output,
const int64_t batch,
const int64_t input_depth,
const int64_t input_height,
const int64_t input_width) {
std::vector<size_t> permutations = {0, 3, 4, 1, 5, 2};

Check warning on line 152 in onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc:152: Add #include <vector> for vector<> [build/include_what_you_use] [4]
constexpr int blocksize = 4;
constexpr int internal_rank = 6;
int64_t internal_output_depth = input_depth / blocksize / blocksize;
const TensorShape internal_input_shape = TensorShape{batch, blocksize, blocksize,
internal_output_depth, input_height, input_width};
const TensorShape internal_output_shape = TensorShape{batch, internal_output_depth,
input_height, blocksize,
input_width, blocksize};
const int64_t number_of_elements = internal_input_shape.Size();
const auto& internal_output_dims = internal_output_shape.GetDims();

InlinedVector<size_t> stride(internal_rank);
for (size_t i = 0; i < internal_rank; i++) {
size_t inpdim = permutations[i];
if (inpdim + 1 < internal_rank)
stride[i] = onnxruntime::narrow<size_t>(internal_input_shape.SizeFromDimension(inpdim + 1));
else
stride[i] = 1;
}

InlinedVector<size_t> internal_output_stride(internal_rank);
internal_output_stride[internal_rank - 1] = 1;
for (int64_t i = internal_rank - 2; i >= 0; --i) {
internal_output_stride[i] = internal_output_stride[i + 1] * internal_output_dims[i + 1];
}

const auto* input_data = reinterpret_cast<const T*>(input.DataRaw());
auto* output_data = reinterpret_cast<T*>(output.MutableDataRaw());

Status status = Status::OK();

concurrency::ThreadPool::TryParallelFor(
ctx->GetOperatorThreadPool(), static_cast<std::ptrdiff_t>(number_of_elements),
{static_cast<float>(sizeof(uint8_t)), static_cast<float>(sizeof(uint8_t)), 1.0F},
[&internal_output_stride, input_data, &stride, output_data](std::ptrdiff_t first,
std::ptrdiff_t last) {
constexpr int chunk_size = 32;

ORT_ENFORCE((first < last) && (first % chunk_size == 0) && (last % chunk_size == 0));

/// The loop is unrolled by 32 for the code below:
/// for (std::ptrdiff_t i = first; i < last; ++i) {
/// int d0 = static_cast<int>(i / internal_output_stride[0]);
/// int d1 = static_cast<int>((i % internal_output_stride[0]) / internal_output_stride[1]);
/// int d2 = static_cast<int>((i % internal_output_stride[1]) / internal_output_stride[2]);
/// int d3 = static_cast<int>((i % internal_output_stride[2]) / internal_output_stride[3]);
/// int d4 = static_cast<int>((i % internal_output_stride[3]) / internal_output_stride[4] /* blocksize = 4 */);
/// int d5 = static_cast<int>(i % internal_output_stride[4] /* blocksize = 4 */);
/// const T* source = input_data + (d0 * stride[0] +
/// d1 * stride[1] +
/// d2 * stride[2] +
/// d3 * stride[3] +
/// d4 * stride[4] /* 1 */ +
/// d5 * stride[5]);
/// T* target = output_data + i;
/// *target = *source;
/// }
Comment on lines +193 to +209

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
for (std::ptrdiff_t i = first; i < last; i += chunk_size) {
int d0 = static_cast<int>(i / internal_output_stride[0]);
int d1 = static_cast<int>((i % internal_output_stride[0]) / internal_output_stride[1]);
int d2 = static_cast<int>((i % internal_output_stride[1]) / internal_output_stride[2]);
int d3 = static_cast<int>((i % internal_output_stride[2]) / internal_output_stride[3]);
int d4 = static_cast<int>((i % internal_output_stride[3]) / 4 /* blocksize = internal_output_stride[4] */);
const T* source = input_data + (d0 * stride[0] +
d1 * stride[1] +
d2 * stride[2] +
d3 * stride[3] +
d4 * 1 /* stride[4] */);
T* target = output_data + i;
*(target + 0) = *(source + ((0 / 4) * 1) + ((0 % 4) * stride[5]));
*(target + 1) = *(source + ((1 / 4) * 1) + ((1 % 4) * stride[5]));
*(target + 2) = *(source + ((2 / 4) * 1) + ((2 % 4) * stride[5]));
*(target + 3) = *(source + ((3 / 4) * 1) + ((3 % 4) * stride[5]));
*(target + 4) = *(source + ((4 / 4) * 1) + ((4 % 4) * stride[5]));
*(target + 5) = *(source + ((5 / 4) * 1) + ((5 % 4) * stride[5]));
*(target + 6) = *(source + ((6 / 4) * 1) + ((6 % 4) * stride[5]));
*(target + 7) = *(source + ((7 / 4) * 1) + ((7 % 4) * stride[5]));
*(target + 8) = *(source + ((8 / 4) * 1) + ((8 % 4) * stride[5]));
*(target + 9) = *(source + ((9 / 4) * 1) + ((9 % 4) * stride[5]));
*(target + 10) = *(source + ((10 / 4) * 1) + ((10 % 4) * stride[5]));
*(target + 11) = *(source + ((11 / 4) * 1) + ((11 % 4) * stride[5]));
*(target + 12) = *(source + ((12 / 4) * 1) + ((12 % 4) * stride[5]));
*(target + 13) = *(source + ((13 / 4) * 1) + ((13 % 4) * stride[5]));
*(target + 14) = *(source + ((14 / 4) * 1) + ((14 % 4) * stride[5]));
*(target + 15) = *(source + ((15 / 4) * 1) + ((15 % 4) * stride[5]));
*(target + 16) = *(source + ((16 / 4) * 1) + ((16 % 4) * stride[5]));
*(target + 17) = *(source + ((17 / 4) * 1) + ((17 % 4) * stride[5]));
*(target + 18) = *(source + ((18 / 4) * 1) + ((18 % 4) * stride[5]));
*(target + 19) = *(source + ((19 / 4) * 1) + ((19 % 4) * stride[5]));
*(target + 20) = *(source + ((20 / 4) * 1) + ((20 % 4) * stride[5]));
*(target + 21) = *(source + ((21 / 4) * 1) + ((21 % 4) * stride[5]));
*(target + 22) = *(source + ((22 / 4) * 1) + ((22 % 4) * stride[5]));
*(target + 23) = *(source + ((23 / 4) * 1) + ((23 % 4) * stride[5]));
*(target + 24) = *(source + ((24 / 4) * 1) + ((24 % 4) * stride[5]));
*(target + 25) = *(source + ((25 / 4) * 1) + ((25 % 4) * stride[5]));
*(target + 26) = *(source + ((26 / 4) * 1) + ((26 % 4) * stride[5]));
*(target + 27) = *(source + ((27 / 4) * 1) + ((27 % 4) * stride[5]));
*(target + 28) = *(source + ((28 / 4) * 1) + ((28 % 4) * stride[5]));
*(target + 29) = *(source + ((29 / 4) * 1) + ((29 % 4) * stride[5]));
*(target + 30) = *(source + ((30 / 4) * 1) + ((30 % 4) * stride[5]));
*(target + 31) = *(source + ((31 / 4) * 1) + ((31 % 4) * stride[5]));
}
});

return status;
}

Status DepthToSpace::Compute(OpKernelContext* context) const {
const auto* tensor_pointer = context->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
Expand Down Expand Up @@ -199,18 +314,23 @@
onnxruntime::narrow<std::ptrdiff_t>(input_width),
onnxruntime::narrow<std::ptrdiff_t>(blocksize_));
} else if (input.IsDataType<uint8_t>()) {
SpaceDepthOpCpuImpl<uint8_t>(input, output, permutation,
onnxruntime::narrow<std::ptrdiff_t>(batch),
onnxruntime::narrow<std::ptrdiff_t>(dim1),
onnxruntime::narrow<std::ptrdiff_t>(blocksize_),
onnxruntime::narrow<std::ptrdiff_t>(dim3),
onnxruntime::narrow<std::ptrdiff_t>(input_height),
onnxruntime::narrow<std::ptrdiff_t>(input_width),
onnxruntime::narrow<std::ptrdiff_t>(input_depth / blocksize_ / blocksize_),
onnxruntime::narrow<std::ptrdiff_t>(input_height),
onnxruntime::narrow<std::ptrdiff_t>(blocksize_),
onnxruntime::narrow<std::ptrdiff_t>(input_width),
onnxruntime::narrow<std::ptrdiff_t>(blocksize_));
if (is_dcr_ && (blocksize_ == 4) && (input_width % 8 == 0)) {
ORT_RETURN_IF_ERROR(DepthSpaceDCRBlock4Width8XOpCpuImpl<uint8_t>(context, input, output,
batch, input_depth, input_height, input_width));
} else {
SpaceDepthOpCpuImpl<uint8_t>(input, output, permutation,
onnxruntime::narrow<std::ptrdiff_t>(batch),
onnxruntime::narrow<std::ptrdiff_t>(dim1),
onnxruntime::narrow<std::ptrdiff_t>(blocksize_),
onnxruntime::narrow<std::ptrdiff_t>(dim3),
onnxruntime::narrow<std::ptrdiff_t>(input_height),
onnxruntime::narrow<std::ptrdiff_t>(input_width),
onnxruntime::narrow<std::ptrdiff_t>(input_depth / blocksize_ / blocksize_),
onnxruntime::narrow<std::ptrdiff_t>(input_height),
onnxruntime::narrow<std::ptrdiff_t>(blocksize_),
onnxruntime::narrow<std::ptrdiff_t>(input_width),
onnxruntime::narrow<std::ptrdiff_t>(blocksize_));
}
} else {
// user will not see this as the kernel doesn't claim support for types other than float and double
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input type in DepthToSpace op: ", input.DataType());
Expand Down
Loading