From adc02e016d075c56c5d300df9cdafce3d97458d8 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 14 Sep 2023 19:44:28 +0800 Subject: [PATCH 01/19] add doc --- ...30914_api_design_for_igamma_and_igammac.md | 433 ++++++++++++++++++ 1 file changed, 433 insertions(+) create mode 100644 rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md new file mode 100644 index 000000000..ad696471f --- /dev/null +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -0,0 +1,433 @@ +# igamma 和 igammac 设计文档 +| API 名称 | igamma / igammac | +| ------------ | --------------------------------- | +| 提交作者 | zrr1999 | +| 提交时间 | 2023-09-14 | +| 版本号 | V1.0 | +| 依赖飞桨版本 | develop | +| 文件名 | 20230914_api_design_for_igamma_and_igammac.md | + +# 一、概述 + +## 1、相关背景 + +为了提升飞桨 API 丰富度,支持随机分布生成相关 API,Paddle 需要扩充 API `paddle.igamma`, `paddle.igammac`。 + +## 2、功能目标 +新增 paddle.igamma /igammac API,即实现(上)不完全伽马函数和补(下)不完全伽马函数的 API。 +这两个函数的定义如下: +$$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ +$$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ + +相应的 API 需要输入两个参数 `input` 与 `other`,对应上式的 $a$ 和 $x$; + +## 3、意义 + +为 Paddle 增加(上)不完全伽马函数和补(下)不完全伽马函数,丰富 `paddle` 的 API。 + +# 二、飞桨现状 + +- 目前 Paddle 缺少 `igamma` 和 `igammac` API,无法方便地计算(上)不完全伽马函数和补(下)不完全伽马函数的数值,以及 inplace 的方式修改输入 `x`。 + +# 三、业内方案调研 + +## PyTorch + +PyTorch 中有 `torch.igamma(input, other, *, out=None)`和`torch.igammac(input, other, *, out=None)` 的 API,以及相应inplace版本。 +PyTorch 中有 `torch.Tensor.igamma(other)` 和 `torch.Tensor.igammac(other)` 的 API,以及相应inplace版本。 + +因为 PyTorch 中这些 API 的实际计算逻辑相似性较大,因此下文的分析均以 igammac 为例。 + +在 PyTorch (aten/src/ATen/native/Math.h)中,不完全伽马函数的核心计算逻辑是 `calc_igammac`/`calc_igammacc` 函数, +然后针对不同架构,进行了不同的并行化操作,核心计算逻辑代码如下 +```cpp +template +static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} +``` + +针对一般 CPU 的并行化处理,主要是给`Vectorized`结构体添加一个新方法,代码如下(aten/src/ATen/cpu/vec/vec256/vec256_float.h) +```cpp + Vectorized igamma(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } +``` + +针对 CUDA 的并行化处理,代码如下(aten/src/ATen/native/cuda/IGammaKernel.cu) +```cpp +template +__noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + + using accscalar_t = at::acc_type; + accscalar_t absxma_a; + + static const accscalar_t SMALL = 20.0; + static const accscalar_t LARGE = 200.0; + static const accscalar_t SMALLRATIO = 0.3; + static const accscalar_t LARGERATIO = 4.5; + + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (::isinf(static_cast(x))) { + return 0.0; + } + + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / ::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + + +``` + +## Scipy + +Scipy 中有 `scipy.special.gammainc(a, x, dps=50, maxterms=10**8)`和`scipy.special.gammaincc(a, x, dps=50, maxterms=10**8)` 的 API。 + +在 Scipy (scipy/special/_precompute/gammainc_data.py)中, +gammainc 通过超几何函数计算,代码如下 +```py +def gammainc(a, x, dps=50, maxterms=10**8): + """Compute gammainc exactly like mpmath does but allow for more + summands in hypercomb. See + + mpmath/functions/expintegrals.py#L134 + + in the mpmath github repository. + + """ + with mp.workdps(dps): + z, a, b = mp.mpf(a), mp.mpf(x), mp.mpf(x) + G = [z] + negb = mp.fneg(b, exact=True) + + def h(z): + T1 = [mp.exp(negb), b, z], [1, z, -1], [], G, [1], [1+z], b + return (T1,) + + res = mp.hypercomb(h, [z], maxterms=maxterms) + return mpf2float(res) +``` + + +## TensorFlow +TensorFlow 中有 `Igamma(a: XlaOp, x: XlaOp)`和`Igammac(a: XlaOp, x: XlaOp)` 的 API。 + +TensorFlow 会转换成 XLA,最后 XLA 的实现代码如下: +```cpp + +XlaOp Igamma(XlaOp a, XlaOp x) { + auto& b = *a.builder(); + auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp { + XlaOp is_nan = Or(IsNan(a), IsNan(x)); + XlaOp x_is_zero = Eq(x, ScalarLike(x, 0)); + XlaOp x_is_infinity = + Eq(x, ScalarLike(x, std::numeric_limits::infinity())); + XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0))); + XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a)); + XlaOp ax = a * Log(x) - x - Lgamma(a); + XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type))); + ax = Exp(ax); + XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan)); + const double nan = std::numeric_limits::quiet_NaN(); + XlaOp output = Select( + use_igammac, + ScalarLike(a, 1) - IgammacContinuedFraction( + ax, x, a, And(enabled, use_igammac), type), + IgammaSeries(ax, x, a, And(enabled, Not(use_igammac)), type)); + output = Select(x_is_zero, ZerosLike(output), output); + output = Select(x_is_infinity, FullLike(output, 1), output); + output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); + return output; + }; + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); + TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x)); + if (a_shape != x_shape) { + return InvalidArgument( + "Arguments to Igamma must have equal shapes and types; got %s and %s", + a_shape.ToString(), x_shape.ToString()); + } + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); + PrimitiveType a_x_type = a_shape.element_type(); + bool needs_upcast = false; + for (PrimitiveType type : + {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + if (a_shape.element_type() == type) { + needs_upcast = true; + break; + } + } + + if (needs_upcast) { + a = ConvertElementType(a, F32); + x = ConvertElementType(x, F32); + a_x_type = F32; + } + XlaOp result = doit(a, x, a_x_type); + if (needs_upcast) { + result = ConvertElementType(result, a_shape.element_type()); + } + return result; + }); +} + +``` + +# 四、对比分析 + +## 共同点 +- 都有提供对 Python 的调用接口。 +- 均支持 tensor 的输入。 + +## 不同点 + +- PyTorch 是使用 C++ 独立编写的计算逻辑。 +- Scipy 是使用超几何函数计算。 +- Tensorflow 是通过转换为 XLA 再进行计算。 + +# 五、设计思路与实现方案 + +## 命名与参数设计 + +添加 Python API + +```python +paddle.igamma( + inout: Tensor, + other: Tensor, + name: str | None = None +) +``` + +```python +paddle.igammac( + inout: Tensor, + other: Tensor, + name: str | None = None +) +``` + +```python +paddle.igamma_( + inout: Tensor, + other: Tensor, + name: str | None = None +) +``` + +```python +paddle.igammac_( + inout: Tensor, + other: Tensor, + name: str | None = None +) +``` + +```python +paddle.Tensor.igamma( + other: Tensor +) +``` + +```python +paddle.Tensor.igammac( + other: Tensor +) +``` + +```python +paddle.Tensor.igamma_( + other: Tensor +) +``` + +```python +paddle.Tensor.igammac_( + other: Tensor +) +``` + +## 底层OP设计 + +不涉及 + +## API实现方案 + +该 API 实现于 `python/paddle/tensor/manipulation.py`。 + +### igamma +参考 PyTorch 的实现,使用 C++ 独立编写的计算逻辑。 + +### igammac + + +# 六、测试和验收的考量 + +测试需要考虑的 case 如下: + +- 输出数值结果的一致性和数据类型是否正确,使用 scipy 作为参考标准 +- 对不同 dtype 的输入数据 `x` 进行计算精度检验 (float32, float64) +- 输入输出的容错性与错误提示信息 +- 输出 Dtype 错误或不兼容时抛出异常 +- 保证调用属性时是可以被正常找到的 +- 覆盖静态图和动态图测试场景 + +# 七、可行性分析和排期规划 + +方案主要依赖现有原理实现。工期上可以满足在当前版本周期内开发完成。 + +# 八、影响面 + +新增 API,对其他模块无影响 + +# 名词解释 +gammainc 是 igamma 的另一种写法,gammaincc 是 igammac 的另一种写法。 + +# 附件及参考资料 + +- [torch.igamma](https://pytorch.org/docs/stable/special.html#torch.special.gammainc) +- [torch.igammac](https://pytorch.org/docs/stable/special.html#torch.special.gammaincc) +- [scipy](https://github.com/scipy/scipy) +- [tensorflow](https://github.com/tensorflow/tensorflow) From 4cb874c28d7091b0e28c1433f8f9b03751d84221 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 14 Sep 2023 23:18:10 +0800 Subject: [PATCH 02/19] up --- ...30914_api_design_for_igamma_and_igammac.md | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index ad696471f..8584e8476 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -14,7 +14,7 @@ 为了提升飞桨 API 丰富度,支持随机分布生成相关 API,Paddle 需要扩充 API `paddle.igamma`, `paddle.igammac`。 ## 2、功能目标 -新增 paddle.igamma /igammac API,即实现(上)不完全伽马函数和补(下)不完全伽马函数的 API。 +新增 paddle.igamma /igammac API,即实现[(上)不完全伽马函数和补(下)不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 这两个函数的定义如下: $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ $$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ @@ -39,7 +39,7 @@ PyTorch 中有 `torch.Tensor.igamma(other)` 和 `torch.Tensor.igammac(other)` 因为 PyTorch 中这些 API 的实际计算逻辑相似性较大,因此下文的分析均以 igammac 为例。 在 PyTorch (aten/src/ATen/native/Math.h)中,不完全伽马函数的核心计算逻辑是 `calc_igammac`/`calc_igammacc` 函数, -然后针对不同架构,进行了不同的并行化操作,核心计算逻辑代码如下 +这是一个`inline`函数,后续进行了不同的向量化操作,核心计算逻辑代码如下 ```cpp template static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { @@ -122,7 +122,7 @@ static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { } ``` -针对一般 CPU 的并行化处理,主要是给`Vectorized`结构体添加一个新方法,代码如下(aten/src/ATen/cpu/vec/vec256/vec256_float.h) +针对一般 float 的向量化处理,主要是给`Vectorized`结构体添加一个新方法,代码如下(aten/src/ATen/cpu/vec/vec256/vec256_float.h) ```cpp Vectorized igamma(const Vectorized &x) const { __at_align__ float tmp[size()]; @@ -136,7 +136,7 @@ static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { } ``` -针对 CUDA 的并行化处理,代码如下(aten/src/ATen/native/cuda/IGammaKernel.cu) +针对 CUDA,核心计算逻辑代码如下(aten/src/ATen/native/cuda/IGammaKernel.cu),这部分与 CPU的实现相似 ```cpp template __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { @@ -221,6 +221,28 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { ``` +然后通过一些内部的机制,例如 `AT_DISPATCH_FLOATING_TYPES` 对其处理,代码如下: +```cpp +template +struct CalcIgamma{ + CalcIgamma(bool calc_igammac): calc_igammac_(calc_igammac){} + bool calc_igammac_; + __device__ scalar_t operator() (scalar_t a, scalar_t b) const { + if (calc_igammac_) { + return calc_igammac(a,b); + } else { + return calc_igamma(a,b); + } + } +}; + +void igammac_kernel_cuda(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igammac_cuda", [&]() { + gpu_kernel(iter, CalcIgamma(true)); + }); +} +``` + ## Scipy Scipy 中有 `scipy.special.gammainc(a, x, dps=50, maxterms=10**8)`和`scipy.special.gammaincc(a, x, dps=50, maxterms=10**8)` 的 API。 @@ -390,8 +412,13 @@ paddle.Tensor.igammac_( ``` ## 底层OP设计 +对于底层 OP 主要分为三部分,由于 `igamma` 和 `igammac`是互补关系,所以实际上可服用代码很多, +因此底层OP设计仅以`igammac`为例。 -不涉及 +### 实现基础计算逻辑 +根据 igamma (上不完全伽马函数) 的定义。 +这两个函数的定义如下: +$$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ ## API实现方案 @@ -401,7 +428,7 @@ paddle.Tensor.igammac_( 参考 PyTorch 的实现,使用 C++ 独立编写的计算逻辑。 ### igammac - +参考 PyTorch 的实现,使用 C++ 独立编写的计算逻辑。 # 六、测试和验收的考量 @@ -426,7 +453,7 @@ paddle.Tensor.igammac_( gammainc 是 igamma 的另一种写法,gammaincc 是 igammac 的另一种写法。 # 附件及参考资料 - +- [不完全伽马函数的定义——小时百科](https://wuli.wiki/online/IncGam.html) - [torch.igamma](https://pytorch.org/docs/stable/special.html#torch.special.gammainc) - [torch.igammac](https://pytorch.org/docs/stable/special.html#torch.special.gammaincc) - [scipy](https://github.com/scipy/scipy) From afa5f0566163c5af0d90bb08b68efa0cf680956f Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 14 Sep 2023 23:25:54 +0800 Subject: [PATCH 03/19] up op --- .../20230914_api_design_for_igamma_and_igammac.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 8584e8476..34fb87580 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -416,9 +416,16 @@ paddle.Tensor.igammac_( 因此底层OP设计仅以`igammac`为例。 ### 实现基础计算逻辑 -根据 igamma (上不完全伽马函数) 的定义。 -这两个函数的定义如下: +根据 igamma (上不完全伽马函数) 的定义,即 $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ +设计相应的CPU和CUDA计算函数(CPU和CUDA主体逻辑相似,仅写法上会存在一些差异),这部分与PyTorch相似,也是最核心的内容。 + +### 实现基础计算逻辑的向量化(针对CPU) +可采用类似 PyTorch 的向量化技术加速。 + +### 实现基础计算逻辑的向量化(针对GPU) +这里的 hip 和 cuda 的实现可利用 Paddle 已经实现的很多宏或函数,从而消除两者的差异, +最终实现 Kernel 函数。 ## API实现方案 From d9165a4860823185e474044718f32088536b85e4 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 14 Sep 2023 23:36:46 +0800 Subject: [PATCH 04/19] up api --- ...30914_api_design_for_igamma_and_igammac.md | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 34fb87580..de75eefff 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -357,7 +357,7 @@ XlaOp Igamma(XlaOp a, XlaOp x) { ```python paddle.igamma( - inout: Tensor, + input: Tensor, other: Tensor, name: str | None = None ) @@ -365,7 +365,7 @@ paddle.igamma( ```python paddle.igammac( - inout: Tensor, + input: Tensor, other: Tensor, name: str | None = None ) @@ -373,7 +373,7 @@ paddle.igammac( ```python paddle.igamma_( - inout: Tensor, + input: Tensor, other: Tensor, name: str | None = None ) @@ -381,7 +381,7 @@ paddle.igamma_( ```python paddle.igammac_( - inout: Tensor, + input: Tensor, other: Tensor, name: str | None = None ) @@ -432,10 +432,16 @@ $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ 该 API 实现于 `python/paddle/tensor/manipulation.py`。 ### igamma -参考 PyTorch 的实现,使用 C++ 独立编写的计算逻辑。 +对于 igamma 、 igamma_ 、igammac 和 igammac_ 有类似的API,下面列出了`igamma`的情况。 -### igammac -参考 PyTorch 的实现,使用 C++ 独立编写的计算逻辑。 +具体的API为`paddle.igamma(input, other, name = None)`和`paddle.Tensor.igamma(input, other)` + +- input: 输入张量,即公式中的 $a$ +- other: 输入张量,即公式中的 $x$ + + +例如将一维张量$[3, 5]$和一维张量$[2, 7]$输入,则计算结果如下: +$$ \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} dt] = $$ # 六、测试和验收的考量 From ccce89774d8226f1c7dab4c10ebc11ae87498964 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Fri, 15 Sep 2023 12:20:48 +0800 Subject: [PATCH 05/19] up --- ...0230914_api_design_for_igamma_and_igammac.md | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index de75eefff..630b63b90 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -14,7 +14,7 @@ 为了提升飞桨 API 丰富度,支持随机分布生成相关 API,Paddle 需要扩充 API `paddle.igamma`, `paddle.igammac`。 ## 2、功能目标 -新增 paddle.igamma /igammac API,即实现[(上)不完全伽马函数和补(下)不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 +新增 paddle.igamma /igammac API,即实现[上不完全伽马函数和下不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 这两个函数的定义如下: $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ $$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ @@ -23,11 +23,11 @@ $$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ ## 3、意义 -为 Paddle 增加(上)不完全伽马函数和补(下)不完全伽马函数,丰富 `paddle` 的 API。 +为 Paddle 增加上不完全伽马函数和下不完全伽马函数,丰富 `paddle` 的 API。 # 二、飞桨现状 -- 目前 Paddle 缺少 `igamma` 和 `igammac` API,无法方便地计算(上)不完全伽马函数和补(下)不完全伽马函数的数值,以及 inplace 的方式修改输入 `x`。 +- 目前 Paddle 缺少 `igamma` 和 `igammac` API,无法方便地计算上不完全伽马函数和下不完全伽马函数的数值,以及 inplace 的方式修改输入 `x`。 # 三、业内方案调研 @@ -444,6 +444,8 @@ $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ $$ \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} dt] = $$ # 六、测试和验收的考量 +1. 添加单测文件 `test/legacy_test/test_igamma_op.py` 和 `test/legacy_test/test_igamma_op.py`。 +2. 在单测文件 `test/legacy_test/test_inplace.py` 补充测试。 测试需要考虑的 case 如下: @@ -454,16 +456,21 @@ $$ \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} - 保证调用属性时是可以被正常找到的 - 覆盖静态图和动态图测试场景 +```python + + +``` + # 七、可行性分析和排期规划 -方案主要依赖现有原理实现。工期上可以满足在当前版本周期内开发完成。 +方案主要根据相关数学原理并参考 PyTorch 的工程实现方法,工期上可以满足在当前版本周期内开发完成。 # 八、影响面 新增 API,对其他模块无影响 # 名词解释 -gammainc 是 igamma 的另一种写法,gammaincc 是 igammac 的另一种写法。 +gammainc 是 igamma 的另一种写法,即上不完全伽马函数,gammaincc 是 igammac 的另一种写法,即下不完全伽马函数。 # 附件及参考资料 - [不完全伽马函数的定义——小时百科](https://wuli.wiki/online/IncGam.html) From e9b59f54c5d8b91e0eda75daa8ec347a651342b7 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Fri, 15 Sep 2023 12:29:37 +0800 Subject: [PATCH 06/19] up --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 630b63b90..f06e19a17 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -19,6 +19,9 @@ $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ $$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ +上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$,$x\geq 0$,值域为 $(0,\Gamma(a)]$。 +下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$,$x\geq 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 + 相应的 API 需要输入两个参数 `input` 与 `other`,对应上式的 $a$ 和 $x$; ## 3、意义 @@ -456,11 +459,6 @@ $$ \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} - 保证调用属性时是可以被正常找到的 - 覆盖静态图和动态图测试场景 -```python - - -``` - # 七、可行性分析和排期规划 方案主要根据相关数学原理并参考 PyTorch 的工程实现方法,工期上可以满足在当前版本周期内开发完成。 From cb5a5676a0ab2303b47e0b761d3f3b52cf203317 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 18 Sep 2023 12:01:46 +0800 Subject: [PATCH 07/19] add dtype --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index f06e19a17..a1430ffab 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -39,6 +39,8 @@ $$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ PyTorch 中有 `torch.igamma(input, other, *, out=None)`和`torch.igammac(input, other, *, out=None)` 的 API,以及相应inplace版本。 PyTorch 中有 `torch.Tensor.igamma(other)` 和 `torch.Tensor.igammac(other)` 的 API,以及相应inplace版本。 +PyTorch 中输入 CPU 支持 float16, bfloat16, float32, float64,GPU支持 float32, float64 + 因为 PyTorch 中这些 API 的实际计算逻辑相似性较大,因此下文的分析均以 igammac 为例。 在 PyTorch (aten/src/ATen/native/Math.h)中,不完全伽马函数的核心计算逻辑是 `calc_igammac`/`calc_igammacc` 函数, @@ -453,7 +455,7 @@ $$ \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} 测试需要考虑的 case 如下: - 输出数值结果的一致性和数据类型是否正确,使用 scipy 作为参考标准 -- 对不同 dtype 的输入数据 `x` 进行计算精度检验 (float32, float64) +- 对不同 dtype 的输入数据 `input` 和 `other` 进行计算精度检验,与PyTorch保持一致 - 输入输出的容错性与错误提示信息 - 输出 Dtype 错误或不兼容时抛出异常 - 保证调用属性时是可以被正常找到的 From 49f9687aac244117f46885e3fd9a08a9414a6d31 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 18 Sep 2023 12:03:09 +0800 Subject: [PATCH 08/19] up --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index a1430ffab..0a1f4326b 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -11,7 +11,7 @@ ## 1、相关背景 -为了提升飞桨 API 丰富度,支持随机分布生成相关 API,Paddle 需要扩充 API `paddle.igamma`, `paddle.igammac`。 +为了提升飞桨 API 丰富度,支持随机分布生成相关 API,Paddle 需要扩充 API `paddle.igamma`, `paddle.igammac`, `paddle.igamma_`, `paddle.igammac_`。 ## 2、功能目标 新增 paddle.igamma /igammac API,即实现[上不完全伽马函数和下不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 From b1f7b259cc448f8b442950689a44dd3195c28ea2 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 18 Sep 2023 12:03:29 +0800 Subject: [PATCH 09/19] up --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 0a1f4326b..654c5ce2a 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -434,7 +434,7 @@ $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ ## API实现方案 -该 API 实现于 `python/paddle/tensor/manipulation.py`。 +该 API 实现于 `python/paddle/tensor/math.py`。 ### igamma 对于 igamma 、 igamma_ 、igammac 和 igammac_ 有类似的API,下面列出了`igamma`的情况。 From 2355d621507c5b0deb5b1430edc414f724d3bf44 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Tue, 19 Sep 2023 04:39:01 +0800 Subject: [PATCH 10/19] up --- .../APIs/20230914_api_design_for_igamma_and_igammac.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 654c5ce2a..c486bdab4 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -14,7 +14,7 @@ 为了提升飞桨 API 丰富度,支持随机分布生成相关 API,Paddle 需要扩充 API `paddle.igamma`, `paddle.igammac`, `paddle.igamma_`, `paddle.igammac_`。 ## 2、功能目标 -新增 paddle.igamma /igammac API,即实现[上不完全伽马函数和下不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 +新增 `paddle.igamma`, `paddle.igammac`, `paddle.igamma_`, `paddle.igammac_` API,即实现[上不完全伽马函数和下不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 这两个函数的定义如下: $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ $$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ @@ -30,7 +30,7 @@ $$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ # 二、飞桨现状 -- 目前 Paddle 缺少 `igamma` 和 `igammac` API,无法方便地计算上不完全伽马函数和下不完全伽马函数的数值,以及 inplace 的方式修改输入 `x`。 +- 目前 Paddle 缺少 `paddle.igamma`, `paddle.igammac`, `paddle.igamma_`, `paddle.igammac_` API,无法方便地计算上不完全伽马函数和下不完全伽马函数的数值,以及 inplace 的方式修改输入 `x`。 # 三、业内方案调研 @@ -445,8 +445,10 @@ $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ - other: 输入张量,即公式中的 $x$ -例如将一维张量$[3, 5]$和一维张量$[2, 7]$输入,则计算结果如下: -$$ \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} dt] = $$ +例如将一维张量 $[3, 5]$ 和一维张量 $[2, 7]$ 输入,则计算结果如下: +$ + \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} dt] +$ # 六、测试和验收的考量 1. 添加单测文件 `test/legacy_test/test_igamma_op.py` 和 `test/legacy_test/test_igamma_op.py`。 From ecbf9f345580494361c3a811d199f09631205a01 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Tue, 19 Sep 2023 04:40:11 +0800 Subject: [PATCH 11/19] fix math --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index c486bdab4..607b3a0f1 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -446,9 +446,7 @@ $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ 例如将一维张量 $[3, 5]$ 和一维张量 $[2, 7]$ 输入,则计算结果如下: -$ - \Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} dt] -$ +$\Gamma(a, x) = [\int_2^{\infty} t^{2} e^{-t} dt, \int_7^{\infty} t^{4} e^{-t} dt]$ # 六、测试和验收的考量 1. 添加单测文件 `test/legacy_test/test_igamma_op.py` 和 `test/legacy_test/test_igamma_op.py`。 From 9c1e0154cfba87a2b852a3b5b8eb1bcdbae389a0 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Tue, 19 Sep 2023 04:51:38 +0800 Subject: [PATCH 12/19] up --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 607b3a0f1..d90e16a9e 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -417,7 +417,8 @@ paddle.Tensor.igammac_( ``` ## 底层OP设计 -对于底层 OP 主要分为三部分,由于 `igamma` 和 `igammac`是互补关系,所以实际上可服用代码很多, +Kernel部分实现添加在 `paddle/phi/kernels/cpu/igamma_kernel.cc` 和 `paddle/phi/kernels/cpu/igammac_kernel.cc` +对于底层 OP 主要分为三部分,由于 `igamma` 和 `igammac`是互补关系,所以实际上可复用代码很多, 因此底层OP设计仅以`igammac`为例。 ### 实现基础计算逻辑 From 298cbe80def28c9f07c7ea19623ac56f3470e8e6 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Tue, 19 Sep 2023 04:52:27 +0800 Subject: [PATCH 13/19] up --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index d90e16a9e..466d1fc34 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -417,7 +417,8 @@ paddle.Tensor.igammac_( ``` ## 底层OP设计 -Kernel部分实现添加在 `paddle/phi/kernels/cpu/igamma_kernel.cc` 和 `paddle/phi/kernels/cpu/igammac_kernel.cc` +Kernel部分CPU实现添加在 `paddle/phi/kernels/cpu/igamma_kernel.cc` 和 `paddle/phi/kernels/cpu/igammac_kernel.cc`, +Kernel部分GPU实现添加在 `paddle/phi/kernels/gpu/igamma_kernel.cu` 和 `paddle/phi/kernels/gpu/igammac_kernel.cu`, 对于底层 OP 主要分为三部分,由于 `igamma` 和 `igammac`是互补关系,所以实际上可复用代码很多, 因此底层OP设计仅以`igammac`为例。 From 36cb853bbdb2321295b14b0d14a07648f508f42f Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Tue, 19 Sep 2023 04:54:04 +0800 Subject: [PATCH 14/19] up --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 466d1fc34..ae4c07c21 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -419,6 +419,7 @@ paddle.Tensor.igammac_( ## 底层OP设计 Kernel部分CPU实现添加在 `paddle/phi/kernels/cpu/igamma_kernel.cc` 和 `paddle/phi/kernels/cpu/igammac_kernel.cc`, Kernel部分GPU实现添加在 `paddle/phi/kernels/gpu/igamma_kernel.cu` 和 `paddle/phi/kernels/gpu/igammac_kernel.cu`, +输入 CPU 支持 float16, bfloat16, float32, float64,GPU支持 float32, float64, 对于底层 OP 主要分为三部分,由于 `igamma` 和 `igammac`是互补关系,所以实际上可复用代码很多, 因此底层OP设计仅以`igammac`为例。 @@ -443,8 +444,8 @@ $$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ 具体的API为`paddle.igamma(input, other, name = None)`和`paddle.Tensor.igamma(input, other)` -- input: 输入张量,即公式中的 $a$ -- other: 输入张量,即公式中的 $x$ +- input: 输入张量,即公式中的 $a$, CPU 支持 float16, bfloat16, float32, float64,GPU支持 float32, float64 +- other: 输入张量,即公式中的 $x$, CPU 支持 float16, bfloat16, float32, float64,GPU支持 float32, float64 例如将一维张量 $[3, 5]$ 和一维张量 $[2, 7]$ 输入,则计算结果如下: From fae3e363f3b846043acc238c5f3f52526795837b Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 21 Sep 2023 16:53:02 +0800 Subject: [PATCH 15/19] fix --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index ae4c07c21..c9a37434e 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -16,8 +16,8 @@ ## 2、功能目标 新增 `paddle.igamma`, `paddle.igammac`, `paddle.igamma_`, `paddle.igammac_` API,即实现[上不完全伽马函数和下不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 这两个函数的定义如下: -$$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ -$$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $$ +$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ +$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $ 上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$,$x\geq 0$,值域为 $(0,\Gamma(a)]$。 下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$,$x\geq 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 @@ -425,7 +425,7 @@ Kernel部分GPU实现添加在 `paddle/phi/kernels/gpu/igamma_kernel.cu` 和 `pa ### 实现基础计算逻辑 根据 igamma (上不完全伽马函数) 的定义,即 -$$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $$ +$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ 设计相应的CPU和CUDA计算函数(CPU和CUDA主体逻辑相似,仅写法上会存在一些差异),这部分与PyTorch相似,也是最核心的内容。 ### 实现基础计算逻辑的向量化(针对CPU) From 8e6f9137573fe5afcec96b2969377082017515e0 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 21 Sep 2023 16:56:27 +0800 Subject: [PATCH 16/19] fix --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index c9a37434e..ef0713b5e 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -16,11 +16,12 @@ ## 2、功能目标 新增 `paddle.igamma`, `paddle.igammac`, `paddle.igamma_`, `paddle.igammac_` API,即实现[上不完全伽马函数和下不完全伽马](https://wuli.wiki/online/IncGam.html)函数的 API。 这两个函数的定义如下: -$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ -$ \gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $ -上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$,$x\geq 0$,值域为 $(0,\Gamma(a)]$。 -下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$,$x\geq 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 +$\Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ +$\gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $ + +上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$,$x \geq 0$,值域为 $(0,\Gamma(a)]$。 +下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$,$x \geq 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 相应的 API 需要输入两个参数 `input` 与 `other`,对应上式的 $a$ 和 $x$; From 8b834212bf4d6342413567a8e360086292b9b781 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 21 Sep 2023 16:57:33 +0800 Subject: [PATCH 17/19] fix --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index ef0713b5e..a421d3079 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -18,10 +18,11 @@ 这两个函数的定义如下: $\Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ + $\gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $ -上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$,$x \geq 0$,值域为 $(0,\Gamma(a)]$。 -下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$,$x \geq 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 +上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$,$x >= 0$,值域为 $(0,\Gamma(a)]$。 +下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$,$x >= 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 相应的 API 需要输入两个参数 `input` 与 `other`,对应上式的 $a$ 和 $x$; From 6bbb6d3c67a1296ce970783441e46e200353498f Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 21 Sep 2023 16:58:26 +0800 Subject: [PATCH 18/19] fix --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index a421d3079..8b10f4dc4 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -21,8 +21,8 @@ $\Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ $\gamma(a, x) = \int_0^x t^{a-1} e^{-t} dt $ -上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$,$x >= 0$,值域为 $(0,\Gamma(a)]$。 -下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$,$x >= 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 +上不完全伽马函数 $\Gamma(a,x)$ 的定义域为 $a>0$, $x \geq 0$,值域为 $(0,\Gamma(a)]$。 +下不完全伽马函数 $\gamma(a,x)$ 的定义域为 $a>0$, $x \geq 0$,值域为 $[0,\Gamma(a))$,其中 $\Gamma(a)$ 是伽马函数的值。 相应的 API 需要输入两个参数 `input` 与 `other`,对应上式的 $a$ 和 $x$; From 1c08fd7b25cedd59e18dc38bc941b6ee42493d98 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 21 Sep 2023 17:52:37 +0800 Subject: [PATCH 19/19] fix --- rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md index 8b10f4dc4..f919886f8 100644 --- a/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md +++ b/rfcs/APIs/20230914_api_design_for_igamma_and_igammac.md @@ -427,7 +427,7 @@ Kernel部分GPU实现添加在 `paddle/phi/kernels/gpu/igamma_kernel.cu` 和 `pa ### 实现基础计算逻辑 根据 igamma (上不完全伽马函数) 的定义,即 -$ \Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ +$\Gamma(a, x) = \int_x^{\infty} t^{a-1} e^{-t} dt $ 设计相应的CPU和CUDA计算函数(CPU和CUDA主体逻辑相似,仅写法上会存在一些差异),这部分与PyTorch相似,也是最核心的内容。 ### 实现基础计算逻辑的向量化(针对CPU)