1
+ #include < cstring>
1
2
#include < mlc/core/all.h>
2
3
3
4
namespace mlc {
@@ -7,6 +8,7 @@ namespace {
7
8
8
9
MLC_REGISTER_FUNC (" mlc.testing.cxx_none" ).set_body([]() -> void { return ; });
9
10
MLC_REGISTER_FUNC (" mlc.testing.cxx_null" ).set_body([]() -> void * { return nullptr ; });
11
+ MLC_REGISTER_FUNC (" mlc.testing.cxx_bool" ).set_body([](bool x) -> bool { return x; });
10
12
MLC_REGISTER_FUNC (" mlc.testing.cxx_int" ).set_body([](int x) -> int { return x; });
11
13
MLC_REGISTER_FUNC (" mlc.testing.cxx_float" ).set_body([](double x) -> double { return x; });
12
14
MLC_REGISTER_FUNC (" mlc.testing.cxx_ptr" ).set_body([](void *x) -> void * { return x; });
@@ -17,6 +19,7 @@ MLC_REGISTER_FUNC("mlc.testing.cxx_raw_str").set_body([](const char *x) { return
17
19
/* *************** Reflection ****************/
18
20
19
21
struct TestingCClassObj : public Object {
22
+ bool bool_;
20
23
int8_t i8 ;
21
24
int16_t i16 ;
22
25
int32_t i32 ;
@@ -40,6 +43,7 @@ struct TestingCClassObj : public Object {
40
43
Dict<Any, Str> dict_any_str;
41
44
Dict<Str, List<int >> dict_str_list_int;
42
45
46
+ Optional<bool > opt_bool;
43
47
Optional<int64_t > opt_i64;
44
48
Optional<double > opt_f64;
45
49
Optional<void *> opt_raw_ptr;
@@ -57,22 +61,22 @@ struct TestingCClassObj : public Object {
57
61
Optional<Dict<Any, Str>> opt_dict_any_str;
58
62
Optional<Dict<Str, List<int >>> opt_dict_str_list_int;
59
63
60
- explicit TestingCClassObj (int8_t i8 , int16_t i16 , int32_t i32 , int64_t i64 , float f32 , double f64 , void *raw_ptr ,
61
- DLDataType dtype, DLDevice device, Any any, Func func, UList ulist, UDict udict, Str str_ ,
62
- Str str_readonly, List<Any> list_any, List<List<int >> list_list_int,
64
+ explicit TestingCClassObj (bool bool_, int8_t i8 , int16_t i16 , int32_t i32 , int64_t i64 , float f32 , double f64 ,
65
+ void *raw_ptr, DLDataType dtype, DLDevice device, Any any, Func func, UList ulist,
66
+ UDict udict, Str str_, Str str_readonly, List<Any> list_any, List<List<int >> list_list_int,
63
67
Dict<Any, Any> dict_any_any, Dict<Str, Any> dict_str_any, Dict<Any, Str> dict_any_str,
64
- Dict<Str, List<int >> dict_str_list_int, Optional<int64_t > opt_i64 , Optional<double > opt_f64 ,
65
- Optional<void *> opt_raw_ptr , Optional<DLDataType> opt_dtype , Optional<DLDevice> opt_device ,
66
- Optional<Func> opt_func , Optional<UList> opt_ulist , Optional<UDict> opt_udict ,
67
- Optional<Str> opt_str, Optional<List<Any>> opt_list_any,
68
+ Dict<Str, List<int >> dict_str_list_int, Optional<bool > opt_bool , Optional<int64_t > opt_i64 ,
69
+ Optional<double > opt_f64 , Optional<void *> opt_raw_ptr , Optional<DLDataType> opt_dtype ,
70
+ Optional<DLDevice> opt_device , Optional<Func> opt_func , Optional<UList> opt_ulist ,
71
+ Optional<UDict> opt_udict, Optional< Str> opt_str, Optional<List<Any>> opt_list_any,
68
72
Optional<List<List<int >>> opt_list_list_int, Optional<Dict<Any, Any>> opt_dict_any_any,
69
73
Optional<Dict<Str, Any>> opt_dict_str_any, Optional<Dict<Any, Str>> opt_dict_any_str,
70
74
Optional<Dict<Str, List<int >>> opt_dict_str_list_int)
71
- : i8(i8 ), i16(i16 ), i32(i32 ), i64(i64 ), f32(f32 ), f64(f64 ), raw_ptr(raw_ptr), dtype(dtype), device(device ),
72
- any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any ),
73
- list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any),
74
- dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_i64(opt_i64 ), opt_f64(opt_f64 ),
75
- opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func),
75
+ : bool_(bool_), i8(i8 ), i16(i16 ), i32(i32 ), i64(i64 ), f32(f32 ), f64(f64 ), raw_ptr(raw_ptr), dtype(dtype),
76
+ device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly),
77
+ list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any),
78
+ dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool ), opt_i64(opt_i64 ),
79
+ opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func),
76
80
opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any),
77
81
opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any),
78
82
opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {}
@@ -84,6 +88,7 @@ struct TestingCClassObj : public Object {
84
88
85
89
struct TestingCClass : public ObjectRef {
86
90
MLC_DEF_OBJ_REF (MLC_EXPORTS, TestingCClass, TestingCClassObj, ObjectRef)
91
+ .Field(" bool_" , &TestingCClassObj::bool_)
87
92
.Field(" i8" , &TestingCClassObj::i8 )
88
93
.Field(" i16" , &TestingCClassObj::i16 )
89
94
.Field(" i32" , &TestingCClassObj::i32 )
@@ -105,6 +110,7 @@ struct TestingCClass : public ObjectRef {
105
110
.Field(" dict_str_any" , &TestingCClassObj::dict_str_any)
106
111
.Field(" dict_any_str" , &TestingCClassObj::dict_any_str)
107
112
.Field(" dict_str_list_int" , &TestingCClassObj::dict_str_list_int)
113
+ .Field(" opt_bool" , &TestingCClassObj::opt_bool)
108
114
.Field(" opt_i64" , &TestingCClassObj::opt_i64)
109
115
.Field(" opt_f64" , &TestingCClassObj::opt_f64)
110
116
.Field(" opt_raw_ptr" , &TestingCClassObj::opt_raw_ptr)
@@ -121,13 +127,13 @@ struct TestingCClass : public ObjectRef {
121
127
.Field(" opt_dict_any_str" , &TestingCClassObj::opt_dict_any_str)
122
128
.Field(" opt_dict_str_list_int" , &TestingCClassObj::opt_dict_str_list_int)
123
129
.MemFn(" i64_plus_one" , &TestingCClassObj::i64_plus_one)
124
- .StaticFn(" __init__" ,
125
- InitOf<TestingCClassObj, int8_t , int16_t , int32_t , int64_t , float , double , void *, DLDataType, DLDevice ,
126
- Any, Func, UList, UDict, Str, Str, List <Any>, List<List< int >> , Dict<Any, Any >, Dict<Str, Any >,
127
- Dict<Any, Str>, Dict<Str, List< int >>, Optional<int64_t >, Optional<double >, Optional<void *>,
128
- Optional<DLDataType>, Optional<DLDevice>, Optional<Func>, Optional<UList>, Optional<UDict>,
129
- Optional<Str>, Optional<List<Any>>, Optional<List<List<int >>>, Optional<Dict<Any, Any>>,
130
- Optional<Dict<Str, Any>>, Optional<Dict<Any, Str>>, Optional<Dict<Str, List<int >>>>);
130
+ .StaticFn(" __init__" , InitOf<TestingCClassObj, bool , int8_t , int16_t , int32_t , int64_t , float , double , void *,
131
+ DLDataType, DLDevice, Any, Func, UList, UDict, Str, Str, List<Any>, List<List< int >> ,
132
+ Dict <Any, Any >, Dict<Str, Any> , Dict<Any, Str >, Dict<Str, List< int >>, Optional< bool >,
133
+ Optional<int64_t >, Optional<double >, Optional<void *>, Optional<DLDataType >,
134
+ Optional<DLDevice>, Optional<Func>, Optional<UList>, Optional<UDict>, Optional<Str >,
135
+ Optional<List<Any>>, Optional<List<List<int >>>, Optional<Dict<Any, Any>>,
136
+ Optional<Dict<Str, Any>>, Optional<Dict<Any, Str>>, Optional<Dict<Str, List<int >>>>);
131
137
};
132
138
133
139
/* *************** Traceback ****************/
@@ -191,5 +197,140 @@ MLC_REGISTER_FUNC("mlc.testing.nested_type_checking_list").set_body([](Str name)
191
197
MLC_UNREACHABLE ();
192
198
});
193
199
200
+ /* *************** Visitor ****************/
201
+
202
+ MLC_REGISTER_FUNC (" mlc.testing.VisitFields" ).set_body([](ObjectRef root) {
203
+ struct Visitor {
204
+ void operator ()(MLCTypeField *f, const Any *any) { Push (" Any" , f->name , *any); }
205
+ void operator ()(MLCTypeField *f, ObjectRef *obj) { Push (" ObjectRef" , f->name , *obj); }
206
+ void operator ()(MLCTypeField *f, Optional<ObjectRef> *opt) { Push (" Optional<ObjectRef>" , f->name , *opt); }
207
+ void operator ()(MLCTypeField *f, Optional<bool > *opt) { Push (" Optional<bool>" , f->name , *opt); }
208
+ void operator ()(MLCTypeField *f, Optional<int64_t > *opt) { Push (" Optional<int64_t>" , f->name , *opt); }
209
+ void operator ()(MLCTypeField *f, Optional<double > *opt) { Push (" Optional<double>" , f->name , *opt); }
210
+ void operator ()(MLCTypeField *f, Optional<DLDevice> *opt) { Push (" Optional<DLDevice>" , f->name , *opt); }
211
+ void operator ()(MLCTypeField *f, Optional<DLDataType> *opt) { Push (" Optional<DLDataType>" , f->name , *opt); }
212
+ void operator ()(MLCTypeField *f, bool *v) { Push (" bool" , f->name , *v); }
213
+ void operator ()(MLCTypeField *f, int8_t *v) { Push (" int8_t" , f->name , *v); }
214
+ void operator ()(MLCTypeField *f, int16_t *v) { Push (" int16_t" , f->name , *v); }
215
+ void operator ()(MLCTypeField *f, int32_t *v) { Push (" int32_t" , f->name , *v); }
216
+ void operator ()(MLCTypeField *f, int64_t *v) { Push (" int64_t" , f->name , *v); }
217
+ void operator ()(MLCTypeField *f, float *v) { Push (" float" , f->name , *v); }
218
+ void operator ()(MLCTypeField *f, double *v) { Push (" double" , f->name , *v); }
219
+ void operator ()(MLCTypeField *f, DLDataType *v) { Push (" DLDataType" , f->name , *v); }
220
+ void operator ()(MLCTypeField *f, DLDevice *v) { Push (" DLDevice" , f->name , *v); }
221
+ void operator ()(MLCTypeField *f, Optional<void *> *v) { Push (" Optional<void *>" , f->name , *v); }
222
+ void operator ()(MLCTypeField *f, void **v) { Push (" void *" , f->name , *v); }
223
+ void operator ()(MLCTypeField *f, const char **v) { Push (" const char *" , f->name , *v); }
224
+
225
+ void Push (const char *ty, const char *name, Any value) {
226
+ types->push_back (ty);
227
+ names->push_back (name);
228
+ values->push_back (value);
229
+ }
230
+ List<Str> *types;
231
+ List<Str> *names;
232
+ UList *values;
233
+ };
234
+ List<Str> types;
235
+ List<Str> names;
236
+ UList values;
237
+ MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo (root.GetTypeIndex ());
238
+ ::mlc::core::VisitFields (root.get(), info, Visitor{&types, &names, &values});
239
+ return UList{types, names, values};
240
+ });
241
+
242
+ struct FieldFoundException : public ::std::exception {};
243
+
244
+ struct FieldGetter {
245
+ void operator ()(MLCTypeField *f, const Any *any) { Check (f->name , any); }
246
+ void operator ()(MLCTypeField *f, ObjectRef *obj) { Check (f->name , obj); }
247
+ void operator ()(MLCTypeField *f, Optional<ObjectRef> *opt) { Check (f->name , opt); }
248
+ void operator ()(MLCTypeField *f, Optional<bool > *opt) { Check (f->name , opt); }
249
+ void operator ()(MLCTypeField *f, Optional<int64_t > *opt) { Check (f->name , opt); }
250
+ void operator ()(MLCTypeField *f, Optional<double > *opt) { Check (f->name , opt); }
251
+ void operator ()(MLCTypeField *f, Optional<DLDevice> *opt) { Check (f->name , opt); }
252
+ void operator ()(MLCTypeField *f, Optional<DLDataType> *opt) { Check (f->name , opt); }
253
+ void operator ()(MLCTypeField *f, bool *v) { Check (f->name , v); }
254
+ void operator ()(MLCTypeField *f, int8_t *v) { Check (f->name , v); }
255
+ void operator ()(MLCTypeField *f, int16_t *v) { Check (f->name , v); }
256
+ void operator ()(MLCTypeField *f, int32_t *v) { Check (f->name , v); }
257
+ void operator ()(MLCTypeField *f, int64_t *v) { Check (f->name , v); }
258
+ void operator ()(MLCTypeField *f, float *v) { Check (f->name , v); }
259
+ void operator ()(MLCTypeField *f, double *v) { Check (f->name , v); }
260
+ void operator ()(MLCTypeField *f, DLDataType *v) { Check (f->name , v); }
261
+ void operator ()(MLCTypeField *f, DLDevice *v) { Check (f->name , v); }
262
+ void operator ()(MLCTypeField *f, Optional<void *> *v) { Check (f->name , v); }
263
+ void operator ()(MLCTypeField *f, void **v) { Check (f->name , v); }
264
+ void operator ()(MLCTypeField *f, const char **v) { Check (f->name , v); }
265
+
266
+ template <typename T> void Check (const char *name, T *v) {
267
+ if (std::strcmp (name, target_name) == 0 ) {
268
+ *ret = Any (*v);
269
+ throw FieldFoundException ();
270
+ }
271
+ }
272
+ const char *target_name;
273
+ Any *ret;
274
+ };
275
+
276
+ struct FieldSetter {
277
+ void operator ()(MLCTypeField *f, Any *any) { Check (f->name , any); }
278
+ void operator ()(MLCTypeField *f, ObjectRef *obj) { Check (f->name , obj); }
279
+ void operator ()(MLCTypeField *f, Optional<ObjectRef> *opt) { Check (f->name , opt); }
280
+ void operator ()(MLCTypeField *f, Optional<bool > *opt) { Check (f->name , opt); }
281
+ void operator ()(MLCTypeField *f, Optional<int64_t > *opt) { Check (f->name , opt); }
282
+ void operator ()(MLCTypeField *f, Optional<double > *opt) { Check (f->name , opt); }
283
+ void operator ()(MLCTypeField *f, Optional<DLDevice> *opt) { Check (f->name , opt); }
284
+ void operator ()(MLCTypeField *f, Optional<DLDataType> *opt) { Check (f->name , opt); }
285
+ void operator ()(MLCTypeField *f, bool *v) { Check (f->name , v); }
286
+ void operator ()(MLCTypeField *f, int8_t *v) { Check (f->name , v); }
287
+ void operator ()(MLCTypeField *f, int16_t *v) { Check (f->name , v); }
288
+ void operator ()(MLCTypeField *f, int32_t *v) { Check (f->name , v); }
289
+ void operator ()(MLCTypeField *f, int64_t *v) { Check (f->name , v); }
290
+ void operator ()(MLCTypeField *f, float *v) { Check (f->name , v); }
291
+ void operator ()(MLCTypeField *f, double *v) { Check (f->name , v); }
292
+ void operator ()(MLCTypeField *f, DLDataType *v) { Check (f->name , v); }
293
+ void operator ()(MLCTypeField *f, DLDevice *v) { Check (f->name , v); }
294
+ void operator ()(MLCTypeField *f, Optional<void *> *v) { Check (f->name , v); }
295
+ void operator ()(MLCTypeField *f, void **v) { Check (f->name , v); }
296
+ void operator ()(MLCTypeField *f, const char **v) { Check (f->name , v); }
297
+
298
+ template <typename T> void Check (const char *name, T *v) {
299
+ if (std::strcmp (name, target_name) == 0 ) {
300
+ if constexpr (std::is_same_v<T, Any>) {
301
+ *v = src;
302
+ } else {
303
+ *v = src.operator T ();
304
+ }
305
+ throw FieldFoundException ();
306
+ }
307
+ }
308
+ const char *target_name;
309
+ Any src;
310
+ };
311
+
312
+ MLC_REGISTER_FUNC (" mlc.testing.FieldGet" ).set_body([](ObjectRef root, const char *target_name) {
313
+ Any ret;
314
+ MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo (root.GetTypeIndex ());
315
+ try {
316
+ ::mlc::core::VisitFields (root.get(), info, FieldGetter{target_name, &ret});
317
+ } catch (FieldFoundException &) {
318
+ return ret;
319
+ }
320
+ MLC_THROW (ValueError) << " Field not found: " << target_name;
321
+ MLC_UNREACHABLE ();
322
+ });
323
+
324
+ MLC_REGISTER_FUNC (" mlc.testing.FieldSet" ).set_body([](ObjectRef root, const char *target_name, Any src) {
325
+ MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo (root.GetTypeIndex ());
326
+ try {
327
+ ::mlc::core::VisitFields (root.get(), info, FieldSetter{target_name, src});
328
+ } catch (FieldFoundException &) {
329
+ return ;
330
+ }
331
+ MLC_THROW (ValueError) << " Field not found: " << target_name;
332
+ MLC_UNREACHABLE ();
333
+ });
334
+
194
335
} // namespace
195
336
} // namespace mlc
0 commit comments