|
15 | 15 | #include <vector> |
16 | 16 |
|
17 | 17 | #include "xtensor/xtensor.hpp" |
| 18 | +#include "xtensor/xfixed.hpp" |
18 | 19 |
|
19 | 20 | #include <pybind11/numpy.h> |
20 | 21 | #include <pybind11/pybind11.h> |
@@ -64,6 +65,15 @@ namespace pybind11 |
64 | 65 | } |
65 | 66 | }; |
66 | 67 |
|
| 68 | + template <class T, class FSH, xt::layout_type L> |
| 69 | + struct pybind_array_getter<xt::xtensor_fixed<T, FSH, L>> |
| 70 | + { |
| 71 | + static auto run(handle src) |
| 72 | + { |
| 73 | + return pybind_array_getter_impl<T, L>::run(src); |
| 74 | + } |
| 75 | + }; |
| 76 | + |
67 | 77 | template <class CT, class S, xt::layout_type L, class FST> |
68 | 78 | struct pybind_array_getter<xt::xstrided_view<CT, S, L, FST>> |
69 | 79 | { |
@@ -113,6 +123,37 @@ namespace pybind11 |
113 | 123 | } |
114 | 124 | }; |
115 | 125 |
|
| 126 | + template <class T, class FSH, xt::layout_type L> |
| 127 | + struct pybind_array_dim_checker<xt::xtensor_fixed<T, FSH, L>> |
| 128 | + { |
| 129 | + template <class B> |
| 130 | + static bool run(const B& buf) |
| 131 | + { |
| 132 | + return buf.ndim() == FSH::size(); |
| 133 | + } |
| 134 | + }; |
| 135 | + |
| 136 | + |
| 137 | + template <class T> |
| 138 | + struct pybind_array_shape_checker |
| 139 | + { |
| 140 | + template <class B> |
| 141 | + static bool run(const B& buf) |
| 142 | + { |
| 143 | + return true; |
| 144 | + } |
| 145 | + }; |
| 146 | + |
| 147 | + template <class T, class FSH, xt::layout_type L> |
| 148 | + struct pybind_array_shape_checker<xt::xtensor_fixed<T, FSH, L>> |
| 149 | + { |
| 150 | + template <class B> |
| 151 | + static bool run(const B& buf) |
| 152 | + { |
| 153 | + auto shape = FSH(); |
| 154 | + return std::equal(shape.begin(), shape.end(), buf.shape()); |
| 155 | + } |
| 156 | + }; |
116 | 157 |
|
117 | 158 | // Casts a strided expression type to numpy array.If given a base, |
118 | 159 | // the numpy array references the src data, otherwise it'll make a copy. |
@@ -215,6 +256,11 @@ namespace pybind11 |
215 | 256 | return false; |
216 | 257 | } |
217 | 258 |
|
| 259 | + if (!pybind_array_shape_checker<Type>::run(buf)) |
| 260 | + { |
| 261 | + return false; |
| 262 | + } |
| 263 | + |
218 | 264 | std::vector<size_t> shape(buf.ndim()); |
219 | 265 | std::copy(buf.shape(), buf.shape() + buf.ndim(), shape.begin()); |
220 | 266 | value = Type::from_shape(shape); |
|
0 commit comments