1313#include < ostream>
1414#include < vector>
1515
16+ bool denseCheckAdaptor (diopiSize_t shape, diopiSize_t stride);
17+
1618std::vector<int64_t > calcStrides (diopiSize_t size, diopiMemoryFormat_t format = diopiMemoryFormat_t::Contiguous);
1719
1820bool isLikeChannelsLast (diopiConstTensorHandle_t tensor, bool checkContiguous, diopiMemoryFormat_t format = diopiMemoryFormat_t::ChannelsLast);
@@ -52,7 +54,7 @@ struct RemoveConst<diopiConstTensorHandle_t> {
5254
5355class NoCast {
5456public:
55- static bool getDstDtype (diopiDtype_t srcDtype, diopiDtype_t & targetDtype) {
57+ static bool getDstDtype (diopiDtype_t srcDtype, diopiDtype_t& targetDtype) {
5658 bool convert = false ;
5759 switch (srcDtype) {
5860 default :
@@ -63,7 +65,7 @@ class NoCast {
6365};
6466
6567template <class T , class strategy = NoCast>
66- ConvertType castImpl (diopiContextHandle_t ctx, T src, T * dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats = {}) {
68+ ConvertType castImpl (diopiContextHandle_t ctx, T src, T* dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats = {}) {
6769 ConvertType convertType;
6870 if (!src) {
6971 *dst = src;
@@ -77,13 +79,13 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
7779 strategy::getDstDtype (srcDtype, dstDtype);
7880 std::vector<diopiMemoryFormat_t> targetMemoryFormats = obtainTargetMemoryFormats (srcSize.len , supportMemoryFormats);
7981 diopiTensorHandle_t memoryFormatedTensor = nullptr ;
80-
8182 // convertDtype
8283
8384 diopiDevice_t device;
8485 diopiGetTensorDevice (src, &device);
8586 diopiTensorHandle_t tmp0 = nullptr ;
8687 bool needConvertDtype = srcDtype != dstDtype;
88+
8789 if (needConvertDtype) {
8890 diopiRequireTensor (ctx, &tmp0, &srcSize, &srcStride, dstDtype, device);
8991 diopiCastDtype (ctx, tmp0, src);
@@ -108,6 +110,13 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
108110 }
109111 diopiSize_t dstStride = srcStride;
110112 diopiSize_t dstSize = srcSize;
113+ if (!targetMemoryFormats.empty ()) {
114+ if (!denseCheckAdaptor (srcSize, srcStride) && supportMemoryFormats[0 ] == diopiMemoryFormat_t::Preserve) {
115+ targetMemoryFormats.push_back (diopiMemoryFormat_t::Preserve);
116+ needConvertMemoryFormat = true ;
117+ }
118+ }
119+
111120 if (needConvertMemoryFormat) {
112121 diopiContiguous (ctx, &memoryFormatedTensor, tmp0, targetMemoryFormats[0 ]);
113122 convertType.setMemoryFormatConverted ();
@@ -122,7 +131,7 @@ ConvertType castImpl(diopiContextHandle_t ctx, T src, T *dst, std::vector<diopiM
122131}
123132
124133template <class T , class strategy >
125- ConvertType requireTensorIfMemoryFormatConvert (diopiContextHandle_t ctx, T src, T * dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats) {
134+ ConvertType requireTensorIfMemoryFormatConvert (diopiContextHandle_t ctx, T src, T* dst, std::vector<diopiMemoryFormat_t> supportMemoryFormats) {
126135 ConvertType convertType;
127136 if (!src) {
128137 *dst = src;
@@ -139,6 +148,7 @@ ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src,
139148 if (targetMemoryFormats.empty ()) {
140149 needConvertMemoryFormat = false ;
141150 }
151+
142152 for (auto memoryFormat : targetMemoryFormats) {
143153 if (isContiguous (srcSize, srcStride, memoryFormat)) {
144154 needConvertMemoryFormat = false ;
@@ -174,7 +184,7 @@ ConvertType requireTensorIfMemoryFormatConvert(diopiContextHandle_t ctx, T src,
174184}
175185
176186template <typename Adaptor, typename ... Args>
177- void dispatchDiopi (diopiContextHandle_t ctx, Args &&...args) {
187+ void dispatchDiopi (diopiContextHandle_t ctx, Args&&... args) {
178188 auto adaptor = Adaptor ();
179189 adaptor (ctx, std::forward<Args>(args)...);
180190}
@@ -195,10 +205,10 @@ template <class strategy = NoCast>
195205class DiopiTensorWrapper {
196206public:
197207 // forbid copy/move constructor/assignment
198- DiopiTensorWrapper (const DiopiTensorWrapper &) = delete ;
199- DiopiTensorWrapper & operator =(const DiopiTensorWrapper &) = delete ;
200- DiopiTensorWrapper (DiopiTensorWrapper &&) = delete ;
201- DiopiTensorWrapper & operator =(DiopiTensorWrapper &&) = delete ;
208+ DiopiTensorWrapper (const DiopiTensorWrapper&) = delete ;
209+ DiopiTensorWrapper& operator =(const DiopiTensorWrapper&) = delete ;
210+ DiopiTensorWrapper (DiopiTensorWrapper&&) = delete ;
211+ DiopiTensorWrapper& operator =(DiopiTensorWrapper&&) = delete ;
202212
203213private:
204214 diopiContextHandle_t ctx_;
@@ -230,26 +240,6 @@ class DiopiTensorWrapper {
230240 if (convertType_.isDtypeConverted ()) {
231241 diopiCastDtype (ctx_, payload_, memoryFormatedTensor);
232242 }
233-
234- // if (convertType_.isDtypeConverted() &&
235- // !convertType_.isMemoryFormatConverted()) {
236- // diopiCastDtype(ctx_, payload_, tmp_);
237- // } else if (!convertType_.isDtypeConverted() &&
238- // convertType_.isMemoryFormatConverted()) {
239- // diopiCopyInp(ctx_, tmp_, payload_);
240- // } else {
241- // diopiDtype_t dtype;
242- // diopiGetTensorDtype(tmp_, &dtype);
243- // diopiSize_t size, stride, dstStride;
244- // diopiGetTensorShape(payload_, &size);
245- // diopiGetTensorStride(payload_, &stride);
246- // diopiDevice_t device;
247- // diopiGetTensorDevice(payload_, &device);
248- // diopiTensorHandle_t tmp = nullptr;
249- // diopiRequireTensor(ctx_, &tmp, &size, &stride, dtype, device);
250- // diopiCopyInp(ctx_, tmp_, tmp);
251- // diopiCastDtype(ctx_, payload_, tmp);
252- // }
253243 }
254244
255245public:
0 commit comments