Skip to content

Commit a7b3452

Browse files
authored
feat(core): Interop torch.dtype/torch.device (#6)
This commit supports - dtype conversion: `mlc.DataType` <=> `torch.dtype` <=> `numpy.dtype` - device conversion: `mlc.Device` <=> `torch.device`. - dtype registration: `mlc.DataType.register(name: str, bits: int)` - device registration: `mlc.Device(name: str)`
1 parent 81fb94f commit a7b3452

File tree

16 files changed

+652
-256
lines changed

16 files changed

+652
-256
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: CI
33
on: [push, pull_request]
44
env:
55
CIBW_BUILD_VERBOSITY: 3
6-
CIBW_TEST_REQUIRES: "pytest"
6+
CIBW_TEST_REQUIRES: "pytest torch"
77
CIBW_TEST_COMMAND: "pytest -svv --durations=20 {project}/tests/python/"
88
MLC_CIBW_VERSION: "2.22.0"
99
MLC_PYTHON_VERSION: "3.9"

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ repos:
1414
- id: check-toml
1515
- id: check-added-large-files
1616
- repo: https://github.com/astral-sh/ruff-pre-commit
17-
rev: v0.8.4
17+
rev: v0.9.0
1818
hooks:
1919
- id: ruff
2020
types_or: [python, pyi, jupyter]
2121
args: [--fix]
2222
- id: ruff-format
2323
types_or: [python, pyi, jupyter]
2424
- repo: https://github.com/pre-commit/mirrors-mypy
25-
rev: "v1.14.0"
25+
rev: "v1.14.1"
2626
hooks:
2727
- id: mypy
28-
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest"]
28+
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "torch"]
2929
args: [--show-error-codes]
3030
- repo: https://github.com/pre-commit/mirrors-clang-format
31-
rev: "v19.1.5"
31+
rev: "v19.1.6"
3232
hooks:
3333
- id: clang-format
3434
- repo: https://github.com/MarcoGorelli/cython-lint

cpp/c_api.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ using ::mlc::registry::TypeTable;
1919

2020
namespace {
2121
thread_local Any last_error;
22-
MLC_REGISTER_FUNC("mlc.ffi.LoadDSO").set_body([](std::string name) { TypeTable::Get(nullptr)->LoadDSO(name); });
2322
} // namespace
2423

2524
MLC_API MLCAny MLCGetLastError() {

cpp/registry.h

Lines changed: 277 additions & 52 deletions
Large diffs are not rendered by default.

include/mlc/base/lib.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ struct Lib {
1111
static ::mlc::Str CxxStr(AnyView obj);
1212
static ::mlc::Str Str(AnyView obj);
1313
static Any IRPrint(AnyView obj, AnyView printer, AnyView path);
14+
static const char *DeviceTypeToStr(int32_t device_type);
15+
static int32_t DeviceTypeFromStr(const char *source);
16+
static void DeviceTypeRegister(const char *name);
17+
static const char *DataTypeCodeToStr(int32_t dtype_code);
18+
static DLDataType DataTypeFromStr(const char *source);
19+
static void DataTypeRegister(const char *name, int32_t dtype_bits);
1420

1521
static FuncObj *_init(int32_t type_index) { return VTableGetFunc(init, type_index, "__init__"); }
1622
MLC_INLINE static MLCTypeInfo *GetTypeInfo(int32_t type_index) {

include/mlc/base/traits_device.h

Lines changed: 15 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
#ifndef MLC_BASE_TRAITS_DEVICE_H_
22
#define MLC_BASE_TRAITS_DEVICE_H_
33

4+
#include "./lib.h"
45
#include "./utils.h"
5-
#include <unordered_map>
66

77
namespace mlc {
88
namespace base {
99

10-
const char *DLDeviceType2Str(DLDeviceType type);
11-
DLDevice String2DLDevice(const std::string &source);
10+
DLDevice DeviceFromStr(const std::string &source);
11+
1212
inline bool DeviceEqual(DLDevice a, DLDevice b) { return a.device_type == b.device_type && a.device_id == b.device_id; }
13+
inline const char *DeviceType2Str(int32_t device_type) { return ::mlc::Lib::DeviceTypeToStr(device_type); }
1314

1415
template <> struct TypeTraits<DLDevice> {
1516
static constexpr int32_t type_index = static_cast<int32_t>(MLCTypeIndex::kMLCDevice);
@@ -26,10 +27,10 @@ template <> struct TypeTraits<DLDevice> {
2627
return v->v.v_device;
2728
}
2829
if (ty == MLCTypeIndex::kMLCRawStr) {
29-
return String2DLDevice(v->v.v_str);
30+
return DeviceFromStr(v->v.v_str);
3031
}
3132
if (ty == MLCTypeIndex::kMLCStr) {
32-
return String2DLDevice(reinterpret_cast<const MLCStr *>(v->v.v_obj)->data);
33+
return DeviceFromStr(reinterpret_cast<const MLCStr *>(v->v.v_obj)->data);
3334
}
3435
throw TemporaryTypeError();
3536
}
@@ -38,87 +39,27 @@ template <> struct TypeTraits<DLDevice> {
3839

3940
MLC_INLINE static std::string __str__(DLDevice device) {
4041
std::ostringstream os;
41-
os << DLDeviceType2Str(static_cast<DLDeviceType>(device.device_type)) << ":" << device.device_id;
42+
os << DeviceType2Str(static_cast<DLDeviceType>(device.device_type)) << ":" << device.device_id;
4243
return os.str();
4344
}
44-
45-
static inline MLC_SYMBOL_HIDE std::unordered_map<std::string, DLDeviceType> str2device_type = {
46-
{"cpu", kDLCPU},
47-
{"cuda", kDLCUDA},
48-
{"cuda_host", kDLCUDAHost},
49-
{"opencl", kDLOpenCL},
50-
{"vulkan", kDLVulkan},
51-
{"mps", kDLMetal},
52-
{"vpi", kDLVPI},
53-
{"rocm", kDLROCM},
54-
{"rocm_host", kDLROCMHost},
55-
{"ext_dev", kDLExtDev},
56-
{"cuda_managed", kDLCUDAManaged},
57-
{"oneapi", kDLOneAPI},
58-
{"webgpu", kDLWebGPU},
59-
{"hexagon", kDLHexagon},
60-
{"maia", kDLMAIA},
61-
// aliases
62-
{"llvm", kDLCPU},
63-
{"nvptx", kDLCUDA},
64-
{"cl", kDLOpenCL},
65-
{"sdaccel", kDLOpenCL},
66-
{"metal", kDLMetal},
67-
};
6845
};
6946

70-
MLC_INLINE const char *DLDeviceType2Str(DLDeviceType type) {
71-
switch (type) {
72-
case kDLCPU:
73-
return "cpu";
74-
case kDLCUDA:
75-
return "cuda";
76-
case kDLCUDAHost:
77-
return "cuda_host";
78-
case kDLOpenCL:
79-
return "opencl";
80-
case kDLVulkan:
81-
return "vulkan";
82-
case kDLMetal:
83-
return "mps";
84-
case kDLVPI:
85-
return "vpi";
86-
case kDLROCM:
87-
return "rocm";
88-
case kDLROCMHost:
89-
return "rocm_host";
90-
case kDLExtDev:
91-
return "ext_dev";
92-
case kDLCUDAManaged:
93-
return "cuda_managed";
94-
case kDLOneAPI:
95-
return "oneapi";
96-
case kDLWebGPU:
97-
return "webgpu";
98-
case kDLHexagon:
99-
return "hexagon";
100-
case kDLMAIA:
101-
return "maia";
102-
}
103-
return "unknown";
104-
}
105-
106-
inline DLDevice String2DLDevice(const std::string &source) {
47+
inline DLDevice DeviceFromStr(const std::string &source) {
10748
constexpr int64_t i32_max = 2147483647;
108-
using Traits = TypeTraits<DLDevice>;
109-
DLDeviceType device_type;
49+
int32_t device_type;
11050
int64_t device_id = 0;
11151
try {
11252
if (size_t c_pos = source.rfind(':'); c_pos != std::string::npos) {
113-
device_type = Traits::str2device_type.at(source.substr(0, c_pos));
53+
device_type = ::mlc::Lib::DeviceTypeFromStr(source.substr(0, c_pos).c_str());
11454
device_id = StrToInt(source, c_pos + 1);
11555
} else {
116-
device_type = Traits::str2device_type.at(source);
56+
device_type = ::mlc::Lib::DeviceTypeFromStr(source.c_str());
57+
device_id = 0;
11758
}
118-
if (device_id < 0 || device_id > i32_max) {
119-
throw std::runtime_error("Invalid device id");
59+
if (device_type < 0 || device_id < 0 || device_id > i32_max) {
60+
throw std::runtime_error(""); // Going to catch it below
12061
}
121-
return DLDevice{device_type, static_cast<int32_t>(device_id)};
62+
return DLDevice{static_cast<DLDeviceType>(device_type), static_cast<int32_t>(device_id)};
12263
} catch (...) {
12364
}
12465
MLC_THROW(ValueError) << "Cannot convert to `Device` from string: " << source;

include/mlc/base/traits_dtype.h

Lines changed: 10 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#ifndef MLC_BASE_TRAITS_DTYPE_H_
22
#define MLC_BASE_TRAITS_DTYPE_H_
33

4+
#include "./lib.h"
45
#include "./utils.h"
5-
#include <cstdlib>
6-
#include <unordered_map>
76

87
namespace mlc {
98
namespace base {
109

11-
inline const char *DLDataTypeCode2Str(int32_t type_code);
12-
inline DLDataType String2DLDataType(const std::string &source);
10+
inline DLDataType DataTypeFromStr(const char *source);
11+
1312
inline bool DataTypeEqual(DLDataType a, DLDataType b) {
1413
return a.code == b.code && a.bits == b.bits && a.lanes == b.lanes;
1514
}
15+
inline const char *DataTypeCode2Str(int32_t type_code) { return ::mlc::Lib::DataTypeCodeToStr(type_code); }
1616

1717
template <> struct TypeTraits<DLDataType> {
1818
static constexpr int32_t type_index = static_cast<int32_t>(MLCTypeIndex::kMLCDataType);
@@ -29,10 +29,10 @@ template <> struct TypeTraits<DLDataType> {
2929
return v->v.v_dtype;
3030
}
3131
if (ty == MLCTypeIndex::kMLCRawStr) {
32-
return String2DLDataType(v->v.v_str);
32+
return DataTypeFromStr(v->v.v_str);
3333
}
3434
if (ty == MLCTypeIndex::kMLCStr) {
35-
return String2DLDataType(reinterpret_cast<const MLCStr *>(v->v.v_obj)->data);
35+
return DataTypeFromStr(reinterpret_cast<const MLCStr *>(v->v.v_obj)->data);
3636
}
3737
throw TemporaryTypeError();
3838
}
@@ -50,107 +50,19 @@ template <> struct TypeTraits<DLDataType> {
5050
return "void";
5151
}
5252
std::ostringstream os;
53-
os << DLDataTypeCode2Str(code);
54-
if (code != kDLDataTypeFloat8E5M2 && code != kDLDataTypeFloat8E4M3FN) {
53+
os << DataTypeCode2Str(code);
54+
if (code < kMLCExtension_DLDataTypeCode_Begin) {
55+
// for `code >= kMLCExtension_DLDataTypeCode_Begin`, the `bits` is already encoded in `code`
5556
os << bits;
5657
}
5758
if (lanes != 1) {
5859
os << "x" << lanes;
5960
}
6061
return os.str();
6162
}
62-
63-
static inline MLC_SYMBOL_HIDE std::unordered_map<std::string, DLDataType> preset = {
64-
{"void", {kDLOpaqueHandle, 0, 0}},
65-
{"bool", {kDLUInt, 1, 1}},
66-
{"int4", {kDLInt, 4, 1}},
67-
{"int8", {kDLInt, 8, 1}},
68-
{"int16", {kDLInt, 16, 1}},
69-
{"int32", {kDLInt, 32, 1}},
70-
{"int64", {kDLInt, 64, 1}},
71-
{"uint4", {kDLUInt, 4, 1}},
72-
{"uint8", {kDLUInt, 8, 1}},
73-
{"uint16", {kDLUInt, 16, 1}},
74-
{"uint32", {kDLUInt, 32, 1}},
75-
{"uint64", {kDLUInt, 64, 1}},
76-
{"float8_e4m3fn", {kDLDataTypeFloat8E4M3FN, 8, 1}},
77-
{"float8_e5m2", {kDLDataTypeFloat8E5M2, 8, 1}},
78-
{"float16", {kDLFloat, 16, 1}},
79-
{"float32", {kDLFloat, 32, 1}},
80-
{"float64", {kDLFloat, 64, 1}},
81-
{"bfloat16", {kDLBfloat, 16, 1}},
82-
};
8363
};
8464

85-
MLC_INLINE const char *DLDataTypeCode2Str(int32_t type_code) {
86-
switch (type_code) {
87-
case kDLInt:
88-
return "int";
89-
case kDLUInt:
90-
return "uint";
91-
case kDLFloat:
92-
return "float";
93-
case kDLOpaqueHandle:
94-
return "ptr";
95-
case kDLBfloat:
96-
return "bfloat";
97-
case kDLComplex:
98-
return "complex";
99-
case kDLBool:
100-
return "bool";
101-
case kDLDataTypeFloat8E4M3FN:
102-
return "float8_e4m3fn";
103-
case kDLDataTypeFloat8E5M2:
104-
return "float8_e5m2";
105-
}
106-
return "unknown";
107-
}
108-
109-
inline DLDataType String2DLDataType(const std::string &source) {
110-
constexpr int64_t u16_max = 65535;
111-
constexpr int64_t u8_max = 255;
112-
using Traits = TypeTraits<DLDataType>;
113-
if (auto it = Traits::preset.find(source); it != Traits::preset.end()) {
114-
return it->second;
115-
}
116-
try {
117-
int64_t dtype_lanes = 1;
118-
std::string dtype_str;
119-
if (size_t x_pos = source.rfind('x'); x_pos != std::string::npos) {
120-
dtype_str = source.substr(0, x_pos);
121-
dtype_lanes = StrToInt(source, x_pos + 1);
122-
if (dtype_lanes < 0 || dtype_lanes > u16_max) {
123-
throw std::runtime_error("Invalid DLDataType");
124-
}
125-
} else {
126-
dtype_str = source;
127-
}
128-
if (dtype_str == "float8_e4m3fn") {
129-
return {static_cast<uint8_t>(kDLDataTypeFloat8E4M3FN), 8, static_cast<uint16_t>(dtype_lanes)};
130-
}
131-
if (dtype_str == "float8_e5m2") {
132-
return {static_cast<uint8_t>(kDLDataTypeFloat8E5M2), 8, static_cast<uint16_t>(dtype_lanes)};
133-
}
134-
#define MLC_DTYPE_PARSE_(str, prefix, prefix_len, dtype_code) \
135-
if (str.length() >= prefix_len && str.compare(0, prefix_len, prefix) == 0) { \
136-
int64_t dtype_bits = StrToInt(str, prefix_len); \
137-
if (dtype_bits < 0 || dtype_bits > u8_max) { \
138-
throw std::runtime_error("Invalid DLDataType"); \
139-
} \
140-
return {static_cast<uint8_t>(dtype_code), static_cast<uint8_t>(dtype_bits), static_cast<uint16_t>(dtype_lanes)}; \
141-
}
142-
MLC_DTYPE_PARSE_(dtype_str, "int", 3, kDLInt)
143-
MLC_DTYPE_PARSE_(dtype_str, "uint", 4, kDLUInt)
144-
MLC_DTYPE_PARSE_(dtype_str, "float", 5, kDLFloat)
145-
MLC_DTYPE_PARSE_(dtype_str, "ptr", 3, kDLOpaqueHandle)
146-
MLC_DTYPE_PARSE_(dtype_str, "bfloat", 6, kDLBfloat)
147-
MLC_DTYPE_PARSE_(dtype_str, "complex", 7, kDLComplex)
148-
#undef MLC_DTYPE_PARSE_
149-
} catch (...) {
150-
}
151-
MLC_THROW(ValueError) << "Cannot convert to `dtype` from string: " << source;
152-
MLC_UNREACHABLE();
153-
}
65+
inline DLDataType DataTypeFromStr(const char *source) { return ::mlc::Lib::DataTypeFromStr(source); }
15466

15567
} // namespace base
15668
} // namespace mlc

include/mlc/c_api.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,30 @@
3030
extern "C" {
3131
#endif
3232

33-
typedef enum {
34-
// TODO: 1) add complete set of fp8 support; 2) allow more flexible dtype definition
35-
kDLDataTypeFloat8E4M3FN = 7,
36-
kDLDataTypeFloat8E5M2 = 8,
37-
} DLDataTypeCodeExtension;
33+
typedef enum { // ranged [0, 2 ** 8)
34+
kMLCExtension_DLDataTypeCode_Begin = 7,
35+
// 8-bit floating point representations
36+
kDLDataTypeFloat8Begin = 7,
37+
kDLDataTypeFloat8E3M4 = 7,
38+
kDLDataTypeFloat8E4M3 = 8,
39+
kDLDataTypeFloat8E4M3B11FNUZ = 9,
40+
kDLDataTypeFloat8E4M3FN = 10,
41+
kDLDataTypeFloat8E4M3FNUZ = 11,
42+
kDLDataTypeFloat8E5M2 = 12,
43+
kDLDataTypeFloat8E5M2FNUZ = 13,
44+
kDLDataTypeFloat8E8M0FNU = 14,
45+
kDLDataTypeFloat8End = 15,
46+
// Microscaling (MX) sub-byte floating point representations
47+
kDLDataTypeFloat4E2M1FN = 15, // higher 4 bits are unused
48+
kDLDataTypeFloat6E2M3FN = 16, // higher 2 bits are unused
49+
kDLDataTypeFloat6E3M2FN = 17, // higher 2 bits are unused
50+
kMLCExtension_DLDataTypeCode_End = kDLDataTypeFloat6E3M2FN,
51+
} MLCExtension_DLDataTypeCode;
52+
53+
typedef enum { // ranged [0, 2 ** 32)
54+
kMLCExtension_DLDeviceType_Begin = 18,
55+
kMLCExtension_DLDeviceType_End = kMLCExtension_DLDeviceType_Begin,
56+
} MLCExtension_DLDeviceType;
3857

3958
#ifdef __cplusplus
4059
enum MLCTypeIndex : int32_t {

0 commit comments

Comments
 (0)