Skip to content

Commit 54c2f7e

Browse files
committed
refactor(python_ffi): 修改从 exector 存取数据块的接口
Signed-off-by: YdrMaster <[email protected]>
1 parent a41a09f commit 54c2f7e

File tree

5 files changed

+19
-21
lines changed

5 files changed

+19
-21
lines changed

src/03runtime/include/runtime/stream.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace refactor::runtime {
4444
decltype(_graph) const &graph() const noexcept { return _graph; }
4545
auto setData(count_t, size_t) -> Arc<hardware::Device::Blob>;
4646
void setData(count_t, Arc<hardware::Device::Blob>);
47-
auto getData(count_t) -> Arc<hardware::Device::Blob> const;
47+
auto getData(count_t) const -> Arc<hardware::Device::Blob>;
4848
void setData(count_t, void const *, size_t);
4949
bool copyData(count_t, void *, size_t) const;
5050
void run();

src/03runtime/src/stream.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace refactor::runtime {
2929
blob->copyFromHost(data, size);
3030
_graph.edges[i].blob = std::move(blob);
3131
}
32-
auto Stream::getData(count_t i) -> Arc<hardware::Device::Blob> const {
32+
auto Stream::getData(count_t i) const -> Arc<hardware::Device::Blob> {
3333
return _graph.edges[i].blob;
3434
}
3535
bool Stream::copyData(count_t i, void *data, size_t size) const {

src/09python_ffi/src/executor.cc

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,15 @@ namespace refactor::python_ffi {
4040
_stream.setData(i, data.data(), data.nbytes());
4141
}
4242

43-
auto Executor::getOutput(count_t i) -> pybind11::array {
43+
void Executor::setInputBlob(count_t i, Arc<hardware::Device::Blob> blob) {
44+
i = _stream.graph().topology.globalInputs().at(i);
45+
46+
auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
47+
ASSERT(tensor.bytesSize() == blob->size(), "input size mismatch");
48+
_stream.setData(i, std::move(blob));
49+
}
50+
51+
auto Executor::getOutput(count_t i) const -> pybind11::array {
4452
i = _stream.graph().topology.globalOutputs().at(i);
4553

4654
auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
@@ -49,20 +57,10 @@ namespace refactor::python_ffi {
4957
return ans;
5058
}
5159

52-
auto Executor::pin(count_t i) -> Arc<hardware::Device::Blob> {
53-
i = _stream.graph().topology.globalInputs().at(i);
54-
55-
if (auto pinned = _stream.getData(i); pinned) {
56-
return pinned;
57-
} else {
58-
auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
59-
return _stream.setData(i, tensor.bytesSize());
60-
}
61-
}
62-
void Executor::setPinned(count_t i, Arc<hardware::Device::Blob> pinned) {
63-
i = _stream.graph().topology.globalInputs().at(i);
60+
auto Executor::getOutputBlob(count_t i) const -> Arc<hardware::Device::Blob> {
61+
i = _stream.graph().topology.globalOutputs().at(i);
6462

65-
_stream.setData(i, std::move(pinned));
63+
return _stream.getData(i);
6664
}
6765

6866
void Executor::run() {

src/09python_ffi/src/executor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ namespace refactor::python_ffi {
1515
Executor(computation::Graph, runtime::Stream);
1616
void dispatch(Arc<hardware::Device>, std::string allocator);
1717
void setInput(count_t, pybind11::array);
18-
auto getOutput(count_t) -> pybind11::array;
19-
auto pin(count_t) -> Arc<hardware::Device::Blob>;
20-
void setPinned(count_t, Arc<hardware::Device::Blob>);
18+
void setInputBlob(count_t, Arc<hardware::Device::Blob>);
19+
auto getOutput(count_t) const -> pybind11::array;
20+
auto getOutputBlob(count_t) const -> Arc<hardware::Device::Blob>;
2121
void run();
2222
void bench(bool sync);
2323
void trace(std::string path, std::string format);

src/09python_ffi/src/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ namespace refactor::python_ffi {
4444
py::class_<Executor , Arc<Executor>>(m, "Executor" )
4545
.def("dispatch" , &Executor::dispatch , return_::automatic )
4646
.def("set_input" , &Executor::setInput , return_::automatic )
47+
.def("set_input_blob" , &Executor::setInputBlob , return_::automatic )
4748
.def("get_output" , &Executor::getOutput , return_::move )
48-
.def("pin" , &Executor::pin , return_::move )
49-
.def("set_pinned" , &Executor::setPinned , return_::automatic )
49+
.def("get_output_blob" , &Executor::getOutputBlob , return_::move )
5050
.def("run" , &Executor::run , return_::automatic )
5151
.def("bench" , &Executor::bench , return_::automatic )
5252
.def("trace" , &Executor::trace , return_::automatic )

0 commit comments

Comments
 (0)