From b4dc977c8346df27026c84a577bf4782d4efe4a1 Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 16 Apr 2024 19:23:47 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=206th=20Fundable=20Projects?= =?UTF-8?q?=203=20No.273=E3=80=91Remove=20fluid=20operator=20precision=5Fr?= =?UTF-8?q?ecall=20(#63483)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/CMakeLists.txt | 1 - paddle/fluid/operators/metrics/CMakeLists.txt | 6 - .../operators/metrics/precision_recall_op.cc | 250 ------------------ .../operators/metrics/precision_recall_op.h | 186 ------------- .../operators/metrics/unity_build_rule.cmake | 7 - test/legacy_test/test_precision_recall_op.py | 206 --------------- 6 files changed, 656 deletions(-) delete mode 100644 paddle/fluid/operators/metrics/CMakeLists.txt delete mode 100644 paddle/fluid/operators/metrics/precision_recall_op.cc delete mode 100644 paddle/fluid/operators/metrics/precision_recall_op.h delete mode 100644 paddle/fluid/operators/metrics/unity_build_rule.cmake delete mode 100644 test/legacy_test/test_precision_recall_op.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 5ac6368c91e4b7..9126023d389bed 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -19,7 +19,6 @@ add_subdirectory(controlflow) add_subdirectory(detection) add_subdirectory(elementwise) add_subdirectory(fused) -add_subdirectory(metrics) add_subdirectory(optimizers) add_subdirectory(reduce_ops) add_subdirectory(sequence_ops) diff --git a/paddle/fluid/operators/metrics/CMakeLists.txt b/paddle/fluid/operators/metrics/CMakeLists.txt deleted file mode 100644 index b968dbf288ee22..00000000000000 --- a/paddle/fluid/operators/metrics/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -include(operators) -if(WITH_UNITY_BUILD) - # Load Unity Build rules for operators in paddle/fluid/operators/metrics. - include(unity_build_rule.cmake) -endif() -register_operators() diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc deleted file mode 100644 index 95a66cb2edd1dd..00000000000000 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ /dev/null @@ -1,250 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/metrics/precision_recall_op.h" - -namespace paddle { -namespace operators { - -class PrecisionRecallOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("MaxProbs"), - true, - phi::errors::NotFound( - "PrecisionRecallOp Input(MaxProbs) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Indices"), - true, - phi::errors::NotFound( - "PrecisionRecallOp Input(Indices) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Labels"), - true, - phi::errors::NotFound( - "PrecisionRecallOp Input(Labels) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("BatchMetrics"), - true, - phi::errors::NotFound( - "PrecisionRecallOp Output(BatchMetrics) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("AccumMetrics"), - true, - phi::errors::NotFound( - "PrecisionRecallOp Output(AccumMetrics) should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("AccumStatesInfo"), - true, - phi::errors::NotFound( - "PrecisionRecallOp Output(AccumStatesInfo) should not be null.")); - - int64_t cls_num = - static_cast(ctx->Attrs().Get("class_number")); - auto max_probs_dims = ctx->GetInputDim("MaxProbs"); - auto labels_dims = ctx->GetInputDim("Labels"); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(max_probs_dims[1], - 1, - phi::errors::InvalidArgument( - "Each instance of PrecisionRecallOp " - "Input(MaxProbs) contains one max probability, " - "the shape of Input(MaxProbs) should be " - "[batch_size, 1], the 2nd dimension of " - "Input(MaxProbs) should be 1. But the 2nd " - "dimension we received is %d", - max_probs_dims[1])); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Indices"), - max_probs_dims, - phi::errors::InvalidArgument( - "The shape of PrecisionRecallOp Input(Indices) should be same " - "with " - "max_probs_dims. But received the shape of Input(Indices) is " - "[%d, %d], max_probs_dims is [%d, %d]", - ctx->GetInputDim("Indices")[0], - ctx->GetInputDim("Indices")[1], - max_probs_dims[0], - max_probs_dims[1])); - PADDLE_ENFORCE_EQ( - max_probs_dims[0], - labels_dims[0], - phi::errors::InvalidArgument( - "The 1st dimension of PrecisionRecallOp Input(MaxProbs) and " - "Input(Labels) both should be batch_size" - "But the 1st dimension we received max_probs_dims[0] = %d, " - "labels_dims[0] = %d", - max_probs_dims[0], - labels_dims[0])); - PADDLE_ENFORCE_EQ(labels_dims[1], - 1, - phi::errors::InvalidArgument( - "The 2nd dimension of PrecisionRecallOp " - "Input(Labels) contains instance label and " - "the shape should be equal to 1. But the 2nd " - "dimension we received is %d", - labels_dims[1])); - } - if (ctx->HasInput("Weights")) { - auto weights_dims = ctx->GetInputDim("Weights"); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - weights_dims, - common::make_ddim({max_probs_dims[0], 1}), - phi::errors::InvalidArgument( - "The shape of PrecisionRecallOp Input(Weights) should be " - "[batch_size, 1]. But the shape we received is [%d, %d]", - weights_dims[0], - weights_dims[1])); - } - } - if (ctx->HasInput("StatesInfo")) { - auto states_dims = ctx->GetInputDim("StatesInfo"); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - states_dims, - common::make_ddim({cls_num, 4}), - phi::errors::InvalidArgument( - "The shape of PrecisionRecallOp Input(StatesInfo) should be " - "[class_number, 4]. But the shape we received is [%d, %d]", - states_dims[0], - states_dims[1])); - } - } - - // Layouts of BatchMetrics and AccumMetrics both are: - // [ - // macro average precision, macro average recall, macro average F1 score, - // micro average precision, micro average recall, micro average F1 score - // ] - ctx->SetOutputDim("BatchMetrics", {6}); - ctx->SetOutputDim("AccumMetrics", {6}); - // Shape of AccumStatesInfo is [class_number, 4] - // The layout of each row is: - // [ TP, FP, TN, FN ] - ctx->SetOutputDim("AccumStatesInfo", {cls_num, 4}); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey( - OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"), - ctx.GetPlace()); - } -}; - -class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("MaxProbs", - "(Tensor, default Tensor) A 2-D tensor with shape N x 1, " - "where N is the batch size. Each row contains the max probability " - "of an instance which computed by the previous top_k (k=1) " - "operator."); - AddInput("Indices", - "(Tensor, default Tensor) A 2-D tensor with shape N x 1, " - "where N is the batch size. Each row contains the corresponding " - "index which computed by the previous top_k (k=1) operator."); - AddInput("Labels", - "(Tensor, default Tensor) A 2-D tensor with shape N x 1, " - "where N is the batch size. Each element is a label and the " - "value should be in [0, class_number - 1]."); - AddInput("Weights", - "(Tensor, default Tensor) A 2-D tensor with shape N x 1, " - "where N is the batch size. This input is optional. If provided, " - "weight of instance would be considered when computing metrics.") - .AsDispensable(); - AddInput("StatesInfo", - "(Tensor, default Tensor) A 2-D tensor with shape D x 4, " - "where D is the number of classes. This input is optional. If " - "provided, current state will be accumulated to this state and " - "the accumulation state will be the output state.") - .AsDispensable(); - AddOutput("BatchMetrics", - "(Tensor, default Tensor) A 1-D tensor with shape {6}. " - "This output tensor contains metrics for current batch data. " - "The layout is [macro average precision, macro average recall, " - "macro f1 score, micro average precision, micro average recall, " - "micro f1 score]."); - AddOutput("AccumMetrics", - "(Tensor, default Tensor) A 1-D tensor with shape {6}. " - "This output tensor contains metrics for accumulated data. " - "The layout is [macro average precision, macro average recall, " - "macro f1 score, micro average precision, micro average recall, " - "micro f1 score]."); - AddOutput("AccumStatesInfo", - "(Tensor, default Tensor) A 2-D tensor with shape D x 4, " - "where D is equal to class number. This output tensor contains " - "accumulated state variables used to compute metrics. The layout " - "for each class is [true positives, false positives, " - "true negatives, false negatives]."); - AddAttr("class_number", "(int) Number of classes to be evaluated."); - AddComment(R"DOC( -Precision Recall Operator. - -When given Input(Indices) and Input(Labels), this operator can be used -to compute various metrics including: -1. macro average precision -2. macro average recall -3. macro f1 score -4. micro average precision -5. micro average recall -6. micro f1 score - -To compute the above metrics, we need to do statistics for true positives, -false positives and false negatives. Here the count of true negatives is not -necessary, but counting it may provide potential usage and the cost is -trivial, so the operator also provides the count of true negatives. - -We define state as a 2-D tensor with shape [class_number, 4]. Each row of a -state contains statistic variables for corresponding class. Layout of each row -is: TP(true positives), FP(false positives), TN(true negatives), -FN(false negatives). If Input(Weights) is provided, TP, FP, TN, FN will be -calculated by given weight instead of the instance count. - -This operator also supports metrics computing for cross-batch situation. To -achieve this, Input(StatesInfo) should be provided. State of current batch -data will be accumulated to Input(StatesInfo) and Output(AccumStatesInfo) -is the accumulation state. - -Output(BatchMetrics) is metrics of current batch data while -Output(AccumStatesInfo) is metrics of accumulation data. - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR( - precision_recall, - ops::PrecisionRecallOp, - ops::PrecisionRecallOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -PD_REGISTER_STRUCT_KERNEL(precision_recall, - CPU, - ALL_LAYOUT, - ops::PrecisionRecallKernel, - float, - double) {} diff --git a/paddle/fluid/operators/metrics/precision_recall_op.h b/paddle/fluid/operators/metrics/precision_recall_op.h deleted file mode 100644 index 8a276d2fa5a32f..00000000000000 --- a/paddle/fluid/operators/metrics/precision_recall_op.h +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -using EigenMatrix = framework::EigenMatrix; - -enum StateVariable { TP = 0, FP, TN, FN }; - -template -class PrecisionRecallKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in0 = ctx.Input("Indices"); - auto* in1 = ctx.Input("Labels"); - auto* in2 = ctx.Input("Weights"); - auto* in3 = ctx.Input("StatesInfo"); - auto* out0 = ctx.Output("BatchMetrics"); - auto* out1 = ctx.Output("AccumMetrics"); - auto* out2 = ctx.Output("AccumStatesInfo"); - - const int* ids_data = in0->data(); - const int* labels_data = in1->data(); - size_t cls_num = static_cast(ctx.Attr("class_number")); - const T* weights_data = in2 ? in2->data() : nullptr; - const T* states_data = in3 ? in3->data() : nullptr; - double* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); - double* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); - out2->mutable_data(ctx.GetPlace()); - auto accum_states = EigenMatrix::From(*out2); - accum_states.setZero(); - T* accum_states_data = out2->data(); - - size_t sample_num = in0->dims()[0]; - size_t state_var_num = 4; // TP FP TN FN - - // get states info for current batch - for (size_t i = 0; i < sample_num; ++i) { - size_t idx = ids_data[i]; - size_t label = labels_data[i]; - - PADDLE_ENFORCE_GE( - idx, - 0, - phi::errors::InvalidArgument( - "Class index of each instance should be " - "greater than or equal to 0, But the index we received is %d", - idx)); - PADDLE_ENFORCE_LT(idx, - cls_num, - phi::errors::InvalidArgument( - "Class index of each instance should be less than " - "cls_num = %d, But the index we received is %d", - cls_num, - idx)); - - PADDLE_ENFORCE_GE(label, - 0, - phi::errors::InvalidArgument( - "Label of each instance should be greater than or " - "equal to 0, But the label we received is %d", - label)); - PADDLE_ENFORCE_LT(label, - cls_num, - phi::errors::InvalidArgument( - "Label of each instance should be less than " - "cls_num = %d, But the label we received is %d", - cls_num, - label)); - - T w = weights_data ? weights_data[i] : 1.0; - if (idx == label) { - accum_states_data[idx * state_var_num + TP] += w; - for (size_t j = 0; j < cls_num; ++j) { - accum_states_data[j * state_var_num + TN] += w; - } - accum_states_data[idx * state_var_num + TN] -= w; - } else { - accum_states_data[label * state_var_num + FN] += w; - accum_states_data[idx * state_var_num + FP] += w; - for (size_t j = 0; j < cls_num; ++j) { - accum_states_data[j * state_var_num + TN] += w; - } - accum_states_data[idx * state_var_num + TN] -= w; - accum_states_data[label * state_var_num + TN] -= w; - } - } - - ComputeMetrics( - accum_states_data, batch_metrics_data, state_var_num, cls_num); - - if (states_data) { - for (size_t i = 0; i < cls_num; ++i) { - for (size_t j = 0; j < state_var_num; ++j) { - size_t idx = i * state_var_num + j; - accum_states_data[idx] += states_data[idx]; - } - } - } - - ComputeMetrics( - accum_states_data, accum_metrics_data, state_var_num, cls_num); - } - - // expose to be reused - static inline T CalcPrecision(T tp_count, T fp_count) { - if (tp_count > 0.0 || fp_count > 0.0) { - return tp_count / (tp_count + fp_count); - } - return 1.0; - } - - static inline T CalcRecall(T tp_count, T fn_count) { - if (tp_count > 0.0 || fn_count > 0.0) { - return tp_count / (tp_count + fn_count); - } - return 1.0; - } - - static inline T CalcF1Score(T precision, T recall) { - if (precision > 0.0 || recall > 0.0) { - return 2 * precision * recall / (precision + recall); - } - return 0.0; - } - - protected: - void ComputeMetrics(const T* states_data, - double* metrics_data, - size_t state_var_num, - size_t cls_num) const { - T total_tp_count = 0; - T total_fp_count = 0; - T total_fn_count = 0; - T macro_avg_precision = 0.0; - T macro_avg_recall = 0.0; - - for (size_t i = 0; i < cls_num; ++i) { - T tp_count = states_data[i * state_var_num + TP]; - T fp_count = states_data[i * state_var_num + FP]; - T fn_count = states_data[i * state_var_num + FN]; - total_tp_count += tp_count; - total_fp_count += fp_count; - total_fn_count += fn_count; - macro_avg_precision += CalcPrecision(tp_count, fp_count); - macro_avg_recall += CalcRecall(tp_count, fn_count); - } - macro_avg_precision /= cls_num; - macro_avg_recall /= cls_num; - T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall); - - T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); - T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count); - T micro_f1_score = CalcF1Score(micro_avg_precision, micro_avg_recall); - - // fill metrics data - metrics_data[0] = macro_avg_precision; - metrics_data[1] = macro_avg_recall; - metrics_data[2] = macro_f1_score; - metrics_data[3] = micro_avg_precision; - metrics_data[4] = micro_avg_recall; - metrics_data[5] = micro_f1_score; - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/metrics/unity_build_rule.cmake b/paddle/fluid/operators/metrics/unity_build_rule.cmake deleted file mode 100644 index dee8680cc93d33..00000000000000 --- a/paddle/fluid/operators/metrics/unity_build_rule.cmake +++ /dev/null @@ -1,7 +0,0 @@ -# This file records the Unity Build compilation rules. -# The source files in a `register_unity_group` called are compiled in a unity -# file. -# Generally, the combination rules in this file do not need to be modified. -# If there are some redefined error in compiling with the source file which -# in combination rule, you can remove the source file from the following rules. -register_unity_group(cc precision_recall_op.cc) diff --git a/test/legacy_test/test_precision_recall_op.py b/test/legacy_test/test_precision_recall_op.py deleted file mode 100644 index 97f3d7e7724a47..00000000000000 --- a/test/legacy_test/test_precision_recall_op.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from op_test import OpTest - - -def calc_precision(tp_count, fp_count): - if tp_count > 0.0 or fp_count > 0.0: - return tp_count / (tp_count + fp_count) - return 1.0 - - -def calc_recall(tp_count, fn_count): - if tp_count > 0.0 or fn_count > 0.0: - return tp_count / (tp_count + fn_count) - return 1.0 - - -def calc_f1_score(precision, recall): - if precision > 0.0 or recall > 0.0: - return 2 * precision * recall / (precision + recall) - return 0.0 - - -def get_states(idxs, labels, cls_num, weights=None): - ins_num = idxs.shape[0] - # TP FP TN FN - states = np.zeros((cls_num, 4)).astype('float32') - for i in range(ins_num): - w = weights[i] if weights is not None else 1.0 - idx = idxs[i][0] - label = labels[i][0] - if idx == label: - states[idx][0] += w - for j in range(cls_num): - states[j][2] += w - states[idx][2] -= w - else: - states[label][3] += w - states[idx][1] += w - for j in range(cls_num): - states[j][2] += w - states[label][2] -= w - states[idx][2] -= w - return states - - -def compute_metrics(states, cls_num): - total_tp_count = 0.0 - total_fp_count = 0.0 - total_fn_count = 0.0 - macro_avg_precision = 0.0 - macro_avg_recall = 0.0 - for i in range(cls_num): - total_tp_count += states[i][0] - total_fp_count += states[i][1] - total_fn_count += states[i][3] - macro_avg_precision += calc_precision(states[i][0], states[i][1]) - macro_avg_recall += calc_recall(states[i][0], states[i][3]) - metrics = [] - macro_avg_precision /= cls_num - macro_avg_recall /= cls_num - metrics.append(macro_avg_precision) - metrics.append(macro_avg_recall) - metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall)) - micro_avg_precision = calc_precision(total_tp_count, total_fp_count) - metrics.append(micro_avg_precision) - micro_avg_recall = calc_recall(total_tp_count, total_fn_count) - metrics.append(micro_avg_recall) - metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall)) - return np.array(metrics).astype('float32') - - -class TestPrecisionRecallOp_0(OpTest): - def setUp(self): - self.op_type = "precision_recall" - ins_num = 64 - cls_num = 10 - max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - idxs = ( - np.random.choice(range(cls_num), ins_num) - .reshape((ins_num, 1)) - .astype('int32') - ) - labels = ( - np.random.choice(range(cls_num), ins_num) - .reshape((ins_num, 1)) - .astype('int32') - ) - states = get_states(idxs, labels, cls_num) - metrics = compute_metrics(states, cls_num) - - self.attrs = {'class_number': cls_num} - - self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels} - - self.outputs = { - 'BatchMetrics': metrics, - 'AccumMetrics': metrics, - 'AccumStatesInfo': states, - } - - def test_check_output(self): - self.check_output() - - -class TestPrecisionRecallOp_1(OpTest): - def setUp(self): - self.op_type = "precision_recall" - ins_num = 64 - cls_num = 10 - max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - idxs = ( - np.random.choice(range(cls_num), ins_num) - .reshape((ins_num, 1)) - .astype('int32') - ) - weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - labels = ( - np.random.choice(range(cls_num), ins_num) - .reshape((ins_num, 1)) - .astype('int32') - ) - - states = get_states(idxs, labels, cls_num, weights) - metrics = compute_metrics(states, cls_num) - - self.attrs = {'class_number': cls_num} - - self.inputs = { - 'MaxProbs': max_probs, - 'Indices': idxs, - 'Labels': labels, - 'Weights': weights, - } - - self.outputs = { - 'BatchMetrics': metrics, - 'AccumMetrics': metrics, - 'AccumStatesInfo': states, - } - - def test_check_output(self): - self.check_output() - - -class TestPrecisionRecallOp_2(OpTest): - def setUp(self): - self.op_type = "precision_recall" - ins_num = 64 - cls_num = 10 - max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - idxs = ( - np.random.choice(range(cls_num), ins_num) - .reshape((ins_num, 1)) - .astype('int32') - ) - weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - labels = ( - np.random.choice(range(cls_num), ins_num) - .reshape((ins_num, 1)) - .astype('int32') - ) - states = np.random.randint(0, 30, (cls_num, 4)).astype('float32') - - accum_states = get_states(idxs, labels, cls_num, weights) - batch_metrics = compute_metrics(accum_states, cls_num) - accum_states += states - accum_metrics = compute_metrics(accum_states, cls_num) - - self.attrs = {'class_number': cls_num} - - self.inputs = { - 'MaxProbs': max_probs, - 'Indices': idxs, - 'Labels': labels, - 'Weights': weights, - 'StatesInfo': states, - } - - self.outputs = { - 'BatchMetrics': batch_metrics, - 'AccumMetrics': accum_metrics, - 'AccumStatesInfo': accum_states, - } - - def test_check_output(self): - self.check_output() - - -if __name__ == '__main__': - unittest.main()