From d06a714643bc7a2c3e31b49785ee5b25c219754a Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Fri, 24 Nov 2023 13:21:29 +0800 Subject: [PATCH 01/10] :sparkles: Feature: added base code for igamma and igammac --- paddle/phi/api/yaml/ops.yaml | 18 ++++++++++ paddle/phi/kernels/cpu/igamma_kernel.cc | 38 ++++++++++++++++++++ paddle/phi/kernels/cpu/igammac_kernel.cc | 38 ++++++++++++++++++++ paddle/phi/kernels/gpu/igamma_kernel.cu | 30 ++++++++++++++++ paddle/phi/kernels/gpu/igammac_kernel.cu | 32 +++++++++++++++++ paddle/phi/kernels/igamma_kernel.h | 34 ++++++++++++++++++ paddle/phi/kernels/igammac_kernel.h | 34 ++++++++++++++++++ paddle/phi/kernels/impl/igamma_kernel_impl.h | 20 +++++++++++ 8 files changed, 244 insertions(+) create mode 100644 paddle/phi/kernels/cpu/igamma_kernel.cc create mode 100644 paddle/phi/kernels/cpu/igammac_kernel.cc create mode 100644 paddle/phi/kernels/gpu/igamma_kernel.cu create mode 100644 paddle/phi/kernels/gpu/igammac_kernel.cu create mode 100644 paddle/phi/kernels/igamma_kernel.h create mode 100644 paddle/phi/kernels/igammac_kernel.h create mode 100644 paddle/phi/kernels/impl/igamma_kernel_impl.h diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 82d24dfff4dd81..5fef80efbdc7a8 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1222,6 +1222,24 @@ func : i1e backward : i1e_grad +- op : igamma + args : (Tensor x, Tensor a) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : igamma + inplace: (x -> out) + +- op : igammac + args : (Tensor x, Tensor a) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : igammac + inplace: (x -> out) + - op : imag args : (Tensor x) output : Tensor (out) diff --git a/paddle/phi/kernels/cpu/igamma_kernel.cc b/paddle/phi/kernels/cpu/igamma_kernel.cc new file mode 100644 index 00000000000000..ea35a108d74cce --- /dev/null +++ b/paddle/phi/kernels/cpu/igamma_kernel.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/igamma_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +void IgammaKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& a, + DenseTensor* out) {} + +} // namespace phi + +PD_REGISTER_KERNEL(igamma, + CPU, + ALL_LAYOUT, + phi::IgammaKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/igammac_kernel.cc b/paddle/phi/kernels/cpu/igammac_kernel.cc new file mode 100644 index 00000000000000..35911c523facf4 --- /dev/null +++ b/paddle/phi/kernels/cpu/igammac_kernel.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/igammac_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +void IgammacKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& a, + DenseTensor* out) {} + +} // namespace phi + +PD_REGISTER_KERNEL(igammac, + CPU, + ALL_LAYOUT, + phi::IgammacKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/igamma_kernel.cu b/paddle/phi/kernels/gpu/igamma_kernel.cu new file mode 100644 index 00000000000000..5746500b1977bf --- /dev/null +++ b/paddle/phi/kernels/gpu/igamma_kernel.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/igamma_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +namespace phi { + +template +void IgammaKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& a, DenseTensor* out) { + +} + +} // namespace phi + +PD_REGISTER_KERNEL(igamma, GPU, ALL_LAYOUT, phi::IgammaKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/igammac_kernel.cu b/paddle/phi/kernels/gpu/igammac_kernel.cu new file mode 100644 index 00000000000000..8fe029b716f307 --- /dev/null +++ b/paddle/phi/kernels/gpu/igammac_kernel.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/igammac_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +namespace phi { + +template +void IgammacKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& a, + DenseTensor* out) {} + +} // namespace phi + +PD_REGISTER_KERNEL( + igammac, GPU, ALL_LAYOUT, phi::IgammacKernel, float, double) {} diff --git a/paddle/phi/kernels/igamma_kernel.h b/paddle/phi/kernels/igamma_kernel.h new file mode 100644 index 00000000000000..53d9dc0370fa90 --- /dev/null +++ b/paddle/phi/kernels/igamma_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/** + * @brief This kernel calculate Modified Bessel function of order 0. + * @param ctx device context + * @param x The input tensor of i0 + * @param out The output tensor of i0 kernel, it has the same shape and + * dtype with input. Each element corresponds to input tensor + */ +template +void IgammaKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& a, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/igammac_kernel.h b/paddle/phi/kernels/igammac_kernel.h new file mode 100644 index 00000000000000..30ede703ce470b --- /dev/null +++ b/paddle/phi/kernels/igammac_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/** + * @brief This kernel calculate Modified Bessel function of order 0. + * @param ctx device context + * @param x The input tensor of i0 + * @param out The output tensor of i0 kernel, it has the same shape and + * dtype with input. Each element corresponds to input tensor + */ +template +void IgammacKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& a, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/impl/igamma_kernel_impl.h b/paddle/phi/kernels/impl/igamma_kernel_impl.h new file mode 100644 index 00000000000000..8054b09f97ed79 --- /dev/null +++ b/paddle/phi/kernels/impl/igamma_kernel_impl.h @@ -0,0 +1,20 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi {} // namespace phi From 5b95b346028a9e9a881f1930bdab7a47d4d7bb31 Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Fri, 24 Nov 2023 23:24:55 +0800 Subject: [PATCH 02/10] :sparkles: Feature: added partial code for igamma --- paddle/phi/kernels/cpu/igamma_kernel.cc | 12 +- paddle/phi/kernels/impl/igamma_kernel_impl.h | 235 ++++++++++++++++++- 2 files changed, 245 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/cpu/igamma_kernel.cc b/paddle/phi/kernels/cpu/igamma_kernel.cc index ea35a108d74cce..e91938e937457c 100644 --- a/paddle/phi/kernels/cpu/igamma_kernel.cc +++ b/paddle/phi/kernels/cpu/igamma_kernel.cc @@ -24,7 +24,17 @@ template void IgammaKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& a, - DenseTensor* out) {} + DenseTensor* out) { + const int64_t size = x.numel(); + const T* x_data = x.data(); + const T* a_data = a.data(); + T* out_data = ctx.template Alloc(out); + + phi::funcs::ForRange for_range(ctx, size); + IgammaFunctor functor(x_data, a_data, out_data, size); + for_range(functor); + +} } // namespace phi diff --git a/paddle/phi/kernels/impl/igamma_kernel_impl.h b/paddle/phi/kernels/impl/igamma_kernel_impl.h index 8054b09f97ed79..f9d0c36a539575 100644 --- a/paddle/phi/kernels/impl/igamma_kernel_impl.h +++ b/paddle/phi/kernels/impl/igamma_kernel_impl.h @@ -17,4 +17,237 @@ limitations under the License. */ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -namespace phi {} // namespace phi +namespace phi { + +template +static inline T ratevl(T x, const T num[], int64_t M, + const T denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + T y, num_ans, denom_ans; + T absx = ::fabs(x); + const T *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return ::pow(x, static_cast(i)) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +template +static inline T lanczos_sum_expg_scaled(T x) { + // lanczos approximation + + static const T lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const T lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0 + }; + return ratevl(static_cast(x), lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static inline T _igam_helper_fac(T a, T x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + T ax, fac, res, num, numfac; + static const T MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static const T EXP1 = 2.718281828459045; + static const T lanczos_g = 6.024680040776729583740234375; + + if (::fabs(a - x) > 0.4 * ::fabs(a)) { + ax = a * ::log(x) - x - ::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return ::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= ::exp(a - x) * ::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static inline T _igam_helper_series(T a, T x) { + // Compute igam using DLMF 8.11.4. [igam1] + static const T MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static const int MAXITER = 2000; + + int i; + T ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +struct IgammaFunctor { + static const T SMALL = 20.0; + static const T LARGE = 200.0; + static const T SMALLRATIO = 0.3; + static const T LARGERATIO = 4.5; + IgammaFunctor(const T* input, const T* other, T* output, int64_t numel) + : input_(input), other_(other), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + a = other_[idx]; + x = input_[idx]; + + // boundary values following SciPy + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (::isinf(static_cast(x))) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. */ + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); + } + + private: + const T* input_; + const T* other_; + T* output_; + int64_t numel_; +}; + +} // namespace phi From 0218c4bfba9ea2829ca5d39a0bd0f5a952b69d11 Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Sun, 26 Nov 2023 12:33:29 +0800 Subject: [PATCH 03/10] :sparkles: Feature: added infermeta for igamma and igammac --- paddle/phi/api/yaml/ops.yaml | 6 +- paddle/phi/infermeta/unary.cc | 56 ++ paddle/phi/infermeta/unary.h | 6 + paddle/phi/kernels/cpu/igamma_kernel.cc | 17 +- paddle/phi/kernels/impl/igamma_kernel_impl.h | 603 ++++++++++++++++--- 5 files changed, 579 insertions(+), 109 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 5fef80efbdc7a8..aa0fccf3e39b3c 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1226,7 +1226,8 @@ args : (Tensor x, Tensor a) output : Tensor(out) infer_meta : - func : UnchangedInferMeta + func : IgammaInferMeta + param: [x] kernel : func : igamma inplace: (x -> out) @@ -1235,7 +1236,8 @@ args : (Tensor x, Tensor a) output : Tensor(out) infer_meta : - func : UnchangedInferMeta + func : IgammacInferMeta + param: [x] kernel : func : igammac inplace: (x -> out) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 5bdcb58a51fb60..3c254f4c27868e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5230,6 +5230,62 @@ void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_strides(x.strides()); } +void IgammaInferMeta(const MetaTensor& x, + const MetaTensor& a, + MetaTensor* out) { + out->share_meta(x); + out->set_dtype(x.dtype()); + + auto x_dims = x.dims(); + auto a_dims = a.dims(); + auto x_dims_size = x_dims.size(); + auto a_dims_size = a_dims.size(); + + PADDLE_ENFORCE_EQ( + x_dims_size == a_dims_size || x_dims_size == 1 || a_dims_size == 1, + false, + phi::errors::InvalidArgument( + "The size of tesnor x and tensor a should be same dimension, or one " + "of the two tensors should be 1, but received tensor x (%d) and " + "tensor a (%d).", + x_dims_size, + a_dims_size)); + + if (x_dims_size == 1) { + out->set_dims(a_dims); + } else { + out->set_dims(x_dims); + } +} + +void IgammacInferMeta(const MetaTensor& x, + const MetaTensor& a, + MetaTensor* out) { + out->share_meta(x); + out->set_dtype(x.dtype()); + + auto x_dims = x.dims(); + auto a_dims = a.dims(); + auto x_dims_size = x_dims.size(); + auto a_dims_size = a_dims.size(); + + PADDLE_ENFORCE_EQ( + x_dims_size == a_dims_size || x_dims_size == 1 || a_dims_size == 1, + false, + phi::errors::InvalidArgument( + "The size of tesnor x and tensor a should be same dimension, or one " + "of the two tensors should be 1, but received tensor x (%d) and " + "tensor a (%d).", + x_dims_size, + a_dims_size)); + + if (x_dims_size == 1) { + out->set_dims(a_dims); + } else { + out->set_dims(x_dims); + } +} + } // namespace phi PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index daab02f2b46b14..151c4f136262e2 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -295,6 +295,12 @@ void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out); void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out); +void IgammaInferMeta(const MetaTensor& x, const MetaTensor& a, MetaTensor* out); + +void IgammacInferMeta(const MetaTensor& x, + const MetaTensor& a, + MetaTensor* out); + void KthvalueInferMeta(const MetaTensor& x, int k, int axis, diff --git a/paddle/phi/kernels/cpu/igamma_kernel.cc b/paddle/phi/kernels/cpu/igamma_kernel.cc index e91938e937457c..284fc4b0f372fb 100644 --- a/paddle/phi/kernels/cpu/igamma_kernel.cc +++ b/paddle/phi/kernels/cpu/igamma_kernel.cc @@ -25,15 +25,14 @@ void IgammaKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& a, DenseTensor* out) { - const int64_t size = x.numel(); - const T* x_data = x.data(); - const T* a_data = a.data(); - T* out_data = ctx.template Alloc(out); - - phi::funcs::ForRange for_range(ctx, size); - IgammaFunctor functor(x_data, a_data, out_data, size); - for_range(functor); - + const int64_t size = x.numel(); + const T* x_data = x.data(); + const T* a_data = a.data(); + T* out_data = ctx.template Alloc(out); + + phi::funcs::ForRange for_range(ctx, size); + IgammaFunctor functor(x_data, a_data, out_data, size); + for_range(functor); } } // namespace phi diff --git a/paddle/phi/kernels/impl/igamma_kernel_impl.h b/paddle/phi/kernels/impl/igamma_kernel_impl.h index f9d0c36a539575..29c2cd5c4fc743 100644 --- a/paddle/phi/kernels/impl/igamma_kernel_impl.h +++ b/paddle/phi/kernels/impl/igamma_kernel_impl.h @@ -20,8 +20,8 @@ limitations under the License. */ namespace phi { template -static inline T ratevl(T x, const T num[], int64_t M, - const T denom[], int64_t N) { +static inline T ratevl( + T x, const T num[], int64_t M, const T denom[], int64_t N) { // evaluating rational function, i.e., the ratio of two polynomials // the coefficients for numerator are given by `num` while coeffs for // denumerator are given by `denom` @@ -29,15 +29,14 @@ static inline T ratevl(T x, const T num[], int64_t M, int64_t i, dir; T y, num_ans, denom_ans; T absx = ::fabs(x); - const T *p; + const T* p; if (absx > 1) { /* Evaluate as a polynomial in 1/x. */ dir = -1; p = num + M; y = 1 / x; - } - else { + } else { dir = 1; p = num; y = x; @@ -53,8 +52,7 @@ static inline T ratevl(T x, const T num[], int64_t M, /* Evaluate the denominator */ if (absx > 1) { p = denom + N; - } - else { + } else { p = denom; } @@ -67,8 +65,7 @@ static inline T ratevl(T x, const T num[], int64_t M, if (absx > 1) { i = N - M; return ::pow(x, static_cast(i)) * num_ans / denom_ans; - } - else { + } else { return num_ans / denom_ans; } } @@ -78,39 +75,41 @@ static inline T lanczos_sum_expg_scaled(T x) { // lanczos approximation static const T lanczos_sum_expg_scaled_num[13] = { - 0.006061842346248906525783753964555936883222, - 0.5098416655656676188125178644804694509993, - 19.51992788247617482847860966235652136208, - 449.9445569063168119446858607650988409623, - 6955.999602515376140356310115515198987526, - 75999.29304014542649875303443598909137092, - 601859.6171681098786670226533699352302507, - 3481712.15498064590882071018964774556468, - 14605578.08768506808414169982791359218571, - 43338889.32467613834773723740590533316085, - 86363131.28813859145546927288977868422342, - 103794043.1163445451906271053616070238554, - 56906521.91347156388090791033559122686859 - }; - static const T lanczos_sum_expg_scaled_denom[13] = { - 1., - 66., - 1925., - 32670., - 357423., - 2637558., - 13339535., - 45995730., - 105258076., - 150917976., - 120543840., - 39916800., - 0 - }; - return ratevl(static_cast(x), lanczos_sum_expg_scaled_num, - sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, - lanczos_sum_expg_scaled_denom, - sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859}; + static const T lanczos_sum_expg_scaled_denom[13] = {1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0}; + return ratevl(static_cast(x), + lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / + sizeof(lanczos_sum_expg_scaled_num[0]) - + 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / + sizeof(lanczos_sum_expg_scaled_denom[0]) - + 1); } template @@ -120,8 +119,9 @@ static inline T _igam_helper_fac(T a, T x) { // exp(a - x). T ax, fac, res, num, numfac; - static const T MAXLOG = std::is_same::value ? - 7.09782712893383996843E2 : 88.72283905206835; + static const T MAXLOG = std::is_same::value + ? 7.09782712893383996843E2 + : 88.72283905206835; static const T EXP1 = 2.718281828459045; static const T lanczos_g = 6.024680040776729583740234375; @@ -138,8 +138,7 @@ static inline T _igam_helper_fac(T a, T x) { if ((a < 200) && (x < 200)) { res *= ::exp(a - x) * ::pow(x / fac, a); - } - else { + } else { num = x - a - lanczos_g + 0.5; numfac = num / fac; res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); @@ -150,8 +149,9 @@ static inline T _igam_helper_fac(T a, T x) { template static inline T _igam_helper_series(T a, T x) { // Compute igam using DLMF 8.11.4. [igam1] - static const T MACHEP = std::is_same::value ? - 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static const T MACHEP = std::is_same::value + ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; static const int MAXITER = 2000; int i; @@ -179,68 +179,475 @@ static inline T _igam_helper_series(T a, T x) { } template -struct IgammaFunctor { - static const T SMALL = 20.0; - static const T LARGE = 200.0; - static const T SMALLRATIO = 0.3; - static const T LARGERATIO = 4.5; - IgammaFunctor(const T* input, const T* other, T* output, int64_t numel) - : input_(input), other_(other), output_(output), numel_(numel) {} +static inline T _igam_helper_asymptotic_series(T a, T x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] - HOSTDEVICE void operator()(int64_t idx) const { - /* the calculation of the regularized lower incomplete gamma function - * is done differently based on the values of a and x: - * - if x and/or a is at the boundary of defined region, then assign the - * result at the boundary - * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for - * Large Parameter (see DLMF 8.12.3 [igam1]) - * - if x > 1 and x > a, using the substraction from the regularized upper - * incomplete gamma - * - otherwise, calculate the series from [igam2] eq (4) - */ - a = other_[idx]; - x = input_[idx]; - - // boundary values following SciPy - if ((x < 0) || (a < 0)) { - // out of defined-region of the function - return std::numeric_limits::quiet_NaN(); + static const T d[25][25] = { + {-3.3333333333333333e-1, 8.3333333333333333e-2, + -1.4814814814814815e-2, 1.1574074074074074e-3, + 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, + -1.85406221071516e-6, 8.296711340953086e-7, + -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, + 9.1476995822367902e-10, -2.551419399494625e-11, + -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, + 3.3717632624009854e-13, -1.3923887224181621e-13, + 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, + -2.9907248030319018e-4, -1.4638452578843418e-6, + 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, + -1.6954149536558306e-6, 8.9075075322053097e-7, + -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, + 3.4463580499464897e-9, -2.3024517174528067e-13, + -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, + 4.6792750266579195e-12, -2.1492464706134829e-12, + 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, + -3.3493161081142236e-4, 2.812695154763237e-4, + -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, + 5.7876949497350524e-6, 4.9387589339362704e-10, + -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, + 2.695423606288966e-8, -1.4578352908731271e-8, + 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, + -6.9957960920705679e-11, 2.5899863874868481e-17, + 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, + 6.7823088376673284e-4, -6.4014752602627585e-4, + 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, + -2.1073920183404862e-5, -8.8585890141255994e-10, + 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, + -1.5344695190702061e-7, 8.862466778790695e-8, + -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, + 5.7370135528051385e-10, -1.887749850169741e-19, + -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static const T MACHEP = std::is_same::value + ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; + T lambda = x / a; + T sigma = (x - a) / a; + T eta, res, ck, ckterm, term, absterm; + T absoldterm = INFINITY; + T etapow[25] = {1}; + T sum = 0; + T afac = 1; + + if (igam) { + sgn = -1; + } else { + sgn = 1; } - else if (a == 0) { - if (x > 0) { - return 1.0; + + if (lambda > 1) { + eta = ::sqrt(-2 * (::log1p(sigma) - sigma)); + } else if (lambda < 1) { + eta = -::sqrt(-2 * (::log1p(sigma) - sigma)); + } else { + eta = 0; + } + res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n - 1]; + maxpow += 1; + } + ckterm = d[k][n] * etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } } - else { - return std::numeric_limits::quiet_NaN(); + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; } - } - else if (x == 0) { - return 0.0; // zero integration limit - } - else if (::isinf(static_cast(a))) { - if (::isinf(static_cast(x))) { - return std::numeric_limits::quiet_NaN(); + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; } - return 0.0; - } - else if (::isinf(static_cast(x))) { - return 1.0; + absoldterm = absterm; + afac /= a; } + res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a); - /* Asymptotic regime where a ~ x. */ - absxma_a = ::fabs(x - a) / a; - if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { - return _igam_helper_asymptotic_series(a, x, 1); + return res; +} + +template +static T _igamc_helper_continued_fraction(T a, T x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + T ans, ax, c, yc, r, t, y, z; + T pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static T MACHEP = std::is_same::value ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; + static T BIG = + std::is_same::value ? 4.503599627370496e15 : 16777216.; + static T BIGINV = std::is_same::value ? 2.22044604925031308085e-16 + : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; } - else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { - return _igam_helper_asymptotic_series(a, x, 1); + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } } + return ans * ax; +} + +template +struct IgammaFunctor { + static const T SMALL = 20.0; + static const T LARGE = 200.0; + static const T SMALLRATIO = 0.3; + static const T LARGERATIO = 4.5; + IgammaFunctor(const T* input, const T* other, T* output, int64_t numel) + : input_(input), other_(other), output_(output), numel_(numel) {} - if ((x > 1.0) && (x > a)) { - return 1.0 - calc_igammac(a, x); + HOSTDEVICE void operator()(int64_t idx) const { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + a = other_[idx]; + x = input_[idx]; + + T absxma_a; } - return _igam_helper_series(a, x); + private: + const T* input_; + const T* other_; + T* output_; + int64_t numel_; +}; + +template +struct IgammacFunctor { + static const T SMALL = 20.0; + static const T LARGE = 200.0; + static const T SMALLRATIO = 0.3; + static const T LARGERATIO = 4.5; + IgammacFunctor(const T* input, const T* other, T* output, int64_t numel) + : input_(input), other_(other), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + a = other_[idx]; + x = input_[idx]; + T absxma_a; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } else if (a == 0) { + if (x > 0) { + return 0.0; + } else { + return std::numeric_limits::quiet_NaN(); + } + } else if (x == 0) { + return 1.0; + } else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_continued_fraction(a, x); + } + } else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } } private: From e5b67043560c28d2e285567355f1b7d00617053d Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Sun, 26 Nov 2023 21:08:03 +0800 Subject: [PATCH 04/10] :bug: Fix: fixed type error in impl file --- paddle/phi/api/yaml/ops.yaml | 2 - paddle/phi/kernels/cpu/igamma_kernel.cc | 5 +- paddle/phi/kernels/cpu/igammac_kernel.cc | 16 ++- paddle/phi/kernels/impl/igamma_kernel_impl.h | 102 +++++++++++++------ 4 files changed, 86 insertions(+), 39 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index aa0fccf3e39b3c..aab850cad3ddcc 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1227,7 +1227,6 @@ output : Tensor(out) infer_meta : func : IgammaInferMeta - param: [x] kernel : func : igamma inplace: (x -> out) @@ -1237,7 +1236,6 @@ output : Tensor(out) infer_meta : func : IgammacInferMeta - param: [x] kernel : func : igammac inplace: (x -> out) diff --git a/paddle/phi/kernels/cpu/igamma_kernel.cc b/paddle/phi/kernels/cpu/igamma_kernel.cc index 284fc4b0f372fb..ecb9a4f25e3d9b 100644 --- a/paddle/phi/kernels/cpu/igamma_kernel.cc +++ b/paddle/phi/kernels/cpu/igamma_kernel.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/igamma_kernel_impl.h" namespace phi { @@ -42,6 +43,4 @@ PD_REGISTER_KERNEL(igamma, ALL_LAYOUT, phi::IgammaKernel, float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + double) {} diff --git a/paddle/phi/kernels/cpu/igammac_kernel.cc b/paddle/phi/kernels/cpu/igammac_kernel.cc index 35911c523facf4..58bb587533543e 100644 --- a/paddle/phi/kernels/cpu/igammac_kernel.cc +++ b/paddle/phi/kernels/cpu/igammac_kernel.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/igamma_kernel_impl.h" namespace phi { @@ -24,7 +25,16 @@ template void IgammacKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& a, - DenseTensor* out) {} + DenseTensor* out) { + const int64_t size = x.numel(); + const T* x_data = x.data(); + const T* a_data = a.data(); + T* out_data = ctx.template Alloc(out); + + phi::funcs::ForRange for_range(ctx, size); + IgammacFunctor functor(x_data, a_data, out_data, size); + for_range(functor); +} } // namespace phi @@ -33,6 +43,4 @@ PD_REGISTER_KERNEL(igammac, ALL_LAYOUT, phi::IgammacKernel, float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + double) {} diff --git a/paddle/phi/kernels/impl/igamma_kernel_impl.h b/paddle/phi/kernels/impl/igamma_kernel_impl.h index 29c2cd5c4fc743..6c439df6721021 100644 --- a/paddle/phi/kernels/impl/igamma_kernel_impl.h +++ b/paddle/phi/kernels/impl/igamma_kernel_impl.h @@ -484,7 +484,34 @@ static inline T _igam_helper_asymptotic_series(T a, T x, bool igam) { } template -static T _igamc_helper_continued_fraction(T a, T x) { +static inline T _igamc_helper_series(T a, T x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + T fac = 1; + T sum = 0; + T term, logx; + static T MAXITER = 2000; + static T MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static inline T _igamc_helper_continued_fraction(T a, T x) { // Compute igamc using DLMF 8.9.2. [igam1] int i; T ans, ax, c, yc, r, t, y, z; @@ -545,10 +572,6 @@ static T _igamc_helper_continued_fraction(T a, T x) { template struct IgammaFunctor { - static const T SMALL = 20.0; - static const T LARGE = 200.0; - static const T SMALLRATIO = 0.3; - static const T LARGERATIO = 4.5; IgammaFunctor(const T* input, const T* other, T* output, int64_t numel) : input_(input), other_(other), output_(output), numel_(numel) {} @@ -563,10 +586,14 @@ struct IgammaFunctor { * incomplete gamma * - otherwise, calculate the series from [igam2] eq (4) */ - a = other_[idx]; - x = input_[idx]; - - T absxma_a; + // static const T SMALL = 20.0; + // static const T LARGE = 200.0; + // static const T SMALLRATIO = 0.3; + // static const T LARGERATIO = 4.5; + // a = other_[idx]; + // x = input_[idx]; + + // T absxma_a; } private: @@ -578,10 +605,6 @@ struct IgammaFunctor { template struct IgammacFunctor { - static const T SMALL = 20.0; - static const T LARGE = 200.0; - static const T SMALLRATIO = 0.3; - static const T LARGERATIO = 4.5; IgammacFunctor(const T* input, const T* other, T* output, int64_t numel) : input_(input), other_(other), output_(output), numel_(numel) {} @@ -596,56 +619,75 @@ struct IgammacFunctor { * incomplete gamma * - otherwise, calculate the series from [igam2] eq (5) */ - a = other_[idx]; - x = input_[idx]; + static const T SMALL = 20.0; + static const T LARGE = 200.0; + static const T SMALLRATIO = 0.3; + static const T LARGERATIO = 4.5; + T a = other_[idx]; + T x = input_[idx]; T absxma_a; // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. if ((x < 0) || (a < 0)) { // out of defined-region of the function - return std::numeric_limits::quiet_NaN(); + output_[idx] = std::numeric_limits::quiet_NaN(); + return; } else if (a == 0) { if (x > 0) { - return 0.0; + output_[idx] = 0.0; + return; } else { - return std::numeric_limits::quiet_NaN(); + output_[idx] = std::numeric_limits::quiet_NaN(); + return; } } else if (x == 0) { - return 1.0; + output_[idx] = 1.0; + return; } else if (std::isinf(a)) { if (std::isinf(x)) { - return std::numeric_limits::quiet_NaN(); + output_[idx] = std::numeric_limits::quiet_NaN(); + return; } - return 1.0; + output_[idx] = 1.0; + return; } else if (std::isinf(x)) { - return 0.0; + output_[idx] = 0.0; + return; } absxma_a = std::fabs(x - a) / a; if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { - return _igam_helper_asymptotic_series(a, x, 0); + output_[idx] = _igam_helper_asymptotic_series(a, x, 0); + return; } else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { - return _igam_helper_asymptotic_series(a, x, 0); + output_[idx] = _igam_helper_asymptotic_series(a, x, 0); + return; } if (x > 1.1) { if (x < a) { - return 1.0 - _igam_helper_series(a, x); + output_[idx] = 1.0 - _igam_helper_series(a, x); + return; } else { - return _igamc_helper_continued_fraction(a, x); + output_[idx] = _igamc_helper_continued_fraction(a, x); + return; } } else if (x <= 0.5) { if (-0.4 / std::log(x) < a) { - return 1.0 - _igam_helper_series(a, x); + output_[idx] = 1.0 - _igam_helper_series(a, x); + return; } else { - return _igamc_helper_series(a, x); + output_[idx] = _igamc_helper_series(a, x); + return; } } else { if (x * 1.1 < a) { - return 1.0 - _igam_helper_series(a, x); + output_[idx] = 1.0 - _igam_helper_series(a, x); + return; } else { - return _igamc_helper_series(a, x); + output_[idx] = _igamc_helper_series(a, x); + return; } } } From 8e99783f19fa19f582b98d4d63225a84058b756b Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Tue, 28 Nov 2023 00:14:18 +0800 Subject: [PATCH 05/10] :recycle: Refactor: fixed cpu code for igamma and igammac --- paddle/phi/infermeta/unary.cc | 12 +- paddle/phi/kernels/cpu/igamma_kernel.cc | 7 +- paddle/phi/kernels/cpu/igammac_kernel.cc | 8 +- paddle/phi/kernels/impl/igamma_kernel_impl.h | 199 +++++++++++-------- python/paddle/__init__.py | 4 + python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/manipulation.py | 97 +++++++++ 7 files changed, 231 insertions(+), 100 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3c254f4c27868e..a6aa6ca579d082 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5243,11 +5243,11 @@ void IgammaInferMeta(const MetaTensor& x, PADDLE_ENFORCE_EQ( x_dims_size == a_dims_size || x_dims_size == 1 || a_dims_size == 1, - false, + true, phi::errors::InvalidArgument( "The size of tesnor x and tensor a should be same dimension, or one " - "of the two tensors should be 1, but received tensor x (%d) and " - "tensor a (%d).", + "of the two tensors should be 1, but received tensor x (%u) and " + "tensor a (%u).", x_dims_size, a_dims_size)); @@ -5271,11 +5271,11 @@ void IgammacInferMeta(const MetaTensor& x, PADDLE_ENFORCE_EQ( x_dims_size == a_dims_size || x_dims_size == 1 || a_dims_size == 1, - false, + true, phi::errors::InvalidArgument( "The size of tesnor x and tensor a should be same dimension, or one " - "of the two tensors should be 1, but received tensor x (%d) and " - "tensor a (%d).", + "of the two tensors should be 1, but received tensor x (%u) and " + "tensor a (%u).", x_dims_size, a_dims_size)); diff --git a/paddle/phi/kernels/cpu/igamma_kernel.cc b/paddle/phi/kernels/cpu/igamma_kernel.cc index ecb9a4f25e3d9b..9e4c4d3c805273 100644 --- a/paddle/phi/kernels/cpu/igamma_kernel.cc +++ b/paddle/phi/kernels/cpu/igamma_kernel.cc @@ -38,9 +38,4 @@ void IgammaKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL(igamma, - CPU, - ALL_LAYOUT, - phi::IgammaKernel, - float, - double) {} +PD_REGISTER_KERNEL(igamma, CPU, ALL_LAYOUT, phi::IgammaKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/igammac_kernel.cc b/paddle/phi/kernels/cpu/igammac_kernel.cc index 58bb587533543e..41ceb6eb0b1bc0 100644 --- a/paddle/phi/kernels/cpu/igammac_kernel.cc +++ b/paddle/phi/kernels/cpu/igammac_kernel.cc @@ -38,9 +38,5 @@ void IgammacKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL(igammac, - CPU, - ALL_LAYOUT, - phi::IgammacKernel, - float, - double) {} +PD_REGISTER_KERNEL( + igammac, CPU, ALL_LAYOUT, phi::IgammacKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/igamma_kernel_impl.h b/paddle/phi/kernels/impl/igamma_kernel_impl.h index 6c439df6721021..4ea3ee8a482ff4 100644 --- a/paddle/phi/kernels/impl/igamma_kernel_impl.h +++ b/paddle/phi/kernels/impl/igamma_kernel_impl.h @@ -493,22 +493,22 @@ static inline T _igamc_helper_series(T a, T x) { T sum = 0; T term, logx; static T MAXITER = 2000; - static T MACHEP = std::is_same::value ? - 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static T MACHEP = std::is_same::value ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; for (n = 1; n < MAXITER; n++) { fac *= -x / n; term = fac / (a + n); sum += term; if (std::fabs(term) <= MACHEP * std::fabs(sum)) { - break; + break; } } logx = std::log(x); - term = -std::expm1(a * logx - std::lgamma(1+a)); + term = -std::expm1(a * logx - std::lgamma(1 + a)); return term - std::exp(a * logx - std::lgamma(a)) * sum; -} +} template static inline T _igamc_helper_continued_fraction(T a, T x) { @@ -570,6 +570,114 @@ static inline T _igamc_helper_continued_fraction(T a, T x) { return ans * ax; } +template +static inline T compute_igammac(T x, T a) { + static T SMALL = 20.0; + static T LARGE = 200.0; + static T SMALLRATIO = 0.3; + static T LARGERATIO = 4.5; + T absxma_a; + + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } else if (a == 0) { + if (x > 0) { + return 0.0; + } else { + return std::numeric_limits::quiet_NaN(); + } + } else if (x == 0) { + return 1.0; + } else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_continued_fraction(a, x); + } + } else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } +} + +template +static inline T compute_igamma(T x, T a) { + static T SMALL = 20.0; + static T LARGE = 200.0; + static T SMALLRATIO = 0.3; + static T LARGERATIO = 4.5; + T absxma_a; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - compute_igammac(x, a); + } + + return _igam_helper_series(a, x); +} + template struct IgammaFunctor { IgammaFunctor(const T* input, const T* other, T* output, int64_t numel) @@ -586,14 +694,9 @@ struct IgammaFunctor { * incomplete gamma * - otherwise, calculate the series from [igam2] eq (4) */ - // static const T SMALL = 20.0; - // static const T LARGE = 200.0; - // static const T SMALLRATIO = 0.3; - // static const T LARGERATIO = 4.5; - // a = other_[idx]; - // x = input_[idx]; - - // T absxma_a; + T a = other_[idx]; + T x = input_[idx]; + output_[idx] = compute_igamma(x, a); } private: @@ -619,77 +722,9 @@ struct IgammacFunctor { * incomplete gamma * - otherwise, calculate the series from [igam2] eq (5) */ - static const T SMALL = 20.0; - static const T LARGE = 200.0; - static const T SMALLRATIO = 0.3; - static const T LARGERATIO = 4.5; T a = other_[idx]; T x = input_[idx]; - T absxma_a; - - // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., - // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. - if ((x < 0) || (a < 0)) { - // out of defined-region of the function - output_[idx] = std::numeric_limits::quiet_NaN(); - return; - } else if (a == 0) { - if (x > 0) { - output_[idx] = 0.0; - return; - } else { - output_[idx] = std::numeric_limits::quiet_NaN(); - return; - } - } else if (x == 0) { - output_[idx] = 1.0; - return; - } else if (std::isinf(a)) { - if (std::isinf(x)) { - output_[idx] = std::numeric_limits::quiet_NaN(); - return; - } - output_[idx] = 1.0; - return; - } else if (std::isinf(x)) { - output_[idx] = 0.0; - return; - } - - absxma_a = std::fabs(x - a) / a; - if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { - output_[idx] = _igam_helper_asymptotic_series(a, x, 0); - return; - } else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { - output_[idx] = _igam_helper_asymptotic_series(a, x, 0); - return; - } - - if (x > 1.1) { - if (x < a) { - output_[idx] = 1.0 - _igam_helper_series(a, x); - return; - } else { - output_[idx] = _igamc_helper_continued_fraction(a, x); - return; - } - } else if (x <= 0.5) { - if (-0.4 / std::log(x) < a) { - output_[idx] = 1.0 - _igam_helper_series(a, x); - return; - } else { - output_[idx] = _igamc_helper_series(a, x); - return; - } - } else { - if (x * 1.1 < a) { - output_[idx] = 1.0 - _igam_helper_series(a, x); - return; - } else { - output_[idx] = _igamc_helper_series(a, x); - return; - } - } + output_[idx] = compute_igammac(x, a); } private: diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5e87e9514c0e98..c55fe11e5b680b 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -261,6 +261,8 @@ index_fill, index_fill_, diagonal_scatter, + igamma, + igammac, ) from .tensor.math import ( # noqa: F401 @@ -931,4 +933,6 @@ 'index_fill', "index_fill_", 'diagonal_scatter', + 'igamma', + 'igammac', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b96045d35faf69..b404d6993b5ed8 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -174,6 +174,8 @@ from .manipulation import index_fill # noqa: F401 from .manipulation import index_fill_ # noqa: F401 from .manipulation import diagonal_scatter # noqa: F401 +from .manipulation import igamma # noqa: F401 +from .manipulation import igammac # noqa: F401 from .math import abs # noqa: F401 from .math import abs_ # noqa: F401 from .math import acos # noqa: F401 @@ -743,6 +745,8 @@ 'atleast_2d', 'atleast_3d', 'diagonal_scatter', + 'igamma', + 'igammac', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4afff9d3a9ad75..bdaf71a8e2ebd8 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5896,3 +5896,100 @@ def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None): """ return fill_diagonal_tensor(x, y, offset, axis1, axis2, name) + + +def igamma(x, other, name=None): + r""" + Computes the regularized lower incomplete gamma function. + The equation is: + + .. math:: + \mathrm{igamma}(x, a)=\frac{1}{\Gamma(x)}\int_0^at^{x-1}e^{-t}dt + + + Args: + x (Tensor): ``x`` is the original Tensor. Must be at least 2-dimensional. + y (Tensor): ``y`` is the Tensor to embed into ``x`` + offset (int, optional): which diagonal to consider. Default: 0 (main diagonal). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, Tensor with diagonal embedeed with ``y``. + + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.arange(6.0).reshape((2, 3)) + >>> a = paddle.ones((2,)) + >>> out = x.diagonal_scatter(y) + >>> print(out) + Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[1., 1., 2.], + [3., 1., 5.]]) + """ + if in_dynamic_or_pir_mode(): + return _C_ops.igamma(x, other) + else: + check_variable_and_dtype(x, "x", ["float32", "float64"], "igamma") + check_variable_and_dtype( + other, "other", ["float32", "float64"], "igamma" + ) + + helper = LayerHelper("igamma", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='igamma', + inputs={'x': x, 'a': other}, + outputs={'out': out}, + attrs={}, + ) + return out + + +def igammac(x, other, name=None): + r""" + Computes the regularized upper incomplete gamma function. + The equation is: + + .. math:: + \mathrm{igammac}(x, a)=\frac{1}{\Gamma(x)}\int_a^{\infty}t^{x-1}e^{-t}dt + + Args: + x (Tensor): ``x`` is the original Tensor. Must be at least 2-dimensional. + y (Tensor): ``y`` is the Tensor to embed into ``x`` + offset (int, optional): which diagonal to consider. Default: 0 (main diagonal). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, Tensor with diagonal embedeed with ``y``. + + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.arange(6.0).reshape((2, 3)) + >>> y = paddle.ones((2,)) + >>> out = x.diagonal_scatter(y) + >>> print(out) + Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[1., 1., 2.], + [3., 1., 5.]]) + """ + if in_dynamic_or_pir_mode(): + return _C_ops.igammac(x, other) + else: + check_variable_and_dtype(x, "x", ["float32", "float64"], "igammac") + check_variable_and_dtype( + other, "other", ["float32", "float64"], "igammac" + ) + + helper = LayerHelper("igammac", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='igammac', + inputs={'x': x, 'a': other}, + outputs={'out': out}, + attrs={}, + ) + return out From ba88f1f5d907971ffabd6878a4bc9ae9118e8446 Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Wed, 29 Nov 2023 00:30:16 +0800 Subject: [PATCH 06/10] Feature: added gpu code --- paddle/phi/kernels/gpu/igamma_kernel.cu | 7 +- paddle/phi/kernels/gpu/igammac_kernel.cu | 10 +- paddle/phi/kernels/impl/igamma_kernel_impl.h | 687 +++++++++++++++++++ 3 files changed, 702 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/igamma_kernel.cu b/paddle/phi/kernels/gpu/igamma_kernel.cu index 5746500b1977bf..d801490bc1a7ef 100644 --- a/paddle/phi/kernels/gpu/igamma_kernel.cu +++ b/paddle/phi/kernels/gpu/igamma_kernel.cu @@ -17,12 +17,17 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/igamma_kernel_impl.h" namespace phi { template void IgammaKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& a, DenseTensor* out) { - + ctx.template Alloc(out); + std::vector ins = {&x, &a}; + std::vector outs = {out}; + auto functor = CudaIgammaFunctor(); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/igammac_kernel.cu b/paddle/phi/kernels/gpu/igammac_kernel.cu index 8fe029b716f307..74147aa5f8c592 100644 --- a/paddle/phi/kernels/gpu/igammac_kernel.cu +++ b/paddle/phi/kernels/gpu/igammac_kernel.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/igamma_kernel_impl.h" namespace phi { @@ -24,7 +25,14 @@ template void IgammacKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& a, - DenseTensor* out) {} + DenseTensor* out) { +ctx.template Alloc(out); + std::vector ins = {&x, &a}; + std::vector outs = {out}; + auto functor = CudaIgammacFunctor(); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); + + } } // namespace phi diff --git a/paddle/phi/kernels/impl/igamma_kernel_impl.h b/paddle/phi/kernels/impl/igamma_kernel_impl.h index 4ea3ee8a482ff4..2663c56c9d048b 100644 --- a/paddle/phi/kernels/impl/igamma_kernel_impl.h +++ b/paddle/phi/kernels/impl/igamma_kernel_impl.h @@ -15,10 +15,697 @@ limitations under the License. */ #pragma once #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#else +#include "paddle/phi/kernels/funcs/for_range.h" +#endif namespace phi { +#if defined(__NVCC__) || defined(__HIPCC__) +template +__host__ __device__ T ratevl( + T x, const T num[], int64_t M, const T denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + T y, num_ans, denom_ans; + T absx = ::fabs(x); + const T* p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return ::pow(x, static_cast(i)) * num_ans / denom_ans; + } else { + return num_ans / denom_ans; + } +} + +template +__host__ __device__ T lanczos_sum_expg_scaled(T x) { + // lanczos approximation + + static const T lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859}; + static const T lanczos_sum_expg_scaled_denom[13] = {1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0}; + return ratevl(static_cast(x), + lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / + sizeof(lanczos_sum_expg_scaled_num[0]) - + 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / + sizeof(lanczos_sum_expg_scaled_denom[0]) - + 1); +} + +template +__host__ __device__ T _igam_helper_fac(T a, T x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + T ax, fac, res, num, numfac; + static const T MAXLOG = std::is_same::value + ? 7.09782712893383996843E2 + : 88.72283905206835; + static const T EXP1 = 2.718281828459045; + static const T lanczos_g = 6.024680040776729583740234375; + + if (::fabs(a - x) > 0.4 * ::fabs(a)) { + ax = a * ::log(x) - x - ::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return ::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= ::exp(a - x) * ::pow(x / fac, a); + } else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +__host__ __device__ T _igam_helper_series(T a, T x) { + // Compute igam using DLMF 8.11.4. [igam1] + static const T MACHEP = std::is_same::value + ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; + static const int MAXITER = 2000; + + int i; + T ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +__host__ __device__ T _igam_helper_asymptotic_series(T a, T x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + + static const T d[25][25] = { + {-3.3333333333333333e-1, 8.3333333333333333e-2, + -1.4814814814814815e-2, 1.1574074074074074e-3, + 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, + -1.85406221071516e-6, 8.296711340953086e-7, + -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, + 9.1476995822367902e-10, -2.551419399494625e-11, + -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, + 3.3717632624009854e-13, -1.3923887224181621e-13, + 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, + -2.9907248030319018e-4, -1.4638452578843418e-6, + 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, + -1.6954149536558306e-6, 8.9075075322053097e-7, + -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, + 3.4463580499464897e-9, -2.3024517174528067e-13, + -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, + 4.6792750266579195e-12, -2.1492464706134829e-12, + 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, + -3.3493161081142236e-4, 2.812695154763237e-4, + -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, + 5.7876949497350524e-6, 4.9387589339362704e-10, + -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, + 2.695423606288966e-8, -1.4578352908731271e-8, + 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, + -6.9957960920705679e-11, 2.5899863874868481e-17, + 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, + 6.7823088376673284e-4, -6.4014752602627585e-4, + 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, + -2.1073920183404862e-5, -8.8585890141255994e-10, + 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, + -1.5344695190702061e-7, 8.862466778790695e-8, + -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, + 5.7370135528051385e-10, -1.887749850169741e-19, + -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static const T MACHEP = std::is_same::value + ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; + T lambda = x / a; + T sigma = (x - a) / a; + T eta, res, ck, ckterm, term, absterm; + T absoldterm = INFINITY; + T etapow[25] = {1}; + T sum = 0; + T afac = 1; + + if (igam) { + sgn = -1; + } else { + sgn = 1; + } + + if (lambda > 1) { + eta = ::sqrt(-2 * (::log1p(sigma) - sigma)); + } else if (lambda < 1) { + eta = -::sqrt(-2 * (::log1p(sigma) - sigma)); + } else { + eta = 0; + } + res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n - 1]; + maxpow += 1; + } + ckterm = d[k][n] * etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a); + + return res; +} + +template +__host__ __device__ T _igamc_helper_series(T a, T x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + T fac = 1; + T sum = 0; + T term, logx; + static T MAXITER = 2000; + static T MACHEP = std::is_same::value ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1 + a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +__host__ __device__ T _igamc_helper_continued_fraction(T a, T x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + T ans, ax, c, yc, r, t, y, z; + T pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static T MACHEP = std::is_same::value ? 1.11022302462515654042E-16 + : 5.9604644775390625E-8; + static T BIG = + std::is_same::value ? 4.503599627370496e15 : 16777216.; + static T BIGINV = std::is_same::value ? 2.22044604925031308085e-16 + : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +__host__ __device__ T compute_igammac(T x, T a) { + static T SMALL = 20.0; + static T LARGE = 200.0; + static T SMALLRATIO = 0.3; + static T LARGERATIO = 4.5; + T absxma_a; + + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } else if (a == 0) { + if (x > 0) { + return 0.0; + } else { + return std::numeric_limits::quiet_NaN(); + } + } else if (x == 0) { + return 1.0; + } else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_continued_fraction(a, x); + } + } else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } +} + +template +__host__ __device__ T compute_igamma(T x, T a) { + static T SMALL = 20.0; + static T LARGE = 200.0; + static T SMALLRATIO = 0.3; + static T LARGERATIO = 4.5; + T absxma_a; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - compute_igammac(x, a); + } + + return _igam_helper_series(a, x); +} + +template +struct CudaIgammaFunctor { + __device__ __forceinline__ T operator()(const T _x, const T _a) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + const MT mp_a = static_cast(_a); + return compute_igamma(mp_x, mp_a); + } +}; + +template +struct CudaIgammacFunctor { + __device__ __forceinline__ T operator()(const T _x, const T _a) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + const MT mp_a = static_cast(_a); + return compute_igammac(mp_x, mp_a); zeta(static_cast(_n + 1), mp_x)); + } +}; + +#else template static inline T ratevl( T x, const T num[], int64_t M, const T denom[], int64_t N) { From 34724ccce93efacbc54caf1d285b23693ca8a1d2 Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Wed, 29 Nov 2023 00:39:06 +0800 Subject: [PATCH 07/10] :recycle: Refactor: fixed gpu bugs --- paddle/phi/kernels/gpu/igammac_kernel.cu | 5 +- paddle/phi/kernels/impl/igamma_kernel_impl.h | 48 ++++++++------------ python/paddle/tensor/manipulation.py | 6 +-- 3 files changed, 23 insertions(+), 36 deletions(-) diff --git a/paddle/phi/kernels/gpu/igammac_kernel.cu b/paddle/phi/kernels/gpu/igammac_kernel.cu index 74147aa5f8c592..fa0d19581c7aac 100644 --- a/paddle/phi/kernels/gpu/igammac_kernel.cu +++ b/paddle/phi/kernels/gpu/igammac_kernel.cu @@ -26,13 +26,12 @@ void IgammacKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& a, DenseTensor* out) { -ctx.template Alloc(out); + ctx.template Alloc(out); std::vector ins = {&x, &a}; std::vector outs = {out}; auto functor = CudaIgammacFunctor(); phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); - - } +} } // namespace phi diff --git a/paddle/phi/kernels/impl/igamma_kernel_impl.h b/paddle/phi/kernels/impl/igamma_kernel_impl.h index 2663c56c9d048b..b5400738742e1c 100644 --- a/paddle/phi/kernels/impl/igamma_kernel_impl.h +++ b/paddle/phi/kernels/impl/igamma_kernel_impl.h @@ -27,8 +27,8 @@ namespace phi { #if defined(__NVCC__) || defined(__HIPCC__) template -__host__ __device__ T ratevl( - T x, const T num[], int64_t M, const T denom[], int64_t N) { +__host__ __device__ T +ratevl(T x, const T num[], int64_t M, const T denom[], int64_t N) { // evaluating rational function, i.e., the ratio of two polynomials // the coefficients for numerator are given by `num` while coeffs for // denumerator are given by `denom` @@ -647,25 +647,20 @@ __host__ __device__ T compute_igamma(T x, T a) { if ((x < 0) || (a < 0)) { // out of defined-region of the function return std::numeric_limits::quiet_NaN(); - } - else if (a == 0) { + } else if (a == 0) { if (x > 0) { return 1.0; - } - else { + } else { return std::numeric_limits::quiet_NaN(); } - } - else if (x == 0) { - return 0.0; // zero integration limit - } - else if (std::isinf(a)) { + } else if (x == 0) { + return 0.0; // zero integration limit + } else if (std::isinf(a)) { if (std::isinf(x)) { return std::numeric_limits::quiet_NaN(); } return 0.0; - } - else if (std::isinf(x)) { + } else if (std::isinf(x)) { return 1.0; } @@ -673,8 +668,7 @@ __host__ __device__ T compute_igamma(T x, T a) { absxma_a = std::fabs(x - a) / a; if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { return _igam_helper_asymptotic_series(a, x, 1); - } - else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + } else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { return _igam_helper_asymptotic_series(a, x, 1); } @@ -701,10 +695,9 @@ struct CudaIgammacFunctor { using MT = typename phi::dtype::MPTypeTrait::Type; const MT mp_x = static_cast(_x); const MT mp_a = static_cast(_a); - return compute_igammac(mp_x, mp_a); zeta(static_cast(_n + 1), mp_x)); + return compute_igammac(mp_x, mp_a); } }; - #else template static inline T ratevl( @@ -1327,25 +1320,20 @@ static inline T compute_igamma(T x, T a) { if ((x < 0) || (a < 0)) { // out of defined-region of the function return std::numeric_limits::quiet_NaN(); - } - else if (a == 0) { + } else if (a == 0) { if (x > 0) { return 1.0; - } - else { + } else { return std::numeric_limits::quiet_NaN(); } - } - else if (x == 0) { - return 0.0; // zero integration limit - } - else if (std::isinf(a)) { + } else if (x == 0) { + return 0.0; // zero integration limit + } else if (std::isinf(a)) { if (std::isinf(x)) { return std::numeric_limits::quiet_NaN(); } return 0.0; - } - else if (std::isinf(x)) { + } else if (std::isinf(x)) { return 1.0; } @@ -1353,8 +1341,7 @@ static inline T compute_igamma(T x, T a) { absxma_a = std::fabs(x - a) / a; if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { return _igam_helper_asymptotic_series(a, x, 1); - } - else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + } else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { return _igam_helper_asymptotic_series(a, x, 1); } @@ -1420,5 +1407,6 @@ struct IgammacFunctor { T* output_; int64_t numel_; }; +#endif } // namespace phi diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index bdaf71a8e2ebd8..fbdf570decbf29 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5905,8 +5905,8 @@ def igamma(x, other, name=None): .. math:: \mathrm{igamma}(x, a)=\frac{1}{\Gamma(x)}\int_0^at^{x-1}e^{-t}dt - - + + Args: x (Tensor): ``x`` is the original Tensor. Must be at least 2-dimensional. y (Tensor): ``y`` is the Tensor to embed into ``x`` @@ -5954,7 +5954,7 @@ def igammac(x, other, name=None): .. math:: \mathrm{igammac}(x, a)=\frac{1}{\Gamma(x)}\int_a^{\infty}t^{x-1}e^{-t}dt - + Args: x (Tensor): ``x`` is the original Tensor. Must be at least 2-dimensional. y (Tensor): ``y`` is the Tensor to embed into ``x`` From 35a85cdad1b53491cf21ea256c64e63eb950b9ca Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Wed, 29 Nov 2023 01:29:39 +0800 Subject: [PATCH 08/10] :thread: Refactor: fixed gpu bugs --- paddle/phi/infermeta/unary.cc | 28 +++++++--------------- python/paddle/tensor/manipulation.py | 36 +++++++++++++--------------- 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a6aa6ca579d082..5dd64bdc5c97de 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5242,20 +5242,14 @@ void IgammaInferMeta(const MetaTensor& x, auto a_dims_size = a_dims.size(); PADDLE_ENFORCE_EQ( - x_dims_size == a_dims_size || x_dims_size == 1 || a_dims_size == 1, + x_dims_size == a_dims_size, true, phi::errors::InvalidArgument( - "The size of tesnor x and tensor a should be same dimension, or one " - "of the two tensors should be 1, but received tensor x (%u) and " - "tensor a (%u).", + "The size of tesnor x and tensor a should be same dimension," + "but received tensor x (%u) and tensor a (%u).", x_dims_size, a_dims_size)); - - if (x_dims_size == 1) { - out->set_dims(a_dims); - } else { - out->set_dims(x_dims); - } + out->set_dims(x.dims()); } void IgammacInferMeta(const MetaTensor& x, @@ -5270,20 +5264,14 @@ void IgammacInferMeta(const MetaTensor& x, auto a_dims_size = a_dims.size(); PADDLE_ENFORCE_EQ( - x_dims_size == a_dims_size || x_dims_size == 1 || a_dims_size == 1, + x_dims_size == a_dims_size, true, phi::errors::InvalidArgument( - "The size of tesnor x and tensor a should be same dimension, or one " - "of the two tensors should be 1, but received tensor x (%u) and " - "tensor a (%u).", + "The size of tesnor x and tensor a should be same dimension," + "but received tensor x (%u) and tensor a (%u).", x_dims_size, a_dims_size)); - - if (x_dims_size == 1) { - out->set_dims(a_dims); - } else { - out->set_dims(x_dims); - } + out->set_dims(x.dims()); } } // namespace phi diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index fbdf570decbf29..430b5d17f96b65 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5908,25 +5908,23 @@ def igamma(x, other, name=None): Args: - x (Tensor): ``x`` is the original Tensor. Must be at least 2-dimensional. - y (Tensor): ``y`` is the Tensor to embed into ``x`` - offset (int, optional): which diagonal to consider. Default: 0 (main diagonal). + x (Tensor): ``x`` is the non-negative input tensor. + other (Tensor): ``other`` is the Tensor to the upper boundary of Integral. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, Tensor with diagonal embedeed with ``y``. + - out (Tensor), A Tensor. the igamma of the input Tensor, the shape and data type is the same with input. Examples: .. code-block:: python >>> import paddle - >>> x = paddle.arange(6.0).reshape((2, 3)) - >>> a = paddle.ones((2,)) - >>> out = x.diagonal_scatter(y) + >>> x = paddle.to_tensor([4.0, 2.5, 1.0]) + >>> other = paddle.to_tensor([4.0, 3.0, 2.5]) + >>> out = paddle.igamma(x, other) >>> print(out) - Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - [[1., 1., 2.], - [3., 1., 5.]]) + Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0.56652987, 0.45618692, 0.15085495]) """ if in_dynamic_or_pir_mode(): return _C_ops.igamma(x, other) @@ -5956,25 +5954,23 @@ def igammac(x, other, name=None): \mathrm{igammac}(x, a)=\frac{1}{\Gamma(x)}\int_a^{\infty}t^{x-1}e^{-t}dt Args: - x (Tensor): ``x`` is the original Tensor. Must be at least 2-dimensional. - y (Tensor): ``y`` is the Tensor to embed into ``x`` - offset (int, optional): which diagonal to consider. Default: 0 (main diagonal). + x (Tensor): ``x`` is the non-negative input tensor. + other (Tensor): ``other`` is the Tensor to the upper boundary of Integral. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, Tensor with diagonal embedeed with ``y``. + - out (Tensor), A Tensor. the igamma of the input Tensor, the shape and data type is the same with input. Examples: .. code-block:: python >>> import paddle - >>> x = paddle.arange(6.0).reshape((2, 3)) - >>> y = paddle.ones((2,)) - >>> out = x.diagonal_scatter(y) + >>> x = paddle.to_tensor([4.0, 2.5, 1.0]) + >>> other = paddle.to_tensor([4.0, 3.0, 2.5]) + >>> out = paddle.igammac(x, other) >>> print(out) - Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - [[1., 1., 2.], - [3., 1., 5.]]) + Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0.43347010, 0.54381311, 0.84914505]) """ if in_dynamic_or_pir_mode(): return _C_ops.igammac(x, other) From f539a00c37e1ac183e9f3a4b028866823a8edf95 Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Wed, 29 Nov 2023 17:08:03 +0800 Subject: [PATCH 09/10] :white_check_mark: Test: added igamma test case --- test/legacy_test/test_igamma_op.py | 207 ++++++++++++++++++++++++++++ test/legacy_test/test_igammac_op.py | 207 ++++++++++++++++++++++++++++ 2 files changed, 414 insertions(+) create mode 100644 test/legacy_test/test_igamma_op.py create mode 100644 test/legacy_test/test_igammac_op.py diff --git a/test/legacy_test/test_igamma_op.py b/test/legacy_test/test_igamma_op.py new file mode 100644 index 00000000000000..28f09fa0ae77f6 --- /dev/null +++ b/test/legacy_test/test_igamma_op.py @@ -0,0 +1,207 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest +from scipy import special + +import paddle +from paddle.base import core +from paddle.pir_utils import test_with_pir_api + +np.random.seed(100) +paddle.seed(100) + + +def ref_igamma(x, a): + """ + The case where x = 0 differs from + the current mainstream implementation, + and requires specifying a special value point. + """ + pass + + +def ref_igamma_grad(x, dout): + """ + The case where x = 0 differs from + the current mainstream implementation, + and requires specifying a special value point. + """ + pass + + +class TestIgammaAPI(unittest.TestCase): + DTYPE = "float64" + DATA = [0, 1, 2, 3, 4, 5] + ORDER = 1 + + def setUp(self): + self.x = np.array(self.DATA).astype(self.DTYPE) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + @test_with_pir_api + def test_api_static(self): + def run(place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="x", shape=self.x.shape, dtype=self.DTYPE + ) + y = paddle.polygamma(x, self.ORDER) + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={"x": self.x}, + fetch_list=[y], + ) + out_ref = ref_igamma(self.x, self.ORDER) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-5) + paddle.disable_static() + + for place in self.place: + run(place) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.polygamma(x, self.ORDER) + + out_ref = ref_igamma(self.x, self.ORDER) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_empty_input_error(self): + for place in self.place: + paddle.disable_static(place) + x = None + self.assertRaises(ValueError, paddle.polygamma, x, self.ORDER) + paddle.enable_static() + + def test_input_type_error(self): + for place in self.place: + paddle.disable_static(place) + self.assertRaises( + TypeError, paddle.polygamma, self.x, float(self.ORDER) + ) + paddle.enable_static() + + def test_negative_order_error(self): + for place in self.place: + paddle.disable_static(place) + self.assertRaises(ValueError, paddle.polygamma, self.x, -self.ORDER) + paddle.enable_static() + + +class TestIgammaFloat32Order1(TestIgammaAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestIgammaFloat32Order2(TestIgammaAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 2 + + +class TestIgammaFloat32Order3(TestIgammaAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 3 + + +class TestIgammaFloat64Order1(TestIgammaAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestIgammaFloat64Order2(TestIgammaAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 2 + + +class TestIgammaFloat64Order3(TestIgammaAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 3 + + +class TestIgammaNegativeInputOrder1(TestIgammaAPI): + DTYPE = "float64" + DATA = [-2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestIgammaMultiDimOrder1(TestIgammaAPI): + DTYPE = "float64" + DATA = [[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]] + ORDER = 1 + + +class TestIgammaMultiDimOrder2(TestIgammaAPI): + DTYPE = "float64" + DATA = [ + [[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]], + [[6, 7, 8, 9, 1, 2], [0, 1, 2, 3, 4, 5]], + ] + ORDER = 2 + + +class TestIgammaOp(OpTest): + def setUp(self) -> None: + self.op_type = "igamma" + self.python_api = paddle.igamma + self.init_config() + self.outputs = {"out": self.target} + + def init_config(self): + self.dtype = np.float64 + rand_case = np.random.randn(100).astype(self.dtype) + int_case = np.random.randint(low=1, high=100, size=100).astype( + self.dtype + ) + self.other_case = int_case = np.random.randint(low=1, high=100, size=100).astype( + self.dtype + ) + self.case = np.concatenate([rand_case, int_case]) + self.inputs = {'x': self.case, 'a': self.other_case} + self.target = ref_igamma(self.inputs['x'], self.inputs['a']) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + ref_igamma_grad(self.case, 1 / self.case.size, self.order) + ], + check_pir=True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_igammac_op.py b/test/legacy_test/test_igammac_op.py new file mode 100644 index 00000000000000..606e9b073838aa --- /dev/null +++ b/test/legacy_test/test_igammac_op.py @@ -0,0 +1,207 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest +from scipy import special + +import paddle +from paddle.base import core +from paddle.pir_utils import test_with_pir_api + +np.random.seed(100) +paddle.seed(100) + + +def ref_igammac(x, a): + """ + The case where x = 0 differs from + the current mainstream implementation, + and requires specifying a special value point. + """ + pass + + +def ref_igammac_grad(x, dout): + """ + The case where x = 0 differs from + the current mainstream implementation, + and requires specifying a special value point. + """ + pass + + +class TestIgammacAPI(unittest.TestCase): + DTYPE = "float64" + DATA = [0, 1, 2, 3, 4, 5] + ORDER = 1 + + def setUp(self): + self.x = np.array(self.DATA).astype(self.DTYPE) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + @test_with_pir_api + def test_api_static(self): + def run(place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="x", shape=self.x.shape, dtype=self.DTYPE + ) + y = paddle.polygamma(x, self.ORDER) + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={"x": self.x}, + fetch_list=[y], + ) + out_ref = ref_igammac(self.x, self.ORDER) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-5) + paddle.disable_static() + + for place in self.place: + run(place) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.polygamma(x, self.ORDER) + + out_ref = ref_igammac(self.x, self.ORDER) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_empty_input_error(self): + for place in self.place: + paddle.disable_static(place) + x = None + self.assertRaises(ValueError, paddle.polygamma, x, self.ORDER) + paddle.enable_static() + + def test_input_type_error(self): + for place in self.place: + paddle.disable_static(place) + self.assertRaises( + TypeError, paddle.polygamma, self.x, float(self.ORDER) + ) + paddle.enable_static() + + def test_negative_order_error(self): + for place in self.place: + paddle.disable_static(place) + self.assertRaises(ValueError, paddle.polygamma, self.x, -self.ORDER) + paddle.enable_static() + + +class TestIgammacFloat32Order1(TestIgammacAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestIgammacFloat32Order2(TestIgammacAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 2 + + +class TestIgammacFloat32Order3(TestIgammacAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 3 + + +class TestIgammacFloat64Order1(TestIgammacAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestIgammacFloat64Order2(TestIgammacAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 2 + + +class TestIgammacFloat64Order3(TestIgammacAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 3 + + +class TestIgammacNegativeInputOrder1(TestIgammacAPI): + DTYPE = "float64" + DATA = [-2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestIgammacMultiDimOrder1(TestIgammacAPI): + DTYPE = "float64" + DATA = [[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]] + ORDER = 1 + + +class TestIgammacMultiDimOrder2(TestIgammacAPI): + DTYPE = "float64" + DATA = [ + [[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]], + [[6, 7, 8, 9, 1, 2], [0, 1, 2, 3, 4, 5]], + ] + ORDER = 2 + + +class TestIgammacOp(OpTest): + def setUp(self) -> None: + self.op_type = "igammac" + self.python_api = paddle.igammac + self.init_config() + self.outputs = {"out": self.target} + + def init_config(self): + self.dtype = np.float64 + rand_case = np.random.randn(100).astype(self.dtype) + int_case = np.random.randint(low=1, high=100, size=100).astype( + self.dtype + ) + self.other_case = int_case = np.random.randint(low=1, high=100, size=100).astype( + self.dtype + ) + self.case = np.concatenate([rand_case, int_case]) + self.inputs = {'x': self.case, 'a': self.other_case} + self.target = ref_igammac(self.inputs['x'], self.inputs['a']) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + ref_igammac_grad(self.case, 1 / self.case.size, self.order) + ], + check_pir=True, + ) + + +if __name__ == "__main__": + unittest.main() From 54b00e59c64efb47c4ed3bc4d0653e5abcf63874 Mon Sep 17 00:00:00 2001 From: PommesPeter <434596665@qq.com> Date: Wed, 29 Nov 2023 19:33:50 +0800 Subject: [PATCH 10/10] :art: Updated code style --- python/paddle/tensor/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 9e972d842989a5..e8da87f26e3c71 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -194,6 +194,8 @@ view, view_as, vsplit, + igamma, + igammac, ) from .math import ( # noqa: F401 abs, @@ -364,7 +366,6 @@ trunc, trunc_, vander, - signbit, ) from .random import ( # noqa: F401 exponential_, @@ -760,7 +761,8 @@ 'atleast_2d', 'atleast_3d', 'diagonal_scatter', - 'signbit', + 'igamma', + 'igammac', ] # this list used in math_op_patch.py for magic_method bind