Skip to content

Commit 0f9fde5

Browse files
committed
feat(python_ffi): 支持从 executor 存取 Device::Blob
Signed-off-by: YdrMaster <[email protected]>
1 parent 5046a53 commit 0f9fde5

File tree

5 files changed

+38
-13
lines changed

5 files changed

+38
-13
lines changed

src/03runtime/include/runtime/stream.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ namespace refactor::runtime {
4242
decltype(_device));
4343

4444
decltype(_graph) const &graph() const noexcept { return _graph; }
45-
void setData(count_t, void const *, size_t);
45+
auto setData(count_t, size_t) -> Arc<hardware::Device::Blob>;
4646
void setData(count_t, Arc<hardware::Device::Blob>);
47-
bool getData(count_t, void *, size_t) const;
47+
auto getData(count_t) -> Arc<hardware::Device::Blob> const;
48+
void setData(count_t, void const *, size_t);
49+
bool copyData(count_t, void *, size_t) const;
4850
void run();
4951
auto bench(void (*sync)()) -> std::vector<std::chrono::nanoseconds>;
5052
void trace(std::function<void(count_t, void const *const *, void const *const *)>);

src/03runtime/src/stream.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@ namespace refactor::runtime {
1818
std::move(edges),
1919
} {}
2020

21+
auto Stream::setData(count_t i, size_t size) -> Arc<hardware::Device::Blob> {
22+
return _graph.edges[i].blob = _device->malloc(size);
23+
}
24+
void Stream::setData(count_t i, Arc<hardware::Device::Blob> blob) {
25+
_graph.edges[i].blob = std::move(blob);
26+
}
2127
void Stream::setData(count_t i, void const *data, size_t size) {
2228
auto blob = _device->malloc(size);
2329
blob->copyFromHost(data, size);
2430
_graph.edges[i].blob = std::move(blob);
2531
}
26-
void Stream::setData(count_t i, Arc<hardware::Device::Blob> blob) {
27-
_graph.edges[i].blob = std::move(blob);
32+
auto Stream::getData(count_t i) -> Arc<hardware::Device::Blob> const {
33+
return _graph.edges[i].blob;
2834
}
29-
bool Stream::getData(count_t i, void *data, size_t size) const {
35+
bool Stream::copyData(count_t i, void *data, size_t size) const {
3036
if (!_graph.edges[i].blob) { return false; }
3137
_graph.edges[i].blob->copyToHost(data, size);
3238
return true;

src/09python_ffi/src/executor.cc

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace refactor::python_ffi {
2626
for (auto i : graph.topology.globalInputs()) {
2727
auto size = graph.edges[i].tensor->bytesSize();
2828
buffer.resize(size);
29-
if (stream.getData(i, buffer.data(), size)) {
29+
if (stream.copyData(i, buffer.data(), size)) {
3030
_stream.setData(i, buffer.data(), size);
3131
}
3232
}
@@ -35,24 +35,36 @@ namespace refactor::python_ffi {
3535
void Executor::setInput(count_t i, pybind11::array data) {
3636
i = _stream.graph().topology.globalInputs().at(i);
3737

38-
auto const &name = _stream.graph().edges[i].name;
39-
auto const &edges = _graph.internal().contiguous().edges;
40-
auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor;
38+
auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
4139
ASSERT(tensor.bytesSize() == static_cast<size_t>(data.nbytes()), "input size mismatch");
4240
_stream.setData(i, data.data(), data.nbytes());
4341
}
4442

4543
auto Executor::getOutput(count_t i) -> pybind11::array {
4644
i = _stream.graph().topology.globalOutputs().at(i);
4745

48-
auto const &name = _stream.graph().edges[i].name;
49-
auto const &edges = _graph.internal().contiguous().edges;
50-
auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor;
46+
auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
5147
auto ans = pybind11::array(buildNumpyDType(tensor.dataType), std::move(tensor.shape));
52-
_stream.getData(i, ans.mutable_data(), ans.nbytes());
48+
_stream.copyData(i, ans.mutable_data(), ans.nbytes());
5349
return ans;
5450
}
5551

52+
auto Executor::pin(count_t i) -> Arc<hardware::Device::Blob> {
53+
i = _stream.graph().topology.globalInputs().at(i);
54+
55+
if (auto pinned = _stream.getData(i); pinned) {
56+
return pinned;
57+
} else {
58+
auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
59+
return _stream.setData(i, tensor.bytesSize());
60+
}
61+
}
62+
void Executor::setPinned(count_t i, Arc<hardware::Device::Blob> pinned) {
63+
i = _stream.graph().topology.globalInputs().at(i);
64+
65+
_stream.setData(i, std::move(pinned));
66+
}
67+
5668
void Executor::run() {
5769
_stream.run();
5870
}

src/09python_ffi/src/executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace refactor::python_ffi {
1616
void dispatch(Arc<hardware::Device>, std::string allocator);
1717
void setInput(count_t, pybind11::array);
1818
auto getOutput(count_t) -> pybind11::array;
19+
auto pin(count_t) -> Arc<hardware::Device::Blob>;
20+
void setPinned(count_t, Arc<hardware::Device::Blob>);
1921
void run();
2022
void bench(bool sync);
2123
void trace(std::string path, std::string format);

src/09python_ffi/src/main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace refactor::python_ffi {
2121
py::class_<Tensor , Arc<Tensor> >(m, "Tensor" );
2222
py::class_<OpBox , Arc<OpBox> >(m, "Operator" );
2323
py::class_<Device , Arc<Device> >(m, "Device" );
24+
py::class_<Device::Blob, Arc<Device::Blob>>(m, "Pinned" );
2425

2526
m .def("config_log" , &configLog , return_::automatic )
2627
.def("find_device" , &findDevice , return_::move )
@@ -44,6 +45,8 @@ namespace refactor::python_ffi {
4445
.def("dispatch" , &Executor::dispatch , return_::automatic )
4546
.def("set_input" , &Executor::setInput , return_::automatic )
4647
.def("get_output" , &Executor::getOutput , return_::move )
48+
.def("pin" , &Executor::pin , return_::move )
49+
.def("set_pinned" , &Executor::setPinned , return_::automatic )
4750
.def("run" , &Executor::run , return_::automatic )
4851
.def("bench" , &Executor::bench , return_::automatic )
4952
.def("trace" , &Executor::trace , return_::automatic )

0 commit comments

Comments
 (0)