Skip to content

Add ConvBiasAct #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/infini_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#include "ops/attention/attention.h"
#include "ops/avg_pool/avg_pool.h"
#include "ops/causal_softmax/causal_softmax.h"
#include "ops/global_avg_pool/global_avg_pool.h"
#include "ops/conv/conv.h"
#include "ops/conv_act/conv_act.h"
#include "ops/expand/expand.h"
#include "ops/gemm/gemm.h"
#include "ops/conv/conv.h"
#include "ops/global_avg_pool/global_avg_pool.h"
#include "ops/matmul/matmul.h"
#include "ops/max_pool/max_pool.h"
#include "ops/mlp/mlp.h"
Expand Down
25 changes: 25 additions & 0 deletions include/ops/activations.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef __ACTIVATIONS_H__
#define __ACTIVATIONS_H__

/**
* @brief Specifies the type of activation function
*/
typedef enum InfiniActivationMode {
// activation functions
INFINI_ACTIVATION_IDENTITY = 0,
INFINI_ACTIVATION_RELU = 1,
INFINI_ACTIVATION_LEAKY_RELU = 2,
INFINI_ACTIVATION_CLIPPED_RELU = 3,
INFINI_ACTIVATION_SIGMOID = 4,
INFINI_ACTIVATION_HEAVISIDE_STEP = 5,
INFINI_ACTIVATION_ELU = 6,
INFINI_ACTIVATION_GELU = 7,
INFINI_ACTIVATION_SIN = 8,
INFINI_ACTIVATION_TANH = 9,

// Count
// NOTE: new activation functions should add before "Count"
INFINI_ACTIVATION_COUNT,
} InfiniActivationMode_t;

#endif
9 changes: 5 additions & 4 deletions include/ops/conv/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ __C __export infiniopStatus_t infiniopCreateConvDescriptor(infiniopHandle_t hand
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t w,
void *pads,
void *strides,
void *dilations,
infiniopTensorDescriptor_t b,
uint64_t const *pads,
int64_t const *strides,
uint64_t const *dilations,
uint64_t n);

__C __export infiniopStatus_t infiniopGetConvWorkspaceSize(infiniopConvDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopConv(infiniopConvDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void *stream);
__C __export infiniopStatus_t infiniopConv(infiniopConvDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void const *b, void *stream);

__C __export infiniopStatus_t infiniopDestroyConvDescriptor(infiniopConvDescriptor_t desc);

Expand Down
55 changes: 55 additions & 0 deletions include/ops/conv_act/conv_act.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef CONV_ACT_H
#define CONV_ACT_H

#include "../../export.h"
#include "../../operators.h"
#include "../activations.h"
#include <cstddef>

typedef struct ConvActParam {
/**
* Used by:
* - INFINI_ACTIVATION_CLIPPED_RELU: as its clipping ceiling
*/
double clip_coef;
/**
* Used by:
* - INFINI_ACTIVATION_LEAKY_RELU: as its slope for x < 0
* - INFINI_ACTIVATION_ELU: alpha * (exp(x) - 1.) for x < 0
*/
double alpha;
/**
* Used by:
* - INFINI_ACTIVATION_GELU: as its approximation switch
*/
const char *approximate;

} ConvActParam_t;

typedef struct ConvActDescriptor {
Device device;
} ConvActDescriptor;

typedef ConvActDescriptor *infiniopConvActDescriptor_t;

__C __export infiniopStatus_t infiniopCreateConvActDescriptor(infiniopHandle_t handle,
infiniopConvActDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t w,
infiniopTensorDescriptor_t b,
uint64_t const *pads,
int64_t const *strides,
uint64_t const *dilations,
uint64_t n,
InfiniActivationMode_t activation_mode,
ConvActParam_t act_params);

__C __export infiniopStatus_t infiniopGetConvActWorkspaceSize(infiniopConvActDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopConvAct(infiniopConvActDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void const *b, void *stream);

__C __export infiniopStatus_t infiniopDestroyConvActDescriptor(infiniopConvActDescriptor_t desc);


#endif
91 changes: 55 additions & 36 deletions operatorspy/tests/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,26 @@ class ConvDescriptor(Structure):
infiniopConvDescriptor_t = POINTER(ConvDescriptor)


def conv(x, w, stride, padding, dilation):
match len(x.shape) - 2:
case 1:
return F.conv1d(
x, w, stride=stride, padding=padding, dilation=dilation
)
case 2:
return F.conv2d(
x, w, stride=stride, padding=padding, dilation=dilation
)
case 3:
return F.conv3d(
x, w, stride=stride, padding=padding, dilation=dilation
)
case _:
print("Error: Pytorch -> Unsupported tensor dimension")
return None
def conv(x, w, b, stride, padding, dilation):
ndim = len(x.shape) - 2
conv_func_map = {
1: F.conv1d,
2: F.conv2d,
3: F.conv3d
}

if ndim not in conv_func_map:
print("Error: Pytorch -> Unsupported tensor dimension")
return None

# Select the appropriate convolution function
conv_func = conv_func_map[ndim]

if PROFILE:
ans = conv_func(x, w, b, stride=stride, padding=padding, dilation=dilation)
torch.cuda.synchronize()
return ans
return conv_func(x, w, b, stride=stride, padding=padding, dilation=dilation)


# infer the shape of the output given the inputs for a N-ary convolution
Expand Down Expand Up @@ -95,30 +98,33 @@ def test(
pads,
strides,
dilations,
tensor_stride=None,
add_bias,
tensor_dtype=torch.float16,
):
assert len(pads) == len(strides) == len(dilations)
print(
f"Testing Conv on {torch_device} with x_shape: {x_shape}, w_shape: {w_shape}, b_shape: {w_shape[0]}, pads: {pads}, strides: {strides}, dilations: {dilations}, x_stride: {tensor_stride} dtype:{tensor_dtype}"
f"Testing Conv on {torch_device} with x_shape: {x_shape}, w_shape: {w_shape}, add_bias: {add_bias}, "
f"b_shape: {w_shape[0]}, pads: {pads}, strides: {strides}, dilations: {dilations}, dtype:{tensor_dtype}"
)
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
w = torch.rand(w_shape, dtype=tensor_dtype).to(torch_device)
b = torch.round((torch.rand(w_shape[0], dtype=tensor_dtype).to(torch_device) * 2 - 1) * 1000) / 1000 if add_bias else None
y = torch.zeros(
inferShape(x.shape, w.shape, pads, strides, dilations), dtype=tensor_dtype
).to(torch_device)

for i in range(NUM_PRERUN if PROFILE else 1):
ans = conv(x, w, strides, pads, dilations)
ans = conv(x, w, b, strides, pads, dilations)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = conv(x, w, strides, pads, dilations)
_ = conv(x, w, b, strides, pads, dilations)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f"pytorch time: {elapsed :6f}")

x_tensor = to_tensor(x, lib)
w_tensor = to_tensor(w, lib)
b_tensor = to_tensor(b, lib) if b is not None else None
y_tensor = to_tensor(y, lib)
descriptor = infiniopConvDescriptor_t()

Expand All @@ -129,6 +135,7 @@ def test(
y_tensor.descriptor,
x_tensor.descriptor,
w_tensor.descriptor,
b_tensor.descriptor if b_tensor else None,
tuple_to_void_p(pads),
tuple_to_void_p(strides),
tuple_to_void_p(dilations),
Expand Down Expand Up @@ -157,6 +164,7 @@ def test(
y_tensor.data,
x_tensor.data,
w_tensor.data,
b_tensor.data if b_tensor else None,
None,
)
)
Expand All @@ -171,6 +179,7 @@ def test(
y_tensor.data,
x_tensor.data,
w_tensor.data,
b_tensor.data if b_tensor else None,
None,
)
)
Expand All @@ -187,18 +196,18 @@ def test(
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
for x_shape, w_shape, pads, strides, dilations, add_bias in test_cases:
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float32)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
for x_shape, w_shape, pads, strides, dilations, add_bias in test_cases:
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float32)
destroy_handle(lib, handle)


Expand All @@ -207,54 +216,62 @@ def test_bang(lib, test_cases):

device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
for x_shape, w_shape, pads, strides, dilations, add_bias in test_cases:
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, add_bias, tensor_dtype=torch.float32)
destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
# x_shape, w_shape, pads, strides, dilations, x_strides
# x_shape, w_shape, pads, strides, dilations, add_bias
(
(32, 3, 4),
(32, 3, 5),
(1,),
(1,),
(1,),
None,
False,
),
(
(3, 7, 4),
(3, 7, 5),
(1,),
(1,),
(1,),
True,
),
(
(1, 3, 4, 4),
(2, 3, 3, 3),
(1, 1),
(1, 2),
(2, 1),
None,
True,
),
(
(32, 3, 128, 128),
(64, 3, 5, 5),
(2, 2),
(2, 2),
(1, 1),
None,
False,
),
(
(1, 1, 4, 4, 4),
(1, 1, 5, 5, 5),
(1, 1, 1),
(1, 1, 1),
(1, 1, 1),
None,
True,
),
(
(32, 3, 32, 32, 32),
(64, 3, 5, 5, 5),
(3, 2, 2),
(4, 3, 3),
(2, 2, 1),
None,
False,
),
]
args = get_args()
Expand All @@ -266,6 +283,7 @@ def test_bang(lib, test_cases):
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_void_p,
c_void_p,
c_void_p,
Expand All @@ -280,6 +298,7 @@ def test_bang(lib, test_cases):
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyConvDescriptor.restype = c_int32
lib.infiniopDestroyConvDescriptor.argtypes = [
Expand Down
Loading
Loading