@@ -575,6 +575,30 @@ static void check_shape_backward(
575
575
check_shape_forward (input, weight_sizes, /* bias=*/ Tensor (), params);
576
576
}
577
577
578
+ // Given an input tensor and an expected number of spatial dimensions, checks that the
579
+ // input is a valid shape and returns the batched form of the input.
580
+ //
581
+ // Args:
582
+ // input (Tensor): Input tensor
583
+ // num_spatial_dims (int): Number of spatial dimensions expected for the input
584
+ // func_name (string): Function name to produce a nice error message for invalid input
585
+ //
586
+ // Returns a std::tuple containing:
587
+ // batched_input (Tensor): Input with a batch dimension
588
+ // is_batched (bool): Indicates whether the original input was already batched
589
+ static std::tuple<Tensor, bool > batchify (
590
+ const Tensor& input,
591
+ const int64_t num_spatial_dims,
592
+ const std::string& func_name) {
593
+ const auto dim_count_no_batch = num_spatial_dims + 1 ;
594
+ const auto dim_count_batch = dim_count_no_batch + 1 ;
595
+ const auto is_batched = (input.dim () == dim_count_batch);
596
+ TORCH_CHECK (input.dim () == dim_count_no_batch || is_batched,
597
+ " Expected " , dim_count_no_batch, " D (unbatched) or " , dim_count_batch,
598
+ " D (batched) input to " , func_name, " , but got input of size: " , input.sizes ());
599
+ return std::make_tuple (is_batched ? input : input.unsqueeze (0 ), is_batched);
600
+ }
601
+
578
602
static void check_input_same_type_as_parameters (
579
603
const Tensor& input,
580
604
const Tensor& weight,
@@ -618,36 +642,45 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
618
642
619
643
620
644
at::Tensor conv1d (
621
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
645
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
622
646
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
623
647
// See [Note: hacky wrapper removal for optional tensor]
624
648
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor (bias_opt);
625
649
const Tensor& bias = *bias_maybe_owned;
626
650
627
- return at::convolution (input, weight, bias, stride, padding, dilation,
628
- false , {0 }, groups);
651
+ Tensor input;
652
+ bool is_batched;
653
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 1 , " conv1d" );
654
+ auto output = at::convolution (input, weight, bias, stride, padding, dilation, false , {0 }, groups);
655
+ return is_batched ? output : output.squeeze (0 );
629
656
}
630
657
631
658
at::Tensor conv2d (
632
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
659
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
633
660
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
634
661
// See [Note: hacky wrapper removal for optional tensor]
635
662
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor (bias_opt);
636
663
const Tensor& bias = *bias_maybe_owned;
637
664
638
- return at::convolution (input, weight, bias, stride, padding, dilation,
639
- false , {{0 , 0 }}, groups);
665
+ Tensor input;
666
+ bool is_batched;
667
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 2 , " conv2d" );
668
+ auto output = at::convolution (input, weight, bias, stride, padding, dilation, false , {{0 , 0 }}, groups);
669
+ return is_batched ? output : output.squeeze (0 );
640
670
}
641
671
642
672
at::Tensor conv3d (
643
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
673
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
644
674
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
645
675
// See [Note: hacky wrapper removal for optional tensor]
646
676
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor (bias_opt);
647
677
const Tensor& bias = *bias_maybe_owned;
648
678
649
- return at::convolution (input, weight, bias, stride, padding, dilation,
650
- false , {{0 , 0 , 0 }}, groups);
679
+ Tensor input;
680
+ bool is_batched;
681
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 3 , " conv3d" );
682
+ auto output = at::convolution (input, weight, bias, stride, padding, dilation, false , {{0 , 0 , 0 }}, groups);
683
+ return is_batched ? output : output.squeeze (0 );
651
684
}
652
685
653
686
@@ -736,60 +769,84 @@ Tensor _convolution_mode(
736
769
}
737
770
738
771
at::Tensor conv1d (
739
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias,
772
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias,
740
773
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
741
774
int64_t groups) {
742
- return at::_convolution_mode (
775
+ Tensor input;
776
+ bool is_batched;
777
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 1 , " conv1d" );
778
+ auto output = at::_convolution_mode (
743
779
input, weight, bias, stride, std::move (padding), dilation, groups);
780
+ return is_batched ? output : output.squeeze (0 );
744
781
}
745
782
746
783
at::Tensor conv2d (
747
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias,
784
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias,
748
785
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
749
786
int64_t groups) {
750
- return at::_convolution_mode (
787
+ Tensor input;
788
+ bool is_batched;
789
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 2 , " conv2d" );
790
+ auto output = at::_convolution_mode (
751
791
input, weight, bias, stride, std::move (padding), dilation, groups);
792
+ return is_batched ? output : output.squeeze (0 );
752
793
}
753
794
754
795
at::Tensor conv3d (
755
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias,
796
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias,
756
797
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
757
798
int64_t groups) {
758
- return at::_convolution_mode (
799
+ Tensor input;
800
+ bool is_batched;
801
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 3 , " conv3d" );
802
+ auto output = at::_convolution_mode (
759
803
input, weight, bias, stride, std::move (padding), dilation, groups);
804
+ return is_batched ? output : output.squeeze (0 );
760
805
}
761
806
762
807
at::Tensor conv_transpose1d (
763
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
808
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
764
809
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
765
810
// See [Note: hacky wrapper removal for optional tensor]
766
811
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor (bias_opt);
767
812
const Tensor& bias = *bias_maybe_owned;
768
813
769
- return at::convolution (input, weight, bias, stride, padding, dilation,
770
- true , output_padding, groups);
814
+ Tensor input;
815
+ bool is_batched;
816
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 1 , " conv_transpose1d" );
817
+ auto output = at::convolution (
818
+ input, weight, bias, stride, padding, dilation, true , output_padding, groups);
819
+ return is_batched ? output : output.squeeze (0 );
771
820
}
772
821
773
822
at::Tensor conv_transpose2d (
774
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
823
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
775
824
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
776
825
// See [Note: hacky wrapper removal for optional tensor]
777
826
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor (bias_opt);
778
827
const Tensor& bias = *bias_maybe_owned;
779
828
780
- return at::convolution (input, weight, bias, stride, padding, dilation,
781
- true , output_padding, groups);
829
+ Tensor input;
830
+ bool is_batched;
831
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 2 , " conv_transpose2d" );
832
+ auto output = at::convolution (
833
+ input, weight, bias, stride, padding, dilation, true , output_padding, groups);
834
+ return is_batched ? output : output.squeeze (0 );
782
835
}
783
836
784
837
at::Tensor conv_transpose3d (
785
- const Tensor& input , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
838
+ const Tensor& input_ , const Tensor& weight, const c10::optional<Tensor>& bias_opt,
786
839
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
787
840
// See [Note: hacky wrapper removal for optional tensor]
788
841
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor (bias_opt);
789
842
const Tensor& bias = *bias_maybe_owned;
790
843
791
- return at::convolution (input, weight, bias, stride, padding, dilation,
792
- true , output_padding, groups);
844
+ Tensor input;
845
+ bool is_batched;
846
+ std::tie (input, is_batched) = batchify (input_, /* num_spatial_dims=*/ 3 , " conv_transpose3d" );
847
+ auto output = at::convolution (
848
+ input, weight, bias, stride, padding, dilation, true , output_padding, groups);
849
+ return is_batched ? output : output.squeeze (0 );
793
850
}
794
851
795
852
at::Tensor convolution (
0 commit comments