1
1
#ifndef MLC_BASE_TRAITS_DTYPE_H_
2
2
#define MLC_BASE_TRAITS_DTYPE_H_
3
3
4
+ #include " ./lib.h"
4
5
#include " ./utils.h"
5
- #include < cstdlib>
6
- #include < unordered_map>
7
6
8
7
namespace mlc {
9
8
namespace base {
10
9
11
- inline const char *DLDataTypeCode2Str ( int32_t type_code );
12
- inline DLDataType String2DLDataType ( const std::string &source);
10
+ inline DLDataType DataTypeFromStr ( const char *source );
11
+
13
12
inline bool DataTypeEqual (DLDataType a, DLDataType b) {
14
13
return a.code == b.code && a.bits == b.bits && a.lanes == b.lanes ;
15
14
}
15
+ inline const char *DataTypeCode2Str (int32_t type_code) { return ::mlc::Lib::DataTypeCodeToStr (type_code); }
16
16
17
17
template <> struct TypeTraits <DLDataType> {
18
18
static constexpr int32_t type_index = static_cast <int32_t >(MLCTypeIndex::kMLCDataType );
@@ -29,10 +29,10 @@ template <> struct TypeTraits<DLDataType> {
29
29
return v->v .v_dtype ;
30
30
}
31
31
if (ty == MLCTypeIndex::kMLCRawStr ) {
32
- return String2DLDataType (v->v .v_str );
32
+ return DataTypeFromStr (v->v .v_str );
33
33
}
34
34
if (ty == MLCTypeIndex::kMLCStr ) {
35
- return String2DLDataType (reinterpret_cast <const MLCStr *>(v->v .v_obj )->data );
35
+ return DataTypeFromStr (reinterpret_cast <const MLCStr *>(v->v .v_obj )->data );
36
36
}
37
37
throw TemporaryTypeError ();
38
38
}
@@ -50,107 +50,19 @@ template <> struct TypeTraits<DLDataType> {
50
50
return " void" ;
51
51
}
52
52
std::ostringstream os;
53
- os << DLDataTypeCode2Str (code);
54
- if (code != kDLDataTypeFloat8E5M2 && code != kDLDataTypeFloat8E4M3FN ) {
53
+ os << DataTypeCode2Str (code);
54
+ if (code < kMLCExtension_DLDataTypeCode_Begin ) {
55
+ // for `code >= kMLCExtension_DLDataTypeCode_Begin`, the `bits` is already encoded in `code`
55
56
os << bits;
56
57
}
57
58
if (lanes != 1 ) {
58
59
os << " x" << lanes;
59
60
}
60
61
return os.str ();
61
62
}
62
-
63
- static inline MLC_SYMBOL_HIDE std::unordered_map<std::string, DLDataType> preset = {
64
- {" void" , {kDLOpaqueHandle , 0 , 0 }},
65
- {" bool" , {kDLUInt , 1 , 1 }},
66
- {" int4" , {kDLInt , 4 , 1 }},
67
- {" int8" , {kDLInt , 8 , 1 }},
68
- {" int16" , {kDLInt , 16 , 1 }},
69
- {" int32" , {kDLInt , 32 , 1 }},
70
- {" int64" , {kDLInt , 64 , 1 }},
71
- {" uint4" , {kDLUInt , 4 , 1 }},
72
- {" uint8" , {kDLUInt , 8 , 1 }},
73
- {" uint16" , {kDLUInt , 16 , 1 }},
74
- {" uint32" , {kDLUInt , 32 , 1 }},
75
- {" uint64" , {kDLUInt , 64 , 1 }},
76
- {" float8_e4m3fn" , {kDLDataTypeFloat8E4M3FN , 8 , 1 }},
77
- {" float8_e5m2" , {kDLDataTypeFloat8E5M2 , 8 , 1 }},
78
- {" float16" , {kDLFloat , 16 , 1 }},
79
- {" float32" , {kDLFloat , 32 , 1 }},
80
- {" float64" , {kDLFloat , 64 , 1 }},
81
- {" bfloat16" , {kDLBfloat , 16 , 1 }},
82
- };
83
63
};
84
64
85
- MLC_INLINE const char *DLDataTypeCode2Str (int32_t type_code) {
86
- switch (type_code) {
87
- case kDLInt :
88
- return " int" ;
89
- case kDLUInt :
90
- return " uint" ;
91
- case kDLFloat :
92
- return " float" ;
93
- case kDLOpaqueHandle :
94
- return " ptr" ;
95
- case kDLBfloat :
96
- return " bfloat" ;
97
- case kDLComplex :
98
- return " complex" ;
99
- case kDLBool :
100
- return " bool" ;
101
- case kDLDataTypeFloat8E4M3FN :
102
- return " float8_e4m3fn" ;
103
- case kDLDataTypeFloat8E5M2 :
104
- return " float8_e5m2" ;
105
- }
106
- return " unknown" ;
107
- }
108
-
109
- inline DLDataType String2DLDataType (const std::string &source) {
110
- constexpr int64_t u16_max = 65535 ;
111
- constexpr int64_t u8_max = 255 ;
112
- using Traits = TypeTraits<DLDataType>;
113
- if (auto it = Traits::preset.find (source); it != Traits::preset.end ()) {
114
- return it->second ;
115
- }
116
- try {
117
- int64_t dtype_lanes = 1 ;
118
- std::string dtype_str;
119
- if (size_t x_pos = source.rfind (' x' ); x_pos != std::string::npos) {
120
- dtype_str = source.substr (0 , x_pos);
121
- dtype_lanes = StrToInt (source, x_pos + 1 );
122
- if (dtype_lanes < 0 || dtype_lanes > u16_max) {
123
- throw std::runtime_error (" Invalid DLDataType" );
124
- }
125
- } else {
126
- dtype_str = source;
127
- }
128
- if (dtype_str == " float8_e4m3fn" ) {
129
- return {static_cast <uint8_t >(kDLDataTypeFloat8E4M3FN ), 8 , static_cast <uint16_t >(dtype_lanes)};
130
- }
131
- if (dtype_str == " float8_e5m2" ) {
132
- return {static_cast <uint8_t >(kDLDataTypeFloat8E5M2 ), 8 , static_cast <uint16_t >(dtype_lanes)};
133
- }
134
- #define MLC_DTYPE_PARSE_ (str, prefix, prefix_len, dtype_code ) \
135
- if (str.length () >= prefix_len && str.compare (0 , prefix_len, prefix) == 0 ) { \
136
- int64_t dtype_bits = StrToInt (str, prefix_len); \
137
- if (dtype_bits < 0 || dtype_bits > u8_max) { \
138
- throw std::runtime_error (" Invalid DLDataType" ); \
139
- } \
140
- return {static_cast <uint8_t >(dtype_code), static_cast <uint8_t >(dtype_bits), static_cast <uint16_t >(dtype_lanes)}; \
141
- }
142
- MLC_DTYPE_PARSE_ (dtype_str, " int" , 3 , kDLInt )
143
- MLC_DTYPE_PARSE_ (dtype_str, " uint" , 4 , kDLUInt )
144
- MLC_DTYPE_PARSE_ (dtype_str, " float" , 5 , kDLFloat )
145
- MLC_DTYPE_PARSE_ (dtype_str, " ptr" , 3 , kDLOpaqueHandle )
146
- MLC_DTYPE_PARSE_ (dtype_str, " bfloat" , 6 , kDLBfloat )
147
- MLC_DTYPE_PARSE_ (dtype_str, " complex" , 7 , kDLComplex )
148
- #undef MLC_DTYPE_PARSE_
149
- } catch (...) {
150
- }
151
- MLC_THROW (ValueError) << " Cannot convert to `dtype` from string: " << source;
152
- MLC_UNREACHABLE ();
153
- }
65
+ inline DLDataType DataTypeFromStr (const char *source) { return ::mlc::Lib::DataTypeFromStr (source); }
154
66
155
67
} // namespace base
156
68
} // namespace mlc
0 commit comments