Skip to content

Commit 81fb94f

Browse files
authored
feat(cpp): Import Mode (#5)
This feature simplifies C++ library integration of MLC by avoiding double registration. **What is and when double registration happens**. Originally, macro `MLC_DEF_{DYN/STATIC}_TYPE` and `MLC_DEF_OBJ_REF` registers an MLC object's type key and reflection information into the global type table, and those macros are idiomatically used in public header files. It means, if there's a different DLL that includes MLC headers, the registration of MLC objects will be performed again at loading time of this DLL. **Implication of double registration**. Double registration is practically harmless in most of the cases, as long as libraries are not dynamically unloaded, and in our case, ABI issue is carefully avoided. However, it may take a significant hit in compilation time and is not usually best practice for potential ABI issues. **Solution**. This PR introduces a new parameter `IS_EXPORT` to the three macros above, as part of the practice where we define macros like `MLC_EXPORTS` already in C++. When it is set to `False`, all the registrations are skipped and the macros instead look up the type table for already-existed information. Note that while setting `IS_EXPORT` can be tedious, it could be resolved by wrapping it with a simpler macro: ```C++ // Step 1. Check `MY_LIB_EXPORTS` #ifndef MY_LIB_EXPORTS #define MY_LIB_EXPORTS 0 #else #undef MY_LIB_EXPORTS #define MY_LIB_EXPORTS 1 #endif // Step 2. Wrap `MLC_*` macros using new macros #define MY_LIB_DEF_OBJ(A, B, C) \ MLC_DEF_DYN_TYPE(MY_LIB_EXPORTS, A, B, C) #define MY_LIB_DEF_OBJ_REF(A, B, C) \ MLC_DEF_OBJ_REF(MY_LIB_EXPORTS, A, B, C) ``` This PR also refactors the existing approach accessing `libmlc.so`, limiting the access to MLC C APIs via `mlc::Lib` defined in `lib.h`, which exposes a global MLC type table handle as an inline static member.
1 parent fba2689 commit 81fb94f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+539
-457
lines changed

cmake/Utils/Library.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
function(add_target_from_obj target_name obj_target_name)
22
add_library(${target_name}_static STATIC $<TARGET_OBJECTS:${obj_target_name}>)
3+
target_compile_definitions(${obj_target_name} PRIVATE MLC_EXPORTS)
34
set_target_properties(
45
${target_name}_static PROPERTIES
56
OUTPUT_NAME "${target_name}_static"
67
PREFIX "lib"
78
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
89
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
9-
)
10+
)
1011
add_library(${target_name}_shared SHARED $<TARGET_OBJECTS:${obj_target_name}>)
1112
set_target_properties(
1213
${target_name}_shared PROPERTIES
@@ -18,7 +19,6 @@ function(add_target_from_obj target_name obj_target_name)
1819
add_custom_target(${target_name})
1920
add_dependencies(${target_name} ${target_name}_static ${target_name}_shared)
2021
if (MSVC)
21-
target_compile_definitions(${obj_target_name} PRIVATE MLC_EXPORTS)
2222
set_target_properties(
2323
${obj_target_name} ${target_name}_shared ${target_name}_static
2424
PROPERTIES

cpp/c_api.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,9 @@ MLC_API void MLCExtObjDelete(void *objptr) {
185185
std::abort();
186186
}
187187
}
188+
189+
MLC_API int32_t MLCHandleGetGlobal(MLCTypeTableHandle *self) {
190+
MLC_SAFE_CALL_BEGIN();
191+
*self = TypeTable::Global();
192+
MLC_SAFE_CALL_END(&last_error);
193+
}

cpp/c_api_tests.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ struct TestingCClassObj : public Object {
7979

8080
int64_t i64_plus_one() const { return i64 + 1; }
8181

82-
MLC_DEF_DYN_TYPE(TestingCClassObj, Object, "mlc.testing.c_class");
82+
MLC_DEF_DYN_TYPE(MLC_EXPORTS, TestingCClassObj, Object, "mlc.testing.c_class");
8383
};
8484

8585
struct TestingCClass : public ObjectRef {
86-
MLC_DEF_OBJ_REF(TestingCClass, TestingCClassObj, ObjectRef)
86+
MLC_DEF_OBJ_REF(MLC_EXPORTS, TestingCClass, TestingCClassObj, ObjectRef)
8787
.Field("i8", &TestingCClassObj::i8)
8888
.Field("i16", &TestingCClassObj::i16)
8989
.Field("i32", &TestingCClassObj::i32)

cpp/json.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ inline mlc::Str Serialize(Any any) {
8484
} else if (type_index >= kMLCStaticObjectBegin) {
8585
EmitObject(any->operator Object *());
8686
} else {
87-
MLC_THROW(TypeError) << "Cannot serialize type: " << ::mlc::base::TypeIndex2TypeKey(type_index);
87+
MLC_THROW(TypeError) << "Cannot serialize type: " << Lib::GetTypeKey(type_index);
8888
}
8989
}
9090
inline void EmitObject(Object *obj) {
@@ -160,7 +160,7 @@ inline mlc::Str Serialize(Any any) {
160160
DLDataType v = any;
161161
os << "[" << type_dtype << ", \"" << TypeTraits<DLDataType>::__str__(v) << "\"]";
162162
} else {
163-
MLC_THROW(TypeError) << "Cannot serialize type: " << mlc::base::TypeIndex2TypeKey(any.type_index);
163+
MLC_THROW(TypeError) << "Cannot serialize type: " << Lib::GetTypeKey(any.type_index);
164164
}
165165
os << "], \"type_keys\": [";
166166
for (size_t i = 0; i < type_keys.size(); ++i) {
@@ -174,16 +174,15 @@ inline mlc::Str Serialize(Any any) {
174174
}
175175

176176
inline Any Deserialize(const char *json_str, int64_t json_str_len) {
177-
MLCVTableHandle init_table = ::mlc::base::LibState::init;
178177
// Step 0. Parse JSON string
179178
UDict json_obj = JSONLoads(json_str, json_str_len);
180179
// Step 1. type_key => constructors
181180
UList type_keys = json_obj->at("type_keys");
182181
std::vector<FuncObj *> constructors;
183182
constructors.reserve(type_keys.size());
184183
for (Str type_key : type_keys) {
185-
int32_t type_index = ::mlc::base::TypeKey2TypeIndex(type_key->data());
186-
FuncObj *func = ::mlc::base::LibState::VTableGetFunc(init_table, type_index, "__init__");
184+
int32_t type_index = Lib::GetTypeIndex(type_key->data());
185+
FuncObj *func = Lib::_init(type_index);
187186
constructors.push_back(func);
188187
}
189188
auto invoke_init = [&constructors](UList args) {

cpp/registry.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,7 @@ struct TypeTable {
379379
}
380380

381381
void AddMethod(int32_t type_index, MLCTypeMethod method) {
382-
// TODO: check `override_mode`
383-
this->GetGlobalVTable(method.name)->Set(type_index, reinterpret_cast<FuncObj *>(method.func), 0);
382+
this->GetGlobalVTable(method.name)->Set(type_index, reinterpret_cast<FuncObj *>(method.func), 2);
384383
this->GetTypeInfoWrapper(type_index)->AddMethod(method);
385384
}
386385

@@ -395,39 +394,39 @@ struct TypeTable {
395394

396395
struct _POD_REG {
397396
inline static const int32_t _none = //
398-
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCNone))
397+
::mlc::core::_Reflect(static_cast<int32_t>(MLCTypeIndex::kMLCNone))
399398
.MemFn("__str__", &::mlc::base::TypeTraits<std::nullptr_t>::__str__);
400399
inline static const int32_t _int = //
401-
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCInt))
400+
::mlc::core::_Reflect(static_cast<int32_t>(MLCTypeIndex::kMLCInt))
402401
.StaticFn("__init__", [](AnyView value) { return value.operator int64_t(); })
403402
.StaticFn("__new_ref__",
404403
[](void *dst, Optional<int64_t> value) { *reinterpret_cast<Optional<int64_t> *>(dst) = value; })
405404
.MemFn("__str__", &::mlc::base::TypeTraits<int64_t>::__str__);
406405
inline static const int32_t _float =
407-
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCFloat))
406+
::mlc::core::_Reflect(static_cast<int32_t>(MLCTypeIndex::kMLCFloat))
408407
.StaticFn("__new_ref__",
409408
[](void *dst, Optional<double> value) { *reinterpret_cast<Optional<double> *>(dst) = value; })
410409
.MemFn("__str__", &::mlc::base::TypeTraits<double>::__str__);
411410
inline static const int32_t _ptr =
412-
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCPtr))
411+
::mlc::core::_Reflect(static_cast<int32_t>(MLCTypeIndex::kMLCPtr))
413412
.StaticFn("__new_ref__",
414413
[](void *dst, Optional<void *> value) { *reinterpret_cast<Optional<void *> *>(dst) = value; })
415414
.MemFn("__str__", &::mlc::base::TypeTraits<void *>::__str__);
416415
inline static const int32_t _device =
417-
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCDevice))
416+
::mlc::core::_Reflect(static_cast<int32_t>(MLCTypeIndex::kMLCDevice))
418417
.StaticFn("__init__", [](AnyView device) { return device.operator DLDevice(); })
419418
.StaticFn("__new_ref__",
420419
[](void *dst, Optional<DLDevice> value) { *reinterpret_cast<Optional<DLDevice> *>(dst) = value; })
421420
.MemFn("__str__", &::mlc::base::TypeTraits<DLDevice>::__str__);
422421
inline static const int32_t _dtype =
423-
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCDataType))
422+
::mlc::core::_Reflect(static_cast<int32_t>(MLCTypeIndex::kMLCDataType))
424423
.StaticFn("__init__", [](AnyView dtype) { return dtype.operator DLDataType(); })
425424
.StaticFn(
426425
"__new_ref__",
427426
[](void *dst, Optional<DLDataType> value) { *reinterpret_cast<Optional<DLDataType> *>(dst) = value; })
428427
.MemFn("__str__", &::mlc::base::TypeTraits<DLDataType>::__str__);
429428
inline static const int32_t _str = //
430-
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCRawStr))
429+
::mlc::core::_Reflect(static_cast<int32_t>(MLCTypeIndex::kMLCRawStr))
431430
.MemFn("__str__", &::mlc::base::TypeTraits<const char *>::__str__);
432431
};
433432

@@ -463,7 +462,6 @@ inline void VTable::Set(int32_t type_index, FuncObj *func, int32_t override_mode
463462
// Allow override
464463
this->type_table->pool.DelObj(it->second);
465464
} else if (override_mode == 2) {
466-
// TODO: throw exception
467465
MLCTypeInfo *type_info = this->type_table->GetTypeInfo(type_index);
468466
if (type_info && !name.empty()) {
469467
MLC_THROW(KeyError) << "VTable `" << name << "` already registered for type: " << type_info->type_key;

cpp/structure.cc

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
152152
int32_t lhs_type_index = lhs ? lhs->GetTypeIndex() : kMLCNone;
153153
int32_t rhs_type_index = rhs ? rhs->GetTypeIndex() : kMLCNone;
154154
if (lhs_type_index != rhs_type_index) {
155-
MLC_CORE_EQ_S_ERR(::mlc::base::TypeIndex2TypeKey(lhs_type_index),
156-
::mlc::base::TypeIndex2TypeKey(rhs_type_index), new_path);
155+
MLC_CORE_EQ_S_ERR(Lib::GetTypeKey(lhs_type_index), Lib::GetTypeKey(rhs_type_index), new_path);
157156
} else if (lhs_type_index == kMLCStr) {
158157
Str lhs_str(reinterpret_cast<StrObj *>(lhs));
159158
Str rhs_str(reinterpret_cast<StrObj *>(rhs));
@@ -164,7 +163,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
164163
throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", new_path);
165164
} else {
166165
bool visited = false;
167-
MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(lhs_type_index);
166+
MLCTypeInfo *type_info = Lib::GetTypeInfo(lhs_type_index);
168167
tasks->push_back(Task{lhs, rhs, type_info, visited, bind_free_vars, new_path, nullptr});
169168
}
170169
}
@@ -301,14 +300,14 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
301300

302301
struct HashCache {
303302
inline static const uint64_t MLC_SYMBOL_HIDE kNoneCombined =
304-
::mlc::base::HashCombine(::mlc::base::TypeIndex2TypeInfo(kMLCNone)->type_key_hash, 0);
305-
inline static const uint64_t MLC_SYMBOL_HIDE kInt = ::mlc::base::TypeIndex2TypeInfo(kMLCInt)->type_key_hash;
306-
inline static const uint64_t MLC_SYMBOL_HIDE kFloat = ::mlc::base::TypeIndex2TypeInfo(kMLCFloat)->type_key_hash;
307-
inline static const uint64_t MLC_SYMBOL_HIDE kPtr = ::mlc::base::TypeIndex2TypeInfo(kMLCPtr)->type_key_hash;
308-
inline static const uint64_t MLC_SYMBOL_HIDE kDType = ::mlc::base::TypeIndex2TypeInfo(kMLCDataType)->type_key_hash;
309-
inline static const uint64_t MLC_SYMBOL_HIDE kDevice = ::mlc::base::TypeIndex2TypeInfo(kMLCDevice)->type_key_hash;
310-
inline static const uint64_t MLC_SYMBOL_HIDE kRawStr = ::mlc::base::TypeIndex2TypeInfo(kMLCRawStr)->type_key_hash;
311-
inline static const uint64_t MLC_SYMBOL_HIDE kStrObj = ::mlc::base::TypeIndex2TypeInfo(kMLCStr)->type_key_hash;
303+
::mlc::base::HashCombine(Lib::GetTypeInfo(kMLCNone)->type_key_hash, 0);
304+
inline static const uint64_t MLC_SYMBOL_HIDE kInt = Lib::GetTypeInfo(kMLCInt)->type_key_hash;
305+
inline static const uint64_t MLC_SYMBOL_HIDE kFloat = Lib::GetTypeInfo(kMLCFloat)->type_key_hash;
306+
inline static const uint64_t MLC_SYMBOL_HIDE kPtr = Lib::GetTypeInfo(kMLCPtr)->type_key_hash;
307+
inline static const uint64_t MLC_SYMBOL_HIDE kDType = Lib::GetTypeInfo(kMLCDataType)->type_key_hash;
308+
inline static const uint64_t MLC_SYMBOL_HIDE kDevice = Lib::GetTypeInfo(kMLCDevice)->type_key_hash;
309+
inline static const uint64_t MLC_SYMBOL_HIDE kRawStr = Lib::GetTypeInfo(kMLCRawStr)->type_key_hash;
310+
inline static const uint64_t MLC_SYMBOL_HIDE kStrObj = Lib::GetTypeInfo(kMLCStr)->type_key_hash;
312311
inline static const uint64_t MLC_SYMBOL_HIDE kBound = ::mlc::base::StrHash("$$Bounds$$");
313312
inline static const uint64_t MLC_SYMBOL_HIDE kUnbound = ::mlc::base::StrHash("$$Unbound$$");
314313
};
@@ -399,7 +398,7 @@ inline uint64_t StructuralHash(Object *obj) {
399398
} else if (type_index == kMLCFunc || type_index == kMLCError) {
400399
throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", ObjectPath::Root());
401400
} else {
402-
MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index);
401+
MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index);
403402
tasks->emplace_back(Task{obj, type_info, false, bind_free_vars, type_info->type_key_hash});
404403
}
405404
}
@@ -565,8 +564,8 @@ inline Any CopyShallow(AnyView source) {
565564
MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); }
566565
std::vector<AnyView> *fields;
567566
};
568-
FuncObj *init_func = ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_index, "__init__");
569-
MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index);
567+
FuncObj *init_func = Lib::_init(type_index);
568+
MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index);
570569
std::vector<AnyView> fields;
571570
VisitFields(source.operator Object *(), type_info, Copier{&fields});
572571
Any ret;
@@ -652,9 +651,8 @@ inline Any CopyDeep(AnyView source) {
652651
} else {
653652
fields.clear();
654653
VisitFields(object, type_info, Copier{&orig2copy, &fields});
655-
FuncObj *func =
656-
::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_info->type_index, "__init__");
657-
::mlc::base::FuncCall(func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
654+
FuncObj *init_func = Lib::_init(type_info->type_index);
655+
::mlc::base::FuncCall(init_func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
658656
}
659657
orig2copy[object] = ret.operator ObjectRef();
660658
});

include/mlc/base/all.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "./alloc.h" // IWYU pragma: export
55
#include "./any.h" // IWYU pragma: export
66
#include "./base_traits.h" // IWYU pragma: export
7+
#include "./lib.h" // IWYU pragma: export
78
#include "./optional.h" // IWYU pragma: export
89
#include "./ref.h" // IWYU pragma: export
910
#include "./traits_device.h" // IWYU pragma: export

include/mlc/base/any.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef MLC_BASE_ANY_H_
22
#define MLC_BASE_ANY_H_
33
#include "./base_traits.h"
4+
#include "./lib.h"
45
#include "./utils.h"
56
#include <cstring>
67
#include <type_traits>
@@ -72,20 +73,20 @@ struct AnyView : public MLCAny {
7273
}
7374
template <typename DerivedObj> inline DerivedObj *Cast() {
7475
if (!this->IsInstance<DerivedObj>()) {
75-
MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index)
76-
<< "` to type `" << ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
76+
MLC_THROW(TypeError) << "Cannot cast from type `" << Lib::GetTypeKey(this->type_index) << "` to type `"
77+
<< ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
7778
}
7879
return reinterpret_cast<DerivedObj *>(this->v.v_obj);
7980
}
8081
template <typename DerivedObj> MLC_INLINE const DerivedObj *Cast() const {
8182
if (!this->IsInstance<DerivedObj>()) {
82-
MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index)
83-
<< "` to type `" << ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
83+
MLC_THROW(TypeError) << "Cannot cast from type `" << Lib::GetTypeKey(this->type_index) << "` to type `"
84+
<< ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
8485
}
8586
return reinterpret_cast<const DerivedObj *>(this->v.v_obj);
8687
}
8788
int32_t GetTypeIndex() const { return this->type_index; }
88-
const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; }
89+
const char *GetTypeKey() const { return Lib::GetTypeKey(this->type_index); }
8990

9091
template <typename T, typename = std::enable_if_t<::mlc::base::Anyable<::mlc::base::RemoveCR<T>>>>
9192
MLC_INLINE_NO_MSVC T _CastWithStorage(Any *storage) const; // TODO: reemove this
@@ -166,20 +167,20 @@ struct Any : public MLCAny {
166167
}
167168
template <typename DerivedObj> inline DerivedObj *Cast() {
168169
if (!this->IsInstance<DerivedObj>()) {
169-
MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index)
170-
<< "` to type `" << ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
170+
MLC_THROW(TypeError) << "Cannot cast from type `" << Lib::GetTypeKey(this->type_index) << "` to type `"
171+
<< ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
171172
}
172173
return reinterpret_cast<DerivedObj *>(this->v.v_obj);
173174
}
174175
template <typename DerivedObj> MLC_INLINE const DerivedObj *Cast() const {
175176
if (!this->IsInstance<DerivedObj>()) {
176-
MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index)
177-
<< "` to type `" << ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
177+
MLC_THROW(TypeError) << "Cannot cast from type `" << Lib::GetTypeKey(this->type_index) << "` to type `"
178+
<< ::mlc::base::Type2Str<DerivedObj>::Run() << "`";
178179
}
179180
return reinterpret_cast<const DerivedObj *>(this->v.v_obj);
180181
}
181182
int32_t GetTypeIndex() const { return this->type_index; }
182-
const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; }
183+
const char *GetTypeKey() const { return Lib::GetTypeKey(this->type_index); }
183184

184185
protected:
185186
MLC_INLINE void Swap(MLCAny &src) {

include/mlc/base/lib.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#ifndef MLC_BASE_LIB_H_
2+
#define MLC_BASE_LIB_H_
3+
4+
#include "./utils.h"
5+
6+
namespace mlc {
7+
8+
struct Lib {
9+
static int32_t FuncSetGlobal(const char *name, FuncObj *func, bool allow_override = false);
10+
static FuncObj *FuncGetGlobal(const char *name, bool allow_missing = false);
11+
static ::mlc::Str CxxStr(AnyView obj);
12+
static ::mlc::Str Str(AnyView obj);
13+
static Any IRPrint(AnyView obj, AnyView printer, AnyView path);
14+
15+
static FuncObj *_init(int32_t type_index) { return VTableGetFunc(init, type_index, "__init__"); }
16+
MLC_INLINE static MLCTypeInfo *GetTypeInfo(int32_t type_index) {
17+
MLCTypeInfo *type_info;
18+
MLCTypeIndex2Info(_lib, type_index, &type_info);
19+
return type_info;
20+
}
21+
MLC_INLINE static MLCTypeInfo *GetTypeInfo(const char *type_key) {
22+
MLCTypeInfo *type_info;
23+
MLCTypeKey2Info(_lib, type_key, &type_info);
24+
return type_info;
25+
}
26+
MLC_INLINE static const char *GetTypeKey(int32_t type_index) {
27+
if (MLCTypeInfo *type_info = GetTypeInfo(type_index)) {
28+
return type_info->type_key;
29+
}
30+
return "(undefined)";
31+
}
32+
MLC_INLINE static const char *GetTypeKey(const MLCAny *self) {
33+
if (self == nullptr) {
34+
return "None";
35+
} else if (MLCTypeInfo *type_info = GetTypeInfo(self->type_index)) {
36+
return type_info->type_key;
37+
}
38+
return "(undefined)";
39+
}
40+
MLC_INLINE static int32_t GetTypeIndex(const char *type_key) {
41+
if (MLCTypeInfo *type_info = GetTypeInfo(type_key)) {
42+
return type_info->type_index;
43+
}
44+
MLC_THROW(TypeError) << "Cannot find type with key: " << type_key;
45+
MLC_UNREACHABLE();
46+
}
47+
MLC_INLINE static MLCTypeInfo *TypeRegister(int32_t parent_type_index, int32_t type_index, const char *type_key) {
48+
MLCTypeInfo *info = nullptr;
49+
MLCTypeRegister(_lib, parent_type_index, type_key, type_index, &info);
50+
return info;
51+
}
52+
53+
private:
54+
static FuncObj *VTableGetFunc(MLCVTableHandle vtable, int32_t type_index, const char *vtable_name) {
55+
MLCAny func{};
56+
MLCVTableGetFunc(vtable, type_index, true, &func);
57+
if (!::mlc::base::IsTypeIndexPOD(func.type_index)) {
58+
::mlc::base::DecRef(func.v.v_obj);
59+
}
60+
FuncObj *ret = reinterpret_cast<FuncObj *>(func.v.v_obj);
61+
if (func.type_index == kMLCNone) {
62+
MLC_THROW(TypeError) << "Function `" << vtable_name << "` for type: " << GetTypeKey(type_index)
63+
<< " is not defined in the vtable";
64+
} else if (func.type_index != kMLCFunc) {
65+
MLC_THROW(TypeError) << "Function `" << vtable_name << "` for type: " << GetTypeKey(type_index)
66+
<< " is not callable. Its type is " << GetTypeKey(func.type_index);
67+
}
68+
return ret;
69+
}
70+
static MLCVTableHandle VTableGetGlobal(const char *name) {
71+
MLCVTableHandle ret;
72+
MLCVTableGetGlobal(_lib, name, &ret);
73+
return ret;
74+
}
75+
static MLC_SYMBOL_HIDE inline MLCTypeTableHandle _lib = []() {
76+
MLCTypeTableHandle ret = nullptr;
77+
::MLCHandleGetGlobal(&ret);
78+
return ret;
79+
}();
80+
static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__");
81+
static MLC_SYMBOL_HIDE inline MLCVTableHandle str = VTableGetGlobal("__str__");
82+
static MLC_SYMBOL_HIDE inline MLCVTableHandle ir_print = VTableGetGlobal("__ir_print__");
83+
static MLC_SYMBOL_HIDE inline MLCVTableHandle init = VTableGetGlobal("__init__");
84+
};
85+
86+
} // namespace mlc
87+
#endif // MLC_BASE_LIB_H_

include/mlc/base/optional.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#ifndef MLC_BASE_OPTIONAL_H_
22
#define MLC_BASE_OPTIONAL_H_
3+
#include "./lib.h"
34
#include "./ref.h"
45
#include "./utils.h"
56
#include <type_traits>

0 commit comments

Comments
 (0)