Skip to content

Commit 1813692

Browse files
committed
feat(parser): feature completion for v0.1
1 parent 4434b87 commit 1813692

27 files changed

+793
-1262
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)
22

33
project(
44
mlc
5-
VERSION 0.0.13
5+
VERSION 0.0.14
66
DESCRIPTION "MLC-Python"
77
LANGUAGES C CXX
88
)

README.md

Lines changed: 20 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -121,100 +121,47 @@ ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent bindi
121121

122122
### :snake: Text Formats in Python
123123

124-
**IR Printer.** By defining an `__ir_print__` method, which converts an IR node to MLC's Python-style AST, MLC's `IRPrinter` handles variable scoping, renaming and syntax highlighting automatically for a text format based on Python syntax.
124+
**Printer.** MLC converts an IR node to Python AST by looking up the `__ir_print__` method.
125125

126-
<details><summary>Defining Python-based text format on a toy IR using `__ir_print__`.</summary>
126+
**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/ir.py)]**. Copy the toy IR definition to REPL and then create a `Func` node below:
127127

128128
```python
129-
import mlc.dataclasses as mlcd
130-
import mlc.printer as mlcp
131-
from mlc.printer import ast as mlt
132-
133-
@mlcd.py_class
134-
class Expr(mlcd.PyClass): ...
135-
136-
@mlcd.py_class
137-
class Stmt(mlcd.PyClass): ...
138-
139-
@mlcd.py_class
140-
class Var(Expr):
141-
name: str
142-
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
143-
if not printer.var_is_defined(obj=self):
144-
printer.var_def(obj=self, frame=printer.frames[-1], name=self.name)
145-
return printer.var_get(obj=self)
146-
147-
@mlcd.py_class
148-
class Add(Expr):
149-
lhs: Expr
150-
rhs: Expr
151-
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
152-
lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
153-
rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
154-
return lhs + rhs
155-
156-
@mlcd.py_class
157-
class Assign(Stmt):
158-
lhs: Var
159-
rhs: Expr
160-
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
161-
rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
162-
printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name)
163-
lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
164-
return mlt.Assign(lhs=lhs, rhs=rhs)
165-
166-
@mlcd.py_class
167-
class Func(mlcd.PyClass):
168-
name: str
169-
args: list[Var]
170-
stmts: list[Stmt]
171-
ret: Var
172-
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
173-
with printer.with_frame(mlcp.DefaultFrame()):
174-
for arg in self.args:
175-
printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name)
176-
args: list[mlt.Expr] = [printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)]
177-
stmts: list[mlt.Expr] = [printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)]
178-
ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"]))
179-
return mlt.Function(
180-
name=mlt.Id(self.name),
181-
args=[mlt.Assign(lhs=arg, rhs=None) for arg in args],
182-
decorators=[],
183-
return_type=None,
184-
body=[*stmts, ret_stmt],
185-
)
186-
187-
# An example IR:
188-
a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
189-
f = Func(
190-
name="f",
191-
args=[a, b, c],
129+
>>> a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
130+
>>> f = Func("f", [a, b, c],
192131
stmts=[
193132
Assign(lhs=d, rhs=Add(a, b)), # d = a + b
194133
Assign(lhs=e, rhs=Add(d, c)), # e = d + c
195134
],
196-
ret=e,
197-
)
135+
ret=e)
198136
```
199137

200-
</details>
201-
202-
Two printer APIs are provided for Python-based text format:
203-
- `mlc.printer.to_python` that converts an IR fragment to Python text, and
204-
- `mlc.printer.print_python` that further renders the text with proper syntax highlighting.
138+
- Method `mlc.printer.to_python` converts an IR node to Python-based text;
205139

206140
```python
207141
>>> print(mlcp.to_python(f)) # Stringify to Python
208142
def f(a, b, c):
209143
d = a + b
210144
e = d + c
211145
return e
146+
```
147+
148+
- Method `mlc.printer.print_python` further renders the text with proper syntax highlighting. [[Screenshot](https://raw.githubusercontent.com/gist/potatomashed/5a9b20edbdde1b9a91a360baa6bce9ff/raw/3c68031eaba0620a93add270f8ad7ed2c8724a78/mlc-python-printer.svg)]
149+
150+
```python
212151
>>> mlcp.print_python(f) # Syntax highlighting
213152
```
214153

154+
**AST Parser.** MLC has a concise set of APIs for implementing parser with Python's AST module, including:
155+
- Inspection API that obtains source code of a Python class or function and the variables they capture;
156+
- Variable management APIs that help with proper scoping;
157+
- AST fragment evaluation APIs;
158+
- Error rendering APIs.
159+
160+
**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/parser.py)]**. With MLC APIs, a parser can be implemented with 100 lines of code for the Python text format above defined by `__ir_printer__`.
161+
215162
### :zap: Zero-Copy Interoperability with C++ Plugins
216163

217-
TBD
164+
🚧 Under construction.
218165

219166
## :fuelpump: Development
220167

include/mlc/core/typing.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ struct AtomicType : public Type {
122122
};
123123

124124
struct PtrTypeObj : protected MLCTypingPtr {
125-
explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMutable() = ty; }
126-
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingPtr::ty)); }
125+
explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMut() = ty; }
127126
::mlc::Str __str__() const {
128127
std::ostringstream os;
129128
os << "Ptr[" << this->Ty() << "]";
@@ -138,20 +137,20 @@ struct PtrTypeObj : protected MLCTypingPtr {
138137
MLC_DEF_STATIC_TYPE(PtrTypeObj, TypeObj, MLCTypeIndex::kMLCTypingPtr, "mlc.core.typing.PtrType");
139138

140139
private:
141-
Type &TyMutable() { return reinterpret_cast<Type &>(this->MLCTypingPtr::ty); }
140+
Type &TyMut() { return reinterpret_cast<Type &>(this->MLCTypingPtr::ty); }
141+
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingPtr::ty)); }
142142
};
143143

144144
struct PtrType : public Type {
145145
MLC_DEF_OBJ_REF(PtrType, PtrTypeObj, Type)
146146
.StaticFn("__init__", InitOf<PtrTypeObj, Type>)
147-
.MemFn("_ty", &PtrTypeObj::Ty)
147+
._Field("ty", offsetof(MLCTypingPtr, ty), sizeof(MLCTypingPtr::ty), false, ParseType<Type>())
148148
.MemFn("__str__", &PtrTypeObj::__str__)
149149
.MemFn("__cxx_str__", &PtrTypeObj::__cxx_str__);
150150
};
151151

152152
struct OptionalObj : protected MLCTypingOptional {
153153
explicit OptionalObj(Type ty) : MLCTypingOptional{} { this->TyMutable() = ty; }
154-
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingOptional::ty)); }
155154
::mlc::Str __str__() const {
156155
std::ostringstream os;
157156
os << this->Ty() << " | None";
@@ -167,19 +166,19 @@ struct OptionalObj : protected MLCTypingOptional {
167166

168167
private:
169168
Type &TyMutable() { return reinterpret_cast<Type &>(this->MLCTypingOptional::ty); }
169+
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingOptional::ty)); }
170170
};
171171

172172
struct Optional : public Type {
173173
MLC_DEF_OBJ_REF(Optional, OptionalObj, Type)
174174
.StaticFn("__init__", InitOf<OptionalObj, Type>)
175-
.MemFn("_ty", &OptionalObj::Ty)
175+
._Field("ty", offsetof(MLCTypingOptional, ty), sizeof(MLCTypingOptional::ty), false, ParseType<Type>())
176176
.MemFn("__str__", &OptionalObj::__str__)
177177
.MemFn("__cxx_str__", &OptionalObj::__cxx_str__);
178178
};
179179

180180
struct ListObj : protected MLCTypingList {
181181
explicit ListObj(Type ty) : MLCTypingList{} { this->TyMutable() = ty; }
182-
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingList::ty)); }
183182
::mlc::Str __str__() const {
184183
std::ostringstream os;
185184
os << "list[" << this->Ty() << "]";
@@ -195,47 +194,48 @@ struct ListObj : protected MLCTypingList {
195194

196195
protected:
197196
Type &TyMutable() { return reinterpret_cast<Type &>(this->MLCTypingList::ty); }
197+
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingList::ty)); }
198198
};
199199

200200
struct List : public Type {
201201
MLC_DEF_OBJ_REF(List, ListObj, Type)
202202
.StaticFn("__init__", InitOf<ListObj, Type>)
203-
.MemFn("_ty", &ListObj::Ty)
203+
._Field("ty", offsetof(MLCTypingList, ty), sizeof(MLCTypingList::ty), false, ParseType<Type>())
204204
.MemFn("__str__", &ListObj::__str__)
205205
.MemFn("__cxx_str__", &ListObj::__cxx_str__);
206206
};
207207

208208
struct DictObj : protected MLCTypingDict {
209209
explicit DictObj(Type ty_k, Type ty_v) : MLCTypingDict{} {
210-
this->TyMutableK() = ty_k;
211-
this->TyMutableV() = ty_v;
210+
this->TyKMut() = ty_k;
211+
this->TyVMut() = ty_v;
212212
}
213-
Type key() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_k)); }
214-
Type value() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_v)); }
215213
::mlc::Str __str__() const {
216214
std::ostringstream os;
217-
os << "dict[" << this->key() << ", " << this->value() << "]";
215+
os << "dict[" << this->TyK() << ", " << this->TyV() << "]";
218216
return os.str();
219217
}
220218
::mlc::Str __cxx_str__() const {
221-
::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->key());
222-
::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->value());
219+
::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->TyK());
220+
::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->TyV());
223221
std::ostringstream os;
224222
os << "::mlc::Dict<" << k_str->data() << ", " << v_str->data() << ">";
225223
return os.str();
226224
}
227225
MLC_DEF_STATIC_TYPE(DictObj, TypeObj, MLCTypeIndex::kMLCTypingDict, "mlc.core.typing.Dict");
228226

229227
protected:
230-
Type &TyMutableK() { return reinterpret_cast<Type &>(this->ty_k); }
231-
Type &TyMutableV() { return reinterpret_cast<Type &>(this->ty_v); }
228+
Type &TyKMut() { return reinterpret_cast<Type &>(this->ty_k); }
229+
Type &TyVMut() { return reinterpret_cast<Type &>(this->ty_v); }
230+
Type TyK() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_k)); }
231+
Type TyV() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_v)); }
232232
};
233233

234234
struct Dict : public Type {
235235
MLC_DEF_OBJ_REF(Dict, DictObj, Type)
236236
.StaticFn("__init__", InitOf<DictObj, Type, Type>)
237-
.MemFn("_key", &DictObj::key)
238-
.MemFn("_value", &DictObj::value)
237+
._Field("ty_k", offsetof(MLCTypingDict, ty_k), sizeof(MLCTypingDict::ty_k), false, ParseType<Type>())
238+
._Field("ty_v", offsetof(MLCTypingDict, ty_v), sizeof(MLCTypingDict::ty_v), false, ParseType<Type>())
239239
.MemFn("__str__", &DictObj::__str__)
240240
.MemFn("__cxx_str__", &DictObj::__cxx_str__);
241241
};

include/mlc/core/utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ struct ReflectionHelper {
128128
return *this;
129129
}
130130

131+
inline ReflectionHelper &_Field(const char *name, int64_t field_offset, int32_t num_bytes, bool frozen, Any ty) {
132+
this->any_pool.push_back(ty);
133+
int32_t index = static_cast<int32_t>(this->fields.size());
134+
this->fields.emplace_back(MLCTypeField{name, index, field_offset, num_bytes, frozen, ty.v.v_obj});
135+
return *this;
136+
}
137+
131138
template <typename Callable> inline ReflectionHelper &MemFn(const char *name, Callable &&method) {
132139
MLCTypeMethod m = this->PrepareMethod(name, std::forward<Callable>(method));
133140
m.kind = kMemFn;

include/mlc/printer/ir_printer.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct IRPrinterObj : public Object {
4949

5050
bool VarIsDefined(const ObjectRef &obj) { return obj2info->count(obj) > 0; }
5151

52-
Id VarDef(const ObjectRef &obj, const ObjectRef &frame, Str name_hint) {
52+
Id VarDef(Str name_hint, const ObjectRef &obj, const Optional<ObjectRef> &frame) {
5353
if (auto it = obj2info.find(obj); it != obj2info.end()) {
5454
Optional<Str> name = (*it).second->name;
5555
return Id(name.value());
@@ -66,18 +66,19 @@ struct IRPrinterObj : public Object {
6666
name = name_hint.ToStdString() + '_' + std::to_string(i);
6767
}
6868
defined_names->Set(name, 1);
69-
this->_VarDef(obj, frame, VarInfo(name, Func([name]() { return Id(name); })));
69+
this->_VarDef(VarInfo(name, Func([name]() { return Id(name); })), obj, frame);
7070
return Id(name);
7171
}
7272

73-
void VarDefNoName(const ObjectRef &obj, const ObjectRef &frame, const Func &creator) {
73+
void VarDefNoName(const Func &creator, const ObjectRef &obj, const Optional<ObjectRef> &frame) {
7474
if (obj2info.count(obj) > 0) {
7575
MLC_THROW(KeyError) << "Variable already defined: " << obj;
7676
}
77-
this->_VarDef(obj, frame, VarInfo(mlc::Null, creator));
77+
this->_VarDef(VarInfo(mlc::Null, creator), obj, frame);
7878
}
7979

80-
void _VarDef(const ObjectRef &obj, const ObjectRef &frame, VarInfo var_info) {
80+
void _VarDef(VarInfo var_info, const ObjectRef &obj, const Optional<ObjectRef> &_frame) {
81+
ObjectRef frame = _frame.defined() ? _frame.value() : this->frames.back().operator ObjectRef();
8182
obj2info->Set(obj, var_info);
8283
auto it = frame_vars.find(frame);
8384
if (it == frame_vars.end()) {
@@ -99,7 +100,7 @@ struct IRPrinterObj : public Object {
99100
obj2info.erase(it);
100101
}
101102

102-
Optional<Id> VarGet(const ObjectRef &obj) {
103+
Optional<Expr> VarGet(const ObjectRef &obj) {
103104
auto it = obj2info.find(obj);
104105
if (it == obj2info.end()) {
105106
return Null;

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
[project]
22
name = "mlc-python"
3-
version = "0.0.13"
3+
version = "0.0.14"
44
dependencies = [
55
'numpy >= 1.22',
6-
"ml-dtypes >= 0.1",
6+
'ml-dtypes >= 0.1',
77
'Pygments>=2.4.0',
8+
'colorama',
89
'setuptools ; platform_system == "Windows"',
910
]
1011
description = ""

python/mlc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import _cython, ast, cc, dataclasses, printer
1+
from . import _cython, cc, dataclasses, parser, printer
22
from ._cython import Ptr, Str
33
from .core import DataType, Device, Dict, Error, Func, List, Object, ObjectPath, typing
44
from .dataclasses import PyClass, c_class, py_class

python/mlc/_cython/core.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,12 +1360,16 @@ def make_mlc_init(list fields):
13601360
cdef tuple setters = _setters
13611361
cdef int32_t num_args = len(args)
13621362
cdef int32_t i = 0
1363+
cdef object e = None
13631364
assert num_args == len(setters)
13641365
while i < num_args:
13651366
try:
13661367
setters[i](self, args[i])
1367-
except Exception as e: # no-cython-lint
1368-
raise ValueError(f"Failed to set field `{fields[i].name}`: {str(e)}. Got: {args[i]}")
1368+
except Exception as _e: # no-cython-lint
1369+
e = ValueError(f"Failed to set field `{fields[i].name}`: {str(_e)}. Got: {args[i]}")
1370+
e = e.with_traceback(_e.__traceback__)
1371+
if e is not None:
1372+
raise e
13691373
i += 1
13701374

13711375
return _mlc_init

python/mlc/ast/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

0 commit comments

Comments
 (0)