Skip to content

Commit 9f7c693

Browse files
Fixed python binding
1 parent f7d85d6 commit 9f7c693

File tree

8 files changed

+196
-17
lines changed

8 files changed

+196
-17
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/build*
22
__pycache__
33
.DS_Store
4-
.vscode
4+
.vscode
5+
logs

include/ion/c_ion.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ typedef struct ion_port_map_t_ *ion_port_map_t;
3535
int ion_port_create(ion_port_t *, const char *, ion_type_t, int);
3636
int ion_port_create_with_index(ion_port_t *, ion_port_t , int);
3737
int ion_port_destroy(ion_port_t);
38+
int ion_port_bind_i8(ion_port_t, int8_t*);
39+
int ion_port_bind_i16(ion_port_t, int16_t*);
40+
int ion_port_bind_i32(ion_port_t, int32_t*);
41+
int ion_port_bind_i64(ion_port_t, int64_t*);
42+
int ion_port_bind_u1(ion_port_t, bool*);
43+
int ion_port_bind_u8(ion_port_t, uint8_t*);
44+
int ion_port_bind_u16(ion_port_t, uint16_t*);
45+
int ion_port_bind_u32(ion_port_t, uint32_t*);
46+
int ion_port_bind_u64(ion_port_t, uint64_t*);
47+
int ion_port_bind_f32(ion_port_t, float*);
48+
int ion_port_bind_f64(ion_port_t, double*);
49+
int ion_port_bind_buffer(ion_port_t, ion_buffer_t);
50+
int ion_port_bind_buffer_array(ion_port_t, ion_buffer_t *, int);
3851

3952
int ion_param_create(ion_param_t *, const char *, const char *);
4053
int ion_param_destroy(ion_param_t);
@@ -57,24 +70,40 @@ int ion_builder_bb_metadata(ion_builder_t, char *, int, int *);
5770
int ion_builder_run(ion_builder_t, ion_port_map_t);
5871

5972
int ion_buffer_create(ion_buffer_t *, ion_type_t, int *, int);
73+
int ion_buffer_create_with_ptr(ion_buffer_t *, ion_type_t, void *, int *, int);
6074
int ion_buffer_destroy(ion_buffer_t);
6175
int ion_buffer_write(ion_buffer_t, void *, int size);
6276
int ion_buffer_read(ion_buffer_t, void *, int size);
6377

78+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
6479
int ion_port_map_create(ion_port_map_t *);
80+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
6581
int ion_port_map_destroy(ion_port_map_t);
82+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
6683
int ion_port_map_set_i8(ion_port_map_t, ion_port_t, int8_t);
84+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
6785
int ion_port_map_set_i16(ion_port_map_t, ion_port_t, int16_t);
86+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
6887
int ion_port_map_set_i32(ion_port_map_t, ion_port_t, int32_t);
88+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
6989
int ion_port_map_set_i64(ion_port_map_t, ion_port_t, int64_t);
90+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
7091
int ion_port_map_set_u1(ion_port_map_t, ion_port_t, bool);
92+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
7193
int ion_port_map_set_u8(ion_port_map_t, ion_port_t, uint8_t);
94+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
7295
int ion_port_map_set_u16(ion_port_map_t, ion_port_t, uint16_t);
96+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
7397
int ion_port_map_set_u32(ion_port_map_t, ion_port_t, uint32_t);
98+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
7499
int ion_port_map_set_u64(ion_port_map_t, ion_port_t, uint64_t);
100+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
75101
int ion_port_map_set_f32(ion_port_map_t, ion_port_t, float);
102+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
76103
int ion_port_map_set_f64(ion_port_map_t, ion_port_t, double);
104+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
77105
int ion_port_map_set_buffer(ion_port_map_t, ion_port_t, ion_buffer_t);
106+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
78107
int ion_port_map_set_buffer_array(ion_port_map_t, ion_port_t, ion_buffer_t *, int);
79108

80109
#if defined __cplusplus

python/ionpy/Node.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
ion_node_create,
99
ion_node_destroy,
1010
ion_node_get_port,
11-
ion_node_set_port,
11+
ion_node_set_iport,
1212
ion_node_set_param,
1313
)
1414
from .Type import Type
@@ -42,15 +42,15 @@ def get_port(self, name: str) -> Port:
4242

4343
return Port(obj_=c_port)
4444

45-
def set_port(self, ports: List[Port]) -> 'Node':
45+
def set_iport(self, ports: List[Port]) -> 'Node':
4646
num_ports = len(ports)
4747
c_ion_port_sized_array_t = c_ion_port_t * num_ports # arraysize == num_ports
4848
c_ports = c_ion_port_sized_array_t() # instance
4949

5050
for i in range(num_ports):
5151
c_ports[i] = ports[i].obj
5252

53-
ret = ion_node_set_port(self.obj, c_ports, num_ports)
53+
ret = ion_node_set_iport(self.obj, c_ports, num_ports)
5454
if ret != 0:
5555
raise Exception('Invalid operation')
5656

python/ionpy/native.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ class c_builder_compile_option_t(ctypes.Structure):
7777
ion_node_get_port.restype = ctypes.c_int
7878
ion_node_get_port.argtypes = [ c_ion_node_t, ctypes.c_char_p, ctypes.POINTER(c_ion_port_t) ]
7979

80-
# int ion_node_set_port(ion_node_t, ion_port_t *, int);
81-
ion_node_set_port = ion_core.ion_node_set_port
82-
ion_node_set_port.restype = ctypes.c_int
83-
ion_node_set_port.argtypes = [ c_ion_node_t, ctypes.POINTER(c_ion_port_t), ctypes.c_int ]
80+
# int ion_node_set_iport(ion_node_t, ion_port_t *, int);
81+
ion_node_set_iport = ion_core.ion_node_set_iport
82+
ion_node_set_iport.restype = ctypes.c_int
83+
ion_node_set_iport.argtypes = [ c_ion_node_t, ctypes.POINTER(c_ion_port_t), ctypes.c_int ]
8484

8585
# int ion_node_set_param(ion_node_t, ion_param_t *, int);
8686
ion_node_set_param = ion_core.ion_node_set_param

python/test/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_all():
1414
builder.with_bb_module(path='ion-bb-test')
1515
# builder.with_bb_module(path='ion-bb-test.dll') # for Windows
1616

17-
node = builder.add('test_inc_i32x2').set_port(ports=[ input_port, ]).set_param(params=[ value41, ])
17+
node = builder.add('test_inc_i32x2').set_iport(ports=[ input_port, ]).set_param(params=[ value41, ])
1818

1919
port_map = PortMap()
2020

python/test/test_node_port.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
from ionpy import Node, Port, Type, TypeCode
1+
from ionpy import Node, Builder, Port, Type, TypeCode
22

33

44
def test_node_port():
55
t = Type(code_=TypeCode.Int, bits_=32, lanes_=1)
66

7-
port_to_set = Port(name='iamkey', type=t, dim=3)
7+
port_to_set = Port(name='input', type=t, dim=2)
88

9-
ports = [ port_to_set, ]
9+
builder = Builder()
10+
builder.set_target(target='host')
11+
# make sure path includes libion-bb-test.so
12+
builder.with_bb_module(path='ion-bb-test')
13+
# builder.with_bb_module(path='ion-bb-test.dll') # for Windows
1014

11-
n = Node()
12-
n.set_port(ports)
15+
n = builder.add('test_inc_i32x2').set_iport(ports=[ port_to_set, ])
1316

14-
port_to_get = n.get_port('iamkey')
17+
port_to_get = n.get_port('input')
1518
print(f'from node.get_port: {port_to_get}')
19+

python/test/test_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def test_pipeline():
2020
builder.with_bb_module(path='ion-bb')
2121

2222
node = builder.add('image_io_cameraN').set_param(params=[width, height, urls])
23-
node1 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(name='output')[0], ]);
24-
node2 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(name='output')[1], ]);
23+
node1 = builder.add("base_normalize_3d_uint8").set_iport(ports=[node.get_port(name='output')[0], ]);
24+
node2 = builder.add("base_normalize_3d_uint8").set_iport(ports=[node.get_port(name='output')[1], ]);
2525

2626
port_map = PortMap()
2727

src/c_ion.cc

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,151 @@ int ion_port_destroy(ion_port_t obj)
8080
return 0;
8181
}
8282

83+
#define ION_PORT_BIND_IMPL(T, POSTFIX) \
84+
int ion_port_bind_##POSTFIX(ion_port_t obj, T v) { \
85+
try { \
86+
reinterpret_cast<Port*>(obj)->bind(v); \
87+
} catch (const Halide::Error& e) { \
88+
log::error(e.what()); \
89+
return 1; \
90+
} catch (const std::exception& e) { \
91+
log::error(e.what()); \
92+
return 1; \
93+
} catch (...) { \
94+
log::error("Unknown exception was happened"); \
95+
return 1; \
96+
} \
97+
\
98+
return 0; \
99+
}
100+
101+
ION_PORT_BIND_IMPL(int8_t*, i8)
102+
ION_PORT_BIND_IMPL(int16_t*, i16)
103+
ION_PORT_BIND_IMPL(int32_t*, i32)
104+
ION_PORT_BIND_IMPL(int64_t*, i64)
105+
ION_PORT_BIND_IMPL(bool*, u1)
106+
ION_PORT_BIND_IMPL(uint8_t*, u8)
107+
ION_PORT_BIND_IMPL(uint16_t*, u16)
108+
ION_PORT_BIND_IMPL(uint32_t*, u32)
109+
ION_PORT_BIND_IMPL(uint64_t*, u64)
110+
ION_PORT_BIND_IMPL(float*, f32)
111+
ION_PORT_BIND_IMPL(double*, f64)
112+
113+
#undef ION_PORT_BIND_IMPL
114+
115+
int ion_port_bind_buffer(ion_port_t obj, ion_buffer_t b)
116+
{
117+
try {
118+
// NOTE: Halide::Buffer class layout is safe to call Halide::Buffer<void>::type()
119+
auto type = reinterpret_cast<Halide::Buffer<void>*>(b)->type();
120+
if (type.is_int()) {
121+
if (type.bits() == 8) {
122+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<int8_t>*>(b));
123+
} else if (type.bits() == 16) {
124+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<int16_t>*>(b));
125+
} else if (type.bits() == 32) {
126+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<int32_t>*>(b));
127+
} else if (type.bits() == 64) {
128+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<int64_t>*>(b));
129+
} else {
130+
throw std::runtime_error("Unsupported bits number");
131+
}
132+
} else if (type.is_uint()) {
133+
if (type.bits() == 1) {
134+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<bool>*>(b));
135+
} else if (type.bits() == 8) {
136+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<uint8_t>*>(b));
137+
} else if (type.bits() == 16) {
138+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<uint16_t>*>(b));
139+
} else if (type.bits() == 32) {
140+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<uint32_t>*>(b));
141+
} else if (type.bits() == 64) {
142+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<uint64_t>*>(b));
143+
} else {
144+
throw std::runtime_error("Unsupported bits number");
145+
}
146+
} else if (type.is_float()) {
147+
if (type.bits() == 32) {
148+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<float>*>(b));
149+
} else if (type.bits() == 64) {
150+
reinterpret_cast<Port*>(obj)->bind(*reinterpret_cast<Halide::Buffer<double>*>(b));
151+
} else {
152+
throw std::runtime_error("Unsupported bits number");
153+
}
154+
} else {
155+
throw std::runtime_error("Unsupported type code");
156+
}
157+
} catch (const Halide::Error& e) {
158+
log::error(e.what());
159+
return 1;
160+
} catch (const std::exception& e) {
161+
log::error(e.what());
162+
return 1;
163+
} catch (...) {
164+
log::error("Unknown exception was happened");
165+
return 1;
166+
}
167+
168+
169+
return 0;
170+
}
171+
172+
int ion_port_bind_buffer_array(ion_port_t obj, ion_buffer_t *bs, int n)
173+
{
174+
try {
175+
// NOTE: Halide::Buffer class layout is safe to call Halide::Buffer<void>::type()
176+
auto type = reinterpret_cast<Halide::Buffer<void>*>(*bs)->type();
177+
if (type.is_int()) {
178+
if (type.bits() == 8) {
179+
reinterpret_cast<Port*>(obj)->bind(convert<int8_t>(bs, n));
180+
} else if (type.bits() == 16) {
181+
reinterpret_cast<Port*>(obj)->bind(convert<int16_t>(bs, n));
182+
} else if (type.bits() == 32) {
183+
reinterpret_cast<Port*>(obj)->bind(convert<int32_t>(bs, n));
184+
} else if (type.bits() == 64) {
185+
reinterpret_cast<Port*>(obj)->bind(convert<int64_t>(bs, n));
186+
} else {
187+
throw std::runtime_error("Unsupported bits number");
188+
}
189+
} else if (type.is_uint()) {
190+
if (type.bits() == 1) {
191+
reinterpret_cast<Port*>(obj)->bind(convert<bool>(bs, n));
192+
} else if (type.bits() == 8) {
193+
reinterpret_cast<Port*>(obj)->bind(convert<uint8_t>(bs, n));
194+
} else if (type.bits() == 16) {
195+
reinterpret_cast<Port*>(obj)->bind(convert<uint16_t>(bs, n));
196+
} else if (type.bits() == 32) {
197+
reinterpret_cast<Port*>(obj)->bind(convert<uint32_t>(bs, n));
198+
} else if (type.bits() == 64) {
199+
reinterpret_cast<Port*>(obj)->bind(convert<uint64_t>(bs, n));
200+
} else {
201+
throw std::runtime_error("Unsupported bits number");
202+
}
203+
} else if (type.is_float()) {
204+
if (type.bits() == 32) {
205+
reinterpret_cast<Port*>(obj)->bind(convert<float>(bs, n));
206+
} else if (type.bits() == 64) {
207+
reinterpret_cast<Port*>(obj)->bind(convert<double>(bs, n));
208+
} else {
209+
throw std::runtime_error("Unsupported bits number");
210+
}
211+
} else {
212+
throw std::runtime_error("Unsupported type code");
213+
}
214+
} catch (const Halide::Error& e) {
215+
log::error(e.what());
216+
return 1;
217+
} catch (const std::exception& e) {
218+
log::error(e.what());
219+
return 1;
220+
} catch (...) {
221+
log::error("Unknown exception was happened");
222+
return 1;
223+
}
224+
225+
226+
return 0;
227+
}
83228
//
84229
// ion_param_t
85230
//

0 commit comments

Comments
 (0)