Skip to content

Commit ea826f1

Browse files
committed
Introduce mlc.core.typing
1 parent d3f57d8 commit ea826f1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+3249
-2066
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)
22

33
project(
44
mlc
5-
VERSION 0.0.7
5+
VERSION 0.0.8
66
DESCRIPTION "MLC-Python"
77
LANGUAGES C CXX
88
)

README.md

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,35 @@
11
MLC-Python
22
==========
33

4-
## Installation
4+
🛠️ MLC is a Python-first toolkit for building ML compilers, runtimes, and compound AI systems. It enables you to define nested data structures (like compiler IRs) as roundtrippable text formats in Python syntax, with structural comparison for unit-testing and zero-copy C++ interop when needed.
55

6-
```bash
7-
pip install -U mlc-python
8-
```
6+
## 🔑 Key features
7+
8+
### 🐍 `mlc.ast`: Text formats in Python Syntax
9+
10+
TBD
11+
12+
### 🏗️ `mlc.dataclasses`: Cross-Language Dataclasses
13+
14+
TBD
915

10-
## Features
16+
### `mlc.Func`: Zero-Copy Cross-Language Function Calling
1117

12-
TBA
18+
TBD
1319

14-
## Development
20+
### 🎯 Structural Testing for Nested Dataclasses
21+
22+
TBD
23+
24+
## 📥 Installation
25+
26+
### 📦 Install From PyPI
27+
28+
```bash
29+
pip install -U mlc-python
30+
```
1531

16-
### Build from Source
32+
### ⚙️ Build from Source
1733

1834
```bash
1935
python -m venv .venv
@@ -22,9 +38,9 @@ python -m pip install --verbose --editable ".[dev]"
2238
pre-commit install
2339
```
2440

25-
### Create Wheels
41+
### 🎡 Create MLC-Python Wheels
2642

27-
See `.github/workflows/wheels.yml` for more details. This project uses `cibuildwheel` to build cross-platform wheels.
43+
This project uses `cibuildwheel` to build cross-platform wheels. See `.github/workflows/wheels.ym` for more details.
2844

2945
```bash
3046
export CIBW_BUILD_VERBOSITY=3

cpp/c_api.cc

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "./registry.h"
2+
#include <cstdint>
3+
#include <cstdlib>
24
#include <iostream>
35

46
namespace mlc {
@@ -81,15 +83,15 @@ MLC_API int32_t MLCDynTypeTypeTableDestroy(MLCTypeTableHandle handle) {
8183
MLC_API int32_t MLCAnyIncRef(MLCAny *any) {
8284
MLC_SAFE_CALL_BEGIN();
8385
if (!::mlc::base::IsTypeIndexPOD(any->type_index)) {
84-
::mlc::base::IncRef(any->v_obj);
86+
::mlc::base::IncRef(any->v.v_obj);
8587
}
8688
MLC_SAFE_CALL_END(&last_error);
8789
}
8890

8991
MLC_API int32_t MLCAnyDecRef(MLCAny *any) {
9092
MLC_SAFE_CALL_BEGIN();
9193
if (!::mlc::base::IsTypeIndexPOD(any->type_index)) {
92-
::mlc::base::DecRef(any->v_obj);
94+
::mlc::base::DecRef(any->v.v_obj);
9395
}
9496
MLC_SAFE_CALL_END(&last_error);
9597
}
@@ -144,33 +146,28 @@ MLC_API int32_t MLCErrorGetInfo(MLCAny error, int32_t *num_strs, const char ***s
144146
MLC_SAFE_CALL_END(&last_error);
145147
}
146148

147-
MLC_API void *MLCExtObjCreate(int32_t bytes, int32_t type_index) {
148-
char *data = new char[bytes]();
149-
std::memset(data, 0, bytes);
150-
MLCAny *header = reinterpret_cast<MLCAny *>(data);
151-
header->type_index = type_index;
152-
header->ref_cnt = 0;
153-
header->deleter = MLCExtObjDelete;
154-
return data;
149+
MLC_API int32_t MLCExtObjCreate(int32_t num_bytes, int32_t type_index, MLCAny *ret) {
150+
MLC_SAFE_CALL_BEGIN();
151+
*static_cast<Any *>(ret) = mlc::AllocExternObject(type_index, num_bytes);
152+
MLC_SAFE_CALL_END(&last_error);
155153
}
156154

157-
MLC_API void MLCExtObjDelete(void *objptr) {
155+
MLC_API int32_t _MLCExtObjDeleteImpl(void *objptr) {
156+
MLC_SAFE_CALL_BEGIN();
158157
MLCAny *header = reinterpret_cast<MLCAny *>(objptr);
159-
MLCTypeInfo *info = TypeTable::Global()->GetTypeInfo(header->type_index);
160-
if (info == nullptr) { // TODO: error handling
158+
if (MLCTypeInfo *info = TypeTable::Global()->GetTypeInfo(header->type_index)) {
159+
::mlc::core::VisitTypeField(objptr, info, ::mlc::core::ExternObjDeleter{});
160+
std::free(objptr);
161+
} else {
161162
std::cerr << "Cannot find type info for type index: " << header->type_index << std::endl;
162163
std::abort();
163164
}
164-
MLCTypeField *fields = info->fields;
165-
for (int32_t i = 0;; i++) {
166-
MLCTypeField &field = fields[i];
167-
if (field.name == nullptr) {
168-
break;
169-
}
170-
if (field.is_owned_obj_ptr) {
171-
MLCObject *ptr = reinterpret_cast<MLCObjPtr *>(static_cast<char *>(objptr) + field.offset)->ptr;
172-
::mlc::base::DecRef(ptr);
173-
}
165+
MLC_SAFE_CALL_END(&last_error);
166+
}
167+
168+
MLC_API void MLCExtObjDelete(void *objptr) {
169+
if (int32_t error_code = _MLCExtObjDeleteImpl(objptr)) {
170+
std::cerr << "Error code (" << error_code << ") when deleting external object: " << last_error << std::endl;
171+
std::abort();
174172
}
175-
delete[] reinterpret_cast<char *>(objptr);
176173
}

cpp/c_api_tests.cc

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ MLC_REGISTER_FUNC("mlc.testing.cxx_raw_str").set_body([](const char *x) { return
1717
/**************** Reflection ****************/
1818

1919
struct ReflectionTestObj : public Object {
20-
std::string x_mutable;
20+
Str x_mutable;
2121
int32_t y_immutable;
2222

2323
ReflectionTestObj(std::string x, int32_t y) : x_mutable(x), y_immutable(y) {}
@@ -34,6 +34,50 @@ struct ReflectionTest : public ObjectRef {
3434
.MemFn("YPlusOne", &ReflectionTestObj::YPlusOne);
3535
};
3636

37+
struct TestingCClassObj : public Object {
38+
int8_t i8;
39+
int16_t i16;
40+
int32_t i32;
41+
int64_t i64;
42+
float f32;
43+
double f64;
44+
void *raw_ptr;
45+
DLDataType dtype;
46+
DLDevice device;
47+
Any any;
48+
Func func;
49+
UList ulist;
50+
UDict udict;
51+
Str str_;
52+
53+
explicit TestingCClassObj(int8_t i8, int16_t i16, int32_t i32, int64_t i64, float f32, double f64, void *raw_ptr,
54+
DLDataType dtype, DLDevice device, Any any, Func func, UList ulist, UDict udict, Str str_)
55+
: i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device),
56+
any(any), func(func), ulist(ulist), udict(udict), str_(str_) {}
57+
58+
MLC_DEF_DYN_TYPE(ReflectionTestObj, Object, "mlc.testing.c_class");
59+
};
60+
61+
struct TestingCClass : public ObjectRef {
62+
MLC_DEF_OBJ_REF(TestingCClass, TestingCClassObj, ObjectRef)
63+
.Field("i8", &TestingCClassObj::i8)
64+
.Field("i16", &TestingCClassObj::i16)
65+
.Field("i32", &TestingCClassObj::i32)
66+
.Field("i64", &TestingCClassObj::i64)
67+
.Field("f32", &TestingCClassObj::f32)
68+
.Field("f64", &TestingCClassObj::f64)
69+
.Field("raw_ptr", &TestingCClassObj::raw_ptr)
70+
.Field("dtype", &TestingCClassObj::dtype)
71+
.Field("device", &TestingCClassObj::device)
72+
.Field("any", &TestingCClassObj::any)
73+
.Field("func", &TestingCClassObj::func)
74+
.Field("ulist", &TestingCClassObj::ulist)
75+
.Field("udict", &TestingCClassObj::udict)
76+
.Field("str_", &TestingCClassObj::str_)
77+
.StaticFn("__init__", InitOf<TestingCClassObj, int8_t, int16_t, int32_t, int64_t, float, double, void *,
78+
DLDataType, DLDevice, Any, Func, UList, UDict, Str>);
79+
};
80+
3781
/**************** Traceback ****************/
3882

3983
MLC_REGISTER_FUNC("mlc.testing.throw_exception_from_c").set_body([]() {

cpp/registry.h

Lines changed: 49 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -78,74 +78,8 @@ struct TypeInfoWrapper {
7878
~TypeInfoWrapper() { this->Reset(); }
7979
};
8080

81-
template <typename T> struct PODGetterSetter {
82-
static int32_t Getter(MLCTypeField *, void *addr, MLCAny *ret) {
83-
using namespace ::mlc::base;
84-
TypeTraits<T>::TypeToAny(*static_cast<T *>(addr), ret);
85-
return 0;
86-
}
87-
static int32_t Setter(MLCTypeField *, void *addr, MLCAny *src) {
88-
using namespace mlc::base;
89-
try {
90-
*static_cast<T *>(addr) = TypeTraits<T>::AnyToTypeUnowned(src);
91-
} catch (const TemporaryTypeError &) {
92-
std::ostringstream oss;
93-
oss << "Cannot convert from type `" << TypeIndex2TypeKey(src->type_index) << "` to `" << TypeTraits<T>::type_str
94-
<< "`";
95-
*static_cast<::mlc::Any *>(src) = MLC_MAKE_ERROR_HERE(TypeError, oss.str());
96-
return -2;
97-
}
98-
return 0;
99-
}
100-
};
101-
102-
template <> struct PODGetterSetter<std::nullptr_t> {
103-
static int32_t Getter(MLCTypeField *, void *, MLCAny *ret) {
104-
MLC_SAFE_CALL_BEGIN();
105-
*static_cast<Any *>(ret) = nullptr;
106-
MLC_SAFE_CALL_END(static_cast<Any *>(ret));
107-
}
108-
static int32_t Setter(MLCTypeField *, void *addr, MLCAny *src) {
109-
MLC_SAFE_CALL_BEGIN();
110-
*static_cast<void **>(addr) = nullptr;
111-
MLC_SAFE_CALL_END(static_cast<Any *>(src));
112-
}
113-
};
114-
115-
MLC_INLINE int32_t ObjPtrGetterDefault(MLCTypeField *, void *addr, MLCAny *ret) {
116-
if (addr == nullptr) {
117-
ret->type_index = static_cast<int32_t>(MLCTypeIndex::kMLCNone);
118-
ret->v_obj = nullptr;
119-
} else {
120-
Object *v = static_cast<Object *>(addr);
121-
ret->type_index = v->_mlc_header.type_index;
122-
ret->v_obj = reinterpret_cast<MLCAny *>(v);
123-
}
124-
return 0;
125-
}
126-
127-
MLC_INLINE int32_t ObjPtrSetterDefault(MLCTypeField *field, void *addr, MLCAny *src) {
128-
if (field->type_annotation == nullptr) {
129-
std::ostringstream oss;
130-
oss << "Type annotation is required for field `" << field->name << "`";
131-
*static_cast<Any *>(src) = MLC_MAKE_ERROR_HERE(InternalError, oss.str());
132-
return -2;
133-
}
134-
int32_t target_type_index = field->type_annotation[0]->type_index;
135-
if (src == nullptr || src->type_index != target_type_index) {
136-
std::ostringstream oss;
137-
oss << "Cannot convert from type `" << ::mlc::base::TypeIndex2TypeKey(src) << "` to `"
138-
<< ::mlc::base::TypeIndex2TypeKey(target_type_index) << "`";
139-
*static_cast<Any *>(src) = MLC_MAKE_ERROR_HERE(TypeError, oss.str());
140-
return -2;
141-
}
142-
Ref<Object> *dst = static_cast<Ref<Object> *>(addr);
143-
*dst = reinterpret_cast<Object *>(src->v_obj);
144-
return 0;
145-
}
146-
14781
struct TypeTable {
148-
using ObjPtr = std::unique_ptr<MLCObject, void (*)(MLCObject *)>;
82+
using ObjPtr = std::unique_ptr<MLCAny, void (*)(MLCAny *)>;
14983

15084
int32_t num_types;
15185
std::vector<std::unique_ptr<TypeInfoWrapper>> type_table;
@@ -192,7 +126,7 @@ struct TypeTable {
192126
std::cerr << "Object already exists in the memory pool: " << source;
193127
std::abort();
194128
}
195-
MLCObject *source_casted = reinterpret_cast<MLCObject *>(source);
129+
MLCAny *source_casted = reinterpret_cast<MLCAny *>(source);
196130
::mlc::base::IncRef(source_casted);
197131
it->second = ObjPtr(source_casted, ::mlc::base::DecRef);
198132
}
@@ -274,8 +208,6 @@ struct TypeTable {
274208
std::copy(parent->type_ancestors, parent->type_ancestors + parent->type_depth, info->type_ancestors);
275209
info->type_ancestors[parent->type_depth] = parent_type_index;
276210
}
277-
info->getter = ObjPtrSetterDefault;
278-
info->setter = ObjPtrGetterDefault;
279211
info->fields = nullptr;
280212
info->methods = nullptr;
281213
wrapper->table = this;
@@ -329,20 +261,65 @@ struct _POD_REG {
329261
.MemFn("__str__", &::mlc::base::TypeTraits<std::nullptr_t>::__str__);
330262
inline static const int32_t _int = //
331263
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCInt))
264+
.StaticFn("__new_ref__",
265+
[](void *_dst, int64_t value) {
266+
MLCAny **dst = reinterpret_cast<MLCAny **>(_dst);
267+
MLCAny *ret = ::mlc::PODAllocator<int64_t>::New(value);
268+
if (*dst != nullptr) {
269+
::mlc::base::DecRef(*dst);
270+
}
271+
*dst = ret;
272+
})
332273
.MemFn("__str__", &::mlc::base::TypeTraits<int64_t>::__str__);
333274
inline static const int32_t _float = //
334275
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCFloat))
276+
.StaticFn("__new_ref__",
277+
[](void *_dst, double value) {
278+
MLCAny **dst = reinterpret_cast<MLCAny **>(_dst);
279+
MLCAny *ret = ::mlc::PODAllocator<double>::New(value);
280+
if (*dst != nullptr) {
281+
::mlc::base::DecRef(*dst);
282+
}
283+
*dst = ret;
284+
})
335285
.MemFn("__str__", &::mlc::base::TypeTraits<double>::__str__);
336286
inline static const int32_t _ptr = //
337287
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCPtr))
288+
.StaticFn("__new_ref__",
289+
[](void *_dst, void *value) {
290+
MLCAny **dst = reinterpret_cast<MLCAny **>(_dst);
291+
MLCAny *ret = ::mlc::PODAllocator<void *>::New(value);
292+
if (*dst != nullptr) {
293+
::mlc::base::DecRef(*dst);
294+
}
295+
*dst = ret;
296+
})
338297
.MemFn("__str__", &::mlc::base::TypeTraits<void *>::__str__);
339298
inline static const int32_t _device = //
340299
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCDevice))
341300
.StaticFn("__init__", [](AnyView device) { return device.operator DLDevice(); })
301+
.StaticFn("__new_ref__",
302+
[](void *_dst, DLDevice value) {
303+
MLCAny **dst = reinterpret_cast<MLCAny **>(_dst);
304+
MLCAny *ret = ::mlc::PODAllocator<DLDevice>::New(value);
305+
if (*dst != nullptr) {
306+
::mlc::base::DecRef(*dst);
307+
}
308+
*dst = ret;
309+
})
342310
.MemFn("__str__", &::mlc::base::TypeTraits<DLDevice>::__str__);
343311
inline static const int32_t _dtype = //
344312
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCDataType))
345313
.StaticFn("__init__", [](AnyView dtype) { return dtype.operator DLDataType(); })
314+
.StaticFn("__new_ref__",
315+
[](void *_dst, DLDataType value) {
316+
MLCAny **dst = reinterpret_cast<MLCAny **>(_dst);
317+
MLCAny *ret = ::mlc::PODAllocator<DLDataType>::New(value);
318+
if (*dst != nullptr) {
319+
::mlc::base::DecRef(*dst);
320+
}
321+
*dst = ret;
322+
})
346323
.MemFn("__str__", &::mlc::base::TypeTraits<DLDataType>::__str__);
347324
inline static const int32_t _str = //
348325
::mlc::core::ReflectionHelper(static_cast<int32_t>(MLCTypeIndex::kMLCRawStr))
@@ -358,8 +335,7 @@ inline TypeTable *TypeTable::New() {
358335
{ \
359336
using Traits = ::mlc::base::TypeTraits<UnderlyingType>; \
360337
MLCTypeInfo *info = Self->TypeRegister(-1, Traits::type_index, Traits::type_str); \
361-
info->setter = PODGetterSetter<UnderlyingType>::Setter; \
362-
info->getter = PODGetterSetter<UnderlyingType>::Getter; \
338+
(void)info; \
363339
}
364340
MLC_TYPE_TABLE_INIT_TYPE(std::nullptr_t, self);
365341
MLC_TYPE_TABLE_INIT_TYPE(int64_t, self);
@@ -416,12 +392,7 @@ inline void TypeInfoWrapper::SetFields(int64_t new_num_fields, MLCTypeField *fie
416392
for (int64_t i = 0; i < num_fields; i++) {
417393
dst[i] = fields[i];
418394
dst[i].name = this->table->NewArray(fields[i].name);
419-
int32_t len_type_ann = 0;
420-
while (fields[i].type_annotation[len_type_ann] != nullptr) {
421-
++len_type_ann;
422-
}
423-
dst[i].type_annotation = reinterpret_cast<MLCTypeInfo **>(this->table->NewArray<void *>(len_type_ann + 1));
424-
std::copy(fields[i].type_annotation, fields[i].type_annotation + len_type_ann + 1, dst[i].type_annotation);
395+
this->table->NewObjPtr(&dst[i].ty, dst[i].ty);
425396
}
426397
dst[num_fields] = MLCTypeField{};
427398
std::sort(dst, dst + num_fields, [](const MLCTypeField &a, const MLCTypeField &b) { return a.offset < b.offset; });

0 commit comments

Comments
 (0)