Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-68] Random shuffle implementation #10048

Merged
merged 15 commits into from
Mar 20, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/ndarray/random.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ In the rest of this document, we list routines provided by the `ndarray.random`
normal
poisson
uniform
multinomial
shuffle
mxnet.random.seed
```

Expand Down
2 changes: 2 additions & 0 deletions docs/api/python/symbol/random.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ In the rest of this document, we list routines provided by the `symbol.random` p
normal
poisson
uniform
multinomial
shuffle
mxnet.random.seed
```

Expand Down
34 changes: 33 additions & 1 deletion python/mxnet/ndarray/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial',
'negative_binomial', 'generalized_negative_binomial']
'negative_binomial', 'generalized_negative_binomial', 'shuffle']


def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs):
Expand Down Expand Up @@ -431,3 +431,35 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
<NDArray 2 @cpu(0)>
"""
return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs)


def shuffle(data, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python interface seems unnecessary. You can register the operator with name _random_shuffle in C++.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still valid @reminisce ?

"""Shuffle the elements randomly.

This shuffles the array along the first axis.
The order of the elements in each subarray does not change.
For example, if a 2D array is given, the order of the rows randomly changes,
but the order of the elements in each row does not change.

Parameters
----------
data : NDArray
Input data array.
out : NDArray
Array to store the result.

Examples
--------
>>> data = mx.nd.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
>>> mx.nd.random.shuffle(data)
[[ 0. 1. 2.]
[ 6. 7. 8.]
[ 3. 4. 5.]]
<NDArray 2x3 @cpu(0)>
>>> mx.nd.random.shuffle(data)
[[ 3. 4. 5.]
[ 0. 1. 2.]
[ 6. 7. 8.]]
<NDArray 2x3 @cpu(0)>
"""
return _internal._shuffle(data, **kwargs)
33 changes: 32 additions & 1 deletion python/mxnet/symbol/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial',
'negative_binomial', 'generalized_negative_binomial']
'negative_binomial', 'generalized_negative_binomial', 'shuffle']


def _random_helper(random, sampler, params, shape, dtype, kwargs):
Expand Down Expand Up @@ -247,3 +247,34 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs):
reward as head gradient w.r.t. this array to estimate gradient.
"""
return _internal._sample_multinomial(data, shape, get_prob, **kwargs)


def shuffle(data, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python interface seems unnecessary. You can register the operator with name _random_shuffle in C++.

"""Shuffle the elements randomly.

This shuffles the array along the first axis.
The order of the elements in each subarray does not change.
For example, if a 2D array is given, the order of the rows randomly changes,
but the order of the elements in each row does not change.

Parameters
----------
data : NDArray
Input data array.
Examples
--------
>>> data = mx.nd.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
>>> a = mx.sym.Variable('a')
>>> b = mx.sym.random.shuffle(a)
>>> b.eval(a=data)
[[ 0. 1. 2.]
[ 6. 7. 8.]
[ 3. 4. 5.]]
<NDArray 2x3 @cpu(0)>
>>> b.eval(a=data)
[[ 3. 4. 5.]
[ 0. 1. 2.]
[ 6. 7. 8.]]
<NDArray 2x3 @cpu(0)>
"""
return _internal._shuffle(data, **kwargs)
132 changes: 132 additions & 0 deletions src/operator/random/shuffle_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file shuffle_op.cc
* \brief Operator to shuffle elements of an NDArray
*/
#if (__GNUC__ > 4 && !defined(__clang__major__)) || (__clang_major__ > 4 && __linux__)
#define USE_GNU_PARALLEL_SHUFFLE
#endif

#include <mxnet/operator_util.h>
#include <algorithm>
#include <random>
#include <vector>
#ifdef USE_GNU_PARALLEL_SHUFFLE
#include <parallel/algorithm>
#endif
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {

namespace {

template<typename DType, typename Rand>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to indent in namespace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

void Shuffle1D(DType* out, index_t size, Rand* prnd) {
#ifdef USE_GNU_PARALLEL_SHUFFLE
auto rand_n = [prnd](index_t n) {
std::uniform_int_distribution<index_t> dist(0, n - 1);
return dist(*prnd);
};
__gnu_parallel::random_shuffle(out, out + size, rand_n);
#else
std::shuffle(out, out + size, *prnd);
#endif
}

template<typename DType, typename Rand>
void ShuffleND(DType* out, index_t size, index_t first_axis_len, Rand* prnd) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add const qualifier to the arguments if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// Fisher-Yates shuffling
const index_t stride = size / first_axis_len;
auto rand_n = [prnd](index_t n) {
std::uniform_int_distribution<index_t> dist(0, n - 1);
return dist(*prnd);
};
for (index_t i = first_axis_len - 1; i > 0; --i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add CHECK_GT(first_axis_len, 0U); above this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

index_t j = rand_n(i + 1);
if (i != j) {
std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j);
}
}
}

} // namespace

void ShuffleForwardCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) {
return;
}
CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo";
const TShape input_shape = inputs[0].shape_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const TShape&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const index_t size = inputs[0].Size();
const index_t first_axis_len = input_shape[0];
Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason of not supporting integer types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with MSHADOW_TYPE_SWITCH

Tensor<cpu, 1, DType> in = inputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s);
Tensor<cpu, 1, DType> out = outputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s);
auto& prnd = ctx.requested[0].get_random<cpu, index_t>(ctx.get_stream<cpu>())->GetRndEngine();
if (req[0] != kWriteInplace) {
std::copy(in.dptr_, in.dptr_ + size, out.dptr_);
}
if (input_shape.ndim() == 1) {
Shuffle1D(out.dptr_, size, &prnd);
} else {
ShuffleND(out.dptr_, size, first_axis_len, &prnd);
}
});
}


// No parameter is declared.
// No backward computation is registered. Shuffling is not differentiable.

NNVM_REGISTER_OP(_shuffle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why register as internal? You can register with name _random_shuflle and remove the the python interface.

Copy link
Contributor Author

@asitstands asitstands Mar 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not work. random module is not generated automatically as _OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_'] in base.py does not contain 'random' and random.py also does not care about any generated module. Is there any reason for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I forgot @piiswrong refactored it. It makes sense to keep the python interface.

.add_alias("shuffle")
.describe(R"code(Randomly shuffle the elements.

This shuffles the array along the first axis.
The order of the elements in each subarray does not change.
For example, if a 2D array is given, the order of the rows randomly changes,
but the order of the elements in each row does not change.
)code")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int>>{{0, 0}};
})
.set_attr<FCompute>("FCompute<cpu>", ShuffleForwardCPU)
.add_argument("data", "NDArray-or-Symbol", "Data to be shuffled.");

} // namespace op
} // namespace mxnet
106 changes: 106 additions & 0 deletions src/operator/random/shuffle_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file shuffle_op.cc
* \brief Operator to shuffle elements of an NDArray
*/
#include <mxnet/operator_util.h>
#include <algorithm>
#include <random>
#include <vector>
#include "../elemwise_op_common.h"
#include "../tensor/init_op.h"

namespace mxnet {
namespace op {

namespace {

struct CopyForShuffle {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const DType* in, DType* out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add const for in and stride.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const index_t* indices, index_t stride) {
out[i] = in[indices[i / stride] * stride + i % stride];
}
};

} // namespace

void ShuffleForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) {
return;
}
CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo";
const TShape input_shape = inputs[0].shape_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const TShape&.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const index_t size = inputs[0].Size();
const index_t first_axis_len = input_shape[0];
const index_t stride = size / first_axis_len;
Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are integers not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with MSHADOW_TYPE_SWITCH.

using KeyType = index_t;
Tensor<gpu, 1, DType> in = inputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
Tensor<gpu, 1, DType> out = outputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
Random<gpu, KeyType> *prnd = ctx.requested[0].get_random<gpu, KeyType>(s);
if (input_shape.ndim() == 1) {
if (req[0] != kWriteInplace) {
Copy(out, in, s);
}
Tensor<gpu, 1, KeyType> keys =
ctx.requested[1].get_space_typed<gpu, 1, KeyType>(Shape1(size), s);
prnd->GetRandInt(keys);
SortByKey(keys, out, true);
} else {
size_t tmp_space_size = req[0] == kWriteInplace ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

2 * first_axis_len * sizeof(index_t) + size * sizeof(DType) :
2 * first_axis_len * sizeof(index_t);
Tensor<gpu, 1, char> tmp_space =
ctx.requested[1].get_space_typed<gpu, 1, char>(Shape1(tmp_space_size), s);
char* tmp_space_ptr = tmp_space.dptr_;
Tensor<gpu, 1, index_t> indices(reinterpret_cast<index_t*>(tmp_space_ptr),
Shape1(first_axis_len), s);
tmp_space_ptr += sizeof(index_t) * first_axis_len;
Kernel<range_fwd, gpu>::Launch(s, first_axis_len, 1, 0U, 1U, kWriteTo, indices.dptr_);
Tensor<gpu, 1, KeyType> keys(reinterpret_cast<KeyType*>(tmp_space_ptr),
Shape1(first_axis_len), s);
tmp_space_ptr += sizeof(KeyType) * first_axis_len;
prnd->GetRandInt(keys);
SortByKey(keys, indices, true);
if (req[0] == kWriteInplace) {
Tensor<gpu, 1, DType> buf(reinterpret_cast<DType*>(tmp_space_ptr), Shape1(size), s);
Copy(buf, in, s);
Kernel<CopyForShuffle, gpu>::Launch(s, size, buf.dptr_, out.dptr_, indices.dptr_, stride);
} else {
Kernel<CopyForShuffle, gpu>::Launch(s, size, in.dptr_, out.dptr_, indices.dptr_, stride);
}
}
});
}

NNVM_REGISTER_OP(_shuffle)
.set_attr<FCompute>("FCompute<gpu>", ShuffleForwardGPU);

} // namespace op
} // namespace mxnet
Loading