1- #include < tvm/runtime/packed_func .h>
2- #include < tvm/runtime/registry .h>
3-
1+ #include < tvm/ffi/function .h>
2+ #include < tvm/runtime/ndarray .h>
3+ # include < tvm/runtime/device_api.h >
44#include < string>
5- #include " serve/lora_manager.h"
5+ #include < iostream>
6+ #include " lora_manager.h"
67
78namespace mlc ::serve {
89
9- static void UploadLora (const std::string& adapter_npz) {
10- // Alpha to be plumbed in later via manifest – use 1.0 for now.
11- mlc::serve::LoraManager::Global ()->UploadAdapter (adapter_npz, /* alpha=*/ 1 .0f );
12- }
10+ using namespace tvm ;
11+ using namespace tvm ::runtime;
1312
14- } // namespace mlc::serve
13+ // REAL TVM FFI registration for LoRA functions
14+ TVM_FFI_REGISTER_GLOBAL (" mlc.get_lora_delta" )
15+ .set_body_typed([](const String& param_name) -> NDArray {
16+ std::cout << " REAL TVM FFI: get_lora_delta called for: " << param_name << std::endl;
17+
18+ // Get the actual LoRA delta from the manager
19+ auto delta_tensor = LoraManager::Global ()->Lookup (param_name);
20+
21+ if (delta_tensor.defined ()) {
22+ std::cout << " REAL TVM FFI: Found delta tensor with shape: [" ;
23+ for (int i = 0 ; i < delta_tensor->ndim ; ++i) {
24+ std::cout << delta_tensor->shape [i];
25+ if (i < delta_tensor->ndim - 1 ) std::cout << " , " ;
26+ }
27+ std::cout << " ]" << std::endl;
28+ return delta_tensor;
29+ } else {
30+ std::cout << " REAL TVM FFI: No delta found, creating zero tensor" << std::endl;
31+ // Create a zero tensor - TVM will handle broadcasting
32+ Device device{kDLCPU , 0 };
33+ auto zero_tensor = NDArray::Empty ({1 , 1 }, DataType::Float (32 ), device);
34+ // Fill with zeros
35+ float * data = static_cast <float *>(zero_tensor->data );
36+ data[0 ] = 0 .0f ;
37+ return zero_tensor;
38+ }
39+ });
1540
16- // Expose a getter so Python (and other frontends) can retrieve the materialised
17- // delta tensor for a given full parameter name. The returned NDArray may be
18- // undefined if the key is missing.
19- TVM_REGISTER_GLOBAL (" mlc.get_lora_delta" ).set_body_typed([](const std::string& param_name) {
20- return mlc::serve::LoraManager::Global ()->Lookup (param_name);
41+ TVM_FFI_REGISTER_GLOBAL (" mlc.set_active_device" )
42+ .set_body_typed([](int dev_type, int dev_id) {
43+ std::cout << " REAL TVM FFI: set_active_device called: " << dev_type << " , " << dev_id << std::endl;
44+ LoraManager::Global ()->SetDevice (dev_type, dev_id);
2145});
2246
23- // Called once by Python side to tell C++ what device the runtime operates on.
24- TVM_REGISTER_GLOBAL (" mlc.set_active_device" ).set_body_typed([](int dev_type, int dev_id) {
25- mlc::serve::LoraManager::Global ()->SetDevice (dev_type, dev_id);
47+ TVM_FFI_REGISTER_GLOBAL (" mlc.serve.UploadLora" )
48+ .set_body_typed([](const String& adapter_path) {
49+ std::cout << " REAL TVM FFI: UploadLora called with: " << adapter_path << std::endl;
50+ LoraManager::Global ()->UploadAdapter (adapter_path, 1 .0f );
2651});
2752
28- // Register with TVM's FFI so that python can call this symbol via
29- // `tvm.get_global_func("mlc.serve.UploadLora")`.
30- TVM_REGISTER_GLOBAL (" mlc.serve.UploadLora" )
31- .set_body_typed([](const std::string& adapter_path) {
32- mlc::serve::UploadLora (adapter_path);
33- });
53+ // Keep the namespace functions for direct C++ access
54+ void UploadLora (const std::string& adapter_path) {
55+ LoraManager::Global ()->UploadAdapter (adapter_path, 1 .0f );
56+ }
57+
58+ std::string GetLoraDelta (const std::string& param_name) {
59+ auto result = LoraManager::Global ()->Lookup (param_name);
60+ return result.defined () ? " tensor_found" : " tensor_not_found" ;
61+ }
62+
63+ void SetActiveDevice (int dev_type, int dev_id) {
64+ LoraManager::Global ()->SetDevice (dev_type, dev_id);
65+ }
66+
67+ } // namespace mlc::serve
0 commit comments