@@ -668,6 +668,131 @@ scalar_t THTensor_(get4d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_
668
668
return THStorage_ (get)(THTensor_getStoragePtr (tensor), tensor->storage_offset ()+x0*tensor->stride (0 )+x1*tensor->stride (1 )+x2*tensor->stride (2 )+x3*tensor->stride (3 ));
669
669
}
670
670
671
+
672
+ /* Shape manipulation methods */
673
+ void THTensor_ (cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
674
+ {
675
+ THTensor* inputs[2 ];
676
+ inputs[0 ] = ta;
677
+ inputs[1 ] = tb;
678
+ THTensor_ (catArray)(r_, inputs, 2 , dimension);
679
+ }
680
+
681
+ void THTensor_ (check_shape_except_dim)(THTensor *first, THTensor *second, int dimension);
682
+ inline void THTensor_ (check_shape_except_dim)(THTensor *first, THTensor *second, int dimension)
683
+ {
684
+ int first_dims = first->dim ();
685
+ int second_dims = second->dim ();
686
+ THArgCheck (first_dims == second_dims, 0 ,
687
+ " Tensors must have same number of dimensions: got %d and %d" ,
688
+ first_dims, second_dims);
689
+ for (int dim = 0 ; dim < first_dims; dim++) {
690
+ if (dim == dimension) {
691
+ continue ;
692
+ }
693
+ int64_t first_dim_size = first->size (dim);
694
+ int64_t second_dim_size = second->size (dim);
695
+ THArgCheck (first_dim_size == second_dim_size, 0 ,
696
+ " Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d" ,
697
+ dimension, (long long )first_dim_size, (long long )second_dim_size, dim);
698
+ }
699
+ }
700
+
701
+ void THTensor_ (catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension)
702
+ {
703
+ // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
704
+ // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
705
+ // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
706
+ // size (i.e. other empty sizes are not skipped).
707
+ // FIXME: warn if this is the case
708
+ bool allSkipped= true ;
709
+ int64_t nDims = 0 ;
710
+ THTensor *notSkippedTensor; // non-owning reference
711
+ auto should_skip = [](THTensor *t) { return t->is_empty () && t->dim () == 1 ; };
712
+ for (int i = 0 ; i < numInputs; i++) {
713
+ if (should_skip (inputs[i])) {
714
+ continue ;
715
+ }
716
+ // We've found a non-empty tensor
717
+ allSkipped = false ;
718
+ notSkippedTensor = inputs[i];
719
+ nDims = notSkippedTensor->dim ();
720
+ break ;
721
+ }
722
+ if (allSkipped) {
723
+ return ;
724
+ }
725
+
726
+ // Compute cat_dimension based on the non-empty tensor
727
+ THArgCheck (dimension < nDims, 4 , " invalid dimension %d" , dimension);
728
+ THArgCheck (numInputs > 0 , 3 , " invalid number of inputs %d" , numInputs);
729
+
730
+ // Compute size of the result in the cat dimension
731
+ int64_t cat_dim_size = 0 ;
732
+ for (int i = 0 ; i < numInputs; i++) {
733
+ THTensor *tensor = inputs[i];
734
+ if (should_skip (tensor)) {
735
+ continue ;
736
+ }
737
+ THTensor_ (check_shape_except_dim)(notSkippedTensor, tensor, dimension);
738
+ cat_dim_size += tensor->size (dimension);
739
+ }
740
+
741
+ // Compute the size of the result
742
+ std::vector<int64_t > size (nDims);
743
+ for (int dim = 0 ; dim < nDims; dim++) {
744
+ int64_t result_dim_size = notSkippedTensor->size (dim);
745
+ if (dim == dimension) {
746
+ result_dim_size = cat_dim_size;
747
+ }
748
+ size[dim] = result_dim_size;
749
+ }
750
+ THTensor_ (resize)(result, size, {});
751
+
752
+ // Check contiguity of all inputs and result
753
+ bool allContiguous = true ;
754
+ for (int i = 0 ; i < numInputs; i++) {
755
+ if (!should_skip (inputs[i])) {
756
+ allContiguous = allContiguous && THTensor_ (isContiguous)(inputs[i]);
757
+ }
758
+ }
759
+ allContiguous = allContiguous && THTensor_ (isContiguous)(result);
760
+
761
+ // First path is for contiguous inputs along dim 0
762
+ // Second path for non-contiguous
763
+ int64_t offset;
764
+ if (dimension == 0 && allContiguous) {
765
+ scalar_t * result_data = THStorage_ (data)(THTensor_getStoragePtr (result)) + result->storage_offset ();
766
+ offset = 0 ;
767
+ for (int j = 0 ; j < numInputs; j++) {
768
+ if (!should_skip (inputs[j])) {
769
+ THTensor* input0 = inputs[j];
770
+ scalar_t * input0_data = THStorage_ (data)(THTensor_getStoragePtr (input0)) + input0->storage_offset ();
771
+ int64_t input0_size = THTensor_ (nElement)(input0);
772
+ // C standard says you can't pass nullptrs to memcpy, even if the size is 0; ubsan checks this.
773
+ if (input0_size != 0 ) {
774
+ memcpy (result_data + offset, input0_data, input0_size*sizeof (scalar_t ));
775
+ }
776
+ offset += input0_size;
777
+ }
778
+ }
779
+ } else {
780
+ offset = 0 ;
781
+ for (int j = 0 ; j < numInputs; j++) {
782
+ if (!should_skip (inputs[j])) {
783
+ int64_t dimSize = inputs[j]->size (dimension);
784
+ THTensor *nt = THTensor_ (newWithTensor)(result);
785
+ THTensor_ (narrow)(nt, NULL , dimension, offset, dimSize);
786
+ at::Tensor nt__wrap = THTensor_wrap (nt);
787
+ at::Tensor inputs_wrap = THTensor_wrap (inputs[j]);
788
+ at::_copy_same_type_ (nt__wrap, inputs_wrap);
789
+ c10::raw::intrusive_ptr::decref (nt);
790
+ offset += dimSize;
791
+ }
792
+ }
793
+ }
794
+ }
795
+
671
796
THDescBuff THTensor_ (desc)(const THTensor *tensor) {
672
797
const int L = TH_DESC_BUFF_LEN;
673
798
THDescBuff buf;
0 commit comments