@@ -250,4 +250,49 @@ Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value) {
250
250
return src.put_ (linearIndex, expandedValue);
251
251
}
252
252
253
+ Tensor & index_copy_ (Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
254
+ dim = maybe_wrap_dim (dim, self.dim ());
255
+
256
+ if (index .dim () >= 2 ) {
257
+ runtime_error (
258
+ " index_copy_(): Index should have dimension 1 or 0 (got %d)" ,
259
+ (int )index .dim ());
260
+ }
261
+ int64_t numIndices = index .numel ();
262
+ if (source.dim () == 0 && numIndices != 1 ) {
263
+ runtime_error (
264
+ " index_copy_(): When source is scalar, index should have one element (got %d)" ,
265
+ (int )numIndices);
266
+ }
267
+ if (source.dim () > 0 && numIndices != source.size (dim)) {
268
+ runtime_error (
269
+ " index_copy_(): Number of indices (%d) should be equal to source.size(dim) (%d)" ,
270
+ (int )numIndices, (int )source.size (dim));
271
+ }
272
+ if (index .type ().scalarType () != ScalarType::Long) {
273
+ runtime_error (" index_copy_(): Expected LongTensor for index" );
274
+ }
275
+
276
+ // Check that source and destination slices have the same size
277
+ auto selfSlicedSizes = std::vector<int64_t >(self.sizes ());
278
+ if (selfSlicedSizes.size () > 0 ) {
279
+ selfSlicedSizes.erase (selfSlicedSizes.begin () + dim);
280
+ }
281
+ auto sourceSlicedSizes = std::vector<int64_t >(source.sizes ());
282
+ if (sourceSlicedSizes.size () > 0 ) {
283
+ sourceSlicedSizes.erase (sourceSlicedSizes.begin ());
284
+ }
285
+ if (selfSlicedSizes.size () != sourceSlicedSizes.size () ||
286
+ !std::equal (selfSlicedSizes.begin (), selfSlicedSizes.end (),
287
+ sourceSlicedSizes.begin ())) {
288
+ std::stringstream ss;
289
+ ss << " index_copy_(): Source/destination tensor must have same slice shapes. " ;
290
+ ss << " Destination slice shape: " << selfSlicedSizes << " at dimension " << dim;
291
+ ss << " and source slice shape: " << sourceSlicedSizes << " at dimension 0." ;
292
+ throw std::runtime_error (ss.str ());
293
+ }
294
+
295
+ return self._indexCopy_ (dim, index , source);
296
+ }
297
+
253
298
}} // at::native
0 commit comments