diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 280f24bdd6fa69..7659c42963196a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -80,7 +80,6 @@ if (NOT WITH_MKL OR NOT WITH_AVX) SET(OP_MKL_DEPS ${OP_MKL_DEPS} var_conv_2d_op) endif() if(WITH_COVERAGE OR WIN32 OR WITH_NV_JETSON) - SET(OP_MKL_DEPS ${OP_MKL_DEPS} pyramid_hash_op) endif() if(WITH_UNITY_BUILD) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index c9bee1eb607059..f28fa284482631 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -43,8 +43,6 @@ detection_library(bipartite_match_op SRCS bipartite_match_op.cc) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) detection_library(anchor_generator_op SRCS anchor_generator_op.cc anchor_generator_op.cu) -detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc - polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) diff --git a/paddle/fluid/operators/detection/polygon_box_transform_op.cc b/paddle/fluid/operators/detection/polygon_box_transform_op.cc deleted file mode 100644 index 0059aedcdc86ca..00000000000000 --- a/paddle/fluid/operators/detection/polygon_box_transform_op.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class PolygonBoxTransformCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument("It must use CUDAPlace.")); - auto* in = ctx.Input("Input"); - auto in_dims = common::vectorize(in->dims()); - const T* in_data = in->data(); - auto* out = ctx.Output("Output"); - T* out_data = out->mutable_data(ctx.GetPlace()); - - int batch_size = in_dims[0]; - int geo_channel = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - int id = 0; - for (int id_n = 0; id_n < batch_size * geo_channel; ++id_n) { - for (int id_h = 0; id_h < height; ++id_h) { - for (int id_w = 0; id_w < width; ++id_w) { - id = id_n * height * width + width * id_h + id_w; - if (id_n % 2 == 0) { - out_data[id] = id_w * 4 - in_data[id]; - } else { - out_data[id] = id_h * 4 - in_data[id]; - } - } - } - } - } -}; - -class PolygonBoxTransformOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("Input"), "Input", "Input", "polygon_box_transform"); - OP_INOUT_CHECK( - ctx->HasOutput("Output"), "Output", "Output", "polygon_box_transform"); - - auto in_dim = ctx->GetInputDim("Input"); - - PADDLE_ENFORCE_EQ( - in_dim.size(), - 4, - platform::errors::InvalidArgument( - "input's rank must be 4. But received: Input rank is [%d]", - in_dim.size())); - PADDLE_ENFORCE_EQ(in_dim[1] % 2, - 0, - platform::errors::InvalidArgument( - "input's second dimension must be even. But " - "received: Input 2nd dimension is [%d]", - in_dim[1])); - - ctx->SetOutputDim("Output", in_dim); - } -}; - -class PolygonBoxTransformOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "Input", - "The input with shape [batch_size, geometry_channels, height, width]"); - AddOutput("Output", "The output with the same shape as input"); - - AddComment(R"DOC( -PolygonBoxTransform Operator. - -PolygonBoxTransform Operator is used to transform the coordinate shift to the real coordinate. - -The input is the final geometry output in detection network. -We use 2*n numbers to denote the coordinate shift from n corner vertices of -the polygon_box to the pixel location. As each distance offset contains two numbers (xi, yi), -the geometry output contains 2*n channels. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR( - polygon_box_transform, - ops::PolygonBoxTransformOp, - ops::PolygonBoxTransformOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -PD_REGISTER_STRUCT_KERNEL(polygon_box_transform, - CPU, - ALL_LAYOUT, - ops::PolygonBoxTransformCPUKernel, - float, - double) {} diff --git a/paddle/fluid/operators/detection/polygon_box_transform_op.cu b/paddle/fluid/operators/detection/polygon_box_transform_op.cu deleted file mode 100644 index 4f182464f77b50..00000000000000 --- a/paddle/fluid/operators/detection/polygon_box_transform_op.cu +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/phi/backends/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using phi::PADDLE_CUDA_NUM_THREADS; -#define CUDA_BLOCK_SIZE 16 - -template -__global__ void PolygonBoxTransformKernel( - const int n, const int h, const int w, const T* input, T* output) { - int id_n = threadIdx.x + blockDim.x * blockIdx.x; - int id_h = threadIdx.y + blockDim.y * blockIdx.y; - int id_w = threadIdx.z + blockDim.z * blockIdx.z; - if (id_n < n && id_h < h && id_w < w) { - int id = id_n * h * w + w * id_h + id_w; - if (id_n % 2 == 0) { - output[id] = id_w * 4 - input[id]; - } else { - output[id] = id_h * 4 - input[id]; - } - } -} - -template -class PolygonBoxTransformOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument( - "The polygon_box_transform operator needs to be executed on GPU.")); - auto* in = ctx.Input("Input"); - auto in_dims = in->dims(); - const T* in_data = in->data(); - auto* out = ctx.Output("Output"); - T* out_data = out->mutable_data(ctx.GetPlace()); - - int batch_size = in_dims[0]; - int geo_channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - dim3 threadsPerBlock( - PADDLE_CUDA_NUM_THREADS / (CUDA_BLOCK_SIZE * CUDA_BLOCK_SIZE), - CUDA_BLOCK_SIZE, - CUDA_BLOCK_SIZE); - dim3 numBlocks((batch_size * geo_channels) / threadsPerBlock.x, - (height + threadsPerBlock.y - 1) / threadsPerBlock.y, - (width + threadsPerBlock.z - 1) / threadsPerBlock.z); - auto stream = ctx.cuda_device_context().stream(); - PolygonBoxTransformKernel<<>>( - batch_size * geo_channels, height, width, in_data, out_data); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL(polygon_box_transform, - GPU, - ALL_LAYOUT, - ops::PolygonBoxTransformOpCUDAKernel, - float, - double) {} diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc deleted file mode 100644 index 074cc26c994e37..00000000000000 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/optimizers/proximal_gd_op.h" - -namespace paddle { -namespace operators { - -class ProximalGDOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "ProximalGDOp"); - OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "ProximalGDOp"); - OP_INOUT_CHECK( - ctx->HasInput("LearningRate"), "Input", "LearningRate", "ProximalGDOp"); - - OP_INOUT_CHECK( - ctx->HasOutput("ParamOut"), "Output", "Paramout", "ProximalGDOp"); - - auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ(param_dim, - ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "The shape of Intput(Param) should be equal to the " - "Input(Grad) of ProximalGD Op. But received " - "Input(Param).dimensions=[%s], " - "Input(Grad).dimensions=[%s]", - param_dim, - ctx->GetInputDim("Grad"))); - - auto lr_dim = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_EQ( - common::product(lr_dim), - 1, - platform::errors::InvalidArgument( - "Learning Rate should be a scalar. But received dimensions:[%s]", - lr_dim)); - - ctx->SetOutputDim("ParamOut", param_dim); - } - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"), - ctx.GetPlace()); - } -}; - -class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Param", - "(Tensor, default Tensor) " - "Input parameter value that has to be updated."); - AddInput("Grad", - "(Tensor, default Tensor) " - "Input gradient of the parameter."); - AddInput("LearningRate", - "(Tensor, default Tensor) " - "The learning rate should be a tensor of size 1."); - - AddOutput("ParamOut", "(Tensor) Output updated parameter value."); - - AddAttr("l1", - "(float, default 0.0) " - "L1 regularization strength.") - .SetDefault(0.0f); - AddAttr("l2", - "(float, default 0.0) " - "L2 regularization strength.") - .SetDefault(0.0f); - AddComment(R"DOC( -ProximalGD Operator. - -Optimizer that implements the proximal gradient descent algorithm: - -$$ -prox\_param = param - learning\_rate * grad \\ -param = sign(prox\_param) / (1 + learning\_rate * l2) * - \max(|prox\_param| - learning\_rate * l1, 0) -$$ - -The paper that proposed Proximal Gradient Descent: -(http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf) - -)DOC"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(proximal_gd, - ops::ProximalGDOp, - ops::ProximalGDOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - proximal_gd, CPU, ALL_LAYOUT, ops::ProximalGDOpKernel, float) {} diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cu b/paddle/fluid/operators/optimizers/proximal_gd_op.cu deleted file mode 100644 index ef1edfc2ee458f..00000000000000 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cu +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software distributed -under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -CONDITIONS OF ANY KIND, either express or implied. See the License for the -specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/proximal_gd_op.h" - -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL( - proximal_gd, GPU, ALL_LAYOUT, ops::ProximalGDOpKernel, float) {} diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.h b/paddle/fluid/operators/optimizers/proximal_gd_op.h deleted file mode 100644 index 1945ef5bf6b778..00000000000000 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class ProximalGDOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* param_out = ctx.Output("ParamOut"); - - param_out->mutable_data(ctx.GetPlace()); - - auto grad = ctx.Input("Grad"); - - auto l1 = static_cast(ctx.Attr("l1")); - auto l2 = static_cast(ctx.Attr("l2")); - - auto p = framework::EigenVector::Flatten( - *ctx.Input("Param")); - auto g = framework::EigenVector::Flatten(*grad); - auto lr = framework::EigenVector::Flatten( - *ctx.Input("LearningRate")); - - auto p_out = framework::EigenVector::Flatten(*param_out); - auto& place = *ctx.template device_context().eigen_device(); - - Eigen::DSizes grad_dsize(grad->numel()); - - auto prox_param = p - lr.broadcast(grad_dsize) * g; - if (l1 > 0) { - p_out.device(place) = - prox_param.sign() * - (((prox_param.abs() - (lr * l1).broadcast(grad_dsize)) - .cwiseMax(T(0.0))) / - (1.0f + (lr * l2).broadcast(grad_dsize))); - } else { - p_out.device(place) = - prox_param / (1.0f + (lr * l2).broadcast(grad_dsize)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc deleted file mode 100644 index 96d8bbaa6f772f..00000000000000 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ /dev/null @@ -1,262 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/positive_negative_pair_op.h" - -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { - -class PositiveNegativePairOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("Score"), "Input", "Score", "positive_negative_pair"); - OP_INOUT_CHECK( - ctx->HasInput("Label"), "Input", "Label", "positive_negative_pair"); - OP_INOUT_CHECK( - ctx->HasInput("QueryID"), "Input", "QueryID", "positive_negative_pair"); - OP_INOUT_CHECK(ctx->HasOutput("PositivePair"), - "Output", - "PositivePair", - "positive_negative_pair"); - OP_INOUT_CHECK(ctx->HasOutput("NegativePair"), - "Output", - "NegativePair", - "positive_negative_pair"); - OP_INOUT_CHECK(ctx->HasOutput("NeutralPair"), - "Output", - "NeutralPair", - "positive_negative_pair"); - - auto scalar_dim = common::make_ddim({1}); - if (ctx->HasInput("AccumulatePositivePair") || - ctx->HasInput("AccumulateNegativePair") || - ctx->HasInput("AccumulateNeutralPair")) { - PADDLE_ENFORCE_EQ( - ctx->HasInput("AccumulatePositivePair") && - ctx->HasInput("AccumulateNegativePair") && - ctx->HasInput("AccumulateNeutralPair"), - true, - platform::errors::InvalidArgument( - "All optional inputs(AccumulatePositivePair, " - "AccumulateNegativePair, AccumulateNeutralPair) of " - "PositiveNegativePairOp are required if one of them " - "is specified.")); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("AccumulatePositivePair"), - scalar_dim, - platform::errors::InvalidArgument( - "Shape of Input(AccumulatePositivePair) should be [1]. Received " - "shape of Input(AccumulatePositivePair): [%s].", - ctx->GetInputDim("AccumulatePositivePair"))); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("AccumulateNegativePair"), - scalar_dim, - platform::errors::InvalidArgument( - "Shape of Input(AccumulateNegativePair) should be [1]. Received " - "shape of Input(AccumulateNegativePair): [%s].", - ctx->GetInputDim("AccumulateNegativePair"))); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("AccumulateNeutralPair"), - scalar_dim, - platform::errors::InvalidArgument( - "Shape of Input(AccumulateNeutralPair) should be [1]. Received " - "shape of Input(AccumulateNeutralPair): [%s].", - ctx->GetInputDim("AccumulateNeutralPair"))); - } - - auto score_dim = ctx->GetInputDim("Score"); - auto label_dim = ctx->GetInputDim("Label"); - auto query_dim = ctx->GetInputDim("QueryID"); - PADDLE_ENFORCE_EQ(score_dim.size(), - 2, - platform::errors::InvalidArgument( - "Score should be a 2-D tensor. Received shape of " - "Input(Score): [%s].", - score_dim)); - PADDLE_ENFORCE_EQ(label_dim.size(), - 2, - platform::errors::InvalidArgument( - "Label should be a 2-D tensor. Received shape of " - "Input(Label): [%s].", - label_dim)); - - if (ctx->IsRuntime() || - (score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) { - PADDLE_ENFORCE_EQ( - label_dim[0], - score_dim[0], - platform::errors::InvalidArgument( - "Input(Score) and Input(Label) should have the same " - "height (batch size). Received: the shape of Input(Score) is " - "[%s], while the shape of Input(Label) is [%s]. The first " - "dimensions of them are different.", - label_dim, - score_dim)); - - PADDLE_ENFORCE_EQ( - label_dim[1], - 1, - platform::errors::InvalidArgument( - "The width of Label should be 1, i.e. each item should " - "have a scalar label. Received shape of Input(Label) is [%s]. " - "The second dimension of it is %d, while the expected is %d.", - label_dim, - label_dim[1], - 1)); - - PADDLE_ENFORCE_EQ( - query_dim, - label_dim, - platform::errors::InvalidArgument( - "Input(QueryID) should have the same shape as Input(Label). " - "Received: the shape of Input(QueryID) is [%s], " - "while the shape of Input(Label) is [%s].", - query_dim, - label_dim)); - - if (ctx->HasInput("Weight")) { - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Weight"), - label_dim, - platform::errors::InvalidArgument( - "Input(Weight) should have the same shape as Input(Label). " - "Received: the shape of Input(Weight) is [%s] while the shape " - "of Input(Label) is [%s].", - ctx->GetInputDim("Weight"), - label_dim)); - } - - int column = ctx->Attrs().Get("column"); - auto depth = score_dim[1]; - PADDLE_ENFORCE_LT( - column, - depth, - platform::errors::OutOfRange( - "Attr(column) should be less than depth(the second " - "dimension of Input(Score)). Received Attr(column): %d, while " - "depth is %d.", - column, - depth)); - PADDLE_ENFORCE_GE( - column, - -depth, - platform::errors::OutOfRange( - "Attr(column) should be greater than equal to negative " - "depth, i.e. the second dimension of Input(Score). " - "Received Attr(column): %d, while negative depth is %d.", - column, - -depth)); - } - - ctx->SetOutputDim("PositivePair", scalar_dim); - ctx->SetOutputDim("NegativePair", scalar_dim); - ctx->SetOutputDim("NeutralPair", scalar_dim); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Score"), - ctx.device_context().GetPlace()); - } -}; - -class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Score", - "(Tensor, float) Model Score on an item (with " - "respect to QueryID). It's a 2-D tensor with shape [batch_size, " - "depth], where the column specified by the attribute \"column\" " - "is used as item score."); - AddInput("Label", - "(Tensor, float) Label of an item (with repsect to " - "QueryId). It's a 2-D tensor with shape [batch_size, 1]."); - AddInput("QueryID", - "(Tensor, int64) Query ID that indicates the context. Its shape " - "should be the same as Label."); - AddInput( - "AccumulatePositivePair", - "(float) Optional. The accumulated number of positive pairs over a " - "stream of data. If provided, the output PositivePair will be " - "initialized with this number rather than 0. it won't be modified " - "in place.") - .AsDispensable(); - AddInput( - "AccumulateNegativePair", - "(float) Optional. The accumulated number of negative pairs over a " - "stream of data. If provided, the output NegativePair will be " - "initialized with this number rather than 0. it won't be modified " - "in place.") - .AsDispensable(); - AddInput("AccumulateNeutralPair", - "(float) Optional. The accumulated number of neutral pairs over a " - "stream of data. If provided, the output NeutralPair will be " - "initialized with this number rather than 0. it won't be modified " - "in place.") - .AsDispensable(); - AddInput("Weight", - "(float) Optional. Weight of current item. If specified, its " - "shape should be the same as Label, and the meaning of the output " - "changes from numbers of pairs to the total sum of pairs' " - "weights. Weight of a pair of items is the average of their " - "weights.") - .AsDispensable(); - AddOutput("PositivePair", - "(float) Number of positive pairs, i.e. the pairs of " - "items that are ranked correctly."); - AddOutput("NegativePair", - "(float) Number of negative pairs, i.e. the pairs of " - "items that are ranked incorrectly."); - AddOutput("NeutralPair", - "(float) Number of neutral pairs, i.e. the pairs of items " - "that have the same score.") - .AsDispensable(); - AddAttr( - "column", - "(int, default -1) The column position of Score used to rank items in " - "descending order. It must be in the range of [-rank(Score), " - "rank(Score)). " - "If `dim < 0`, the dim to reduce is `rank + dim`. " - "Noting that reducing on the first dim will make the LoD info lost.") - .SetDefault(0); - AddComment(R"DOC( -PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR) model's -performance. - -Within some context, e.g. the "query", a LTR model generates scores for a list -of items, which gives a partial order of the items. PositiveNegativePairOp -takes a list of reference rank order (Input("Label")) and the model generated -scores (Input(Score)) as inputs and counts the pairs that ranked correctly -and incorrectly. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(positive_negative_pair, - ops::PositiveNegativePairOp, - ops::PositiveNegativePairOpMaker); - -PD_REGISTER_STRUCT_KERNEL(positive_negative_pair, - CPU, - ALL_LAYOUT, - ops::PositiveNegativePairKernel, - float, - double) {} diff --git a/paddle/fluid/operators/positive_negative_pair_op.h b/paddle/fluid/operators/positive_negative_pair_op.h deleted file mode 100644 index 0cddbcc3abf853..00000000000000 --- a/paddle/fluid/operators/positive_negative_pair_op.h +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class PositiveNegativePairKernel : public framework::OpKernel { - public: - struct PredictionResult { - PredictionResult(T score, T label, T weight) - : score(score), label(label), weight(weight) {} - T score; - T label; - T weight; - }; - - void Compute(const framework::ExecutionContext& context) const override { - auto score_t = context.Input("Score"); - auto label_t = context.Input("Label"); - auto query_t = context.Input("QueryID"); - auto acc_positive_t = - context.Input("AccumulatePositivePair"); - auto acc_negative_t = - context.Input("AccumulateNegativePair"); - auto acc_neutral_t = - context.Input("AccumulateNeutralPair"); - auto positive_t = context.Output("PositivePair"); - auto negative_t = context.Output("NegativePair"); - auto neutral_t = context.Output("NeutralPair"); - auto weight_t = context.Input("Weight"); - - auto score = score_t->data(); - auto label = label_t->data(); - auto query = query_t->data(); - const T* weight = nullptr; - if (weight_t != nullptr) { - weight = weight_t->data(); - } - T* positive = positive_t->mutable_data(context.GetPlace()); - T* negative = negative_t->mutable_data(context.GetPlace()); - T* neutral = neutral_t->mutable_data(context.GetPlace()); - - auto score_dim = score_t->dims(); - auto batch_size = score_dim[0]; - auto width = score_dim[1]; - auto column = context.Attr("column"); - if (column < 0) { - column += width; - } - - // construct document instances for each query: Query => List[, ...] - std::unordered_map> predictions; - for (auto i = 0; i < batch_size; ++i) { - if (predictions.find(query[i]) == predictions.end()) { - predictions.emplace( - std::make_pair(query[i], std::vector())); - } - predictions[query[i]].emplace_back(score[i * width + column], - label[i], - weight_t != nullptr ? weight[i] : 1.0); - } - - // for each query, accumulate pair counts - T pos = 0, neg = 0, neu = 0; - if (acc_positive_t != nullptr && acc_negative_t != nullptr && - acc_neutral_t != nullptr) { - pos = acc_positive_t->data()[0]; - neg = acc_negative_t->data()[0]; - neu = acc_neutral_t->data()[0]; - } - auto evaluate_one_list = - [&pos, &neg, &neu](std::vector vec) { - for (auto ite1 = vec.begin(); ite1 != vec.end(); ++ite1) { - for (auto ite2 = ite1 + 1; ite2 != vec.end(); ++ite2) { - if (ite1->label == ite2->label) { // labels are equal, ignore. - continue; - } - T w = (ite1->weight + ite2->weight) * 0.5; - if (ite1->score == ite2->score) { - neu += w; - } - (ite1->score - ite2->score) * (ite1->label - ite2->label) > 0.0 - ? pos += w - : neg += w; - } - } - }; - for (auto prediction : predictions) { - evaluate_one_list(prediction.second); - } - *positive = pos; - *negative = neg; - *neutral = neu; - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/pyramid_hash_op.cc b/paddle/fluid/operators/pyramid_hash_op.cc deleted file mode 100644 index f5a8fcaa9de0ce..00000000000000 --- a/paddle/fluid/operators/pyramid_hash_op.cc +++ /dev/null @@ -1,596 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include -#include - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/search_compute.h" - -extern "C" { -#include "math/bloomfilter.h" -} - -namespace paddle { -namespace operators { - -using LoD = framework::LoD; - -class PyramidHashOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "X (Tensor, MUST be Tensor) Input variable which " - "should contain lod information."); - AddInput("W", "W (Tensor)"); - AddInput("WhiteList", "WhiteList (Tensor)"); - AddInput("BlackList", "BlackList (Tensor)"); - AddAttr("num_emb", "num_emb").SetDefault(0).EqualGreaterThan(0); - AddAttr("space_len", "space_len").SetDefault(0).EqualGreaterThan(0); - AddAttr("pyramid_layer", "pyramid_layer (must be >= 2)") - .SetDefault(2) - .EqualGreaterThan(2); - AddAttr("rand_len", "rand_len").SetDefault(0).EqualGreaterThan(0); - AddAttr("drop_out_percent", "drop_out_percent") - .SetDefault(0) - .EqualGreaterThan(0); - AddAttr("is_training", "is_training") - .SetDefault(0) - .EqualGreaterThan(0); - AddAttr("use_filter", "use_filter").SetDefault(true); - AddAttr("white_list_len", "white_list_len") - .SetDefault(0) - .EqualGreaterThan(0); - AddAttr("black_list_len", "black_list_len") - .SetDefault(0) - .EqualGreaterThan(0); - AddAttr("seed", "seed").SetDefault(0).EqualGreaterThan(0); - AddAttr("lr", "learning rate").SetDefault(0.0).EqualGreaterThan(0.0); - AddAttr( - "distribute_update_vars", - "['PyramidHash_emb_0','Filter']" - "Decided which params should be updated in distribute training. " - "Used in Distribute Transpiler to create a trainer/server program.") - .SetDefault(""); - AddOutput("Out", "Out (Tensor, default Tensor) Output variable"); - AddOutput("DropPos", "Out (Tensor, Tensor) Output variable"); - AddOutput("X_Temp_Out", "Out (Tensor, Tensor) Output variable") - .AsIntermediate(); - - AddComment(R"DOC( - PyramidHash - - NOTE: only support 'float32' data type now. - - )DOC"); - } -}; - -class PyramidHashOP : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), - true, - platform::errors::NotFound("Input(X) of PyramidHashOP is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("W"), - true, - platform::errors::NotFound("Input(W) of PyramidHashOP is not found.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), - true, - platform::errors::NotFound( - "Output(Out) of PyramidHashOP is not found.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("DropPos"), - true, - platform::errors::NotFound( - "Output(DropPos) of PyramidHashOP is not found.")); - - auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(X) of PyramidHashOP is invalid. " - "It should be 2, but got %d", - x_dims.size())); - - auto w_dims = ctx->GetInputDim("W"); - PADDLE_ENFORCE_EQ(w_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(W) of PyramidHashOP is invalid. " - "It should be 2, but got %d", - w_dims.size())); - - int space_len = ctx->Attrs().Get("space_len"); - int rand_len = ctx->Attrs().Get("rand_len"); - - PADDLE_ENFORCE_EQ( - w_dims[0], - space_len + rand_len, - platform::errors::InvalidArgument( - "The first dimension of Input(W) of PyramidHashOP is invalid. " - "It should be space_len + rand_len, but now %d != %d + %d", - w_dims[0], - space_len, - rand_len)); - PADDLE_ENFORCE_EQ( - w_dims[1], - 1, - platform::errors::InvalidArgument( - "The second dimension of Input(W) of PyramidHashOP is invalid." - " It should be 1, but got %d", - w_dims[1])); - - int num_emb = ctx->Attrs().Get("num_emb"); - PADDLE_ENFORCE_EQ( - num_emb % rand_len, - 0, - platform::errors::InvalidArgument( - "The PyramidHashOP's Attr(num_emb) should mod Attr(rand_len), " - "but num_emb is %d, rand_len is %d", - num_emb, - rand_len)); - - int white_list_len = ctx->Attrs().Get("white_list_len"); - if (white_list_len > 0) { - PADDLE_ENFORCE_EQ( - ctx->HasInput("WhiteList"), - true, - platform::errors::NotFound("Input(WhiteList) of PyramidHashOP is not " - "found but white_list_len > 0.")); - auto wl_dims = ctx->GetInputDim("WhiteList"); - PADDLE_ENFORCE_EQ( - wl_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(WhiteList) of PyramidHashOP is invalid." - " It should be 2, but got %d", - wl_dims.size())); - PADDLE_ENFORCE_EQ(wl_dims[0], - white_list_len, - platform::errors::InvalidArgument( - "The first dimension of Input(WhiteList) of " - "PyramidHashOP is invalid." - " It should be equal to Attr(white_list_len) " - ", but first dimension is %d, white_list_len is %d", - wl_dims[0], - white_list_len)); - PADDLE_ENFORCE_EQ(wl_dims[1], - 1, - platform::errors::InvalidArgument( - "The second dimension of Input(WhiteList) of " - "PyramidHashOP is invalid." - " It should be 1, but got %d", - wl_dims[1])); - } - - int black_list_len = ctx->Attrs().Get("black_list_len"); - if (black_list_len > 0) { - PADDLE_ENFORCE_EQ( - ctx->HasInput("BlackList"), - true, - platform::errors::NotFound("Input(BlackList) of PyramidHashOP is not " - "found but black_list_len > 0.")); - auto bl_dims = ctx->GetInputDim("BlackList"); - PADDLE_ENFORCE_EQ( - bl_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(BlackList) of PyramidHashOP is invalid." - " It should be 2, but got %d", - bl_dims.size())); - PADDLE_ENFORCE_EQ(bl_dims[0], - black_list_len, - platform::errors::InvalidArgument( - "The first dimension of Input(BlackList) of " - "PyramidHashOP is invalid." - " It should be equal to Attr(black_list_len)" - ", but first dimension is %d, black_list_len is %d", - bl_dims[0], - black_list_len)); - PADDLE_ENFORCE_EQ(bl_dims[1], - 1, - platform::errors::InvalidArgument( - "The second dimension of Input(BlackList) of " - "PyramidHashOP is invalid." - " It should be 1, but got %d", - bl_dims[1])); - } - - if (ctx->IsRuntime()) { - // something to do in runtime. - } else { - // compile time - ctx->SetOutputDim("Out", common::make_ddim({-1, num_emb})); - ctx->SetOutputDim("X_Temp_Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"), - ctx.GetPlace()); - } -}; - -template -class CPUPyramidHashOPKernel : public framework::OpKernel { - public: - bool should_use_term(math::bloomfilter* _filter, - math::bloomfilter* _black_filter, - const float* word_repr, - int len) const { - return (!_filter || 1 == math::bloomfilter_get( - _filter, word_repr, len * sizeof(float))) && - (!_black_filter || - 0 == math::bloomfilter_get( - _black_filter, word_repr, len * sizeof(float))); - } - - void hash_embedding_ff(const float* hash_id, - int len, - T* top_pos, - const T* weights, - int _num_emb, - int _rand_len, - int _space_len) const { - unsigned int pos1 = XXH32(hash_id, len * sizeof(float), 0) % _space_len; - unsigned int pos2 = - XXH32(hash_id, len * sizeof(float), _rand_len) % _space_len; - - for (int j = 0; j != _num_emb; j += _rand_len) { - if (j + _rand_len < _num_emb) { - __builtin_prefetch(weights + pos2); - __builtin_prefetch(top_pos + j + _rand_len); - } - - unsigned int pos3 = - XXH32(hash_id, len * sizeof(float), j + 2 * _rand_len) % _space_len; - memcpy( - top_pos + j, const_cast(weights + pos1), _rand_len * sizeof(T)); - pos1 = pos2; - pos2 = pos3; - } - } - - void Compute(const framework::ExecutionContext& ctx) const override { - auto* bottom = ctx.Input("X"); - auto* _blobs_0 = ctx.Input("W"); - auto* _blobs_1 = ctx.Input("WhiteList"); - auto* _blobs_2 = ctx.Input("BlackList"); - auto* top = ctx.Output("Out"); - auto* drop_pos = ctx.Output("DropPos"); - - int _num_emb = ctx.Attr("num_emb"); - bool use_filter = ctx.Attr("use_filter"); - int white_list_len = ctx.Attr("white_list_len"); - int black_list_len = ctx.Attr("black_list_len"); - int _pyramid_layer = ctx.Attr("pyramid_layer"); - int _is_training = ctx.Attr("is_training"); - int seed = ctx.Attr("seed"); - unsigned int _seed = (unsigned int)seed; - int _rand_len = ctx.Attr("rand_len"); - int _space_len = ctx.Attr("space_len"); - float _drop_out_percent = ctx.Attr("drop_out_percent"); - - const auto& offset = bottom->lod()[0]; - const auto* bottom_data_ori = bottom->data(); - auto* buff = ctx.Output("X_Temp_Out"); - buff->Resize(common::make_ddim({bottom->dims()[0], bottom->dims()[1]})); - float* bottom_data = buff->mutable_data(ctx.GetPlace()); - for (int i = 0; i < bottom->dims()[0]; i++) { - bottom_data[i] = bottom_data_ori[i]; // NOLINT - } - - const auto* weights = _blobs_0->data(); - - std::vector top_offset; - top_offset.resize(offset.size()); - top_offset[0] = 0; - - math::bloomfilter* _filter = nullptr; - math::bloomfilter* _black_filter = nullptr; - if (use_filter) { - if (white_list_len != 0) { - _filter = (math::bloomfilter*)_blobs_1->data(); - PADDLE_ENFORCE_EQ( - math::bloomfilter_check(_filter), - 1, - platform::errors::PreconditionNotMet( - "The white filter is not loaded successfully, please make sure " - "'white_list_len': %d is valid for Input(WhiteList).", - white_list_len)); - } - if (black_list_len != 0) { - _black_filter = (math::bloomfilter*)_blobs_2->data(); - PADDLE_ENFORCE_EQ( - math::bloomfilter_check(_black_filter), - 1, - platform::errors::PreconditionNotMet( - "The black filter is not loaded successfully, please make sure " - "'black_list_len': %d is valid for Input(BlackList).", - black_list_len)); - } - } - - drop_pos->Resize(common::make_ddim( - {bottom->dims()[0] * bottom->dims()[1] * _pyramid_layer, 1})); - std::vector drop_pos_offset; - drop_pos_offset.resize(offset.size()); - drop_pos_offset[0] = 0; - int* iter = drop_pos->mutable_data(ctx.GetPlace()); - int* iter_end = iter; - - for (size_t i = 0; i < top_offset.size() - 1; ++i) { - int w = static_cast(offset[i + 1] - offset[i]); - int nsentense_with_pyramid = 0; - if (w < 2) { - nsentense_with_pyramid = 0; - } else { - for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) { - for (int l = 0; l < w - ilayer; ++l) { - if (should_use_term(_filter, - _black_filter, - (const float*)(bottom_data + offset[i] + l), - ilayer + 1)) { - if (_is_training != 0) { - unsigned int rand_val = rand_r(&_seed); - double rate = static_cast(rand_val) / (RAND_MAX); - *(iter_end++) = (rate < _drop_out_percent ? 0 : 1); - } else { - *(iter_end++) = 1; - } - } else { - *(iter_end++) = 0; - } - } - } - nsentense_with_pyramid = - static_cast(std::count(iter, iter_end, 1)); - iter = iter_end; - } - drop_pos_offset[i + 1] = drop_pos_offset[i] + nsentense_with_pyramid; - top_offset[i + 1] = - top_offset[i] + - (nsentense_with_pyramid == 0 ? 1 : nsentense_with_pyramid); - } - - int top_l = static_cast(top_offset[top_offset.size() - 1]); - - framework::LoD top_lod; - top_lod.push_back(top_offset); - top->set_lod(top_lod); - top->Resize(common::make_ddim({top_l, _num_emb})); - auto* top_data = top->mutable_data(ctx.GetPlace()); - - framework::LoD drop_pos_lod; - drop_pos_lod.push_back(drop_pos_offset); - drop_pos->set_lod(drop_pos_lod); - - iter = drop_pos->mutable_data(ctx.GetPlace()); - int top_counter = 0; - for (size_t i = 0; i < offset.size() - 1; ++i) { - int w_drop = - static_cast(drop_pos_offset[i + 1] - drop_pos_offset[i]); - int w = static_cast(offset[i + 1] - offset[i]); - if (w_drop == 0) { - if (w >= 2) { - for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; - ++ilayer) { - for (int l = 0; l < w - ilayer; ++l) { - iter++; - } - } - } - auto* top_pos = top_data + top_counter++ * _num_emb; - memset(top_pos, 0, _num_emb * sizeof(T)); - continue; - } - if (w >= 2) { - for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) { - for (int l = 0; l < w - ilayer; ++l) { - if (*(iter++) == 0) { - // do nothing - } else { - auto* top_pos = top_data + top_counter++ * _num_emb; - hash_embedding_ff((const float*)(bottom_data + offset[i] + l), - ilayer + 1, - top_pos, - weights, - _num_emb, - _rand_len, - _space_len); - } - } - } - } - } - if (iter != iter_end) { - exit(1); - } - auto weight_type = framework::TransToProtoVarType(_blobs_0->dtype()); - if (_is_training == 0 && weight_type != framework::proto::VarType::INT8) { - axpy_noadd(top_data, - top_data, - top->dims()[0] * top->dims()[1], - _drop_out_percent); - } - } -}; - -class PyramidHashOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), - true, - platform::errors::NotFound( - "Input(X) of PyramidHashOpGrad is not found.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("W"), - true, - platform::errors::NotFound( - "Input(W) of PyramidHashOpGrad is not found.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"), - true, - platform::errors::NotFound( - "Input(DropPos) of PyramidHashOpGrad is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("X_Temp_Out"), - true, - platform::errors::NotFound( - "Input(X_Temp_Out) of PyramidHashOpGrad is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Out")), - true, - platform::errors::NotFound( - "Input(Out@Grad) of PyramidHashOpGrad is not found.")); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "W"), - ctx.GetPlace()); - } -}; - -template -class PyramidHashGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op_desc_ptr) const override { - op_desc_ptr->SetType("pyramid_hash_grad"); - op_desc_ptr->SetInput("X", this->Input("X")); - op_desc_ptr->SetInput("W", this->Input("W")); - op_desc_ptr->SetInput("DropPos", this->Output("DropPos")); - op_desc_ptr->SetInput("X_Temp_Out", this->Output("X_Temp_Out")); - - op_desc_ptr->SetInput(framework::GradVarName("Out"), - this->OutputGrad("Out")); - op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op_desc_ptr->SetAttrMap(this->Attrs()); - } -}; - -template -class CPUPyramidHashOPGradKernel : public framework::OpKernel { - public: - void hash_embedding_bp(const T* hash_id, - int len, - const T* top_pos, - T* weights, - T mlr, - int _num_emb, - int _rand_len, - int _space_len) const { - for (int j = 0; j != _num_emb; j += _rand_len) { - unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len; - axpy(top_pos + j, weights + pos, _rand_len, mlr); - } - } - - void Compute(const framework::ExecutionContext& ctx) const override { - auto* bottom = ctx.Input("X"); - auto* _blobs = ctx.Input("W"); - auto* drop_pos = ctx.Input("DropPos"); - auto* top = ctx.Input(framework::GradVarName("Out")); - - int _num_emb = ctx.Attr("num_emb"); - float _lr = ctx.Attr("lr"); - int _rand_len = ctx.Attr("rand_len"); - int _space_len = ctx.Attr("space_len"); - int _pyramid_layer = ctx.Attr("pyramid_layer"); - - auto* buff = ctx.Input("X_Temp_Out"); - auto* bottom_data = buff->data(); - - int _slot_len = static_cast(bottom->dims()[0]); - if (static_cast(_slot_len) == bottom->lod()[0].size() - 1 && - std::count(bottom_data, bottom_data + _slot_len, -1) == _slot_len) { - return; - } - - auto& offset = bottom->lod()[0]; - auto& drop_pos_offset = drop_pos->lod()[0]; - - const auto* top_diff = top->data(); - // in-place update weight, so need const_cast - T* weights = const_cast(_blobs->data()); - T mlr = -1.0 * _lr; - - const int* iter = drop_pos->data(); - int top_counter = 0; - for (size_t i = 0; i < offset.size() - 1; ++i) { - int w = static_cast(offset[i + 1] - offset[i]); - int w_drop = - static_cast(drop_pos_offset[i + 1] - drop_pos_offset[i]); - if (w_drop == 0) { - top_counter++; - } - if (w > 1) { - for (int ilayer = 1; ilayer < _pyramid_layer && ilayer < w; ++ilayer) { - for (int l = 0; l < w - ilayer; ++l) { - if (*(iter++) == 0) { - // do nothing - } else { - const T* top_pos = top_diff + top_counter++ * _num_emb; - hash_embedding_bp((const T*)(bottom_data + offset[i] + l), - ilayer + 1, - top_pos, - weights, - mlr, - _num_emb, - _rand_len, - _space_len); - } - } - } - } else { - // do nothing - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(pyramid_hash, - ops::PyramidHashOP, - ops::PyramidHashOpMaker, - ops::PyramidHashGradOpMaker, - ops::PyramidHashGradOpMaker); -REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad); - -PD_REGISTER_STRUCT_KERNEL( - pyramid_hash, CPU, ALL_LAYOUT, ops::CPUPyramidHashOPKernel, float, int8_t) { -} -PD_REGISTER_STRUCT_KERNEL(pyramid_hash_grad, - CPU, - ALL_LAYOUT, - ops::CPUPyramidHashOPGradKernel, - float) {} diff --git a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc deleted file mode 100644 index 319fad9b392317..00000000000000 --- a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc +++ /dev/null @@ -1,303 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device_context.h" - -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace framework { -class LoDRankTable; -class OpDesc; -class Scope; -} // namespace framework -namespace imperative { -class OpBase; -} // namespace imperative -} // namespace paddle - -namespace paddle { -namespace operators { - -class ReorderLoDTensorByRankTableOpProtoMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "X", - "(phi::DenseTensor), the input lod tensor to be reordered according to " - "Input(RankTable)."); - AddInput("RankTable", - "(LoDRankTable), the rank table according to which Input(X) is " - "reordered."); - AddOutput("Out", "phi::DenseTensor, the reordered lod tensor."); - AddComment(R"DOC(ReorderLoDTensorByRankTable operator. - -Input(X) is a batch of sequences. Input(RankTable) stores new orders of the -input sequence batch. The reorder_lod_tensor_by_rank operator reorders the -Input(X) according to the information provided by Input(RankTable). - -For example: - -If the indices stored in the Input(RankTable) are [3, 0, 2, 1], the -Input(X) will be reordered that the fourth sequence in Input(X) will become the -first one, and then followed by the original first, third, and the second one. - -This is: -X = [Seq0, Seq1, Seq2, Seq3]. The indices in RankTable are [3, 0, 2, 1]. -Out = [Seq3, Seq0, Seq2, Seq1] with a new LoD information. - -If the LoD information of Input(X) is empty, this means Input(X) is not sequence -data. This is also identical to a batch of sequences where each sequence has a -fixed length 1. In this case, the reorder_lod_tensor_by_rank operator reorders -each slice of Input(X) along the first axis according to Input(RankTable). - -This is: -X = [Slice0, Slice1, Slice2, Slice3] and its LoD information is empty. The -indices in RankTable are [3, 0, 2, 1]. -Out = [Slice3, Slice0, Slice2, Slice1] with no LoD information is appended. - -**NOTE**: -This operator sorts Input(X) according to a given LoDRankTable which does -not need to be calculated according to Input(X). It can be calculated according -to another different sequence, and then this operator sorts Input(X) according -to the given LoDRankTable. - -)DOC"); - } -}; - -class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { - public: - ReorderLoDTensorByRankTableBase(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &x = GET_DATA_SAFELY(scope.FindVar(Input("X")), - "Input", - "X", - "ReorderLoDTensorByRankTable") - .Get(); - auto &rank_table = GET_DATA_SAFELY(scope.FindVar(Input("RankTable")), - "Input", - "RankTable", - "ReorderLoDTensorByRankTable") - .Get(); - auto &out = *(GET_DATA_SAFELY(scope.FindVar(Output("Out")), - "Output", - "Out", - "ReorderLoDTensorByRankTable") - .GetMutable()); - - out.Resize(x.dims()); - out.mutable_data(x.place(), x.type()); - this->process(place, x, rank_table, &out); - } - - protected: - virtual void process(const platform::Place &place, - const phi::DenseTensor &x, - const framework::LoDRankTable &rank_table, - phi::DenseTensor *out) const = 0; - - struct AbsoluteRankTableItem { - size_t offset; // the absolute/accumulated offset. - size_t length; // the length - framework::LoD lod; - }; - - std::vector GetAbsoluteOffsetAndLengthByLoDRankTable( - const phi::DenseTensor &x) const { - std::vector absolute_table; - - if (x.lod().empty()) { - // For Tensor without lod, such as the output of sequence_pool_op - size_t size = x.dims()[0]; - absolute_table.reserve(size); - for (size_t i = 0; i < size; ++i) { - absolute_table.emplace_back(); - absolute_table.back().length = 1; - absolute_table.back().offset = i; - } - } else { - size_t level = 0; - size_t size = x.lod()[level].size(); - - for (size_t i = 0; i < size - 1; ++i) { - auto lod_offset = - framework::GetSubLoDAndAbsoluteOffset(x.lod(), i, i + 1, level); - - auto &offset = lod_offset.second; - - absolute_table.emplace_back(); - absolute_table.back().length = offset.second - offset.first; - absolute_table.back().offset = offset.first; - absolute_table.back().lod = lod_offset.first; - } - } - - return absolute_table; - } - - size_t CopyTensorAndLod(const platform::Place &place, - const AbsoluteRankTableItem &item, - const phi::DenseTensor &x, - phi::DenseTensor *out, - size_t out_offset) const { - auto &out_lod = *out->mutable_lod(); - auto len = item.length; - auto x_offset = item.offset; - - if (out_lod.empty()) { - for (size_t i = 0; i < item.lod.size(); ++i) { - out_lod.push_back(std::vector({0})); - } - } - - for (size_t i = 0; i < out_lod.size(); ++i) { - auto &out_v = out_lod[i]; - auto &new_lod_v = item.lod[i]; - - for (auto &detail : new_lod_v) { - out_v.push_back(out_v.back() + detail); - } - } - - auto x_sliced = x.Slice(x_offset, x_offset + len); // NOLINT - auto out_sliced = out->Slice(out_offset, out_offset + len); // NOLINT - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - framework::TensorCopy(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); - out_offset += len; - return out_offset; - } -}; - -class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase { - public: - ReorderLoDTensorByRankTableOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {} - - protected: - void process(const platform::Place &place, - const phi::DenseTensor &x, - const framework::LoDRankTable &rank_table, - phi::DenseTensor *out) const override { - auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x); - size_t out_offset = 0; - out->mutable_lod()->clear(); - for (auto &item : rank_table.items()) { - PADDLE_ENFORCE_LT(item.index, - absolute_table.size(), - platform::errors::OutOfRange( - "The value of rank_table is out of range.")); - out_offset = CopyTensorAndLod( - place, absolute_table[item.index], x, out, out_offset); - } - } -}; - -class IdentityInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *context) const override { - context->SetOutputDim("Out", context->GetInputDim("X")); - // X'lod and Out'lod is different on runtime, so there is no need to call - // ShareLoD for runtime. While the setting of Out's lod is done in detail - // kernel implementation. - if (!context->IsRuntime()) { - context->ShareLoD("X", /*->*/ "Out"); - } - } -}; - -template -class ReorderLodTensorByRankGradOpMaker - : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("reorder_lod_tensor_by_rank_grad"); - grad_op->SetInput("X", this->OutputGrad("Out")); - grad_op->SetOutput("Out", this->InputGrad("X")); - grad_op->SetInput("RankTable", this->Input("RankTable")); - } -}; - -class ReorderLoDTensorByRankGradOp : public ReorderLoDTensorByRankTableBase { - public: - ReorderLoDTensorByRankGradOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {} - - protected: - void process(const platform::Place &place, - const phi::DenseTensor &x, - const framework::LoDRankTable &rank_table, - phi::DenseTensor *out) const override { - auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x); - - // offsets = enumerate([item.index for item in rank_table.items()]) - std::vector> offsets; - offsets.reserve(rank_table.items().size()); - for (size_t i = 0; i < rank_table.items().size(); ++i) { - offsets.push_back({i, rank_table.items()[i].index}); - } - - // offsets.sort(key=lambda x: x[1]) - std::sort( - offsets.begin(), - offsets.end(), - [](const std::pair &a, - const std::pair &b) { return a.second < b.second; }); - - // Copy TensorAndLod - size_t out_offset = 0; - for (auto &offset : offsets) { - out_offset = this->CopyTensorAndLod( - place, absolute_table[offset.first], x, out, out_offset); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR( - reorder_lod_tensor_by_rank, - ops::ReorderLoDTensorByRankTableOp, - ops::ReorderLodTensorByRankGradOpMaker, - ops::ReorderLodTensorByRankGradOpMaker, - ops::ReorderLoDTensorByRankTableOpProtoMaker, - ops::IdentityInferShape); -REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad, - ops::ReorderLoDTensorByRankGradOp, - ops::IdentityInferShape); diff --git a/python/paddle/incubate/layers/__init__.py b/python/paddle/incubate/layers/__init__.py index f25a845d0a4dc6..cadfbb4186bb7b 100644 --- a/python/paddle/incubate/layers/__init__.py +++ b/python/paddle/incubate/layers/__init__.py @@ -28,7 +28,6 @@ partial_sum, pow2_decay_with_linear_warmup, rank_attention, - search_pyramid_hash, shuffle_batch, tdm_child, tdm_sampler, diff --git a/python/paddle/incubate/layers/nn.py b/python/paddle/incubate/layers/nn.py index 5b6236567e6491..34ba1ad7188079 100644 --- a/python/paddle/incubate/layers/nn.py +++ b/python/paddle/incubate/layers/nn.py @@ -322,128 +322,6 @@ class number. return output -def search_pyramid_hash( - input, - num_emb, - space_len, - pyramid_layer, - rand_len, - drop_out_percent, - is_training, - use_filter, - white_list_len, - black_list_len, - seed, - lr, - param_attr=None, - param_attr_wl=None, - param_attr_bl=None, - name=None, - distribute_update_vars=None, - dtype='float32', -): - """ - **Pyramid hash embedding** - - Args: - input (Tensor): LoDTensor Tensor contained the IDs' information. - num_emb (int): The embedding size of output. - space_len (int): The length of pyramid hash embedding space. - pyramid_layer (int): The number of pyramid layers. It should be greater than 2. - rand_len (int): The minimum length of pyramid hash cell. - drop_out_percent (float): The probability of dropping out the input token randomly. - It should satisfy: [0., 1.]. - is_training (bool): Whether in training or testing phrase. - use_filter (bool): If set True, the white filter and black filter should be given by - :attr:`param_attr_wl` and :attr:`param_attr_bl` . - white_list_len (int): If set :math:`white_list_len>0` , white filter with shape [white_list_len, 1] - should be provided by param_attr_wl. - black_list_len (int): If set :math:`black_list_len>0` , black filter with shape [black_list_len, 1] - should be provided by param_attr_bl. - seed (int): The number of random seed. - lr (float): The learning rate of weight created by :attr:`param_attr` with shape [space_len+rand_len, 1] - in this layer. - param_attr (ParamAttr, optional): To specify the weight parameter property. Default: None, which means the - default weight parameter property is used. See usage for details in :ref:`api_paddle_ParamAttr` . - param_attr_wl (ParamAttr, optional): Specified parameters of white filter. Default: None. - param_attr_bl (ParamAttr, optional): Specified parameters of black filter. Default: None. - distribute_update_vars(list[ParamAttr.name], optional): Decided which params should be updated in distribute training. - Used in Distribute Transpiler to create a trainer/server program. Default: None. - name (str, optional): The default value is None. Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name` . Default: None. - dtype (str, optional): The data type of output Tensor, float32. Default: float32. - - Returns: - Tensor: LoDTensor of pyramid hash embedding. - """ - helper = LayerHelper('search_pyramid_hash', **locals()) - - w_shape = [space_len + rand_len, 1] - w = helper.create_parameter( - attr=param_attr, shape=w_shape, dtype=dtype, is_bias=False - ) - w.stop_gradient = True - - input_vars = {'X': input, 'W': w} - if white_list_len > 0: - wl_shape = [white_list_len, 1] - white_list = helper.create_parameter( - attr=param_attr_wl, shape=wl_shape, dtype=dtype, is_bias=False - ) - white_list.stop_gradient = True - input_vars['WhiteList'] = white_list - - if black_list_len >= 0: - bl_shape = [black_list_len, 1] - black_list = helper.create_parameter( - attr=param_attr_bl, shape=bl_shape, dtype=dtype, is_bias=False - ) - black_list.stop_gradient = True - input_vars['BlackList'] = black_list - - distribute_update_vars_str = "" - if distribute_update_vars: - assert isinstance(distribute_update_vars, list) - special_name_list = [] - if param_attr: - special_name_list.append(param_attr.name) - if param_attr_wl: - special_name_list.append(param_attr_wl.name) - if param_attr_bl: - special_name_list.append(param_attr_bl.name) - for param in distribute_update_vars: - if param not in special_name_list: - raise ValueError( - f"Pyramid Hash layer didn't have parameter {param}" - ) - distribute_update_vars_str = ",".join(distribute_update_vars) - - res = helper.create_variable_for_type_inference(dtype) - drop_pos = helper.create_variable_for_type_inference(dtype) - x_temp_out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='pyramid_hash', - inputs=input_vars, - outputs={"Out": res, "X_Temp_Out": x_temp_out, 'DropPos': drop_pos}, - attrs={ - 'num_emb': num_emb, - 'space_len': space_len, - 'pyramid_layer': pyramid_layer, - 'rand_len': rand_len, - 'drop_out_percent': drop_out_percent, - 'is_training': is_training, - 'use_filter': use_filter, - 'white_list_len': white_list_len, - 'black_list_len': black_list_len, - 'seed': seed, - 'lr': lr, - 'distribute_update_vars': distribute_update_vars_str, - }, - ) - - return res - - def shuffle_batch(x, seed=None): """ This layer shuffle input tensor :attr:`x` . Normally, :attr:`x` is 2-D LoDTensor. diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 63d84ece4aa988..3ba9a2f1a492e8 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -267,10 +267,8 @@ endif() if(WITH_COVERAGE OR WIN32 OR WITH_NV_JETSON) - list(REMOVE_ITEM TEST_OPS test_pyramid_hash_op) -endif() -list(REMOVE_ITEM TEST_OPS test_fleet_pyramid_hash) +endif() if((WITH_ROCM OR WITH_GPU) OR NOT WITH_MKLML) # matmul with multiple heads need MKL support diff --git a/test/legacy_test/test_fleet_pyramid_hash.py b/test/legacy_test/test_fleet_pyramid_hash.py deleted file mode 100644 index cfd02ee72ced9b..00000000000000 --- a/test/legacy_test/test_fleet_pyramid_hash.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import paddle -from paddle import base -from paddle.incubate.distributed.fleet import role_maker -from paddle.incubate.distributed.fleet.parameter_server.distribute_transpiler import ( - fleet, -) -from paddle.incubate.distributed.fleet.parameter_server.distribute_transpiler.distributed_strategy import ( - StrategyFactory, -) -from paddle.incubate.layers.nn import search_pyramid_hash - - -class TestPyramidHashOpApi(unittest.TestCase): - def test_dist_geo_server_transpiler(self): - num_voc = 128 - embed_dim = 64 - x_shape, x_lod = [16, 10], [[3, 5, 2, 6]] - x = paddle.static.data( - name='x', shape=x_shape, dtype='int32', lod_level=1 - ) - hash_embd = search_pyramid_hash( - input=x, - num_emb=embed_dim, - space_len=num_voc * embed_dim, - pyramid_layer=4, - rand_len=16, - drop_out_percent=0.5, - is_training=True, - use_filter=False, - white_list_len=6400, - black_list_len=2800, - seed=3, - lr=0.002, - param_attr=base.ParamAttr( - name="PyramidHash_emb_0", - learning_rate=0, - ), - param_attr_wl=base.ParamAttr( - name="Filter", - learning_rate=0, - ), - param_attr_bl=None, - distribute_update_vars=["PyramidHash_emb_0"], - name=None, - ) - - cost = paddle.sum(hash_embd) - - role = role_maker.UserDefinedRoleMaker( - current_id=0, - role=role_maker.Role.SERVER, - worker_num=2, - server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"], - ) - - fleet.init(role) - - strategy = StrategyFactory.create_geo_strategy(5) - optimizer = paddle.optimizer.SGD(0.1) - optimizer = fleet.distributed_optimizer(optimizer, strategy) - optimizer.minimize(cost) - - pserver_startup_program = fleet.startup_program - pserver_mian_program = fleet.main_program - - -if __name__ == "__main__": - paddle.enable_static() - unittest.main() diff --git a/test/legacy_test/test_polygon_box_transform.py b/test/legacy_test/test_polygon_box_transform.py deleted file mode 100644 index 6e3f19927d5cc2..00000000000000 --- a/test/legacy_test/test_polygon_box_transform.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from op_test import OpTest - - -def PolygonBoxRestore(input): - shape = input.shape - batch_size = shape[0] - geo_channels = shape[1] - h = shape[2] - w = shape[3] - h_indexes = ( - np.array(list(range(h)) * w).reshape([w, h]).transpose()[np.newaxis, :] - ) # [1, h, w] - w_indexes = np.array(list(range(w)) * h).reshape([h, w])[ - np.newaxis, : - ] # [1, h, w] - indexes = np.concatenate((w_indexes, h_indexes))[ - np.newaxis, : - ] # [1, 2, h, w] - indexes = indexes.repeat([geo_channels / 2], axis=0)[ - np.newaxis, : - ] # [1, geo_channels/2, 2, h, w] - indexes = indexes.repeat( - [batch_size], axis=0 - ) # [batch_size, geo_channels/2, 2, h, w] - return ( - indexes.reshape(input.shape) * 4 - input - ) # [batch_size, geo_channels, h, w] - - -class TestPolygonBoxRestoreOp(OpTest): - def config(self): - self.input_shape = (1, 8, 2, 2) - - def setUp(self): - self.config() - self.op_type = "polygon_box_transform" - input = np.random.random(self.input_shape).astype("float32") - self.inputs = {'Input': input} - output = PolygonBoxRestore(input) - self.outputs = {'Output': output} - - def test_check_output(self): - self.check_output() - - -class TestCase1(TestPolygonBoxRestoreOp): - def config(self): - self.input_shape = (2, 10, 3, 2) - - -class TestCase2(TestPolygonBoxRestoreOp): - def config(self): - self.input_shape = (3, 12, 4, 5) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/legacy_test/test_positive_negative_pair_op.py b/test/legacy_test/test_positive_negative_pair_op.py deleted file mode 100644 index cf3440f365cd7c..00000000000000 --- a/test/legacy_test/test_positive_negative_pair_op.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import unittest - -import numpy as np -from op_test import OpTest - - -def py_pnpair_op(score, label, query, column=-1, weight=None): - # group by query id - predictions = {} - batch_size = label.shape[0] - if weight is None: - weight = np.ones(shape=(batch_size, 1)).astype('float32') - for s, l, q, w in zip(score, label, query, weight): - s, l, q, w = s[column], l[0], q[0], w[0] - if q not in predictions: - predictions[q] = [] - predictions[q].append((s, l, w)) - - # accumulate statistics - pos, neg, neu = 0, 0, 0 - for _, ranks in predictions.items(): - for e1, e2 in itertools.combinations(ranks, 2): - s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2] - w = (w1 + w2) * 0.5 - if l1 == l2: - continue - if s1 == s2: - neu += w - elif (s1 - s2) * (l1 - l2) > 0: - pos += w - else: - neg += w - - return ( - np.array([pos]).astype('float32'), - np.array([neg]).astype('float32'), - np.array([neu]).astype('float32'), - ) - - -class TestPositiveNegativePairOp(OpTest): - def setUp(self): - self.op_type = 'positive_negative_pair' - batch_size = 20 - max_query_id = 5 - score = np.random.normal(size=(batch_size, 1)).astype('float32') - label = np.random.normal(size=(batch_size, 1)).astype('float32') - query = np.array( - [np.random.randint(max_query_id) for i in range(batch_size)] - ) - query = np.reshape(query, newshape=(batch_size, 1)).astype('int64') - - pos, neg, neu = py_pnpair_op(score, label, query) - self.inputs = {'Score': score, 'Label': label, 'QueryID': query} - self.attrs = {'column': -1} - self.outputs = { - 'PositivePair': pos, - 'NegativePair': neg, - 'NeutralPair': neu, - } - - def test_check_output(self): - # NODE(yjjiang11): This op will be deprecated. - self.check_output(check_dygraph=False) - - -class TestPositiveNegativePairOpAccumulateWeight(OpTest): - def setUp(self): - self.op_type = 'positive_negative_pair' - batch_size = 20 - max_query_id = 5 - max_random_num = 2 << 15 - score_dim = 2 - score = np.random.normal(size=(batch_size, 2)).astype('float32') - label = np.random.normal(size=(batch_size, 1)).astype('float32') - weight = np.random.normal(size=(batch_size, 1)).astype('float32') - query = np.array( - [np.random.randint(max_query_id) for i in range(batch_size)] - ) - query = np.reshape(query, newshape=(batch_size, 1)).astype('int64') - acc_pos = np.reshape( - np.random.randint(max_random_num), newshape=(1) - ).astype('float32') - acc_neg = np.reshape( - np.random.randint(max_random_num), newshape=(1) - ).astype('float32') - acc_neu = np.reshape( - np.random.randint(max_random_num), newshape=(1) - ).astype('float32') - column = np.random.randint(score_dim) - - pos, neg, neu = py_pnpair_op( - score, label, query, column=column, weight=weight - ) - self.inputs = { - 'Score': score, - 'Label': label, - 'QueryID': query, - 'AccumulatePositivePair': acc_pos, - 'AccumulateNegativePair': acc_neg, - 'AccumulateNeutralPair': acc_neu, - 'Weight': weight, - } - self.attrs = {'column': column} - self.outputs = { - 'PositivePair': pos + acc_pos, - 'NegativePair': neg + acc_neg, - 'NeutralPair': neu + acc_neu, - } - - def test_check_output(self): - self.check_output(check_dygraph=False) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/legacy_test/test_proximal_gd_op.py b/test/legacy_test/test_proximal_gd_op.py deleted file mode 100644 index d55c1ffcc2d8d7..00000000000000 --- a/test/legacy_test/test_proximal_gd_op.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from op_test import OpTest - - -class TestProximalGDOp(OpTest): - def setUp(self): - self.op_type = "proximal_gd" - w = np.random.random((102, 105)).astype("float32") - g = np.random.random((102, 105)).astype("float32") - lr = np.array([0.1]).astype("float32") - l1 = 0.1 - l2 = 0.2 - - self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} - self.attrs = {'l1': l1, 'l2': l2} - prox_param = w - lr * g - param_out = 0.0 - if l1 > 0.0: - x = np.abs(prox_param) - lr * l1 - x[x < 0] = 0 - param_out = np.sign(prox_param) * (x / (1.0 + lr * l2)) - else: - param_out = prox_param / (1.0 + lr * l2) - - self.outputs = {'ParamOut': param_out} - - def test_check_output(self): - self.check_output() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/legacy_test/test_pyramid_hash_op.py b/test/legacy_test/test_pyramid_hash_op.py deleted file mode 100644 index b7173774921644..00000000000000 --- a/test/legacy_test/test_pyramid_hash_op.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np - -import paddle -from paddle import base -from paddle.incubate.layers.nn import search_pyramid_hash - - -class TestPyramidHashOpApi(unittest.TestCase): - def test_api(self): - num_voc = 128 - embed_dim = 64 - x_shape, x_lod = [16, 10], [[3, 5, 2, 6]] - x = paddle.static.data( - name='x', shape=x_shape, dtype='int32', lod_level=1 - ) - hash_embed = search_pyramid_hash( - input=x, - num_emb=embed_dim, - space_len=num_voc * embed_dim, - pyramid_layer=4, - rand_len=16, - drop_out_percent=0.5, - is_training=True, - use_filter=False, - white_list_len=6400, - black_list_len=2800, - seed=3, - lr=0.002, - param_attr=base.ParamAttr( - name="PyramidHash_emb_0", - learning_rate=0, - ), - param_attr_wl=base.ParamAttr( - name="Filter", - learning_rate=0, - ), - param_attr_bl=None, - distribute_update_vars=["PyramidHash_emb_0"], - name=None, - ) - - place = base.CPUPlace() - x_tensor = base.create_lod_tensor( - np.random.randint(0, num_voc, x_shape).astype('int32'), x_lod, place - ) - - exe = base.Executor(place) - exe.run(base.default_startup_program()) - ret = exe.run( - feed={'x': x_tensor}, fetch_list=[hash_embed], return_numpy=False - ) - - -if __name__ == "__main__": - unittest.main()