diff --git a/include/infini_operators.h b/include/infini_operators.h index 9a5a2555..0ebeafc4 100644 --- a/include/infini_operators.h +++ b/include/infini_operators.h @@ -3,8 +3,10 @@ #include "ops/attention/attention.h" #include "ops/avg_pool/avg_pool.h" #include "ops/causal_softmax/causal_softmax.h" +#include "ops/clip/clip.h" #include "ops/global_avg_pool/global_avg_pool.h" #include "ops/expand/expand.h" +#include "ops/gather/gather.h" #include "ops/gemm/gemm.h" #include "ops/conv/conv.h" #include "ops/matmul/matmul.h" @@ -12,8 +14,12 @@ #include "ops/mlp/mlp.h" #include "ops/random_sample/random_sample.h" #include "ops/rearrange/rearrange.h" +#include "ops/reduce_max/reduce_max.h" +#include "ops/reduce_mean/reduce_mean.h" +#include "ops/reduce_min/reduce_min.h" #include "ops/relu/relu.h" #include "ops/rms_norm/rms_norm.h" #include "ops/rotary_embedding/rotary_embedding.h" #include "ops/swiglu/swiglu.h" +#include "ops/where/where.h" #include "tensor/tensor_descriptor.h" diff --git a/include/ops/clip/clip.h b/include/ops/clip/clip.h new file mode 100644 index 00000000..d07e67ea --- /dev/null +++ b/include/ops/clip/clip.h @@ -0,0 +1,29 @@ +#ifndef CLIP_H +#define CLIP_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ClipDescriptor { + Device device; +} ClipDescriptor; + +typedef ClipDescriptor *infiniopClipDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle, + infiniopClipDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t intput_desc, + infiniopTensorDescriptor_t min_desc, + infiniopTensorDescriptor_t max_desc); + +__C __export infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc, + void *output, + void *input, + void const *min, + void const *max, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/ops/gather/gather.h b/include/ops/gather/gather.h new file mode 100644 index 00000000..8c9a98cb --- /dev/null +++ b/include/ops/gather/gather.h @@ -0,0 +1,27 @@ +#ifndef GATHER_H +#define GATHER_H + +#include "../../export.h" +#include "../../operators.h" +typedef struct GatherDescriptor { + Device device; +} GatherDescriptor; + +typedef GatherDescriptor *infiniopGatherDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateGatherDescriptor(infiniopHandle_t handle, + infiniopGatherDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t data_desc, + infiniopTensorDescriptor_t index_desc, + int axis); + +__C __export infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc, + void *output, + void *data, + void const *indices, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/ops/reduce_max/reduce_max.h b/include/ops/reduce_max/reduce_max.h new file mode 100644 index 00000000..48333836 --- /dev/null +++ b/include/ops/reduce_max/reduce_max.h @@ -0,0 +1,31 @@ +#ifndef __REDUCE_MAX_H__ +#define __REDUCE_MAX_H__ + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReduceMaxDescriptor { + Device device; +} ReduceMaxDescriptor; + +typedef ReduceMaxDescriptor *infiniopReduceMaxDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMaxDescriptor(infiniopHandle_t handle, + infiniopReduceMaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes); + + +__C __export infiniopStatus_t infiniopReduceMax(infiniopReduceMaxDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyReduceMaxDescriptor(infiniopReduceMaxDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/ops/reduce_mean/reduce_mean.h b/include/ops/reduce_mean/reduce_mean.h new file mode 100644 index 00000000..f5480795 --- /dev/null +++ b/include/ops/reduce_mean/reduce_mean.h @@ -0,0 +1,30 @@ +#ifndef __REDUCE_MEAN_H__ +#define __REDUCE_MEAN_H__ + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReduceMeanDescriptor { + Device device; +} ReduceMeanDescriptor; + +typedef ReduceMeanDescriptor *infiniopReduceMeanDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMeanDescriptor(infiniopHandle_t handle, + infiniopReduceMeanDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes); + +__C __export infiniopStatus_t infiniopReduceMean(infiniopReduceMeanDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyReduceMeanDescriptor(infiniopReduceMeanDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/ops/reduce_min/reduce_min.h b/include/ops/reduce_min/reduce_min.h new file mode 100644 index 00000000..dcf02e99 --- /dev/null +++ b/include/ops/reduce_min/reduce_min.h @@ -0,0 +1,30 @@ +#ifndef __REDUCE_MIN_H__ +#define __REDUCE_MIN_H__ + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReduceMinDescriptor { + Device device; +} ReduceMinDescriptor; + +typedef ReduceMinDescriptor *infiniopReduceMinDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMinDescriptor(infiniopHandle_t handle, + infiniopReduceMinDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes); + +__C __export infiniopStatus_t infiniopReduceMin(infiniopReduceMinDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyReduceMinDescriptor(infiniopReduceMinDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/ops/where/where.h b/include/ops/where/where.h new file mode 100644 index 00000000..c80ee33b --- /dev/null +++ b/include/ops/where/where.h @@ -0,0 +1,29 @@ +#ifndef WHERE_H +#define WHERE_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct WhereDescriptor { + Device device; +} WhereDescriptor; + +typedef WhereDescriptor *infiniopWhereDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t condition_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc); + +__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc, + void *output, + void *condition, + void *x, + void *y, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/tensor/tensor_descriptor.h b/include/tensor/tensor_descriptor.h index 2fb9fc1d..5476dfe3 100644 --- a/include/tensor/tensor_descriptor.h +++ b/include/tensor/tensor_descriptor.h @@ -5,6 +5,15 @@ #include "../tensor.h" #include "../status.h" +/** + * @brief 根据给定的参数,创建表示张量的描述符 + * @param desc_ptr 保存所创建的张量描述符 + * @param ndim 所表示张量的阶 + * @param shape_ 张量的形状 + * @param strides_ 张量的步长 + * @param datatype 张量元素的类型 + * @return 返回表示创建是否成功的状态码 + */ __C __export infiniopStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescriptor_t *desc_ptr, uint64_t ndim, uint64_t const *shape_, int64_t const *strides_, DataLayout datatype); __C __export infiniopStatus_t infiniopDestroyTensorDescriptor(infiniopTensorDescriptor_t desc); diff --git a/operatorspy/tests/clip.py b/operatorspy/tests/clip.py new file mode 100644 index 00000000..1d86046e --- /dev/null +++ b/operatorspy/tests/clip.py @@ -0,0 +1,155 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + CTensor, + DeviceEnum, + open_lib, + infiniopHandle_t, + infiniopTensorDescriptor_t, + to_tensor, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +from enum import Enum, auto +import torch + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE = auto() + +class ClipDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopClipDescriptor_t = POINTER(ClipDescriptor) + +def clip(x, min, max): + return torch.clamp(x, min, max) + +def test( + lib, + handle, + torch_device, + input_shape, + min_val, + max_val, + tensor_dtype=torch.float16, + inplace=Inplace.OUT_OF_PLACE, +): + print( + f"Testing Clip on {torch_device} with input_shape:{input_shape} min:{min_val} max:{max_val} dtype:{tensor_dtype} inplace: {inplace.name}" + ) + + # 随机生成输入张量 + input_data= torch.rand(input_shape, dtype=tensor_dtype).to(torch_device) + # 随机生成输出张量 + output_data = torch.rand(input_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else input_data + # 将 min 和 max 转换为标量张量 + min = torch.tensor(min_val, dtype=tensor_dtype).to(torch_device) if min_val is not None else None + max = torch.tensor(max_val, dtype=tensor_dtype).to(torch_device) if max_val is not None else None + + # 计算正确的结果(pytorch 的 clamp 不支持 min 和 max 均为空,跳过这种情况) + if (min_val is not None) or (max_val is not None): + ans = clip(input_data, min, max) + + # 将 torch 张量转换为 infiniop 张量 + input_tensor = to_tensor(input_data, lib) + output_tensor = to_tensor(output_data, lib) if inplace == Inplace.OUT_OF_PLACE else input_tensor + min_tensor = to_tensor(min, lib) if min is not None else None + max_tensor = to_tensor(max, lib) if max is not None else None + + # 创建 clip 算子描述符的指针 + descriptor = infiniopClipDescriptor_t() + # 创建 clip 算子描述符 + check_error( + lib.infiniopCreateClipDescriptor( + handle, + ctypes.byref(descriptor), + output_tensor.descriptor, + input_tensor.descriptor, + min_tensor.descriptor if min_tensor is not None else None, + max_tensor.descriptor if max_tensor is not None else None, + ) + ) + + # 标记输入和输出张量的形状和步长信息为无效,检查错误的实现 + input_tensor.descriptor.contents.invalidate() + output_tensor.descriptor.contents.invalidate() + + # 执行 clip 算子 + check_error( + lib.infiniopClip( + descriptor, + output_tensor.data, + input_tensor.data, + min_tensor.data if min_tensor is not None else None, + max_tensor.data if max_tensor is not None else None, + None + ) + ) + + # 检查结果是否正确(pytorch 的 clamp 不支持 min 和 max 均为空,跳过这种情况的比较,动态库支持这种操作) + if (min_val is not None) or (max_val is not None): + assert torch.allclose(output_data, ans, atol=0, rtol=0) + + # 销毁 clip 算子描述符 + check_error(lib.infiniopDestroyClipDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for input_shape, min, max, inplace in test_cases: + # 测试数据类型为 float16 + test(lib, handle, "cpu", input_shape, min, max, tensor_dtype=torch.float16, inplace=inplace) + # 测试数据类型为 float32 + test(lib, handle, "cpu", input_shape, min, max, tensor_dtype=torch.float32, inplace=inplace) + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # input_shape, min, max, inplace + ((), 0, 1, Inplace.OUT_OF_PLACE), # 测试传入标量时的正确性 + ((), 0.4, 0.6, Inplace.OUT_OF_PLACE), # 传入标量时可以正确执行 + ((1, 3), 0.4, 0.6, Inplace.OUT_OF_PLACE), # 传入标量时可以正确执行 + ((4, 3), 0.3, 0.6, Inplace.OUT_OF_PLACE), # min 和 max 均非空时可以正确执行 + ((4, 3), None, 0.6, Inplace.OUT_OF_PLACE), # min 为空时可以正确执行 + ((4, 3), 0.3, None, Inplace.OUT_OF_PLACE), # max 为空时可以正确执行 + ((4, 3), None, None, Inplace.OUT_OF_PLACE), # min 和 max 均为空时可以正确执行 + ((32, 20, 512), 0.4, 0.6, Inplace.OUT_OF_PLACE), # 较多元素的张量 + ((32, 20, 512), 0.4, 0.6, Inplace.INPLACE), # 较多元素的张量可以原地执行 + ((2, 3, 4, 5, 6), 0.4, 0.6, Inplace.OUT_OF_PLACE) # 较多维度的张量 + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateClipDescriptor.restype = c_int32 + lib.infiniopCreateClipDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopClipDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_void_p, + ] + lib.infiniopClip.restype = c_int32 + lib.infiniopClip.argtypes = [ + infiniopClipDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyClipDescriptor.restype = c_int32 + lib.infiniopDestroyClipDescriptor.argtypes = [ + infiniopClipDescriptor_t, + ] + if args.cpu: + test_cpu(lib, test_cases) + if not (args.cpu): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/operatorspy/tests/gather.py b/operatorspy/tests/gather.py new file mode 100644 index 00000000..f8bc22bd --- /dev/null +++ b/operatorspy/tests/gather.py @@ -0,0 +1,229 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p +import ctypes +import sys +import os + +# 将当前文件 ../../ 添加到环境变量中,这样就能使用 operatorspy 模块了 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + # operatorspy/devices.py 中的类型 + DeviceEnum, + # operatorspy/operators.py 中的函数和类型 + open_lib, + infiniopHandle_t, + infiniopTensorDescriptor_t, + # 下面是 operatorspy/utils.py 中的函数 + to_tensor, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch + + +class GatherDescriptor(Structure): + """定义与 C 语言中的 GatherDescriptor 类型相同的结构体 + """ + _fields_ = [("device", c_int32)] + + +# infiniopGatherDescriptor_t 类型是 C 语言中指向 GatherDescriptor 类型的指针 +infiniopGatherDescriptor_t = POINTER(GatherDescriptor) + +def gather(data, axis, index): + """使用 pytorch 实现 gather 操作 + Args: + data: 要获取的数据 + axis: 合并的维度 + index: 索引张量 + Returns: + 计算结果 + """ + # 计算合并后的输出张量形状 + output_ndim = data.ndim + index.ndim - 1 + output_shape = [0 for i in range(output_ndim)] + axis = axis if axis >= 0 else axis + data.ndim + for i in range(output_ndim): + if i < axis: + output_shape[i] = data.shape[i] + elif i >= axis + index.ndim: + output_shape[i] = data.shape[i - index.ndim + 1] + else: + output_shape[i] = index.shape[i - axis] + + # 初始化输出张量 + output = torch.zeros(output_shape, dtype=data.dtype) + # 计算输出张量的元素个数 + output_elements = output.numel() + # 初始化遍历输出张量的索引 + output_indices = [0 for i in range(output_ndim)] + # 遍历输出张量的每个元素 + for i in range(output_elements): + # 以 axis 分界从 output_indices 中分别提取 index 和 data 的索引 + index_indices = [0 for i in range(index.ndim)] + data_indices = [0 for i in range(data.ndim)] + for j in range(output_ndim): + if j < axis: + data_indices[j] = output_indices[j] + elif j >= axis + index.ndim: + data_indices[j - index.ndim + 1] = output_indices[j] + else: + index_indices[j - axis] = output_indices[j] + # 从 index 中获取索引值,并赋值给 data_indices 的对应位置 + index_value = index[tuple(index_indices)] + data_indices[axis] = index_value.item() + # 从 data 中获取数据值 + data_value = data[tuple(data_indices)] + # 将数据值赋值给输出张量 + output[tuple(output_indices)] = data_value + + # print("索引", tuple(output_indices), "元素", output[tuple(output_indices)]) + # 递增 output_indices 中的索引 + for j in range(output_ndim - 1, -1, -1): + output_indices[j] += 1 + if output_indices[j] < output_shape[j]: + break + output_indices[j] = 0 + + return output + +def test( + lib, + handle, + torch_device, + output_shape, + data_shape, + index_shape, + axis, + data_dtype +): + """使用所加载的动态库,执行 gather 操作,并与 pytorch 计算的结果进行比较 + + Args: + lib: 动态库 + handle: 句柄 + torch_device: pytorch 设备 + output_shape: 输出张量的形状 + data_shape: 输入张量的形状 + index_shape: 索引张量的形状 + axis: 合并的维度 + data_dtype: 输入张量的数据类型 + """ + print( + f"Testing Add on {torch_device} with output_shape:{output_shape} data_shape:{data_shape} index_shape:{index_shape} axis: {axis} data_dtype:{data_dtype}" + ) + # 生成随机数据 + data = torch.rand(data_shape, dtype=data_dtype).to(torch_device) + output = torch.zeros(output_shape, dtype=data_dtype).to(torch_device) + # index_shape 为标量时,直接作为索引 + if (type(index_shape) != tuple): + index = torch.tensor(index_shape, dtype=torch.int32).to(torch_device) + else: + index = torch.randint(-data_shape[axis], data_shape[axis], index_shape, dtype=torch.int32).to(torch_device) + + # 使用 pytorch 计算正确结果 + ans = gather(data, axis, index) + + # 将 pytorch 张量转换为 infiniop 张量 + data_tensor = to_tensor(data, lib) + output_tensor = to_tensor(output, lib) + index_tensor = to_tensor(index, lib) + descriptor = infiniopGatherDescriptor_t() + + check_error( + lib.infiniopCreateGatherDescriptor( + handle, + ctypes.byref(descriptor), + output_tensor.descriptor, + data_tensor.descriptor, + index_tensor.descriptor, + axis, + ) + ) + + # 置空所有张量的相关信息,避免在实际运算时直接使用这些信息 + data_tensor.descriptor.contents.invalidate() + index_tensor.descriptor.contents.invalidate() + output_tensor.descriptor.contents.invalidate() + check_error( + lib.infiniopGather(descriptor, output_tensor.data, data_tensor.data, index_tensor.data, None) + ) + + # 比较输出结果 + assert torch.allclose(output, ans, atol=0, rtol=0) + + # 销毁算子描述符 + check_error(lib.infiniopDestroyGatherDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + """测试 CPU 上的给定算子 + Args: + lib: 要测试的 CPU 算子动态库 + test_cases: 要执行的测试用例 + """ + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for output_shape, data_shape, index_shape, axis in test_cases: + # 测试数据类型为 float16 的数据 + test(lib, handle, "cpu", output_shape, data_shape, index_shape, axis, data_dtype=torch.float16) + # 测试数据类型为 float32 的数据 + test(lib, handle, "cpu", output_shape, data_shape, index_shape, axis, data_dtype=torch.float32) + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # output_shape, data_shape, index_shape, axis + ((5, 6, 3, 4), (2, 3, 4), (5, 6), 0), + ((2, 5, 6, 4), (2, 3, 4), (5, 6), 1), + ((2, 3, 5, 6), (2, 3, 4), (5, 6), 2), + ((5, 6, 3, 4), (2, 3, 4), (5, 6), -3), # axis 为负的情况 + ((2, 5, 6, 4), (2, 3, 4), (5, 6), -2), + ((2, 3, 5, 6), (2, 3, 4), (5, 6), -1), + ((7, 8, 4, 5, 6), (3, 4, 5, 6), (7, 8), 0), # 较大尺寸的张量 + ((3, 7, 8, 5, 6), (3, 4, 5, 6), (7, 8), 1), + ((3, 4, 7, 8, 6), (3, 4, 5, 6), (7, 8), 2), + ((3, 4, 5, 7, 8), (3, 4, 5, 6), (7, 8), 3), + ((200, 32, 32, 7, 8), (200, 32, 32, 3), (7, 8), 3), + ((3, 4), (2, 3, 4), 0, 0), # index 为标量 + ((3, 4), (2, 3, 4), 1, 0), + ((2, 4), (2, 3, 4), 0, 1), + ((2, 4), (2, 3, 4), 1, 1), + ((2, 4), (2, 3, 4), 2, 1), + ((2, 3), (2, 3, 4), 0, 2), + ((2, 3), (2, 3, 4), 1, 2), + ((2, 3), (2, 3, 4), 2, 2), + ((2, 3), (2, 3, 4), 3, 2), + ((2, 3), (2, 3, 4), -1, 2), # index 为负的标量 + ((2, 3), (2, 3, 4), -2, 2), + ((2, 3), (2, 3, 4), -3, 2), + ((2, 3), (2, 3, 4), -4, 2) + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateGatherDescriptor.restype = c_int32 + lib.infiniopCreateGatherDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopGatherDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int32, + ] + lib.infiniopGather.restype = c_int32 + lib.infiniopGather.argtypes = [ + infiniopGatherDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyGatherDescriptor.restype = c_int32 + lib.infiniopDestroyGatherDescriptor.argtypes = [infiniopGatherDescriptor_t] + if args.cpu: + test_cpu(lib, test_cases) + if not (args.cpu): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/reduce_max.py b/operatorspy/tests/reduce_max.py new file mode 100644 index 00000000..017bd22a --- /dev/null +++ b/operatorspy/tests/reduce_max.py @@ -0,0 +1,199 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_int64, c_size_t +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch + +class ReduceMaxDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopReduceMaxDescriptor_t = POINTER(ReduceMaxDescriptor) + +def reduce_max(data, dim, keepdim): + return torch.amax(data, dim=dim, keepdim=keepdim) + +def test(lib, handle, device, reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype): + print( + f"Testing ReduceMax on {device} with reduced_shape:{reduced_shape} data_shape:{data_shape} axes:{axes} " \ + f"axes_size:{axes_size} keepdims:{keepdims} noop_with_empty_axes:{noop_with_empty_axes} dtype:{tensor_dtype}" + ) + + # 生成随机数据 + data = torch.randn(data_shape, dtype=tensor_dtype).to(device) + reduced = torch.randn(reduced_shape, dtype=tensor_dtype).to(device) + + # 调用 pytorch 的函数获取实际结果 + if isinstance(data_shape, tuple) and data_shape == (0,): + # 如果产生的是空集合,ans 应该为无穷小或最小值 + ans = torch.tensor(-float("inf"), dtype=tensor_dtype).to(device) + else: + if (axes is None) and (noop_with_empty_axes == 1): + # axes 为空,且 noop_with_empty_axes 为 1,应该返回原数组 + ans = data + else: + # 其他情况都直接调用 pytorch 的函数返回结果 + ans = reduce_max(data, + dim=axes, + keepdim=False if (keepdims==0) else True) + assert ans.shape == reduced.shape + + # 将 pytorch 的数据转换为 tensor + data_tensor = to_tensor(data, lib) + reduced_tensor = to_tensor(reduced, lib) + axes_ptr = None + if axes is not None: + axes_tensor = torch.tensor(axes, dtype=torch.int64).to(device) + axes_ptr = ctypes.cast(axes_tensor.data_ptr(), ctypes.POINTER(ctypes.c_int64)) + + # 创建 descriptor + descriptor = infiniopReduceMaxDescriptor_t() + check_error( + lib.infiniopCreateReduceMaxDescriptor( + handle, + ctypes.byref(descriptor), + reduced_tensor.descriptor, + data_tensor.descriptor, + axes_ptr, + axes_size, + keepdims if keepdims is not None else 1, + noop_with_empty_axes if noop_with_empty_axes is not None else 0, + ) + ) + + # 置空参数的相关信息 + data_tensor.descriptor.contents.invalidate() + reduced_tensor.descriptor.contents.invalidate() + + # 调用 infiniop 的函数 + check_error( + lib.infiniopReduceMax( + descriptor, + reduced_tensor.data, + data_tensor.data, + axes_ptr, + None, + ) + ) + if tensor_dtype == torch.float16: + assert torch.allclose(reduced, ans, atol=0, rtol=1e-3) + elif tensor_dtype == torch.float32: + assert torch.allclose(reduced, ans, atol=0, rtol=1e-5) + + # 销毁 descriptor + check_error(lib.infiniopDestroyReduceMaxDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes in test_cases: + test(lib, handle, "cpu", reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype=torch.float16) + test(lib, handle, "cpu", reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes + ((1, 3, 4), (2, 3, 4), [0], 1, 1, 0), # 测试 keepdims + ((1, 3, 4), (2, 3, 4), [-3], 1, 1, 0), + + ((3, 4), (2, 3, 4), [0], 1, 0, 0), + ((3, 4), (2, 3, 4), [-3], 1, 0, 0), + + ((2, 1, 4), (2, 3, 4), [1], 1, 1, 0), + ((2, 1, 4), (2, 3, 4), [-2], 1, 1, 0), + + ((2, 4), (2, 3, 4), [1], 1, 0, 0), + ((2, 4), (2, 3, 4), [-2], 1, 0, 0), + + ((2, 3, 1), (2, 3, 4), [2], 1, 1, 0), + ((2, 3, 1), (2, 3, 4), [-1], 1, 1, 0), + + ((2, 3), (2, 3, 4), [2], 1, 0, 0), + ((2, 3), (2, 3, 4), [-1], 1, 0, 0), + + ((2, 3, 1), (2, 3, 4), [2], 1, 1, 1), # 指定了 axes 时,noop 不生效 + ((2, 3), (2, 3, 4), [2], 1, 0, 1), # 指定了 axes 时,noop 不生效 + + ((1, 1, 4), (2, 3, 4), [0, 1], 2, 1, 0), + ((4), (2, 3, 4), [0, 1], 2, 0, 0), + + ((2, 1, 1), (2, 3, 4), [1, 2], 2, 1, 0), + ((2), (2, 3, 4), [1, 2], 2, 0, 0), + + ((1, 1, 1, 5, 6), (2, 3, 4, 5, 6), [0, 1, 2], 3, 1, 0), + ((5, 6), (2, 3, 4, 5, 6), [0, 1, 2], 3, 0, 0), + ((2, 1, 1, 1, 6), (2, 3, 4, 5, 6), [1, 2, 3], 3, 1, 0), + ((2, 6), (2, 3, 4, 5, 6), [1, 2, 3], 3, 0, 0), + ((2, 3, 1, 1, 1), (2, 3, 4, 5, 6), [2, 3, 4], 3, 1, 0), + ((2, 3), (2, 3, 4, 5, 6), [2, 3, 4], 3, 0, 0), + ((1, 3, 1, 5, 1), (2, 3, 4, 5, 6), [0, 2, 4], 3, 1, 0), + ((3, 5), (2, 3, 4, 5, 6), [0, 2, 4], 3, 0, 0), + + ((2, 1, 4), (2, 3, 4), [1], 1, None, None), # axes 非空,keepdims 默认为 1 + ((1, 1, 1), (2, 3, 4), None, 0, None, None), # axes 为空,keepdims 默认为 1,noop 默认为 0 + ((1, 1, 1), (2, 3, 4), None, 0, None, 0), # axes 为空,keepdims 默认为 1,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, None, 1), # axes 为空,keepdims 默认为 1,noop 设置为 1 + ((), (2, 3, 4), None, 0, 0, None), # axes 为空,keepdims 为 0,noop 默认为 0 + ((), (2, 3, 4), None, 0, 0, 0), # axes 为空,keepdims 为 0,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, 0, 1), # axes 为空,keepdims 为 0,noop 设置为 1,应该返回原数组 + ((1, 1, 1), (2, 3, 4), None, 0, 1, None), # axes 为空,keepdims 为 1,noop 默认为 0 + ((1, 1, 1), (2, 3, 4), None, 0, 1, 0), # axes 为空,keepdims 为 1,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, 1, 1), # axes 为空,keepdims 为 1,noop 设置为 1,应该返回原数组 + + ((), (), [], 0, 1, 0), # 输入标量 + ((), (), [], 0, 0, 0), + ((), (), None, 0, 1, 0), + ((), (), None, 0, 1, 1), + + ((), (0,), [], 0, 1, 0), # 空集合返回负无穷 + ((), (0,), [], 0, 0, 0), + ((), (0,), None, 0, 1, 0), + ((), (0,), None, 0, 1, 1), + + ((32, 1, 1, 224), (32, 3, 224, 224), [1, 2], 2, 1, 0), + ((32, 224), (32, 3, 224, 224), [1, 2], 2, 0, 0), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReduceMaxDescriptor.restype = c_int32 + lib.infiniopCreateReduceMaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReduceMaxDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + POINTER(c_int64), + c_size_t, + c_int32, + c_int32, + ] + lib.infiniopReduceMax.restype = c_int32 + lib.infiniopReduceMax.argtypes = [ + infiniopReduceMaxDescriptor_t, + c_void_p, + c_void_p, + POINTER(c_int64), + c_void_p, + ] + lib.infiniopDestroyReduceMaxDescriptor.restype = c_int32 + lib.infiniopDestroyReduceMaxDescriptor.argtypes = [infiniopReduceMaxDescriptor_t] + + if args.cpu: + test_cpu(lib, test_cases) + if not (args.cpu): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/operatorspy/tests/reduce_mean.py b/operatorspy/tests/reduce_mean.py new file mode 100644 index 00000000..7f53c1d5 --- /dev/null +++ b/operatorspy/tests/reduce_mean.py @@ -0,0 +1,204 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_int64, c_size_t +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch + +class ReduceMeanDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopReduceMeanDescriptor_t = POINTER(ReduceMeanDescriptor) + +def reduce_mean(data, dim, keepdim): + return torch.mean(data, dim=dim, keepdim=keepdim) + +def test(lib, handle, device, reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype): + print( + f"Testing ReduceMean on {device} with reduced_shape:{reduced_shape} data_shape:{data_shape} axes:{axes} " \ + f"axes_size:{axes_size} keepdims:{keepdims} noop_with_empty_axes:{noop_with_empty_axes} dtype:{tensor_dtype}" + ) + + # 生成随机数据 + data = torch.randn(data_shape, dtype=tensor_dtype).to(device) + reduced = torch.randn(reduced_shape, dtype=tensor_dtype).to(device) + + # 调用 pytorch 的函数获取实际结果 + if isinstance(data_shape, tuple) and data_shape == (0,): + # 如果是空集合,ans 应该为无穷小或最小值 + ans = torch.tensor(float("nan"), dtype=tensor_dtype).to(device) + else: + if (axes is None) and (noop_with_empty_axes == 1): + # axes 为空,且 noop_with_empty_axes 为 1,应该返回原数组 + ans = data + else: + # 其他情况都直接调用 pytorch 的函数返回结果 + ans = reduce_mean(data, + dim=axes, + keepdim=False if (keepdims==0) else True) + assert ans.shape == reduced.shape + + # 将 pytorch 的数据转换为 tensor + data_tensor = to_tensor(data, lib) + reduced_tensor = to_tensor(reduced, lib) + axes_ptr = None + if axes is not None: + axes_tensor = torch.tensor(axes, dtype=torch.int64).to(device) + axes_ptr = ctypes.cast(axes_tensor.data_ptr(), ctypes.POINTER(ctypes.c_int64)) + + # 创建 descriptor,C 语言不支持默认参数,因此在传入参数时根据参数是否为 None 来决定传入的值 + descriptor = infiniopReduceMeanDescriptor_t() + check_error( + lib.infiniopCreateReduceMeanDescriptor( + handle, + ctypes.byref(descriptor), + reduced_tensor.descriptor, + data_tensor.descriptor, + axes_ptr, + axes_size, + keepdims if keepdims is not None else 1, + noop_with_empty_axes if noop_with_empty_axes is not None else 0, + ) + ) + + # 置空参数的相关信息 + data_tensor.descriptor.contents.invalidate() + reduced_tensor.descriptor.contents.invalidate() + + # 调用 infiniop 的函数 + check_error( + lib.infiniopReduceMean( + descriptor, + reduced_tensor.data, + data_tensor.data, + axes_ptr, + None + ) + ) + + if isinstance(data_shape, tuple) and data_shape == (0,): + # 如果是空集合,reduced 为 nan + assert reduced.ndim == 0 and torch.isnan(reduced).all() + else: + if tensor_dtype == torch.float16: + assert torch.allclose(reduced, ans, atol=1e-6, rtol=1e-3) + elif tensor_dtype == torch.float32: + assert torch.allclose(reduced, ans, atol=1e-6, rtol=1e-5) + + # 销毁 descriptor + check_error(lib.infiniopDestroyReduceMeanDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes in test_cases: + test(lib, handle, "cpu", reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype=torch.float16) + test(lib, handle, "cpu", reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # reduced_shape, data_shape, axes, keepdims, noop_with_empty_axes + ((1, 3, 4), (2, 3, 4), [0], 1, 1, 0), # 测试 keepdims + ((1, 3, 4), (2, 3, 4), [-3], 1, 1, 0), + + ((3, 4), (2, 3, 4), [0], 1, 0, 0), + ((3, 4), (2, 3, 4), [-3], 1, 0, 0), + + ((2, 1, 4), (2, 3, 4), [1], 1, 1, 0), + ((2, 1, 4), (2, 3, 4), [-2], 1, 1, 0), + + ((2, 4), (2, 3, 4), [1], 1, 0, 0), + ((2, 4), (2, 3, 4), [-2], 1, 0, 0), + + ((2, 3, 1), (2, 3, 4), [2], 1, 1, 0), + ((2, 3, 1), (2, 3, 4), [-1], 1, 1, 0), + + ((2, 3), (2, 3, 4), [2], 1, 0, 0), + ((2, 3), (2, 3, 4), [-1], 1, 0, 0), + + ((2, 3, 1), (2, 3, 4), [2], 1, 1, 1), # 指定了 axes 时,noop 不生效 + ((2, 3), (2, 3, 4), [2], 1, 0, 1), # 指定了 axes 时,noop 不生效 + + ((1, 1, 4), (2, 3, 4), [0, 1], 2, 1, 0), + ((4), (2, 3, 4), [0, 1], 2, 0, 0), + + ((2, 1, 1), (2, 3, 4), [1, 2], 2, 1, 0), + ((2), (2, 3, 4), [1, 2], 2, 0, 0), + + ((1, 1, 1, 5, 6), (2, 3, 4, 5, 6), [0, 1, 2], 3, 1, 0), + ((5, 6), (2, 3, 4, 5, 6), [0, 1, 2], 3, 0, 0), + ((2, 1, 1, 1, 6), (2, 3, 4, 5, 6), [1, 2, 3], 3, 1, 0), + ((2, 6), (2, 3, 4, 5, 6), [1, 2, 3], 3, 0, 0), + ((2, 3, 1, 1, 1), (2, 3, 4, 5, 6), [2, 3, 4], 3, 1, 0), + ((2, 3), (2, 3, 4, 5, 6), [2, 3, 4], 3, 0, 0), + ((1, 3, 1, 5, 1), (2, 3, 4, 5, 6), [0, 2, 4], 3, 1, 0), + ((3, 5), (2, 3, 4, 5, 6), [0, 2, 4], 3, 0, 0), + + ((2, 1, 4), (2, 3, 4), [1], 1, None, None), # axes 非空,keepdims 默认为 1 + ((1, 1, 1), (2, 3, 4), None, 0, None, None), # axes 为空,keepdims 默认为 1,noop 默认为 0 + ((1, 1, 1), (2, 3, 4), None, 0, None, 0), # axes 为空,keepdims 默认为 1,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, None, 1), # axes 为空,keepdims 默认为 1,noop 设置为 1 + ((), (2, 3, 4), None, 0, 0, None), # axes 为空,keepdims 为 0,noop 默认为 0 + ((), (2, 3, 4), None, 0, 0, 0), # axes 为空,keepdims 为 0,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, 0, 1), # axes 为空,keepdims 为 0,noop 设置为 1,应该返回原数组 + ((1, 1, 1), (2, 3, 4), None, 0, 1, None), # axes 为空,keepdims 为 1,noop 默认为 0 + ((1, 1, 1), (2, 3, 4), None, 0, 1, 0), # axes 为空,keepdims 为 1,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, 1, 1), # axes 为空,keepdims 为 1,noop 设置为 1,应该返回原数组 + + ((), (), [], 0, 1, 0), # 输入标量 + ((), (), [], 0, 0, 0), + ((), (), None, 0, 1, 0), + ((), (), None, 0, 1, 1), + + ((), (0,), [], 0, 1, 0), # 空集合返回负无穷 + ((), (0,), [], 0, 0, 0), + ((), (0,), None, 0, 1, 0), + ((), (0,), None, 0, 1, 1), + + ((32, 1, 1, 224), (32, 3, 224, 224), [1, 2], 2, 1, 0), + ((32, 224), (32, 3, 224, 224), [1, 2], 2, 0, 0), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReduceMeanDescriptor.restype = c_int32 + lib.infiniopCreateReduceMeanDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReduceMeanDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + POINTER(c_int64), + c_size_t, + c_int32, + c_int32, + ] + lib.infiniopReduceMean.restype = c_int32 + lib.infiniopReduceMean.argtypes = [ + infiniopReduceMeanDescriptor_t, + c_void_p, + c_void_p, + POINTER(c_int64), + c_void_p, + ] + lib.infiniopDestroyReduceMeanDescriptor.restype = c_int32 + lib.infiniopDestroyReduceMeanDescriptor.argtypes = [infiniopReduceMeanDescriptor_t] + + if args.cpu: + test_cpu(lib, test_cases) + if not (args.cpu): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/operatorspy/tests/reduce_min.py b/operatorspy/tests/reduce_min.py new file mode 100644 index 00000000..4e441673 --- /dev/null +++ b/operatorspy/tests/reduce_min.py @@ -0,0 +1,199 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_int64, c_size_t +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch + +class ReduceMinDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopReduceMinDescriptor_t = POINTER(ReduceMinDescriptor) + +def reduce_min(data, dim, keepdim): + return torch.amin(data, dim=dim, keepdim=keepdim) + +def test(lib, handle, device, reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype): + print( + f"Testing ReduceMin on {device} with reduced_shape:{reduced_shape} data_shape:{data_shape} axes:{axes} " \ + f"axes_size:{axes_size} keepdims:{keepdims} noop_with_empty_axes:{noop_with_empty_axes} dtype:{tensor_dtype}" + ) + + # 生成随机数据 + data = torch.randn(data_shape, dtype=tensor_dtype).to(device) + reduced = torch.randn(reduced_shape, dtype=tensor_dtype).to(device) + + # 调用 pytorch 的函数获取实际结果 + if isinstance(data_shape, tuple) and data_shape == (0,): + # 如果产生的是空集合,ans 应该为无穷小或最小值 + ans = torch.tensor(float("inf"), dtype=tensor_dtype).to(device) + else: + if (axes is None) and (noop_with_empty_axes == 1): + # axes 为空,且 noop_with_empty_axes 为 1,应该返回原数组 + ans = data + else: + # 其他情况都直接调用 pytorch 的函数返回结果 + ans = reduce_min(data, + dim=axes, + keepdim=False if (keepdims==0) else True) + assert ans.shape == reduced.shape + + # 将 pytorch 的数据转换为 tensor + data_tensor = to_tensor(data, lib) + reduced_tensor = to_tensor(reduced, lib) + axes_ptr = None + if axes is not None: + axes_tensor = torch.tensor(axes, dtype=torch.int64).to(device) + axes_ptr = ctypes.cast(axes_tensor.data_ptr(), ctypes.POINTER(ctypes.c_int64)) + + # 创建 descriptor,C 语言不支持默认参数,因此在传入参数时根据参数是否为 None 来决定传入的值 + descriptor = infiniopReduceMinDescriptor_t() + check_error( + lib.infiniopCreateReduceMinDescriptor( + handle, + ctypes.byref(descriptor), + reduced_tensor.descriptor, + data_tensor.descriptor, + axes_ptr, + axes_size, + keepdims if keepdims is not None else 1, + noop_with_empty_axes if noop_with_empty_axes is not None else 0, + ) + ) + + # 置空参数的相关信息 + data_tensor.descriptor.contents.invalidate() + reduced_tensor.descriptor.contents.invalidate() + + # 调用 infiniop 的函数 + check_error( + lib.infiniopReduceMin( + descriptor, + reduced_tensor.data, + data_tensor.data, + axes_ptr, + None + ) + ) + if tensor_dtype == torch.float16: + assert torch.allclose(reduced, ans, atol=0, rtol=1e-3) + elif tensor_dtype == torch.float32: + assert torch.allclose(reduced, ans, atol=0, rtol=1e-5) + + # 销毁 descriptor + check_error(lib.infiniopDestroyReduceMinDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes in test_cases: + test(lib, handle, "cpu", reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype=torch.float16) + test(lib, handle, "cpu", reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # reduced_shape, data_shape, axes, axes_size, keepdims, noop_with_empty_axes + ((1, 3, 4), (2, 3, 4), [0], 1, 1, 0), # 测试 keepdims + ((1, 3, 4), (2, 3, 4), [-3], 1, 1, 0), + + ((3, 4), (2, 3, 4), [0], 1, 0, 0), + ((3, 4), (2, 3, 4), [-3], 1, 0, 0), + + ((2, 1, 4), (2, 3, 4), [1], 1, 1, 0), + ((2, 1, 4), (2, 3, 4), [-2], 1, 1, 0), + + ((2, 4), (2, 3, 4), [1], 1, 0, 0), + ((2, 4), (2, 3, 4), [-2], 1, 0, 0), + + ((2, 3, 1), (2, 3, 4), [2], 1, 1, 0), + ((2, 3, 1), (2, 3, 4), [-1], 1, 1, 0), + + ((2, 3), (2, 3, 4), [2], 1, 0, 0), + ((2, 3), (2, 3, 4), [-1], 1, 0, 0), + + ((2, 3, 1), (2, 3, 4), [2], 1, 1, 1), # 指定了 axes 时,noop 不生效 + ((2, 3), (2, 3, 4), [2], 1, 0, 1), # 指定了 axes 时,noop 不生效 + + ((1, 1, 4), (2, 3, 4), [0, 1], 2, 1, 0), + ((4), (2, 3, 4), [0, 1], 2, 0, 0), + + ((2, 1, 1), (2, 3, 4), [1, 2], 2, 1, 0), + ((2), (2, 3, 4), [1, 2], 2, 0, 0), + + ((1, 1, 1, 5, 6), (2, 3, 4, 5, 6), [0, 1, 2], 3, 1, 0), + ((5, 6), (2, 3, 4, 5, 6), [0, 1, 2], 3, 0, 0), + ((2, 1, 1, 1, 6), (2, 3, 4, 5, 6), [1, 2, 3], 3, 1, 0), + ((2, 6), (2, 3, 4, 5, 6), [1, 2, 3], 3, 0, 0), + ((2, 3, 1, 1, 1), (2, 3, 4, 5, 6), [2, 3, 4], 3, 1, 0), + ((2, 3), (2, 3, 4, 5, 6), [2, 3, 4], 3, 0, 0), + ((1, 3, 1, 5, 1), (2, 3, 4, 5, 6), [0, 2, 4], 3, 1, 0), + ((3, 5), (2, 3, 4, 5, 6), [0, 2, 4], 3, 0, 0), + + ((2, 1, 4), (2, 3, 4), [1], 1, None, None), # axes 非空,keepdims 默认为 1 + ((1, 1, 1), (2, 3, 4), None, 0, None, None), # axes 为空,keepdims 默认为 1,noop 默认为 0 + ((1, 1, 1), (2, 3, 4), None, 0, None, 0), # axes 为空,keepdims 默认为 1,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, None, 1), # axes 为空,keepdims 默认为 1,noop 设置为 1 + ((), (2, 3, 4), None, 0, 0, None), # axes 为空,keepdims 为 0,noop 默认为 0 + ((), (2, 3, 4), None, 0, 0, 0), # axes 为空,keepdims 为 0,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, 0, 1), # axes 为空,keepdims 为 0,noop 设置为 1,应该返回原数组 + ((1, 1, 1), (2, 3, 4), None, 0, 1, None), # axes 为空,keepdims 为 1,noop 默认为 0 + ((1, 1, 1), (2, 3, 4), None, 0, 1, 0), # axes 为空,keepdims 为 1,noop 设置为 0 + ((2, 3, 4), (2, 3, 4), None, 0, 1, 1), # axes 为空,keepdims 为 1,noop 设置为 1,应该返回原数组 + + ((), (), [], 0, 1, 0), # 输入标量 + ((), (), [], 0, 0, 0), + ((), (), None, 0, 1, 0), + ((), (), None, 0, 1, 1), + + ((), (0,), [], 0, 1, 0), # 空集合返回负无穷 + ((), (0,), [], 0, 0, 0), + ((), (0,), None, 0, 1, 0), + ((), (0,), None, 0, 1, 1), + + ((32, 1, 1, 224), (32, 3, 224, 224), [1, 2], 2, 1, 0), + ((32, 224), (32, 3, 224, 224), [1, 2], 2, 0, 0), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReduceMinDescriptor.restype = c_int32 + lib.infiniopCreateReduceMinDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReduceMinDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + POINTER(c_int64), + c_size_t, + c_int32, + c_int32, + ] + lib.infiniopReduceMin.restype = c_int32 + lib.infiniopReduceMin.argtypes = [ + infiniopReduceMinDescriptor_t, + c_void_p, + c_void_p, + POINTER(c_int64), + c_void_p, + ] + lib.infiniopDestroyReduceMinDescriptor.restype = c_int32 + lib.infiniopDestroyReduceMinDescriptor.argtypes = [infiniopReduceMinDescriptor_t] + + if args.cpu: + test_cpu(lib, test_cases) + if not (args.cpu): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/operatorspy/tests/where.py b/operatorspy/tests/where.py new file mode 100644 index 00000000..1bb474d0 --- /dev/null +++ b/operatorspy/tests/where.py @@ -0,0 +1,174 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p +import ctypes +import sys +import os + +# 将当前文件 ../../ 添加到环境变量中,这样就能使用 operatorspy 模块了 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + # operatorspy/devices.py 中的类型 + DeviceEnum, + # operatorspy/operators.py 中的函数和类型 + open_lib, + infiniopHandle_t, + infiniopTensorDescriptor_t, + # 下面是 operatorspy/utils.py 中的函数 + to_tensor, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +from enum import Enum, auto +import torch + +class WhereDescriptor(Structure): + """定义与 C 语言中的 WhereCpuDescriptor 类型相同的结构体 + """ + _fields_ = [("device", c_int32)] + + +# infiniopWhereDescriptor_t 类型是 C 语言中指向 WhereDescriptor 类型的指针 +infiniopWhereDescriptor_t = POINTER(WhereDescriptor) + +def where(x, y, z): + """调用 torch 中的 where 操作符执行 where 运算 + Args: + x: 第一个操作数 + y: 第二个操作数 + z: 第三个操作数 + Returns: + 计算结果 + """ + return torch.where(x, y, z) + +def test( + lib, + handle, + torch_device, + output_shape, + x_shape, + y_shape, + tensor_dtype=torch.float16 +): + """使用所加载的动态库,执行实际的算子操作,并与 torch 计算的结果进行比较 + + Args: + lib: 要测试的动态库 + handle: 动态库的句柄 + torch_device: torch 计算的设备 + output_shape: 输出张量的形状 + x_shape: 第一个操作数张量的形状 + y_shape: 第二个操作数张量的形状 + tensor_dtype: 张量的数据类型 + """ + print( + f"Testing Add on {torch_device} with output_shape:{output_shape} x_shape:{x_shape} y_shape:{y_shape} dtype:{tensor_dtype}" + ) + # 生成随机的输入数据 + x_data = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + y_data = torch.rand(y_shape, dtype=tensor_dtype).to(torch_device) + output_data = torch.rand(output_shape, dtype=tensor_dtype).to(torch_device) + condition_data = torch.randint(0, 2, output_shape, dtype=torch.uint8).to(torch_device) + + # 使用 torch 计算的正确结果 + ans = where(condition_data.bool(), x_data, y_data) + + # 将 torch 张量转换为 infiniop 张量 + x_tensor = to_tensor(x_data, lib) + y_tensor = to_tensor(y_data, lib) + output_tensor = to_tensor(output_data, lib) + condition_tensor = to_tensor(condition_data, lib) + + # 创建 where 算子描述符 + descriptor = infiniopWhereDescriptor_t() + check_error( + lib.infiniopCreateWhereDescriptor( + handle, + ctypes.byref(descriptor), + output_tensor.descriptor, + condition_tensor.descriptor, + x_tensor.descriptor, + y_tensor.descriptor, + ) + ) + + # 将输入和输出张量的相关信息置为无效,以防止运算时直接使用这些信息 + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + condition_tensor.descriptor.contents.invalidate() + output_tensor.descriptor.contents.invalidate() + # 执行 where 运算 + check_error( + lib.infiniopWhere( + descriptor, + output_tensor.data, + condition_tensor.data, + x_tensor.data, + y_tensor.data, + None + ) + ) + # 比较计算结果 + assert torch.allclose(output_data, ans, atol=0, rtol=0) + check_error(lib.infiniopDestroyWhereDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + """测试 CPU 上的给定算子 + Args: + lib: 要测试的 CPU 算子动态库 + test_cases: 要执行的测试用例 + """ + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for output_shape, x_shape, y_shape in test_cases: + # 测试数据类型为 float16 的数据 + test(lib, handle, "cpu", output_shape, x_shape, y_shape, tensor_dtype=torch.float16) + # 测试数据类型为 float32 的数据 + test(lib, handle, "cpu", output_shape, x_shape, y_shape, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # output_shape, x_shape, y_shape + ((1, 3), (1, 3), (1, 3)), + ((), (), ()), # 都为空时为标量 + ((3, 3), (3, 3), (3, 3)), + ((2, 20, 3), (2, 1, 3), (2, 20, 3)), # 广播测试 + ((32, 20, 512), (32, 20, 512), (32, 20, 512)), # 较多数据 + ((32, 256, 112, 112), (32, 256, 112, 112), (32, 256, 112, 112)), # 更多数据 + ((32, 256, 112, 112), (32, 256, 112, 1), (32, 256, 112, 112)), # 较多数据的广播 + ((2, 4, 3), (2, 1, 3), (4, 3)), + ((2, 3, 4, 5), (2, 3, 4, 5), (5,)), # 广播多个维度 + ((3, 2, 4, 5), (4, 5), (3, 2, 1, 1)), # 广播多个维度 + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateWhereDescriptor.restype = c_int32 + lib.infiniopCreateWhereDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopWhereDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopWhere.restype = c_int32 + lib.infiniopWhere.argtypes = [ + infiniopWhereDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyWhereDescriptor.restype = c_int32 + lib.infiniopDestroyWhereDescriptor.argtypes = [ + infiniopWhereDescriptor_t, + ] + if args.cpu: + test_cpu(lib, test_cases) + if not (args.cpu): + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/src/ops/clip/cpu/clip_cpu.cc b/src/ops/clip/cpu/clip_cpu.cc new file mode 100644 index 00000000..d64c9820 --- /dev/null +++ b/src/ops/clip/cpu/clip_cpu.cc @@ -0,0 +1,148 @@ +#include "clip_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +infiniopStatus_t cpuCreateClipDescriptor(infiniopHandle_t, ClipCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t min_desc, + infiniopTensorDescriptor_t max_desc) { + + // 确定输入张量元素的类型满足要求 + if (!dtype_eq(input_desc->dt, F16) && !dtype_eq(input_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + // 确定输出张量和输入张量的类型是否一致 + if (!dtype_eq(output_desc->dt, input_desc->dt)) { + return STATUS_BAD_TENSOR_DTYPE; + } + // 确定输入和输出张量的 strides 是否合法 + if (!is_contiguous(output_desc) || !is_contiguous(input_desc)) { + return STATUS_BAD_TENSOR_STRIDES; + } + // 确定输入张量和输出张量的维度是否一致 + if (output_desc->ndim != input_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + // 确定输入张量和输出张量的形状是否一致 + for (uint64_t i = 0; i < output_desc->ndim; ++i) { + if (output_desc->shape[i] != input_desc->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + // 确定 min_desc 和 max_desc 是否为空(两者均可以为空),如果非空,判断是否为标量或类型是否满足要求 + bool min_is_null = false; + bool max_is_null = false; + if (min_desc == nullptr) { + min_is_null = true; + } else { + if (min_desc->ndim != 0) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(min_desc->dt, input_desc->dt)) { + return STATUS_BAD_TENSOR_DTYPE; + } + } + if (max_desc == nullptr) { + max_is_null = true; + } else { + if (max_desc->ndim != 0) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(max_desc->dt, input_desc->dt)) { + return STATUS_BAD_TENSOR_DTYPE; + } + } + + // 计算创建 ClipDescriptor 所需的参数 + uint64_t n_dim = output_desc->ndim; + uint64_t data_size = + std::accumulate(output_desc->shape, output_desc->shape + n_dim, 1ULL, std::multiplies()); + + // 创建 ClipDescriptor + *desc_ptr = new ClipCpuDescriptor{ + DevCpu, + input_desc->dt, + n_dim, + data_size, + min_is_null, + max_is_null}; + + return STATUS_SUCCESS; +} + +/** + * @brief 实际执行裁剪操作的函数,可以根据模板实参,分别对 F16 和 F32 类型的数据执行裁剪 + * @tparam Tdata 要 clip 张量中的元素数据类型 + * @param desc ClipCpuDescriptor_t 类型的指针,指向裁剪操作的描述符 + * @param output 裁剪结果保存的位置 + * @param input 输入张量 + * @param min 裁剪的最小值 + * @param max 裁剪的最大值 + * @return + */ +template +infiniopStatus_t clip_cpu(ClipCpuDescriptor_t desc, void *output, void const *input, void const *min, + void const *max) { + + // 先将数据转换为对应类型的指针 + Tdata *output_data = reinterpret_cast(output); + Tdata const *input_data = reinterpret_cast(input); + + // 如果 min 和 max 都是空,直接复制 + if (desc->min_is_null && desc->max_is_null) { + std::memcpy(output_data, input_data, desc->data_size * desc->dtype.size); + return STATUS_SUCCESS; + } + + // 根据 min 和 max 是否为空决定是否需要转换裁剪的最小值和最大值 + Tdata min_value; + Tdata max_value; + if (!desc->min_is_null) { + min_value = *reinterpret_cast(min); + } + if (!desc->max_is_null) { + max_value = *reinterpret_cast(max); + } + // 遍历输入数据中的每个元素,执行裁剪操作 + for (uint64_t i = 0; i < desc->data_size; ++i) { + if constexpr (std::is_same::value) { + // 由于 F16 是用 uint16_t 表示的,所以不能使用 std::numeric_limits 获取 Tdata + // 类型的最大值和最小值,需要根据不同情况进行比较 + if (!desc->min_is_null && !desc->max_is_null) { + output_data[i] = f32_to_f16(std::min( + std::max(f16_to_f32(input_data[i]), f16_to_f32(min_value)), f16_to_f32(max_value))); + } else if (desc->min_is_null && !desc->max_is_null) { + output_data[i] = f32_to_f16(std::min(f16_to_f32(input_data[i]), f16_to_f32(max_value))); + } else if (!desc->min_is_null && desc->max_is_null) { + output_data[i] = f32_to_f16(std::max(f16_to_f32(input_data[i]), f16_to_f32(min_value))); + } + } else { + if (!desc->min_is_null && !desc->max_is_null) { + output_data[i] = std::min(std::max(input_data[i], min_value), max_value); + } else if (desc->min_is_null && !desc->max_is_null) { + output_data[i] = std::min(input_data[i], max_value); + } else if (!desc->min_is_null && desc->max_is_null) { + output_data[i] = std::max(input_data[i], min_value); + } + } + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuClip(ClipCpuDescriptor_t desc, void *output, void const *input, void const *min, + void const *max, void *stream) { + // 根据不同类型的张量类型,给 clip_cpu 函数传递不同的函数实参 + if (desc->dtype == F16) { + return clip_cpu(desc, output, input, min, max); + } + if (desc->dtype == F32) { + return clip_cpu(desc, output, input, min, max); + } + return STATUS_BAD_TENSOR_DTYPE; +} + +infiniopStatus_t cpuDestroyClipDescriptor(ClipCpuDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/clip/cpu/clip_cpu.h b/src/ops/clip/cpu/clip_cpu.h new file mode 100644 index 00000000..050bb244 --- /dev/null +++ b/src/ops/clip/cpu/clip_cpu.h @@ -0,0 +1,58 @@ +#ifndef __CPU_CLIP_H__ +#define __CPU_CLIP_H__ + +#include "operators.h" + +/** + * @brief CPU 裁剪操作的描述符 + */ +struct ClipCpuDescriptor { + Device device; + DT dtype; + uint64_t ndim; // 结果张量的维度 + uint64_t data_size; // 结果张量的元素数量 + bool min_is_null; // 最小值是否为空 + bool max_is_null; // 最大值是否为空 +}; +typedef struct ClipCpuDescriptor *ClipCpuDescriptor_t; + +/** + * @brief 创建一个用于对张量执行 CPU 裁剪操作的描述符 + * @param desc_ptr CPU 裁剪操作描述符的指针 + * @param output_desc 输出张量描述符 + * @param input_desc 输入张量描述符 + * @param min_desc 每个元素的最小值描述符 + * @param max_desc 每个元素的最大值描述符 + * @return 标识是否创建成功 + */ +infiniopStatus_t cpuCreateClipDescriptor(infiniopHandle_t, + ClipCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t min_desc, + infiniopTensorDescriptor_t max_desc); + + +/** + * @brief 执行 CPU 上的裁剪操作 + * @param desc ClipCpuDescriptor_t 类型的指针,指向裁剪操作的描述符 + * @param output 裁剪结果保存的位置 + * @param input 输入张量 + * @param min 每个元素的最小值 + * @param max 每个元素的最大值 + * @param stream 未使用 + * @return + */ +infiniopStatus_t cpuClip(ClipCpuDescriptor_t desc, + void *output, void const *input, + void const *min, void const *max, + void *stream); + +/** + * @brief 销毁指定的 CPU 裁剪操作描述符 + * @param desc 要销毁的裁剪操作描述符 + * @return 返回表示是否销毁成功的状态 + */ +infiniopStatus_t cpuDestroyClipDescriptor(ClipCpuDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/ops/clip/operator.cc b/src/ops/clip/operator.cc new file mode 100644 index 00000000..71aaf027 --- /dev/null +++ b/src/ops/clip/operator.cc @@ -0,0 +1,43 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/clip/clip.h" + +#ifdef ENABLE_CPU +#include "cpu/clip_cpu.h" +#endif + +__C infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle, infiniopClipDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t intput_desc, + infiniopTensorDescriptor_t min_desc, + infiniopTensorDescriptor_t max_desc) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateClipDescriptor(handle, (ClipCpuDescriptor_t *)desc_ptr, output_desc, intput_desc, + min_desc, max_desc); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc, void *output, void *input, void const *min, + void const *max, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuClip((ClipCpuDescriptor_t)desc, output, input, min, max, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyClipDescriptor((ClipCpuDescriptor_t)desc); +#endif + } + return STATUS_BAD_DEVICE; +} \ No newline at end of file diff --git a/src/ops/gather/cpu/gather_cpu.cc b/src/ops/gather/cpu/gather_cpu.cc new file mode 100644 index 00000000..1360b9ef --- /dev/null +++ b/src/ops/gather/cpu/gather_cpu.cc @@ -0,0 +1,233 @@ +#include "gather_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +/** + * @brief 将 indices 表示的索引按照字典序递增 1,即在当前索引的基础上,转换为下一个索引 + * @param indices 当前的索引 + * @param shape 索引所确定张量的维度 + * @param ndim 张量的阶数 + */ +inline void incrementOne(uint64_t *indices, uint64_t const *shape, uint64_t ndim) { + // 每次优先从最后一维开始递增 + for (int64_t i = ndim - 1; i >= 0; --i) { + // 如果当前维度递增后没有超过该维度的最大值,则直接递增并返回 + if (++indices[i] != shape[i]) { + return; + } + // 如果递增后等于了最大值,则将该维度的索引置为 0,继续递增前一个维度 + indices[i] = 0; + } +} + +/** + * @brief 根据给定的索引和步长,计算该索引确定的元素在一维数组中的位置 + * @param indices 给定的索引 + * @param strides 每个维度的偏移步长 + * @param ndim 索引的总维度 + * @return 返回该索引确定的元素在一维数组中的位置 + */ +inline uint64_t compactToFlat(uint64_t const *indices, uint64_t const *strides, uint64_t ndim) { + return std::inner_product(indices, indices + ndim, strides, uint64_t(0)); +} + +infiniopStatus_t cpuCreateGatherDescriptor(infiniopHandle_t, + GatherCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t data_desc, + infiniopTensorDescriptor_t index_desc, + int axis) { + + // 先检查输入数据是否合法 + // 0. 获取每个参数的秩序 + uint64_t output_ndim = output_desc->ndim; + uint64_t data_ndim = data_desc->ndim; + uint64_t index_ndim = index_desc->ndim; + + // 1. 检查所有张量的步长都合法 + if (!is_contiguous(output_desc) || !is_contiguous(data_desc) || !is_contiguous(index_desc)) { + return STATUS_BAD_TENSOR_SHAPE; + } + // 2. 检查 axis 的值是否在 [-ndata_ndim, data_ndim - 1] 范围内 + if (axis < -static_cast(data_ndim) || + axis >= static_cast(data_ndim)) { + return STATUS_BAD_PARAM; + } + + // 3. 检查输入和输出数据类型符合需要 + if (output_desc->dt != F16 && output_desc->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (data_desc->dt != output_desc->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + // 4. 检查索引数据类型符合要求 + if (index_desc->dt != I32 && index_desc->dt != I64) { + return STATUS_BAD_TENSOR_DTYPE; + } + + // 5. 检查 data 和 output 的秩合法 + if (data_ndim < 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (output_ndim != index_ndim + data_ndim - 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + + // 计算创建 GatherCpuDescriptor 所需的相关数据 + // 1. 将负的 axis 转换为正数 + uint64_t data_axis = axis < 0 ? (axis + data_ndim) : axis; + + // 2. 计算形状 + // 计算 output_shape 并与输入的 output_desc 比较是否相同 + uint64_t *output_shape = new uint64_t[output_ndim]; + for (uint64_t i = 0, j = 0; i < output_ndim;) { + // j 索引用于遍历 data 的维度 + if (j == data_axis) { + // 当 j 遍历到 data 的第 axis 索引处时,将 index 的 shape 插入 + for (uint64_t k = 0; k < index_ndim; ++k) { + output_shape[i++] = index_desc->shape[k]; + } + ++j; + } else { + output_shape[i++] = data_desc->shape[j++]; + } + } + // 比较 output_shape 和 output_desc->shape 是否相同 + for (uint64_t i = 0; i < output_ndim; ++i) { + if (output_shape[i] != output_desc->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + + // 保存 data_shape + uint64_t *data_shape = new uint64_t[data_ndim]; + std::copy(data_desc->shape, data_desc->shape + data_ndim, data_shape); + + // 3. 计算所有的元素个数 size + uint64_t output_size = std::accumulate(output_shape, output_shape + output_ndim, uint64_t(1), std::multiplies()); + uint64_t indices_size = std::accumulate(index_desc->shape, index_desc->shape + index_ndim, uint64_t(1), std::multiplies()); + + // 4. 保存 data 和 index 的步长 + uint64_t *data_strides = new uint64_t[data_ndim]; + std::copy(data_desc->strides, data_desc->strides + data_ndim, data_strides); + uint64_t *index_strides = new uint64_t[index_ndim]; + std::copy(index_desc->strides, index_desc->strides + index_ndim, index_strides); + + // 5. 初始化所有的 indices + uint64_t *output_indices = new uint64_t[output_ndim]; + std::fill(output_indices, output_indices + output_ndim, uint64_t(0)); + uint64_t *data_indices = new uint64_t[data_ndim]; + std::fill(data_indices, data_indices + data_ndim, uint64_t(0)); + uint64_t *index_indices = new uint64_t[index_ndim]; + std::fill(index_indices, index_indices + index_ndim, uint64_t(0)); + + // 创建 GatherCpuDescriptor 描述符对象 + *desc_ptr = new GatherCpuDescriptor { + DevCpu, + output_desc->dt, + index_desc->dt, + output_ndim, + output_size, + output_shape, + output_indices, + data_ndim, + data_shape, + data_strides, + data_indices, + index_ndim, + index_strides, + index_indices, + indices_size, + data_axis + }; + + return STATUS_SUCCESS; +} + +template +infiniopStatus_t gather_cpu(GatherCpuDescriptor_t desc, void *output, void *data, void const *indices, void *stream) { + auto input_data = reinterpret_cast(data); + auto output_data = reinterpret_cast(output); + // 将输入的 indices 数据复制到新的数组中,以进行修改 + auto indices_data = reinterpret_cast(indices); + + // 取出 desc 中保存的相关信息 + uint64_t axis = desc->axis; + uint64_t data_ndim = desc->data_ndim; + auto output_indices = desc->output_indices; + auto data_indices = desc->data_indices; + auto index_indices = desc->index_indices; + + // 将对应索引的结果保存到 output 张量中 + for (uint64_t i = 0; i < desc->output_size; + ++i, incrementOne(output_indices, desc->output_shape, desc->output_ndim)) { + + // 根据输出张量的 output_indices 索引确定在 data 中的索引 + for (uint64_t j = 0; j < desc->data_ndim; ++j) { + if (j < axis) { + // data 中小于 axis 部分的索引和 output_indices 中的索引相同 + data_indices[j] = output_indices[j]; + } else if (j == axis) { + // 等于 axis 部分的索引需要从 indices_data 中获取 + for (uint64_t k = 0; k < desc->index_ndim; ++k) { + index_indices[k] = output_indices[k + axis]; + } + // 根据 index_indices 索引获取 indices_data 中的值 + uint64_t cur_index_indices = compactToFlat(index_indices, desc->index_strides, desc->index_ndim); + // 获取 cur_index_indices 处的值 + Tindices element_index = indices_data[cur_index_indices]; + // 处理负值的情况 + if (element_index < -static_cast(desc->data_shape[desc->axis]) || + element_index >= static_cast(desc->data_shape[desc->axis])) { + return STATUS_BAD_PARAM; + } + if (element_index < 0) { + element_index += desc->data_shape[desc->axis]; + } + data_indices[j] = element_index; + } else { + // 大于 axis 部分的索引为原有的索引,向后偏移 index_ndim - 1 个位置 + data_indices[j] = output_indices[j + desc->index_ndim - 1]; + } + } + + // 根据 data_indices 索引获取 data 中的值 + uint64_t cur_data_indices = compactToFlat(data_indices, desc->data_strides, desc->data_ndim); + output_data[i] = input_data[cur_data_indices]; + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuGather(GatherCpuDescriptor_t desc, void *output, void *data, void const *indices, void *stream) { + if (desc->output_dtype == F16) { + if (desc->index_dtype == I32) { + return gather_cpu(desc, output, data, indices, stream); + } + if (desc->index_dtype == I64) { + return gather_cpu(desc, output, data, indices, stream); + } + } + if (desc->output_dtype == F32) { + if (desc->index_dtype == I32) { + return gather_cpu(desc, output, data, indices, stream); + } + if (desc->index_dtype == I64) { + return gather_cpu(desc, output, data, indices, stream); + } + } + return STATUS_BAD_TENSOR_DTYPE; +} + +infiniopStatus_t cpuDestroyGatherDescriptor(GatherCpuDescriptor_t desc) { + delete desc; + delete[] desc->output_shape; + delete[] desc->output_indices; + delete[] desc->data_shape; + delete[] desc->data_strides; + delete[] desc->data_indices; + delete[] desc->index_strides; + delete[] desc->index_indices; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/gather/cpu/gather_cpu.h b/src/ops/gather/cpu/gather_cpu.h new file mode 100644 index 00000000..ed0000b9 --- /dev/null +++ b/src/ops/gather/cpu/gather_cpu.h @@ -0,0 +1,76 @@ +#ifndef __CPU_GATHER_H__ +#define __CPU_GATHER_H__ + +#include "operators.h" + +/** + * @brief CPU Gather 操作的描述符 + */ +struct GatherCpuDescriptor { + Device device; // 设备类型 + DT output_dtype; // 数据中元素的类型 + DT index_dtype; // index 中元素的类型 + + uint64_t output_ndim; // 结果张量的维度 + uint64_t output_size; // 结果张量的元素数量 + uint64_t const *output_shape; // 结果张量的形状 + uint64_t *output_indices; // 用于在执行计算时索引 output 的每个元素 + + uint64_t data_ndim; // 输入张量的维度 + uint64_t *data_shape; // 输入张量的形状 + uint64_t const *data_strides; // 输入张量的偏移量步长(确定元素在张量中的位置) + uint64_t *data_indices; // 用于在计算时索引输入张量的每个元素 + + uint64_t index_ndim; // indices 张量的维度 + uint64_t const *index_strides; // indices 张量的偏移量步长(确定元素在张量中的位置) + uint64_t *index_indices; // 用于在计算时索引 indices 张量的每个元素 + uint64_t indices_size; // indices 张量的元素数量 + + uint64_t axis; // 索引的轴 +}; + +/** + * @brief CPU Gather 操作描述符的指针类型 + */ +typedef struct GatherCpuDescriptor *GatherCpuDescriptor_t; + +/** + * @brief 创建用于对张量执行 Gather 操作的 CPU 描述符 + * @param handle infiniop 句柄 + * @param desc_ptr 指向内部创建的 Gather 描述符的指针 + * @param output_desc 输出张量的描述符 + * @param data_desc 所操作张量的描述符 + * @param index_desc indices 张量描述符 + * @param axis 聚集时的轴 + * @return 返回表示创建是否成功的状态 + */ +infiniopStatus_t cpuCreateGatherDescriptor(infiniopHandle_t, + GatherCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t data_desc, + infiniopTensorDescriptor_t index_desc, + int axis); + +/** + * @brief 执行 Gather 操作 + * @param desc 指向内部创建的 Gather 描述符的指针 + * @param output 输出张量 + * @param data 所操作张量 + * @param indices indices 张量 + * @param stream 未使用参数 + * @return 返回操作执行的状态 + */ +infiniopStatus_t cpuGather(GatherCpuDescriptor_t desc, + void *output, + void *data, + void const *indices, + void *stream); + +/** + * @brief 销毁指定的 Gather 操作描述符 + * @param desc 要销毁的 Gather 操作描述符 + * @return 返回表示是否销毁成功的状态 + */ +infiniopStatus_t cpuDestroyGatherDescriptor(GatherCpuDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/ops/gather/operator.cc b/src/ops/gather/operator.cc new file mode 100644 index 00000000..9c1e11a1 --- /dev/null +++ b/src/ops/gather/operator.cc @@ -0,0 +1,46 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/gather/gather.h" + +#ifdef ENABLE_CPU +#include "cpu/gather_cpu.h" +#endif + +__C infiniopStatus_t infiniopCreateGatherDescriptor( + infiniopHandle_t handle, + infiniopGatherDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t data_desc, + infiniopTensorDescriptor_t index_desc, + int axis) { + + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateGatherDescriptor(handle, (GatherCpuDescriptor_t *)desc_ptr, output_desc, data_desc, index_desc, axis); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc, void *output, void *data, + void const *indices, void *stream) { + + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuGather((GatherCpuDescriptor_t)desc, output, data, indices, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyGatherDescriptor((GatherCpuDescriptor_t)desc); +#endif + } + return STATUS_BAD_DEVICE; +} \ No newline at end of file diff --git a/src/ops/reduce/cpu/reduce_cpu.cc b/src/ops/reduce/cpu/reduce_cpu.cc new file mode 100644 index 00000000..f4aeff7b --- /dev/null +++ b/src/ops/reduce/cpu/reduce_cpu.cc @@ -0,0 +1,399 @@ +#include "reduce_cpu.h" +#include "../../utils.h" + + +/** + * @brief 将 indices 表示的索引按照字典序递增 1,即在当前索引的基础上,转换为下一个索引 + * @param indices 当前的索引 + * @param shape 索引所确定张量的维度 + * @param ndim 张量的阶数 + */ +inline void incrementOne(uint64_t *indices, uint64_t const *shape, uint64_t ndim) { + // 每次优先从最后一维开始递增 + for (int64_t i = ndim - 1; i >= 0; --i) { + // 如果当前维度递增后没有超过该维度的最大值,则直接递增并返回 + if (++indices[i] != shape[i]) { + return; + } + // 如果递增后等于了最大值,则将该维度的索引置为 0,继续递增前一个维度 + indices[i] = 0; + } +} + +/** + * @brief 根据给定的索引和步长,计算该索引确定的元素在一维数组中的位置 + * @param indices 给定的索引 + * @param strides 每个维度的偏移步长 + * @param ndim 索引的总维度 + * @return 返回该索引确定的元素在一维数组中的位置 + */ +inline uint64_t compactToFlat(uint64_t const *indices, uint64_t const *strides, uint64_t ndim) { + return std::inner_product(indices, indices + ndim, strides, uint64_t(0)); +} + + +infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t handle, + ReduceCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_ndim, + int keepdims, + int noop_with_empty_axes, + int reduce_type) { + // 1. 检查 data 的类型和形状都正确 + if (data->dt != F16 && data->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(data)) { + return STATUS_BAD_TENSOR_SHAPE; + } + // 2. 检查 reduced 的类型和形状都正确 + if (reduced->dt != data->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + // 3. 如果要保留维度,reduced 应该可以广播到 data + if (keepdims && !isValidBroadcastShape(data, reduced)) { + return STATUS_BAD_TENSOR_SHAPE; + } + // 4. 如果不保留维度 + if (keepdims == 0) { + // 如果 axes 为空,noop_with_empty_axes 为 0,reduced 应该是个标量 + if (axes == nullptr && noop_with_empty_axes == 0 && reduced->ndim != 0) { + return STATUS_BAD_TENSOR_SHAPE; + } + // 如果 axes 非空,且 data 非标量时,reduced 应该是 data 去掉 axes 指定的维度 + if (axes != nullptr && data->ndim != 0 && reduced->ndim != data->ndim - axes_ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + // 计算创建 ReduceCpuDescriptor 所需的参数 + // 1. 计算 data 的总元素个数 + uint64_t data_size = + std::accumulate(data->shape, data->shape + data->ndim, uint64_t(1), std::multiplies()); + // 2. 保存 data 的形状 + uint64_t *data_shape = new uint64_t[data->ndim]; + std::copy(data->shape, data->shape + data->ndim, data_shape); + // 3. 保存 data 的步长 + uint64_t *data_strides = new uint64_t[data->ndim]; + std::copy(data->strides, data->strides + data->ndim, data_strides); + // 4. 初始化 data_indices + uint64_t *data_indices = new uint64_t[data->ndim]; + std::fill(data_indices, data_indices + data->ndim, 0); + // 5. 保存 axes + int64_t *axes_data = nullptr; + if (axes != nullptr) { + axes_data = new int64_t[axes_ndim]; + std::copy(axes, axes + axes_ndim, axes_data); + // 将 axes_data 中的负数转换为正数 + for (size_t i = 0; i < axes_ndim; ++i) { + if (axes_data[i] < -static_cast(data->ndim) || + axes_data[i] >= static_cast(data->ndim)) { + return STATUS_BAD_PARAM; + } else if (axes_data[i] < 0) { + axes_data[i] += static_cast(data->ndim); + } + } + // 对 axes_data 排序,方便后续遍历 + std::sort(axes_data, axes_data + axes_ndim); + } + // 6. 初始化 axes_indices 的索引 + uint64_t *axes_indices = new uint64_t[axes_ndim]; + std::fill(axes_indices, axes_indices + axes_ndim, 0); + // 7. 初始化 axes_shape (axes 中所指定维度的 shape) 和 axes_size + uint64_t *axes_shape = new uint64_t[axes_ndim]; + for (size_t i = 0; i < axes_ndim; ++i) { + axes_shape[i] = data_shape[axes_data[i]]; + } + u_int64_t axes_size = + std::accumulate(axes_shape, axes_shape + axes_ndim, uint64_t(1), std::multiplies()); + // 8. 初始化 reduced 的形状 + uint64_t *reduced_shape = new uint64_t[reduced->ndim]; + if (axes != nullptr && axes_ndim != 0) { + // axes 非空时,根据是否 keepdims 设置 reduced 的形状 + if (keepdims == 0) { + // 不保留维度时,reduced 的形状是 data 去掉 axes 指定的维度 + for (size_t i = 0, j = 0; i < data->ndim; ++i) { + if (i == axes_data[j]) { + ++j; + continue; + } else { + reduced_shape[i - j] = data_shape[i]; + } + } + } else { + // 保留维度时,reduced 中 axes 指定的维度为 1 + for (size_t i = 0, j = 0; i < data->ndim; ++i) { + if (i == axes_data[j]) { + reduced_shape[i] = 1; + ++j; + } else { + reduced_shape[i] = data_shape[i]; + } + } + } + } else { + // axes 为空时,reduced 的形状根据 noop_with_empty_axes 设置 + if (noop_with_empty_axes == 1) { + // noop_with_empty_axes 为 1,reduced 与 data 相同 + std::copy(data_shape, data_shape + data->ndim, reduced_shape); + } else { + // noop_with_empty_axes 为 0,表示对所有维度进行操作 + if (keepdims == 1) { + // 保留维度时,reduced 每个维度都为 1 + std::fill(reduced_shape, reduced_shape + reduced->ndim, 1); + } + // 不保留维度时,reduced 为标量,其 shape 的形状为 0 + } + } + // 9. 初始时根据 reduced 的形状计算总元素个数 + uint64_t reduced_size = + std::accumulate(reduced->shape, reduced->shape + reduced->ndim, uint64_t(1), std::multiplies()); + // 10. 初始化 reduced_indices + uint64_t *reduced_indices = new uint64_t[reduced->ndim]; + std::fill(reduced_indices, reduced_indices + reduced->ndim, 0); + + // 创建 ReduceCpuDescriptor + *desc_ptr = new ReduceCpuDescriptor{ + DevCpu, + data->dt, + data->ndim, + data_size, + data_shape, + data_strides, + data_indices, + axes_data, + axes_indices, + axes_shape, + axes_size, + axes_ndim, + reduced->ndim, + reduced_size, + reduced_shape, + reduced_indices, + keepdims, + noop_with_empty_axes, + reduce_type, + }; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t reduce_cpu(ReduceCpuDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream) { + // 1. 先将输入转换为对应类型的数据 + auto reduced_data = reinterpret_cast(reduced); + auto data_data = reinterpret_cast(data); + int64_t const *axes_data = desc->axes; + + // 2. 判断输入数据是否为空或标量 + if (desc->data_size == 1) { + // 如果只有一个元素,表示是一个标量,直接复制返回 + reduced_data[0] = data_data[0]; + return STATUS_SUCCESS; + } else if (desc->data_size == 0) { + // 如果元素个数为 0,表示是空集合,应该根据类型,返回无穷大 + if constexpr (std::is_same::value) { + if (desc->reduce_type == 1) { + // 返回 float32 的负无穷小 + reduced_data[0] = f32_to_f16(-std::numeric_limits::infinity()); + } else if (desc->reduce_type == 2) { + // 返回 float32 的正无穷大 + reduced_data[0] = f32_to_f16(std::numeric_limits::infinity()); + } else if (desc->reduce_type == 3) { + // 返回 float32 的 NaN + reduced_data[0] = f32_to_f16(std::numeric_limits::quiet_NaN()); + } + } else { + if (desc->reduce_type == 1) { + // 返回 float32 的负无穷小 + reduced_data[0] = -std::numeric_limits::infinity(); + } else if (desc->reduce_type == 2) { + // 返回 float32 的正无穷大 + reduced_data[0] = std::numeric_limits::infinity(); + } else if (desc->reduce_type == 3) { + // 返回 float32 的 NaN + reduced_data[0] = std::numeric_limits::quiet_NaN(); + } + } + return STATUS_SUCCESS; + } + + // 3. 判断是否返回原有数据 + if (axes_data == nullptr && desc->noop_with_empty_axes == 1) { + // 如果 axes_data 为空且 noop_with_empty_axes 为 true,则将 data 直接复制到 reduced 中后返回 + std::copy(data_data, data_data + desc->data_size, reduced_data); + return STATUS_SUCCESS; + } + + //4. 判断是否是对所有维度执行操作 + bool is_reduce_all = false; + if (axes_data == nullptr && desc->noop_with_empty_axes == 0) { + // axes 为空且 noop_with_empty_axes 为 0 时,表示对所有维度进行操作 + is_reduce_all = true; + } else if (desc->axes_ndim == desc->data_ndim) { + // 如果 axes_data 的大小等于 data_ndim,则表示对所有维度执行操作 + is_reduce_all = true; + } + // 如果是对所有维度进行操作,直接遍历 data 的每个元素,计算后保存到 reduced_data[0] 中,直接返回 + if (is_reduce_all) { + float res; + for (size_t i = 0; i < desc->data_size; ++i) { + if constexpr (std::is_same::value) { + if (i == 0) { + res = f16_to_f32(data_data[i]); + continue; + } + if (desc->reduce_type == 1) { // 最大值 + res = std::max(res, f16_to_f32(data_data[i])); + } else if (desc->reduce_type == 2) { // 最小值 + res = std::min(res, f16_to_f32(data_data[i])); + } else if (desc->reduce_type == 3) { // 平均值 + res += f16_to_f32(data_data[i]); + } + } else { + if (i == 0) { + res = data_data[i]; + continue; + } + if (desc->reduce_type == 1) { // 最大值 + res = std::max(res, data_data[i]); + } else if (desc->reduce_type == 2) { // 最小值 + res = std::min(res, data_data[i]); + } else if (desc->reduce_type == 3) { // 平均值 + res += data_data[i]; + } + } + } + if constexpr (std::is_same::value) { + if (desc->reduce_type == 3) { // 平均值 + res /= desc->data_size; + reduced_data[0] = f32_to_f16(res); + } else { + reduced_data[0] = f32_to_f16(res); + } + } else { + if (desc->reduce_type == 3) { // 平均值 + res /= desc->data_size; + reduced_data[0] = res; + } else { + reduced_data[0] = res; + } + } + return STATUS_SUCCESS; + } + + // 5. 对指定的维度进行操作,遍历 reduced_data 的每个索引,根据索引,从 data_data 中获取对应元素 + const auto &reduced_indices = desc->reduced_indices; + const auto &data_indices = desc->data_indices; + for (size_t i = 0; i < desc->reduced_size; + ++i, incrementOne(reduced_indices, desc->reduced_shape, desc->reduced_ndim)) { + + // 先将 reduced_indices 的非 reduce 索引保存到 data_indices 中 + if (desc->keepdims == 0) { + // 不保存 reduce 的维度时,reduced_indices 是所有不在 axes 中的索引 + for (size_t j = 0, k = 0, l = 0; j < desc->data_ndim; ++j) { + // l 遍历 axes_data,k 遍历 reduced_indices + if (l < desc->axes_ndim && j == axes_data[l]) { + ++l; + } else { + data_indices[j] = reduced_indices[k++]; + } + } + } else { + // 保留维度时,reduced_indices 是所有维度的索引,可以直接复制 + std::copy(reduced_indices, reduced_indices + desc->reduced_ndim, data_indices); + } + + // 遍历 axes 中指定的所有维度,与 reduced_indices 组合,构成 data_indices + float res; + const auto &axes_indices = desc->axes_indices; + for (size_t j = 0; j < desc->axes_size; + ++j, incrementOne(axes_indices, desc->axes_shape, desc->axes_ndim)) { + // 将 axes_indices 的索引保存到 axes 指定的 data_indices 对应的索引中 + for (size_t k = 0; k < desc->axes_ndim; ++k) { + data_indices[desc->axes[k]] = axes_indices[k]; + } + // 根据 data_indices 计算 data_data 的索引 + size_t data_index = compactToFlat(data_indices, desc->data_strides, desc->data_ndim); + + if (j == 0) { + // 如果是第一个元素,直接赋值给 res + if constexpr (std::is_same::value) { + res = f16_to_f32(data_data[data_index]); + } else { + res = data_data[data_index]; + } + continue; + } + + // 根据 reduce_type 计算 res + if constexpr (std::is_same::value) { + if (desc->reduce_type == 1) { // 最大值 + res = std::max(res, f16_to_f32(data_data[data_index])); + } else if (desc->reduce_type == 2) { // 最小值 + res = std::min(res, f16_to_f32(data_data[data_index])); + } else if (desc->reduce_type == 3) { // 平均值 + res += f16_to_f32(data_data[data_index]); + } + } else { + if (desc->reduce_type == 1) { // 最大值 + res = std::max(res, data_data[data_index]); + } else if (desc->reduce_type == 2) { // 最小值 + res = std::min(res, data_data[data_index]); + } else if (desc->reduce_type == 3) { // 平均值 + res += data_data[data_index]; + } + } + } + // 将 res 保存到 reduced_data 中 + if constexpr (std::is_same::value) { + if (desc->reduce_type == 3) { // 平均值 + res /= desc->axes_size; + reduced_data[i] = f32_to_f16(res); + } else { + reduced_data[i] = f32_to_f16(res); + } + } else { + if (desc->reduce_type == 3) { // 平均值 + res /= desc->axes_size; + reduced_data[i] = res; + } else { + reduced_data[i] = res; + } + } + } + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream) { + + if (desc->dtype == F16) { + return reduce_cpu(desc, reduced, data, axes, stream); + } + if (desc->dtype == F32) { + return reduce_cpu(desc, reduced, data, axes, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} + +infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc) { + delete desc; + delete[] desc->data_shape; + delete[] desc->data_strides; + delete[] desc->data_indices; + delete[] desc->axes_shape; + delete[] desc->axes; + delete[] desc->axes_indices; + delete[] desc->reduced_shape; + delete[] desc->reduced_indices; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/reduce/cpu/reduce_cpu.h b/src/ops/reduce/cpu/reduce_cpu.h new file mode 100644 index 00000000..430b59e4 --- /dev/null +++ b/src/ops/reduce/cpu/reduce_cpu.h @@ -0,0 +1,54 @@ +#ifndef __CPU_REDUCE_H__ +#define __CPU_REDUCE_H__ + +#include "../../../devices/cpu/common_cpu.h" +#include "operators.h" + +struct ReduceCpuDescriptor { + Device device; + DT dtype; + uint64_t data_ndim; // 输入数据的维度 + uint64_t data_size; // 输入数据的总元素个数 + uint64_t *data_shape; // 输入数据的形状 + uint64_t const *data_strides; // 输入数据的步长 + uint64_t *data_indices; // 用于根据 reduced_indices 遍历 data 的指定维度 + + int64_t const *axes; // 保存要进行 reduce 的维度 + uint64_t *axes_indices; // 用于遍历 axes 的索引 + uint64_t *axes_shape; // 所有 axes 指定轴的形状,遍历 axes 时需要用到 + u_int64_t axes_size; // axes 指定的所有轴形状构成张量的总元素个数 + size_t axes_ndim; // axes 数组的长度 + + uint64_t reduced_ndim; // 输出数据的维度 + uint64_t reduced_size; // 输出数据的总元素个数 + uint64_t *reduced_shape; // 输出数据的形状 + uint64_t *reduced_indices; // 用于遍历输出数据的索引(初始化为全 0) + + int keepdims; // 是否保留减少后的维度,1(默认)表示保留,0 表示不保留 + + // 指定输入参数 axes 为空时的行为,1 表示返回输入数据,0(默认)表示返回对所有维度进行操作 + int noop_with_empty_axes; + int reduce_type; // 1: max, 2: min, 3: mean +}; + +typedef struct ReduceCpuDescriptor *ReduceCpuDescriptor_t; + +infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t handle, + ReduceCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes, + int reduce_type); + +infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream); + + +infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc); +#endif \ No newline at end of file diff --git a/src/ops/reduce/operator.cc b/src/ops/reduce/operator.cc new file mode 100644 index 00000000..38888392 --- /dev/null +++ b/src/ops/reduce/operator.cc @@ -0,0 +1,60 @@ +#include "../utils.h" +#include "operators.h" +#include "reduce.h" + +#ifdef ENABLE_CPU +#include "cpu/reduce_cpu.h" +#endif + +__C infiniopStatus_t infiniopCreateReduceDescriptor( + infiniopHandle_t handle, + infiniopReduceDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + const int64_t *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes, + int reduce_type) { + + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateReduceDescriptor(handle, + (ReduceCpuDescriptor_t *)desc_ptr, + reduced, + data, + axes, + axes_size, + keepdims, + noop_with_empty_axes, + reduce_type); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopReduce( + infiniopReduceDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuReduce((ReduceCpuDescriptor_t)desc, reduced, data, axes, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyReduceDescriptor((ReduceCpuDescriptor_t)desc); +#endif + } + return STATUS_BAD_DEVICE; +} \ No newline at end of file diff --git a/src/ops/reduce/reduce.h b/src/ops/reduce/reduce.h new file mode 100644 index 00000000..d8fe3f79 --- /dev/null +++ b/src/ops/reduce/reduce.h @@ -0,0 +1,44 @@ +#ifndef REDUCE_H +#define REDUCE_H + +#include "export.h" +#include "operators.h" +#include + +typedef struct ReduceDescriptor { + Device device; +} ReduceDescriptor; + +typedef ReduceDescriptor *infiniopReduceDescriptor_t; + +/** + * @brief 根据传入的参数创建执行 Reduce 操作的描述符 + * @param handle infiniop 句柄 + * @param desc_ptr Reduce 描述符指针 + * @param reduced 结果张量描述符 + * @param data 输入张量描述符 + * @param axes 要进行 Reduce 的维度 + * @param keepdims 是否保留 Reduce 后的维度 + * @param noop_with_empty_axes 在 axes 为空时是否还执行 Reduce 操作 + * @param reduce_type 1: ReduceMax, 2: ReduceMin, 3: ReduceMean + * @return 描述符是否创建成功 + */ +__C infiniopStatus_t infiniopCreateReduceDescriptor(infiniopHandle_t handle, + infiniopReduceDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + const int64_t *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes, + int reduce_type); + +__C infiniopStatus_t infiniopReduce(infiniopReduceDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream); + +__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/ops/reduce_max/operator.cc b/src/ops/reduce_max/operator.cc new file mode 100644 index 00000000..96ce104e --- /dev/null +++ b/src/ops/reduce_max/operator.cc @@ -0,0 +1,46 @@ +#include "../reduce/reduce.h" +#include "../utils.h" +#include "ops/reduce_max/reduce_max.h" + +struct _ReduceMaxDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReduceMaxDescriptor *_ReduceMaxDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMaxDescriptor(infiniopHandle_t handle, + infiniopReduceMaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes) { + infiniopReduceDescriptor_t reduce_desc; + CHECK_STATUS( + infiniopCreateReduceDescriptor(handle, &reduce_desc, reduced, data, axes, axes_size, keepdims, noop_with_empty_axes, 1), + STATUS_SUCCESS); + + *(_ReduceMaxDescriptor_t *)desc_ptr = new _ReduceMaxDescriptor{handle->device, reduce_desc}; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReduceMax(infiniopReduceMaxDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream) { + auto _desc = (_ReduceMaxDescriptor_t)desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, reduced, data, axes, stream), + STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReduceMaxDescriptor(infiniopReduceMaxDescriptor_t desc) { + auto _desc = (_ReduceMaxDescriptor_t)desc; + CHECK_STATUS(infiniopDestroyReduceDescriptor(_desc->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/reduce_mean/operator.cc b/src/ops/reduce_mean/operator.cc new file mode 100644 index 00000000..8c8f2c62 --- /dev/null +++ b/src/ops/reduce_mean/operator.cc @@ -0,0 +1,46 @@ +#include "../reduce/reduce.h" +#include "../utils.h" +#include "ops/reduce_mean/reduce_mean.h" + +struct _ReduceMeanDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReduceMeanDescriptor *_ReduceMeanDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMeanDescriptor(infiniopHandle_t handle, + infiniopReduceMeanDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes) { + infiniopReduceDescriptor_t reduce_desc; + CHECK_STATUS( + infiniopCreateReduceDescriptor(handle, &reduce_desc, reduced, data, axes, axes_size, keepdims, noop_with_empty_axes, 3), + STATUS_SUCCESS); + + *(_ReduceMeanDescriptor_t *)desc_ptr = new _ReduceMeanDescriptor{handle->device, reduce_desc}; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReduceMean(infiniopReduceMeanDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream) { + auto _desc = (_ReduceMeanDescriptor_t)desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, reduced, data, axes, stream), + STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReduceMeanDescriptor(infiniopReduceMeanDescriptor_t desc) { + auto _desc = (_ReduceMeanDescriptor_t)desc; + CHECK_STATUS(infiniopDestroyReduceDescriptor(_desc->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/reduce_min/operator.cc b/src/ops/reduce_min/operator.cc new file mode 100644 index 00000000..c6fed2a0 --- /dev/null +++ b/src/ops/reduce_min/operator.cc @@ -0,0 +1,46 @@ +#include "../reduce/reduce.h" +#include "../utils.h" +#include "ops/reduce_min/reduce_min.h" + +struct _ReduceMinDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReduceMinDescriptor *_ReduceMinDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMinDescriptor(infiniopHandle_t handle, + infiniopReduceMinDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t reduced, + infiniopTensorDescriptor_t data, + int64_t const *axes, + size_t axes_size, + int keepdims, + int noop_with_empty_axes) { + infiniopReduceDescriptor_t reduce_desc; + CHECK_STATUS( + infiniopCreateReduceDescriptor(handle, &reduce_desc, reduced, data, axes, axes_size, keepdims, noop_with_empty_axes, 2), + STATUS_SUCCESS); + + *(_ReduceMinDescriptor_t *)desc_ptr = new _ReduceMinDescriptor{handle->device, reduce_desc}; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReduceMin(infiniopReduceMinDescriptor_t desc, + void *reduced, + void const *data, + int64_t const *axes, + void *stream) { + auto _desc = (_ReduceMinDescriptor_t)desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, reduced, data, axes, stream), + STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReduceMinDescriptor(infiniopReduceMinDescriptor_t desc) { + auto _desc = (_ReduceMinDescriptor_t)desc; + CHECK_STATUS(infiniopDestroyReduceDescriptor(_desc->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/where/cpu/where_cpu.cc b/src/ops/where/cpu/where_cpu.cc new file mode 100644 index 00000000..29c4fe93 --- /dev/null +++ b/src/ops/where/cpu/where_cpu.cc @@ -0,0 +1,163 @@ +#include "where_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +/** + * @brief 将 indices 表示的索引按照字典序递增 1,即在当前索引的基础上,转换为下一个索引 + * @param indices 当前的索引 + * @param shape 索引所确定张量的维度 + * @param ndim 张量的阶数 + */ +inline void incrementOne(uint64_t *indices, uint64_t const *shape, uint64_t ndim) { + // 每次优先从最后一维开始递增 + for (int64_t i = ndim - 1; i >= 0; --i) { + // 如果当前维度递增后没有超过该维度的最大值,则直接递增并返回 + if (++indices[i] != shape[i]) { + return; + } + // 如果递增后等于了最大值,则将该维度的索引置为 0,继续递增前一个维度 + indices[i] = 0; + } +} + +/** + * @brief 根据给定的索引和步长,计算该索引确定的元素在一维数组中的位置 + * @param indices 给定的索引 + * @param strides 每个维度的偏移步长 + * @param ndim 索引的总维度 + * @return 返回该索引确定的元素在一维数组中的位置 + */ +inline uint64_t compactToFlat(uint64_t const *indices, uint64_t const *strides, uint64_t ndim) { + return std::inner_product(indices, indices + ndim, strides, uint64_t(0)); +} + +infiniopStatus_t cpuCreateWhereDescriptor(infiniopHandle_t, + WhereCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t condition_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + + // 先执行检查 + // 1. 检查 condition 的类型是否为 bool(按照要求,bool 类型使用 uint8_t 表示) + if (condition_desc->dt != U8) { + return STATUS_BAD_TENSOR_DTYPE; + } + // 2. 确定 output 的张量类型是否为 F16 或 F32 + if (output_desc->dt != F16 && output_desc->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + // 3. 检查 x、y 的类型是否与 output 一致 + if (x_desc->dt != output_desc->dt || y_desc->dt != output_desc->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + // 4. 检查 x、y、output、condition 张量的 strides 是否都合法 + if (!is_contiguous(x_desc) || + !is_contiguous(y_desc) || + !is_contiguous(output_desc) || + !is_contiguous(condition_desc)) { + return STATUS_BAD_TENSOR_STRIDES; + } + // 5. 检查输出张量能否通过 x 和 y 广播后得到 + if (!isValidBroadcastShape(x_desc, y_desc, output_desc)) { + return STATUS_BAD_TENSOR_SHAPE; + } + + // 计算创建 WhereDescriptor 所需的参数 + uint64_t ndim = output_desc->ndim; + // 1. output_data_size + uint64_t output_data_size = std::accumulate(output_desc->shape, output_desc->shape + ndim, + 1ULL, std::multiplies()); + // 2. output_shape,应该利用 out_desc 中的 shape 信息创建副本 + uint64_t *output_shape = new uint64_t[ndim]; + std::copy(output_desc->shape, output_desc->shape + ndim, output_shape); + // 3. x_strides 和 y_strides,使用广播后的形状计算,方便后续根据索引确定元素在实际一维数组中的位置 + uint64_t *x_strides = new uint64_t[ndim]; + uint64_t *y_strides = new uint64_t[ndim]; + for (size_t i = 0; i < ndim; ++i) { + x_strides[i] = + (i < ndim - x_desc->ndim || output_desc->shape[i] != x_desc->shape[i + x_desc->ndim - ndim]) + ? 0 + : x_desc->strides[i + x_desc->ndim - ndim]; + y_strides[i] = + (i < ndim - y_desc->ndim || output_desc->shape[i] != y_desc->shape[i + y_desc->ndim - ndim]) + ? 0 + : y_desc->strides[i + y_desc->ndim - ndim]; + } + // 4. output_indices,用于索引输出张量的每个元素,初始化为全 0,表示从第一个位置开始遍历计算 + uint64_t *output_indices = new uint64_t[ndim]; + std::fill(output_indices, output_indices + ndim, 0); + + // 分配内存并创建 WhereDescriptor + *desc_ptr = new WhereCpuDescriptor { + DevCpu, + output_desc->dt, + ndim, + output_data_size, + output_shape, + x_strides, + y_strides, + output_indices + }; + + return STATUS_SUCCESS; +} + +/** + * @brief 为了可以根据元素类型直接获取对应类型的指针,因此设置一个模板函数,执行实际的操作 + * @tparam T 输入输出中元素类型 + * @param desc where 操作的描述符 + * @param output 输出张量 + * @param condition 条件张量 + * @param x 条件为真时的输入张量 + * @param y 条件为假时的输入张量 + * @param stream 未使用参数 + * @return 返回操作执行的结果 + */ +template +infiniopStatus_t where_cpu(WhereCpuDescriptor_t desc, void *output, void *condition, void *x, void *y, + void *stream) { + // 执行实际的 where 计算 + // 1. 先将参数转换为对应类型的指针 + auto x_data = reinterpret_cast(x); + auto y_data = reinterpret_cast(y); + auto condition_data = reinterpret_cast(condition); + auto output_data = reinterpret_cast(output); + const auto &indices = desc->output_indices; // 用于遍历输出张量每个元素的索引 + + // 2. 遍历每个元素,根据 condition 的值选择 x 或 y 的值,将结果写入 output_data + for (uint64_t i = 0; i < desc->output_data_size; ++i, incrementOne(indices, desc->output_shape, desc->ndim)) { + // 获取输入张量中索引对应的位置 + auto x_index = compactToFlat(indices, desc->x_strides, desc->ndim); + auto y_index = compactToFlat(indices, desc->y_strides, desc->ndim); + + // for (int j = 0; j < desc->ndim; ++j) { + // std::cout << indices[j] << " "; + // } + // std::cout << ": " << (condition_data[i] ? "true" : "false") << " " << x_index << " " << y_index << std::endl; + // 根据条件选择输入张量中的值 + output_data[i] = condition_data[i] ? x_data[x_index] : y_data[y_index]; + } + return STATUS_SUCCESS; + } + + +infiniopStatus_t cpuWhere(WhereCpuDescriptor_t desc, void *output, void *condition, void *x, void *y, + void *stream) { + if (desc->dtype == F16) { + return where_cpu(desc, output, condition, x, y, stream); + } + if (desc->dtype == F32) { + return where_cpu(desc, output, condition, x, y, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} + +infiniopStatus_t cpuDestroyWhereDescriptor(WhereCpuDescriptor_t desc) { + delete[] desc->output_shape; + delete[] desc->x_strides; + delete[] desc->y_strides; + delete[] desc->output_indices; + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/where/cpu/where_cpu.h b/src/ops/where/cpu/where_cpu.h new file mode 100644 index 00000000..d3d74713 --- /dev/null +++ b/src/ops/where/cpu/where_cpu.h @@ -0,0 +1,36 @@ +#ifndef __CPU_WHERE_H__ +#define __CPU_WHERE_H__ + +#include "operators.h" +#include +#include + +struct WhereCpuDescriptor { + Device device; + DT dtype; // 输入和输出张量中元素的类型 + uint64_t ndim; // 输出张量的阶数 + uint64_t output_data_size; // 输出张量的元素数量 + uint64_t const *output_shape; // 结果张量的形状(两个输入广播后的形状) + uint64_t const *x_strides; // 第一个操作数的偏移量步长 + uint64_t const *y_strides; // 第二个操作数的偏移量步长 + uint64_t *output_indices; // 用于遍历结果张量时的索引数组 +}; + +typedef struct WhereCpuDescriptor *WhereCpuDescriptor_t; + +infiniopStatus_t cpuCreateWhereDescriptor(infiniopHandle_t, + WhereCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t condition_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc); + +infiniopStatus_t cpuWhere(WhereCpuDescriptor_t desc, + void *output, void *condition, + void *x, void *y, + void *stream); + +infiniopStatus_t cpuDestroyWhereDescriptor(WhereCpuDescriptor_t desc); + + +#endif \ No newline at end of file diff --git a/src/ops/where/operator.cc b/src/ops/where/operator.cc new file mode 100644 index 00000000..5c0fa229 --- /dev/null +++ b/src/ops/where/operator.cc @@ -0,0 +1,44 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/where/where.h" + +#ifdef ENABLE_CPU +#include "cpu/where_cpu.h" +#endif + +__C infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t condition_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateWhereDescriptor(handle, (WhereCpuDescriptor_t *)desc_ptr, output_desc, condition_desc, + x_desc, y_desc); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc, void *output, void *condition, + void *x, void *y, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuWhere((WhereCpuDescriptor_t)desc, output, condition, x, y, stream); +#endif + } + return STATUS_BAD_DEVICE; +} + +__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyWhereDescriptor((WhereCpuDescriptor_t)desc); +#endif + } + return STATUS_BAD_DEVICE; +}