Skip to content

Commit 59b261d

Browse files
asitstandspiiswrong
authored andcommitted
[MXNET-68] Random shuffle implementation (apache#10048)
* Random shuffle implementation This operator randomly shuffles an NDArray along the first axis. The order of the elements in each subarray does not change. For exmaple, if an NDArray `x` is shuffled, the order of the subarrays `x[i]` randomly changes but the order of the elements in each `x[i]` does not change. It is modeled on `numpy.random.shuffle`. In cpu, the shuffling of an 1D array is delegated to `__gnu_parallel::random_shuffle`, which utilizes openmp, for clang on linux and gcc on any OS and delegated to `std::shuffle` for other platforms. For an multidimensional array, the usual Fisher-Yates shuffling is implemented. In gpu, it shuffles the array of indices representing the subarrays and then rearrange the elements of the data array according to the shuffled index array. To shuffle the index array, a random key is generated for each index and then the indices are sorted by the keys. The sorting is delegated to mshadow's `SortByKey` which again delegates the call to thrust's `sort_by_key`. * Refactoring to avoid a preprocessing problem in Windows build * Cosmetic changes * Typo * Adding const keyword at several places * Fix the bug that integer arrays are not allowed * Revise the comments to explain the unit test * Add a check for correct array shape * Revised unit test with larger arrays * Replace the custom hash with 'str' * Fix a bug due to the integer arithmetic in python2 * Revise comments for the unit test * Fix the invalid fix in the commit f240714 * Update random.md * Update random.md
1 parent 2442d8e commit 59b261d

File tree

7 files changed

+386
-2
lines changed

7 files changed

+386
-2
lines changed

docs/api/python/ndarray/random.md

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ In the rest of this document, we list routines provided by the `ndarray.random`
3535
normal
3636
poisson
3737
uniform
38+
multinomial
39+
shuffle
3840
mxnet.random.seed
3941
```
4042

docs/api/python/symbol/random.md

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ In the rest of this document, we list routines provided by the `symbol.random` p
3535
normal
3636
poisson
3737
uniform
38+
multinomial
39+
shuffle
3840
mxnet.random.seed
3941
```
4042

python/mxnet/ndarray/random.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial',
27-
'negative_binomial', 'generalized_negative_binomial']
27+
'negative_binomial', 'generalized_negative_binomial', 'shuffle']
2828

2929

3030
def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs):
@@ -431,3 +431,35 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
431431
<NDArray 2 @cpu(0)>
432432
"""
433433
return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs)
434+
435+
436+
def shuffle(data, **kwargs):
437+
"""Shuffle the elements randomly.
438+
439+
This shuffles the array along the first axis.
440+
The order of the elements in each subarray does not change.
441+
For example, if a 2D array is given, the order of the rows randomly changes,
442+
but the order of the elements in each row does not change.
443+
444+
Parameters
445+
----------
446+
data : NDArray
447+
Input data array.
448+
out : NDArray
449+
Array to store the result.
450+
451+
Examples
452+
--------
453+
>>> data = mx.nd.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
454+
>>> mx.nd.random.shuffle(data)
455+
[[ 0. 1. 2.]
456+
[ 6. 7. 8.]
457+
[ 3. 4. 5.]]
458+
<NDArray 2x3 @cpu(0)>
459+
>>> mx.nd.random.shuffle(data)
460+
[[ 3. 4. 5.]
461+
[ 0. 1. 2.]
462+
[ 6. 7. 8.]]
463+
<NDArray 2x3 @cpu(0)>
464+
"""
465+
return _internal._shuffle(data, **kwargs)

python/mxnet/symbol/random.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial',
26-
'negative_binomial', 'generalized_negative_binomial']
26+
'negative_binomial', 'generalized_negative_binomial', 'shuffle']
2727

2828

2929
def _random_helper(random, sampler, params, shape, dtype, kwargs):
@@ -247,3 +247,34 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs):
247247
reward as head gradient w.r.t. this array to estimate gradient.
248248
"""
249249
return _internal._sample_multinomial(data, shape, get_prob, **kwargs)
250+
251+
252+
def shuffle(data, **kwargs):
253+
"""Shuffle the elements randomly.
254+
255+
This shuffles the array along the first axis.
256+
The order of the elements in each subarray does not change.
257+
For example, if a 2D array is given, the order of the rows randomly changes,
258+
but the order of the elements in each row does not change.
259+
260+
Parameters
261+
----------
262+
data : NDArray
263+
Input data array.
264+
Examples
265+
--------
266+
>>> data = mx.nd.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
267+
>>> a = mx.sym.Variable('a')
268+
>>> b = mx.sym.random.shuffle(a)
269+
>>> b.eval(a=data)
270+
[[ 0. 1. 2.]
271+
[ 6. 7. 8.]
272+
[ 3. 4. 5.]]
273+
<NDArray 2x3 @cpu(0)>
274+
>>> b.eval(a=data)
275+
[[ 3. 4. 5.]
276+
[ 0. 1. 2.]
277+
[ 6. 7. 8.]]
278+
<NDArray 2x3 @cpu(0)>
279+
"""
280+
return _internal._shuffle(data, **kwargs)

src/operator/random/shuffle_op.cc

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2018 by Contributors
22+
* \file shuffle_op.cc
23+
* \brief Operator to shuffle elements of an NDArray
24+
*/
25+
#if (__GNUC__ > 4 && !defined(__clang__major__)) || (__clang_major__ > 4 && __linux__)
26+
#define USE_GNU_PARALLEL_SHUFFLE
27+
#endif
28+
29+
#include <mxnet/operator_util.h>
30+
#include <algorithm>
31+
#include <random>
32+
#include <vector>
33+
#ifdef USE_GNU_PARALLEL_SHUFFLE
34+
#include <parallel/algorithm>
35+
#endif
36+
#include "../elemwise_op_common.h"
37+
38+
namespace mxnet {
39+
namespace op {
40+
41+
namespace {
42+
43+
template<typename DType, typename Rand>
44+
void Shuffle1D(DType* const out, const index_t size, Rand* const prnd) {
45+
#ifdef USE_GNU_PARALLEL_SHUFFLE
46+
auto rand_n = [prnd](index_t n) {
47+
std::uniform_int_distribution<index_t> dist(0, n - 1);
48+
return dist(*prnd);
49+
};
50+
__gnu_parallel::random_shuffle(out, out + size, rand_n);
51+
#else
52+
std::shuffle(out, out + size, *prnd);
53+
#endif
54+
}
55+
56+
template<typename DType, typename Rand>
57+
void ShuffleND(DType* const out, const index_t size, const index_t first_axis_len,
58+
Rand* const prnd) {
59+
// Fisher-Yates shuffling
60+
const index_t stride = size / first_axis_len;
61+
auto rand_n = [prnd](index_t n) {
62+
std::uniform_int_distribution<index_t> dist(0, n - 1);
63+
return dist(*prnd);
64+
};
65+
CHECK_GT(first_axis_len, 0U);
66+
for (index_t i = first_axis_len - 1; i > 0; --i) {
67+
const index_t j = rand_n(i + 1);
68+
if (i != j) {
69+
std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j);
70+
}
71+
}
72+
}
73+
74+
} // namespace
75+
76+
void ShuffleForwardCPU(const nnvm::NodeAttrs& attrs,
77+
const OpContext& ctx,
78+
const std::vector<TBlob>& inputs,
79+
const std::vector<OpReqType>& req,
80+
const std::vector<TBlob>& outputs) {
81+
using namespace mxnet_op;
82+
if (req[0] == kNullOp) {
83+
return;
84+
}
85+
CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo";
86+
const TShape& input_shape = inputs[0].shape_;
87+
const index_t size = inputs[0].Size();
88+
const index_t first_axis_len = input_shape[0];
89+
Stream<cpu> *s = ctx.get_stream<cpu>();
90+
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
91+
Tensor<cpu, 1, DType> in = inputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s);
92+
Tensor<cpu, 1, DType> out = outputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s);
93+
auto& prnd = ctx.requested[0].get_random<cpu, index_t>(ctx.get_stream<cpu>())->GetRndEngine();
94+
if (req[0] != kWriteInplace) {
95+
std::copy(in.dptr_, in.dptr_ + size, out.dptr_);
96+
}
97+
if (input_shape.ndim() == 1) {
98+
Shuffle1D(out.dptr_, size, &prnd);
99+
} else {
100+
ShuffleND(out.dptr_, size, first_axis_len, &prnd);
101+
}
102+
});
103+
}
104+
105+
106+
// No parameter is declared.
107+
// No backward computation is registered. Shuffling is not differentiable.
108+
109+
NNVM_REGISTER_OP(_shuffle)
110+
.add_alias("shuffle")
111+
.describe(R"code(Randomly shuffle the elements.
112+
113+
This shuffles the array along the first axis.
114+
The order of the elements in each subarray does not change.
115+
For example, if a 2D array is given, the order of the rows randomly changes,
116+
but the order of the elements in each row does not change.
117+
)code")
118+
.set_num_inputs(1)
119+
.set_num_outputs(1)
120+
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
121+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
122+
.set_attr<FResourceRequest>("FResourceRequest",
123+
[](const nnvm::NodeAttrs& attrs) {
124+
return std::vector<ResourceRequest>{ResourceRequest::kRandom, ResourceRequest::kTempSpace};
125+
})
126+
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
127+
[](const NodeAttrs& attrs) {
128+
return std::vector<std::pair<int, int>>{{0, 0}};
129+
})
130+
.set_attr<FCompute>("FCompute<cpu>", ShuffleForwardCPU)
131+
.add_argument("data", "NDArray-or-Symbol", "Data to be shuffled.");
132+
133+
} // namespace op
134+
} // namespace mxnet

src/operator/random/shuffle_op.cu

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2018 by Contributors
22+
* \file shuffle_op.cc
23+
* \brief Operator to shuffle elements of an NDArray
24+
*/
25+
#include <mxnet/operator_util.h>
26+
#include <algorithm>
27+
#include <random>
28+
#include <vector>
29+
#include "../elemwise_op_common.h"
30+
#include "../tensor/init_op.h"
31+
32+
namespace mxnet {
33+
namespace op {
34+
35+
namespace {
36+
37+
struct CopyForShuffle {
38+
template<typename DType>
39+
MSHADOW_XINLINE static void Map(int i, const DType* const in, DType* out,
40+
const index_t* indices, const index_t stride) {
41+
out[i] = in[indices[i / stride] * stride + i % stride];
42+
}
43+
};
44+
45+
} // namespace
46+
47+
void ShuffleForwardGPU(const nnvm::NodeAttrs& attrs,
48+
const OpContext& ctx,
49+
const std::vector<TBlob>& inputs,
50+
const std::vector<OpReqType>& req,
51+
const std::vector<TBlob>& outputs) {
52+
using namespace mxnet_op;
53+
if (req[0] == kNullOp) {
54+
return;
55+
}
56+
CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo";
57+
const TShape& input_shape = inputs[0].shape_;
58+
const index_t size = inputs[0].Size();
59+
const index_t first_axis_len = input_shape[0];
60+
const index_t stride = size / first_axis_len;
61+
Stream<gpu> *s = ctx.get_stream<gpu>();
62+
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
63+
using KeyType = index_t;
64+
Tensor<gpu, 1, DType> in = inputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
65+
Tensor<gpu, 1, DType> out = outputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
66+
Random<gpu, KeyType> *prnd = ctx.requested[0].get_random<gpu, KeyType>(s);
67+
if (input_shape.ndim() == 1) {
68+
if (req[0] != kWriteInplace) {
69+
Copy(out, in, s);
70+
}
71+
Tensor<gpu, 1, KeyType> keys =
72+
ctx.requested[1].get_space_typed<gpu, 1, KeyType>(Shape1(size), s);
73+
prnd->GetRandInt(keys);
74+
SortByKey(keys, out, true);
75+
} else {
76+
const size_t tmp_space_size = req[0] == kWriteInplace ?
77+
2 * first_axis_len * sizeof(index_t) + size * sizeof(DType) :
78+
2 * first_axis_len * sizeof(index_t);
79+
Tensor<gpu, 1, char> tmp_space =
80+
ctx.requested[1].get_space_typed<gpu, 1, char>(Shape1(tmp_space_size), s);
81+
char* tmp_space_ptr = tmp_space.dptr_;
82+
Tensor<gpu, 1, index_t> indices(reinterpret_cast<index_t*>(tmp_space_ptr),
83+
Shape1(first_axis_len), s);
84+
tmp_space_ptr += sizeof(index_t) * first_axis_len;
85+
Kernel<range_fwd, gpu>::Launch(s, first_axis_len, 1, 0U, 1U, kWriteTo, indices.dptr_);
86+
Tensor<gpu, 1, KeyType> keys(reinterpret_cast<KeyType*>(tmp_space_ptr),
87+
Shape1(first_axis_len), s);
88+
tmp_space_ptr += sizeof(KeyType) * first_axis_len;
89+
prnd->GetRandInt(keys);
90+
SortByKey(keys, indices, true);
91+
if (req[0] == kWriteInplace) {
92+
Tensor<gpu, 1, DType> buf(reinterpret_cast<DType*>(tmp_space_ptr), Shape1(size), s);
93+
Copy(buf, in, s);
94+
Kernel<CopyForShuffle, gpu>::Launch(s, size, buf.dptr_, out.dptr_, indices.dptr_, stride);
95+
} else {
96+
Kernel<CopyForShuffle, gpu>::Launch(s, size, in.dptr_, out.dptr_, indices.dptr_, stride);
97+
}
98+
}
99+
});
100+
}
101+
102+
NNVM_REGISTER_OP(_shuffle)
103+
.set_attr<FCompute>("FCompute<gpu>", ShuffleForwardGPU);
104+
105+
} // namespace op
106+
} // namespace mxnet

0 commit comments

Comments
 (0)