@@ -40,7 +40,15 @@ namespace refactor::python_ffi {
4040 _stream.setData (i, data.data (), data.nbytes ());
4141 }
4242
43- auto Executor::getOutput (count_t i) -> pybind11::array {
43+ void Executor::setInputBlob (count_t i, Arc<hardware::Device::Blob> blob) {
44+ i = _stream.graph ().topology .globalInputs ().at (i);
45+
46+ auto const &tensor = *_graph.internal ().contiguous ().edges [i].tensor ;
47+ ASSERT (tensor.bytesSize () == blob->size (), " input size mismatch" );
48+ _stream.setData (i, std::move (blob));
49+ }
50+
51+ auto Executor::getOutput (count_t i) const -> pybind11::array {
4452 i = _stream.graph ().topology .globalOutputs ().at (i);
4553
4654 auto const &tensor = *_graph.internal ().contiguous ().edges [i].tensor ;
@@ -49,20 +57,10 @@ namespace refactor::python_ffi {
4957 return ans;
5058 }
5159
52- auto Executor::pin (count_t i) -> Arc<hardware::Device::Blob> {
53- i = _stream.graph ().topology .globalInputs ().at (i);
54-
55- if (auto pinned = _stream.getData (i); pinned) {
56- return pinned;
57- } else {
58- auto const &tensor = *_graph.internal ().contiguous ().edges [i].tensor ;
59- return _stream.setData (i, tensor.bytesSize ());
60- }
61- }
62- void Executor::setPinned (count_t i, Arc<hardware::Device::Blob> pinned) {
63- i = _stream.graph ().topology .globalInputs ().at (i);
60+ auto Executor::getOutputBlob (count_t i) const -> Arc<hardware::Device::Blob> {
61+ i = _stream.graph ().topology .globalOutputs ().at (i);
6462
65- _stream.setData (i, std::move (pinned) );
63+ return _stream.getData (i );
6664 }
6765
6866 void Executor::run () {
0 commit comments