Skip to content

Commit b8c78bc

Browse files
committed
Introduce Structural Equality Checks
1 parent 574d5ae commit b8c78bc

36 files changed

+1907
-752
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)
22

33
project(
44
mlc
5-
VERSION 0.0.9
5+
VERSION 0.0.10
66
DESCRIPTION "MLC-Python"
77
LANGUAGES C CXX
88
)

README.md

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,132 @@
44
MLC-Python
55
</h1>
66

7-
* [:key: Key features](#keykey-features)
8-
* [:inbox_tray: Installation](#inbox_trayinstallation)
9-
+ [:package: Install From PyPI](#packageinstall-from-pypi)
10-
+ [:gear: Build from Source](#gearbuild-from-source)
11-
+ [:ferris_wheel: Create MLC-Python Wheels](#ferris_wheel-create-mlc-python-wheels)
7+
* [:inbox_tray: Installation](#inbox_tray-installation)
8+
* [:key: Key Features](#key-key-features)
9+
+ [:building_construction: MLC Dataclass](#building_construction-mlc-dataclass)
10+
+ [:dart: Structure-Aware Tooling](#dart-structure-aware-tooling)
11+
+ [:snake: Text Formats in Python AST](#snake-text-formats-in-python-ast)
12+
+ [:zap: Zero-Copy Interoperability with C++ Plugins](#zap-zero-copy-interoperability-with-c-plugins)
13+
* [:fuelpump: Development](#fuelpump-development)
14+
+ [:gear: Editable Build](#gear-editable-build)
15+
+ [:ferris_wheel: Create Wheels](#ferris_wheel-create-wheels)
1216

13-
MLC is a Python-first toolkit that makes it more ergonomic to build AI compilers, runtimes, and compound AI systems. It provides Pythonic dataclasses with rich tooling infra, which includes:
1417

15-
- Structure-aware equality and hashing methods;
16-
- Serialization in JSON / pickle;
17-
- Text format printing and parsing in Python syntax.
18+
MLC is a Python-first toolkit that makes it more ergonomic to build AI compilers, runtimes, and compound AI systems with Pythonic dataclass, rich tooling infra and zero-copy interoperability with C++ plugins.
1819

19-
Additionally, MLC language bindings support:
20+
## :inbox_tray: Installation
2021

21-
- Zero-copy bidirectional functioning calling for all MLC dataclasses.
22+
```bash
23+
pip install -U mlc-python
24+
```
2225

23-
## :key: Key features
26+
## :key: Key Features
2427

25-
TBD
28+
### :building_construction: MLC Dataclass
2629

27-
## :inbox_tray: Installation
30+
MLC dataclass is similar to Python’s native dataclass:
2831

29-
### :package: Install From PyPI
32+
```python
33+
import mlc.dataclasses as mlcd
3034

31-
```bash
32-
pip install -U mlc-python
35+
@mlcd.py_class("demo.MyClass")
36+
class MyClass(mlcd.PyClass):
37+
a: int
38+
b: str
39+
c: float | None
40+
41+
instance = MyClass(12, "test", c=None)
42+
```
43+
44+
**Type safety**. MLC dataclass checks type strictly in Cython and C++.
45+
46+
```python
47+
>>> instance.c = 10; print(instance)
48+
demo.MyClass(a=12, b='test', c=10.0)
49+
50+
>>> instance.c = "wrong type"
51+
TypeError: must be real number, not str
52+
53+
>>> instance.non_exist = 1
54+
AttributeError: 'MyClass' object has no attribute 'non_exist' and no __dict__ for setting new attributes
55+
```
56+
57+
**Serialization**. MLC dataclasses are picklable and JSON-serializable.
58+
59+
```python
60+
>>> MyClass.from_json(instance.json())
61+
demo.MyClass(a=12, b='test', c=None)
62+
63+
>>> import pickle; pickle.loads(pickle.dumps(instance))
64+
demo.MyClass(a=12, b='test', c=None)
3365
```
3466

35-
### :gear: Build from Source
67+
### :dart: Structure-Aware Tooling
68+
69+
An extra `structure` field are used to specify a dataclass's structure, indicating def site and scoping in an IR.
70+
71+
```python
72+
import mlc.dataclasses as mlcd
73+
74+
@mlcd.py_class
75+
class Expr(mlcd.PyClass):
76+
def __add__(self, other):
77+
return Add(a=self, b=other)
78+
79+
@mlcd.py_class(structure="nobind")
80+
class Add(Expr):
81+
a: Expr
82+
b: Expr
83+
84+
@mlcd.py_class(structure="var")
85+
class Var(Expr):
86+
name: str = mlcd.field(structure=None) # excludes `name` from defined structure
87+
88+
@mlcd.py_class(structure="bind")
89+
class Let(Expr):
90+
rhs: Expr
91+
lhs: Var = mlcd.field(structure="bind") # `Let.lhs` is the def-site
92+
body: Expr
93+
```
94+
95+
**Structural equality**. Method eq_s is ready to use to compare the structural equality (alpha equivalence) of two IRs.
96+
97+
```python
98+
"""
99+
L1: let z = x + y; z
100+
L2: let x = y + z; x
101+
L3: let z = x + x; z
102+
"""
103+
>>> x, y, z = Var("x"), Var("y"), Var("z")
104+
>>> L1 = Let(rhs=x + y, lhs=z, body=z)
105+
>>> L2 = Let(rhs=y + z, lhs=x, body=x)
106+
>>> L3 = Let(rhs=x + x, lhs=z, body=z)
107+
>>> L1.eq_s(L2)
108+
True
109+
>>> L1.eq_s(L3, assert_mode=True)
110+
ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound
111+
```
112+
113+
**Structural hashing**. TBD
114+
115+
### :snake: Text Formats in Python AST
116+
117+
TBD
118+
119+
### :zap: Zero-Copy Interoperability with C++ Plugins
120+
121+
TBD
122+
123+
## :fuelpump: Development
124+
125+
### :gear: Editable Build
36126

37127
```bash
38-
python -m venv .venv
39-
source .venv/bin/activate
40-
python -m pip install --verbose --editable ".[dev]"
128+
pip install --verbose --editable ".[dev]"
41129
pre-commit install
42130
```
43131

44-
### :ferris_wheel: Create MLC-Python Wheels
132+
### :ferris_wheel: Create Wheels
45133

46134
This project uses `cibuildwheel` to build cross-platform wheels. See `.github/workflows/wheels.ym` for more details.
47135

cpp/c_api.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "./registry.h"
2-
#include "mlc/core/str.h"
2+
#include <mlc/all.h>
33

44
namespace mlc {
55
namespace registry {
@@ -20,22 +20,23 @@ using ::mlc::registry::TypeTable;
2020
namespace {
2121
thread_local Any last_error;
2222
MLC_REGISTER_FUNC("mlc.ffi.LoadDSO").set_body([](std::string name) { TypeTable::Get(nullptr)->LoadDSO(name); });
23-
MLC_REGISTER_FUNC("mlc.core.JSONParse").set_body([](AnyView json_str) {
23+
MLC_REGISTER_FUNC("mlc.core.JSONLoads").set_body([](AnyView json_str) {
2424
if (json_str.type_index == kMLCRawStr) {
25-
return ::mlc::core::ParseJSON(json_str.operator const char *());
25+
return ::mlc::core::JSONLoads(json_str.operator const char *());
2626
} else {
2727
::mlc::Str str = json_str;
28-
return ::mlc::core::ParseJSON(str);
28+
return ::mlc::core::JSONLoads(str);
2929
}
3030
});
31-
MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize);
31+
MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize); // TODO: `AnyView` as function argument
3232
MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
3333
if (json_str.type_index == kMLCRawStr) {
3434
return ::mlc::core::Deserialize(json_str.operator const char *());
3535
} else {
3636
return ::mlc::core::Deserialize(json_str.operator ::mlc::Str());
3737
}
3838
});
39+
MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
3940
} // namespace
4041

4142
MLC_API MLCAny MLCGetLastError() {
@@ -64,9 +65,14 @@ MLC_API int32_t MLCTypeKey2Info(MLCTypeTableHandle _self, const char *type_key,
6465
}
6566

6667
MLC_API int32_t MLCTypeDefReflection(MLCTypeTableHandle self, int32_t type_index, int64_t num_fields,
67-
MLCTypeField *fields, int64_t num_methods, MLCTypeMethod *methods) {
68+
MLCTypeField *fields, int64_t num_methods, MLCTypeMethod *methods,
69+
int32_t structure_kind, int64_t num_sub_structures, int32_t *sub_structure_indices,
70+
int32_t *sub_structure_kinds) {
6871
MLC_SAFE_CALL_BEGIN();
69-
TypeTable::Get(self)->TypeDefReflection(type_index, num_fields, fields, num_methods, methods);
72+
auto *type_info = TypeTable::Get(self)->GetTypeInfoWrapper(type_index);
73+
type_info->SetFields(num_fields, fields);
74+
type_info->SetMethods(num_methods, methods);
75+
type_info->SetStructure(structure_kind, num_sub_structures, sub_structure_indices, sub_structure_kinds);
7076
MLC_SAFE_CALL_END(&last_error);
7177
}
7278

@@ -168,7 +174,7 @@ MLC_API int32_t MLCExtObjCreate(int32_t num_bytes, int32_t type_index, MLCAny *r
168174

169175
MLC_API int32_t _MLCExtObjDeleteImpl(void *objptr) {
170176
MLC_SAFE_CALL_BEGIN();
171-
::mlc::core::DeleteExternObject(static_cast<MLCAny *>(objptr));
177+
::mlc::core::DeleteExternObject(static_cast<::mlc::Object *>(objptr));
172178
MLC_SAFE_CALL_END(&last_error);
173179
}
174180

cpp/registry.h

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ struct TypeInfoWrapper {
7373
void Reset();
7474
void ResetFields();
7575
void ResetMethods();
76+
void ResetStructure();
7677
void SetFields(int64_t new_num_fields, MLCTypeField *fields);
7778
void SetMethods(int64_t new_num_methods, MLCTypeMethod *methods);
79+
void SetStructure(int32_t structure_kind, int64_t num_sub_structures, int32_t *sub_structure_indices,
80+
int32_t *sub_structure_kinds);
7881
~TypeInfoWrapper() { this->Reset(); }
7982
};
8083

@@ -210,6 +213,9 @@ struct TypeTable {
210213
}
211214
info->fields = nullptr;
212215
info->methods = nullptr;
216+
info->structure_kind = 0;
217+
info->sub_structure_indices = nullptr;
218+
info->sub_structure_kinds = nullptr;
213219
wrapper->table = this;
214220
return info;
215221
}
@@ -231,9 +237,7 @@ struct TypeTable {
231237
this->NewObjPtr(&it->second, func);
232238
}
233239

234-
void TypeDefReflection(int32_t type_index, //
235-
int64_t num_fields, MLCTypeField *fields, //
236-
int64_t num_methods, MLCTypeMethod *methods) {
240+
TypeInfoWrapper *GetTypeInfoWrapper(int32_t type_index) {
237241
TypeInfoWrapper *wrapper = nullptr;
238242
try {
239243
wrapper = this->type_table.at(type_index).get();
@@ -242,8 +246,7 @@ struct TypeTable {
242246
if (wrapper == nullptr || wrapper->table != this) {
243247
MLC_THROW(KeyError) << "Type index `" << type_index << "` not registered";
244248
}
245-
wrapper->SetFields(num_fields, fields);
246-
wrapper->SetMethods(num_methods, methods);
249+
return wrapper;
247250
}
248251

249252
void LoadDSO(std::string name) {
@@ -352,6 +355,15 @@ inline void TypeInfoWrapper::ResetMethods() {
352355
}
353356
}
354357

358+
inline void TypeInfoWrapper::ResetStructure() {
359+
if (this->info.sub_structure_indices) {
360+
this->table->DelArray(this->info.sub_structure_indices);
361+
}
362+
if (this->info.sub_structure_kinds) {
363+
this->table->DelArray(this->info.sub_structure_kinds);
364+
}
365+
}
366+
355367
inline void TypeInfoWrapper::SetFields(int64_t new_num_fields, MLCTypeField *fields) {
356368
this->ResetFields();
357369
this->num_fields = new_num_fields;
@@ -360,6 +372,9 @@ inline void TypeInfoWrapper::SetFields(int64_t new_num_fields, MLCTypeField *fie
360372
dst[i] = fields[i];
361373
dst[i].name = this->table->NewArray(fields[i].name);
362374
this->table->NewObjPtr(&dst[i].ty, dst[i].ty);
375+
if (dst[i].index != i) {
376+
MLC_THROW(ValueError) << "Field index mismatch: " << i << " vs " << dst[i].index;
377+
}
363378
}
364379
dst[num_fields] = MLCTypeField{};
365380
std::sort(dst, dst + num_fields, [](const MLCTypeField &a, const MLCTypeField &b) { return a.offset < b.offset; });
@@ -382,6 +397,25 @@ inline void TypeInfoWrapper::SetMethods(int64_t new_num_methods, MLCTypeMethod *
382397
[](const MLCTypeMethod &a, const MLCTypeMethod &b) { return std::strcmp(a.name, b.name) < 0; });
383398
}
384399

400+
inline void TypeInfoWrapper::SetStructure(int32_t structure_kind, int64_t num_sub_structures,
401+
int32_t *sub_structure_indices, int32_t *sub_structure_kinds) {
402+
this->ResetStructure();
403+
this->info.structure_kind = structure_kind;
404+
if (num_sub_structures > 0) {
405+
this->info.sub_structure_indices = this->table->NewArray<int32_t>(num_sub_structures + 1);
406+
this->info.sub_structure_kinds = this->table->NewArray<int32_t>(num_sub_structures + 1);
407+
std::memcpy(this->info.sub_structure_indices, sub_structure_indices, num_sub_structures * sizeof(int32_t));
408+
std::memcpy(this->info.sub_structure_kinds, sub_structure_kinds, num_sub_structures * sizeof(int32_t));
409+
std::reverse(this->info.sub_structure_indices, this->info.sub_structure_indices + num_sub_structures);
410+
std::reverse(this->info.sub_structure_kinds, this->info.sub_structure_kinds + num_sub_structures);
411+
this->info.sub_structure_indices[num_sub_structures] = -1;
412+
this->info.sub_structure_kinds[num_sub_structures] = -1;
413+
} else {
414+
this->info.sub_structure_indices = nullptr;
415+
this->info.sub_structure_kinds = nullptr;
416+
}
417+
}
418+
385419
} // namespace registry
386420
} // namespace mlc
387421

include/mlc/base/any.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct AnyView : public MLCAny {
3535
}
3636
/***** Misc *****/
3737
bool defined() const { return this->type_index != static_cast<int32_t>(MLCTypeIndex::kMLCNone); }
38+
const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; }
39+
int32_t GetTypeIndex() const { return this->type_index; }
3840
Str str() const;
3941
friend std::ostream &operator<<(std::ostream &os, const AnyView &src);
4042

@@ -76,6 +78,8 @@ struct Any : public MLCAny {
7678
}
7779
/***** Misc *****/
7880
bool defined() const { return this->type_index != static_cast<int32_t>(MLCTypeIndex::kMLCNone); }
81+
const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; }
82+
int32_t GetTypeIndex() const { return this->type_index; }
7983
Str str() const;
8084
friend std::ostream &operator<<(std::ostream &os, const Any &src);
8185

0 commit comments

Comments
 (0)