From 47b51597b8331224a9944b37af7f1fe1a136dc09 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 8 Apr 2024 13:54:57 +0000 Subject: [PATCH 1/7] [Dy2St] Mark FullWithTensor shape as mutable attribute --- .../ir_adaptor/translator/op_translator.cc | 25 ++++++++++--------- .../fluid/pir/dialect/op_generator/op_gen.py | 12 ++++----- paddle/fluid/pir/dialect/operator/ir/ops.yaml | 2 +- .../manual/manual_static_prim_backend.cc | 2 +- .../fluid/pybind/manual_static_op_function.h | 2 +- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- paddle/phi/api/yaml/op_compat.yaml | 6 +++++ paddle/phi/infermeta/multiary.cc | 4 +-- paddle/phi/infermeta/multiary.h | 2 +- paddle/phi/kernels/full_kernel.h | 2 +- .../impl/full_whit_tensor_kernel_impl.h | 7 +++--- .../phi/kernels/selected_rows/full_kernel.cc | 4 +-- .../phi/kernels/selected_rows/full_kernel.h | 2 +- 13 files changed, 39 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index f41a25fe9717c..d3d9174e84c48 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -2023,6 +2023,19 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { const OpInputInfoList& input_infos, pir::Block* block) override { std::vector op_inputs; + if (op_desc.HasInput("ValueTensor", true) && + op_desc.Input("ValueTensor", true).size() > 0) { + auto value_tensor_vars = op_desc.Input("ValueTensor", true); + auto defining_info = (*param_map)[value_tensor_vars[0]]; + op_inputs.push_back(defining_info.value); + } else { + float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); + pir::Attribute new_attr = pir::FloatAttribute::get(ctx, value); + auto defining_op = + InsertFullOperationForAttributeInput(ctx, block, new_attr); + op_inputs.push_back(defining_op->result(0)); + } + if (op_desc.HasInput("ShapeTensor", true) && op_desc.Input("ShapeTensor", true).size() > 0) { auto shape_tensor_vars = op_desc.Input("ShapeTensor", true); @@ -2044,18 +2057,6 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { op_inputs.push_back(defining_op->result(0)); } - if (op_desc.HasInput("ValueTensor", true) && - op_desc.Input("ValueTensor", true).size() > 0) { - auto value_tensor_vars = op_desc.Input("ValueTensor", true); - auto defining_info = (*param_map)[value_tensor_vars[0]]; - op_inputs.push_back(defining_info.value); - } else { - float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); - pir::Attribute new_attr = pir::FloatAttribute::get(ctx, value); - auto defining_op = - InsertFullOperationForAttributeInput(ctx, block, new_attr); - op_inputs.push_back(defining_op->result(0)); - } return op_inputs; } diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 37e620ab24589..d6aea4e94bb0e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1802,12 +1802,12 @@ def AutoCodeGen( extra_args=extra_args, skip_transform_inputs=skip_transform_inputs, data_format_tensors=data_format_tensors, - is_onednn_only="true" - if op_info.is_onednn_only - else "false", - dynamic_fallback="true" - if op_info.dynamic_fallback - else "false", + is_onednn_only=( + "true" if op_info.is_onednn_only else "false" + ), + dynamic_fallback=( + "true" if op_info.dynamic_fallback else "false" + ), ) # generate op verify function str op_verify_str = '' diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 9cc5b1bf4a341..58e88c4eb07ab 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -766,7 +766,7 @@ skip_transform : x - op : full_with_tensor - args : (Tensor shape, Tensor value, DataType dtype=DataType::FLOAT32) + args : (Tensor value, IntArray shape, DataType dtype=DataType::FLOAT32) output: Tensor(out) infer_meta : func : FullWithTensorInferMeta diff --git a/paddle/fluid/primitive/backend/manual/manual_static_prim_backend.cc b/paddle/fluid/primitive/backend/manual/manual_static_prim_backend.cc index a79e929a6e5cc..a479379cc6ab4 100644 --- a/paddle/fluid/primitive/backend/manual/manual_static_prim_backend.cc +++ b/paddle/fluid/primitive/backend/manual/manual_static_prim_backend.cc @@ -43,7 +43,7 @@ Tensor full_with_tensor(const Tensor& shape, std::static_pointer_cast(shape.impl())->value(); pir::Value value_res = paddle::dialect::full( std::vector{}, value.to(), dtype, place); - auto op_res = paddle::dialect::full_with_tensor(shape_res, value_res, dtype); + auto op_res = paddle::dialect::full_with_tensor(value_res, shape_res, dtype); Tensor out(std::make_shared(op_res)); return out; } diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index 7767c4a4569b3..8943633fb4cda 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -159,7 +159,7 @@ PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) { CallStackRecorder callstack_recoder("full_with_tensor"); callstack_recoder.Record(); auto static_api_out = - paddle::dialect::full_with_tensor(shape, value, dtype); + paddle::dialect::full_with_tensor(value, shape, dtype); callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 142814e1cc01e..f6e9b1e4e135b 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -551,7 +551,7 @@ skip_transform : x - op : full_with_tensor - args : (Tensor shape, Tensor value, DataType dtype=DataType::FLOAT32) + args : (Tensor value, IntArray shape, DataType dtype=DataType::FLOAT32) output: Tensor(out) infer_meta : func : FullWithTensorInferMeta diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f1db7cb97191b..dc025d7841be8 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1279,6 +1279,12 @@ data_type : float support_tensor : true +- op : full_with_tensor + int_array: + shape : + data_type : int + support_tensor : true + - op : fused_adam_(fused_adam) inputs : {params : Params, grads : Grads, learning_rate : LearningRate, moments1 : Moments1, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index a87fdd936b89d..ceebbdb5b2d74 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4921,10 +4921,10 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, } } -void FullWithTensorInferMeta(const MetaTensor& shape, +void FullWithTensorInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { - out->set_dims(common::make_ddim(std::vector(shape.numel(), -1))); + out->set_dims(common::make_ddim(shape.GetData())); out->set_dtype(dtype); } diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 7a94ef98bc993..8d6a366fdbb24 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -952,7 +952,7 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, MetaTensor* cache_kv_out, MetaTensor* beam_cache_offset_out); -void FullWithTensorInferMeta(const MetaTensor& shape, +void FullWithTensorInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out); diff --git a/paddle/phi/kernels/full_kernel.h b/paddle/phi/kernels/full_kernel.h index b10e02658fe75..e6d80ed43dff4 100644 --- a/paddle/phi/kernels/full_kernel.h +++ b/paddle/phi/kernels/full_kernel.h @@ -33,8 +33,8 @@ void FullKernel(const Context& dev_ctx, template void FullWithTensorKernel(const Context& dev_ctx, - const DenseTensor& shape, const DenseTensor& value, + const IntArray& shape, DataType dtype, DenseTensor* out); diff --git a/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h b/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h index ae7ce8a3f41a8..375c2f8eae696 100644 --- a/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h +++ b/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h @@ -20,12 +20,11 @@ namespace phi { template void FullWithTensorKernel(const Context& dev_ctx, - const DenseTensor& shape, const DenseTensor& value, + const IntArray& shape, DataType dtype, DenseTensor* out) { - auto shape_tmp = IntArray(shape); - out->Resize(common::make_ddim(shape_tmp.GetData())); - FullKernel(dev_ctx, shape_tmp, Scalar(value), dtype, out); + out->Resize(common::make_ddim(shape.GetData())); + FullKernel(dev_ctx, shape, Scalar(value), dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/selected_rows/full_kernel.cc b/paddle/phi/kernels/selected_rows/full_kernel.cc index 0a3b3ae62fe63..9ea85bf094a73 100644 --- a/paddle/phi/kernels/selected_rows/full_kernel.cc +++ b/paddle/phi/kernels/selected_rows/full_kernel.cc @@ -37,12 +37,12 @@ void FullKernel(const Context& dev_ctx, template void FullWithTensorKernel(const Context& dev_ctx, - const DenseTensor& shape, const DenseTensor& value, + const IntArray& shape, DataType dtype, SelectedRows* out) { phi::FullWithTensorKernel( - dev_ctx, shape, value, dtype, out->mutable_value()); + dev_ctx, value, shape, dtype, out->mutable_value()); } } // namespace sr diff --git a/paddle/phi/kernels/selected_rows/full_kernel.h b/paddle/phi/kernels/selected_rows/full_kernel.h index 07cfe7fd6378b..2515c60ebcfb5 100644 --- a/paddle/phi/kernels/selected_rows/full_kernel.h +++ b/paddle/phi/kernels/selected_rows/full_kernel.h @@ -30,8 +30,8 @@ void FullKernel(const Context& dev_ctx, template void FullWithTensorKernel(const Context& dev_ctx, - const DenseTensor& shape, const DenseTensor& value, + const IntArray& shape, DataType dtype, SelectedRows* out); } // namespace sr From 77899df6156ff016baca7ac8524decfca3ab120f Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 8 Apr 2024 16:21:26 +0000 Subject: [PATCH 2/7] add ut --- paddle/fluid/pir/dialect/op_generator/op_gen.py | 12 ++++++------ test/legacy_test/test_zeros_op.py | 8 ++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index d6aea4e94bb0e..37e620ab24589 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1802,12 +1802,12 @@ def AutoCodeGen( extra_args=extra_args, skip_transform_inputs=skip_transform_inputs, data_format_tensors=data_format_tensors, - is_onednn_only=( - "true" if op_info.is_onednn_only else "false" - ), - dynamic_fallback=( - "true" if op_info.dynamic_fallback else "false" - ), + is_onednn_only="true" + if op_info.is_onednn_only + else "false", + dynamic_fallback="true" + if op_info.dynamic_fallback + else "false", ) # generate op verify function str op_verify_str = '' diff --git a/test/legacy_test/test_zeros_op.py b/test/legacy_test/test_zeros_op.py index ce4449335425c..a5888c11d086e 100644 --- a/test/legacy_test/test_zeros_op.py +++ b/test/legacy_test/test_zeros_op.py @@ -80,5 +80,13 @@ def test_shape_errors(self): assert error_msg.find("expected to be no less than 0") > 0 +class ApiZerosWithDynamicShape(unittest.TestCase): + def test_dynamic_shape(self): + with paddle.pir_utils.IrGuard(): + x = paddle.static.data("x", shape=[], dtype='int32') + out = paddle.zeros(shape=[101, x]) + self.assertEqual(out.shape, [101, -1]) + + if __name__ == '__main__': unittest.main() From b13eb465d5205777babfaa1b1b4d369bea044fa4 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 8 Apr 2024 16:24:16 +0000 Subject: [PATCH 3/7] fix typo `with` --- paddle/phi/kernels/cpu/full_kernel.cc | 2 +- paddle/phi/kernels/gpu/full_kernel.cu | 2 +- ...whit_tensor_kernel_impl.h => full_with_tensor_kernel_impl.h} | 0 paddle/phi/kernels/xpu/full_kernel.cc | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename paddle/phi/kernels/impl/{full_whit_tensor_kernel_impl.h => full_with_tensor_kernel_impl.h} (100%) diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index b1a6ceda3647d..06267595fc3f3 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h" +#include "paddle/phi/kernels/impl/full_with_tensor_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index fde2e33505f97..b815b754c8eb7 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h" +#include "paddle/phi/kernels/impl/full_with_tensor_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h b/paddle/phi/kernels/impl/full_with_tensor_kernel_impl.h similarity index 100% rename from paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h rename to paddle/phi/kernels/impl/full_with_tensor_kernel_impl.h diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index 1a780f132016d..758ad186b83ce 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -23,7 +23,7 @@ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h" +#include "paddle/phi/kernels/impl/full_with_tensor_kernel_impl.h" namespace phi { From a6823fd552380b0316ffec7885b40e68c1db5c11 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 9 Apr 2024 02:04:34 +0000 Subject: [PATCH 4/7] fix pd_to_cinn_pass --- .../cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 3bf32aa91837d..be57629fe8747 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -761,8 +761,8 @@ class FullWithTensorOpPattern bool MatchAndRewrite(paddle::dialect::FullWithTensorOp op, pir::PatternRewriter &rewriter) const override { - auto shape = op->operand_source(0); - auto value = op->operand_source(1); + auto value = op->operand_source(0); + auto shape = op->operand_source(1); if (paddle::dialect::TransToPhiDataType( value.type() From ca5200659c218ac4214a08aef861d2d6e1ff785a Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 9 Apr 2024 03:39:57 +0000 Subject: [PATCH 5/7] fix infer_symbolic_shape --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 3a1c411caf1b3..30b79314c365d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -261,7 +261,7 @@ bool ConcatOpInferSymbolicShape( bool FullWithTensorOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); + pir::Value operand_source = op->operand_source(1); const symbol::ShapeOrDataDimExprs &operand_shape_or_data = shape_analysis->GetShapeOrDataForValue(operand_source); From 09fe64fa8486cacbd6d630c0cec805628d504bc4 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 9 Apr 2024 05:22:22 +0000 Subject: [PATCH 6/7] inverse inputs in simple_llama.config --- test/ir/pir/cinn/symbolic/simple_llama.config | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/ir/pir/cinn/symbolic/simple_llama.config b/test/ir/pir/cinn/symbolic/simple_llama.config index 1e80f206a970d..3898b93f62723 100644 --- a/test/ir/pir/cinn/symbolic/simple_llama.config +++ b/test/ir/pir/cinn/symbolic/simple_llama.config @@ -26,24 +26,24 @@ (%24) = "pd_op.slice" (%18, %22, %23) {axes:[(Int64)0],decrease_axis:[(Int64)0],infer_flags:[(Int64)1],is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<2xi32>, builtin.tensor<1xi64>, builtin.tensor<1xi64>) -> builtin.tensor (%25) = "pd_op.cast" (%24) {dtype:(pd_op.DataType)int64,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor) -> builtin.tensor (%26) = "pd_op.full_int_array" () {dtype:(pd_op.DataType)int64,place:(pd_op.Place)Place(cpu),stop_gradient:[true],value:[(Int64)1]} : () -> builtin.tensor<1xi64> - (%27) = "pd_op.full_with_tensor" (%26, %25) {dtype:(pd_op.DataType)int64,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<1xi64>, builtin.tensor) -> builtin.tensor<1xi64> + (%27) = "pd_op.full_with_tensor" (%25, %26) {dtype:(pd_op.DataType)int64,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor, builtin.tensor<1xi64>) -> builtin.tensor<1xi64> (%28) = "pd_op.shape" (%17) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<-1x-1xi64>) -> builtin.tensor<2xi32> (%29) = "pd_op.full_int_array" () {dtype:(pd_op.DataType)int64,place:(pd_op.Place)Place(cpu),stop_gradient:[true],value:[(Int64)1]} : () -> builtin.tensor<1xi64> (%30) = "pd_op.full_int_array" () {dtype:(pd_op.DataType)int64,place:(pd_op.Place)Place(cpu),stop_gradient:[true],value:[(Int64)2]} : () -> builtin.tensor<1xi64> (%31) = "pd_op.slice" (%28, %29, %30) {axes:[(Int64)0],decrease_axis:[(Int64)0],infer_flags:[(Int64)1],is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<2xi32>, builtin.tensor<1xi64>, builtin.tensor<1xi64>) -> builtin.tensor (%32) = "pd_op.cast" (%31) {dtype:(pd_op.DataType)int64,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor) -> builtin.tensor (%33) = "pd_op.full_int_array" () {dtype:(pd_op.DataType)int64,place:(pd_op.Place)Place(cpu),stop_gradient:[true],value:[(Int64)1]} : () -> builtin.tensor<1xi64> - (%34) = "pd_op.full_with_tensor" (%33, %32) {dtype:(pd_op.DataType)int64,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<1xi64>, builtin.tensor) -> builtin.tensor<1xi64> + (%34) = "pd_op.full_with_tensor" (%32, %33) {dtype:(pd_op.DataType)int64,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor, builtin.tensor<1xi64>) -> builtin.tensor<1xi64> (%35) = "pd_op.full" () {dtype:(pd_op.DataType)int32,is_persistable:[false],place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[],stop_gradient:[false],value:(Float)1} : () -> builtin.tensor (%36) = "builtin.combine" (%21, %35) {} : (builtin.tensor, builtin.tensor) -> vec[builtin.tensor,builtin.tensor] (%37) = "pd_op.stack" (%36) {axis:(Int32)0,stop_gradient:[true]} : (vec[builtin.tensor,builtin.tensor]) -> builtin.tensor<2xi32> (%38) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Float)1} : () -> builtin.tensor<1xf32> - (%39) = "pd_op.full_with_tensor" (%37, %38) {dtype:(pd_op.DataType)bool,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<2xi32>, builtin.tensor<1xf32>) -> builtin.tensor<-1x1xb> + (%39) = "pd_op.full_with_tensor" (%38, %37) {dtype:(pd_op.DataType)bool,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<1xf32>, builtin.tensor<2xi32>) -> builtin.tensor<-1x1xb> (%40) = "pd_op.full" () {dtype:(pd_op.DataType)int32,is_persistable:[false],place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[],stop_gradient:[false],value:(Float)1} : () -> builtin.tensor (%41) = "builtin.combine" (%21, %40) {} : (builtin.tensor, builtin.tensor) -> vec[builtin.tensor,builtin.tensor] (%42) = "pd_op.stack" (%41) {axis:(Int32)0,stop_gradient:[true]} : (vec[builtin.tensor,builtin.tensor]) -> builtin.tensor<2xi32> (%43) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Float)0} : () -> builtin.tensor<1xf32> - (%44) = "pd_op.full_with_tensor" (%42, %43) {dtype:(pd_op.DataType)float16,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<2xi32>, builtin.tensor<1xf32>) -> builtin.tensor<-1x1xf16> + (%44) = "pd_op.full_with_tensor" (%43, %42) {dtype:(pd_op.DataType)float16,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<1xf32>, builtin.tensor<2xi32>) -> builtin.tensor<-1x1xf16> (%45) = "pd_op.shape" (%17) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<-1x-1xi64>) -> builtin.tensor<2xi32> (%46) = "pd_op.full_int_array" () {dtype:(pd_op.DataType)int64,place:(pd_op.Place)Place(cpu),stop_gradient:[true],value:[(Int64)1]} : () -> builtin.tensor<1xi64> (%47) = "pd_op.full_int_array" () {dtype:(pd_op.DataType)int64,place:(pd_op.Place)Place(cpu),stop_gradient:[true],value:[(Int64)2]} : () -> builtin.tensor<1xi64> @@ -222,7 +222,7 @@ (%232) = "pd_op.full" () {dtype:(pd_op.DataType)int32,is_persistable:[false],place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[],stop_gradient:[false],value:(Float)1} : () -> builtin.tensor (%233) = "builtin.combine" (%230, %232) {} : (builtin.tensor, builtin.tensor) -> vec[builtin.tensor,builtin.tensor] (%234) = "pd_op.stack" (%233) {axis:(Int32)0,stop_gradient:[true]} : (vec[builtin.tensor,builtin.tensor]) -> builtin.tensor<2xi32> - (%235) = "pd_op.full_with_tensor" (%234, %231) {dtype:(pd_op.DataType)float16,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<2xi32>, builtin.tensor<1xf16>) -> builtin.tensor<-1x1xf16> + (%235) = "pd_op.full_with_tensor" (%231, %234) {dtype:(pd_op.DataType)float16,is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<1xf16>, builtin.tensor<2xi32>) -> builtin.tensor<-1x1xf16> (%236, %237) = "pd_op.top_p_sampling" (%225, %235, <>) {is_persistable:[false,false],seed:(Int32)-1,stop_gradient:[false,false]} : (builtin.tensor<-1x32000xf16>, builtin.tensor<-1x1xf16>, <>) -> builtin.tensor<-1x1xf16>, builtin.tensor<-1x1xi64> (%238) = "pd_op.index_sample" (%226, %237) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<-1x32000xf16>, builtin.tensor<-1x1xi64>) -> builtin.tensor<-1x1xf16> (%239) = "pd_op.subtract" (%27, %34) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<1xi64>, builtin.tensor<1xi64>) -> builtin.tensor<1xi64> From 531269b26de0ab24f2ecd7d1e2d542299328e312 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 10 Apr 2024 03:02:34 +0000 Subject: [PATCH 7/7] change shape dtype to `int64_t` --- paddle/phi/api/yaml/op_compat.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index dc025d7841be8..87bc330570125 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1282,7 +1282,7 @@ - op : full_with_tensor int_array: shape : - data_type : int + data_type : int64_t support_tensor : true - op : fused_adam_(fused_adam)