Skip to content

Commit e3e0c34

Browse files
zou3519soumith
authored andcommitted
Unify error checking for tesnor.index_copy_ (pytorch#5642)
1 parent d946267 commit e3e0c34

File tree

5 files changed

+52
-6
lines changed

5 files changed

+52
-6
lines changed

aten/src/ATen/Declarations.cwrap

+1-2
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,7 @@
258258
- THIndexTensor* index
259259
]]
260260
[[
261-
name: indexCopy_
262-
python_name: index_copy_
261+
name: _indexCopy_
263262
cname: indexCopy
264263
return: argument 0
265264
arguments:

aten/src/ATen/native/Indexing.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,49 @@ Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value) {
250250
return src.put_(linearIndex, expandedValue);
251251
}
252252

253+
Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
254+
dim = maybe_wrap_dim(dim, self.dim());
255+
256+
if (index.dim() >= 2) {
257+
runtime_error(
258+
"index_copy_(): Index should have dimension 1 or 0 (got %d)",
259+
(int)index.dim());
260+
}
261+
int64_t numIndices = index.numel();
262+
if (source.dim() == 0 && numIndices != 1) {
263+
runtime_error(
264+
"index_copy_(): When source is scalar, index should have one element (got %d)",
265+
(int)numIndices);
266+
}
267+
if (source.dim() > 0 && numIndices != source.size(dim)) {
268+
runtime_error(
269+
"index_copy_(): Number of indices (%d) should be equal to source.size(dim) (%d)",
270+
(int)numIndices, (int)source.size(dim));
271+
}
272+
if (index.type().scalarType() != ScalarType::Long) {
273+
runtime_error("index_copy_(): Expected LongTensor for index");
274+
}
275+
276+
// Check that source and destination slices have the same size
277+
auto selfSlicedSizes = std::vector<int64_t>(self.sizes());
278+
if (selfSlicedSizes.size() > 0) {
279+
selfSlicedSizes.erase(selfSlicedSizes.begin() + dim);
280+
}
281+
auto sourceSlicedSizes = std::vector<int64_t>(source.sizes());
282+
if (sourceSlicedSizes.size() > 0) {
283+
sourceSlicedSizes.erase(sourceSlicedSizes.begin());
284+
}
285+
if (selfSlicedSizes.size() != sourceSlicedSizes.size() ||
286+
!std::equal(selfSlicedSizes.begin(), selfSlicedSizes.end(),
287+
sourceSlicedSizes.begin())) {
288+
std::stringstream ss;
289+
ss << "index_copy_(): Source/destination tensor must have same slice shapes. ";
290+
ss << "Destination slice shape: " << selfSlicedSizes << " at dimension " << dim;
291+
ss << " and source slice shape: " << sourceSlicedSizes << " at dimension 0.";
292+
throw std::runtime_error(ss.str());
293+
}
294+
295+
return self._indexCopy_(dim, index, source);
296+
}
297+
253298
}} // at::native

aten/src/ATen/native/native_functions.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@
294294
- func: index(Tensor self, TensorList indices) -> Tensor
295295
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
296296

297+
- func: index_copy_(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
298+
variants: method
299+
297300
- func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor
298301

299302
- func: is_cuda(Tensor self) -> bool

aten/src/TH/generic/THTensorMath.c

+2-3
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,9 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens
382382
THTensor *tSlice, *sSlice;
383383
int64_t *index_data;
384384

385+
// Error checking for this function has moved to ATen!!
386+
385387
numel = THLongTensor_nElement(index);
386-
THArgCheck(index->nDimension == 1, 3, "Index is supposed to be a vector");
387-
THArgCheck(dim < src->nDimension, 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
388-
THArgCheck(numel == src->size[dim],4,"Number of indices should be equal to source:size(dim)");
389388

390389
index = THLongTensor_newContiguous(index);
391390
index_data = THLongTensor_data(index);

tools/autograd/gen_python_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
'alias', 'contiguous', 'clamp.*', 'is_cuda', 'is_sparse', 'size', 'stride',
1818
'.*_backward', '.*_backward_out', '.*_forward', '.*_forward_out',
1919
'sparse_raw_resize_', '_unsafe_view', 'tensor', 'sparse_coo_tensor',
20-
'_arange.*', '_range.*', '_linspace.*', '_logspace.*'
20+
'_arange.*', '_range.*', '_linspace.*', '_logspace.*', '_indexCopy_',
2121
]
2222

2323
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')

0 commit comments

Comments
 (0)