@@ -2895,7 +2895,24 @@ at::Tensor viewStorage(const at::Tensor input, const c10::IntArrayRef sizes, con
28952895 if (st != -1 ) st *= sizes[i - 1 ];
28962896 }
28972897 }
2898- return fromPreAllocated (input.data_ptr () + storageOffset * input.itemsize (), sizes, stridesVec, input.options ());
2898+
2899+ // when shape[0]=-1, fill data
2900+ std::vector<int64_t > sizeVec (sizes.size (), 1 );
2901+ std::copy (sizes.begin (), sizes.end (), sizeVec.begin ());
2902+ if (!sizes.empty () && sizes[0 ] == -1 ) {
2903+ bool flag = true ;
2904+ for (auto i : sizes) {
2905+ if (!flag && i < 0 ) {
2906+ TORCH_CHECK (false , " more than one -1, sizes=" , sizes);
2907+ }
2908+ if (i < 0 ) {
2909+ flag = false ;
2910+ }
2911+ }
2912+ int count = std::accumulate (sizeVec.begin () + 1 , sizeVec.end (), 1 , std::multiplies<int >());
2913+ sizeVec[0 ] = input.numel () / count;
2914+ }
2915+ return fromPreAllocated (input.data_ptr () + storageOffset * input.itemsize (), sizeVec, stridesVec, input.options ());
28992916}
29002917
29012918c10::List<c10::optional<at::Tensor>> castIntIndicesToLongIndices (const c10::List<c10::optional<at::Tensor>>& indices) {
@@ -3057,7 +3074,11 @@ at::Tensor wrapper__transpose(const at::Tensor& self, int64_t dim0, int64_t dim1
30573074}
30583075
30593076at::Scalar wrapper___local_scalar_dense (const at::Tensor& self) { return at_npu::native::NPUNativeFunctions::_local_scalar_dense (self); }
3077+ at::Tensor& wrapper_out_mm_out (const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::mm_out (self, mat2, out); }
30603078
3079+ at::Tensor& wrapper_source_Tensor_set_ (at::Tensor& self, const at::Tensor& source) { return at_npu::native::NPUNativeFunctions::set_ (self, source); }
3080+ at::Tensor& wrapper_out_bmm_out (const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { return acl_op::bmm_out (self, mat2, out); }
3081+ at::Tensor wrapper__dot (const at::Tensor& self, const at::Tensor& tensor) { return acl_op::dot (self, tensor); }
30613082} // namespace
30623083
30633084namespace at {
@@ -3092,6 +3113,11 @@ TORCH_LIBRARY_IMPL(aten, XLA, m) {
30923113 m.impl (" repeat" , TORCH_FN (wrapper__repeat));
30933114 m.impl (" transpose.int" , TORCH_FN (wrapper__transpose));
30943115 m.impl (" _local_scalar_dense" , TORCH_FN (wrapper___local_scalar_dense));
3116+ m.impl (" cat" , TORCH_FN (wrapper__cat));
3117+ m.impl (" mm.out" , TORCH_FN (wrapper_out_mm_out));
3118+ m.impl (" set_.source_Tensor" , TORCH_FN (wrapper_source_Tensor_set_));
3119+ m.impl (" dot" , TORCH_FN (wrapper__dot));
3120+ m.impl (" bmm.out" , TORCH_FN (wrapper_out_bmm_out));
30953121};
30963122
30973123TORCH_LIBRARY_IMPL (_, XLA, m) { m.fallback (torch::CppFunction::makeFromBoxedFunction<&ascend_diopi_fallback>()); }
0 commit comments