|
| 1 | +#include "mlc/core/error.h" |
1 | 2 | #include <algorithm>
|
2 | 3 | #include <cmath>
|
3 | 4 | #include <cstdint>
|
@@ -532,11 +533,141 @@ inline uint64_t StructuralHash(Object *obj) {
|
532 | 533 | #undef MLC_CORE_HASH_S_POD
|
533 | 534 | #undef MLC_CORE_HASH_S_ANY
|
534 | 535 |
|
| 536 | +inline Any CopyShallow(AnyView source) { |
| 537 | + int32_t type_index = source.type_index; |
| 538 | + if (::mlc::base::IsTypeIndexPOD(type_index)) { |
| 539 | + return source; |
| 540 | + } else if (UListObj *list = source.TryCast<UListObj>()) { |
| 541 | + return UList(list->begin(), list->end()); |
| 542 | + } else if (UDictObj *dict = source.TryCast<UDictObj>()) { |
| 543 | + return UDict(dict->begin(), dict->end()); |
| 544 | + } else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>()) { |
| 545 | + return source; |
| 546 | + } |
| 547 | + struct Copier { |
| 548 | + MLC_INLINE void operator()(MLCTypeField *, const Any *any) { fields->push_back(AnyView(*any)); } |
| 549 | + MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { fields->push_back(AnyView(*obj)); } |
| 550 | + MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { fields->push_back(AnyView(*opt)); } |
| 551 | + MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); } |
| 552 | + MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); } |
| 553 | + MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); } |
| 554 | + MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); } |
| 555 | + MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); } |
| 556 | + MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); } |
| 557 | + MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); } |
| 558 | + MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); } |
| 559 | + MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); } |
| 560 | + MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); } |
| 561 | + MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); } |
| 562 | + MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); } |
| 563 | + MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); } |
| 564 | + MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); } |
| 565 | + MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); } |
| 566 | + std::vector<AnyView> *fields; |
| 567 | + }; |
| 568 | + FuncObj *init_func = ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_index, "__init__"); |
| 569 | + MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index); |
| 570 | + std::vector<AnyView> fields; |
| 571 | + VisitFields(source.operator Object *(), type_info, Copier{&fields}); |
| 572 | + Any ret; |
| 573 | + ::mlc::base::FuncCall(init_func, static_cast<int32_t>(fields.size()), fields.data(), &ret); |
| 574 | + return ret; |
| 575 | +} |
| 576 | + |
| 577 | +inline Any CopyDeep(AnyView source) { |
| 578 | + if (::mlc::base::IsTypeIndexPOD(source.type_index)) { |
| 579 | + return source; |
| 580 | + } |
| 581 | + struct Copier { |
| 582 | + MLC_INLINE void operator()(MLCTypeField *, const Any *any) { HandleAny(any); } |
| 583 | + MLC_INLINE void operator()(MLCTypeField *, ObjectRef *ref) { |
| 584 | + if (const Object *obj = ref->get()) { |
| 585 | + HandleObject(obj); |
| 586 | + } else { |
| 587 | + fields->push_back(AnyView()); |
| 588 | + } |
| 589 | + } |
| 590 | + MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { |
| 591 | + if (const Object *obj = opt->get()) { |
| 592 | + HandleObject(obj); |
| 593 | + } else { |
| 594 | + fields->push_back(AnyView()); |
| 595 | + } |
| 596 | + } |
| 597 | + MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); } |
| 598 | + MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); } |
| 599 | + MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); } |
| 600 | + MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); } |
| 601 | + MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); } |
| 602 | + MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); } |
| 603 | + MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); } |
| 604 | + MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); } |
| 605 | + MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); } |
| 606 | + MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); } |
| 607 | + MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); } |
| 608 | + MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); } |
| 609 | + MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); } |
| 610 | + MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); } |
| 611 | + MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); } |
| 612 | + |
| 613 | + void HandleObject(const Object *obj) { |
| 614 | + if (auto it = orig2copy->find(obj); it != orig2copy->end()) { |
| 615 | + fields->push_back(AnyView(it->second)); |
| 616 | + } else { |
| 617 | + MLC_THROW(InternalError) << "InternalError: object doesn't exist in the memo: " << AnyView(obj); |
| 618 | + } |
| 619 | + } |
| 620 | + |
| 621 | + void HandleAny(const Any *any) { |
| 622 | + if (const Object *obj = any->TryCast<Object>()) { |
| 623 | + HandleObject(obj); |
| 624 | + } else { |
| 625 | + fields->push_back(AnyView(*any)); |
| 626 | + } |
| 627 | + } |
| 628 | + |
| 629 | + std::unordered_map<const Object *, ObjectRef> *orig2copy; |
| 630 | + std::vector<AnyView> *fields; |
| 631 | + }; |
| 632 | + std::unordered_map<const Object *, ObjectRef> orig2copy; |
| 633 | + std::vector<AnyView> fields; |
| 634 | + TopoVisit(source.operator Object *(), nullptr, [&](Object *object, MLCTypeInfo *type_info) mutable -> void { |
| 635 | + Any ret; |
| 636 | + if (UListObj *list = object->TryCast<UListObj>()) { |
| 637 | + fields.clear(); |
| 638 | + fields.reserve(list->size()); |
| 639 | + for (Any &e : *list) { |
| 640 | + Copier{&orig2copy, &fields}.HandleAny(&e); |
| 641 | + } |
| 642 | + UList::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret); |
| 643 | + } else if (UDictObj *dict = object->TryCast<UDictObj>()) { |
| 644 | + fields.clear(); |
| 645 | + for (auto [key, value] : *dict) { |
| 646 | + Copier{&orig2copy, &fields}.HandleAny(&key); |
| 647 | + Copier{&orig2copy, &fields}.HandleAny(&value); |
| 648 | + } |
| 649 | + UDict::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret); |
| 650 | + } else if (object->IsInstance<StrObj>() || object->IsInstance<ErrorObj>() || object->IsInstance<FuncObj>()) { |
| 651 | + ret = object; |
| 652 | + } else { |
| 653 | + fields.clear(); |
| 654 | + VisitFields(object, type_info, Copier{&orig2copy, &fields}); |
| 655 | + FuncObj *func = |
| 656 | + ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_info->type_index, "__init__"); |
| 657 | + ::mlc::base::FuncCall(func, static_cast<int32_t>(fields.size()), fields.data(), &ret); |
| 658 | + } |
| 659 | + orig2copy[object] = ret.operator ObjectRef(); |
| 660 | + }); |
| 661 | + return orig2copy.at(source.operator Object *()); |
| 662 | +} |
| 663 | + |
535 | 664 | MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
|
536 | 665 | MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t {
|
537 | 666 | uint64_t ret = ::mlc::core::StructuralHash(obj);
|
538 | 667 | return static_cast<int64_t>(ret);
|
539 | 668 | });
|
| 669 | +MLC_REGISTER_FUNC("mlc.core.CopyShallow").set_body(::mlc::core::CopyShallow); |
| 670 | +MLC_REGISTER_FUNC("mlc.core.CopyDeep").set_body(::mlc::core::CopyDeep); |
540 | 671 | } // namespace
|
541 | 672 | } // namespace core
|
542 | 673 | } // namespace mlc
|
0 commit comments