Skip to content

Commit 7a9f0a5

Browse files
committed
Clip, reduce, gather, where Cpu operators
1 parent d7bda2a commit 7a9f0a5

34 files changed

+2362
-3
lines changed

env.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export INFINI_ROOT=/home/zzw/workspace/operators-dev/build/linux/x86_64/release

include/infini_operators.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,10 @@
1616
#include "ops/rms_norm/rms_norm.h"
1717
#include "ops/rotary_embedding/rotary_embedding.h"
1818
#include "ops/swiglu/swiglu.h"
19+
#include "ops/reducemax/reducemax.h"
20+
#include "ops/reducemean/reducemean.h"
21+
#include "ops/reducemin/reducemin.h"
22+
#include "ops/clip/clip.h"
23+
#include "ops/where/where.h"
24+
#include "ops/gather/gather.h"
1925
#include "tensor/tensor_descriptor.h"

include/ops/clip/clip.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef CLIP_H
2+
#define CLIP_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ClipDescriptor {
8+
Device device;
9+
} ClipDescriptor;
10+
typedef ClipDescriptor *infiniopClipDescriptor_t;
11+
12+
__C __export infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle,
13+
infiniopClipDescriptor_t *desc_ptr,
14+
infiniopTensorDescriptor_t x,
15+
infiniopTensorDescriptor_t y
16+
);
17+
18+
__C __export infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc, void const *x, void *min, void *max, void *y, void *stream);
19+
20+
__C __export infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc);
21+
22+
#endif

include/ops/gather/gather.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef GAHTER_H
2+
#define GAHTER_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct GatherDescriptor {
8+
Device device;
9+
} GatherDescriptor;
10+
typedef GatherDescriptor *infiniopGatherDescriptor_t;
11+
12+
__C __export infiniopStatus_t infiniopCreateGatherDescriptor(infiniopHandle_t handle,
13+
infiniopGatherDescriptor_t *desc_ptr,
14+
infiniopTensorDescriptor_t y,
15+
infiniopTensorDescriptor_t x,
16+
infiniopTensorDescriptor_t indices,
17+
int64_t axis
18+
);
19+
20+
__C __export infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc, void const *x, void *indices, void *y, void *stream);
21+
22+
__C __export infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc);
23+
24+
#endif

include/ops/reducemax/reducemax.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef REDUCEMAX_H
2+
#define REDUCEMAX_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ReducemaxDescriptor {
8+
Device device;
9+
} ReducemaxDescriptor;
10+
typedef ReducemaxDescriptor *infiniopReducemaxDescriptor_t;
11+
12+
__C __export infiniopStatus_t infiniopCreateReducemaxDescriptor(infiniopHandle_t handle,
13+
infiniopReducemaxDescriptor_t *desc_ptr,
14+
infiniopTensorDescriptor_t y,
15+
infiniopTensorDescriptor_t x,
16+
int64_t const *axes,
17+
uint64_t n,
18+
bool keepdims,
19+
bool noop_with_empty_axes
20+
);
21+
22+
__C __export infiniopStatus_t infiniopReducemax(infiniopReducemaxDescriptor_t desc, void *y, void const *x, void const *dynamic_axes, uint64_t dynamic_axes_size, void *stream);
23+
24+
__C __export infiniopStatus_t infiniopDestroyReducemaxDescriptor(infiniopReducemaxDescriptor_t desc);
25+
#endif

include/ops/reducemean/reducemean.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef REDUCEMEAN_H
2+
#define REDUCEMEAN_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ReducemeanDescriptor {
8+
Device device;
9+
} ReducemeanDescriptor;
10+
typedef ReducemeanDescriptor *infiniopReducemeanDescriptor_t;
11+
12+
__C __export infiniopStatus_t infiniopCreateReducemeanDescriptor(infiniopHandle_t handle,
13+
infiniopReducemeanDescriptor_t *desc_ptr,
14+
infiniopTensorDescriptor_t y,
15+
infiniopTensorDescriptor_t x,
16+
int64_t const *axes,
17+
uint64_t n,
18+
bool keepdims,
19+
bool noop_with_empty_axes
20+
);
21+
22+
__C __export infiniopStatus_t infiniopReducemean(infiniopReducemeanDescriptor_t desc, void *dst, void const *src, void const *dynamic_axes, uint64_t dynamic_axes_size, void *stream);
23+
24+
__C __export infiniopStatus_t infiniopDestroyReducemeanDescriptor(infiniopReducemeanDescriptor_t desc);
25+
#endif

include/ops/reducemin/reducemin.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef REDUCEMIN_H
2+
#define REDUCEMIN_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ReduceminDescriptor {
8+
Device device;
9+
} ReduceminDescriptor;
10+
typedef ReduceminDescriptor *infiniopReduceminDescriptor_t;
11+
12+
__C __export infiniopStatus_t infiniopCreateReduceminDescriptor(infiniopHandle_t handle,
13+
infiniopReduceminDescriptor_t *desc_ptr,
14+
infiniopTensorDescriptor_t dst,
15+
infiniopTensorDescriptor_t src,
16+
int64_t const *axes,
17+
uint64_t n,
18+
bool keepdims,
19+
bool noop_with_empty_axes
20+
);
21+
22+
__C __export infiniopStatus_t infiniopReducemin(infiniopReduceminDescriptor_t desc, void *dst, void const *src, void const *dynamic_axes, uint64_t dynamic_axes_size, void *stream);
23+
24+
__C __export infiniopStatus_t infiniopDestroyReduceminDescriptor(infiniopReduceminDescriptor_t desc);
25+
#endif

include/ops/where/where.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef WHERE_H
2+
#define WHERE_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct WhereDescriptor {
8+
Device device;
9+
} WhereDescriptor;
10+
typedef WhereDescriptor *infiniopWhereDescriptor_t;
11+
12+
__C __export infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle,
13+
infiniopWhereDescriptor_t *desc_ptr,
14+
infiniopTensorDescriptor_t dst,
15+
infiniopTensorDescriptor_t src1,
16+
infiniopTensorDescriptor_t src2,
17+
infiniopTensorDescriptor_t condition
18+
);
19+
20+
__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc, void *dst, void *src1, void *src2, void *condition, void *stream);
21+
22+
__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc);
23+
24+
#endif

operatorspy/liboperators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class Handle(Structure):
4545
def open_lib():
4646
def find_library_in_ld_path(library_name):
4747
ld_library_path = LIB_OPERATORS_DIR
48+
49+
print(LIB_OPERATORS_DIR)
4850
paths = ld_library_path.split(os.pathsep)
4951
for path in paths:
5052
full_path = os.path.join(path, library_name)

operatorspy/tests/clip.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_bool
2+
import ctypes
3+
import sys
4+
import os
5+
import time
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
)
18+
19+
from operatorspy.tests.test_utils import get_args
20+
import torch
21+
from typing import Tuple
22+
import numpy as np
23+
24+
PROFILE = False
25+
NUM_PRERUN = 10
26+
NUM_ITERATIONS = 1000
27+
28+
class ClipDescriptor(Structure):
29+
_fields_ = [("device", c_int32)]
30+
31+
infiniopClipDescriptor_t = POINTER(ClipDescriptor)
32+
33+
def clip(input, min, max):
34+
return torch.clamp(input, min, max)
35+
36+
37+
def tuple_to_void_p(py_tuple: Tuple):
38+
array = ctypes.c_int64 * len(py_tuple)
39+
data_array = array(*py_tuple)
40+
return ctypes.cast(data_array, ctypes.c_void_p)
41+
42+
def test(
43+
lib,
44+
handle,
45+
torch_device,
46+
x_shape,
47+
min,
48+
max,
49+
tensor_dtype=torch.float16
50+
):
51+
print(
52+
f"Testing clip on {torch_device} with x_shape:{x_shape} dtype:{tensor_dtype} max:{max} min:{min}"
53+
)
54+
x = torch.randn(x_shape, dtype=tensor_dtype, device=torch_device)
55+
output = torch.randn(x_shape, dtype=tensor_dtype, device=torch_device)
56+
if min != None:
57+
min = torch.tensor(min, dtype=torch.float32, device=torch_device)
58+
else:
59+
min = torch.tensor(float("-inf"), dtype=torch.float32, device=torch_device)
60+
if max != None:
61+
max = torch.tensor(max, dtype=torch.float32, device=torch_device)
62+
else:
63+
max = torch.tensor(float("inf"), dtype=torch.float32, device=torch_device)
64+
for i in range(NUM_PRERUN if PROFILE else 1):
65+
if min == None and max == None:
66+
break
67+
ans = clip(x, min, max)
68+
if PROFILE:
69+
start_time = time.time()
70+
for i in range(NUM_ITERATIONS):
71+
_ = clip(x, min, max)
72+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
73+
print(f"pytorch time: {elapsed :10f}")
74+
x_tensor = to_tensor(x, lib)
75+
y_tensor = to_tensor(output, lib)
76+
descriptor = infiniopClipDescriptor_t()
77+
check_error(
78+
lib.infiniopCreateClipDescriptor(
79+
handle,
80+
ctypes.byref(descriptor),
81+
x_tensor.descriptor,
82+
y_tensor.descriptor,
83+
)
84+
)
85+
x_tensor.descriptor.contents.invalidate()
86+
y_tensor.descriptor.contents.invalidate()
87+
for i in range(NUM_PRERUN if PROFILE else 1):
88+
check_error(
89+
lib.infiniopClip(
90+
descriptor,
91+
x_tensor.data,
92+
min.data_ptr() if min != None else None,
93+
max.data_ptr() if max != None else None,
94+
y_tensor.data,
95+
None,
96+
)
97+
)
98+
if PROFILE:
99+
start_time = time.time()
100+
for i in range(NUM_ITERATIONS):
101+
check_error(
102+
lib.infiniopClip(
103+
descriptor,
104+
x_tensor.data,
105+
min.data_ptr() if min != None else None,
106+
max.data_ptr() if max != None else None,
107+
y_tensor.data,
108+
None,
109+
)
110+
)
111+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
112+
print(f"lib time: {elapsed :10f}")
113+
print("x:", x)
114+
print("custom op ans:", output)
115+
print("ans:", ans) if max != None or min != None else print("ans:", x)
116+
assert torch.allclose(output, ans, atol=0, rtol=0) if max != None or min != None else torch.allclose(output, x, atol=0, rtol=0)
117+
check_error(lib.infiniopDestroyClipDescriptor(descriptor))
118+
119+
def test_cpu(lib, test_cases):
120+
device = DeviceEnum.DEVICE_CPU
121+
handle = create_handle(lib, device)
122+
for x_shape, min, max in test_cases:
123+
test(lib, handle, "cpu", x_shape, min, max, tensor_dtype=torch.float16)
124+
print("\n")
125+
#test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
126+
destroy_handle(lib, handle)
127+
128+
129+
if __name__ == "__main__":
130+
test_cases = [
131+
((3, 4), -1, 1),
132+
((3, 4), None, 1),
133+
((3, 4), -1, None),
134+
((3, 4), None, None)
135+
# stride =
136+
]
137+
args = get_args()
138+
lib = open_lib()
139+
lib.infiniopCreateClipDescriptor.restype = c_int32
140+
lib.infiniopCreateClipDescriptor.argtypes = [
141+
infiniopHandle_t,
142+
POINTER(infiniopClipDescriptor_t),
143+
infiniopTensorDescriptor_t,
144+
]
145+
lib.infiniopClip.restype = c_int32
146+
lib.infiniopClip.argtypes = [
147+
infiniopClipDescriptor_t,
148+
c_void_p,
149+
c_void_p,
150+
c_void_p,
151+
c_void_p,
152+
]
153+
lib.infiniopDestroyClipDescriptor.restype = c_int32
154+
lib.infiniopDestroyClipDescriptor.argtypes = [infiniopClipDescriptor_t]
155+
test_cpu(lib, test_cases)
156+
print("All tests passed!")

0 commit comments

Comments
 (0)