diff --git a/include/infini_operators.h b/include/infini_operators.h index 9a5a2555..77b25eb0 100644 --- a/include/infini_operators.h +++ b/include/infini_operators.h @@ -3,10 +3,11 @@ #include "ops/attention/attention.h" #include "ops/avg_pool/avg_pool.h" #include "ops/causal_softmax/causal_softmax.h" -#include "ops/global_avg_pool/global_avg_pool.h" +#include "ops/conv/conv.h" +#include "ops/conv_act/conv_act.h" #include "ops/expand/expand.h" #include "ops/gemm/gemm.h" -#include "ops/conv/conv.h" +#include "ops/global_avg_pool/global_avg_pool.h" #include "ops/matmul/matmul.h" #include "ops/max_pool/max_pool.h" #include "ops/mlp/mlp.h" diff --git a/include/ops/activations.h b/include/ops/activations.h new file mode 100644 index 00000000..6152c216 --- /dev/null +++ b/include/ops/activations.h @@ -0,0 +1,25 @@ +#ifndef __ACTIVATIONS_H__ +#define __ACTIVATIONS_H__ + +/** + * @brief Specifies the type of activation function + */ +typedef enum InfiniActivationMode { + // activation functions + INFINI_ACTIVATION_IDENTITY = 0, + INFINI_ACTIVATION_RELU = 1, + INFINI_ACTIVATION_LEAKY_RELU = 2, + INFINI_ACTIVATION_CLIPPED_RELU = 3, + INFINI_ACTIVATION_SIGMOID = 4, + INFINI_ACTIVATION_HEAVISIDE_STEP = 5, + INFINI_ACTIVATION_ELU = 6, + INFINI_ACTIVATION_GELU = 7, + INFINI_ACTIVATION_SIN = 8, + INFINI_ACTIVATION_TANH = 9, + + // Count + // NOTE: new activation functions should add before "Count" + INFINI_ACTIVATION_COUNT, +} InfiniActivationMode_t; + +#endif diff --git a/include/ops/conv/conv.h b/include/ops/conv/conv.h index 12e1b289..1f023436 100644 --- a/include/ops/conv/conv.h +++ b/include/ops/conv/conv.h @@ -15,14 +15,15 @@ __C __export infiniopStatus_t infiniopCreateConvDescriptor(infiniopHandle_t hand infiniopTensorDescriptor_t y, infiniopTensorDescriptor_t x, infiniopTensorDescriptor_t w, - void *pads, - void *strides, - void *dilations, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, uint64_t n); __C __export infiniopStatus_t infiniopGetConvWorkspaceSize(infiniopConvDescriptor_t desc, uint64_t *size); -__C __export infiniopStatus_t infiniopConv(infiniopConvDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void *stream); +__C __export infiniopStatus_t infiniopConv(infiniopConvDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void const *b, void *stream); __C __export infiniopStatus_t infiniopDestroyConvDescriptor(infiniopConvDescriptor_t desc); diff --git a/include/ops/conv_act/conv_act.h b/include/ops/conv_act/conv_act.h new file mode 100644 index 00000000..d8a19f0c --- /dev/null +++ b/include/ops/conv_act/conv_act.h @@ -0,0 +1,55 @@ +#ifndef CONV_ACT_H +#define CONV_ACT_H + +#include "../../export.h" +#include "../../operators.h" +#include "../activations.h" +#include + +typedef struct ConvActParam { + /** + * Used by: + * - INFINI_ACTIVATION_CLIPPED_RELU: as its clipping ceiling + */ + double clip_coef; + /** + * Used by: + * - INFINI_ACTIVATION_LEAKY_RELU: as its slope for x < 0 + * - INFINI_ACTIVATION_ELU: alpha * (exp(x) - 1.) for x < 0 + */ + double alpha; + /** + * Used by: + * - INFINI_ACTIVATION_GELU: as its approximation switch + */ + const char *approximate; + +} ConvActParam_t; + +typedef struct ConvActDescriptor { + Device device; +} ConvActDescriptor; + +typedef ConvActDescriptor *infiniopConvActDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateConvActDescriptor(infiniopHandle_t handle, + infiniopConvActDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n, + InfiniActivationMode_t activation_mode, + ConvActParam_t act_params); + +__C __export infiniopStatus_t infiniopGetConvActWorkspaceSize(infiniopConvActDescriptor_t desc, uint64_t *size); + +__C __export infiniopStatus_t infiniopConvAct(infiniopConvActDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void const *b, void *stream); + +__C __export infiniopStatus_t infiniopDestroyConvActDescriptor(infiniopConvActDescriptor_t desc); + + +#endif diff --git a/operatorspy/tests/conv.py b/operatorspy/tests/conv.py index 7e7ea953..0c30b8c7 100644 --- a/operatorspy/tests/conv.py +++ b/operatorspy/tests/conv.py @@ -38,23 +38,26 @@ class ConvDescriptor(Structure): infiniopConvDescriptor_t = POINTER(ConvDescriptor) -def conv(x, w, stride, padding, dilation): - match len(x.shape) - 2: - case 1: - return F.conv1d( - x, w, stride=stride, padding=padding, dilation=dilation - ) - case 2: - return F.conv2d( - x, w, stride=stride, padding=padding, dilation=dilation - ) - case 3: - return F.conv3d( - x, w, stride=stride, padding=padding, dilation=dilation - ) - case _: - print("Error: Pytorch -> Unsupported tensor dimension") - return None +def conv(x, w, b, stride, padding, dilation): + ndim = len(x.shape) - 2 + conv_func_map = { + 1: F.conv1d, + 2: F.conv2d, + 3: F.conv3d + } + + if ndim not in conv_func_map: + print("Error: Pytorch -> Unsupported tensor dimension") + return None + + # Select the appropriate convolution function + conv_func = conv_func_map[ndim] + + if PROFILE: + ans = conv_func(x, w, b, stride=stride, padding=padding, dilation=dilation) + torch.cuda.synchronize() + return ans + return conv_func(x, w, b, stride=stride, padding=padding, dilation=dilation) # infer the shape of the output given the inputs for a N-ary convolution @@ -95,30 +98,33 @@ def test( pads, strides, dilations, - tensor_stride=None, + add_bias, tensor_dtype=torch.float16, ): assert len(pads) == len(strides) == len(dilations) print( - f"Testing Conv on {torch_device} with x_shape: {x_shape}, w_shape: {w_shape}, b_shape: {w_shape[0]}, pads: {pads}, strides: {strides}, dilations: {dilations}, x_stride: {tensor_stride} dtype:{tensor_dtype}" + f"Testing Conv on {torch_device} with x_shape: {x_shape}, w_shape: {w_shape}, add_bias: {add_bias}, " + f"b_shape: {w_shape[0]}, pads: {pads}, strides: {strides}, dilations: {dilations}, dtype:{tensor_dtype}" ) x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) w = torch.rand(w_shape, dtype=tensor_dtype).to(torch_device) + b = torch.round((torch.rand(w_shape[0], dtype=tensor_dtype).to(torch_device) * 2 - 1) * 1000) / 1000 if add_bias else None y = torch.zeros( inferShape(x.shape, w.shape, pads, strides, dilations), dtype=tensor_dtype ).to(torch_device) for i in range(NUM_PRERUN if PROFILE else 1): - ans = conv(x, w, strides, pads, dilations) + ans = conv(x, w, b, strides, pads, dilations) if PROFILE: start_time = time.time() for i in range(NUM_ITERATIONS): - _ = conv(x, w, strides, pads, dilations) + _ = conv(x, w, b, strides, pads, dilations) elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f"pytorch time: {elapsed :6f}") x_tensor = to_tensor(x, lib) w_tensor = to_tensor(w, lib) + b_tensor = to_tensor(b, lib) if b is not None else None y_tensor = to_tensor(y, lib) descriptor = infiniopConvDescriptor_t() @@ -129,6 +135,7 @@ def test( y_tensor.descriptor, x_tensor.descriptor, w_tensor.descriptor, + b_tensor.descriptor if b_tensor else None, tuple_to_void_p(pads), tuple_to_void_p(strides), tuple_to_void_p(dilations), @@ -157,6 +164,7 @@ def test( y_tensor.data, x_tensor.data, w_tensor.data, + b_tensor.data if b_tensor else None, None, ) ) @@ -171,6 +179,7 @@ def test( y_tensor.data, x_tensor.data, w_tensor.data, + b_tensor.data if b_tensor else None, None, ) ) @@ -187,18 +196,18 @@ def test( def test_cpu(lib, test_cases): device = DeviceEnum.DEVICE_CPU handle = create_handle(lib, device) - for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases: - test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16) - test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32) + for x_shape, w_shape, pads, strides, dilations, add_bias in test_cases: + test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float16) + test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float32) destroy_handle(lib, handle) def test_cuda(lib, test_cases): device = DeviceEnum.DEVICE_CUDA handle = create_handle(lib, device) - for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases: - test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16) - test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32) + for x_shape, w_shape, pads, strides, dilations, add_bias in test_cases: + test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float16) + test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float32) destroy_handle(lib, handle) @@ -207,22 +216,30 @@ def test_bang(lib, test_cases): device = DeviceEnum.DEVICE_BANG handle = create_handle(lib, device) - for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases: - test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16) - test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32) + for x_shape, w_shape, pads, strides, dilations, add_bias in test_cases: + test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float16) + test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float32) destroy_handle(lib, handle) if __name__ == "__main__": test_cases = [ - # x_shape, w_shape, pads, strides, dilations, x_strides + # x_shape, w_shape, pads, strides, dilations, add_bias ( (32, 3, 4), (32, 3, 5), (1,), (1,), (1,), - None, + False, + ), + ( + (3, 7, 4), + (3, 7, 5), + (1,), + (1,), + (1,), + True, ), ( (1, 3, 4, 4), @@ -230,7 +247,7 @@ def test_bang(lib, test_cases): (1, 1), (1, 2), (2, 1), - None, + True, ), ( (32, 3, 128, 128), @@ -238,7 +255,7 @@ def test_bang(lib, test_cases): (2, 2), (2, 2), (1, 1), - None, + False, ), ( (1, 1, 4, 4, 4), @@ -246,7 +263,7 @@ def test_bang(lib, test_cases): (1, 1, 1), (1, 1, 1), (1, 1, 1), - None, + True, ), ( (32, 3, 32, 32, 32), @@ -254,7 +271,7 @@ def test_bang(lib, test_cases): (3, 2, 2), (4, 3, 3), (2, 2, 1), - None, + False, ), ] args = get_args() @@ -266,6 +283,7 @@ def test_bang(lib, test_cases): infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, c_void_p, c_void_p, c_void_p, @@ -280,6 +298,7 @@ def test_bang(lib, test_cases): c_void_p, c_void_p, c_void_p, + c_void_p, ] lib.infiniopDestroyConvDescriptor.restype = c_int32 lib.infiniopDestroyConvDescriptor.argtypes = [ diff --git a/operatorspy/tests/conv_act.py b/operatorspy/tests/conv_act.py new file mode 100644 index 00000000..1f62533e --- /dev/null +++ b/operatorspy/tests/conv_act.py @@ -0,0 +1,342 @@ +from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_double +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +import math +import ctypes +from torch.nn import functional as F +from typing import List, Tuple + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class ConvActDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopConvActDescriptor_t = POINTER(ConvActDescriptor) + + +def convAct(x, w, bias, stride, padding, dilation, mode): + ndim = len(x.shape) - 2 + conv_func_map = { + 1: F.conv1d, + 2: F.conv2d, + 3: F.conv3d + } + activation_func_map = { + 0: lambda x: x, # Identity + 1: F.relu, # ReLU activation + 2: torch.sigmoid # Sigmoid activation + } + + if ndim not in conv_func_map: + print("Error: Pytorch -> Unsupported tensor dimension") + return None + + if mode not in activation_func_map: + print("Error: Unsupported activation mode") + return None + + # Select the appropriate convolution function + conv_func = conv_func_map[ndim] + + if PROFILE: + ans = conv_func(x, w, bias=bias, stride=stride, padding=padding, dilation=dilation) + torch.cuda.synchronize() + return activation_func_map[mode](ans) + + ans = conv_func(x, w, bias=bias, stride=stride, padding=padding, dilation=dilation) + return activation_func_map[mode](ans) + + +# infer the shape of the output given the inputs for a N-ary convolution +def inferShape( + x_shape: List[int], + w_shape: List[int], + pads: List[int], + strides: List[int], + dilations: List[int], +) -> Tuple[int, ...]: + assert ( + len(x_shape) == len(w_shape) == len(pads) + 2 == len(dilations) + 2 == len(strides) + 2 + ), "x and w should have the same length; pads, strides, and dilatinos should have the same length; the length of pads should be that of x - 2" + output_dims = [ + math.floor( + (x_shape[i+2] + 2 * pads[i] - dilations[i] * (w_shape[i+2] - 1) - 1) + / strides[i] + + 1 + ) + for i in range(len(pads)) + ] + return (x_shape[0], w_shape[0]) + tuple(output_dims) + + +# convert a python tuple to a ctype void pointer +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + + +def test( + lib, + handle, + torch_device, + x_shape, + w_shape, + pads, + strides, + dilations, + add_bias, + mode, + clip_coef=0.0, + tensor_dtype=torch.float16, +): + assert len(pads) == len(strides) == len(dilations) + print( + f"Testing ConvAct on {torch_device} with x_shape: {x_shape}, w_shape: {w_shape}, add_bias: {add_bias} b_shape: {w_shape[0]}, pads: {pads}, strides: {strides}, dilations: {dilations}, mode: {mode}, clip_coef: {clip_coef} dtype:{tensor_dtype}" + ) + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + w = torch.rand(w_shape, dtype=tensor_dtype).to(torch_device) + b = torch.round((torch.rand(w_shape[0], dtype=tensor_dtype).to(torch_device) * 2 - 1) * 1000) / 1000 if add_bias else None + y = torch.zeros( + inferShape(x.shape, w.shape, pads, strides, dilations), dtype=tensor_dtype + ).to(torch_device) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = convAct(x, w, b, strides, pads, dilations, mode) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = convAct(x, w, b, strides, pads, dilations, mode) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + + x_tensor = to_tensor(x, lib) + w_tensor = to_tensor(w, lib) + b_tensor = to_tensor(b, lib) if b is not None else None + y_tensor = to_tensor(y, lib) + descriptor = infiniopConvActDescriptor_t() + + check_error( + lib.infiniopCreateConvActDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + w_tensor.descriptor, + b_tensor.descriptor if b_tensor else None, + tuple_to_void_p(pads), + tuple_to_void_p(strides), + tuple_to_void_p(dilations), + len(pads), + mode, + clip_coef, + ) + ) + workspaceSize = ctypes.c_uint64(0) + check_error( + lib.infiniopGetConvActWorkspaceSize(descriptor, ctypes.byref(workspaceSize)) + ) + workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device) + workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8)) + + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopConvAct( + descriptor, + workspace_ptr, + workspaceSize, + y_tensor.data, + x_tensor.data, + w_tensor.data, + b_tensor.data if b_tensor else None, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopConvAct( + descriptor, + workspace_ptr, + workspaceSize, + y_tensor.data, + x_tensor.data, + w_tensor.data, + b_tensor.data if b_tensor else None, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + + if (tensor_dtype == torch.float16): + assert torch.allclose(y, ans, atol=1e-5, rtol=1e-2, equal_nan=True) + else: + assert torch.allclose(y, ans, atol=1e-7, rtol=1e-3, equal_nan=True) + check_error(lib.infiniopDestroyConvActDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, w_shape, pads, strides, dilations, add_bias, mode in test_cases: + test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, add_bias, mode, tensor_dtype=torch.float16) + test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, add_bias, mode, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, w_shape, pads, strides, dilations, add_bias, mode in test_cases: + test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, add_bias, mode, tensor_dtype=torch.float16) + test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, add_bias, mode, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for x_shape, w_shape, pads, strides, dilations, add_bias, mode in test_cases: + test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, add_bias, mode, tensor_dtype=torch.float16) + test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, add_bias, mode, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, w_shape, pads, strides, dilations, add_bias, activation_mode + ( + (2, 2, 4), + (2, 2, 2), + (0,), + (1,), + (1,), + True, + 0, + ), + ( + (32, 3, 4), + (32, 3, 5), + (1,), + (1,), + (1,), + False, + 0, + ), + ( + (3, 7, 4), + (7, 7, 2), + (0,), + (1,), + (1,), + False, + 0, + ), + ( + (1, 3, 4, 4), + (2, 3, 3, 3), + (1, 1), + (1, 2), + (2, 1), + True, + 1, + ), + ( + (32, 3, 128, 128), + (64, 3, 5, 5), + (2, 2), + (2, 2), + (1, 1), + True, + 0, + ), + ( + (1, 1, 4, 4, 4), + (1, 1, 5, 5, 5), + (1, 1, 1), + (1, 1, 1), + (1, 1, 1), + False, + 1, + ), + ( + (3, 3, 32, 32, 32), + (6, 3, 5, 5, 5), + (3, 2, 2), + (4, 3, 3), + (2, 2, 1), + True, + 0, + ), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateConvActDescriptor.restype = c_int32 + lib.infiniopCreateConvActDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopConvActDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_uint64, + c_int32, + c_double, + ] + lib.infiniopConvAct.restype = c_int32 + lib.infiniopConvAct.argtypes = [ + infiniopConvActDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyConvActDescriptor.restype = c_int32 + lib.infiniopDestroyConvActDescriptor.argtypes = [ + infiniopConvActDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/src/devices/cpu/common_cpu.cc b/src/devices/cpu/common_cpu.cc index 7fb9e5d8..643cf3ae 100644 --- a/src/devices/cpu/common_cpu.cc +++ b/src/devices/cpu/common_cpu.cc @@ -91,6 +91,7 @@ uint64_t getPaddedSize(uint64_t ndim, uint64_t *shape, uint64_t const *pads) { void getPaddedShape(uint64_t ndim, uint64_t const *shape, uint64_t const *pads, uint64_t *padded_shape) { memcpy(padded_shape, shape, ndim * sizeof(uint64_t)); +#pragma unroll for (size_t i = 2; i < ndim; ++i) { padded_shape[i] += 2 * pads[i - 2]; } diff --git a/src/ops/conv/cpu/conv_cpu.h b/src/ops/conv/cpu/conv_cpu.h deleted file mode 100644 index 48a91990..00000000 --- a/src/ops/conv/cpu/conv_cpu.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef __CPU_CONV_H__ -#define __CPU_CONV_H__ - -#include "../../../devices/cpu/common_cpu.h" -#include "operators.h" -#include -#include -#include - -struct ConvCpuDescriptor { - Device device; - DT dtype; - uint64_t ndim; - uint64_t y_size; - uint64_t padded_x_size; - uint64_t const *x_shape; - uint64_t const *w_shape; - uint64_t const *y_shape; - uint64_t const *pads; - int64_t const *strides; - uint64_t const *dilations; -}; - -typedef struct ConvCpuDescriptor *ConvCpuDescriptor_t; - -infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t, - ConvCpuDescriptor_t *, - infiniopTensorDescriptor_t y, - infiniopTensorDescriptor_t x, - infiniopTensorDescriptor_t w, - void const *pads, - void const *strides, - void const *dilations, - uint64_t n); - -infiniopStatus_t cpuGetConvWorkspaceSize(ConvCpuDescriptor_t desc, uint64_t *size); - -infiniopStatus_t cpuConv(ConvCpuDescriptor_t desc, - void *workspace, uint64_t workspace_size, - void *y, void const *x, void const *w, - void *stream); - -infiniopStatus_t cpuDestroyConvDescriptor(ConvCpuDescriptor_t desc); - -#endif diff --git a/src/ops/conv/cuda/conv.cuh b/src/ops/conv/cuda/conv.cuh deleted file mode 100644 index 36f22e90..00000000 --- a/src/ops/conv/cuda/conv.cuh +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef __CUDA_CONV_H__ -#define __CUDA_CONV_H__ - -#include "../../../devices/cuda/common_cuda.h" -#include "../../../devices/cuda/cuda_handle.h" -#include "operators.h" -#include - -struct ConvCudaDescriptor { - Device device; - DT dtype; - int device_id; - std::shared_ptr> cudnn_handles_t; - cudnnTensorDescriptor_t const x_desc; - cudnnFilterDescriptor_t const w_desc; - cudnnTensorDescriptor_t const y_desc; - cudnnConvolutionDescriptor_t const op_desc; - cudnnConvolutionFwdAlgo_t algo; - const float alpha; - const float beta; - uint64_t workspace_size; -}; - -typedef struct ConvCudaDescriptor *ConvCudaDescriptor_t; - -infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t, - ConvCudaDescriptor_t *, - infiniopTensorDescriptor_t y, - infiniopTensorDescriptor_t x, - infiniopTensorDescriptor_t w, - void const *pads, - void const *strides, - void const *dilations, - uint64_t n); - -infiniopStatus_t cudaGetConvWorkspaceSize(ConvCudaDescriptor_t desc, uint64_t *size); - -infiniopStatus_t cudaConv(ConvCudaDescriptor_t desc, - void *workspace, uint64_t workspace_size, - void *y, void const *x, void const *w, - void *stream); - -infiniopStatus_t cudaDestroyConvDescriptor(ConvCudaDescriptor_t desc); - -#endif diff --git a/src/ops/conv/operator.cc b/src/ops/conv/operator.cc index 306527e5..d7e19992 100644 --- a/src/ops/conv/operator.cc +++ b/src/ops/conv/operator.cc @@ -1,96 +1,82 @@ +#include "../conv_base/conv_base.h" #include "../utils.h" -#include "operators.h" #include "ops/conv/conv.h" +#include "ops/conv_act/conv_act.h" -#ifdef ENABLE_CPU -#include "cpu/conv_cpu.h" -#endif -#ifdef ENABLE_NV_GPU -#include "../../devices/cuda/cuda_handle.h" -#include "cuda/conv.cuh" -#endif +struct _ConvDescriptor { + Device device; + infiniopConvBaseDescriptor_t conv_base_desc; + infiniopConvActDescriptor_t conv_act_desc; +}; -__C infiniopStatus_t infiniopCreateConvDescriptor( - infiniopHandle_t handle, - infiniopConvDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t y, - infiniopTensorDescriptor_t x, - infiniopTensorDescriptor_t w, - void *pads, - void *strides, - void *dilations, - uint64_t n) { - switch (handle->device) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuCreateConvDescriptor(handle, (ConvCpuDescriptor_t *) desc_ptr, y, x, w, pads, strides, dilations, n); -#endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaCreateConvDescriptor((CudaHandle_t) handle, (ConvCudaDescriptor_t *) desc_ptr, y, x, w, pads, strides, dilations, n); - } +typedef struct _ConvDescriptor *_ConvDescriptor_t; -#endif -#ifdef ENABLE_CAMBRICON_MLU - // TODO -#endif +__C infiniopStatus_t infiniopCreateConvDescriptor(infiniopHandle_t handle, + infiniopConvDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n) { + infiniopConvBaseDescriptor_t conv_base_desc = nullptr; + infiniopConvActDescriptor_t conv_act_desc = nullptr; + if (!b) { + CHECK_STATUS(infiniopCreateConvBaseDescriptor(handle, &conv_base_desc, y, x, w, pads, strides, dilations, n), STATUS_SUCCESS); + } else { + ConvActParam_t act_params; + CHECK_STATUS(infiniopCreateConvActDescriptor(handle, &conv_act_desc, y, x, w, b, pads, strides, dilations, n, INFINI_ACTIVATION_IDENTITY, act_params), STATUS_SUCCESS); } - return STATUS_BAD_DEVICE; + + // create descriptor + *(_ConvDescriptor_t *) desc_ptr = new _ConvDescriptor{ + handle->device, + conv_base_desc, + conv_act_desc, + }; + + return STATUS_SUCCESS; } __C infiniopStatus_t infiniopGetConvWorkspaceSize(infiniopConvDescriptor_t desc, uint64_t *size) { - switch (desc->device) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuGetConvWorkspaceSize((ConvCpuDescriptor_t) desc, size); -#endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaGetConvWorkspaceSize((ConvCudaDescriptor_t) desc, size); - } - -#endif -#ifdef ENABLE_CAMBRICON_MLU - // TODO -#endif + _ConvDescriptor_t _conv_desc = (_ConvDescriptor_t) desc; + if (_conv_desc->conv_base_desc) { + CHECK_STATUS(infiniopGetConvBaseWorkspaceSize(_conv_desc->conv_base_desc, size), STATUS_SUCCESS); + } else { + CHECK_STATUS(infiniopGetConvActWorkspaceSize(_conv_desc->conv_act_desc, size), STATUS_SUCCESS); } - return STATUS_BAD_DEVICE; + return STATUS_SUCCESS; } -__C infiniopStatus_t infiniopConv(infiniopConvDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void *stream) { - switch (desc->device) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuConv((ConvCpuDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); -#endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaConv((ConvCudaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); +__C infiniopStatus_t infiniopConv(infiniopConvDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, + void const *x, + void const *w, + void const *b, + void *stream) { + _ConvDescriptor_t _conv_desc = (_ConvDescriptor_t) desc; + if (_conv_desc->conv_base_desc) { + CHECK_STATUS(infiniopConvBase(_conv_desc->conv_base_desc, workspace, workspace_size, y, x, w, stream), STATUS_SUCCESS); + } else { + if (!b) { + WARN("The bias descriptor has been initialized, but no bias data is provided. The computation will proceed as if there is no bias and continue as far as possible."); } - -#endif -#ifdef ENABLE_CAMBRICON_MLU - // TODO -#endif + CHECK_STATUS(infiniopConvAct(_conv_desc->conv_act_desc, workspace, workspace_size, y, x, w, b, stream), STATUS_SUCCESS); } - return STATUS_BAD_DEVICE; + return STATUS_SUCCESS; } __C infiniopStatus_t infiniopDestroyConvDescriptor(infiniopConvDescriptor_t desc) { - switch (desc->device) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuDestroyConvDescriptor((ConvCpuDescriptor_t) desc); -#endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaDestroyConvDescriptor((ConvCudaDescriptor_t) desc); - } - -#endif -#ifdef ENABLE_CAMBRICON_MLU - // TODO -#endif + _ConvDescriptor_t _conv_desc = (_ConvDescriptor_t) desc; + if (_conv_desc->conv_base_desc) { + CHECK_STATUS(infiniopDestroyConvBaseDescriptor(_conv_desc->conv_base_desc), STATUS_SUCCESS); + } else { + CHECK_STATUS(infiniopDestroyConvActDescriptor(_conv_desc->conv_act_desc), STATUS_SUCCESS); } - return STATUS_BAD_DEVICE; + delete desc; + return STATUS_SUCCESS; } diff --git a/src/ops/conv_act/cpu/conv_act_cpu.cc b/src/ops/conv_act/cpu/conv_act_cpu.cc new file mode 100644 index 00000000..3f919b43 --- /dev/null +++ b/src/ops/conv_act/cpu/conv_act_cpu.cc @@ -0,0 +1,349 @@ +#include "conv_act_cpu.h" +#include "../../utils.h" + +// get the total number of elements in arr +inline uint64_t getTotalSize(const uint64_t *arr, uint64_t ndim) { + return std::accumulate(arr, arr + ndim, 1ULL, std::multiplies()); +} + +// check if padding is needed +inline bool requirePadding(uint64_t const *pads, uint64_t ndim) { + return std::any_of(pads, pads + ndim - 2, + [](uint64_t pad) { return pad > 0; }); +} + +// check if bias is needed +template +inline bool requireBias(const Tdata *b, uint64_t length) { + return std::any_of(b, b + length, [](const Tdata &bias) { return bias != 0; }); +} + +template +Tdata relu(const Tdata &x) { + return (x > Tdata(0)) ? x : Tdata(0); +} + +template +Tdata sigmoid(const Tdata &x) { + return Tdata(1) / (Tdata(1) + std::exp(-x)); +} + +infiniopStatus_t cpuCreateConvActDescriptor(infiniopHandle_t, + ConvActCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n, + InfiniActivationMode_t activation_mode, + ConvActParam_t act_params) { + uint64_t ndim = y->ndim; + if (ndim < 3 || ndim != x->ndim || ndim != w->ndim || n != ndim - 2) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (x->shape[0] != y->shape[0] || w->shape[0] != y->shape[1] || x->shape[1] != w->shape[1]) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != x->dt || y->dt != w->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (b) { + if (b->ndim != 1 || b->shape[0] != w->shape[0]) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != b->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + } + // check if the activation_mode is valid + if (activation_mode < 0 || activation_mode >= INFINI_ACTIVATION_COUNT) { + return STATUS_BAD_PARAM; + } + // check if the activation_mode is currently supported by this platform + switch (activation_mode) { + case INFINI_ACTIVATION_IDENTITY: + case INFINI_ACTIVATION_RELU: + case INFINI_ACTIVATION_SIGMOID: + break; + default: + return STATUS_BAD_PARAM; + } + + uint64_t y_size = getTotalSize(y->shape, ndim); + uint64_t padded_x_size = requirePadding(pads, ndim) ? getPaddedSize(ndim, x->shape, pads) : 0; + uint64_t *x_shape = new uint64_t[ndim]; + uint64_t *w_shape = new uint64_t[ndim]; + uint64_t *b_shape = new uint64_t[1]; + uint64_t *y_shape = new uint64_t[ndim]; + uint64_t *pads_ = new uint64_t[n]; + int64_t *strides_ = new int64_t[n]; + uint64_t *dilations_ = new uint64_t[n]; + memcpy(x_shape, x->shape, ndim * sizeof(*x->shape)); + memcpy(w_shape, w->shape, ndim * sizeof(*w->shape)); + memcpy(b_shape, &(w->shape[0]), sizeof(*w->shape)); + memcpy(y_shape, y->shape, ndim * sizeof(*y->shape)); + memcpy(pads_, pads, n * sizeof(*pads)); + memcpy(strides_, strides, n * sizeof(*strides)); + memcpy(dilations_, dilations, n * sizeof(*dilations)); + + uint64_t *padded_shape = nullptr; + + if (padded_x_size > 0) { + padded_shape = new uint64_t[ndim]; + getPaddedShape(ndim, x_shape, pads, padded_shape); + } + + *desc_ptr = new ConvActCpuDescriptor{ + DevCpu, + y->dt, + ndim, + y_size, + padded_x_size, + padded_shape, + x_shape, + w_shape, + b_shape, + y_shape, + pads_, + strides_, + dilations_, + activation_mode, + b == nullptr, + act_params, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuGetConvActWorkspaceSize(ConvActCpuDescriptor_t desc, uint64_t *size) { + *size = desc->padded_x_size * desc->dtype.size; + if (desc->dtype == F16) { + *size += desc->y_size * sizeof(float); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyConvActDescriptor(ConvActCpuDescriptor_t desc) { + delete[] desc->x_shape; + delete[] desc->w_shape; + delete[] desc->b_shape; + delete[] desc->y_shape; + delete[] desc->pads; + delete[] desc->strides; + delete[] desc->dilations; + delete[] desc->padded_shape; + delete desc; + return STATUS_SUCCESS; +} + +// initialize the padded input with the data from the original input +template +void fillPaddedInput(ConvActCpuDescriptor_t desc, uint64_t const *padded_x_shape, + Tdata *padded_x, Tdata const *x, + uint64_t const *pads, uint64_t x_index, + uint64_t padded_x_index, uint64_t ndim) { + const auto x_shape = desc->x_shape[ndim]; + const auto padded_x_shape_ = padded_x_shape[ndim]; + const auto x_base_index = x_index * x_shape; + const auto padded_x_base_index = padded_x_index * padded_x_shape_ + + (x_shape == padded_x_shape_ ? 0 : pads[ndim - 2]); +#pragma omp parallel for + for (size_t i = 0; i < x_shape; ++i) { + // base case (last dimension) + if (ndim == desc->ndim - 1) { + padded_x[padded_x_base_index + i] = x[x_base_index + i]; + } + // recursive case + else { + fillPaddedInput(desc, padded_x_shape, padded_x, x, pads, x_base_index + i, + padded_x_base_index + i, ndim + 1); + } + } +} + +// Recursive convolution function +template +void _applyConvAct(ConvActCpuDescriptor_t desc, Ydata *y, Xdata const *x, + Xdata const *w, uint64_t const *x_shape, + uint64_t x_index, uint64_t w_index, uint64_t y_index, + uint64_t ndim) { + const auto dim_size = x_shape[ndim]; + const auto kernel_size = desc->w_shape[ndim]; + const auto dilation = desc->dilations[ndim - 2]; + const auto stride = desc->strides[ndim - 2]; + const auto steps = + (dim_size - dilation * (kernel_size - 1) - 1) / stride + 1; + x_index *= dim_size; + w_index *= kernel_size; + y_index *= desc->y_shape[ndim]; + + // perform all the convolutions along this axis + for (size_t i = 0; i < steps; ++i, ++y_index) { +// perform a single convolution +#pragma unroll + for (size_t k = 0; k < kernel_size; ++k) { + // calculate the current indices + const auto curr_x_index = x_index + i * stride + k * dilation; + const auto curr_w_index = w_index + k; + + // base case (last dimension) + if (ndim == desc->ndim - 1) { + if constexpr (std::is_same_v) { + y[y_index] += f16_to_f32(x[curr_x_index]) * f16_to_f32(w[curr_w_index]); + } else { + y[y_index] += x[curr_x_index] * w[curr_w_index]; + } + } + // recursive case + else { + _applyConvAct(desc, y, x, w, x_shape, curr_x_index, curr_w_index, + y_index, ndim + 1); + } + } + } +} + +// add bias b to the output y +template +void addBias(Ydata *y, const Xdata *b, uint64_t batch_size, uint64_t out_channel_size, uint64_t in_channel_size, uint64_t num_channel_elements) { +#pragma omp parallel for collapse(2) + // batch + for (size_t i = 0; i < batch_size; ++i) { + // output channel + for (size_t j = 0; j < out_channel_size; ++j) { + uint64_t y_index = (i * in_channel_size + j) * num_channel_elements; + + // Add the bias to the output channel for the current batch + for (size_t yi = 0; yi < num_channel_elements; ++yi) { + if constexpr (std::is_same_v && std::is_same_v) { + y[y_index + yi] += f16_to_f32(b[j]); + } else { + y[y_index + yi] += b[j]; + } + } + } + } +} + +// apply activation function given the mode on the array arr with length n +template +void applyActivation(Tdata *arr, uint64_t n, InfiniActivationMode_t mode) { + if (mode != INFINI_ACTIVATION_IDENTITY) { + std::function activation_fn = [](Tdata &value) { value = value; }; + + switch (mode) { + case INFINI_ACTIVATION_RELU: + activation_fn = [](Tdata &value) { value = relu(value); }; + break; + case INFINI_ACTIVATION_SIGMOID: + activation_fn = [](Tdata &value) { value = sigmoid(value); }; + break; + default: + break; + } + +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + activation_fn(arr[i]); + } + } +} + +template +void applyConvAct(ConvActCpuDescriptor_t desc, Ydata *y, Xdata const *x, + Xdata const *w, Xdata const *b, uint64_t const *x_shape) { + const auto y_num_channel_elements = + getTotalSize(desc->y_shape + 2, desc->ndim - 2); + bool biasRequired = b && !desc->bias_is_optional && requireBias(b, desc->b_shape[0]); + +#pragma omp parallel for collapse(2) schedule(dynamic) + // batch + for (size_t i = 0; i < x_shape[0]; ++i) { + + // output channel + for (size_t j = 0; j < desc->w_shape[0]; ++j) { + uint64_t y_index = i * desc->y_shape[1] + j; + + // input channel + for (size_t k = 0; k < x_shape[1]; ++k) { + uint64_t x_index = i * x_shape[1] + k; + uint64_t w_index = j * desc->w_shape[1] + k; + _applyConvAct(desc, y, x, w, x_shape, x_index, w_index, y_index, 2); + } + } + } + + if (biasRequired) { + addBias(y, b, x_shape[0], desc->w_shape[0], desc->y_shape[1], y_num_channel_elements); + } + + // apply activation function + applyActivation(y, desc->y_size, desc->mode); +} + +template +void _conv_bias_act_cpu(ConvActCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, + Ydata *y, Xdata const *x, Xdata const *w, Xdata const *b) { + if (desc->padded_x_size > 0) { + auto padded_x = reinterpret_cast(workspace); + std::fill(padded_x, padded_x + desc->padded_x_size, 0); + fillPaddedInput(desc, desc->padded_shape, padded_x, x, desc->pads, 0, 0, 0); + applyConvAct(desc, y, padded_x, w, b, desc->padded_shape); + } else { + applyConvAct(desc, y, x, w, b, desc->x_shape); + } +} + +// Convolution function +template +infiniopStatus_t conv_bias_act_cpu(ConvActCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, void const *b) { + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto w_ = reinterpret_cast(w); + auto b_ = reinterpret_cast(b); + std::fill(y_, y_ + desc->y_size, 0); + _conv_bias_act_cpu(desc, workspace, workspace_size, y_, x_, w_, b_); + return STATUS_SUCCESS; +} + +// sepcial case for fp16 (uint16_t) +template<> +infiniopStatus_t conv_bias_act_cpu(ConvActCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, void const *b) { + auto y_ = reinterpret_cast(workspace); + auto x_ = reinterpret_cast(x); + auto w_ = reinterpret_cast(w); + auto b_ = reinterpret_cast(b); + std::fill(y_, y_ + desc->y_size, 0); + + _conv_bias_act_cpu(desc, y_ + desc->y_size, workspace_size, y_, x_, w_, b_); + + // copy data from y_ to y + auto y_16 = reinterpret_cast(y); +#pragma omp parallel for + for (size_t i = 0; i < desc->y_size; ++i) { + y_16[i] = f32_to_f16(y_[i]); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuConvAct(ConvActCpuDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void const *b, void *stream) { + if (desc->dtype == F16) { + return conv_bias_act_cpu(desc, workspace, workspace_size, y, x, w, b); + } + if (desc->dtype == F32) { + return conv_bias_act_cpu(desc, workspace, workspace_size, y, x, w, b); + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/conv_act/cpu/conv_act_cpu.h b/src/ops/conv_act/cpu/conv_act_cpu.h new file mode 100644 index 00000000..0116b57b --- /dev/null +++ b/src/ops/conv_act/cpu/conv_act_cpu.h @@ -0,0 +1,56 @@ +#ifndef __CPU_CONV_ACT_H__ +#define __CPU_CONV_ACT_H__ + +#include "../../../devices/cpu/common_cpu.h" +#include "operators.h" +#include "ops/conv_act/conv_act.h" +#include +#include +#include +#include +#include + +struct ConvActCpuDescriptor { + Device device; + DT dtype; + uint64_t ndim; + uint64_t y_size; + uint64_t padded_x_size; + uint64_t const *padded_shape; + uint64_t const *x_shape; + uint64_t const *w_shape; + uint64_t const *b_shape; + uint64_t const *y_shape; + uint64_t const *pads; + int64_t const *strides; + uint64_t const *dilations; + InfiniActivationMode_t mode; + bool bias_is_optional; + ConvActParam_t act_params; +}; + +typedef struct ConvActCpuDescriptor *ConvActCpuDescriptor_t; + +infiniopStatus_t cpuCreateConvActDescriptor(infiniopHandle_t, + ConvActCpuDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n, + InfiniActivationMode_t activation_mode, + ConvActParam_t act_params); + +infiniopStatus_t cpuGetConvActWorkspaceSize(ConvActCpuDescriptor_t desc, uint64_t *size); + +infiniopStatus_t cpuConvAct(ConvActCpuDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void const *b, void *stream); + +infiniopStatus_t cpuDestroyConvActDescriptor(ConvActCpuDescriptor_t desc); + +#endif diff --git a/src/ops/conv_act/cuda/conv_act.cc b/src/ops/conv_act/cuda/conv_act.cc new file mode 100644 index 00000000..88656b33 --- /dev/null +++ b/src/ops/conv_act/cuda/conv_act.cc @@ -0,0 +1,214 @@ +#include "conv_act.cuh" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateConvActDescriptor(CudaHandle_t handle, + ConvActCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n, + InfiniActivationMode_t activation_mode, + ConvActParam_t act_params) { + uint64_t ndim = y->ndim; + if (ndim < 3 || ndim != x->ndim || ndim != w->ndim || n != ndim - 2) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (x->shape[0] != y->shape[0] || w->shape[0] != y->shape[1] || x->shape[1] != w->shape[1]) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != x->dt || y->dt != w->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (b) { + if (b->ndim != 1 || b->shape[0] != w->shape[0]) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != b->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + } + // check if the activation_mode is valid + if (activation_mode < 0 || activation_mode >= INFINI_ACTIVATION_COUNT) { + return STATUS_BAD_PARAM; + } + cudnnActivationMode_t act_mode = [&] { + switch (activation_mode) { + case INFINI_ACTIVATION_IDENTITY: + return CUDNN_ACTIVATION_IDENTITY; + case INFINI_ACTIVATION_RELU: + return CUDNN_ACTIVATION_RELU; + default: + return CUDNN_ACTIVATION_SIGMOID; + } + }(); + // cudnnConvolutionBiasActivationForward() currently only supports identity and relu activations + if (act_mode != CUDNN_ACTIVATION_IDENTITY && act_mode != CUDNN_ACTIVATION_RELU) { + return STATUS_BAD_PARAM; + } + + const auto new_ndim = std::max(4UL, ndim); + const auto new_n = std::max(2UL, n); + // convert pads, strides, dilations into int32[] + int32_t pad[new_n]; + int32_t stride[new_n]; + int32_t dilation[new_n]; + int32_t x_shape[new_ndim]; + int32_t w_shape[new_ndim]; + int32_t b_shape[new_ndim]; + int32_t y_shape[new_ndim]; +#pragma unroll + for (size_t i = 0; i < new_n; ++i) { + pad[i] = i < n ? static_cast(pads[i]) : 0; + stride[i] = i < n ? static_cast(strides[i]) : 1; + dilation[i] = i < n ? static_cast(dilations[i]) : 1; + } +#pragma unroll + for (size_t i = 0; i < new_ndim; ++i) { + x_shape[i] = i < ndim ? static_cast(x->shape[i]) : 1; + w_shape[i] = i < ndim ? static_cast(w->shape[i]) : 1; + b_shape[i] = i == 1 ? static_cast(w->shape[0]) : 1; + y_shape[i] = i < ndim ? static_cast(y->shape[i]) : 1; + } + + // get the data types of the tensors and the conv operator + CREATE_CHECK_ERROR(auto tensor_dt = dataTypeMap[x->dt], tensor_dt, -1, STATUS_BAD_PARAM); + cudnnDataType_t conv_op_dt = [&] { + switch (tensor_dt) { + case CUDNN_DATA_HALF: + case CUDNN_DATA_BFLOAT16: + case CUDNN_DATA_FLOAT: + return CUDNN_DATA_FLOAT; + case CUDNN_DATA_DOUBLE: + return CUDNN_DATA_DOUBLE; + default: + return CUDNN_DATA_INT32; + } + }(); + + // create and set tensor descriptors for x + cudnnTensorDescriptor_t x_desc; + checkCudnnError(cudnnCreateTensorDescriptor(&x_desc)); + checkCudnnError(cudnnSetTensorNdDescriptorEx(x_desc, CUDNN_TENSOR_NCHW, static_cast(tensor_dt), new_ndim, x_shape)); + + // create and set tensor descriptors for w + cudnnFilterDescriptor_t w_desc; + checkCudnnError(cudnnCreateFilterDescriptor(&w_desc)); + checkCudnnError(cudnnSetFilterNdDescriptor(w_desc, static_cast(tensor_dt), CUDNN_TENSOR_NCHW, new_ndim, w_shape)); + + // create and set conv operator descriptor + cudnnConvolutionDescriptor_t op_desc; + checkCudnnError(cudnnCreateConvolutionDescriptor(&op_desc)); + checkCudnnError(cudnnSetConvolutionNdDescriptor( + op_desc, new_ndim - 2, pad, stride, dilation, CUDNN_CROSS_CORRELATION, + conv_op_dt)); + + cudnnSetConvolutionMathType(op_desc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); + + // create and set tensor descriptors for y + cudnnTensorDescriptor_t y_desc; + int outDim[new_ndim]; + checkCudnnError(cudnnGetConvolutionNdForwardOutputDim(op_desc, x_desc, w_desc, new_ndim, outDim)); + checkCudnnError(cudnnCreateTensorDescriptor(&y_desc)); + checkCudnnError(cudnnSetTensorNdDescriptorEx(y_desc, CUDNN_TENSOR_NCHW, static_cast(tensor_dt), new_ndim, y_shape)); + + // create the activation descriptor + cudnnActivationDescriptor_t act_desc; + checkCudnnError(cudnnCreateActivationDescriptor(&act_desc)); + checkCudnnError(cudnnSetActivationDescriptor(act_desc, act_mode, CUDNN_NOT_PROPAGATE_NAN, act_params.clip_coef)); + + // create the bias descriptor + cudnnTensorDescriptor_t b_desc; + checkCudnnError(cudnnCreateTensorDescriptor(&b_desc)); + checkCudnnError(cudnnSetTensorNdDescriptorEx(b_desc, CUDNN_TENSOR_NCHW, static_cast(tensor_dt), new_ndim, b_shape)); + + // get the best algorithm and the required workspace + cudnnConvolutionFwdAlgo_t algo; + size_t workspace_size = 0; + + if (act_mode == CUDNN_ACTIVATION_IDENTITY) { + algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + checkCudnnError(use_cudnn(handle->cudnn_handles_t, handle->device_id, nullptr, + [&](cudnnHandle_t handle) { return cudnnGetConvolutionForwardWorkspaceSize(handle, x_desc, w_desc, op_desc, y_desc, algo, &workspace_size); })); + } else {// tuning + int requestedAlgoCount = 1; + checkCudnnError(use_cudnn(handle->cudnn_handles_t, handle->device_id, nullptr, + [&](cudnnHandle_t handle) { return cudnnGetConvolutionForwardAlgorithmMaxCount(handle, &requestedAlgoCount); })); + int algoCounts = 0; + int chosenAlgoIndex = 0; + bool chosen = false; + + cudnnConvolutionFwdAlgoPerf_t perf_results[requestedAlgoCount]; + checkCudnnError(use_cudnn(handle->cudnn_handles_t, handle->device_id, nullptr, + [&](cudnnHandle_t handle) { return cudnnFindConvolutionForwardAlgorithm(handle, x_desc, w_desc, op_desc, y_desc, requestedAlgoCount, &algoCounts, perf_results); })); + if (algoCounts < 1) { + return STATUS_EXECUTION_FAILED; + } + for (int i = 0; i < algoCounts; ++i) { + if (use_cudnn(handle->cudnn_handles_t, handle->device_id, nullptr, + [&](cudnnHandle_t handle) { return cudnnGetConvolutionForwardWorkspaceSize(handle, x_desc, w_desc, op_desc, y_desc, perf_results[i].algo, &workspace_size); }) == CUDNN_STATUS_SUCCESS) { + chosenAlgoIndex = i; + chosen = true; + break; + } + } + if (!chosen) { + return STATUS_EXECUTION_FAILED; + } + algo = perf_results[chosenAlgoIndex].algo; + } + + // if bias is not given, add the workspace size needed by the optional bias + uint64_t bias_size = 0; + if (!b) { + bias_size = w->shape[0] * w->dt.size; + workspace_size += bias_size; + } + + const float alpha = 1.0f; + const float beta = 0.0f; + + *desc_ptr = new ConvActCudaDescriptor{ + DevNvGpu, + y->dt, + handle->device_id, + handle->cudnn_handles_t, + x_desc, + w_desc, + b_desc, + y_desc, + op_desc, + act_desc, + algo, + alpha, + beta, + workspace_size, + bias_size, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaGetConvActWorkspaceSize(ConvActCudaDescriptor_t desc, uint64_t *size) { + *size = desc->workspace_size; + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyConvActDescriptor(ConvActCudaDescriptor_t desc) { + checkCudnnError(cudnnDestroyActivationDescriptor(desc->act_desc)); + checkCudnnError(cudnnDestroyConvolutionDescriptor(desc->op_desc)); + checkCudnnError(cudnnDestroyTensorDescriptor(desc->y_desc)); + checkCudnnError(cudnnDestroyTensorDescriptor(desc->b_desc)); + checkCudnnError(cudnnDestroyFilterDescriptor(desc->w_desc)); + checkCudnnError(cudnnDestroyTensorDescriptor(desc->x_desc)); + desc->cudnn_handles_t = nullptr; + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/conv_act/cuda/conv_act.cu b/src/ops/conv_act/cuda/conv_act.cu new file mode 100644 index 00000000..f27162b6 --- /dev/null +++ b/src/ops/conv_act/cuda/conv_act.cu @@ -0,0 +1,32 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "conv_act.cuh" + +infiniopStatus_t conv_bias_act_nv_gpu(ConvActCudaDescriptor_t desc, void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, void const *b, void *stream) { + checkCudaError(cudaSetDevice(desc->device_id)); + void const *b_ = b; + if (!b || desc->bias_size != 0) { + b_ = reinterpret_cast(reinterpret_cast(workspace) + desc->workspace_size - desc->bias_size); + checkCudaErrorWithCode(cudaMemset((void *) b_, 0, desc->bias_size), STATUS_EXECUTION_FAILED); + } + void *workspace_ = (desc->bias_size == 0 || desc->workspace_size > desc->bias_size) ? workspace : nullptr; + checkCudnnError(use_cudnn(desc->cudnn_handles_t, desc->device_id, (cudaStream_t) stream, + [&](cudnnHandle_t handle) { return cudnnConvolutionBiasActivationForward(handle, &desc->alpha, + desc->x_desc, x, desc->w_desc, w, desc->op_desc, desc->algo, workspace_, workspace_size - desc->bias_size, + &desc->beta, desc->y_desc, y, desc->b_desc, b_, desc->act_desc, desc->y_desc, y); })); + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaConvAct(ConvActCudaDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void const *b, void *stream) { + if (workspace_size < desc->workspace_size) { + return STATUS_INSUFFICIENT_WORKSPACE; + } + if (desc->dtype == F16 || desc->dtype == F32) { + return conv_bias_act_nv_gpu(desc, workspace, workspace_size, y, x, w, b, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/conv_act/cuda/conv_act.cuh b/src/ops/conv_act/cuda/conv_act.cuh new file mode 100644 index 00000000..d8d55421 --- /dev/null +++ b/src/ops/conv_act/cuda/conv_act.cuh @@ -0,0 +1,53 @@ +#ifndef __CUDA_CONV_ACT_H__ +#define __CUDA_CONV_ACT_H__ + +#include "../../../devices/cuda/common_cuda.h" +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include "ops/conv_act/conv_act.h" +#include +#include + +struct ConvActCudaDescriptor { + Device device; + DT dtype; + int device_id; + std::shared_ptr> cudnn_handles_t; + cudnnTensorDescriptor_t const x_desc; + cudnnFilterDescriptor_t const w_desc; + cudnnTensorDescriptor_t const b_desc; + cudnnTensorDescriptor_t const y_desc; + cudnnConvolutionDescriptor_t const op_desc; + cudnnActivationDescriptor_t const act_desc; + cudnnConvolutionFwdAlgo_t algo; + const float alpha; + const float beta; + uint64_t workspace_size; + uint64_t bias_size; +}; + +typedef struct ConvActCudaDescriptor *ConvActCudaDescriptor_t; + +infiniopStatus_t cudaCreateConvActDescriptor(CudaHandle_t, + ConvActCudaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n, + InfiniActivationMode_t activation_mode, + ConvActParam_t act_params); + +infiniopStatus_t cudaGetConvActWorkspaceSize(ConvActCudaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t cudaConvAct(ConvActCudaDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void const *b, void *stream); + +infiniopStatus_t cudaDestroyConvActDescriptor(ConvActCudaDescriptor_t desc); + +#endif diff --git a/src/ops/conv_act/operator.cc b/src/ops/conv_act/operator.cc new file mode 100644 index 00000000..df17b83f --- /dev/null +++ b/src/ops/conv_act/operator.cc @@ -0,0 +1,99 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/conv_act/conv_act.h" + +#ifdef ENABLE_CPU +#include "cpu/conv_act_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/conv_act.cuh" +#endif + +__C infiniopStatus_t infiniopCreateConvActDescriptor( + infiniopHandle_t handle, + infiniopConvActDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + infiniopTensorDescriptor_t b, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n, + InfiniActivationMode_t activation_mode, + ConvActParam_t act_params) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateConvActDescriptor(handle, (ConvActCpuDescriptor_t *) desc_ptr, y, x, w, b, pads, strides, dilations, n, activation_mode, act_params); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateConvActDescriptor((CudaHandle_t) handle, (ConvActCudaDescriptor_t *) desc_ptr, y, x, w, b, pads, strides, dilations, n, activation_mode, act_params); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopGetConvActWorkspaceSize(infiniopConvActDescriptor_t desc, uint64_t *size) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuGetConvActWorkspaceSize((ConvActCpuDescriptor_t) desc, size); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaGetConvActWorkspaceSize((ConvActCudaDescriptor_t) desc, size); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopConvAct(infiniopConvActDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void const *b, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuConvAct((ConvActCpuDescriptor_t) desc, workspace, workspace_size, y, x, w, b, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaConvAct((ConvActCudaDescriptor_t) desc, workspace, workspace_size, y, x, w, b, stream); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyConvActDescriptor(infiniopConvActDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyConvActDescriptor((ConvActCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyConvActDescriptor((ConvActCudaDescriptor_t) desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} diff --git a/src/ops/conv_base/conv_base.h b/src/ops/conv_base/conv_base.h new file mode 100644 index 00000000..aef6fbc4 --- /dev/null +++ b/src/ops/conv_base/conv_base.h @@ -0,0 +1,30 @@ +#ifndef CONV_BASE_H +#define CONV_BASE_H + +#include "export.h" +#include "operators.h" + +typedef struct ConvBaseDescriptor { + Device device; +} ConvBaseDescriptor; + +typedef ConvBaseDescriptor *infiniopConvBaseDescriptor_t; + +__C infiniopStatus_t infiniopCreateConvBaseDescriptor(infiniopHandle_t handle, + infiniopConvBaseDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n); + +__C infiniopStatus_t infiniopGetConvBaseWorkspaceSize(infiniopConvBaseDescriptor_t desc, uint64_t *size); + +__C infiniopStatus_t infiniopConvBase(infiniopConvBaseDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void *stream); + +__C infiniopStatus_t infiniopDestroyConvBaseDescriptor(infiniopConvBaseDescriptor_t desc); + + +#endif diff --git a/src/ops/conv/cpu/conv_cpu.cc b/src/ops/conv_base/cpu/conv_base_cpu.cc similarity index 73% rename from src/ops/conv/cpu/conv_cpu.cc rename to src/ops/conv_base/cpu/conv_base_cpu.cc index 2646c482..f05e5d11 100644 --- a/src/ops/conv/cpu/conv_cpu.cc +++ b/src/ops/conv_base/cpu/conv_base_cpu.cc @@ -1,4 +1,4 @@ -#include "conv_cpu.h" +#include "conv_base_cpu.h" #include "../../utils.h" // get the total number of elements in arr @@ -12,15 +12,15 @@ inline bool requirePadding(uint64_t const *pads, uint64_t ndim) { [](uint64_t pad) { return pad > 0; }); } -infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t, - ConvCpuDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t y, - infiniopTensorDescriptor_t x, - infiniopTensorDescriptor_t w, - void const *pads, - void const *strides, - void const *dilations, - uint64_t n) { +infiniopStatus_t cpuCreateConvBaseDescriptor(infiniopHandle_t, + ConvBaseCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n) { uint64_t ndim = y->ndim; if (ndim < 3 || ndim != x->ndim || ndim != w->ndim) { return STATUS_BAD_TENSOR_SHAPE; @@ -36,33 +36,37 @@ infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t, } uint64_t y_size = getTotalSize(y->shape, ndim); - const auto pads_ = reinterpret_cast(pads); - uint64_t padded_x_size = requirePadding(pads_, ndim) ? getPaddedSize(ndim, x->shape, pads_) : 0; + uint64_t padded_x_size = requirePadding(pads, ndim) ? getPaddedSize(ndim, x->shape, pads) : 0; uint64_t *x_shape = new uint64_t[ndim]; uint64_t *w_shape = new uint64_t[ndim]; uint64_t *y_shape = new uint64_t[ndim]; - uint64_t *pad_ = new uint64_t[n]; + uint64_t *pads_ = new uint64_t[n]; int64_t *strides_ = new int64_t[n]; uint64_t *dilations_ = new uint64_t[n]; memcpy(x_shape, x->shape, ndim * sizeof(uint64_t)); memcpy(w_shape, w->shape, ndim * sizeof(uint64_t)); memcpy(y_shape, y->shape, ndim * sizeof(uint64_t)); - for (size_t i = 0; i < n; ++i) { - pad_[i] = pads_[i]; - strides_[i] = reinterpret_cast(strides)[i]; - dilations_[i] = reinterpret_cast(dilations)[i]; + memcpy(pads_, pads, n * sizeof(*pads)); + memcpy(strides_, strides, n * sizeof(*strides)); + memcpy(dilations_, dilations, n * sizeof(*dilations)); + + uint64_t *padded_shape = nullptr; + if (padded_x_size > 0) { + padded_shape = new uint64_t[ndim]; + getPaddedShape(ndim, x_shape, pads_, padded_shape); } - *desc_ptr = new ConvCpuDescriptor{ + *desc_ptr = new ConvBaseCpuDescriptor{ DevCpu, y->dt, ndim, y_size, padded_x_size, + padded_shape, x_shape, w_shape, y_shape, - pad_, + pads_, strides_, dilations_, }; @@ -70,7 +74,7 @@ infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t, return STATUS_SUCCESS; } -infiniopStatus_t cpuGetConvWorkspaceSize(ConvCpuDescriptor_t desc, uint64_t *size) { +infiniopStatus_t cpuGetConvBaseWorkspaceSize(ConvBaseCpuDescriptor_t desc, uint64_t *size) { *size = desc->padded_x_size * desc->dtype.size; if (desc->dtype == F16) { *size += desc->y_size * sizeof(float); @@ -78,7 +82,8 @@ infiniopStatus_t cpuGetConvWorkspaceSize(ConvCpuDescriptor_t desc, uint64_t *siz return STATUS_SUCCESS; } -infiniopStatus_t cpuDestroyConvDescriptor(ConvCpuDescriptor_t desc) { +infiniopStatus_t cpuDestroyConvBaseDescriptor(ConvBaseCpuDescriptor_t desc) { + delete[] desc->padded_shape; delete[] desc->x_shape; delete[] desc->w_shape; delete[] desc->y_shape; @@ -91,7 +96,7 @@ infiniopStatus_t cpuDestroyConvDescriptor(ConvCpuDescriptor_t desc) { // initialize the padded input with the data from the original input template -void fillPaddedInput(ConvCpuDescriptor_t desc, uint64_t const *padded_x_shape, +void fillPaddedInput(ConvBaseCpuDescriptor_t desc, uint64_t const *padded_x_shape, Tdata *padded_x, Tdata const *x, uint64_t const *pads, uint64_t x_index, uint64_t padded_x_index, uint64_t ndim) { @@ -116,7 +121,7 @@ void fillPaddedInput(ConvCpuDescriptor_t desc, uint64_t const *padded_x_shape, // Recursive convolution function template -void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x, +void _applyConv(ConvBaseCpuDescriptor_t desc, Ydata *y, Xdata const *x, Xdata const *w, uint64_t const *x_shape, uint64_t x_index, uint64_t w_index, uint64_t y_index, uint64_t ndim) { @@ -132,6 +137,7 @@ void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x, // perform all the convolutions along this axis for (size_t i = 0; i < steps; ++i, ++y_index) { +#pragma unroll // perform a single convolution for (size_t k = 0; k < kernel_size; ++k) { // calculate the current indices @@ -140,7 +146,7 @@ void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x, // base case (last dimension) if (ndim == desc->ndim - 1) { - if (desc->dtype == F16) { + if constexpr (std::is_same_v) { y[y_index] += f16_to_f32(x[curr_x_index]) * f16_to_f32(w[curr_w_index]); } else { y[y_index] += x[curr_x_index] * w[curr_w_index]; @@ -156,12 +162,12 @@ void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x, } template -void applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x, +void applyConv(ConvBaseCpuDescriptor_t desc, Ydata *y, Xdata const *x, Xdata const *w, uint64_t const *x_shape) { const auto y_num_channel_elements = getTotalSize(desc->y_shape + 2, desc->ndim - 2); -#pragma omp parallel for collapse(2) +#pragma omp parallel for collapse(2) schedule(dynamic) // batch for (size_t i = 0; i < x_shape[0]; ++i) { @@ -180,16 +186,13 @@ void applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x, } template -void _conv_cpu(ConvCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, +void _conv_cpu(ConvBaseCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, Ydata *y, Xdata const *x, Xdata const *w) { if (desc->padded_x_size > 0) { auto padded_x = reinterpret_cast(workspace); - std::vector padded_shape_(desc->ndim); - auto padded_shape = padded_shape_.data(); std::fill(padded_x, padded_x + desc->padded_x_size, 0); - getPaddedShape(desc->ndim, desc->x_shape, desc->pads, padded_shape); - fillPaddedInput(desc, padded_shape, padded_x, x, desc->pads, 0, 0, 0); - applyConv(desc, y, padded_x, w, padded_shape); + fillPaddedInput(desc, desc->padded_shape, padded_x, x, desc->pads, 0, 0, 0); + applyConv(desc, y, padded_x, w, desc->padded_shape); } else { applyConv(desc, y, x, w, desc->x_shape); } @@ -197,7 +200,7 @@ void _conv_cpu(ConvCpuDescriptor_t desc, void *workspace, uint64_t workspace_siz // Convolution function template -infiniopStatus_t conv_cpu(ConvCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, +infiniopStatus_t conv_cpu(ConvBaseCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w) { auto y_ = reinterpret_cast(y); auto x_ = reinterpret_cast(x); @@ -209,7 +212,7 @@ infiniopStatus_t conv_cpu(ConvCpuDescriptor_t desc, void *workspace, uint64_t wo // sepcial case for fp16 (uint16_t) template<> -infiniopStatus_t conv_cpu(ConvCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, +infiniopStatus_t conv_cpu(ConvBaseCpuDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w) { auto y_ = reinterpret_cast(workspace); auto x_ = reinterpret_cast(x); @@ -227,10 +230,10 @@ infiniopStatus_t conv_cpu(ConvCpuDescriptor_t desc, void *workspace, u return STATUS_SUCCESS; } -infiniopStatus_t cpuConv(ConvCpuDescriptor_t desc, - void *workspace, uint64_t workspace_size, - void *y, void const *x, void const *w, - void *stream) { +infiniopStatus_t cpuConvBase(ConvBaseCpuDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream) { if (desc->dtype == F16) { return conv_cpu(desc, workspace, workspace_size, y, x, w); } diff --git a/src/ops/conv_base/cpu/conv_base_cpu.h b/src/ops/conv_base/cpu/conv_base_cpu.h new file mode 100644 index 00000000..3b65559c --- /dev/null +++ b/src/ops/conv_base/cpu/conv_base_cpu.h @@ -0,0 +1,46 @@ +#ifndef __CPU_CONV_BASE_H__ +#define __CPU_CONV_BASE_H__ + +#include "../../../devices/cpu/common_cpu.h" +#include "operators.h" +#include +#include +#include + +struct ConvBaseCpuDescriptor { + Device device; + DT dtype; + uint64_t ndim; + uint64_t y_size; + uint64_t padded_x_size; + uint64_t const *padded_shape; + uint64_t const *x_shape; + uint64_t const *w_shape; + uint64_t const *y_shape; + uint64_t const *pads; + int64_t const *strides; + uint64_t const *dilations; +}; + +typedef struct ConvBaseCpuDescriptor *ConvBaseCpuDescriptor_t; + +infiniopStatus_t cpuCreateConvBaseDescriptor(infiniopHandle_t, + ConvBaseCpuDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n); + +infiniopStatus_t cpuGetConvBaseWorkspaceSize(ConvBaseCpuDescriptor_t desc, uint64_t *size); + +infiniopStatus_t cpuConvBase(ConvBaseCpuDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream); + +infiniopStatus_t cpuDestroyConvBaseDescriptor(ConvBaseCpuDescriptor_t desc); + +#endif diff --git a/src/ops/conv/cuda/conv.cc b/src/ops/conv_base/cuda/conv_base.cc similarity index 82% rename from src/ops/conv/cuda/conv.cc rename to src/ops/conv_base/cuda/conv_base.cc index 2ccabfda..c2e48f95 100644 --- a/src/ops/conv/cuda/conv.cc +++ b/src/ops/conv_base/cuda/conv_base.cc @@ -1,16 +1,16 @@ -#include "conv.cuh" +#include "conv_base.cuh" #include "../../../devices/cuda/common_cuda.h" #include "../../utils.h" -infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle, - ConvCudaDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t y, - infiniopTensorDescriptor_t x, - infiniopTensorDescriptor_t w, - void const *pads, - void const *strides, - void const *dilations, - uint64_t n) { +infiniopStatus_t cudaCreateConvBaseDescriptor(CudaHandle_t handle, + ConvBaseCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n) { uint64_t ndim = y->ndim; if (ndim < 3 || ndim != x->ndim || ndim != w->ndim) { return STATUS_BAD_TENSOR_SHAPE; @@ -25,7 +25,7 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle, return STATUS_BAD_TENSOR_DTYPE; } - const uint64_t new_ndim = std::max(ndim, (uint64_t)4); + const uint64_t new_ndim = std::max(ndim, (uint64_t) 4); // convert pads, strides, dilations into int32[] int32_t *pad = new int32_t[new_ndim]; int32_t *stride = new int32_t[new_ndim]; @@ -33,13 +33,10 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle, int32_t *x_shape = new int32_t[new_ndim]; int32_t *w_shape = new int32_t[new_ndim]; int32_t *y_shape = new int32_t[new_ndim]; - auto pads_ = reinterpret_cast(pads); - auto strides_ = reinterpret_cast(strides); - auto dilations_ = reinterpret_cast(dilations); for (size_t i = 0; i < new_ndim; ++i) { - pad[i] = i < ndim - 2 ? static_cast(pads_[i]) : 0; - stride[i] = i < ndim - 2 ? static_cast(strides_[i]) : 1; - dilation[i] = i < ndim - 2 ? static_cast(dilations_[i]) : 1; + pad[i] = i < ndim - 2 ? static_cast(pads[i]) : 0; + stride[i] = i < ndim - 2 ? static_cast(strides[i]) : 1; + dilation[i] = i < ndim - 2 ? static_cast(dilations[i]) : 1; x_shape[i] = i < ndim ? static_cast(x->shape[i]) : 1; w_shape[i] = i < ndim ? static_cast(w->shape[i]) : 1; y_shape[i] = i < ndim ? static_cast(y->shape[i]) : 1; @@ -93,6 +90,8 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle, checkCudnnError(cudnnCreateTensorDescriptor(&y_desc)); checkCudnnError(cudnnSetTensorNdDescriptorEx(y_desc, CUDNN_TENSOR_NCHW, static_cast(tensor_dt), new_ndim, y_shape)); + cudnnSetConvolutionMathType(op_desc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); + // tuning: get the best algorithm int requestedAlgoCount = 1; checkCudnnError(use_cudnn(handle->cudnn_handles_t, handle->device_id, nullptr, @@ -123,7 +122,7 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle, const float alpha = 1.0f; const float beta = 0.0f; - *desc_ptr = new ConvCudaDescriptor{ + *desc_ptr = new ConvBaseCudaDescriptor{ DevNvGpu, y->dt, handle->device_id, @@ -147,12 +146,12 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t cudaGetConvWorkspaceSize(ConvCudaDescriptor_t desc, uint64_t *size) { +infiniopStatus_t cudaGetConvBaseWorkspaceSize(ConvBaseCudaDescriptor_t desc, uint64_t *size) { *size = desc->workspace_size; return STATUS_SUCCESS; } -infiniopStatus_t cudaDestroyConvDescriptor(ConvCudaDescriptor_t desc) { +infiniopStatus_t cudaDestroyConvBaseDescriptor(ConvBaseCudaDescriptor_t desc) { checkCudnnError(cudnnDestroyConvolutionDescriptor(desc->op_desc)); checkCudnnError(cudnnDestroyTensorDescriptor(desc->y_desc)); checkCudnnError(cudnnDestroyFilterDescriptor(desc->w_desc)); diff --git a/src/ops/conv/cuda/conv.cu b/src/ops/conv_base/cuda/conv_base.cu similarity index 71% rename from src/ops/conv/cuda/conv.cu rename to src/ops/conv_base/cuda/conv_base.cu index 3f15843b..2ebf89d4 100644 --- a/src/ops/conv/cuda/conv.cu +++ b/src/ops/conv_base/cuda/conv_base.cu @@ -1,8 +1,8 @@ #include "../../../devices/cuda/common_cuda.h" #include "../../utils.h" -#include "conv.cuh" +#include "conv_base.cuh" -infiniopStatus_t conv_nv_gpu(ConvCudaDescriptor_t desc, void *workspace, uint64_t workspace_size, +infiniopStatus_t conv_nv_gpu(ConvBaseCudaDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void *stream) { checkCudaError(cudaSetDevice(desc->device_id)); checkCudnnError(use_cudnn(desc->cudnn_handles_t, desc->device_id, (cudaStream_t) stream, @@ -12,10 +12,10 @@ infiniopStatus_t conv_nv_gpu(ConvCudaDescriptor_t desc, void *workspace, uint64_ return STATUS_SUCCESS; } -infiniopStatus_t cudaConv(ConvCudaDescriptor_t desc, - void *workspace, uint64_t workspace_size, - void *y, void const *x, void const *w, - void *stream) { +infiniopStatus_t cudaConvBase(ConvBaseCudaDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream) { if (desc->dtype == F16 || desc->dtype == F32) { return conv_nv_gpu(desc, workspace, workspace_size, y, x, w, stream); } diff --git a/src/ops/conv_base/cuda/conv_base.cuh b/src/ops/conv_base/cuda/conv_base.cuh new file mode 100644 index 00000000..f40a9413 --- /dev/null +++ b/src/ops/conv_base/cuda/conv_base.cuh @@ -0,0 +1,45 @@ +#ifndef __CUDA_CONV_H__ +#define __CUDA_CONV_H__ + +#include "../../../devices/cuda/common_cuda.h" +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include + +struct ConvBaseCudaDescriptor { + Device device; + DT dtype; + int device_id; + std::shared_ptr> cudnn_handles_t; + cudnnTensorDescriptor_t const x_desc; + cudnnFilterDescriptor_t const w_desc; + cudnnTensorDescriptor_t const y_desc; + cudnnConvolutionDescriptor_t const op_desc; + cudnnConvolutionFwdAlgo_t algo; + const float alpha; + const float beta; + uint64_t workspace_size; +}; + +typedef struct ConvBaseCudaDescriptor *ConvBaseCudaDescriptor_t; + +infiniopStatus_t cudaCreateConvBaseDescriptor(CudaHandle_t, + ConvBaseCudaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n); + +infiniopStatus_t cudaGetConvBaseWorkspaceSize(ConvBaseCudaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t cudaConvBase(ConvBaseCudaDescriptor_t desc, + void *workspace, uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream); + +infiniopStatus_t cudaDestroyConvBaseDescriptor(ConvBaseCudaDescriptor_t desc); + +#endif diff --git a/src/ops/conv_base/operator.cc b/src/ops/conv_base/operator.cc new file mode 100644 index 00000000..5c0ac798 --- /dev/null +++ b/src/ops/conv_base/operator.cc @@ -0,0 +1,96 @@ +#include "../utils.h" +#include "conv_base.h" +#include "operators.h" + +#ifdef ENABLE_CPU +#include "cpu/conv_base_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/conv_base.cuh" +#endif + +__C infiniopStatus_t infiniopCreateConvBaseDescriptor( + infiniopHandle_t handle, + infiniopConvBaseDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t w, + uint64_t const *pads, + int64_t const *strides, + uint64_t const *dilations, + uint64_t n) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateConvBaseDescriptor(handle, (ConvBaseCpuDescriptor_t *) desc_ptr, y, x, w, pads, strides, dilations, n); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateConvBaseDescriptor((CudaHandle_t) handle, (ConvBaseCudaDescriptor_t *) desc_ptr, y, x, w, pads, strides, dilations, n); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopGetConvBaseWorkspaceSize(infiniopConvBaseDescriptor_t desc, uint64_t *size) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuGetConvBaseWorkspaceSize((ConvBaseCpuDescriptor_t) desc, size); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaGetConvBaseWorkspaceSize((ConvBaseCudaDescriptor_t) desc, size); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopConvBase(infiniopConvBaseDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuConvBase((ConvBaseCpuDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaConvBase((ConvBaseCudaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyConvBaseDescriptor(infiniopConvBaseDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyConvBaseDescriptor((ConvBaseCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyConvBaseDescriptor((ConvBaseCudaDescriptor_t) desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} diff --git a/src/ops/utils.h b/src/ops/utils.h index b48cf419..4fb092de 100644 --- a/src/ops/utils.h +++ b/src/ops/utils.h @@ -28,6 +28,15 @@ inline void assert_true(int expr, const char *msg, const char *file, int line) { printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \ exit(EXIT_FAILURE) +#define WARN(msg) \ + do { \ + if constexpr (std::is_same_v, const char *>) { \ + fprintf(stderr, "\033[33mWarning: %s\033[0m\n", msg); \ + } else { \ + fprintf(stderr, "\033[33mWarning: %s\033[0m\n", #msg); \ + } \ + } while (0) + #define ROUND_UP_DIV(x, y) ((x + y - 1) / y) #define CHECK_ERROR(call, target, errCode) \