diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index a49dc15199d8b5..729cf467ea7675 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -162,7 +162,7 @@ struct ConcatTensorsForAllReduce { void operator()(const DeviceContext &context, const std::vector &dense_tensors_, Tensor *p_dense_contents) { - operators::math::ConcatFunctor concat_functor_; + phi::funcs::ConcatFunctor concat_functor_; concat_functor_( context, dense_tensors_, @@ -191,7 +191,7 @@ struct SplitTensorsForAllReduce { shape_refer.emplace_back(&tensor); } - operators::math::SplitFunctor split_functor_; + phi::funcs::SplitFunctor split_functor_; split_functor_(context, *in, shape_refer, 0, &outs); } }; diff --git a/paddle/fluid/distributed/collective/reducer.h b/paddle/fluid/distributed/collective/reducer.h index c16b194ac9c073..661675c449117e 100644 --- a/paddle/fluid/distributed/collective/reducer.h +++ b/paddle/fluid/distributed/collective/reducer.h @@ -22,12 +22,12 @@ #include "paddle/fluid/eager/api/utils/hook_utils.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/utils.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/fused_api.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/utils/string/string_helper.h" diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.h b/paddle/fluid/distributed/index_dataset/index_sampler.h index e8fbf39ce9341b..f32cd62445d40c 100644 --- a/paddle/fluid/distributed/index_dataset/index_sampler.h +++ b/paddle/fluid/distributed/index_dataset/index_sampler.h @@ -18,8 +18,8 @@ #include "paddle/fluid/distributed/index_dataset/index_wrapper.h" #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/math/sampler.h" namespace paddle { namespace distributed { @@ -107,9 +107,8 @@ class LayerWiseSampler : public IndexSampler { while (layer_index >= start_sample_layer_) { auto layer_codes = tree_->GetLayerCodes(layer_index); layer_ids_.push_back(tree_->GetNodes(layer_codes)); - auto sampler_temp = - std::make_shared( - layer_ids_[idx].size() - 1, seed_); + auto sampler_temp = std::make_shared( + layer_ids_[idx].size() - 1, seed_); sampler_vec_.push_back(sampler_temp); layer_index--; idx++; @@ -131,7 +130,7 @@ class LayerWiseSampler : public IndexSampler { std::shared_ptr tree_{nullptr}; int seed_{0}; int start_sample_layer_{1}; - std::vector> sampler_vec_; + std::vector> sampler_vec_; std::vector> layer_ids_; }; diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 62459827d3c390..83da397e8a7cc4 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -671,7 +671,6 @@ if(WITH_DISTRIBUTE) glog index_sampler index_wrapper - sampler index_dataset_proto lod_rank_table framework_io diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 526935a5182be6..3d6f38aac1ecea 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/parallel_context.h" -#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" @@ -74,7 +74,7 @@ static void ConcatTensorsForAllReduce( const DeviceContext &context, const std::vector &dense_tensors_, framework::Variable *p_dense_contents) { - operators::math::ConcatFunctor concat_functor_; + phi::funcs::ConcatFunctor concat_functor_; concat_functor_(context, dense_tensors_, 0, @@ -102,7 +102,7 @@ static void SplitTensorsForAllReduce( phi::funcs::StridedMemcpyWithAxis0( context, *in, shape_refer, &outs); } else { - operators::math::SplitFunctor split_functor_; + phi::funcs::SplitFunctor split_functor_; split_functor_(context, *in, shape_refer, 0, &outs); } } @@ -179,8 +179,7 @@ void SplitTensorsForAllReduce( outs.emplace_back(&tensor); shape_refer.emplace_back(&tensor); } - operators::math::SplitFunctor - split_functor_; + phi::funcs::SplitFunctor split_functor_; split_functor_(context, *in, shape_refer, 0, &outs); } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a498b2aca31963..43d33643713420 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -134,7 +134,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} phi common) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_utils lod_tensor unpooling lod_rank_table context_project executor static_prim_api) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc static_prim_api static_utils static_global_utils prim_utils) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} cos_sim_functor concat_and_split sampler sample_prob tree2col) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} cos_sim_functor concat_and_split tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} beam_search) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index fae4ecbf9eb2b3..9c21dfbb1d327f 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device_context.h" @@ -77,7 +77,7 @@ struct ArrayToLoDFunctor { template template void ArrayToLoDFunctorImpl::apply() { - math::ConcatFunctor func; + phi::funcs::ConcatFunctor func; func(*dev_ctx_, prev_functor_->in, 0, prev_functor_->out); } diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index 7211c0f295d01f..22610a8fb1f15d 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -17,8 +17,8 @@ limitations under the License. */ #include #include "paddle/phi/core/distributed/comm_context_manager.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/common/flags.h" @@ -151,7 +151,7 @@ class CConcatOpCUDAKernel : public framework::OpKernel { offset += rows_per_tensor; } - math::ConcatFunctor functor; + phi::funcs::ConcatFunctor functor; out->mutable_data(out_dims, place); auto& dev_ctx2 = ctx.template device_context(); functor(dev_ctx2, inputs, axis, out); diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu index b2bbd9c82095c8..65cb7d3043d18d 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -23,10 +23,10 @@ namespace cub = hipcub; #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/mixed_vector.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index 42f6a4786fb25b..ff9197f40f8d76 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/lod_utils.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" namespace paddle { namespace framework { @@ -88,7 +88,7 @@ struct LoDTensorToArrayFunctor { template template void LoDTensorToArrayFunctorImpl::apply() { - math::SplitFunctor func; + phi::funcs::SplitFunctor func; func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 0e0423bd64ff45..e7545b8fd4f2df 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -10,8 +10,6 @@ math_library(concat_and_split DEPS phi common) math_library(context_project DEPS phi common) math_library(cos_sim_functor) math_library(depthwise_conv) -math_library(sample_prob) -math_library(sampler DEPS phi common) if(WITH_XPU) math_library(beam_search DEPS phi common beam_search_xpu) diff --git a/paddle/fluid/operators/math/sample_prob.cc b/paddle/fluid/operators/math/sample_prob.cc deleted file mode 100644 index 18321cf9b9ece6..00000000000000 --- a/paddle/fluid/operators/math/sample_prob.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/sample_prob.h" - -namespace paddle { -namespace operators { -namespace math {} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.cu b/paddle/fluid/operators/math/sample_prob.cu deleted file mode 100644 index 1d70b402104f58..00000000000000 --- a/paddle/fluid/operators/math/sample_prob.cu +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include - -#include -#include - -#include "paddle/common/ddim.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/sample_prob.h" -#include "paddle/fluid/operators/math/sampler.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -namespace math { - -template -__device__ T gpu_adjust_prob(const T prob, - const int num_samples, - const int num_tries) { - if (num_samples == num_tries) { - return prob * num_samples; - } else { - return -expm1(num_tries * log1p(-prob)); - } -} - -class GPULogUniformSampler { - public: - __device__ int64_t Sample(float random, - const int range, - const float log_range) const; - __device__ float Probability(int64_t value, const float log_range) const; -}; - -__device__ int64_t GPULogUniformSampler::Sample(float random, - const int range, - const float log_range) const { - // Got Log Uniform distribution from uniform distribution by - // inverse_transform_sampling method - const int64_t value = static_cast(exp(random * log_range)) - 1; - // Mathematically, value should be <= range_, but might not be due to some - // floating point roundoff, so we mod by range_. - return value % range; -} - -__device__ float GPULogUniformSampler::Probability( - int64_t value, const float log_range) const { - // Given f(x) = 1/[(x+1) * log_range_] - // The value's probability is integral of f(x) from value to (value + 1) - return (log((value + 2.0) / (value + 1.0))) / log_range; -} - -template -__global__ void SamplingCondidate(const size_t n, - const int num_tries, - const int range, - const float log_range, - const int num_true, - const std::size_t num_samples, - const int64_t* label_data, - int64_t* samples_data, - T* probabilities_data) { - const int num_sampled_classes = num_true + num_samples; - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int step_size = 0; - GPULogUniformSampler sampler; - - for (; idx < n; idx += blockDim.x * gridDim.x) { - int col_idx = idx % num_sampled_classes; - int row_idx = idx / num_sampled_classes; - if (col_idx < num_true) { - samples_data[idx] = label_data[row_idx * num_true + col_idx]; - } else { - samples_data[idx] = samples_data[col_idx]; - } - probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range); - probabilities_data[idx] = - gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries); - } -} - -template -int UniqSampler(const Sampler& sampler, - const std::size_t num_samples, - int64_t* samples_data) { - // sample num_samles unique samples for an example, note that they are not - // all negative samples - std::unordered_set tmp_samples; - tmp_samples.clear(); - int num_tries = 0; - int j = 0; - while (j < num_samples) { - ++num_tries; - auto v = sampler.Sample(); - auto insert_ok = tmp_samples.insert(v).second; - if (!insert_ok) { - continue; - } - samples_data[j] = v; - ++j; - } - return num_tries; -} - -template -void GPUSampleWithProb::operator()(const phi::GPUContext& context, - const int seed, - const int dict_size, - const bool uniq, - const std::size_t num_samples, - const phi::DenseTensor* L, - phi::DenseTensor* S, - phi::DenseTensor* P) { - // UNDERSTAND: dimension issues - const auto lbl_dim = L->dims(); - const int batch_size = lbl_dim[0]; - const int num_true = lbl_dim[1]; - const int num_sampled_classes = num_true + num_samples; - framework::DDim ret_dim{batch_size, num_sampled_classes}; - - // UNDERSTAND: raw data view - const int64_t* label_data = L->data(); - int64_t* samples_data = S->data(); - T* probabilities_data = P->data(); - - int s_size = num_samples; - framework::DDim s_dim{s_size}; - phi::DenseTensor s; - int64_t* s_data = s.mutable_data(s_dim, platform::CPUPlace()); - - math::LogUniformSampler sampler(dict_size, seed); - - int range = dict_size; - float log_range = log(range + 1); - - int num_tries = UniqSampler(sampler, num_samples, s_data); - VLOG(1) << "num_tries: " << num_tries; - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(samples_data + num_true, - s_data, - sizeof(int64_t) * num_samples, - hipMemcpyHostToDevice)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(samples_data + num_true, - s_data, - sizeof(int64_t) * num_samples, - cudaMemcpyHostToDevice)); -#endif - - int threads = 512; - const size_t size = batch_size * num_sampled_classes; - int grid = (batch_size * num_sampled_classes + threads - 1) / threads; -#ifdef PADDLE_WITH_HIP - hipLaunchKernelGGL(HIP_KERNEL_NAME(SamplingCondidate), - dim3(grid), - dim3(threads), - 0, - context.stream(), - size, - num_tries, - range, - log_range, - num_true, - num_samples, - label_data, - samples_data, - probabilities_data); -#else - SamplingCondidate - <<>>(size, - num_tries, - range, - log_range, - num_true, - num_samples, - label_data, - samples_data, - probabilities_data); -#endif -} - -template class GPUSampleWithProb; -template class GPUSampleWithProb; -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.h b/paddle/fluid/operators/math/sample_prob.h deleted file mode 100644 index f30ada2f1f3c52..00000000000000 --- a/paddle/fluid/operators/math/sample_prob.h +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -#include "paddle/common/ddim.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/sampler.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -namespace paddle { -namespace operators { -namespace math { - -/* UNDERSTAND: utility function to adjust probability for unique sampling, -return whatever as it is if not using unique samping */ -template -static T adjust_prob(const T prob, const int num_samples, const int num_tries) { - if (num_samples == num_tries) { - return prob * num_samples; - } else { - return -expm1(num_tries * log1p(-prob)); - } -} - -template -class SampleWithProb { - public: - void operator()(const DeviceContext& context, - const Sampler& sampler, - const std::size_t num_samples, - const phi::DenseTensor* L, - phi::DenseTensor* S, - phi::DenseTensor* P) { - // UNDERSTAND: dimension issues - const auto& lbl_dim = L->dims(); - const int batch_size = lbl_dim[0]; - const int num_true = lbl_dim[1]; - const int num_sampled_classes = num_true + num_samples; - framework::DDim ret_dim{batch_size, num_sampled_classes}; - - // UNDERSTAND: raw data view - const int64_t* label_data = L->data(); - int64_t* samples_data = - S->mutable_data(ret_dim, context.GetPlace()); - T* probabilities_data = P->mutable_data(ret_dim, context.GetPlace()); - - // temp sets for unique sampling - std::unordered_set tmp_samples; - int j = 0; // column index - // add true labels, not that efficient - while (j < num_true) { - for (int i = 0; i < batch_size; ++i) { - auto samples_index = i * num_sampled_classes + j; - auto v = label_data[i * num_true + j]; - samples_data[samples_index] = v; - probabilities_data[samples_index] = sampler.Probability(v); - } - ++j; - } - - // sample num_samles unique samples for an example, note that they are not - // all negative samples - tmp_samples.clear(); - int num_tries = 0; - while (j < num_sampled_classes) { - ++num_tries; - auto v = sampler.Sample(); - auto insert_ok = tmp_samples.insert(v).second; - if (!insert_ok) { - continue; - } - auto p = sampler.Probability(v); - for (int i = 0; i < batch_size; ++i) { - auto samples_index = i * num_sampled_classes + j; - samples_data[samples_index] = v; - probabilities_data[samples_index] = p; - } - ++j; - } - - // compute Q(y|x), because of unique sampling, probabilities need to be - // adjusted - for (int k = 0; k < num_sampled_classes; ++k) { - for (int i = 0; i < batch_size; ++i) { - auto samples_index = i * num_sampled_classes + k; - probabilities_data[samples_index] = adjust_prob( - probabilities_data[samples_index], num_samples, num_tries); - } - } - } -}; - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -class GPUSampleWithProb { - public: - void operator()(const phi::GPUContext& context, - const int seed, - const int dict_size, - const bool uniq, - const std::size_t num_samples, - const phi::DenseTensor* L, - phi::DenseTensor* S, - phi::DenseTensor* P); -}; -#endif -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sampler.cc b/paddle/fluid/operators/math/sampler.cc deleted file mode 100644 index 0ea4336e92ec0b..00000000000000 --- a/paddle/fluid/operators/math/sampler.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/sampler.h" - -#include - -#include "paddle/phi/core/generator.h" - -namespace paddle { -namespace operators { -namespace math { - -Sampler::~Sampler() = default; - -UniformSampler::UniformSampler(int64_t range, unsigned int seed) - : Sampler(range, seed), inv_range_(1.0f / (range + 1)) { // NOLINT - random_engine_ = phi::GetCPURandomEngine(seed_); - dist_ = std::make_shared>(0, range); -} - -int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); } - -float UniformSampler::Probability(int64_t value) const { return inv_range_; } - -LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed) - : Sampler(range, seed), log_range_(log(range + 1)) { // NOLINT - random_engine_ = phi::GetCPURandomEngine(seed_); - dist_ = std::make_shared>(0, 1); -} - -int64_t LogUniformSampler::Sample() const { - // Got Log Uniform distribution from uniform distribution by - // inverse_transform_sampling method - // More details: - // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ - auto cur_random = (*dist_)(*random_engine_); - const int64_t value = static_cast(exp(cur_random * log_range_)) - 1; - // Mathematically, value should be <= range_, but might not be due to some - // floating point roundoff, so we mod by range_. - return value % range_; -} - -float LogUniformSampler::Probability(int64_t value) const { - // Given f(x) = 1/[(x+1) * log_range_] - // The value's probability is integral of f(x) from value to (value + 1) - // More details: - // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler - return (log((value + 2.0) / (value + 1.0))) / log_range_; // NOLINT -} - -CustomSampler::CustomSampler(int64_t range, - const float *probabilities, - const int *alias, - const float *alias_probabilities, - unsigned int seed) - : Sampler(range, seed) { - random_engine_ = phi::GetCPURandomEngine(seed_); - real_dist_ = std::make_shared>(0, 1); - int_dist_ = std::make_shared>(0, range); - - alias_probs_ = alias_probabilities; - probs_ = probabilities; - alias_ = alias; -} - -int64_t CustomSampler::Sample() const { - auto index = (*int_dist_)(*random_engine_); - auto p = (*real_dist_)(*random_engine_); - if (p > alias_probs_[index]) { - int alias = alias_[index]; - - if (alias == exceptional_val) { - LOG(WARNING) << "WARNING: CustomSampler get alias " << exceptional_val; - return index; - } - - return alias; - } else { - return index; - } -} - -float CustomSampler::Probability(int64_t value) const { return probs_[value]; } - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sampler.h b/paddle/fluid/operators/math/sampler.h deleted file mode 100644 index e14e1ca572cab7..00000000000000 --- a/paddle/fluid/operators/math/sampler.h +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace math { - -// TODO(wanghaoshuang): Support for GPU - -/** - * Sample integers from [0, range). - */ -class Sampler { - public: - explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) { - PADDLE_ENFORCE_GT( - range, - 0, - phi::errors::InvalidArgument( - "Range should be greater than 0, but received %d.", range)); - if (seed == 0) { - std::random_device r; - seed_ = r(); - } else { - seed_ = seed; - } - } - - virtual ~Sampler(); - - // Sample a single value - virtual int64_t Sample() const = 0; - - // The probability that a single call to Sample() returns the given value. - virtual float Probability(int64_t value) const = 0; - - int64_t range() { return range_; } - - protected: - const int64_t range_; - unsigned int seed_; -}; - -/** - * Sample integers from [0, range). - * And the distribution function is: - * P(x) = 1 / range - */ -class UniformSampler : public Sampler { - public: - explicit UniformSampler(int64_t range, unsigned int seed = 0UL); - - ~UniformSampler() override {} - - int64_t Sample() const override; - - float Probability(int64_t value) const override; - - private: - const float inv_range_; - std::shared_ptr random_engine_; - std::shared_ptr> dist_; -}; - -/** - * Sample integers from [0, range). - * And the distribution function is: - * P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1)) - */ -class LogUniformSampler : public Sampler { - public: - explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL); - - ~LogUniformSampler() override {} - - int64_t Sample() const override; - - float Probability(int64_t value) const override; - - private: - const float log_range_; - std::shared_ptr random_engine_; - std::shared_ptr> dist_; -}; - -/** - * Sample integers from [0, range) from custom distribution. - */ -class CustomSampler : public Sampler { - public: - explicit CustomSampler(int64_t range, - const float* probabilities, - const int* alias, - const float* alias_probabilities, - unsigned int seed = 0UL); - - ~CustomSampler() override {} - - int64_t Sample() const override; - - float Probability(int64_t value) const override; - - private: - const float* alias_probs_; - const int* alias_; - const float* probs_; - const int exceptional_val = -1; - std::shared_ptr random_engine_; - std::shared_ptr> real_dist_; - std::shared_ptr> int_dist_; -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 5ad76785276dad..19eb81cb3d2b73 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -24,15 +24,15 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/operators/math/sampler.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math/sampler.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace operators { using SelectedRows = phi::SelectedRows; -using Sampler = math::Sampler; +using Sampler = phi::math::Sampler; using DDim = framework::DDim; template { Sampler *sampler; switch (sampler_type) { case 0: { - sampler = new math::UniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::UniformSampler(num_total_classes - 1, seed); break; } case 1: { - sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::LogUniformSampler(num_total_classes - 1, seed); break; } case 2: { @@ -136,11 +136,11 @@ class NCEKernel : public framework::OpKernel { const float *probs_data = dist_probs->data(); const int *alias_data = dist_alias->data(); const float *alias_probs_data = dist_alias_probs->data(); - sampler = new math::CustomSampler(num_total_classes - 1, - probs_data, - alias_data, - alias_probs_data, - seed); + sampler = new phi::math::CustomSampler(num_total_classes - 1, + probs_data, + alias_data, + alias_probs_data, + seed); break; } default: { @@ -274,11 +274,11 @@ class NCEGradKernel : public framework::OpKernel { Sampler *sampler; switch (sampler_type) { case 0: { - sampler = new math::UniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::UniformSampler(num_total_classes - 1, seed); break; } case 1: { - sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + sampler = new phi::math::LogUniformSampler(num_total_classes - 1, seed); break; } case 2: { @@ -322,11 +322,11 @@ class NCEGradKernel : public framework::OpKernel { const float *probs_data = dist_probs->data(); const int *alias_data = dist_alias->data(); const float *alias_probs_data = dist_alias_probs->data(); - sampler = new math::CustomSampler(num_total_classes - 1, - probs_data, - alias_data, - alias_probs_data, - seed); + sampler = new phi::math::CustomSampler(num_total_classes - 1, + probs_data, + alias_data, + alias_probs_data, + seed); break; } default: { diff --git a/paddle/fluid/operators/tdm_child_op.cc b/paddle/fluid/operators/tdm_child_op.cc index 41bcae86c551bd..6e3804fcb0a923 100644 --- a/paddle/fluid/operators/tdm_child_op.cc +++ b/paddle/fluid/operators/tdm_child_op.cc @@ -17,7 +17,6 @@ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { diff --git a/paddle/fluid/operators/unbind_op.h b/paddle/fluid/operators/unbind_op.h index ea2c6d4ee2bb8c..dad3e2ed9001bd 100644 --- a/paddle/fluid/operators/unbind_op.h +++ b/paddle/fluid/operators/unbind_op.h @@ -20,8 +20,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" namespace paddle { diff --git a/paddle/fluid/operators/unique_op.h b/paddle/fluid/operators/unique_op.h index 47bd4674c9a299..0bced76407b7e8 100644 --- a/paddle/fluid/operators/unique_op.h +++ b/paddle/fluid/operators/unique_op.h @@ -22,7 +22,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -304,7 +304,7 @@ static void UniqueDim(const framework::ExecutionContext& context, indices_vec.erase(indices_vec.begin() + input_unbind.size(), indices_vec.end()); - math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; phi::DenseTensor out_trans; std::vector out_trans_dims_vec = in_trans_dims_vec; out_trans_dims_vec[0] = input_unbind.size(); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index c93588f73d6f3b..ba096252689e05 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -31,10 +31,10 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/pybind/complex.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/strided_memcpy.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -721,7 +721,7 @@ void _concatCompute(const std::vector &ins, output_offset += in_stride[axis]; } } else { - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(ctx, ins, static_cast(axis), out); } } diff --git a/test/cpp/fluid/math/concat_test.cc b/test/cpp/fluid/math/concat_test.cc index 080a659ecdbbc6..b93c7c9a4870bd 100644 --- a/test/cpp/fluid/math/concat_test.cc +++ b/test/cpp/fluid/math/concat_test.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" /** * case 1: @@ -77,7 +77,7 @@ void ConcatCase1(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 0, &out); // check the dim of input_a, input_b @@ -182,7 +182,7 @@ void ConcatCase2(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 1, &out); // check the dim of input_a, input_b @@ -291,7 +291,7 @@ void ConcatCase3(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 2, &out); // check the dim of input_a, input_b @@ -402,7 +402,7 @@ void ConcatCase4(DeviceContext* context) { input.push_back(input_a); input.push_back(input_b); - paddle::operators::math::ConcatFunctor concat_functor; + phi::funcs::ConcatFunctor concat_functor; concat_functor(*context, input, 1, &out); context->Wait();