Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.103】move skip_layernorm/fused_bias_dropout_residual_layer_norm to phi #58396

Merged
merged 19 commits into from
Nov 24, 2023

Conversation

zeroRains
Copy link
Contributor

@zeroRains zeroRains commented Oct 26, 2023

PR types

Others

PR changes

Others

Description

move skip_layernorm/fused_bias_dropout_residual_layer_norm to phi
#57262

@paddle-bot
Copy link

paddle-bot bot commented Oct 26, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Oct 26, 2023
@zeroRains
Copy link
Contributor Author

zeroRains commented Oct 27, 2023

image

额,请问这个应该怎么转化成phi中的函数参数变量名呢?

@zeroRains
Copy link
Contributor Author

额,Cmake的时候说反向输出的数量比正向输出的数量要多,所以失败了,我数了一下.cc文件里正向有5个输出,反向有6个输出,但是反向的3个输出是要被3个可选参数控制的(没有可选参数就没有这三个输出)。请问这种情况应该怎么解决呢?还是我哪里理解错了吗?

image

@yuanlehome
Copy link
Contributor

image

额,请问这个应该怎么转化成phi中的函数参数变量名呢?

op_compat.yaml 放在extra里

@yuanlehome
Copy link
Contributor

额,Cmake的时候说反向输出的数量比正向输出的数量要多,所以失败了,我数了一下.cc文件里正向有5个输出,反向有6个输出,但是反向的3个输出是要被3个可选参数控制的(没有可选参数就没有这三个输出)。请问这种情况应该怎么解决呢?还是我哪里理解错了吗?

image

这应该就是输入输出数量没对应好,可以再检查一下~

Comment on lines -87 to -94
AddOutput("BiasDropoutResidualOut", "Output of bias + dropout + residual.")
.AsIntermediate();
AddOutput("DropoutMaskOut", "The random sampled dropout mask.")
.AsIntermediate();
AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("LnVariance", "Variance of the current mini batch.")
.AsIntermediate();
AddOutput("Y", "Result.");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在forward中确实设置了5个output,但在backword设置了6个output呀_(:з」∠)_

Comment on lines -214 to -243
if (this->HasInput("Bias")) {
op->SetInput("Bias", this->Input("Bias"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
}
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
if (this->HasOutput("LnMean")) {
op->SetInput("LnMean", this->Output("LnMean"));
}
if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance"));
}
if (this->HasOutput("BiasDropoutResidualOut")) {
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
}
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Residual"),
this->InputGrad("Residual"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
Copy link
Contributor Author

@zeroRains zeroRains Nov 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个backward确实SetOutput6次,因为前3个Output是受可选参数控制的我还把他们加在了optional里,但是仍然报错,(:з」∠)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在主要问题还是这个地方

@zeroRains
Copy link
Contributor Author

image
额,请问这个应该怎么转化成phi中的函数参数变量名呢?

op_compat.yaml 放在extra里

额,我不太明白是吧什么放extra里。是这样吗?
image

Copy link

paddle-ci-bot bot commented Nov 5, 2023

Sorry to inform you that 8092f50's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@yuanlehome
Copy link
Contributor

image
额,请问这个应该怎么转化成phi中的函数参数变量名呢?

op_compat.yaml 放在extra里

额,我不太明白是吧什么放extra里。是这样吗? image

对的,试一下

@zeroRains
Copy link
Contributor Author

image
额,请问这个应该怎么转化成phi中的函数参数变量名呢?

op_compat.yaml 放在extra里

额,我不太明白是吧什么放extra里。是这样吗? image

对的,试一下

确实可以

@yuanlehome
Copy link
Contributor

这个pr可以尽快推进下哈~ 有工作依赖fc的迁移

@zeroRains
Copy link
Contributor Author

zeroRains commented Nov 7, 2023

这个pr可以尽快推进下哈~ 有工作依赖fc的迁移

卡在前向和反向的输出数量上了。。实在看不出来哪里有问题_(:з)∠)_

@luotao1
Copy link
Contributor

luotao1 commented Nov 7, 2023

先把skip_layernorm和fused_bias_dropout_residual_layer_norm两个交了?

@yuanlehome
Copy link
Contributor

这个pr可以尽快推进下哈~ 有工作依赖fc的迁移

卡在前向和反向的输出数量上了。。实在看不出来哪里有问题_(:з)∠)_

分开,可以先把fc搞了

@yuanlehome
Copy link
Contributor

这个pr可以尽快推进下哈~ 有工作依赖fc的迁移

卡在前向和反向的输出数量上了。。实在看不出来哪里有问题_(:з)∠)_

分开,可以先把fc搞了

fc需要确保开启那个flag也能跑通哈

@zeroRains
Copy link
Contributor Author

这个pr可以尽快推进下哈~ 有工作依赖fc的迁移

卡在前向和反向的输出数量上了。。实在看不出来哪里有问题_(:з)∠)_

分开,可以先把fc搞了

fc需要确保开启那个flag也能跑通哈

好的,晚上我看看

@zeroRains
Copy link
Contributor Author

这个pr可以尽快推进下哈~ 有工作依赖fc的迁移

卡在前向和反向的输出数量上了。。实在看不出来哪里有问题_(:з)∠)_

分开,可以先把fc搞了

fc需要确保开启那个flag也能跑通哈

fc的pr在#58777 这里,但是有点小问题,写在fc的pr里了

@zeroRains zeroRains changed the title 【Hackathon 5th No.103】move skip_layernorm/fc/fused_bias_dropout_residual_layer_norm to phi 【Hackathon 5th No.103】move skip_layernorm/fused_bias_dropout_residual_layer_norm to phi Nov 7, 2023
@CLAassistant
Copy link

CLAassistant commented Nov 7, 2023

CLA assistant check
All committers have signed the CLA.

@luotao1
Copy link
Contributor

luotao1 commented Nov 15, 2023

@zeroRains 可以更新下这个PR了

Copy link

paddle-ci-bot bot commented Nov 16, 2023

Sorry to inform you that b3ea0d5's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@zeroRains
Copy link
Contributor Author

zeroRains commented Nov 16, 2023

70a63fbcedad7a13b8c08baff1dfc87
感觉这个有点多余了,我试试把这个参数在backward的参数列表中去掉,然后在backward的kernel中新建一个DenseTensor代替他,这样就能解决反向输出比前向输入多的问题了

@zeroRains
Copy link
Contributor Author

zeroRains commented Nov 16, 2023

image
但是遇到了个奇怪的问题。。。这个vjp是啥
@yuanlehome

@zeroRains
Copy link
Contributor Author

image 但是遇到了个奇怪的问题。。。这个vjp是啥 @yuanlehome

将算子加入到vjp_interface_black_list解决

@zeroRains
Copy link
Contributor Author

两个算子的单测,均已通过~

@yuanlehome

infer_meta :
func : SkipLayerNormInferMeta
kernel :
func : skip_layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_type : x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@yuanlehome
Copy link
Contributor

两个算子的单测,均已通过~

@yuanlehome

👍,开启那个flags也是可以通过的吗?

@zeroRains
Copy link
Contributor Author

两个算子的单测,均已通过~
@yuanlehome

👍,开启那个flags也是可以通过的吗?

是的~

- backward_op : fused_bias_dropout_residual_layer_norm_grad
forward: fused_bias_dropout_residual_layer_norm (Tensor x, Tensor residual, Tensor bias, Tensor ln_scale, Tensor ln_bias, float dropout_rate, bool is_test, bool dropout_fix_seed, int dropout_seed, str dropout_implementation, float ln_epsilon) -> Tensor(bias_dropout_residual_out), Tensor(dropout_mask_out), Tensor(ln_mean), Tensor(ln_variance), Tensor(y)
args : (Tensor y_grad, Tensor x, Tensor residual, Tensor bias, Tensor ln_scale, Tensor ln_bias, Tensor ln_mean, Tensor ln_variance, Tensor bias_dropout_residual_out, Tensor dropout_mask_out, float dropout_rate = 0.5f, bool is_test = false, bool dropout_fix_seed = true, int dropout_seed = true, str dropout_implementation = "downgrade_in_infer", float ln_epsilon = 1e-5)
output : Tensor(bias_grad), Tensor(ln_scale_grad), Tensor(ln_bias_grad), Tensor(x_grad), Tensor(residual_grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad的顺序需要按照forward里的tensor顺序排

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- op : fused_bias_dropout_residual_layer_norm
args : (Tensor x, Tensor residual, Tensor bias, Tensor ln_scale, Tensor ln_bias, float dropout_rate = 0.5f, bool is_test = false, bool dropout_fix_seed = true, int dropout_seed = true, str dropout_implementation = "downgrade_in_infer", float ln_epsilon = 1e-5)
optional : bias, ln_scale, ln_bias
output : Tensor(bias_dropout_residual_out), Tensor(dropout_mask_out), Tensor(ln_mean), Tensor(ln_variance), Tensor(y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

标记为intermediate的tensor放在后面,没有标记intermediate的在前面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 606 to 614
void SkipLayerNormInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& scale,
const MetaTensor& bias,
const float epsilon,
const int begin_norm_axis,
MetaTensor* out);

void FusedBiasDropoutResidualLnInferMeta(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferMeta函数按照字母序放置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个SkipLayerNormFunctor看着只有skip_layernorm的kernel在用,直接放到skip_layernorm_kernel.cu中吧

Copy link
Contributor Author

@zeroRains zeroRains Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接复制到skip_layernorm_kernel中会有一些奇怪的匹配问题,得再看看

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能不能不改呀,o(╥﹏╥)o,好多奇怪的不匹配的bug

我一眼看过去参数匹配没问题啊

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不改也行

kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(2).SetDataType(kernel_key.dtype());
kernel->OutputAt(3).SetDataType(kernel_key.dtype());
kernel->OutputAt(3).SetDataType(kernel_key.dtype());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_key dtype相同的就不需要设置了,可以去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blas.h有使用到吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实没有,已删除

Copy link
Contributor

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit e57e312 into PaddlePaddle:develop Nov 24, 2023
@zeroRains zeroRains deleted the 103 branch November 24, 2023 14:33
SecretXV pushed a commit to SecretXV/Paddle that referenced this pull request Nov 28, 2023
…_layer_norm to phi (PaddlePaddle#58396)

* temp commit

* move skip_layernorm to phi, but have abug in include skip_layernorm_functor.h

* temp save

* move skip_layernorm to phi

* move fused_bias_dropout_layer_norm to phi

* roback

move fc to file but have a bug in onednn

roback

* change the register name

* fix the forward input len smaller than backward output len

* roback

* add in vjp_interface_black_list

* fix parse

* fix the bug

* Update fused_ops.yaml

* update
@yuanlehome
Copy link
Contributor

yuanlehome commented Nov 28, 2023

@zeroRains 提个PR修复一下吧
image

@zeroRains
Copy link
Contributor Author

zeroRains commented Nov 28, 2023

@zeroRains 提个PR修复一下吧 image

已提交,#59461 ,sorry

@luotao1 luotao1 changed the title 【Hackathon 5th No.103】move skip_layernorm/fused_bias_dropout_residual_layer_norm to phi 【Hackathon 5th No.103】move skip_layernorm/fused_bias_dropout_residual_layer_norm to phi -part Dec 22, 2023
@luotao1 luotao1 changed the title 【Hackathon 5th No.103】move skip_layernorm/fused_bias_dropout_residual_layer_norm to phi -part 【Hackathon 5th No.103】move skip_layernorm/fused_bias_dropout_residual_layer_norm to phi Jan 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants