diff --git a/ppisp/ext.cpp b/ppisp/ext.cpp index 02c5a0e..0e29a2e 100644 --- a/ppisp/ext.cpp +++ b/ppisp/ext.cpp @@ -18,6 +18,9 @@ #include "bindings.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("ppisp_forward", &ppisp_forward_tensor); - m.def("ppisp_backward", &ppisp_backward_tensor); + // Use move policy to ensure proper tensor ownership transfer + m.def("ppisp_forward", &ppisp_forward_tensor, + py::return_value_policy::move); + m.def("ppisp_backward", &ppisp_backward_tensor, + py::return_value_policy::move); }