Skip to content

Softmax:CPU,MLU,GPU三个平台的重构 #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/ops/softmax/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef SOFTMAX_H
#define SOFTMAX_H

#include "../../export.h"
#include "../../operators.h"

typedef struct SoftmaxDescriptor {
Device device;
} SoftmaxDescriptor;

typedef SoftmaxDescriptor *infiniopSoftmaxDescriptor_t;

__C __export infiniopStatus_t infiniopCreateSoftmaxDescriptor(infiniopHandle_t handle,
infiniopSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc);

__C infiniopStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, uint64_t *size);
__C __export infiniopStatus_t infiniopSoftmax(infiniopSoftmaxDescriptor_t desc, void *workspace,
uint64_t workspace_size,
void const *input,
void *output,
void *stream);

__C __export infiniopStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc);


#endif
144 changes: 144 additions & 0 deletions operatorspy/tests/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
import ctypes
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
create_workspace,
check_error,
rearrange_tensor,
)

from operatorspy.tests.test_utils import get_args
import torch


class SoftmaxDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopSoftmaxDescriptor_t = POINTER(SoftmaxDescriptor)


def softmax(x, axis):
return torch.softmax(x, axis = axis).to(x.dtype)


def test(lib, handle, torch_device, x_shape, axis, x_dtype=torch.float16):
print(
f"Testing Softmax on {torch_device} with x_shape:{x_shape} , axis:{axis} ,dtype:{x_dtype}"
)
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
y = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
ans = softmax(x, axis)
x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
descriptor = infiniopSoftmaxDescriptor_t()
check_error(
lib.infiniopCreateSoftmaxDescriptor(
handle, ctypes.byref(descriptor), x_tensor.descriptor, axis, y_tensor.descriptor
)
)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetSoftmaxWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = create_workspace(workspace_size.value, torch_device)
check_error(
lib.infiniopSoftmax(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
x_tensor.data,
y_tensor.data,
None,
)
)
err = y.reshape(-1,1) - ans.reshape(-1,1)
print(max(abs(err)))
assert torch.allclose(y, ans, atol=0, rtol=1e-2)
check_error(lib.infiniopDestroySoftmaxDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, axis, x_dtype in test_cases:
test(lib, handle, "cpu", x_shape, axis, x_dtype)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, axis, x_dtype in test_cases:
test(lib, handle, "cuda", x_shape, axis, x_dtype)
destroy_handle(lib, handle)


def test_bang(lib, test_cases):
import torch_mlu

device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for x_shape, axis, x_dtype in test_cases:
test(lib, handle, "mlu", x_shape, axis, x_dtype)
destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
# x_shape, axis
# 寒武纪芯片的国产CPU可能不支持f16
((32, 20, 512), 0, torch.float16),
((32, 20, 512), 1, torch.float16),
((32, 20, 512), 2, torch.float16),

((32, 20, 512), 0, torch.float32),
((32, 20, 512), 1, torch.float32),
((32, 20, 512), 2, torch.float32),

]
args = get_args()
lib = open_lib()
lib.infiniopCreateSoftmaxDescriptor.restype = c_int32
lib.infiniopCreateSoftmaxDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopSoftmaxDescriptor_t),
infiniopTensorDescriptor_t,
]

lib.infiniopSoftmax.restype = c_int32
lib.infiniopSoftmax.argtypes = [
infiniopSoftmaxDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroySoftmaxDescriptor.restype = c_int32
lib.infiniopDestroySoftmaxDescriptor.argtypes = [
infiniopSoftmaxDescriptor_t,
]

if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)

if not (args.cpu or args.cuda or args.bang):
test_cpu(lib, test_cases)
print("Test passed!")
61 changes: 61 additions & 0 deletions src/ops/softmax/bang/softmax_bang.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "softmax_bang.h"
#include "../../utils.h"

infiniopStatus_t bangCreateSoftmaxDescriptor(BangHandle_t handle,
SoftmaxBangDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc) {

if (input_desc->ndim != output_desc->ndim) {
return STATUS_BAD_TENSOR_SHAPE;
}
if (!dtype_eq(input_desc->dt, F16) && !dtype_eq(input_desc->dt, F32)) {
return STATUS_BAD_TENSOR_DTYPE;
}

int ndim = input_desc->ndim;

for (int i = 0; i < ndim; i++) {
if (input_desc->shape[i] != output_desc->shape[i]) {
return STATUS_BAD_TENSOR_SHAPE;
}
}

int stride = 1;
int dimsize = static_cast<int>(input_desc->shape[axis]);
int othersize = 1;
int frontsize = 1;

for (int s = ndim - 1; s >= 0; s--) {
if (s > axis) {
stride *= static_cast<int>(input_desc->shape[s]);
}
if (s < axis) {
frontsize *= static_cast<int>(input_desc->shape[s]);
}
if (s != axis) {
othersize *= static_cast<int>(input_desc->shape[s]);
}
}
*desc_ptr = new SoftmaxBangDescriptor{
handle->device,
handle->device_id,
input_desc->dt,
ndim,
axis,
dimsize,
stride,
othersize,
frontsize};

return STATUS_SUCCESS;
}
infiniopStatus_t bangGetSoftmaxWorkspaceSize(SoftmaxBangDescriptor_t desc, unsigned long int *size) {
*size = 32 * desc->othersize * sizeof(desc->dtype);//taskDim * othersize * sizeof(T),taskDim不超过32
return STATUS_SUCCESS;
}

infiniopStatus_t bangDestroySoftmaxDescriptor(SoftmaxBangDescriptor_t desc) {

delete desc;
return STATUS_SUCCESS;
}
36 changes: 36 additions & 0 deletions src/ops/softmax/bang/softmax_bang.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef __BANG_SOFTMAX_H__
#define __BANG_SOFTMAX_H__

#include "../../../devices/bang/bang_handle.h"
#include "../../utils.h"
#include "operators.h"

struct SoftmaxBangDescriptor {
Device device;
int device_id;
DT dtype;
int ndim;
int axis;
int dimsize;
int stride;
int othersize;
int frontsize;
};

typedef struct SoftmaxBangDescriptor *SoftmaxBangDescriptor_t;

infiniopStatus_t bangCreateSoftmaxDescriptor(BangHandle_t handle,
SoftmaxBangDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc);

infiniopStatus_t bangGetSoftmaxWorkspaceSize(SoftmaxBangDescriptor_t desc, unsigned long int *size);
infiniopStatus_t bangSoftmax(SoftmaxBangDescriptor_t desc, void *workspace,
uint64_t workspace_size,
void const *input,
void *output,
void *stream);

infiniopStatus_t bangDestroySoftmaxDescriptor(SoftmaxBangDescriptor_t desc);


#endif
Loading
Loading