Skip to content

Commit fba2689

Browse files
committed
feat(dataclass): Support copy and deepcopy (#4)
This PR introduces support for Python's native `__copy__` and `__deepcopy__` method for all MLC dataclasses. This is done by field visitor and topo visitor.
1 parent 73bbe0f commit fba2689

File tree

14 files changed

+514
-54
lines changed

14 files changed

+514
-54
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ jobs:
2020
with:
2121
python-version: ${{ env.MLC_PYTHON_VERSION }}
2222
- uses: pre-commit/[email protected]
23+
- uses: ytanikin/[email protected]
24+
if: github.event_name == 'pull_request'
25+
with:
26+
task_types: '["feat", "fix", "ci", "chore", "test"]'
27+
add_label: 'false'
2328
windows:
2429
name: Windows
2530
runs-on: windows-latest

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# See https://pre-commit.com for more information
22
# See https://pre-commit.com/hooks.html for more hooks
3+
default_install_hook_types:
4+
- pre-commit
5+
- commit-msg
36
repos:
47
- repo: https://github.com/pre-commit/pre-commit-hooks
58
rev: v5.0.0

cpp/json.cc

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,30 @@ inline mlc::Str Serialize(Any any) {
3131
using TObj2Idx = std::unordered_map<Object *, int32_t>;
3232
using TJsonTypeIndex = decltype(get_json_type_index);
3333
struct Emitter {
34+
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
3435
// clang-format off
35-
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
36-
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
37-
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
38-
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
39-
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
40-
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
41-
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
42-
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast<int64_t>(*v)); }
43-
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast<int64_t>(*v)); }
44-
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast<int64_t>(*v)); }
45-
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast<int64_t>(*v)); }
46-
MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast<double>(*v)); }
47-
MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast<double>(*v)); }
48-
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); }
49-
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); }
50-
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
51-
MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
52-
MLC_INLINE void operator()(MLCTypeField *, const char **) { MLC_THROW(TypeError) << "Unserializable type: const char *"; }
36+
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
37+
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
38+
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
39+
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
40+
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
41+
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
5342
// clang-format on
43+
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast<int64_t>(*v)); }
44+
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast<int64_t>(*v)); }
45+
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast<int64_t>(*v)); }
46+
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast<int64_t>(*v)); }
47+
MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast<double>(*v)); }
48+
MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast<double>(*v)); }
49+
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); }
50+
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); }
51+
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *) {
52+
MLC_THROW(TypeError) << "Unserializable type: void *";
53+
}
54+
MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
55+
MLC_INLINE void operator()(MLCTypeField *, const char **) {
56+
MLC_THROW(TypeError) << "Unserializable type: const char *";
57+
}
5458
inline void EmitNil() { (*os) << ", null"; }
5559
inline void EmitFloat(double v) { (*os) << ", " << std::fixed << std::setprecision(19) << v; }
5660
inline void EmitInt(int64_t v) {
@@ -98,10 +102,17 @@ inline mlc::Str Serialize(Any any) {
98102
const TObj2Idx *obj2index;
99103
};
100104

105+
std::unordered_map<Object *, int32_t> topo_indices;
101106
std::ostringstream os;
102-
auto on_visit = [get_json_type_index = &get_json_type_index, os = &os, is_first_object = true](
103-
Object *object, MLCTypeInfo *type_info, const TObj2Idx &obj2index) mutable -> void {
104-
Emitter emitter{os, get_json_type_index, &obj2index};
107+
auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os,
108+
is_first_object = true](Object *object, MLCTypeInfo *type_info) mutable -> void {
109+
int32_t &topo_index = topo_indices[object];
110+
if (topo_index == 0) {
111+
topo_index = static_cast<int32_t>(topo_indices.size()) - 1;
112+
} else {
113+
MLC_THROW(InternalError) << "This should never happen: object already visited";
114+
}
115+
Emitter emitter{os, get_json_type_index, &topo_indices};
105116
if (is_first_object) {
106117
is_first_object = false;
107118
} else {
@@ -163,29 +174,23 @@ inline mlc::Str Serialize(Any any) {
163174
}
164175

165176
inline Any Deserialize(const char *json_str, int64_t json_str_len) {
166-
MLCVTableHandle init_vtable;
167-
MLCVTableGetGlobal(nullptr, "__init__", &init_vtable);
177+
MLCVTableHandle init_table = ::mlc::base::LibState::init;
168178
// Step 0. Parse JSON string
169179
UDict json_obj = JSONLoads(json_str, json_str_len);
170180
// Step 1. type_key => constructors
171181
UList type_keys = json_obj->at("type_keys");
172-
std::vector<Func> constructors;
182+
std::vector<FuncObj *> constructors;
173183
constructors.reserve(type_keys.size());
174184
for (Str type_key : type_keys) {
175-
Any init_func;
176185
int32_t type_index = ::mlc::base::TypeKey2TypeIndex(type_key->data());
177-
MLCVTableGetFunc(init_vtable, type_index, false, &init_func);
178-
if (!::mlc::base::IsTypeIndexNone(init_func.type_index)) {
179-
constructors.push_back(init_func.operator Func());
180-
} else {
181-
MLC_THROW(InternalError) << "Method `__init__` is not defined for type " << type_key;
182-
}
186+
FuncObj *func = ::mlc::base::LibState::VTableGetFunc(init_table, type_index, "__init__");
187+
constructors.push_back(func);
183188
}
184189
auto invoke_init = [&constructors](UList args) {
185190
int32_t json_type_index = args[0];
186191
Any ret;
187-
::mlc::base::FuncCall(constructors.at(json_type_index).get(), static_cast<int32_t>(args.size()) - 1,
188-
args->data() + 1, &ret);
192+
::mlc::base::FuncCall(constructors.at(json_type_index), static_cast<int32_t>(args.size()) - 1, args->data() + 1,
193+
&ret);
189194
return ret;
190195
};
191196
// Step 2. Translate JSON object to objects

cpp/structure.cc

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlc/core/error.h"
12
#include <algorithm>
23
#include <cmath>
34
#include <cstdint>
@@ -532,11 +533,141 @@ inline uint64_t StructuralHash(Object *obj) {
532533
#undef MLC_CORE_HASH_S_POD
533534
#undef MLC_CORE_HASH_S_ANY
534535

536+
inline Any CopyShallow(AnyView source) {
537+
int32_t type_index = source.type_index;
538+
if (::mlc::base::IsTypeIndexPOD(type_index)) {
539+
return source;
540+
} else if (UListObj *list = source.TryCast<UListObj>()) {
541+
return UList(list->begin(), list->end());
542+
} else if (UDictObj *dict = source.TryCast<UDictObj>()) {
543+
return UDict(dict->begin(), dict->end());
544+
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>()) {
545+
return source;
546+
}
547+
struct Copier {
548+
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { fields->push_back(AnyView(*any)); }
549+
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { fields->push_back(AnyView(*obj)); }
550+
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { fields->push_back(AnyView(*opt)); }
551+
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); }
552+
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); }
553+
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); }
554+
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); }
555+
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); }
556+
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); }
557+
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); }
558+
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); }
559+
MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); }
560+
MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); }
561+
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); }
562+
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); }
563+
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); }
564+
MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); }
565+
MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); }
566+
std::vector<AnyView> *fields;
567+
};
568+
FuncObj *init_func = ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_index, "__init__");
569+
MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index);
570+
std::vector<AnyView> fields;
571+
VisitFields(source.operator Object *(), type_info, Copier{&fields});
572+
Any ret;
573+
::mlc::base::FuncCall(init_func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
574+
return ret;
575+
}
576+
577+
inline Any CopyDeep(AnyView source) {
578+
if (::mlc::base::IsTypeIndexPOD(source.type_index)) {
579+
return source;
580+
}
581+
struct Copier {
582+
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { HandleAny(any); }
583+
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *ref) {
584+
if (const Object *obj = ref->get()) {
585+
HandleObject(obj);
586+
} else {
587+
fields->push_back(AnyView());
588+
}
589+
}
590+
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) {
591+
if (const Object *obj = opt->get()) {
592+
HandleObject(obj);
593+
} else {
594+
fields->push_back(AnyView());
595+
}
596+
}
597+
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); }
598+
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); }
599+
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); }
600+
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); }
601+
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); }
602+
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); }
603+
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); }
604+
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); }
605+
MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); }
606+
MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); }
607+
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); }
608+
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); }
609+
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); }
610+
MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); }
611+
MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); }
612+
613+
void HandleObject(const Object *obj) {
614+
if (auto it = orig2copy->find(obj); it != orig2copy->end()) {
615+
fields->push_back(AnyView(it->second));
616+
} else {
617+
MLC_THROW(InternalError) << "InternalError: object doesn't exist in the memo: " << AnyView(obj);
618+
}
619+
}
620+
621+
void HandleAny(const Any *any) {
622+
if (const Object *obj = any->TryCast<Object>()) {
623+
HandleObject(obj);
624+
} else {
625+
fields->push_back(AnyView(*any));
626+
}
627+
}
628+
629+
std::unordered_map<const Object *, ObjectRef> *orig2copy;
630+
std::vector<AnyView> *fields;
631+
};
632+
std::unordered_map<const Object *, ObjectRef> orig2copy;
633+
std::vector<AnyView> fields;
634+
TopoVisit(source.operator Object *(), nullptr, [&](Object *object, MLCTypeInfo *type_info) mutable -> void {
635+
Any ret;
636+
if (UListObj *list = object->TryCast<UListObj>()) {
637+
fields.clear();
638+
fields.reserve(list->size());
639+
for (Any &e : *list) {
640+
Copier{&orig2copy, &fields}.HandleAny(&e);
641+
}
642+
UList::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret);
643+
} else if (UDictObj *dict = object->TryCast<UDictObj>()) {
644+
fields.clear();
645+
for (auto [key, value] : *dict) {
646+
Copier{&orig2copy, &fields}.HandleAny(&key);
647+
Copier{&orig2copy, &fields}.HandleAny(&value);
648+
}
649+
UDict::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret);
650+
} else if (object->IsInstance<StrObj>() || object->IsInstance<ErrorObj>() || object->IsInstance<FuncObj>()) {
651+
ret = object;
652+
} else {
653+
fields.clear();
654+
VisitFields(object, type_info, Copier{&orig2copy, &fields});
655+
FuncObj *func =
656+
::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_info->type_index, "__init__");
657+
::mlc::base::FuncCall(func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
658+
}
659+
orig2copy[object] = ret.operator ObjectRef();
660+
});
661+
return orig2copy.at(source.operator Object *());
662+
}
663+
535664
MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
536665
MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t {
537666
uint64_t ret = ::mlc::core::StructuralHash(obj);
538667
return static_cast<int64_t>(ret);
539668
});
669+
MLC_REGISTER_FUNC("mlc.core.CopyShallow").set_body(::mlc::core::CopyShallow);
670+
MLC_REGISTER_FUNC("mlc.core.CopyDeep").set_body(::mlc::core::CopyDeep);
540671
} // namespace
541672
} // namespace core
542673
} // namespace mlc

include/mlc/base/all.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ template <typename T> MLC_INLINE AnyView::AnyView(Ref<T> &&src) : AnyView(static
5858
// `src` is not reset here because `AnyView` does not take ownership of the object
5959
}
6060

61-
template <typename T> MLC_INLINE AnyView::AnyView(const Optional<T> &src) {
61+
template <typename T> MLC_INLINE AnyView::AnyView(const Optional<T> &src) : MLCAny() {
6262
if (const auto *value = src.get()) {
6363
if constexpr (::mlc::base::IsPOD<T>) {
6464
using TPOD = T;

include/mlc/base/utils.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,11 @@ struct LibState {
387387
DecRef(func.v.v_obj);
388388
}
389389
FuncObj *ret = reinterpret_cast<FuncObj *>(func.v.v_obj);
390-
if (func.type_index != kMLCFunc) {
390+
if (func.type_index == kMLCNone) {
391+
MLC_THROW(TypeError) << "Function `" << vtable_name
392+
<< "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index)
393+
<< " is not defined in the vtable";
394+
} else if (func.type_index != kMLCFunc) {
391395
MLC_THROW(TypeError) << "Function `" << vtable_name
392396
<< "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index)
393397
<< " is not callable. Its type is " << ::mlc::base::TypeIndex2TypeKey(func.type_index);
@@ -401,6 +405,7 @@ struct LibState {
401405
static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__");
402406
static MLC_SYMBOL_HIDE inline MLCVTableHandle str = VTableGetGlobal("__str__");
403407
static MLC_SYMBOL_HIDE inline MLCVTableHandle ir_print = VTableGetGlobal("__ir_print__");
408+
static MLC_SYMBOL_HIDE inline MLCVTableHandle init = VTableGetGlobal("__init__");
404409
};
405410

406411
} // namespace base

include/mlc/core/dict.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,14 @@ struct UDict : public ObjectRef {
138138
MLC_INLINE const_iterator end() const { return get()->end(); }
139139
MLC_INLINE const_reverse_iterator rbegin() const { return get()->rbegin(); }
140140
MLC_INLINE const_reverse_iterator rend() const { return get()->rend(); }
141+
MLC_INLINE static void FromAnyTuple(int32_t num_args, const AnyView *args, Any *ret) {
142+
::mlc::core::DictBase::Accessor<UDictObj>::New(num_args, args, ret);
143+
}
141144
MLC_DEF_OBJ_REF(UDict, UDictObj, ObjectRef)
142145
.FieldReadOnly("capacity", &MLCDict::capacity)
143146
.FieldReadOnly("size", &MLCDict::size)
144147
.FieldReadOnly("data", &MLCDict::data)
145-
.StaticFn("__init__", ::mlc::core::DictBase::Accessor<UDictObj>::New)
148+
.StaticFn("__init__", FromAnyTuple)
146149
.MemFn("__str__", &UDictObj::__str__)
147150
.MemFn("__getitem__", ::mlc::core::DictBase::Accessor<UDictObj>::GetItem)
148151
.MemFn("__iter_get_key__", ::mlc::core::DictBase::Accessor<UDictObj>::GetKey)

include/mlc/core/field_visitor.h

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ template <typename Visitor> inline void VisitStructure(Object *root, MLCTypeInfo
164164
}
165165

166166
inline void TopoVisit(Object *root, std::function<void(Object *object, MLCTypeInfo *type_info)> pre_visit,
167-
std::function<void(Object *object, MLCTypeInfo *type_info,
168-
const std::unordered_map<Object *, int32_t> &topo_indices)>
169-
on_visit) {
167+
std::function<void(Object *object, MLCTypeInfo *type_info)> on_visit) {
170168
struct TopoInfo {
171169
Object *obj;
172170
MLCTypeInfo *type_info;
@@ -271,20 +269,13 @@ inline void TopoVisit(Object *root, std::function<void(Object *object, MLCTypeIn
271269
}
272270
}
273271
// Step 3. Traverse the graph by topological order
274-
std::unordered_map<Object *, int32_t> topo_indices;
275272
size_t num_objects = 0;
276273
for (; !stack.empty(); ++num_objects) {
277274
TopoInfo *current = stack.back();
278275
stack.pop_back();
279-
// Step 3.1. Lable object index
280-
int32_t &topo_index = topo_indices[current->obj];
281-
if (topo_index != 0) {
282-
MLC_THROW(InternalError) << "This should never happen: object already visited";
283-
}
284-
topo_index = static_cast<int32_t>(num_objects);
285-
// Step 3.2. Visit object
286-
on_visit(current->obj, current->type_info, topo_indices);
287-
// Step 3.3. Decrease the dependency count of topo_parents
276+
// Step 3.1. Visit object
277+
on_visit(current->obj, current->type_info);
278+
// Step 3.2. Decrease the dependency count of topo_parents
288279
for (TopoInfo *parent : current->topo_parents) {
289280
if (--parent->topo_deps == 0) {
290281
stack.push_back(parent);

0 commit comments

Comments
 (0)