Skip to content

Commit 573019f

Browse files
zheng-dalanking520
authored andcommitted
fix type inference in index_copy. (apache#12890)
1 parent 793f9c6 commit 573019f

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/operator/contrib/index_copy.cc

+11-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@
2626
namespace mxnet {
2727
namespace op {
2828

29+
static bool IndexCopyType(const nnvm::NodeAttrs& attrs,
30+
std::vector<int> *in_attrs,
31+
std::vector<int> *out_attrs) {
32+
CHECK_EQ(in_attrs->size(), 3U);
33+
CHECK_EQ(out_attrs->size(), 1U);
34+
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
35+
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
36+
return out_attrs->at(0) != -1;
37+
}
38+
2939
NNVM_REGISTER_OP(_contrib_index_copy)
3040
.describe(R"code(Copies the elements of a `new_tensor` into the `old_tensor` by
3141
selecting the indices in the order given in `index`. The output will be a new tensor
@@ -56,7 +66,7 @@ mx.nd.contrib.index_copy(x, index, t)
5666
.set_num_inputs(3)
5767
.set_num_outputs(1)
5868
.set_attr<nnvm::FInferShape>("FInferShape", IndexCopyShape)
59-
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
69+
.set_attr<nnvm::FInferType>("FInferType", IndexCopyType)
6070
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_contrib_backward_index_copy"})
6171
.set_attr<FCompute>("FCompute<cpu>", IndexCopyForward<cpu>)
6272
.add_argument("old_tensor", "NDArray-or-Symbol", "Old tensor")

tests/python/unittest/test_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4766,7 +4766,7 @@ def test_quantization_op():
47664766
def test_index_copy():
47674767
x = mx.nd.zeros((5,3))
47684768
t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
4769-
index = mx.nd.array([0,4,2])
4769+
index = mx.nd.array([0,4,2], dtype=np.int64)
47704770

47714771
x.attach_grad()
47724772
t.attach_grad()

0 commit comments

Comments
 (0)