From 02ea09271a56d49fe42ea61f689b00061b54c153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 28 Nov 2020 02:18:57 +0100 Subject: [PATCH 1/8] improves processing time by 10 --- .../cpu/ml/tree_ensemble_classifier.cc | 2 +- .../providers/cpu/ml/tree_ensemble_common.h | 192 ++++++++++++------ .../core/providers/cpu/ml/treeregressor.cc | 2 +- .../providers/cpu/ml/treeregressor_test.cc | 25 ++- 4 files changed, 151 insertions(+), 70 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc index f4512e5df05bc..216bacd607cd0 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc @@ -139,7 +139,7 @@ template TreeEnsembleClassifier::TreeEnsembleClassifier(const OpKernelInfo& info) : OpKernel(info), tree_ensemble_( - 100, + 80, 50, info.GetAttrOrDefault("aggregate_function", "SUM"), info.GetAttrsOrDefault("base_values"), diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 0dbdf544bc8d3..fc8aba378dfe5 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -7,6 +7,9 @@ #include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" +#define TREEENSEMBLE_BATCHSIZE 256 +#define TREEENSEMBLE_MAXSIZE 8589934592 + namespace onnxruntime { namespace ml { namespace detail { @@ -266,11 +269,11 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, if (n_targets_or_classes_ == 1) { if (N == 1) { ScoreValue score = {0, 0}; - if (n_trees_ <= parallel_tree_) { + if (n_trees_ <= parallel_tree_) { /* section A */ for (int64_t j = 0; j < n_trees_; ++j) { agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data)); } - } else { + } else { /* section B */ std::vector> scores_t(n_trees_, {0, 0}); concurrency::ThreadPool::TryBatchParallelFor( ttp, @@ -284,46 +287,105 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.MergePrediction1(score, *it); } } - agg.FinalizeScores1(z_data, score, label_data); - } else { - if (N <= parallel_N_) { - ScoreValue score; - size_t j; - - for (int64_t i = 0; i < N; ++i) { - score = {0, 0}; - for (j = 0; j < static_cast(n_trees_); ++j) { - agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); - } - - agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score, - label_data == nullptr ? nullptr : (label_data + i)); + } else if (N <= parallel_N_) { /* section C */ + ScoreValue score; + size_t j; + + for (int64_t i = 0; i < N; ++i) { + score = {0, 0}; + for (j = 0; j < static_cast(n_trees_); ++j) { + agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); } - } else { - concurrency::ThreadPool::TryBatchParallelFor( - ttp, - SafeInt(N), - [this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) { - ScoreValue score = {0, 0}; - for (size_t j = 0; j < static_cast(n_trees_); ++j) { - agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); - } - agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score, + agg.FinalizeScores1(z_data + i, score, + label_data == nullptr ? nullptr : (label_data + i)); + } + } else if ((n_trees_ > parallel_tree_) && (n_trees_ * N < TREEENSEMBLE_MAXSIZE)) { /* section D */ + // Parallelization by trees. + // This could use an array N * nth where nth is the number of threads. + // It would requires function omp_get_thread_num and omp_get_max_threads. + std::vector> scores_t(n_trees_ * N, {0, 0}); + concurrency::ThreadPool::TryBatchParallelFor( + ttp, + SafeInt(n_trees_), + [this, &scores_t, &agg, x_data, N, stride](ptrdiff_t j) { + for (int64_t i = 0; i < N; ++i) { + agg.ProcessTreeNodePrediction1(scores_t[j * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } + }, + 0); + + concurrency::ThreadPool::TryBatchParallelFor( + ttp, + SafeInt(N), + [this, &scores_t, &agg, x_data, z_data, label_data, N](ptrdiff_t i) { + for (int64_t j = 1; j < this->n_trees_; ++j) { + agg.MergePrediction1(scores_t[i], scores_t[j * N + i]); + } + agg.FinalizeScores1(z_data + i, scores_t[i], label_data == nullptr ? nullptr : (label_data + i)); + }, + 0); + } else if (N < TREEENSEMBLE_BATCHSIZE * 16) { /* section E */ + // Simple parallelization by observations. + concurrency::ThreadPool::TryBatchParallelFor( + ttp, + SafeInt(N), + [this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) { + ScoreValue score = {0, 0}; + for (size_t j = 0; j < static_cast(n_trees_); ++j) { + agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } + + agg.FinalizeScores1(z_data + i, score, + label_data == nullptr ? nullptr : (label_data + i)); + }, + 0); + } else { /* section F */ + // Parallelization by blocs, and inside every bloc, + // the first loop goes through trees, the inside loop goes + // through observations. This trick is twice faster than section E. + int64_t NB = N - N % TREEENSEMBLE_BATCHSIZE; + concurrency::ThreadPool::TryBatchParallelFor( + ttp, + SafeInt(NB / TREEENSEMBLE_BATCHSIZE), + [this, &agg, x_data, z_data, stride, label_data](ptrdiff_t loop_i) { + ScoreValue score[TREEENSEMBLE_BATCHSIZE]; + memset(&score[0], 0, sizeof(ScoreValue) * TREEENSEMBLE_BATCHSIZE); + const ITYPE* x_data_loop = x_data + loop_i * TREEENSEMBLE_BATCHSIZE * stride; + OTYPE* z_data_loop = z_data + loop_i * TREEENSEMBLE_BATCHSIZE; + for (size_t j = 0; j < static_cast(n_trees_); ++j) { + for (int64_t i = 0; i < TREEENSEMBLE_BATCHSIZE; ++i) { + agg.ProcessTreeNodePrediction1(score[i], *ProcessTreeNodeLeave(roots_[j], x_data_loop + i * stride)); + } + } + for (int64_t i = 0; i < TREEENSEMBLE_BATCHSIZE; ++i) { + agg.FinalizeScores1(z_data_loop + i, score[i], label_data == nullptr ? nullptr : (label_data + i)); - }, - 0); + } + }, + 0); + ScoreValue score; + size_t j; + + for (int64_t i = NB; i < N; ++i) { + score = {0, 0}; + for (j = 0; j < static_cast(n_trees_); ++j) { + agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } + + agg.FinalizeScores1(z_data + i, score, + label_data == nullptr ? nullptr : (label_data + i)); } } } else { if (N == 1) { std::vector> scores(n_targets_or_classes_, {0, 0}); - if (n_trees_ <= parallel_tree_) { + if (n_trees_ <= parallel_tree_) { /* section A2 */ for (int64_t j = 0; j < n_trees_; ++j) { agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data)); } - } else { + } else { /* section B2 */ // split the work into one block per thread so we can re-use the 'private_scores' vector as much as possible // TODO: Refine the number of threads used auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(n_trees_)); @@ -344,44 +406,42 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, } agg.FinalizeScores(scores, z_data, -1, label_data); - } else { - if (N <= parallel_N_) { - std::vector> scores(n_targets_or_classes_); - size_t j; - - for (int64_t i = 0; i < N; ++i) { - std::fill(scores.begin(), scores.end(), ScoreValue({0, 0})); - for (j = 0; j < roots_.size(); ++j) { - agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); - } - - agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1, - label_data == nullptr ? nullptr : (label_data + i)); + } else if (N <= parallel_N_) { /* section D2 */ + std::vector> scores(n_targets_or_classes_); + size_t j; + + for (int64_t i = 0; i < N; ++i) { + std::fill(scores.begin(), scores.end(), ScoreValue({0, 0})); + for (j = 0; j < roots_.size(); ++j) { + agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); } - } else { - // split the work into one block per thread so we can re-use the 'scores' vector as much as possible - // TODO: Refine the number of threads used. - auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(N)); - concurrency::ThreadPool::TrySimpleParallelFor( - ttp, - num_threads, - [this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) { - size_t j; - std::vector> scores(n_targets_or_classes_); - auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N); - - for (auto i = work.start; i < work.end; ++i) { - std::fill(scores.begin(), scores.end(), ScoreValue({0, 0})); - for (j = 0; j < roots_.size(); ++j) { - agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); - } - - agg.FinalizeScores(scores, - z_data + i * n_targets_or_classes_, -1, - label_data == nullptr ? nullptr : (label_data + i)); - } - }); + + agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1, + label_data == nullptr ? nullptr : (label_data + i)); } + } else { /* section F2 */ + // split the work into one block per thread so we can re-use the 'scores' vector as much as possible + // TODO: Refine the number of threads used. + auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(N)); + concurrency::ThreadPool::TrySimpleParallelFor( + ttp, + num_threads, + [this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) { + size_t j; + std::vector> scores(n_targets_or_classes_); + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N); + + for (auto i = work.start; i < work.end; ++i) { + std::fill(scores.begin(), scores.end(), ScoreValue({0, 0})); + for (j = 0; j < roots_.size(); ++j) { + agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } + + agg.FinalizeScores(scores, + z_data + i * n_targets_or_classes_, -1, + label_data == nullptr ? nullptr : (label_data + i)); + } + }); } } } // namespace detail diff --git a/onnxruntime/core/providers/cpu/ml/treeregressor.cc b/onnxruntime/core/providers/cpu/ml/treeregressor.cc index cfee2ccae80fe..960e4fcf973f3 100644 --- a/onnxruntime/core/providers/cpu/ml/treeregressor.cc +++ b/onnxruntime/core/providers/cpu/ml/treeregressor.cc @@ -24,7 +24,7 @@ template TreeEnsembleRegressor::TreeEnsembleRegressor(const OpKernelInfo& info) : OpKernel(info), tree_ensemble_( - 100, + 80, 50, info.GetAttrOrDefault("aggregate_function", "SUM"), info.GetAttrsOrDefault("base_values"), diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 87a3ecde9d24d..9d8335600c290 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -97,7 +97,7 @@ TEST(MLOpTest, TreeRegressorMultiTargetMaxDouble) { GenTreeAndRunTest(X, base_values, results, "MAX", true); } -void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs) { +void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs, int64_t n_obs = 3) { OpTester test("TreeEnsembleRegressor", 1, onnxruntime::kMLDomain); //tree @@ -149,16 +149,32 @@ void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs) { // SUM aggregation by default -- no need to add explicitly //fill input data + std::vector xn, yn; if (one_obs) { + ASSERT_TRUE(n_obs == 3); auto X1 = X; auto results1 = results; X1.resize(2); results1.resize(1); test.AddInput("X", {1, 2}, X1); test.AddOutput("Y", {1, 1}, results1); - } else { + } else if (n_obs == 3) { test.AddInput("X", {3, 2}, X); test.AddOutput("Y", {3, 1}, results); + } else { + ASSERT_TRUE(n_obs % 3 == 0); + xn.resize(n_obs * 2); + yn.resize(n_obs); + for (int64_t i = 0; i < n_obs; i += 3) { + for (size_t k = 0; k < 6; ++k) { + xn[i * 2 + k] = X[k]; + } + for (size_t k = 0; k < 3; ++k) { + yn[i + k] = results[k]; + } + } + test.AddInput("X", {n_obs, 2}, xn); + test.AddOutput("Y", {n_obs, 1}, yn); } test.Run(); } @@ -168,6 +184,11 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum) { GenTreeAndRunTest1("SUM", true); } +TEST(MLOpTest, TreeRegressorSingleTargetSumBatch) { + GenTreeAndRunTest1("SUM", false, 201); + GenTreeAndRunTest1("SUM", false, 40002); +} + TEST(MLOpTest, TreeRegressorSingleTargetAverage) { GenTreeAndRunTest1("AVERAGE", false); GenTreeAndRunTest1("AVERAGE", true); From bafe5ce0a8c16634229db7e4696fecc7993936a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sat, 28 Nov 2020 03:16:24 +0100 Subject: [PATCH 2/8] extend coverage unit test coverage --- .../providers/cpu/ml/treeregressor_test.cc | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 9d8335600c290..22fb7bc46e8b4 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -97,7 +97,30 @@ TEST(MLOpTest, TreeRegressorMultiTargetMaxDouble) { GenTreeAndRunTest(X, base_values, results, "MAX", true); } -void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs, int64_t n_obs = 3) { +template +void _multiply_update_array(std::vector& data, int n, T inc = 0) { + std::vector copy = data; + data.resize(copy.size() * n); + T cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + data[j + i * copy.size()] = copy[j] + cst; + } + cst += inc; + } +} + +void _multiply_update_array_string(std::vector& data, int n) { + std::vector copy = data; + data.resize(copy.size() * n); + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + data[j + i * copy.size()] = copy[j]; + } + } +} + +void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs, int64_t n_obs = 3, int n_trees = 1) { OpTester test("TreeEnsembleRegressor", 1, onnxruntime::kMLDomain); //tree @@ -115,6 +138,21 @@ void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs, int64_t n_ std::vector target_weights = {33.33333f, 16.66666f, 33.33333f, -3.33333f, 16.66666f, -3.333333f}; std::vector classes = {0, 1}; + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(lefts, n_trees); + _multiply_update_array(rights, n_trees); + _multiply_update_array(treeids, n_trees, (int64_t)3); + _multiply_update_array(nodeids, n_trees); + _multiply_update_array(featureids, n_trees); + _multiply_update_array(thresholds, n_trees); + _multiply_update_array_string(modes, n_trees); + _multiply_update_array(target_treeids, n_trees, (int64_t)3); + _multiply_update_array(target_nodeids, n_trees); + _multiply_update_array(target_classids, n_trees); + _multiply_update_array(target_weights, n_trees); + } + std::vector results; if (aggFunction == "AVERAGE") { test.AddAttribute("aggregate_function", "AVERAGE"); @@ -126,7 +164,11 @@ void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs, int64_t n_ test.AddAttribute("aggregate_function", "MAX"); results = {33.33333f, 33.33333f, 16.66666f}; } else { // default function is SUM - results = {63.33333333f, 26.66666667f, 30.0f}; + if (n_trees > 1) { + results = {63.33333333f * n_trees + 0.000244140625f, 26.66666667f * n_trees - 0.0001220703125f, 30.0f * n_trees + 0.00042724609375f}; + } else { + results = {63.33333333f, 26.66666667f, 30.0f}; + } } //test data @@ -189,6 +231,12 @@ TEST(MLOpTest, TreeRegressorSingleTargetSumBatch) { GenTreeAndRunTest1("SUM", false, 40002); } +TEST(MLOpTest, TreeRegressorSingleTargetSumBatchTree) { + GenTreeAndRunTest1("SUM", true, 3, 30); + GenTreeAndRunTest1("SUM", false, 201, 30); + GenTreeAndRunTest1("SUM", false, 111040002, 30); +} + TEST(MLOpTest, TreeRegressorSingleTargetAverage) { GenTreeAndRunTest1("AVERAGE", false); GenTreeAndRunTest1("AVERAGE", true); From b19a46010710ccc8602c69f3025340168f66af45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 30 Nov 2020 02:23:09 +0100 Subject: [PATCH 3/8] better implementation for the multi regression case --- .../providers/cpu/ml/tree_ensemble_common.h | 35 +++++- .../providers/cpu/ml/treeregressor_test.cc | 105 +++++++++++++----- 2 files changed, 112 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index fc8aba378dfe5..9b26f0c7f0e5d 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -406,7 +406,7 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, } agg.FinalizeScores(scores, z_data, -1, label_data); - } else if (N <= parallel_N_) { /* section D2 */ + } else if (N <= parallel_N_) { /* section C2 */ std::vector> scores(n_targets_or_classes_); size_t j; @@ -419,7 +419,38 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1, label_data == nullptr ? nullptr : (label_data + i)); } - } else { /* section F2 */ + } else if ((n_trees_ > parallel_tree_) && (N * n_targets_or_classes_ < TREEENSEMBLE_MAXSIZE)) { /* section D2 */ + auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(n_trees_)); + std::vector>> scores(num_threads * N); + concurrency::ThreadPool::TrySimpleParallelFor( + ttp, + num_threads, + [this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) { + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_); + for (int64_t i = 0; i < N; ++i) { + scores[batch_num * N + i].resize(n_targets_or_classes_, {0, 0}); + } + for (auto j = work.start; j < work.end; ++j) { + for (int64_t i = 0; i < N; ++i) { + agg.ProcessTreeNodePrediction(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } + } + }); + + concurrency::ThreadPool::TrySimpleParallelFor( + ttp, + num_threads, + [this, &agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) { + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N); + for (auto i = work.start; i < work.end; ++i) { + for (int64_t j = 1; j < num_threads; ++j) { + agg.MergePrediction(scores[i], scores[j * N + i]); + } + agg.FinalizeScores(scores[i], z_data + i * this->n_targets_or_classes_, -1, + label_data == nullptr ? nullptr : (label_data + i)); + } + }); + } else { /* section E2 */ // split the work into one block per thread so we can re-use the 'scores' vector as much as possible // TODO: Refine the number of threads used. auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(N)); diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 22fb7bc46e8b4..beb3076fe4fdf 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -8,7 +8,31 @@ namespace onnxruntime { namespace test { template -void GenTreeAndRunTest(const std::vector& X, const std::vector& base_values, const std::vector& results, const std::string& aggFunction, bool one_obs = false) { +void _multiply_update_array(std::vector& data, int n, T inc = 0) { + std::vector copy = data; + data.resize(copy.size() * n); + T cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + data[j + i * copy.size()] = copy[j] + cst; + } + cst += inc; + } +} + +void _multiply_update_array_string(std::vector& data, int n) { + std::vector copy = data; + data.resize(copy.size() * n); + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + data[j + i * copy.size()] = copy[j]; + } + } +} + +template +void GenTreeAndRunTest(const std::vector& X, const std::vector& base_values, const std::vector& results, const std::string& aggFunction, + bool one_obs = false, int64_t n_obs = 8, int n_trees = 1) { OpTester test("TreeEnsembleRegressor", 1, onnxruntime::kMLDomain); //tree @@ -26,6 +50,21 @@ void GenTreeAndRunTest(const std::vector& X, const std::vector& base_v std::vector target_weights = {1.5f, 27.5f, 2.25f, 20.75f, 2.f, 23.f, 3.f, 14.f, 0.f, 41.f, 1.83333333f, 24.5f, 0.f, 41.f, 2.75f, 16.25f, 2.f, 23.f, 3.f, 14.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; std::vector classes = {0, 1}; + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(lefts, n_trees); + _multiply_update_array(rights, n_trees); + _multiply_update_array(treeids, n_trees, (int64_t)3); + _multiply_update_array(nodeids, n_trees); + _multiply_update_array(featureids, n_trees); + _multiply_update_array(thresholds, n_trees); + _multiply_update_array_string(modes, n_trees); + _multiply_update_array(target_treeids, n_trees, (int64_t)3); + _multiply_update_array(target_nodeids, n_trees); + _multiply_update_array(target_classids, n_trees); + _multiply_update_array(target_weights, n_trees); + } + //add attributes test.AddAttribute("nodes_truenodeids", lefts); test.AddAttribute("nodes_falsenodeids", rights); @@ -51,6 +90,8 @@ void GenTreeAndRunTest(const std::vector& X, const std::vector& base_v } // default function is SUM //fill input data + std::vector xn; + std::vector yn; if (one_obs) { auto X1 = X; auto results1 = results; @@ -58,13 +99,44 @@ void GenTreeAndRunTest(const std::vector& X, const std::vector& base_v results1.resize(2); test.AddInput("X", {1, 3}, X1); test.AddOutput("Y", {1, 2}, results1); - } else { + } else if (n_obs == 8) { test.AddInput("X", {8, 3}, X); test.AddOutput("Y", {8, 2}, results); + } else { + int64_t i; + size_t k; + ASSERT_TRUE(n_obs % 8 == 0); + xn.resize(n_obs * 3); + yn.resize(n_obs * 2); + for (i = 0; i < n_obs; i += 8) { + for (k = 0; k < 24; ++k) { + xn[i * 3 + k] = X[k]; + } + for (k = 0; k < 16; ++k) { + yn[i * 2 + k] = results[k]; + } + } + ASSERT_TRUE(i == n_obs); + test.AddInput("X", {n_obs, 3}, xn); + test.AddOutput("Y", {n_obs, 2}, yn); } + test.Run(); } // namespace test +TEST(MLOpTest, TreeRegressorMultiTargetSumBatchTree) { + // TreeEnsemble implements different paths depending on n_trees or N. + // It is not possible to test all of them in a short time without + // changing two thresholds which cannot be changed with the current API. + std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; + std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; + std::vector base_values{0.f, 0.f}; + GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 30); + GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 30); + GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 130); + // GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 111040008, 30); +} + TEST(MLOpTest, TreeRegressorMultiTargetAverage) { std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; @@ -97,29 +169,6 @@ TEST(MLOpTest, TreeRegressorMultiTargetMaxDouble) { GenTreeAndRunTest(X, base_values, results, "MAX", true); } -template -void _multiply_update_array(std::vector& data, int n, T inc = 0) { - std::vector copy = data; - data.resize(copy.size() * n); - T cst = 0; - for (int i = 0; i < n; ++i) { - for (size_t j = 0; j < copy.size(); ++j) { - data[j + i * copy.size()] = copy[j] + cst; - } - cst += inc; - } -} - -void _multiply_update_array_string(std::vector& data, int n) { - std::vector copy = data; - data.resize(copy.size() * n); - for (int i = 0; i < n; ++i) { - for (size_t j = 0; j < copy.size(); ++j) { - data[j + i * copy.size()] = copy[j]; - } - } -} - void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs, int64_t n_obs = 3, int n_trees = 1) { OpTester test("TreeEnsembleRegressor", 1, onnxruntime::kMLDomain); @@ -232,9 +281,13 @@ TEST(MLOpTest, TreeRegressorSingleTargetSumBatch) { } TEST(MLOpTest, TreeRegressorSingleTargetSumBatchTree) { + // TreeEnsemble implements different paths depending on n_trees or N. + // It is not possible to test all of them in a short time without + // changing two thresholds which cannot be changed with the current API. GenTreeAndRunTest1("SUM", true, 3, 30); GenTreeAndRunTest1("SUM", false, 201, 30); - GenTreeAndRunTest1("SUM", false, 111040002, 30); + GenTreeAndRunTest1("AVERAGE", false, 201, 130); + //GenTreeAndRunTest1("SUM", false, 111040002, 30); } TEST(MLOpTest, TreeRegressorSingleTargetAverage) { From 5d52525c90d6238b1b21c60306c0ffec66aa1b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 30 Nov 2020 12:44:22 +0100 Subject: [PATCH 4/8] remove unnecessary variable --- onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 9b26f0c7f0e5d..0cbf12f6503ed 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -319,7 +319,7 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, concurrency::ThreadPool::TryBatchParallelFor( ttp, SafeInt(N), - [this, &scores_t, &agg, x_data, z_data, label_data, N](ptrdiff_t i) { + [this, &scores_t, &agg, z_data, label_data, N](ptrdiff_t i) { for (int64_t j = 1; j < this->n_trees_; ++j) { agg.MergePrediction1(scores_t[i], scores_t[j * N + i]); } @@ -333,7 +333,7 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, SafeInt(N), [this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) { ScoreValue score = {0, 0}; - for (size_t j = 0; j < static_cast(n_trees_); ++j) { + for (size_t j = 0; j < static_cast(this->n_trees_); ++j) { agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); } From 91ed22e2bc1a9fd39dd8566e10ea8d6e58b18b88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 1 Dec 2020 00:57:27 +0100 Subject: [PATCH 5/8] reduce number of distinct implementations --- .../providers/cpu/ml/tree_ensemble_common.h | 96 ++++++------------- .../providers/cpu/ml/treeregressor_test.cc | 11 +-- 2 files changed, 31 insertions(+), 76 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 0cbf12f6503ed..ff854563a13bf 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -301,7 +301,8 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.FinalizeScores1(z_data + i, score, label_data == nullptr ? nullptr : (label_data + i)); } - } else if ((n_trees_ > parallel_tree_) && (n_trees_ * N < TREEENSEMBLE_MAXSIZE)) { /* section D */ + } else { /* section D */ + /* // Parallelization by trees. // This could use an array N * nth where nth is the number of threads. // It would requires function omp_get_thread_num and omp_get_max_threads. @@ -326,57 +327,37 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.FinalizeScores1(z_data + i, scores_t[i], label_data == nullptr ? nullptr : (label_data + i)); }, 0); - } else if (N < TREEENSEMBLE_BATCHSIZE * 16) { /* section E */ - // Simple parallelization by observations. - concurrency::ThreadPool::TryBatchParallelFor( + */ + auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(n_trees_)); + std::vector> scores(num_threads * N); + concurrency::ThreadPool::TrySimpleParallelFor( ttp, - SafeInt(N), - [this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) { - ScoreValue score = {0, 0}; - for (size_t j = 0; j < static_cast(this->n_trees_); ++j) { - agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + num_threads, + [this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) { + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_); + for (int64_t i = 0; i < N; ++i) { + scores[batch_num * N + i] = {0, 0}; + } + for (auto j = work.start; j < work.end; ++j) { + for (int64_t i = 0; i < N; ++i) { + agg.ProcessTreeNodePrediction1(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } } + }); - agg.FinalizeScores1(z_data + i, score, - label_data == nullptr ? nullptr : (label_data + i)); - }, - 0); - } else { /* section F */ - // Parallelization by blocs, and inside every bloc, - // the first loop goes through trees, the inside loop goes - // through observations. This trick is twice faster than section E. - int64_t NB = N - N % TREEENSEMBLE_BATCHSIZE; - concurrency::ThreadPool::TryBatchParallelFor( + concurrency::ThreadPool::TrySimpleParallelFor( ttp, - SafeInt(NB / TREEENSEMBLE_BATCHSIZE), - [this, &agg, x_data, z_data, stride, label_data](ptrdiff_t loop_i) { - ScoreValue score[TREEENSEMBLE_BATCHSIZE]; - memset(&score[0], 0, sizeof(ScoreValue) * TREEENSEMBLE_BATCHSIZE); - const ITYPE* x_data_loop = x_data + loop_i * TREEENSEMBLE_BATCHSIZE * stride; - OTYPE* z_data_loop = z_data + loop_i * TREEENSEMBLE_BATCHSIZE; - for (size_t j = 0; j < static_cast(n_trees_); ++j) { - for (int64_t i = 0; i < TREEENSEMBLE_BATCHSIZE; ++i) { - agg.ProcessTreeNodePrediction1(score[i], *ProcessTreeNodeLeave(roots_[j], x_data_loop + i * stride)); + num_threads, + [this, &agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) { + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N); + for (auto i = work.start; i < work.end; ++i) { + for (int64_t j = 1; j < num_threads; ++j) { + agg.MergePrediction1(scores[i], scores[j * N + i]); } - } - for (int64_t i = 0; i < TREEENSEMBLE_BATCHSIZE; ++i) { - agg.FinalizeScores1(z_data_loop + i, score[i], + agg.FinalizeScores1(z_data + i, scores[i], label_data == nullptr ? nullptr : (label_data + i)); } - }, - 0); - ScoreValue score; - size_t j; - - for (int64_t i = NB; i < N; ++i) { - score = {0, 0}; - for (j = 0; j < static_cast(n_trees_); ++j) { - agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); - } - - agg.FinalizeScores1(z_data + i, score, - label_data == nullptr ? nullptr : (label_data + i)); - } + }); } } else { if (N == 1) { @@ -419,7 +400,7 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1, label_data == nullptr ? nullptr : (label_data + i)); } - } else if ((n_trees_ > parallel_tree_) && (N * n_targets_or_classes_ < TREEENSEMBLE_MAXSIZE)) { /* section D2 */ + } else { auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(n_trees_)); std::vector>> scores(num_threads * N); concurrency::ThreadPool::TrySimpleParallelFor( @@ -450,29 +431,6 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, label_data == nullptr ? nullptr : (label_data + i)); } }); - } else { /* section E2 */ - // split the work into one block per thread so we can re-use the 'scores' vector as much as possible - // TODO: Refine the number of threads used. - auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(N)); - concurrency::ThreadPool::TrySimpleParallelFor( - ttp, - num_threads, - [this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) { - size_t j; - std::vector> scores(n_targets_or_classes_); - auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N); - - for (auto i = work.start; i < work.end; ++i) { - std::fill(scores.begin(), scores.end(), ScoreValue({0, 0})); - for (j = 0; j < roots_.size(); ++j) { - agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); - } - - agg.FinalizeScores(scores, - z_data + i * n_targets_or_classes_, -1, - label_data == nullptr ? nullptr : (label_data + i)); - } - }); } } } // namespace detail diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index beb3076fe4fdf..b24a7c5ff59fb 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -213,11 +213,7 @@ void GenTreeAndRunTest1(const std::string& aggFunction, bool one_obs, int64_t n_ test.AddAttribute("aggregate_function", "MAX"); results = {33.33333f, 33.33333f, 16.66666f}; } else { // default function is SUM - if (n_trees > 1) { - results = {63.33333333f * n_trees + 0.000244140625f, 26.66666667f * n_trees - 0.0001220703125f, 30.0f * n_trees + 0.00042724609375f}; - } else { - results = {63.33333333f, 26.66666667f, 30.0f}; - } + results = {63.33333333f, 26.66666667f, 30.0f}; } //test data @@ -284,8 +280,9 @@ TEST(MLOpTest, TreeRegressorSingleTargetSumBatchTree) { // TreeEnsemble implements different paths depending on n_trees or N. // It is not possible to test all of them in a short time without // changing two thresholds which cannot be changed with the current API. - GenTreeAndRunTest1("SUM", true, 3, 30); - GenTreeAndRunTest1("SUM", false, 201, 30); + GenTreeAndRunTest1("SUM", true, 3, 1); + GenTreeAndRunTest1("AVERAGE", true, 3, 30); + GenTreeAndRunTest1("AVERAGE", false, 201, 30); GenTreeAndRunTest1("AVERAGE", false, 201, 130); //GenTreeAndRunTest1("SUM", false, 111040002, 30); } From c4c7ecb1408f1b5103d6c9707a3b0891bddcd2f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Tue, 1 Dec 2020 01:33:49 +0100 Subject: [PATCH 6/8] remove unused variable --- onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index ff854563a13bf..90c64eac2acdb 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -348,7 +348,7 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, concurrency::ThreadPool::TrySimpleParallelFor( ttp, num_threads, - [this, &agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) { + [&agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) { auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N); for (auto i = work.start; i < work.end; ++i) { for (int64_t j = 1; j < num_threads; ++j) { From 72f20aef00b6599441bd0215e455f7f6b7ea672d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 2 Dec 2020 19:48:22 +0100 Subject: [PATCH 7/8] better comment, keep parallelization by trees when not enough trees --- .../providers/cpu/ml/tree_ensemble_common.h | 90 +++++++++++-------- .../providers/cpu/ml/treeregressor_test.cc | 31 +++---- 2 files changed, 67 insertions(+), 54 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 90c64eac2acdb..48e757cb5d4c6 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -265,15 +265,16 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, const ITYPE* x_data = X->template Data(); OTYPE* z_data = Z->template MutableData(); int64_t* label_data = label == nullptr ? nullptr : label->template MutableData(); + auto max_num_threads = concurrency::ThreadPool::DegreeOfParallelism(ttp); if (n_targets_or_classes_ == 1) { if (N == 1) { ScoreValue score = {0, 0}; - if (n_trees_ <= parallel_tree_) { /* section A */ + if (n_trees_ <= parallel_tree_) { /* section A: 1 output, 1 row and not enough trees to parallelize */ for (int64_t j = 0; j < n_trees_; ++j) { agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data)); } - } else { /* section B */ + } else { /* section B: 1 output, 1 row and enough trees to parallelize */ std::vector> scores_t(n_trees_, {0, 0}); concurrency::ThreadPool::TryBatchParallelFor( ttp, @@ -288,7 +289,7 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, } } agg.FinalizeScores1(z_data, score, label_data); - } else if (N <= parallel_N_) { /* section C */ + } else if (N <= parallel_N_) { /* section C: 1 output, 2+ rows but not enough rows to parallelize */ ScoreValue score; size_t j; @@ -301,34 +302,8 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.FinalizeScores1(z_data + i, score, label_data == nullptr ? nullptr : (label_data + i)); } - } else { /* section D */ - /* - // Parallelization by trees. - // This could use an array N * nth where nth is the number of threads. - // It would requires function omp_get_thread_num and omp_get_max_threads. - std::vector> scores_t(n_trees_ * N, {0, 0}); - concurrency::ThreadPool::TryBatchParallelFor( - ttp, - SafeInt(n_trees_), - [this, &scores_t, &agg, x_data, N, stride](ptrdiff_t j) { - for (int64_t i = 0; i < N; ++i) { - agg.ProcessTreeNodePrediction1(scores_t[j * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); - } - }, - 0); - - concurrency::ThreadPool::TryBatchParallelFor( - ttp, - SafeInt(N), - [this, &scores_t, &agg, z_data, label_data, N](ptrdiff_t i) { - for (int64_t j = 1; j < this->n_trees_; ++j) { - agg.MergePrediction1(scores_t[i], scores_t[j * N + i]); - } - agg.FinalizeScores1(z_data + i, scores_t[i], label_data == nullptr ? nullptr : (label_data + i)); - }, - 0); - */ - auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(n_trees_)); + } else if (n_trees_ > max_num_threads) { /* section D: 1 output, 2+ rows and enough trees to parallelize */ + auto num_threads = std::min(max_num_threads, SafeInt(n_trees_)); std::vector> scores(num_threads * N); concurrency::ThreadPool::TrySimpleParallelFor( ttp, @@ -358,18 +333,32 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, label_data == nullptr ? nullptr : (label_data + i)); } }); + } else { /* section E: 1 output, 2+ rows, parallelization by rows */ + concurrency::ThreadPool::TryBatchParallelFor( + ttp, + SafeInt(N), + [this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) { + ScoreValue score = {0, 0}; + for (size_t j = 0; j < static_cast(n_trees_); ++j) { + agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } + + agg.FinalizeScores1(z_data + i, score, + label_data == nullptr ? nullptr : (label_data + i)); + }, + 0); } } else { - if (N == 1) { + if (N == 1) { /* section A2: 2+ outputs, 1 row, not enough trees to parallelize */ std::vector> scores(n_targets_or_classes_, {0, 0}); if (n_trees_ <= parallel_tree_) { /* section A2 */ for (int64_t j = 0; j < n_trees_; ++j) { agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data)); } - } else { /* section B2 */ - // split the work into one block per thread so we can re-use the 'private_scores' vector as much as possible - // TODO: Refine the number of threads used - auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(n_trees_)); + } else { /* section B2: 2+ outputs, 1 row, enough trees to parallelize */ + // Splits the work into one block per thread so we can re-use the 'private_scores' vector as much as possible. + // TODO: Refine the number of threads used. + auto num_threads = std::min(max_num_threads, SafeInt(n_trees_)); OrtMutex merge_mutex; concurrency::ThreadPool::TrySimpleParallelFor( ttp, @@ -387,7 +376,7 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, } agg.FinalizeScores(scores, z_data, -1, label_data); - } else if (N <= parallel_N_) { /* section C2 */ + } else if (N <= parallel_N_) { /* section C2: 2+ outputs, 2+ rows, not enough rows to parallelize */ std::vector> scores(n_targets_or_classes_); size_t j; @@ -400,8 +389,8 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1, label_data == nullptr ? nullptr : (label_data + i)); } - } else { - auto num_threads = std::min(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt(n_trees_)); + } else if (n_trees_ >= max_num_threads) { /* section: D2: 2+ outputs, 2+ rows, enough trees to parallelize*/ + auto num_threads = std::min(max_num_threads, SafeInt(n_trees_)); std::vector>> scores(num_threads * N); concurrency::ThreadPool::TrySimpleParallelFor( ttp, @@ -431,6 +420,29 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, label_data == nullptr ? nullptr : (label_data + i)); } }); + } else { /* section E2: 2+ outputs, 2+ rows, parallelization by rows */ + // Split the work into one block per thread so we can re-use the 'scores' vector as much as possible. + // TODO: Refine the number of threads used. + auto num_threads = std::min(max_num_threads, SafeInt(N)); + concurrency::ThreadPool::TrySimpleParallelFor( + ttp, + num_threads, + [this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) { + size_t j; + std::vector> scores(n_targets_or_classes_); + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N); + + for (auto i = work.start; i < work.end; ++i) { + std::fill(scores.begin(), scores.end(), ScoreValue({0, 0})); + for (j = 0; j < roots_.size(); ++j) { + agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride)); + } + + agg.FinalizeScores(scores, + z_data + i * n_targets_or_classes_, -1, + label_data == nullptr ? nullptr : (label_data + i)); + } + }); } } } // namespace detail diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index b24a7c5ff59fb..40c8c2895f8f6 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -124,17 +124,18 @@ void GenTreeAndRunTest(const std::vector& X, const std::vector& base_v test.Run(); } // namespace test -TEST(MLOpTest, TreeRegressorMultiTargetSumBatchTree) { +TEST(MLOpTest, TreeRegressorMultiTargetBatchTree) { // TreeEnsemble implements different paths depending on n_trees or N. - // It is not possible to test all of them in a short time without - // changing two thresholds which cannot be changed with the current API. + // This test goes through all sections for multi-targets. std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; std::vector base_values{0.f, 0.f}; - GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 30); - GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 30); - GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 130); - // GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 111040008, 30); + GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 1); // section A2 + GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 130); // section B2 + GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 130); // section C2 + GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 30); // section D2 + GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 30); // section D2 + GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 1); // section E2 } TEST(MLOpTest, TreeRegressorMultiTargetAverage) { @@ -276,15 +277,15 @@ TEST(MLOpTest, TreeRegressorSingleTargetSumBatch) { GenTreeAndRunTest1("SUM", false, 40002); } -TEST(MLOpTest, TreeRegressorSingleTargetSumBatchTree) { +TEST(MLOpTest, TreeRegressorSingleTargetBatchTree) { // TreeEnsemble implements different paths depending on n_trees or N. - // It is not possible to test all of them in a short time without - // changing two thresholds which cannot be changed with the current API. - GenTreeAndRunTest1("SUM", true, 3, 1); - GenTreeAndRunTest1("AVERAGE", true, 3, 30); - GenTreeAndRunTest1("AVERAGE", false, 201, 30); - GenTreeAndRunTest1("AVERAGE", false, 201, 130); - //GenTreeAndRunTest1("SUM", false, 111040002, 30); + // This test goes through all sections for one target. + GenTreeAndRunTest1("SUM", true, 3, 1); // section A + GenTreeAndRunTest1("AVERAGE", true, 3, 30); // section B + GenTreeAndRunTest1("AVERAGE", false, 3, 1); // section C + GenTreeAndRunTest1("AVERAGE", false, 201, 30); // section D + GenTreeAndRunTest1("AVERAGE", false, 201, 130); // section D + GenTreeAndRunTest1("AVERAGE", false, 201, 1); // section E } TEST(MLOpTest, TreeRegressorSingleTargetAverage) { From 520442caf123b3e180710b9c5d520ddcd62e1f5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 2 Dec 2020 20:12:37 +0100 Subject: [PATCH 8/8] split a unit test to have a name for every tested section --- .../providers/cpu/ml/tree_ensemble_common.h | 37 +++++------ .../providers/cpu/ml/treeregressor_test.cc | 62 +++++++++++++++---- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 48e757cb5d4c6..4a06f5a7a2bf0 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -7,9 +7,6 @@ #include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" -#define TREEENSEMBLE_BATCHSIZE 256 -#define TREEENSEMBLE_MAXSIZE 8589934592 - namespace onnxruntime { namespace ml { namespace detail { @@ -275,16 +272,16 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data)); } } else { /* section B: 1 output, 1 row and enough trees to parallelize */ - std::vector> scores_t(n_trees_, {0, 0}); + std::vector> scores(n_trees_, {0, 0}); concurrency::ThreadPool::TryBatchParallelFor( ttp, SafeInt(n_trees_), - [this, &scores_t, &agg, x_data](ptrdiff_t j) { - agg.ProcessTreeNodePrediction1(scores_t[j], *ProcessTreeNodeLeave(roots_[j], x_data)); + [this, &scores, &agg, x_data](ptrdiff_t j) { + agg.ProcessTreeNodePrediction1(scores[j], *ProcessTreeNodeLeave(roots_[j], x_data)); }, 0); - for (auto it = scores_t.cbegin(); it != scores_t.cend(); ++it) { + for (auto it = scores.cbegin(); it != scores.cend(); ++it) { agg.MergePrediction1(score, *it); } } @@ -349,33 +346,31 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, 0); } } else { - if (N == 1) { /* section A2: 2+ outputs, 1 row, not enough trees to parallelize */ - std::vector> scores(n_targets_or_classes_, {0, 0}); + if (N == 1) { /* section A2: 2+ outputs, 1 row, not enough trees to parallelize */ if (n_trees_ <= parallel_tree_) { /* section A2 */ + std::vector> scores(n_targets_or_classes_, {0, 0}); for (int64_t j = 0; j < n_trees_; ++j) { agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data)); } + agg.FinalizeScores(scores, z_data, -1, label_data); } else { /* section B2: 2+ outputs, 1 row, enough trees to parallelize */ - // Splits the work into one block per thread so we can re-use the 'private_scores' vector as much as possible. - // TODO: Refine the number of threads used. auto num_threads = std::min(max_num_threads, SafeInt(n_trees_)); - OrtMutex merge_mutex; + std::vector>> scores(num_threads); concurrency::ThreadPool::TrySimpleParallelFor( ttp, num_threads, - [this, &agg, &scores, &merge_mutex, num_threads, x_data](ptrdiff_t batch_num) { - std::vector> private_scores(n_targets_or_classes_, {0, 0}); + [this, &agg, &scores, num_threads, x_data](ptrdiff_t batch_num) { + scores[batch_num].resize(n_targets_or_classes_, {0, 0}); auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, n_trees_); for (auto j = work.start; j < work.end; ++j) { - agg.ProcessTreeNodePrediction(private_scores, *ProcessTreeNodeLeave(roots_[j], x_data)); + agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data)); } - - std::lock_guard lock(merge_mutex); - agg.MergePrediction(scores, private_scores); }); + for (size_t i = 1; i < scores.size(); ++i) { + agg.MergePrediction(scores[0], scores[i]); + } + agg.FinalizeScores(scores[0], z_data, -1, label_data); } - - agg.FinalizeScores(scores, z_data, -1, label_data); } else if (N <= parallel_N_) { /* section C2: 2+ outputs, 2+ rows, not enough rows to parallelize */ std::vector> scores(n_targets_or_classes_); size_t j; @@ -421,8 +416,6 @@ void TreeEnsembleCommon::ComputeAgg(concurrency::ThreadPool* ttp, } }); } else { /* section E2: 2+ outputs, 2+ rows, parallelization by rows */ - // Split the work into one block per thread so we can re-use the 'scores' vector as much as possible. - // TODO: Refine the number of threads used. auto num_threads = std::min(max_num_threads, SafeInt(N)); concurrency::ThreadPool::TrySimpleParallelFor( ttp, diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 40c8c2895f8f6..e6effd1427394 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -124,18 +124,42 @@ void GenTreeAndRunTest(const std::vector& X, const std::vector& base_v test.Run(); } // namespace test -TEST(MLOpTest, TreeRegressorMultiTargetBatchTree) { +TEST(MLOpTest, TreeRegressorMultiTargetBatchTreeA2) { // TreeEnsemble implements different paths depending on n_trees or N. - // This test goes through all sections for multi-targets. + // This test and the next ones go through all sections for multi-targets. + std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; + std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; + std::vector base_values{0.f, 0.f}; + GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 1); // section A2 +} + +TEST(MLOpTest, TreeRegressorMultiTargetBatchTreeB2) { + std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; + std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; + std::vector base_values{0.f, 0.f}; + GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 130); // section B2 +} + +TEST(MLOpTest, TreeRegressorMultiTargetBatchTreeC2) { std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; std::vector base_values{0.f, 0.f}; - GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 1); // section A2 - GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 130); // section B2 GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 130); // section C2 - GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 30); // section D2 - GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 30); // section D2 - GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 1); // section E2 +} + +TEST(MLOpTest, TreeRegressorMultiTargetBatchTreeD2) { + std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; + std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; + std::vector base_values{0.f, 0.f}; + GenTreeAndRunTest(X, base_values, results, "AVERAGE", true, 8, 30); // section D2 + GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 30); // section D2 +} + +TEST(MLOpTest, TreeRegressorMultiTargetBatchTreeE2) { + std::vector X = {1.f, 0.0f, 0.4f, 3.0f, 44.0f, -3.f, 12.0f, 12.9f, -312.f, 23.0f, 11.3f, -222.f, 23.0f, 11.3f, -222.f, 23.0f, 3311.3f, -222.f, 23.0f, 11.3f, -222.f, 43.0f, 413.3f, -114.f}; + std::vector results = {1.33333333f, 29.f, 3.f, 14.f, 2.f, 23.f, 2.f, 23.f, 2.f, 23.f, 2.66666667f, 17.f, 2.f, 23.f, 3.f, 14.f}; + std::vector base_values{0.f, 0.f}; + GenTreeAndRunTest(X, base_values, results, "AVERAGE", false, 200, 1); // section E2 } TEST(MLOpTest, TreeRegressorMultiTargetAverage) { @@ -277,15 +301,27 @@ TEST(MLOpTest, TreeRegressorSingleTargetSumBatch) { GenTreeAndRunTest1("SUM", false, 40002); } -TEST(MLOpTest, TreeRegressorSingleTargetBatchTree) { +TEST(MLOpTest, TreeRegressorSingleTargetBatchTreeA) { // TreeEnsemble implements different paths depending on n_trees or N. - // This test goes through all sections for one target. - GenTreeAndRunTest1("SUM", true, 3, 1); // section A - GenTreeAndRunTest1("AVERAGE", true, 3, 30); // section B - GenTreeAndRunTest1("AVERAGE", false, 3, 1); // section C + // This test and the next ones goe through all sections for one target. + GenTreeAndRunTest1("SUM", true, 3, 1); // section A +} + +TEST(MLOpTest, TreeRegressorSingleTargetBatchTreeB) { + GenTreeAndRunTest1("AVERAGE", true, 3, 30); // section B +} + +TEST(MLOpTest, TreeRegressorSingleTargetBatchTreeC) { + GenTreeAndRunTest1("AVERAGE", false, 3, 1); // section C +} + +TEST(MLOpTest, TreeRegressorSingleTargetBatchTreeD) { GenTreeAndRunTest1("AVERAGE", false, 201, 30); // section D GenTreeAndRunTest1("AVERAGE", false, 201, 130); // section D - GenTreeAndRunTest1("AVERAGE", false, 201, 1); // section E +} + +TEST(MLOpTest, TreeRegressorSingleTargetBatchTreeE) { + GenTreeAndRunTest1("AVERAGE", false, 201, 1); // section E } TEST(MLOpTest, TreeRegressorSingleTargetAverage) {