Skip to content

Commit 4434b87

Browse files
committed
feat(AST): Introduce mlc.printer
1 parent cd668fe commit 4434b87

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+7313
-2016
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on: [push, pull_request]
44
env:
55
CIBW_BUILD_VERBOSITY: 3
66
CIBW_TEST_REQUIRES: "pytest"
7-
CIBW_TEST_COMMAND: "pytest -svv {project}/tests/python/"
7+
CIBW_TEST_COMMAND: "pytest -svv --durations=20 {project}/tests/python/"
88
MLC_CIBW_VERSION: "2.20.0"
99
MLC_PYTHON_VERSION: "3.9"
1010
MLC_CIBW_WIN_BUILD: "cp39-win_amd64"

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)
22

33
project(
44
mlc
5-
VERSION 0.0.12
5+
VERSION 0.0.13
66
DESCRIPTION "MLC-Python"
77
LANGUAGES C CXX
88
)
@@ -41,6 +41,9 @@ if (MLC_BUILD_REGISTRY)
4141
add_library(mlc_registry_objs OBJECT
4242
"${CMAKE_CURRENT_SOURCE_DIR}/cpp/c_api.cc"
4343
"${CMAKE_CURRENT_SOURCE_DIR}/cpp/c_api_tests.cc"
44+
"${CMAKE_CURRENT_SOURCE_DIR}/cpp/printer.cc"
45+
"${CMAKE_CURRENT_SOURCE_DIR}/cpp/json.cc"
46+
"${CMAKE_CURRENT_SOURCE_DIR}/cpp/structure.cc"
4447
"${CMAKE_CURRENT_SOURCE_DIR}/cpp/traceback.cc"
4548
"${CMAKE_CURRENT_SOURCE_DIR}/cpp/traceback_win.cc"
4649
)

README.md

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
* [:key: Key Features](#key-key-features)
99
+ [:building_construction: MLC Dataclass](#building_construction-mlc-dataclass)
1010
+ [:dart: Structure-Aware Tooling](#dart-structure-aware-tooling)
11-
+ [:snake: Text Formats in Python AST](#snake-text-formats-in-python-ast)
11+
+ [:snake: Text Formats in Python](#snake-text-formats-in-python)
1212
+ [:zap: Zero-Copy Interoperability with C++ Plugins](#zap-zero-copy-interoperability-with-c-plugins)
1313
* [:fuelpump: Development](#fuelpump-development)
1414
+ [:gear: Editable Build](#gear-editable-build)
1515
+ [:ferris_wheel: Create Wheels](#ferris_wheel-create-wheels)
1616

1717

18-
MLC is a Python-first toolkit that makes it more ergonomic to build AI compilers, runtimes, and compound AI systems with Pythonic dataclass, rich tooling infra and zero-copy interoperability with C++ plugins.
18+
MLC is a Python-first toolkit that streamlines the development of AI compilers, runtimes, and compound AI systems with its Pythonic dataclasses, structure-aware tooling, and Python-based text formats.
19+
20+
Beyond pure Python, MLC natively supports zero-copy interoperation with C++ plugins, and enables a smooth engineering practice transitioning from Python to hybrid or Python-free development.
1921

2022
## :inbox_tray: Installation
2123

@@ -41,7 +43,7 @@ class MyClass(mlcd.PyClass):
4143
instance = MyClass(12, "test", c=None)
4244
```
4345

44-
**Type safety**. MLC dataclass checks type strictly in Cython and C++.
46+
**Type safety**. MLC dataclass enforces strict type checking using Cython and C++.
4547

4648
```python
4749
>>> instance.c = 10; print(instance)
@@ -68,6 +70,8 @@ demo.MyClass(a=12, b='test', c=None)
6870

6971
An extra `structure` field are used to specify a dataclass's structure, indicating def site and scoping in an IR.
7072

73+
<details><summary> Define a toy IR with `structure`. </summary>
74+
7175
```python
7276
import mlc.dataclasses as mlcd
7377

@@ -92,18 +96,15 @@ class Let(Expr):
9296
body: Expr
9397
```
9498

99+
</details>
100+
95101
**Structural equality**. Member method `eq_s` compares the structural equality (alpha equivalence) of two IRs represented by MLC's structured dataclass.
96102

97103
```python
98-
"""
99-
L1: let z = x + y; z
100-
L2: let x = y + z; x
101-
L3: let z = x + x; z
102-
"""
103104
>>> x, y, z = Var("x"), Var("y"), Var("z")
104-
>>> L1 = Let(rhs=x + y, lhs=z, body=z)
105-
>>> L2 = Let(rhs=y + z, lhs=x, body=x)
106-
>>> L3 = Let(rhs=x + x, lhs=z, body=z)
105+
>>> L1 = Let(rhs=x + y, lhs=z, body=z) # let z = x + y; z
106+
>>> L2 = Let(rhs=y + z, lhs=x, body=x) # let x = y + z; x
107+
>>> L3 = Let(rhs=x + x, lhs=z, body=z) # let z = x + x; z
107108
>>> L1.eq_s(L2)
108109
True
109110
>>> L1.eq_s(L3, assert_mode=True)
@@ -118,9 +119,98 @@ ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent bindi
118119
>>> assert L1_hash != L3_hash
119120
```
120121

121-
### :snake: Text Formats in Python AST
122+
### :snake: Text Formats in Python
122123

123-
TBD
124+
**IR Printer.** By defining an `__ir_print__` method, which converts an IR node to MLC's Python-style AST, MLC's `IRPrinter` handles variable scoping, renaming and syntax highlighting automatically for a text format based on Python syntax.
125+
126+
<details><summary>Defining Python-based text format on a toy IR using `__ir_print__`.</summary>
127+
128+
```python
129+
import mlc.dataclasses as mlcd
130+
import mlc.printer as mlcp
131+
from mlc.printer import ast as mlt
132+
133+
@mlcd.py_class
134+
class Expr(mlcd.PyClass): ...
135+
136+
@mlcd.py_class
137+
class Stmt(mlcd.PyClass): ...
138+
139+
@mlcd.py_class
140+
class Var(Expr):
141+
name: str
142+
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
143+
if not printer.var_is_defined(obj=self):
144+
printer.var_def(obj=self, frame=printer.frames[-1], name=self.name)
145+
return printer.var_get(obj=self)
146+
147+
@mlcd.py_class
148+
class Add(Expr):
149+
lhs: Expr
150+
rhs: Expr
151+
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
152+
lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
153+
rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
154+
return lhs + rhs
155+
156+
@mlcd.py_class
157+
class Assign(Stmt):
158+
lhs: Var
159+
rhs: Expr
160+
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
161+
rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
162+
printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name)
163+
lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
164+
return mlt.Assign(lhs=lhs, rhs=rhs)
165+
166+
@mlcd.py_class
167+
class Func(mlcd.PyClass):
168+
name: str
169+
args: list[Var]
170+
stmts: list[Stmt]
171+
ret: Var
172+
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
173+
with printer.with_frame(mlcp.DefaultFrame()):
174+
for arg in self.args:
175+
printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name)
176+
args: list[mlt.Expr] = [printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)]
177+
stmts: list[mlt.Expr] = [printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)]
178+
ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"]))
179+
return mlt.Function(
180+
name=mlt.Id(self.name),
181+
args=[mlt.Assign(lhs=arg, rhs=None) for arg in args],
182+
decorators=[],
183+
return_type=None,
184+
body=[*stmts, ret_stmt],
185+
)
186+
187+
# An example IR:
188+
a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
189+
f = Func(
190+
name="f",
191+
args=[a, b, c],
192+
stmts=[
193+
Assign(lhs=d, rhs=Add(a, b)), # d = a + b
194+
Assign(lhs=e, rhs=Add(d, c)), # e = d + c
195+
],
196+
ret=e,
197+
)
198+
```
199+
200+
</details>
201+
202+
Two printer APIs are provided for Python-based text format:
203+
- `mlc.printer.to_python` that converts an IR fragment to Python text, and
204+
- `mlc.printer.print_python` that further renders the text with proper syntax highlighting.
205+
206+
```python
207+
>>> print(mlcp.to_python(f)) # Stringify to Python
208+
def f(a, b, c):
209+
d = a + b
210+
e = d + c
211+
return e
212+
>>> mlcp.print_python(f) # Syntax highlighting
213+
```
124214

125215
### :zap: Zero-Copy Interoperability with C++ Plugins
126216

cpp/c_api.cc

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "./registry.h"
2-
#include <mlc/all.h>
2+
#include <mlc/core/all.h>
33

44
namespace mlc {
55
namespace registry {
@@ -20,27 +20,6 @@ using ::mlc::registry::TypeTable;
2020
namespace {
2121
thread_local Any last_error;
2222
MLC_REGISTER_FUNC("mlc.ffi.LoadDSO").set_body([](std::string name) { TypeTable::Get(nullptr)->LoadDSO(name); });
23-
MLC_REGISTER_FUNC("mlc.core.JSONLoads").set_body([](AnyView json_str) {
24-
if (json_str.type_index == kMLCRawStr) {
25-
return ::mlc::core::JSONLoads(json_str.operator const char *());
26-
} else {
27-
::mlc::Str str = json_str;
28-
return ::mlc::core::JSONLoads(str);
29-
}
30-
});
31-
MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize); // TODO: `AnyView` as function argument
32-
MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
33-
if (json_str.type_index == kMLCRawStr) {
34-
return ::mlc::core::Deserialize(json_str.operator const char *());
35-
} else {
36-
return ::mlc::core::Deserialize(json_str.operator ::mlc::Str());
37-
}
38-
});
39-
MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
40-
MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t {
41-
uint64_t ret = ::mlc::core::StructuralHash(obj);
42-
return static_cast<int64_t>(ret);
43-
});
4423
} // namespace
4524

4625
MLC_API MLCAny MLCGetLastError() {

cpp/c_api_tests.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <mlc/all.h>
1+
#include <mlc/core/all.h>
22

33
namespace mlc {
44
namespace {

include/mlc/core/json.h renamed to cpp/json.cc

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,19 @@
1-
#ifndef MLC_CORE_JSON_H_
2-
#define MLC_CORE_JSON_H_
3-
#include "./field_visitor.h"
4-
#include "./func.h"
5-
#include "./str.h"
6-
#include "./udict.h"
7-
#include "./ulist.h"
81
#include <cstdint>
92
#include <iomanip>
10-
#include <mlc/base/all.h>
3+
#include <mlc/core/all.h>
114
#include <sstream>
125

136
namespace mlc {
147
namespace core {
8+
namespace {
159

1610
mlc::Str Serialize(Any any);
1711
Any Deserialize(const char *json_str, int64_t json_str_len);
1812
Any JSONLoads(const char *json_str, int64_t json_str_len);
1913
MLC_INLINE Any Deserialize(const char *json_str) { return Deserialize(json_str, -1); }
2014
MLC_INLINE Any Deserialize(const Str &json_str) { return Deserialize(json_str->data(), json_str->size()); }
21-
MLC_INLINE Any Deserialize(const std::string &json_str) {
22-
return Deserialize(json_str.data(), static_cast<int64_t>(json_str.size()));
23-
}
2415
MLC_INLINE Any JSONLoads(const char *json_str) { return JSONLoads(json_str, -1); }
2516
MLC_INLINE Any JSONLoads(const Str &json_str) { return JSONLoads(json_str->data(), json_str->size()); }
26-
MLC_INLINE Any JSONLoads(const std::string &json_str) {
27-
return JSONLoads(json_str.data(), static_cast<int64_t>(json_str.size()));
28-
}
2917

3018
inline mlc::Str Serialize(Any any) {
3119
using mlc::base::TypeTraits;
@@ -46,7 +34,7 @@ inline mlc::Str Serialize(Any any) {
4634
// clang-format off
4735
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
4836
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
49-
MLC_INLINE void operator()(MLCTypeField *, Optional<Object> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
37+
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
5038
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
5139
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
5240
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
@@ -119,25 +107,21 @@ inline mlc::Str Serialize(Any any) {
119107
} else {
120108
os->put(',');
121109
}
122-
int32_t type_index = type_info->type_index;
123-
if (type_index == kMLCStr) {
124-
StrObj *str = reinterpret_cast<StrObj *>(object);
110+
if (StrObj *str = object->TryCast<StrObj>()) {
125111
str->PrintEscape(*os);
126112
return;
127113
}
128114
(*os) << '[' << (*get_json_type_index)(type_info->type_key);
129-
if (type_index == kMLCList) {
130-
UListObj *list = reinterpret_cast<UListObj *>(object); // TODO: support Downcast
115+
if (UListObj *list = object->TryCast<UListObj>()) {
131116
for (Any &any : *list) {
132117
emitter(nullptr, &any);
133118
}
134-
} else if (type_index == kMLCDict) {
135-
UDictObj *dict = reinterpret_cast<UDictObj *>(object); // TODO: support Downcast
119+
} else if (UDictObj *dict = object->TryCast<UDictObj>()) {
136120
for (auto &kv : *dict) {
137121
emitter(nullptr, &kv.first);
138122
emitter(nullptr, &kv.second);
139123
}
140-
} else if (type_index == kMLCFunc || type_index == kMLCError) {
124+
} else if (object->IsInstance<FuncObj>() || object->IsInstance<ErrorObj>()) {
141125
MLC_THROW(TypeError) << "Unserializable type: " << object->GetTypeKey();
142126
} else {
143127
VisitFields(object, type_info, emitter);
@@ -182,9 +166,9 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
182166
MLCVTableHandle init_vtable;
183167
MLCVTableGetGlobal(nullptr, "__init__", &init_vtable);
184168
// Step 0. Parse JSON string
185-
UDict json_obj = JSONLoads(json_str, json_str_len).operator UDict(); // TODO: impl "Any -> UDict"
169+
UDict json_obj = JSONLoads(json_str, json_str_len);
186170
// Step 1. type_key => constructors
187-
UList type_keys = json_obj->at(Str("type_keys")).operator UList(); // TODO: impl `UDict::at(Str)`
171+
UList type_keys = json_obj->at("type_keys");
188172
std::vector<Func> constructors;
189173
constructors.reserve(type_keys.size());
190174
for (Str type_key : type_keys) {
@@ -205,7 +189,7 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
205189
return ret;
206190
};
207191
// Step 2. Translate JSON object to objects
208-
UList values = json_obj->at(Str("values")).operator UList(); // TODO: impl `UDict::at(Str)`
192+
UList values = json_obj->at("values");
209193
for (int64_t i = 0; i < values->size(); ++i) {
210194
Any obj = values[i];
211195
if (obj.type_index == kMLCList) {
@@ -492,7 +476,22 @@ inline Any JSONLoads(const char *json_str, int64_t json_str_len) {
492476
return JSONParser{0, json_str_len, json_str}.Parse();
493477
}
494478

479+
MLC_REGISTER_FUNC("mlc.core.JSONLoads").set_body([](AnyView json_str) {
480+
if (json_str.type_index == kMLCRawStr) {
481+
return ::mlc::core::JSONLoads(json_str.operator const char *());
482+
} else {
483+
::mlc::Str str = json_str;
484+
return ::mlc::core::JSONLoads(str);
485+
}
486+
});
487+
MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize);
488+
MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
489+
if (json_str.type_index == kMLCRawStr) {
490+
return ::mlc::core::Deserialize(json_str.operator const char *());
491+
} else {
492+
return ::mlc::core::Deserialize(json_str.operator ::mlc::Str());
493+
}
494+
});
495+
} // namespace
495496
} // namespace core
496497
} // namespace mlc
497-
498-
#endif // MLC_CORE_JSON_H_

0 commit comments

Comments
 (0)