Skip to content

【训练营】Add cpu and cuda Clip, Gather, Where, ReduceMin, ReduceMax, ReduceMean #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/infini_operators.h
Original file line number Diff line number Diff line change
@@ -3,17 +3,23 @@
#include "ops/attention/attention.h"
#include "ops/avg_pool/avg_pool.h"
#include "ops/causal_softmax/causal_softmax.h"
#include "ops/clip/clip.h"
#include "ops/global_avg_pool/global_avg_pool.h"
#include "ops/expand/expand.h"
#include "ops/gather/gather.h"
#include "ops/gemm/gemm.h"
#include "ops/conv/conv.h"
#include "ops/matmul/matmul.h"
#include "ops/max_pool/max_pool.h"
#include "ops/mlp/mlp.h"
#include "ops/random_sample/random_sample.h"
#include "ops/rearrange/rearrange.h"
#include "ops/reduce_max/reduce_max.h"
#include "ops/reduce_mean/reduce_mean.h"
#include "ops/reduce_min/reduce_min.h"
#include "ops/relu/relu.h"
#include "ops/rms_norm/rms_norm.h"
#include "ops/rotary_embedding/rotary_embedding.h"
#include "ops/swiglu/swiglu.h"
#include "ops/where/where.h"
#include "tensor/tensor_descriptor.h"
28 changes: 28 additions & 0 deletions include/ops/clip/clip.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef CLIP_H
#define CLIP_H

#include "../../export.h"
#include "../../operators.h"
#include <optional>

typedef struct ClipDescriptor {
Device device;
} ClipDescriptor;

typedef ClipDescriptor *infiniopClipDescriptor_t;

__C __export infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle,
infiniopClipDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
float* lower_bound,
float* upper_bound);

__C __export infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc,
void *y,
void const *x,
void *stream);

__C __export infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc);

#endif
28 changes: 28 additions & 0 deletions include/ops/gather/gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef GATHER_H
#define GATHER_H

#include "../../export.h"
#include "../../operators.h"

typedef struct GatherDescriptor {
Device device;
} GatherDescriptor;

typedef GatherDescriptor *infiniopGatherDescriptor_t;

__C __export infiniopStatus_t infiniopCreateGatherDescriptor(infiniopHandle_t handle,
infiniopGatherDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input,
infiniopTensorDescriptor_t indices,
uint64_t axis);

__C __export infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc,
void *output,
void const *input,
void const *indices,
void *stream);

__C __export infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc);

#endif
27 changes: 27 additions & 0 deletions include/ops/reduce_max/reduce_max.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef REDUCE_MAX_H
#define REDUCE_MAX_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ReduceMaxDescriptor {
Device device;
} ReduceMaxDescriptor;

typedef ReduceMaxDescriptor *infiniopReduceMaxDescriptor_t;

__C __export infiniopStatus_t infiniopCreateReduceMaxDescriptor(infiniopHandle_t handle,
infiniopReduceMaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim);

__C __export infiniopStatus_t infiniopGetReduceMaxWorkspaceSize(infiniopReduceMaxDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopReduceMax(infiniopReduceMaxDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void *stream);

__C __export infiniopStatus_t infiniopDestroyReduceMaxDescriptor(infiniopReduceMaxDescriptor_t desc);


#endif
27 changes: 27 additions & 0 deletions include/ops/reduce_mean/reduce_mean.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef REDUCE_MEAN_H
#define REDUCE_MEAN_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ReduceMeanDescriptor {
Device device;
} ReduceMeanDescriptor;

typedef ReduceMeanDescriptor *infiniopReduceMeanDescriptor_t;

__C __export infiniopStatus_t infiniopCreateReduceMeanDescriptor(infiniopHandle_t handle,
infiniopReduceMeanDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim);

__C __export infiniopStatus_t infiniopGetReduceMeanWorkspaceSize(infiniopReduceMeanDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopReduceMean(infiniopReduceMeanDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void *stream);

__C __export infiniopStatus_t infiniopDestroyReduceMeanDescriptor(infiniopReduceMeanDescriptor_t desc);


#endif
27 changes: 27 additions & 0 deletions include/ops/reduce_min/reduce_min.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef REDUCE_MIN_H
#define REDUCE_MIN_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ReduceMinDescriptor {
Device device;
} ReduceMinDescriptor;

typedef ReduceMinDescriptor *infiniopReduceMinDescriptor_t;

__C __export infiniopStatus_t infiniopCreateReduceMinDescriptor(infiniopHandle_t handle,
infiniopReduceMinDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim);

__C __export infiniopStatus_t infiniopGetReduceMinWorkspaceSize(infiniopReduceMinDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopReduceMin(infiniopReduceMinDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void *stream);

__C __export infiniopStatus_t infiniopDestroyReduceMinDescriptor(infiniopReduceMinDescriptor_t desc);


#endif
29 changes: 29 additions & 0 deletions include/ops/where/where.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef WHERE_H
#define WHERE_H

#include "../../export.h"
#include "../../operators.h"

typedef struct WhereDescriptor {
Device device;
} WhereDescriptor;

typedef WhereDescriptor *infiniopWhereDescriptor_t;

__C __export infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle,
infiniopWhereDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t condition);

__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc,
void *output,
void const *x,
void const *y,
void const *condition,
void *stream);

__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc);

#endif
233 changes: 233 additions & 0 deletions operatorspy/tests/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
import sys
import os
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
rearrange_tensor,
create_workspace,
)

from operatorspy.tests.test_utils import get_args, synchronize_device
import torch

PROFILE = True
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

class ClipDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopClipDescriptor_t = POINTER(ClipDescriptor)

def clip(x, lower_bound, upper_bound):
return torch.clamp_max(torch.clamp_min(x, lower_bound if lower_bound else torch.finfo(x.dtype).min), upper_bound if upper_bound else torch.finfo(x.dtype).max)


def test(
lib,
handle,
torch_device,
x_shape,
lower_bound,
upper_bound,
dtype=torch.float16,
):
print(
f"Testing Clip on {torch_device} with x_shape:{x_shape} lower_bound:{lower_bound} upper_bound:{upper_bound} dtype:{dtype}"
)

x = torch.randn(x_shape, dtype=dtype, device=torch_device)
ans = clip(x, lower_bound, upper_bound)
y = torch.zeros(ans.shape, dtype=dtype, device=torch_device)


x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)


descriptor = infiniopClipDescriptor_t()
check_error(
lib.infiniopCreateClipDescriptor(
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
ctypes.byref(c_float(lower_bound)) if lower_bound else None,
ctypes.byref(c_float(upper_bound)) if upper_bound else None,
)
)


# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()


check_error(
lib.infiniopClip(
descriptor,
y_tensor.data,
x_tensor.data,
None,
)
)

assert torch.allclose(y, ans, atol=0, rtol=0)
# ans_ = ans.cpu().numpy().flatten()
# y_ = y.cpu().numpy().flatten()
# print(ans_)
# print(y_)
# atol = max(abs(ans_ - y_))
# rtol = atol / max(abs(y_) + 1e-8)

# print(f"atol: {atol}, rtol: {rtol}")

if PROFILE:
for i in range(NUM_PRERUN):
_ = clip(x, lower_bound, upper_bound)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = clip(x, lower_bound, upper_bound)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" pytorch time: {elapsed * 1000 :6f} ms")
for i in range(NUM_PRERUN):
check_error(
lib.infiniopClip(
descriptor,
y_tensor.data,
x_tensor.data,
None,
)
)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopClip(
descriptor,
y_tensor.data,
x_tensor.data,
None,
)
)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed * 1000 :6f} ms")

check_error(lib.infiniopDestroyClipDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)

for (
x_shape,
lower_bound,
upper_bound,
dtype,
) in test_cases:
test(
lib,
handle,
"cpu",
x_shape,
lower_bound,
upper_bound,
dtype,
)

destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)

for (
x_shape,
lower_bound,
upper_bound,
dtype,
) in test_cases:
test(
lib,
handle,
"cuda",
x_shape,
lower_bound,
upper_bound,
dtype,
)

destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
# x_shape, lower_bound, upper_bound, test_dtype
((2, 2), -0.1, 0.1, torch.float32),
((2, 2), 0.1, -0.1, torch.float32),
# ((2, 2), None, None, torch.float32),
((2, 2), None, 0.1, torch.float32),
((2, 2), 0.1, None, torch.float32),
((2048, 2048), -0.1, 0.1, torch.float32),

((2, 2), -0.1, 0.1, torch.float16),
((2, 2), 0.1, -0.1, torch.float16),
# ((2, 2), None, None, torch.float16),
((2, 2), None, 0.1, torch.float16),
((2, 2), 0.1, None, torch.float16),
((2048, 2048), -0.1, 0.1, torch.float16),
]
args = get_args()
lib = open_lib()

lib.infiniopCreateClipDescriptor.restype = c_int32
lib.infiniopCreateClipDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopClipDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
POINTER(c_float),
POINTER(c_float),
]


lib.infiniopClip.restype = c_int32
lib.infiniopClip.argtypes = [
infiniopClipDescriptor_t,
c_void_p,
c_void_p,
c_void_p,
]

lib.infiniopDestroyClipDescriptor.restype = c_int32
lib.infiniopDestroyClipDescriptor.argtypes = [
infiniopClipDescriptor_t,
]

if args.profile:
PROFILE = True
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if not (args.cpu or args.cuda):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
241 changes: 241 additions & 0 deletions operatorspy/tests/gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
import sys
import os
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
rearrange_tensor,
create_workspace,
)

from operatorspy.tests.test_utils import get_args, synchronize_device
import torch

PROFILE = True
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

class GatherDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopGatherDescriptor_t = POINTER(GatherDescriptor)

def gather(rank, axis, inputTensor, indexTensor):
indices = [slice(None)] * rank
indices[axis] = indexTensor
outTensor = inputTensor[tuple(indices)]
return outTensor


def test(
lib,
handle,
torch_device,
input_shape,
index_shape,
axis,
dtype=torch.float16,
):
print(
f"Testing Gather on {torch_device} with input_shape:{input_shape} indices_shape:{index_shape} axis:{axis} dtype:{dtype}"
)

input = torch.randn(input_shape, dtype=dtype, device=torch_device)
index = torch.randint(0, input.shape[axis], index_shape, device=torch_device).to(torch.int32)
ans = gather(len(input_shape), axis, input, index)
output = torch.zeros(ans.shape, dtype=dtype, device=torch_device)


input_tensor = to_tensor(input, lib)
index_tensor = to_tensor(index, lib)
output_tensor = to_tensor(output, lib)

descriptor = infiniopGatherDescriptor_t()
check_error(
lib.infiniopCreateGatherDescriptor(
handle,
ctypes.byref(descriptor),
output_tensor.descriptor,
input_tensor.descriptor,
index_tensor.descriptor,
axis,
)
)

# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
input_tensor.descriptor.contents.invalidate()
index_tensor.descriptor.contents.invalidate()
output_tensor.descriptor.contents.invalidate()


check_error(
lib.infiniopGather(
descriptor,
output_tensor.data,
input_tensor.data,
index_tensor.data,
None,
)
)

assert torch.allclose(output, ans, atol=0, rtol=0)
# ans_ = ans.cpu().numpy().flatten()
# output_ = output.cpu().numpy().flatten()
# atol = max(abs(ans_ - output_))
# rtol = atol / max(abs(output_) + 1e-8)

# print(f"atol: {atol}, rtol: {rtol}")

if PROFILE:
for i in range(NUM_PRERUN):
_ = gather(len(input_shape), axis, input, index)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = gather(len(input_shape), axis, input, index)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" pytorch time: {elapsed * 1000 :6f} ms")
for i in range(NUM_PRERUN):
check_error(
lib.infiniopGather(
descriptor,
output_tensor.data,
input_tensor.data,
index_tensor.data,
None,
)
)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopGather(
descriptor,
output_tensor.data,
input_tensor.data,
index_tensor.data,
None,
)
)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed * 1000 :6f} ms")

check_error(lib.infiniopDestroyGatherDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)

for (
input_shape,
index_shape,
axis,
dtype,
) in test_cases:
test(
lib,
handle,
"cpu",
input_shape,
index_shape,
axis,
dtype,
)

destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)

for (
input_shape,
index_shape,
axis,
dtype,
) in test_cases:
test(
lib,
handle,
"cuda",
input_shape,
index_shape,
axis,
dtype,
)

destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
# input_shape , index_shape, axis, test_dtype
((64, 64), (64, 64), 0, torch.float32),
((64, 64), (64, 64), 1, torch.float32),
((8, 8, 8, 8, 8), (8, 8), 0, torch.float32),
((8, 8, 8, 8, 8), (8, 8), 2, torch.float32),
((1024, 1024, 1024), (1, ), 1, torch.float32),
((2048, 2048), (128, 128), 0, torch.float32),
((2048, 2048), (128, 128), 1, torch.float32),

((64, 64), (64, 64), 0, torch.float16),
((64, 64), (64, 64), 1, torch.float16),
((8, 8, 8, 8, 8), (8, 8), 0, torch.float16),
((8, 8, 8, 8, 8), (8, 8), 2, torch.float16),
((1024, 1024, 1024), (1, ), 1, torch.float16),
((2048, 2048), (128, 128), 0, torch.float16),
((2048, 2048), (128, 128), 1, torch.float16),
]
args = get_args()
lib = open_lib()

lib.infiniopCreateGatherDescriptor.restype = c_int32
lib.infiniopCreateGatherDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopGatherDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_uint64,
]


lib.infiniopGather.restype = c_int32
lib.infiniopGather.argtypes = [
infiniopGatherDescriptor_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]

lib.infiniopDestroyGatherDescriptor.restype = c_int32
lib.infiniopDestroyGatherDescriptor.argtypes = [
infiniopGatherDescriptor_t,
]

if args.profile:
PROFILE = True
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if not (args.cpu or args.cuda):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
217 changes: 217 additions & 0 deletions operatorspy/tests/reduce_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@

from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
import ctypes
import sys
import os
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
)

from operatorspy.tests.test_utils import get_args
import torch
import math
import ctypes
from torch.nn import functional as F
from typing import List, Tuple

# constant for control whether profile the pytorch and lib functions
# NOTE: need to manually add synchronization function to the lib function,
# e.g., cudaDeviceSynchronize() for CUDA
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


class ReduceMaxDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopReduceMaxDescriptor_t = POINTER(ReduceMaxDescriptor)

def reduce_max(x, axes, keepdim=False):
return torch.amax(x, dim=axes, keepdim=keepdim)

# convert a python tuple to a ctype void pointer
def tuple_to_int_p(py_tuple: Tuple):
array = ctypes.c_int * len(py_tuple)
data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.POINTER(ctypes.c_int))


def test(
lib,
handle,
torch_device,
x_shape,
axes,
keepdim,
tensor_dtype=torch.float16,
):
print(
f"Testing ReduceMax on {torch_device} with x_shape: {x_shape}, axes:{axes}, keepdim:{keepdim}, dtype:{tensor_dtype}"
)
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
ans = reduce_max(x, axes, keepdim)
y = torch.zeros(ans.shape, dtype=tensor_dtype).to(torch_device)

# print(f'y_shape: {y.shape}')

for i in range(NUM_PRERUN if PROFILE else 1):
ans = reduce_max(x, axes, keepdim)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = reduce_max(x, axes, keepdim)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f"pytorch time: {elapsed :8f}")


x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
descriptor = infiniopReduceMaxDescriptor_t()

# print("!")

check_error(
lib.infiniopCreateReduceMaxDescriptor(
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
tuple_to_int_p(axes),
len(axes)
)
)
# print("!")

# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()

workspaceSize = ctypes.c_uint64(0)
check_error(
lib.infiniopGetReduceMaxWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
)
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))


for i in range(NUM_PRERUN if PROFILE else 1):
check_error(
lib.infiniopReduceMax(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopReduceMax(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed :8f}")

assert torch.allclose(y, ans, atol=0, rtol=0)
# ans_ = ans.cpu().numpy().flatten()
# y_ = y.cpu().numpy().flatten()
# atol = max(abs(ans_ - y_))
# rtol = atol / max(abs(y_) + 1e-8)
# print(f"ans: {ans_}")
# print(f"y: {y_}")

# print(f"atol: {atol}, rtol: {rtol}")
check_error(lib.infiniopDestroyReduceMaxDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, axes, keepdim, tensor_dtype in test_cases:
test(lib, handle, "cpu", x_shape, axes, keepdim, tensor_dtype)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, axes, keepdim, tensor_dtype in test_cases:
test(lib, handle, "cuda", x_shape, axes, keepdim, tensor_dtype)
destroy_handle(lib, handle)



if __name__ == "__main__":
test_cases = [
# x_shape, axes, keepdim, dtype
((2, 2, 3, 4), (0,), False, torch.float32),
((2, 2, 3, 4), (0,), True, torch.float32),
((2, 2, 3, 4), (1,), False, torch.float32),
((2, 2, 3, 4), (1, 2), False, torch.float32),
((2, 2, 3, 4), (1, 3), False, torch.float32),
((2, 2, 3, 4), (0, 1, 2, 3), False, torch.float32),
((2, 2, 3, 4, 5), (0, 3), False, torch.float32),
((64, 64, 64, 64), (0, 2), False, torch.float32),

((2, 2, 3, 4), (0,), False, torch.float16),
((2, 2, 3, 4), (0,), True, torch.float16),
((2, 2, 3, 4), (1,), False, torch.float16),
((2, 2, 3, 4), (1, 2), False, torch.float16),
((2, 2, 3, 4), (1, 3), False, torch.float16),
((2, 2, 3, 4), (0, 1, 2, 3), False, torch.float16),
((2, 2, 3, 4, 5), (0, 3), False, torch.float16),
((64, 64, 64, 64), (0, 2), False, torch.float16),
]
args = get_args()
lib = open_lib()
lib.infiniopCreateReduceMaxDescriptor.restype = c_int32
lib.infiniopCreateReduceMaxDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopReduceMaxDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
POINTER(ctypes.c_int),
c_uint64,
]
lib.infiniopReduceMax.restype = c_int32
lib.infiniopReduceMax.argtypes = [
infiniopReduceMaxDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyReduceMaxDescriptor.restype = c_int32
lib.infiniopDestroyReduceMaxDescriptor.argtypes = [
infiniopReduceMaxDescriptor_t,
]

if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if not (args.cpu or args.cuda):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
220 changes: 220 additions & 0 deletions operatorspy/tests/reduce_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@

from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
import ctypes
import sys
import os
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
)

from operatorspy.tests.test_utils import get_args
import torch
import math
import ctypes
from torch.nn import functional as F
from typing import List, Tuple

# constant for control whether profile the pytorch and lib functions
# NOTE: need to manually add synchronization function to the lib function,
# e.g., cudaDeviceSynchronize() for CUDA
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


class ReduceMeanDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopReduceMeanDescriptor_t = POINTER(ReduceMeanDescriptor)

def reduce_mean(x, axes, keepdim=False):
return torch.mean(x, dim=axes, keepdim=keepdim)

# convert a python tuple to a ctype void pointer
def tuple_to_int_p(py_tuple: Tuple):
array = ctypes.c_int * len(py_tuple)
data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.POINTER(ctypes.c_int))


def test(
lib,
handle,
torch_device,
x_shape,
axes,
keepdim,
tensor_dtype=torch.float16,
):
print(
f"Testing ReduceMean on {torch_device} with x_shape: {x_shape}, axes:{axes}, keepdim:{keepdim}, dtype:{tensor_dtype}"
)
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
ans = reduce_mean(x, axes, keepdim)
y = torch.zeros(ans.shape, dtype=tensor_dtype).to(torch_device)

# print(f'y_shape: {y.shape}')

for i in range(NUM_PRERUN if PROFILE else 1):
ans = reduce_mean(x, axes, keepdim)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = reduce_mean(x, axes, keepdim)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f"pytorch time: {elapsed :8f}")


x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
descriptor = infiniopReduceMeanDescriptor_t()

# print("!")

check_error(
lib.infiniopCreateReduceMeanDescriptor(
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
tuple_to_int_p(axes),
len(axes)
)
)
# print("!")

# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()

workspaceSize = ctypes.c_uint64(0)
check_error(
lib.infiniopGetReduceMeanWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
)
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))


for i in range(NUM_PRERUN if PROFILE else 1):
check_error(
lib.infiniopReduceMean(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopReduceMean(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed :8f}")

if(tensor_dtype == torch.float16):
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
else:
assert torch.allclose(y, ans, atol=0, rtol=1e-5)
# ans_ = ans.cpu().numpy().flatten()
# y_ = y.cpu().numpy().flatten()
# atol = max(abs(ans_ - y_))
# rtol = atol / max(abs(y_) + 1e-8)
# # print(f"ans: {ans_}")
# # print(f"y: {y_}")

# print(f"atol: {atol}, rtol: {rtol}")
check_error(lib.infiniopDestroyReduceMeanDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, axes, keepdim, tensor_dtype in test_cases:
test(lib, handle, "cpu", x_shape, axes, keepdim, tensor_dtype)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, axes, keepdim, tensor_dtype in test_cases:
test(lib, handle, "cuda", x_shape, axes, keepdim, tensor_dtype)
destroy_handle(lib, handle)



if __name__ == "__main__":
test_cases = [
# x_shape, axes, keepdim, dtype
((2, 2, 3, 4), (0,), False, torch.float32),
((2, 2, 3, 4), (0,), True, torch.float32),
((2, 2, 3, 4), (1,), False, torch.float32),
((2, 2, 3, 4), (1, 2), False, torch.float32),
((2, 2, 3, 4), (1, 3), False, torch.float32),
((2, 2, 3, 4), (0, 1, 2, 3), False, torch.float32),
((2, 2, 3, 4, 5), (0, 3), False, torch.float32),
((64, 64, 64, 64), (0, 2), False, torch.float32),

((2, 2, 3, 4), (0,), False, torch.float16),
((2, 2, 3, 4), (0,), True, torch.float16),
((2, 2, 3, 4), (1,), False, torch.float16),
((2, 2, 3, 4), (1, 2), False, torch.float16),
((2, 2, 3, 4), (1, 3), False, torch.float16),
((2, 2, 3, 4), (0, 1, 2, 3), False, torch.float16),
((2, 2, 3, 4, 5), (0, 3), False, torch.float16),
((64, 64, 64, 64), (0, 2), False, torch.float16),
]
args = get_args()
lib = open_lib()
lib.infiniopCreateReduceMeanDescriptor.restype = c_int32
lib.infiniopCreateReduceMeanDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopReduceMeanDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
POINTER(ctypes.c_int),
c_uint64,
]
lib.infiniopReduceMean.restype = c_int32
lib.infiniopReduceMean.argtypes = [
infiniopReduceMeanDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyReduceMeanDescriptor.restype = c_int32
lib.infiniopDestroyReduceMeanDescriptor.argtypes = [
infiniopReduceMeanDescriptor_t,
]

if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if not (args.cpu or args.cuda):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
220 changes: 220 additions & 0 deletions operatorspy/tests/reduce_min.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
import ctypes
import sys
import os
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
)

from operatorspy.tests.test_utils import get_args
import torch
import math
import ctypes
from torch.nn import functional as F
from typing import List, Tuple

# constant for control whether profile the pytorch and lib functions
# NOTE: need to manually add synchronization function to the lib function,
# e.g., cudaDeviceSynchronize() for CUDA
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


class ReduceMinDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopReduceMinDescriptor_t = POINTER(ReduceMinDescriptor)

def reduce_min(x, axes, keepdim=False):
return torch.amin(x, dim=axes, keepdim=keepdim)

# convert a python tuple to a ctype void pointer
def tuple_to_int_p(py_tuple: Tuple):
array = ctypes.c_int * len(py_tuple)
data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.POINTER(ctypes.c_int))


def test(
lib,
handle,
torch_device,
x_shape,
axes,
keepdim,
tensor_dtype=torch.float16,
):
print(
f"Testing ReduceMin on {torch_device} with x_shape: {x_shape}, axes:{axes}, keepdim:{keepdim}, dtype:{tensor_dtype}"
)
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
ans = reduce_min(x, axes, keepdim)
y = torch.zeros(ans.shape, dtype=tensor_dtype).to(torch_device)

# print(f'y_shape: {y.shape}')

for i in range(NUM_PRERUN if PROFILE else 1):
ans = reduce_min(x, axes, keepdim)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = reduce_min(x, axes, keepdim)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f"pytorch time: {elapsed :8f}")


x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
descriptor = infiniopReduceMinDescriptor_t()

# print("!")

check_error(
lib.infiniopCreateReduceMinDescriptor(
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
tuple_to_int_p(axes),
len(axes)
)
)
# print("!")

# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()

workspaceSize = ctypes.c_uint64(0)
check_error(
lib.infiniopGetReduceMinWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
)
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))


for i in range(NUM_PRERUN if PROFILE else 1):
check_error(
lib.infiniopReduceMin(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopReduceMin(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed :8f}")

assert torch.allclose(y, ans, atol=0, rtol=0)
# ans_ = ans.cpu().numpy().flatten()
# y_ = y.cpu().numpy().flatten()
# atol = max(abs(ans_ - y_))
# rtol = atol / max(abs(y_) + 1e-8)
# print(f"ans: {ans_}")
# print(f"y: {y_}")

# print(f"atol: {atol}, rtol: {rtol}")
check_error(lib.infiniopDestroyReduceMinDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, axes, keepdim, tensor_dtype in test_cases:
test(lib, handle, "cpu", x_shape, axes, keepdim, tensor_dtype)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, axes, keepdim, tensor_dtype in test_cases:
test(lib, handle, "cuda", x_shape, axes, keepdim, tensor_dtype)
destroy_handle(lib, handle)



if __name__ == "__main__":
test_cases = [
# x_shape, axes, keepdim, dtype
((2, 3, 4), (1,), False, torch.float32),
((2, 3, 4), (1,), True, torch.float32),
((2, 2, 3, 4), (0,), False, torch.float32),
((2, 2, 3, 4), (0,), True, torch.float32),
((2, 2, 3, 4), (1,), False, torch.float32),
((2, 2, 3, 4), (1, 2), False, torch.float32),
((2, 2, 3, 4), (1, 3), False, torch.float32),
((2, 2, 3, 4), (0, 1, 2, 3), False, torch.float32),
((2, 2, 3, 4, 5), (0, 3), False, torch.float32),
((64, 64, 64, 64), (0, 2), False, torch.float32),

((2, 3, 4), (1,), False, torch.float16),
((2, 3, 4), (1,), True, torch.float16),
((2, 2, 3, 4), (0,), False, torch.float16),
((2, 2, 3, 4), (0,), True, torch.float16),
((2, 2, 3, 4), (1,), False, torch.float16),
((2, 2, 3, 4), (1, 2), False, torch.float16),
((2, 2, 3, 4), (1, 3), False, torch.float16),
((2, 2, 3, 4), (0, 1, 2, 3), False, torch.float16),
((2, 2, 3, 4, 5), (0, 3), False, torch.float16),
((64, 64, 64, 64), (0, 2), False, torch.float16),
]
args = get_args()
lib = open_lib()
lib.infiniopCreateReduceMinDescriptor.restype = c_int32
lib.infiniopCreateReduceMinDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopReduceMinDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
POINTER(ctypes.c_int),
c_uint64,
]
lib.infiniopReduceMin.restype = c_int32
lib.infiniopReduceMin.argtypes = [
infiniopReduceMinDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyReduceMinDescriptor.restype = c_int32
lib.infiniopDestroyReduceMinDescriptor.argtypes = [
infiniopReduceMinDescriptor_t,
]

if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if not (args.cpu or args.cuda):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
253 changes: 253 additions & 0 deletions operatorspy/tests/where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
import sys
import os
import time

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
rearrange_tensor,
create_workspace,
)

from operatorspy.tests.test_utils import get_args, synchronize_device
import torch

PROFILE = True
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

class WhereDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopWhereDescriptor_t = POINTER(WhereDescriptor)

def where(condition, x, y):
return torch.where(condition, x, y)


def test(
lib,
handle,
torch_device,
x_shape,
y_shape,
condition_shape,
dtype=torch.float16,
):
print(
f"Testing Where on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} condition_shape:{condition_shape} dtype:{dtype}"
)

x = torch.randn(x_shape, dtype=dtype, device=torch_device)
y = torch.randn(y_shape, dtype=dtype, device=torch_device)
condition = torch.randint(0, 2, condition_shape, device=torch_device).to(dtype=torch.uint8)
ans = where(condition.to(torch.bool), x, y)

output = torch.zeros(ans.shape, dtype=dtype, device=torch_device)

x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
condition_tensor = to_tensor(condition, lib)
output_tensor = to_tensor(output, lib)


descriptor = infiniopWhereDescriptor_t()
check_error(
lib.infiniopCreateWhereDescriptor(
handle,
ctypes.byref(descriptor),
output_tensor.descriptor,
x_tensor.descriptor,
y_tensor.descriptor,
condition_tensor.descriptor,
)
)


# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()
condition_tensor.descriptor.contents.invalidate()
output_tensor.descriptor.contents.invalidate()


check_error(
lib.infiniopWhere(
descriptor,
output_tensor.data,
x_tensor.data,
y_tensor.data,
condition_tensor.data,
None,
)
)

assert torch.allclose(output, ans, atol=0, rtol=0)
# ans_ = ans.cpu().numpy().flatten()
# output_ = output.cpu().numpy().flatten()
# print(ans_)
# print(output_)
# atol = max(abs(ans_ - output_))
# rtol = atol / max(abs(output_) + 1e-8)

# print(f"atol: {atol}, rtol: {rtol}")

if PROFILE:
for i in range(NUM_PRERUN):
_ = where(condition.to(torch.bool), x, y)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = where(condition.to(torch.bool), x, y)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" pytorch time: {elapsed * 1000 :6f} ms")
for i in range(NUM_PRERUN):
check_error(
lib.infiniopWhere(
descriptor,
output_tensor.data,
x_tensor.data,
y_tensor.data,
condition_tensor.data,
None,
)
)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopWhere(
descriptor,
output_tensor.data,
x_tensor.data,
y_tensor.data,
condition_tensor.data,
None,
)
)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed * 1000 :6f} ms")

check_error(lib.infiniopDestroyWhereDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)

for (
x_shape,
y_shape,
condition_shape,
dtype,
) in test_cases:
test(
lib,
handle,
"cpu",
x_shape,
y_shape,
condition_shape,
dtype,
)

destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)

for (
x_shape,
y_shape,
condition_shape,
dtype,
) in test_cases:
test(
lib,
handle,
"cuda",
x_shape,
y_shape,
condition_shape,
dtype,
)

destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
# x_shape, y_shape, condition_shape, dtype
((2, 2), (2, 2), (2, 2), torch.float32),
((10,), (10,), (10,), torch.float32),
((1,), (2, 2), (2, 2), torch.float32),
((2, 2), (1,), (2, 2), torch.float32),
((2, 2), (2, 2), (1,), torch.float32),
((1, ), (1, ), (2, 2), torch.float32),
((1, ), (2, 2), (1, ), torch.float32),
((2, 2), (1, ), (1, ), torch.float32),
((1024, 1024), (1024, 1024), (1024, 1024), torch.float32),

((2, 2), (2, 2), (2, 2), torch.float16),
((10,), (10,), (10,), torch.float16),
((1,), (2, 2), (2, 2), torch.float16),
((2, 2), (1,), (2, 2), torch.float16),
((2, 2), (2, 2), (1,), torch.float16),
((1, ), (1, ), (2, 2), torch.float16),
((1, ), (2, 2), (1, ), torch.float16),
((2, 2), (1, ), (1, ), torch.float16),
((1024, 1024), (1024, 1024), (1024, 1024), torch.float16),
]
args = get_args()
lib = open_lib()

lib.infiniopCreateWhereDescriptor.restype = c_int32
lib.infiniopCreateWhereDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopWhereDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]


lib.infiniopWhere.restype = c_int32
lib.infiniopWhere.argtypes = [
infiniopWhereDescriptor_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]

lib.infiniopDestroyWhereDescriptor.restype = c_int32
lib.infiniopDestroyWhereDescriptor.argtypes = [
infiniopWhereDescriptor_t,
]

if args.profile:
PROFILE = True
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if not (args.cpu or args.cuda):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
96 changes: 96 additions & 0 deletions src/ops/clip/cpu/clip_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#include "clip_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../utils.h"

infiniopStatus_t cpuCreateClipDescriptor(infiniopHandle_t,
ClipCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
float *lower_bound,
float *upper_bound) {
if (x->ndim != y->ndim) {
return STATUS_BAD_TENSOR_SHAPE;
}
for (uint64_t i = 0; i < x->ndim; i++) {
if (x->shape[i] != y->shape[i]) {
return STATUS_BAD_TENSOR_SHAPE;
}
}

if (!is_contiguous(y) || !is_contiguous(x)) {
return STATUS_BAD_TENSOR_STRIDES;
}

if (y->dt != x->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

uint64_t data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies<uint64_t>());

bool has_lower_bound = lower_bound != nullptr;
bool has_upper_bound = upper_bound != nullptr;
float lower_bound_ = has_lower_bound ? *lower_bound : std::numeric_limits<float>::lowest();
float upper_bound_ = has_upper_bound ? *upper_bound : std::numeric_limits<float>::max();

*desc_ptr = new ClipCpuDescriptor{
DevCpu,
y->dt,
data_size,
has_lower_bound,
lower_bound_,
has_upper_bound,
upper_bound_,
};
return STATUS_SUCCESS;
}

infiniopStatus_t cpuDestroyClipDescriptor(ClipCpuDescriptor_t desc) {
delete desc;
return STATUS_SUCCESS;
}

template<typename Tdata>
infiniopStatus_t clip_cpu(ClipCpuDescriptor_t desc, void *y, void const *x) {
auto x_ = reinterpret_cast<Tdata const *>(x);
auto y_ = reinterpret_cast<Tdata *>(y);
auto lower_bound_ = desc->lower_bound;
auto upper_bound_ = desc->upper_bound;
auto data_size_ = desc->data_size;

if constexpr (std::is_same<Tdata, uint16_t>::value) {
if (!desc->has_lower_bound && !desc->has_upper_bound) {
std::memcpy(y_, x_, data_size_ * sizeof(Tdata));
} else {
#pragma omp parallel for
for (uint64_t i = 0; i < data_size_; i++) {
float x_val = f16_to_f32(x_[i]);
x_val = (desc->has_lower_bound && x_val < lower_bound_) ? lower_bound_ : x_val;
x_val = (desc->has_upper_bound && x_val > upper_bound_) ? upper_bound_ : x_val;
y_[i] = f32_to_f16(x_val);
}
}
} else {
if (!desc->has_lower_bound && !desc->has_upper_bound) {
std::memcpy(y_, x_, data_size_ * sizeof(Tdata));
} else {
#pragma omp parallel for
for (uint64_t i = 0; i < data_size_; i++) {
Tdata x_val = x_[i];
x_val = (desc->has_lower_bound && x_val < lower_bound_) ? lower_bound_ : x_val;
x_val = (desc->has_upper_bound && x_val > upper_bound_) ? upper_bound_ : x_val;
y_[i] = x_val;
}
}
}
return STATUS_SUCCESS;
}

infiniopStatus_t cpuClip(ClipCpuDescriptor_t desc, void *y, void const *x) {
if (desc->dtype == F16) {
return clip_cpu<uint16_t>(desc, y, x);
}
if (desc->dtype == F32) {
return clip_cpu<float>(desc, y, x);
}
return STATUS_BAD_TENSOR_DTYPE;
}
34 changes: 34 additions & 0 deletions src/ops/clip/cpu/clip_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __CPU_CLIP_H__
#define __CPU_CLIP_H__

#include "operators.h"
#include <cstring>
#include <numeric>
#include <optional>

struct ClipCpuDescriptor {
Device device;
DT dtype;
uint64_t data_size;
bool has_lower_bound;
float lower_bound;
bool has_upper_bound;
float upper_bound;
};

typedef struct ClipCpuDescriptor *ClipCpuDescriptor_t;

infiniopStatus_t cpuCreateClipDescriptor(infiniopHandle_t,
ClipCpuDescriptor_t *,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
float* lower_bound,
float* upper_bound);

infiniopStatus_t cpuClip(ClipCpuDescriptor_t desc,
void *y,
void const *x);

infiniopStatus_t cpuDestroyClipDescriptor(ClipCpuDescriptor_t desc);

#endif
55 changes: 55 additions & 0 deletions src/ops/clip/cuda/clip.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "clip.cuh"
#include "../../../devices/cuda/common_cuda.h"
#include "../../utils.h"
#include <optional>

infiniopStatus_t cudaCreateClipDescriptor(CudaHandle_t handle,
ClipCudaDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
float* lower_bound,
float* upper_bound) {
if (x->ndim != y->ndim) {
return STATUS_BAD_TENSOR_SHAPE;
}
for (uint64_t i = 0; i < x->ndim; i++) {
if (x->shape[i] != y->shape[i]) {
return STATUS_BAD_TENSOR_SHAPE;
}
}

if (!is_contiguous(y) || !is_contiguous(x)) {
return STATUS_BAD_TENSOR_STRIDES;
}

if (y->dt != x->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

uint64_t data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies<uint64_t>());

bool has_lower_bound = lower_bound != nullptr;
bool has_upper_bound = upper_bound != nullptr;
float lower_bound_ = has_lower_bound ? *lower_bound : std::numeric_limits<float>::lowest();
float upper_bound_ = has_upper_bound ? *upper_bound : std::numeric_limits<float>::max();

*desc_ptr = new ClipCudaDescriptor{
DevNvGpu,
y->dt,
handle->device_id,
data_size,
has_lower_bound,
lower_bound_,
has_upper_bound,
upper_bound_,
static_cast<uint64_t>(handle->prop.maxGridSize[0]),
};


return STATUS_SUCCESS;
}

infiniopStatus_t cudaDestroyClipDescriptor(ClipCudaDescriptor_t desc) {
delete desc;
return STATUS_SUCCESS;
}
65 changes: 65 additions & 0 deletions src/ops/clip/cuda/clip.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "../../../devices/cuda/common_cuda.h"
#include "../../utils.h"
#include "clip.cuh"
#include "status.h"

template<typename Tdata>
__global__ void clip(
Tdata const *x,
Tdata *y,
uint64_t data_size,
bool has_lower_bound,
bool has_upper_bound,
float lower_bound,
float upper_bound,
uint64_t offset) {
const uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx >= data_size) return;

Tdata value = x[idx];
if constexpr (std::is_same<Tdata, half>::value) {
value = (has_lower_bound && __hlt(value, __float2half(lower_bound))) ? __float2half(lower_bound) : value;
value = (has_upper_bound && __hgt(value, __float2half(upper_bound))) ? __float2half(upper_bound) : value;
y[idx] = value;
} else {
value = (has_lower_bound && value < lower_bound) ? lower_bound : value;
value = (has_upper_bound && value > upper_bound) ? upper_bound : value;
y[idx] = value;
}
}

template<typename Tdata>
infiniopStatus_t clip_nv_gpu(
ClipCudaDescriptor_t desc,
void const *x,
void *y,
cudaStream_t stream) {
const uint64_t data_size = desc->data_size;
const uint64_t max_grid_size = desc->max_grid_size;

auto x_ = reinterpret_cast<const Tdata *>(x);
auto y_ = reinterpret_cast<Tdata *>(y);

dim3 blockDims(std::min(static_cast<uint64_t>(256), data_size));
dim3 gridDims(std::min(ROUND_UP_DIV(data_size, blockDims.x), max_grid_size));
uint64_t step = gridDims.x * blockDims.x;

#pragma unroll
for (uint64_t offset = 0; offset < data_size; offset += step) {
clip<<<gridDims, blockDims, 0, stream>>>(
x_, y_, data_size, desc->has_lower_bound, desc->has_upper_bound, desc->lower_bound, desc->upper_bound, offset);
}

return STATUS_SUCCESS;
}

infiniopStatus_t cudaClip(ClipCudaDescriptor_t desc, void *y, const void *x, void *stream) {
checkCudaError(cudaSetDevice(desc->device_id));
if (desc->dtype == F16) {
return clip_nv_gpu<half>(desc, reinterpret_cast<const half *>(x), reinterpret_cast<half *>(y), reinterpret_cast<cudaStream_t>(stream));
}
if (desc->dtype == F32) {
return clip_nv_gpu<float>(desc, reinterpret_cast<const float *>(x), reinterpret_cast<float *>(y), reinterpret_cast<cudaStream_t>(stream));
}
return STATUS_SUCCESS;
}
38 changes: 38 additions & 0 deletions src/ops/clip/cuda/clip.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef __CUDA_CLIP_H__
#define __CUDA_CLIP_H__

#include "../../../devices/cuda/common_cuda.h"
#include "../../../devices/cuda/cuda_handle.h"
#include "operators.h"
#include <cuda_fp16.h>
#include <numeric>
#include <optional>

struct ClipCudaDescriptor {
Device device;
DT dtype;
int device_id;
uint64_t data_size;
bool has_lower_bound;
float lower_bound;
bool has_upper_bound;
float upper_bound;
uint64_t max_grid_size;
};

typedef struct ClipCudaDescriptor *ClipCudaDescriptor_t;

infiniopStatus_t cudaCreateClipDescriptor(CudaHandle_t,
ClipCudaDescriptor_t *,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
float* lower_bound,
float* upper_bound);

infiniopStatus_t cudaClip(ClipCudaDescriptor_t desc,
void *y, void const *x,
void *stream);

infiniopStatus_t cudaDestroyClipDescriptor(ClipCudaDescriptor_t desc);

#endif
78 changes: 78 additions & 0 deletions src/ops/clip/operator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "../utils.h"
#include "operators.h"
#include "ops/clip/clip.h"

#ifdef ENABLE_CPU
#include "cpu/clip_cpu.h"
#endif
#ifdef ENABLE_NV_GPU
#include "../../devices/cuda/cuda_handle.h"
#include "cuda/clip.cuh"
#endif

__C infiniopStatus_t infiniopCreateClipDescriptor(
infiniopHandle_t handle,
infiniopClipDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
float* lower_bound,
float* upper_bound) {
switch (handle->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuCreateClipDescriptor(handle, (ClipCpuDescriptor_t *) desc_ptr, y, x, lower_bound, upper_bound);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaCreateClipDescriptor((CudaHandle_t) handle, (ClipCudaDescriptor_t *) desc_ptr, y, x, lower_bound, upper_bound);
}

#endif
#ifdef ENABLE_CAMBRICON_MLU
// TODO
#endif
default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}

__C infiniopStatus_t infiniopClip(
infiniopClipDescriptor_t desc,
void *y,
void const *x,
void *stream) {
switch (desc->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuClip((ClipCpuDescriptor_t) desc, y, x);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaClip((ClipCudaDescriptor_t) desc, y, x, stream);
}

#endif
default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}

__C infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc) {
switch (desc->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyClipDescriptor((ClipCpuDescriptor_t) desc);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaDestroyClipDescriptor((ClipCudaDescriptor_t) desc);
}

#endif
default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}
119 changes: 119 additions & 0 deletions src/ops/gather/cpu/gather_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "gather_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../utils.h"

infiniopStatus_t cpuCreateGatherDescriptor(infiniopHandle_t,
GatherCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input,
infiniopTensorDescriptor_t indices,
uint64_t axis) {
if (indices->dt != I32 && indices->dt != I64) {
return STATUS_BAD_TENSOR_DTYPE;
}
if (output->dt != input->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

if (output->ndim != input->ndim + indices->ndim - 1) {
return STATUS_BAD_TENSOR_SHAPE;
}
for (int i = 0; i < output->ndim; i++) {
if (i < axis) {
if (output->shape[i] != input->shape[i]) {
return STATUS_BAD_TENSOR_SHAPE;
}
} else if (i < axis + indices->ndim) {
if (output->shape[i] != indices->shape[i - axis]) {
return STATUS_BAD_TENSOR_SHAPE;
}
} else {
if (output->shape[i] != input->shape[i - indices->ndim + 1]) {
return STATUS_BAD_TENSOR_SHAPE;
}
}
}

if (!is_contiguous(output) || !is_contiguous(input) || !is_contiguous(indices)) {
return STATUS_BAD_TENSOR_STRIDES;
}

uint64_t axis_tmp = axis;
if (axis_tmp < 0) {
axis_tmp = output->ndim + axis;
}

uint64_t pre_size = std::accumulate(input->shape, input->shape + axis_tmp, 1ULL, std::multiplies<uint64_t>());
uint64_t post_size = std::accumulate(input->shape + axis_tmp + 1, input->shape + input->ndim, 1ULL, std::multiplies<uint64_t>());
uint64_t indices_size = std::accumulate(indices->shape, indices->shape + indices->ndim, 1ULL, std::multiplies<uint64_t>());
uint64_t axis_size = input->shape[axis_tmp];

*desc_ptr = new GatherCpuDescriptor{
DevCpu,
output->dt,
indices->dt,
pre_size,
axis_size,
indices_size,
post_size,
};
return STATUS_SUCCESS;
}

infiniopStatus_t cpuDestroyGatherDescriptor(GatherCpuDescriptor_t desc) {
delete desc;
return STATUS_SUCCESS;
}

template<typename Tdata, typename Tind>
infiniopStatus_t gather_cpu(GatherCpuDescriptor_t desc, void *output, void const *input, void const *indices) {
auto input_ = reinterpret_cast<Tdata const *>(input);
auto output_ = reinterpret_cast<Tdata *>(output);
auto indices_ = reinterpret_cast<Tind const *>(indices);

uint64_t pre_size = desc->pre_size;
uint64_t post_size = desc->post_size;
uint64_t indices_size = desc->indices_size;
uint64_t axis_size = desc->axis_size;
if (post_size == 1) {
#pragma omp parallel for collapse(2)
for (uint64_t i = 0; i < pre_size; i++) {
for (uint64_t j = 0; j < indices_size; j++) {
uint64_t output_offset = i * indices_size * post_size + j * post_size;
uint64_t input_offset = i * axis_size * post_size + indices_[j] * post_size;
output_[output_offset] = input_[input_offset];
}
}
} else {
#pragma omp parallel for collapse(2)
for (uint64_t i = 0; i < pre_size; i++) {
for (uint64_t j = 0; j < indices_size; j++) {
uint64_t output_offset = i * indices_size * post_size + j * post_size;
uint64_t input_offset = i * axis_size * post_size + indices_[j] * post_size;
std::memcpy(output_ + output_offset, input_ + input_offset, post_size * sizeof(Tdata));
}
}
}

return STATUS_SUCCESS;
}

infiniopStatus_t cpuGather(GatherCpuDescriptor_t desc, void *output, void const *input, void const *indices, void *stream) {
if (desc->dtype == F16) {
if (desc->indices_dtype == I32) {
return gather_cpu<uint16_t, int32_t>(desc, output, input, indices);
}
if (desc->indices_dtype == I64) {
return gather_cpu<uint16_t, int64_t>(desc, output, input, indices);
}
}
if (desc->dtype == F32) {
if (desc->indices_dtype == I32) {
return gather_cpu<float, int32_t>(desc, output, input, indices);
}
if (desc->indices_dtype == I64) {
return gather_cpu<float, int64_t>(desc, output, input, indices);
}
}
return STATUS_BAD_TENSOR_DTYPE;
}
35 changes: 35 additions & 0 deletions src/ops/gather/cpu/gather_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef __CPU_GATHER_H__
#define __CPU_GATHER_H__

#include "operators.h"
#include <cstring>
#include <numeric>

struct GatherCpuDescriptor {
Device device;
DT dtype;
DT indices_dtype;
uint64_t pre_size;
uint64_t axis_size;
uint64_t indices_size;
uint64_t post_size;
};

typedef struct GatherCpuDescriptor *GatherCpuDescriptor_t;

infiniopStatus_t cpuCreateGatherDescriptor(infiniopHandle_t,
GatherCpuDescriptor_t *,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input,
infiniopTensorDescriptor_t indices,
uint64_t axis);

infiniopStatus_t cpuGather(GatherCpuDescriptor_t desc,
void *output,
void const *input,
void const *indices,
void *stream);

infiniopStatus_t cpuDestroyGatherDescriptor(GatherCpuDescriptor_t desc);

#endif
82 changes: 82 additions & 0 deletions src/ops/gather/cuda/gather.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "gather.cuh"
#include "../../../devices/cuda/common_cuda.h"
#include "../../utils.h"

infiniopStatus_t cudaCreateGatherDescriptor(CudaHandle_t handle,
GatherCudaDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input,
infiniopTensorDescriptor_t indices,
uint64_t axis) {
if (indices->dt != I32 && indices->dt != I64) {
return STATUS_BAD_TENSOR_DTYPE;
}
if (output->dt != input->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

if (output->ndim != input->ndim + indices->ndim - 1) {
return STATUS_BAD_TENSOR_SHAPE;
}
for (int i = 0; i < output->ndim; i++) {
if (i < axis) {
if (output->shape[i] != input->shape[i]) {
return STATUS_BAD_TENSOR_SHAPE;
}
} else if (i < axis + indices->ndim) {
if (output->shape[i] != indices->shape[i - axis]) {
return STATUS_BAD_TENSOR_SHAPE;
}
} else {
if (output->shape[i] != input->shape[i - indices->ndim + 1]) {
return STATUS_BAD_TENSOR_SHAPE;
}
}
}

if (!is_contiguous(output) || !is_contiguous(input) || !is_contiguous(indices)) {
return STATUS_BAD_TENSOR_STRIDES;
}

uint64_t axis_tmp = axis;
if (axis_tmp < 0) {
axis_tmp = output->ndim + axis;
}

uint64_t output_size = std::accumulate(output->shape, output->shape + output->ndim, 1ULL, std::multiplies<uint64_t>());
uint64_t pre_size = std::accumulate(input->shape, input->shape + axis_tmp, 1ULL, std::multiplies<uint64_t>());
uint64_t post_size = std::accumulate(input->shape + axis_tmp + 1, input->shape + input->ndim, 1ULL, std::multiplies<uint64_t>());
uint64_t indices_size = std::accumulate(indices->shape, indices->shape + indices->ndim, 1ULL, std::multiplies<uint64_t>());
uint64_t axis_size = input->shape[axis_tmp];

int kernel_type = 0;
uint64_t sizes[3] = {pre_size * indices_size, indices_size * post_size, pre_size * post_size};
for(int i = 1; i < 3; i++) {
if(sizes[i] > sizes[kernel_type]) {
kernel_type = i;
}
}


*desc_ptr = new GatherCudaDescriptor{
DevNvGpu,
output->dt,
indices->dt,
handle->device_id,
output_size,
pre_size,
axis_size,
indices_size,
post_size,
kernel_type,
static_cast<uint64_t>(handle->prop.maxGridSize[0]),
};


return STATUS_SUCCESS;
}

infiniopStatus_t cudaDestroyGatherDescriptor(GatherCudaDescriptor_t desc) {
delete desc;
return STATUS_SUCCESS;
}
182 changes: 182 additions & 0 deletions src/ops/gather/cuda/gather.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include "../../../devices/cuda/common_cuda.h"
#include "../../utils.h"
#include "gather.cuh"

constexpr int num_elem_per_thread = 2;

template<typename Tdata, typename Tind>
__global__ void gather_0(
Tdata const *data,
Tind const *indices,
uint64_t indices_size,
Tdata *output,
uint64_t pre_size,
uint64_t axis_size,
uint64_t post_size,
uint64_t offset_x,
uint64_t offset_y) {
const uint64_t x = blockIdx.x * blockDim.x + threadIdx.x + offset_x;
const uint64_t y = blockIdx.y * blockDim.y + threadIdx.y + offset_y;
if (x >= pre_size * indices_size || y >= post_size) return;

for (uint64_t i = 0; i < num_elem_per_thread; i++) {
const uint64_t idx = x * num_elem_per_thread + i;
if (idx >= pre_size * indices_size) return;

const uint64_t pre_idx = idx / indices_size;
const uint64_t indices_idx = idx % indices_size;
const uint64_t post_idx = y;

const uint64_t data_idx = pre_idx * axis_size * post_size + indices[indices_idx] * post_size + post_idx;
const uint64_t output_idx = pre_idx * indices_size * post_size + indices_idx * post_size + post_idx;
output[output_idx] = data[data_idx];
}
}

template<typename Tdata, typename Tind>
__global__ void gather_1(
Tdata const *data,
Tind const *indices,
uint64_t indices_size,
Tdata *output,
uint64_t pre_size,
uint64_t axis_size,
uint64_t post_size,
uint64_t offset_x,
uint64_t offset_y) {
const uint64_t x = blockIdx.x * blockDim.x + threadIdx.x + offset_x;
const uint64_t y = blockIdx.y * blockDim.y + threadIdx.y + offset_y;
if (x >= indices_size * post_size || y >= pre_size) return;

for (uint64_t i = 0; i < num_elem_per_thread; i++) {
const uint64_t idx = x * num_elem_per_thread + i;
if (idx >= indices_size * post_size) return;

const uint64_t pre_idx = y;
const uint64_t indices_idx = idx / post_size;
const uint64_t post_idx = idx % post_size;

const uint64_t data_idx = pre_idx * axis_size * post_size + indices[indices_idx] * post_size + post_idx;
const uint64_t output_idx = pre_idx * indices_size * post_size + indices_idx * post_size + post_idx;
output[output_idx] = data[data_idx];
}
}

template<typename Tdata, typename Tind>
__global__ void gather_2(
Tdata const *data,
Tind const *indices,
uint64_t indices_size,
Tdata *output,
uint64_t pre_size,
uint64_t axis_size,
uint64_t post_size,
uint64_t offset_x,
uint64_t offset_y) {
const uint64_t x = blockIdx.x * blockDim.x + threadIdx.x + offset_x;
const uint64_t y = blockIdx.y * blockDim.y + threadIdx.y + offset_y;
if (x >= pre_size * post_size || y >= indices_size) return;

for (uint64_t i = 0; i < num_elem_per_thread; i++) {
const uint64_t idx = x * num_elem_per_thread + i;
if (idx >= pre_size * post_size) return;

const uint64_t pre_idx = idx / post_size;
const uint64_t indices_idx = y;
const uint64_t post_idx = idx % post_size;

const uint64_t data_idx = pre_idx * axis_size * post_size + indices[indices_idx] * post_size + post_idx;
const uint64_t output_idx = pre_idx * indices_size * post_size + indices_idx * post_size + post_idx;
output[output_idx] = data[data_idx];
}
}

template<typename Tdata, typename Tind>
infiniopStatus_t gather_nv_gpu(GatherCudaDescriptor_t desc, void *output, void const *input, void const *indices, void *stream) {
if (desc->output_size == 0) {
return STATUS_SUCCESS;
}

const auto input_ = reinterpret_cast<Tdata const *>(input);
const auto output_ = reinterpret_cast<Tdata *>(output);
const auto indices_ = reinterpret_cast<Tind const *>(indices);
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);

uint64_t pre_size = desc->pre_size;
uint64_t axis_size = desc->axis_size;
uint64_t post_size = desc->post_size;
uint64_t indices_size = desc->indices_size;

int kernel_type = desc->kernel_type;
switch (kernel_type) {
case 0: {
dim3 block_size(std::min(static_cast<uint64_t>(64), pre_size * indices_size / num_elem_per_thread),
std::min(static_cast<uint64_t>(16), post_size));
dim3 grid_size(std::min(ROUND_UP_DIV(pre_size * indices_size / num_elem_per_thread, block_size.x), desc->max_grid_size),
std::min(ROUND_UP_DIV(post_size, block_size.y), desc->max_grid_size));
uint64_t step_x = grid_size.x * block_size.x;
uint64_t step_y = grid_size.y * block_size.y;
#pragma unroll
for(uint64_t i = 0; i < pre_size * indices_size; i += step_x) {
for(uint64_t j = 0; j < post_size; j += step_y) {
gather_0<Tdata, Tind><<<grid_size, block_size, 0, cuda_stream>>>(
input_, indices_, indices_size, output_, pre_size, axis_size, post_size, i, j);
}
}
}
break;
case 1: {
dim3 block_size(std::min(static_cast<uint64_t>(64), indices_size * post_size / num_elem_per_thread),
std::min(static_cast<uint64_t>(16), pre_size));
dim3 grid_size(std::min(ROUND_UP_DIV(indices_size * post_size / num_elem_per_thread, block_size.x), desc->max_grid_size),
std::min(ROUND_UP_DIV(pre_size, block_size.y), desc->max_grid_size));
uint64_t step_x = grid_size.x * block_size.x;
uint64_t step_y = grid_size.y * block_size.y;
#pragma unroll
for(uint64_t i = 0; i < indices_size * post_size; i += step_x) {
for(uint64_t j = 0; j < pre_size; j += step_y) {
gather_1<Tdata, Tind><<<grid_size, block_size, 0, cuda_stream>>>(
input_, indices_, indices_size, output_, pre_size, axis_size, post_size, i, j);
}
}
}
break;
case 2: {
dim3 block_size(std::min(static_cast<uint64_t>(64), pre_size * post_size / num_elem_per_thread),
std::min(static_cast<uint64_t>(16), indices_size));
dim3 grid_size(std::min(ROUND_UP_DIV(pre_size * post_size / num_elem_per_thread, block_size.x), desc->max_grid_size),
std::min(ROUND_UP_DIV(indices_size, block_size.y), desc->max_grid_size));
uint64_t step_x = grid_size.x * block_size.x;
uint64_t step_y = grid_size.y * block_size.y;
#pragma unroll
for(uint64_t i = 0; i < pre_size * post_size; i += step_x) {
for(uint64_t j = 0; j < indices_size; j += step_y) {
gather_2<Tdata, Tind><<<grid_size, block_size, 0, cuda_stream>>>(
input_, indices_, indices_size, output_, pre_size, axis_size, post_size, i, j);
}
}
}
}
return STATUS_SUCCESS;
}

infiniopStatus_t cudaGather(GatherCudaDescriptor_t desc, void *output, void const *input, void const *indices, void *stream) {
checkCudaError(cudaSetDevice(desc->device_id));
if (desc->dtype == F16) {
if (desc->indices_dtype == I32) {
return gather_nv_gpu<half, int32_t>(desc, output, input, indices, stream);
}
if (desc->indices_dtype == I64) {
return gather_nv_gpu<half, int64_t>(desc, output, input, indices, stream);
}
}
if (desc->dtype == F32) {
if (desc->indices_dtype == I32) {
return gather_nv_gpu<float, int32_t>(desc, output, input, indices, stream);
}
if (desc->indices_dtype == I64) {
return gather_nv_gpu<float, int64_t>(desc, output, input, indices, stream);
}
}
return STATUS_BAD_TENSOR_DTYPE;
}
39 changes: 39 additions & 0 deletions src/ops/gather/cuda/gather.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef __CUDA_GATHER_H__
#define __CUDA_GATHER_H__

#include "../../../devices/cuda/common_cuda.h"
#include "../../../devices/cuda/cuda_handle.h"
#include "operators.h"
#include <cuda_fp16.h>
#include <numeric>

struct GatherCudaDescriptor {
Device device;
DT dtype;
DT indices_dtype;
int device_id;
uint64_t output_size;
uint64_t pre_size;
uint64_t axis_size;
uint64_t indices_size;
uint64_t post_size;
int kernel_type;
uint64_t max_grid_size;
};

typedef struct GatherCudaDescriptor *GatherCudaDescriptor_t;

infiniopStatus_t cudaCreateGatherDescriptor(CudaHandle_t,
GatherCudaDescriptor_t *,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input,
infiniopTensorDescriptor_t indices,
uint64_t axis);

infiniopStatus_t cudaGather(GatherCudaDescriptor_t desc,
void *output, void const *input,
void const *indices, void *stream);

infiniopStatus_t cudaDestroyGatherDescriptor(GatherCudaDescriptor_t desc);

#endif
79 changes: 79 additions & 0 deletions src/ops/gather/operator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "../utils.h"
#include "operators.h"
#include "ops/gather/gather.h"

#ifdef ENABLE_CPU
#include "cpu/gather_cpu.h"
#endif
#ifdef ENABLE_NV_GPU
#include "../../devices/cuda/cuda_handle.h"
#include "cuda/gather.cuh"
#endif

__C infiniopStatus_t infiniopCreateGatherDescriptor(
infiniopHandle_t handle,
infiniopGatherDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input,
infiniopTensorDescriptor_t indices,
uint64_t axis) {
switch (handle->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuCreateGatherDescriptor(handle, (GatherCpuDescriptor_t *) desc_ptr, output, input, indices, axis);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaCreateGatherDescriptor((CudaHandle_t) handle, (GatherCudaDescriptor_t *) desc_ptr, output, input, indices, axis);
}

#endif
#ifdef ENABLE_CAMBRICON_MLU
// TODO
#endif
default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}

__C infiniopStatus_t infiniopGather(
infiniopGatherDescriptor_t desc,
void *output,
void const *input,
void const *indices,
void *stream) {
switch (desc->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuGather((GatherCpuDescriptor_t) desc, output, input, indices, stream);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaGather((GatherCudaDescriptor_t) desc, output, input, indices, stream);
}

#endif
default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}

__C infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc) {
switch (desc->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyGatherDescriptor((GatherCpuDescriptor_t) desc);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaDestroyGatherDescriptor((GatherCudaDescriptor_t) desc);
}

#endif
default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}
294 changes: 294 additions & 0 deletions src/ops/reduce/cpu/reduce_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
#include "reduce_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../utils.h"

infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t,
ReduceCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim,
int reduce_op) {
if (!is_contiguous(y) || !is_contiguous(x)) {
return STATUS_BAD_TENSOR_STRIDES;
}

if (y->dt != x->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

bool use_1Dreduce = axes_ndim == x->ndim; // means all axes are reduced, seem as 1D reduce to avoid out_ndim=0
bool is_y_scalar = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies<uint64_t>()) == 1;
if(use_1Dreduce != is_y_scalar) {
return STATUS_BAD_TENSOR_SHAPE;
}

std::vector<int> axes_, out_of_axes;

uint64_t ndim = x->ndim;
uint64_t out_ndim = ndim - axes_ndim;
uint64_t axes_size = 1;
uint64_t out_size = 1;

for(uint64_t i = 0; i < ndim; i++) {
bool is_axis = false;
for(uint64_t j = 0; j < axes_ndim; j++) {
if (axes[j] == i) {
is_axis = true;
break;
}
}
if (is_axis) {
axes_size *= x->shape[i];
axes_.emplace_back(i);
} else {
out_size *= x->shape[i];
out_of_axes.emplace_back(i);
}
}
sort(axes_.begin(), axes_.end());
sort(out_of_axes.begin(), out_of_axes.end());

std::vector<int> axes_strides(axes_ndim, 1);
std::vector<int> out_strides(out_ndim, 1);

for(uint64_t i = 0; i < axes_ndim; i++) {
for(uint64_t j = i + 1; j < axes_ndim; j++) {
axes_strides[i] *= x->shape[axes[j]];
}
}
for(uint64_t i = 0; i < out_ndim; i++) {
for(uint64_t j = i + 1; j < out_ndim; j++) {
out_strides[i] *= x->shape[out_of_axes[j]];
}
}

int64_t *strides = new int64_t[ndim];
std::memcpy(strides, x->strides, ndim * sizeof(int64_t));


*desc_ptr = new ReduceCpuDescriptor{
DevCpu,
x->dt,
use_1Dreduce,
axes_,
out_of_axes,
ndim,
axes_ndim,
out_ndim,
axes_size,
out_size,
strides,
axes_strides,
out_strides,
reduce_op
};

return STATUS_SUCCESS;
}

infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc) {
delete[] (desc->strides);
delete desc;
return STATUS_SUCCESS;
}


template<typename Tdata>
infiniopStatus_t reduce_cpu_1D(ReduceCpuDescriptor_t desc, void *y, void const *x) {
auto x_ = reinterpret_cast<Tdata const *>(x);
auto y_ = reinterpret_cast<Tdata *>(y);
auto data_size_ = desc->out_size * desc->axes_size;
auto reduce_op_ = desc->reduce_op;

if constexpr (std::is_same<Tdata, uint16_t>::value) {
switch (reduce_op_) {
case 0: { // ReduceMin
float result = std::numeric_limits<float>::max();
for (uint64_t i = 0; i < data_size_; i++) {
result = std::min(result, f16_to_f32(x_[i]));
}
y_[0] = f32_to_f16(result);
}
break;
case 1: { // ReduceMax
float result = std::numeric_limits<float>::lowest();
for (uint64_t i = 0; i < data_size_; i++) {
result = std::max(result, f16_to_f32(x_[i]));
}
y_[0] = f32_to_f16(result);
}
break;
case 2: { // ReduceMean
float sum = 0;
for (uint64_t i = 0; i < data_size_; i++) {
sum += f16_to_f32(x_[i]);
}
y_[0] = f32_to_f16(sum / data_size_);
}
break;
}
} else {
switch (reduce_op_) {
case 0: { // ReduceMin
Tdata result = std::numeric_limits<Tdata>::max();
for (uint64_t i = 0; i < data_size_; i++) {
result = std::min(result, x_[i]);
}
y_[0] = result;
}
break;
case 1: { // ReduceMax
Tdata result = std::numeric_limits<Tdata>::lowest();
for (uint64_t i = 0; i < data_size_; i++) {
result = std::max(result, x_[i]);
}
y_[0] = result;
}
break;
case 2: { // ReduceMean
Tdata sum = 0;
for (uint64_t i = 0; i < data_size_; i++) {
sum += x_[i];
}
y_[0] = sum / data_size_;
}
break;
}
}
return STATUS_SUCCESS;
}

template<typename Tdata>
infiniopStatus_t reduce_cpu(ReduceCpuDescriptor_t desc, void *y, void const *x) {
auto x_ = reinterpret_cast<Tdata const *>(x);
auto y_ = reinterpret_cast<Tdata *>(y);
auto axes_ = desc->axes;
auto out_of_axes_ = desc->out_of_axes;
auto ndim_ = desc->ndim;
auto axes_ndim_ = desc->axes_ndim;
auto out_ndim_ = desc->out_ndim;
auto axes_size_ = desc->axes_size;
auto out_size_ = desc->out_size;
auto strides_ = desc->strides;
auto axes_strides_ = desc->axes_strides;
auto out_strides_ = desc->out_strides;
auto reduce_op_ = desc->reduce_op;

#pragma omp parallel for
for(uint64_t i = 0; i < out_size_; i++) {
uint64_t idx = 0;
uint64_t temp_i = i;
for(uint64_t j = 0; j < out_ndim_; j++) {
idx += temp_i / out_strides_[j] * strides_[out_of_axes_[j]];
temp_i %= out_strides_[j];
}

float result;
switch (reduce_op_) {
case 0: { // ReduceMin
result = std::numeric_limits<float>::max();
}
break;
case 1: { // ReduceMax
result = std::numeric_limits<float>::lowest();
}
break;
case 2: { // ReduceMean
result = 0;
}
break;
}

for(uint64_t j = 0; j < axes_size_; j++) {
uint64_t idx_ = idx;
uint64_t temp_j = j;
for(uint64_t k = 0; k < axes_ndim_; k++) {
idx_ += temp_j / axes_strides_[k] * strides_[axes_[k]];
temp_j %= axes_strides_[k];
}

if constexpr (std::is_same<Tdata, uint16_t>::value) {
switch (reduce_op_) {
case 0: { // ReduceMin
result = std::min(result, f16_to_f32(x_[idx_]));
}
break;
case 1: { // ReduceMax
result = std::max(result, f16_to_f32(x_[idx_]));
}
break;
case 2: { // ReduceMean
result += f16_to_f32(x_[idx_]);
}
break;
}
} else {
switch (reduce_op_) {
case 0: { // ReduceMin
result = std::min(result, x_[idx_]);
}
break;
case 1: { // ReduceMax
result = std::max(result, x_[idx_]);
}
break;
case 2: { // ReduceMean
result += x_[idx_];
}
break;
}
}
}
if constexpr (std::is_same<Tdata, uint16_t>::value) {
switch (reduce_op_) {
case 0: { // ReduceMin
y_[i] = f32_to_f16(result);
}
break;
case 1: { // ReduceMax
y_[i] = f32_to_f16(result);
}
break;
case 2: { // ReduceMean
y_[i] = f32_to_f16(result / axes_size_);
}
break;
}
} else {
switch (reduce_op_) {
case 0: { // ReduceMin
y_[i] = result;
}
break;
case 1: { // ReduceMax
y_[i] = result;
}
break;
case 2: { // ReduceMean
y_[i] = result / axes_size_;
}
break;
}
}
}
return STATUS_SUCCESS;
}

infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc, void *y, void const *x, void *stream) {
if (desc->dtype == F16) {
if(desc->use_1Dreduce) {
return reduce_cpu_1D<uint16_t>(desc, y, x);
} else {
return reduce_cpu<uint16_t>(desc, y, x);
}
}
if (desc->dtype == F32) {
if(desc->use_1Dreduce) {
return reduce_cpu_1D<float>(desc, y, x);
} else {
return reduce_cpu<float>(desc, y, x);
}
}
return STATUS_BAD_TENSOR_DTYPE;
}
43 changes: 43 additions & 0 deletions src/ops/reduce/cpu/reduce_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef __CPU_REDUCE_H__
#define __CPU_REDUCE_H__

#include "operators.h"
#include <cstring>
#include <vector>
#include <numeric>

struct ReduceCpuDescriptor {
Device device;
DT dtype;
bool use_1Dreduce;
std::vector<int> axes;
std::vector<int> out_of_axes;
uint64_t ndim;
uint64_t axes_ndim;
uint64_t out_ndim;
uint64_t axes_size;
uint64_t out_size;
int64_t *strides;
std::vector<int> axes_strides;
std::vector<int> out_strides;
int reduce_op;
};

typedef struct ReduceCpuDescriptor *ReduceCpuDescriptor_t;

infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t,
ReduceCpuDescriptor_t *,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim,
int reduce_op);

infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc,
void *y,
void const *x,
void *stream);

infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc);

#endif
141 changes: 141 additions & 0 deletions src/ops/reduce/cuda/reduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#include "reduce.cuh"
#include "../../../devices/cuda/common_cuda.h"
#include "../../utils.h"

infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t handle,
ReduceCudaDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim,
int reduce_op) {
if (y->dt != F16 && y->dt != F32) {
return STATUS_BAD_TENSOR_DTYPE;
}
if (y->dt != x->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

if (!is_contiguous(y) || !is_contiguous(x)) {
return STATUS_BAD_TENSOR_STRIDES;
}

uint64_t ndim = x->ndim;
if (ndim > 8) {
return STATUS_BAD_TENSOR_SHAPE;
}

int x_shape[CUDNN_DIM_MAX];
for (uint64_t i = 0; i < ndim; ++i) {
x_shape[i] = static_cast<int>(x->shape[i]);
}

int x_strides[CUDNN_DIM_MAX];
for (uint64_t i = 0; i < ndim; ++i) {
x_strides[i] = static_cast<int>(x->strides[i]);
}

int y_shape[CUDNN_DIM_MAX];
for (uint64_t i = 0; i < ndim; ++i) {
bool is_axis = false;
for (uint64_t j = 0; j < axes_ndim; ++j) {
if (axes[j] == i) {
is_axis = true;
break;
}
}
if (is_axis) {
y_shape[i] = 1;
} else {
y_shape[i] = static_cast<int>(x->shape[i]);
}
}

int y_strides[CUDNN_DIM_MAX];
int stride = 1;
for (uint64_t i = ndim; i > 0; --i) {
y_strides[i - 1] = stride;
stride *= y_shape[i - 1];
}

CREATE_CHECK_ERROR(auto dt = dataTypeMap[x->dt], dt, -1, STATUS_BAD_PARAM);
cudnnDataType_t cudnn_dt = [&] {
switch (dt) {
case CUDNN_DATA_HALF:
if (handle->compute_capability_major > 5 || (handle->compute_capability_major == 5 && handle->compute_capability_minor >= 3)) {
return CUDNN_DATA_HALF;
}
return CUDNN_DATA_FLOAT;
case CUDNN_DATA_BFLOAT16:
case CUDNN_DATA_FLOAT:
return CUDNN_DATA_FLOAT;
case CUDNN_DATA_DOUBLE:
return CUDNN_DATA_DOUBLE;
default:
return CUDNN_DATA_INT32;
}
}();

cudnnTensorDescriptor_t x_desc;
checkCudnnError(cudnnCreateTensorDescriptor(&x_desc));
cudnnTensorDescriptor_t y_desc;
checkCudnnError(cudnnCreateTensorDescriptor(&y_desc));
if (ndim > 4) {
checkCudnnError(cudnnSetTensorNdDescriptor(x_desc, cudnn_dt, ndim, x_shape, x_strides));
checkCudnnError(cudnnSetTensorNdDescriptor(y_desc, cudnn_dt, ndim, y_shape, y_strides));
} else {
int x_shape_[4] = {1, 1, 1, 1};
int y_shape_[4] = {1, 1, 1, 1};
for (int i = 0; i < ndim; ++i) {
x_shape_[4 - i - 1] = x_shape[ndim - i - 1];
y_shape_[4 - i - 1] = y_shape[ndim - i - 1];
}
checkCudnnError(cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dt, x_shape_[0], x_shape_[1], x_shape_[2], x_shape_[3]));
checkCudnnError(cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dt, y_shape_[0], y_shape_[1], y_shape_[2], y_shape_[3]));
}

cudnnReduceTensorDescriptor_t reduce_desc;
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduce_desc));
cudnnReduceTensorOp_t reduce_op_ = [&] {
switch (reduce_op) {
case 0:
return CUDNN_REDUCE_TENSOR_MIN;
case 1:
return CUDNN_REDUCE_TENSOR_MAX;
case 2:
return CUDNN_REDUCE_TENSOR_AVG;
}
}();
checkCudnnError(cudnnSetReduceTensorDescriptor(reduce_desc, reduce_op_, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));

uint64_t workspace_size = 0;
if (use_cudnn(handle->cudnn_handles_t, handle->device_id, nullptr,
[&](cudnnHandle_t handle) { return cudnnGetReductionWorkspaceSize(handle, reduce_desc, x_desc, y_desc, &workspace_size); }) != CUDNN_STATUS_SUCCESS) {
return STATUS_EXECUTION_FAILED;
}
*desc_ptr = new ReduceCudaDescriptor{
DevNvGpu,
y->dt,
handle->device_id,
handle->cudnn_handles_t,
x_desc,
y_desc,
reduce_desc,
workspace_size};

return STATUS_SUCCESS;
}

infiniopStatus_t cudaGetReduceWorkspaceSize(ReduceCudaDescriptor_t desc, uint64_t *size) {
*size = desc->workspace_size;
return STATUS_SUCCESS;
}

infiniopStatus_t cudaDestroyReduceDescriptor(ReduceCudaDescriptor_t desc) {
checkCudnnError(cudnnDestroyReduceTensorDescriptor(desc->reduce_desc));
checkCudnnError(cudnnDestroyTensorDescriptor(desc->y_desc));
checkCudnnError(cudnnDestroyTensorDescriptor(desc->x_desc));
desc->cudnn_handles_t = nullptr;
delete desc;
return STATUS_SUCCESS;
}
24 changes: 24 additions & 0 deletions src/ops/reduce/cuda/reduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "../../../devices/cuda/common_cuda.h"
#include "../../utils.h"
#include "reduce.cuh"

infiniopStatus_t reduce_nv_gpu(ReduceCudaDescriptor_t desc, void *workspace, uint64_t workspace_size,
void *y, void const *x, void *stream) {
checkCudaError(cudaSetDevice(desc->device_id));
float alpha = 1.f;
float beta = 0.f;

checkCudnnError(use_cudnn(desc->cudnn_handles_t, desc->device_id, (cudaStream_t) stream,
[&](cudnnHandle_t handle) { return cudnnReduceTensor(handle, desc->reduce_desc, nullptr, 0, workspace, (size_t)workspace_size, &alpha, desc->x_desc, x, &beta, desc->y_desc, y); }));
return STATUS_SUCCESS;
}

infiniopStatus_t cudaReduce(ReduceCudaDescriptor_t desc,
void *workspace, uint64_t workspace_size,
void *y, void const *x,
void *stream) {
if (desc->dtype == F16 || desc->dtype == F32) {
return reduce_nv_gpu(desc, workspace, workspace_size, y, x, stream);
}
return STATUS_BAD_TENSOR_DTYPE;
}
42 changes: 42 additions & 0 deletions src/ops/reduce/cuda/reduce.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef __CUDA_REDUCE_H__
#define __CUDA_REDUCE_H__

#include "../../../devices/cuda/common_cuda.h"
#include "../../../devices/cuda/cuda_handle.h"
#include "operators.h"
#include <cuda_fp16.h>
#include <numeric>
#include <cudnn.h>

struct ReduceCudaDescriptor {
Device device;
DT dtype;
int device_id;
std::shared_ptr<Pool<cudnnHandle_t>> cudnn_handles_t;
cudnnTensorDescriptor_t x_desc;
cudnnTensorDescriptor_t y_desc;
cudnnReduceTensorDescriptor_t reduce_desc;
uint64_t workspace_size;
};

typedef struct ReduceCudaDescriptor *ReduceCudaDescriptor_t;

infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t,
ReduceCudaDescriptor_t *,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim,
int reduce_op);

infiniopStatus_t cudaGetReduceWorkspaceSize(ReduceCudaDescriptor_t desc,
uint64_t *size);

infiniopStatus_t cudaReduce(ReduceCudaDescriptor_t desc,
void *workspace, uint64_t workspace_size,
void *y, void const *x,
void *stream);

infiniopStatus_t cudaDestroyReduceDescriptor(ReduceCudaDescriptor_t desc);

#endif
88 changes: 88 additions & 0 deletions src/ops/reduce/operator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "../utils.h"
#include "operators.h"
#include "reduce.h"

#ifdef ENABLE_CPU
#include "cpu/reduce_cpu.h"
#endif
#ifdef ENABLE_NV_GPU
#include "../../devices/cuda/cuda_handle.h"
#include "cuda/reduce.cuh"
#endif

__C infiniopStatus_t infiniopCreateReduceDescriptor(
infiniopHandle_t handle,
infiniopReduceDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim,
int reduce_op) {
switch (handle->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuCreateReduceDescriptor(handle, (ReduceCpuDescriptor_t *) desc_ptr, y, x, axes, axes_ndim, reduce_op);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaCreateReduceDescriptor((CudaHandle_t) handle, (ReduceCudaDescriptor_t *) desc_ptr, y, x, axes, axes_ndim, reduce_op);
}

#endif
}
return STATUS_BAD_DEVICE;
}

__C infiniopStatus_t infiniopGetReduceWorkspaceSize(infiniopReduceDescriptor_t desc, uint64_t *size) {
// std::cout << desc->device << std::endl;
switch (desc->device) {
#ifdef ENABLE_CPU
case DevCpu: {
return STATUS_SUCCESS;
}
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaGetReduceWorkspaceSize((ReduceCudaDescriptor_t) desc, size);
}

#endif

default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}

__C infiniopStatus_t infiniopReduce(infiniopReduceDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void *stream) {
switch (desc->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuReduce((ReduceCpuDescriptor_t) desc, y, x, stream);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu:
return cudaReduce((ReduceCudaDescriptor_t) desc, workspace, workspace_size, y, x, stream);

#endif
default:
return STATUS_BAD_DEVICE;
}
return STATUS_BAD_DEVICE;
}

__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc) {
switch (desc->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyReduceDescriptor((ReduceCpuDescriptor_t) desc);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaDestroyReduceDescriptor((ReduceCudaDescriptor_t) desc);
}

#endif
}
return STATUS_BAD_DEVICE;
}
28 changes: 28 additions & 0 deletions src/ops/reduce/reduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef REDUCE_H
#define REDUCE_H

#include "export.h"
#include "operators.h"

typedef struct ReduceDescriptor {
Device device;
} ReduceDescriptor;

typedef ReduceDescriptor *infiniopReduceDescriptor_t;

__C infiniopStatus_t infiniopCreateReduceDescriptor(infiniopHandle_t handle,
infiniopReduceDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
int *axes,
uint64_t axes_ndim,
int reduce_op);

__C infiniopStatus_t infiniopGetReduceWorkspaceSize(infiniopReduceDescriptor_t desc, uint64_t *size);

__C infiniopStatus_t infiniopReduce(infiniopReduceDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void *stream);

__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc);


#endif
Loading