@@ -538,9 +538,24 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
538
538
} else if (lhs_type_index == kMLCFunc || lhs_type_index == kMLCError ) {
539
539
throw SEqualError (" Cannot compare `mlc.Func` or `mlc.Error`" , new_path);
540
540
} else if (lhs_type_index == kMLCOpaque ) {
541
- std::ostringstream err;
542
- err << " Cannot compare `mlc.Opaque` of type: " << lhs->DynCast <OpaqueObj>()->opaque_type_name ;
543
- throw SEqualError (err.str ().c_str (), new_path);
541
+ std::string func_name = " Opaque.eq_s." ;
542
+ func_name += lhs->DynCast <OpaqueObj>()->opaque_type_name ;
543
+ FuncObj *func = Func::GetGlobal (func_name.c_str (), true );
544
+ if (func == nullptr ) {
545
+ std::ostringstream err;
546
+ err << " Cannot compare `mlc.Opaque` of type: " << lhs->DynCast <OpaqueObj>()->opaque_type_name << " ; Use "
547
+ << " `mlc.Func.register(\" " << func_name << " \" )(eq_s_func)` to register a comparison method" ;
548
+ throw SEqualError (err.str ().c_str (), new_path);
549
+ }
550
+ Any result = (*func)(lhs, rhs);
551
+ if (result.type_index != kMLCBool ) {
552
+ std::ostringstream err;
553
+ err << " Comparison function `" << func_name << " ` must return a boolean value, but got: " << result;
554
+ throw SEqualError (err.str ().c_str (), new_path);
555
+ }
556
+ if (result.operator bool () == false ) {
557
+ MLC_CORE_EQ_S_ERR (lhs, rhs, new_path);
558
+ }
544
559
} else {
545
560
bool visited = false ;
546
561
MLCTypeInfo *type_info = Lib::GetTypeInfo (lhs_type_index);
@@ -802,9 +817,21 @@ inline uint64_t StructuralHashImpl(Object *obj) {
802
817
} else if (type_index == kMLCFunc || type_index == kMLCError ) {
803
818
throw SEqualError (" Cannot compare `mlc.Func` or `mlc.Error`" , ObjectPath::Root ());
804
819
} else if (type_index == kMLCOpaque ) {
805
- std::ostringstream err;
806
- err << " Cannot compare `mlc.Opaque` of type: " << obj->DynCast <OpaqueObj>()->opaque_type_name ;
807
- throw SEqualError (err.str ().c_str (), ObjectPath::Root ());
820
+ std::string func_name = " Opaque.hash_s." ;
821
+ func_name += obj->DynCast <OpaqueObj>()->opaque_type_name ;
822
+ FuncObj *func = Func::GetGlobal (func_name.c_str (), true );
823
+ if (func == nullptr ) {
824
+ MLC_THROW (ValueError) << " Cannot hash `mlc.Opaque` of type: " << obj->DynCast <OpaqueObj>()->opaque_type_name
825
+ << " ; Use `mlc.Func.register(\" " << func_name
826
+ << " \" )(hash_s_func)` to register a hashing method" ;
827
+ }
828
+ Any result = (*func)(obj);
829
+ if (result.type_index != kMLCInt ) {
830
+ MLC_THROW (TypeError) << " Hashing function `" << func_name
831
+ << " ` must return an integer value, but got: " << result;
832
+ }
833
+ int64_t hash_value = result.operator int64_t ();
834
+ EnqueuePOD (tasks, hash_value);
808
835
} else {
809
836
MLCTypeInfo *type_info = Lib::GetTypeInfo (type_index);
810
837
tasks->emplace_back (Task{obj, type_info, false , bind_free_vars, type_info->type_key_hash });
0 commit comments