Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Paddle Tensor 第二期 复数类型支持问题No.9、10】为paddle.not_equalpaddle.equal添加复数类型支持。 #69968

Merged
merged 66 commits into from
Dec 17, 2024

Conversation

MrXnneHang
Copy link
Contributor

@MrXnneHang MrXnneHang commented Dec 5, 2024

PR Category

User Experience

PR Types

Others

Description

compare kernel添加复数支持:

  • not_equal
  • equal

偏特化EqualFunctor加入对<phi::dtype::complex<double><phi::dtype::complex<float>的计算逻辑。

error 1 (solved)

TypeError: (InvalidType) Type promotion only support calculations between floating-point numbers and between complex and real numbers. But got different data type x: int16, y: int8. (at /paddle/paddle/phi/common/type_promotion.h:229

这个问题是不支持int的自动类型提升。
https://github.com/HydrogenSulfate/array-api-tests/blob/paddle/array_api_tests/test_operators_and_elementwise_functions.py#1119
我把这一行换到手动类型转换之后就pass了。

error 2 (solved)

FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] - AssertionError: out=True, but should be (x1 == x2)=False [__eq__()]
  x1=0.2500000260770321, x2=0.2500000298023224
assert True == False
Falsifying example: test_equal(
    ctx=BinaryParamContext(<__eq__(x1, x2)>),
    data=data(...),
)
Draw 1 (x1): Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True,
       0.25000003)
Draw 2 (x2): Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
       0.25000003)

问题二似乎和提前的截断有关。测试里有个bug。
https://github.com/HydrogenSulfate/array-api-tests/blob/paddle/array_api_tests/test_operators_and_elementwise_functions.py#L1117

手动类型转换前leftright就把data截断了。后续对leftright做类型转换也太晚了。
应该利用promoted_dtype进行初始化而不是类型转换。

这样子似乎可以解决:

def test_equal(ctx, data):


    left = data.draw(ctx.left_strat, label=ctx.left_sym)
    right = data.draw(ctx.right_strat, label=ctx.right_sym)

    
    if not ctx.right_is_scalar:
        left_dtype = left.dtype
        right_dtype = right.dtype
        # We manually promote the dtypes as incorrect internal type promotion
        # could lead to false positives. For example
        #
        #     >>> xp.equal(
        #     ...     xp.asarray(1.0, dtype=xp.float32),
        #     ...     xp.asarray(1.00000001, dtype=xp.float64),
        #     ... )
        #
        # would erroneously be True if float64 downcasted to float32.
        promoted_dtype = dh.promotion_table[left_dtype, right_dtype]
        print(f"promoted_dtype:{promoted_dtype}")

        # 修改策略以确保生成的数据具有正确的 dtype
        left_strat = ctx.left_strat.filter(lambda x: x.dtype == promoted_dtype)
        right_strat = ctx.right_strat.filter(lambda x: x.dtype == promoted_dtype)

        left = data.draw(left_strat, label=ctx.left_sym)
        right = data.draw(right_strat, label=ctx.right_sym)

    out = ctx.func(left, right)

    binary_param_assert_dtype(ctx, left, right, out, xp.bool)
    binary_param_assert_shape(ctx, left, right, out)
    binary_param_assert_against_refimpl(
        ctx, left, right, out, "==", operator.eq, res_stype=bool
    )

涉及手动类型提升的测试似乎都应该这么改。

error 3 (solved) :

Paddle Tensor == Tensor,Tensor != Tensor 结果的shape和equal,not_equal不同。

手动调用Tensor.equal:

import torch
import numpy as np
import paddle
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
@given(data=st.data())
def test_not_equal(ctx, data):
    left = data.draw(ctx.left_strat, label=ctx.left_sym)
    right = data.draw(ctx.right_strat, label=ctx.right_sym)
    left_np = np.array(left)
    left_tensor = torch.tensor(left_np)
    right_np = np.array(right)
    right_tensor = torch.tensor(right_np)
    expected = (left_tensor != right_tensor).shape
    out = ctx.func(left, right)
    out_shape = paddle.to_tensor(left).not_equal(paddle.to_tensor(right)).shape
    binary_param_assert_dtype(ctx, left, right, out, xp.bool)
    binary_param_assert_shape(ctx, left, right, out_shape,expected=expected)
    if not ctx.right_is_scalar:
        # See test_equal note
        promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
        left = xp.astype(left, promoted_dtype)
        right = xp.astype(right, promoted_dtype)
    binary_param_assert_against_refimpl(
        ctx, left, right, out, "!=", operator.ne, res_stype=bool
    )

image

image

Copy link

paddle-bot bot commented Dec 5, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Dec 5, 2024
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Dec 5, 2024
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

思路没问题,但是感觉复数这块的修改需要额外添加特化模板,应该不是简单注册算子就能解决的。
可以看到比较类型的Kernel确实是一起用宏生成的
image
以LessEqualFunctor为例,Kernel调用的Functor也是宏生成的,其逻辑在
image
也就是说核心比较代码是a op b,这对于实数是可以的,但是对于复数,应该是要比较其实部和虚部

另外看python和pytorch对于复数比较的支持情况来看,目前只支持了==,对于>, <, >=, <=都不支持,所以应该只需要为 == 的复数case写一个特化模板即可
image

@MrXnneHang MrXnneHang changed the title 【Paddle Tensor 第二期 复数类型支持问题No.7、9、10】为compare_kernel添加复数类型支持。 【Paddle Tensor 第二期 复数类型支持问题No.7、9、10】为not_equalpaddle.equal添加复数类型支持。 Dec 5, 2024
@MrXnneHang MrXnneHang changed the title 【Paddle Tensor 第二期 复数类型支持问题No.7、9、10】为not_equalpaddle.equal添加复数类型支持。 【Paddle Tensor 第二期 复数类型支持问题No.7、9、10】为paddle.not_equalpaddle.equal添加复数类型支持。 Dec 5, 2024
@MrXnneHang MrXnneHang changed the title 【Paddle Tensor 第二期 复数类型支持问题No.7、9、10】为paddle.not_equalpaddle.equal添加复数类型支持。 【Paddle Tensor 第二期 复数类型支持问题No.9、10】为paddle.not_equalpaddle.equal添加复数类型支持。 Dec 5, 2024
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了模板合并之外,其他感觉没问题了

Comment on lines 50 to 80
template <typename OutT>
struct EqualFunctor<phi::dtype::complex<float>, OutT> {
HOSTDEVICE OutT operator()(const phi::dtype::complex<float>& a,
const phi::dtype::complex<float>& b) const {
if (isinf(a.real) || isinf(a.imag) || isinf(b.real) || isinf(b.imag)) {
return a == b;
}
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) {
return false;
}
float epsilon = 1e-8f;
return std::abs(a.real - b.real) < epsilon &&
std::abs(a.imag - b.imag) < epsilon;
}
};

template <typename OutT>
struct EqualFunctor<phi::dtype::complex<double>, OutT> {
HOSTDEVICE OutT operator()(const phi::dtype::complex<double>& a,
const phi::dtype::complex<double>& b) const {
if (isinf(a.real) || isinf(a.imag) || isinf(b.real) || isinf(b.imag)) {
return a == b;
}
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) {
return false;
}
double epsilon = 1e-8;
return std::abs(a.real - b.real) < epsilon &&
std::abs(a.imag - b.imag) < epsilon;
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这俩能合并为以下代码吗(可以参考: Paddle/paddle/phi/kernels/funcs/matrix_inverse.h)?

Suggested change
template <typename OutT>
struct EqualFunctor<phi::dtype::complex<float>, OutT> {
HOSTDEVICE OutT operator()(const phi::dtype::complex<float>& a,
const phi::dtype::complex<float>& b) const {
if (isinf(a.real) || isinf(a.imag) || isinf(b.real) || isinf(b.imag)) {
return a == b;
}
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) {
return false;
}
float epsilon = 1e-8f;
return std::abs(a.real - b.real) < epsilon &&
std::abs(a.imag - b.imag) < epsilon;
}
};
template <typename OutT>
struct EqualFunctor<phi::dtype::complex<double>, OutT> {
HOSTDEVICE OutT operator()(const phi::dtype::complex<double>& a,
const phi::dtype::complex<double>& b) const {
if (isinf(a.real) || isinf(a.imag) || isinf(b.real) || isinf(b.imag)) {
return a == b;
}
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) {
return false;
}
double epsilon = 1e-8;
return std::abs(a.real - b.real) < epsilon &&
std::abs(a.imag - b.imag) < epsilon;
}
};
template <typename InT, typename OutT>
struct EqualFunctor<phi::dtype::complex<InT>, OutT> {
HOSTDEVICE OutT operator()(const phi::dtype::complex<InT>& a,
const phi::dtype::complex<InT>& b) const {
if (isinf(a.real) || isinf(a.imag) || isinf(b.real) || isinf(b.imag)) {
return a == b;
}
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) {
return false;
}
InT epsilon = 1e-8;
return std::abs(a.real - b.real) < epsilon &&
std::abs(a.imag - b.imag) < epsilon;
}
};

Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来还能这么用!

@HydrogenSulfate
Copy link
Contributor

@MrXnneHang 包的编译好像有点问题,可以本地测试下编译和安装
image

@MrXnneHang
Copy link
Contributor Author

我正在给我家里的电脑配环境,我笔记本一直都是编译cpu版本...

Comment on lines 51 to 65
struct EqualFunctor<phi::dtype::complex<T>, OutT> {
HOSTDEVICE OutT operator()(const phi::dtype::complex<float>& a,
const phi::dtype::complex<float>& b) const {
if (isinf(a.real) || isinf(a.imag) || isinf(b.real) || isinf(b.imag)) {
return a == b;
}
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) {
return false;
}
float epsilon = 1e-8f;
return std::abs(a.real - b.real) < epsilon &&
std::abs(a.imag - b.imag) < epsilon;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我对这里的isnan有疑惑,cpu可以用std::isnan,而GPU应该使用isnan,这两者能通用吗?参考代码:Paddle/paddle/phi/kernels/impl/isfinite_kernel_impl.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能isnan通用而std::isnan只能用于cpu?
我注意到原本的EqualFunctor在cpu和gpu上的注册方式完全一致。并且也只采用了isnan
我编译gpu版本的时候都换成std::isnan测试下看看

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能isnan通用而std::isnan只能用于cpu? 我注意到原本的EqualFunctor在cpu和gpu上的注册方式完全一致。并且也只采用了isnan。 我编译gpu版本的时候都换成std::isnan测试下看看

如果不通用可以这样:
image

Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全部改用std::isnan似乎编译gpu版本报错方法在deviece上未实现.
全部用isnancpu版本不会报错功能正常,但是好像gpu版本好像不走这些Functors.
image
我有在模板里面printf一些信息,但是cpu版本显示而gpu版本里面不显示...
分别测的是complexTensor==complexTensorfloatTensor==floatTensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全部改用std::isnan似乎编译gpu版本报错方法在deviece上未实现. 全部用isnancpu版本不会报错功能正常,但是好像gpu版本好像不走这些Functors. image 我有在模板里面printf一些信息,但是cpu版本显示而gpu版本里面不显示... 分别测的是complexTensor==complexTensorfloatTensor==floatTensor

GPU上调用的代码不能是std::系列的,参考: Paddle/paddle/phi/kernels/impl/isfinite_kernel_impl.h

Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前我全部使用isnan。
但是我找不到equal在哪里注册的<GPU,AllLayout,int,float....>那些。
我注意到我在gpu版本中,对于常数的==是可以正常运行的,但是如果对于复数,会报错在未注册GPU上complex64complex128
而且我注意到似乎gpu版本的常数==不会进入EqualFunctor主模板或者特化模板中。
参考:#69968 (comment)

@HydrogenSulfate
Copy link
Contributor

我验证一下想法。因为经过排查问题出在单测。 因为我通过打印调试信息发现GPU版本ctest进入各个模板也都没有问题。 似乎只是coverage计算的时候被排除了?最终一个都没有统计到。

好的

@HydrogenSulfate
Copy link
Contributor

image
单测加一下skip,跳过xpu的设备

template <typename T>
struct EqualFunctor<phi::dtype::complex<T>> {
HOSTDEVICE bool operator()(const phi::dtype::complex<T>& a,
const phi::dtype::complex<T>& b) const {
Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

特化模板应该期望得到引用或者是一个指向变量的指针。
而我五天前因为包编译错误把引用给去了。导致期望得到引用或者变量指针匹配不到传入临时变量类型,这很大概率是CI一直C++ Coverage为0的原因。
ba8cbe8

如果直接引用,会产生这个冲突:reinterpret_cast casts away qualifiers

8eb19ee

我尝试用static_cast来转,但是引用输入是不期望被修改的。编译会出错。

不能在operater中引用变量->不能使用特化模板->合并模板。

Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

让我疑惑的是,我在本地printf可以打印出特化模板的内容。
可能的解释是下午我加入printf后出现的static check的警告:

 You are using GPU version Paddle, but your CUDA device is not set properly. CPU device will be used by default.

加入printf后,GPU陷入一种薛定谔状态,它也不知道自己是CPU还是GPU。

Copy link
Contributor

@HydrogenSulfate HydrogenSulfate Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

让我疑惑的是,我在本地printf可以打印出特化模板的内容。 可能的解释是下午我加入printf后出现的static check的警告:

 You are using GPU version Paddle, but your CUDA device is not set properly. CPU device will be used by default.

加入printf后,GPU陷入一种薛定谔状态,它也不知道自己是CPU还是GPU。

Coverage的覆盖率测试可能是存在问题的,如果仅仅是coverage没通过,而本地自测能够打印出coverage中报告未覆盖(红色)的内容,就先不用管。

Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

似乎确实是CI的问题
不过也确定了引用类型可以匹配上变量类型。

}
if (isinf(static_cast<float>(a)) || isinf(static_cast<float>(b))) {
return static_cast<OutT>(a == b);
}
Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我把泛化模板的isnan的检查提前到isinf之前。
这样子在碰到inf==nan的情况不用多算一次a==b。结果是一样的。逻辑是只要碰到nan直接抛出false

@HydrogenSulfate
Copy link
Contributor

HydrogenSulfate commented Dec 13, 2024

@MrXnneHang 覆盖率没过
image

@MrXnneHang
Copy link
Contributor Author

@MrXnneHang 覆盖率没过 image

image

image

我本地编译和测试的。

@MrXnneHang
Copy link
Contributor Author

image

即使没有用特化模板它的覆盖率也是0,这是上次合并模板后的CI结果

25a024d

@HydrogenSulfate
Copy link
Contributor

image

即使没有用特化模板它的覆盖率也是0,这是上次合并模板后的CI结果

25a024d

好的

Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有一个问题可以看一下

Comment on lines -40 to +43
return static_cast<OutT>(a == b);
if (isnan(static_cast<float>(a)) || isnan(static_cast<float>(b)))
if (isnan(static_cast<float>(a)) || isnan(static_cast<float>(b))) {
return static_cast<OutT>(false);
}
if (isinf(static_cast<float>(a)) || isinf(static_cast<float>(b))) {
return static_cast<OutT>(a == b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::is_floating_point分支下,为什么要再转成float呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/* IsnanFunctor */
template <typename T>
__global__ void IsnanCUDAKernel(
    const T* in_data,
    int num,
    bool* out_data,
    typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
  unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
    const T& a = in_data[i];
    out_data[i] = isnan(a);
  }
}

可能是处于并发检查的性能和内存考虑的?
因为对于每个数都要检查一遍,如果是double的话,长度是float的两倍。而且如果都是相同的数据类型的话对于GPU本身指令特化的结构也会更有利一些。
而且isnanisinf的话确实也没必要保持double。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我去改成float

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/* IsnanFunctor */
template <typename T>
__global__ void IsnanCUDAKernel(
    const T* in_data,
    int num,
    bool* out_data,
    typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
  unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
  for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
    const T& a = in_data[i];
    out_data[i] = isnan(a);
  }
}

可能是处于并发检查的性能和内存考虑的? 因为对于每个数都要检查一遍,如果是double的话,长度是float的两倍。而且如果都是相同的数据类型的话对于GPU本身指令特化的结构也会更有利一些。 而且isnanisinf的话确实也没必要保持double。

这会不会导致原本double的数被转为float后结果就不正确了呢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦如果原本是这样写的话,那先保持原状吧

Comment on lines 77 to 83
template <typename T>
struct NotEqualFunctor<phi::dtype::complex<T>> {
HOSTDEVICE bool operator()(const phi::dtype::complex<T> a,
const phi::dtype::complex<T> b) const {
return !EqualFunctor<phi::dtype::complex<T>>()(a, b);
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里能不改动原本的模板参数吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我需要测试一下,我记得NotEqual没有特化也是正常工作的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除掉NotEqual特化模板没有影响,依然可以进Equal特化和泛化。

Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit 1e4a234 into PaddlePaddle:develop Dec 17, 2024
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants