diff --git a/cmake/external/jemalloc.cmake b/cmake/external/jemalloc.cmake index bdd6bfc6c00378..183c9369a2b2cb 100644 --- a/cmake/external/jemalloc.cmake +++ b/cmake/external/jemalloc.cmake @@ -5,8 +5,7 @@ set(JEMALLOC_DOWNLOAD_DIR set(JEMALLOC_PROJECT "extern_jemalloc") set(JEMALLOC_BUILD ${THIRD_PARTY_PATH}/jemalloc/src/extern_jemalloc) set(JEMALLOC_PREFIX_DIR ${THIRD_PARTY_PATH}/jemalloc) -set(JEMALLOC_URL - ${GIT_URL}/jemalloc/jemalloc/releases/download/5.1.0/jemalloc-5.1.0.tar.bz2) +set(JEMALLOC_URL https://paddle-ci.gz.bcebos.com/jemalloc-5.1.0.tar.bz2) set(JEMALLOC_INSTALL ${THIRD_PARTY_PATH}/install/jemalloc) set(JEMALLOC_INCLUDE_DIR ${JEMALLOC_INSTALL}/include) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 447c744da39c3c..6eb5a36d6e00f2 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -24,7 +24,7 @@ set(XPU_XFT_LIB_NAME "libxft.so") set(XPU_XPTI_LIB_NAME "libxpti.so") if(NOT DEFINED XPU_BASE_DATE) - set(XPU_BASE_DATE "20230926") + set(XPU_BASE_DATE "20231023") endif() set(XPU_XCCL_BASE_VERSION "1.0.53.6") if(NOT DEFINED XPU_XFT_BASE_VERSION) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 9f4ffd23a57e1c..92aaa69cb46f66 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -499,12 +499,15 @@ function(cc_test_run TARGET_NAME) NAME ${TARGET_NAME} COMMAND ${cc_test_COMMAND} ${cc_test_ARGS} WORKING_DIRECTORY ${cc_test_DIR}) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT - FLAGS_cpu_deterministic=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT - FLAGS_init_allocated_mem=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT - FLAGS_cudnn_deterministic=true) + set_property( + TEST ${TARGET_NAME} + PROPERTY + ENVIRONMENT + FLAGS_cpu_deterministic=true + FLAGS_init_allocated_mem=true + FLAGS_cudnn_deterministic=true + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${PADDLE_BINARY_DIR}/python/paddle/libs:${PADDLE_BINARY_DIR}/python/paddle/base + ) # No unit test should exceed 2 minutes. if(WIN32) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 150) @@ -726,6 +729,7 @@ function(nv_test TARGET_NAME) # 2. cuda_add_executable does not support ccache. # Reference: https://cmake.org/cmake/help/v3.10/module/FindCUDA.html add_executable(${TARGET_NAME} ${nv_test_SRCS}) + target_compile_definitions(${TARGET_NAME} PUBLIC STATIC_PADDLE) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} ${os_dependency_modules} paddle_gtest_main phi) diff --git a/cmake/hip.cmake b/cmake/hip.cmake index d3972e577a8009..4f005e95bb98a6 100644 --- a/cmake/hip.cmake +++ b/cmake/hip.cmake @@ -118,6 +118,11 @@ list(APPEND HIP_CXX_FLAGS -Wno-unused-value) list(APPEND HIP_CXX_FLAGS -Wno-braced-scalar-init) list(APPEND HIP_CXX_FLAGS -Wno-return-type) list(APPEND HIP_CXX_FLAGS -Wno-pragma-once-outside-header) +list(APPEND HIP_CXX_FLAGS -Wno-deprecated-builtins) +list(APPEND HIP_CXX_FLAGS -Wno-switch) +list(APPEND HIP_CXX_FLAGS -Wno-literal-conversion) +list(APPEND HIP_CXX_FLAGS -Wno-constant-conversion) +list(APPEND HIP_CXX_FLAGS -Wno-defaulted-function-deleted) if(WITH_CINN) list(APPEND HIP_CXX_FLAGS -std=c++14) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index a0f5d2c82eeb88..61813e3f5e2ffd 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -684,6 +684,9 @@ function(prune_pybind_h) list(APPEND op_list "load_combine") list(APPEND op_list "tensorrt_engine") + # TODO(ming1753): conditional_block_infer is temporarily reserved here to avoid link errors in functions of standalone_executor + list(APPEND op_list "conditional_block_infer") + # add fused_op in op_list list(APPEND op_list "fc") list(APPEND op_list "conv2d_fusion") diff --git a/paddle/cinn/README.md b/paddle/cinn/README.md new file mode 100644 index 00000000000000..204feab7f2798f --- /dev/null +++ b/paddle/cinn/README.md @@ -0,0 +1,121 @@ +``` + ___ ___ ___ + /\__\ /\ \ /\ \ + /:/ / ___ \:\ \ \:\ \ + /:/ / /\__\ \:\ \ \:\ \ + /:/ / ___ /:/__/ _____\:\ \ _____\:\ \ + /:/__/ /\__\/::\ \ /::::::::\__\/::::::::\__\ + \:\ \ /:/ /\/\:\ \__\:\~~\~~\/__/\:\~~\~~\/__/ + \:\ /:/ / \:\/\__\\:\ \ \:\ \ + \:\/:/ / \::/ / \:\ \ \:\ \ + \::/ / /:/ / \:\__\ \:\__\ + \/__/ \/__/ \/__/ \/__/ + +``` + + +# CINN : Compiler Infrastructure for Neural Networks + +The project CINN is a machine learning compiler and executor for multiple hardware backends. +It is designed to provide multiple layers of APIs to make tensor computation easier to define, faster to execute, and more convenient to extend with hardware backends. +Currently, it targets x86 CPUs and Nvidia GPUs. + +This project is under active development. + +## How it works + +The CINN lowers a traditional DNN model into a two-level intermediate representation(IR), the high-level IR(HLIR) and CINN IR. + +The HLIR helps to define some domain-specific computation and perform some overall optimization on the IR-graph; +the CINN IR helps to represent some computation semantic and finally lower to a hardware backend. + +Both levels of IR have the similar SSA graph, analysis and optimization facilities. +The schedule transform is applied on the CINN IR to do optimizations. + +For more details, you can refer to: +https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/cinn + +## Getting Started + +### Compile + +Clone PaddlePaddle first. + +``` +git clone https://github.com/PaddlePaddle/Paddle.git +cd Paddle +mkdir build +cd build +``` + +Build paddle with cinn: + +``` +cmake .. -DCINN_ONLY=OFF -DWITH_CINN=ON -DWITH_GPU=ON +``` + +Build cinn only: + +``` +cmake .. -DCINN_ONLY=ON -DWITH_CINN=ON -DWITH_GPU=ON +``` + +And then + +``` +make -j +``` + +### Install + +Install paddle with cinn: + +``` +pip install python/dist/paddlepaddle_gpu-xxx.whl +``` + +Install cinn only: + +``` +pip install python/dist/cinn_gpu-xxx.whl +``` + +Then you can import paddle in the python environment and check if a paddle version with CINN is installed. + +``` +import paddle +paddle.is_compiled_with_cinn() +``` + +### Concepts + +There are two levels of APIs in CINN, the higher level is HLIR and the lower level is CINN IR, both contain some concepts. + +In HLIR + +- `frontend::Program`, the program helps to define a machine learning computation, +- `hlir::framework::Tensor`, multi-dimensional arrays helps to manage a memory buffer. +- `hlir::framework::Program`, the final executable program in runtime. It holds many basic executable elements. +- `hlir::framework::Graph`, the graph that represents the structure of a model. Each node in the graph represents an operator (conv2d, relu, mul, etc.). +- `hlir::framework::GraphCompiler`, the compiler that transforms the graph representation(hlir::framework::Graph) of a model into an executable program(hlir::framework::Program). + +In CINN IR + +- `Compute`, the method to define a computation, +- `Lower`, the method to lower a computation to the corresponding IR, +- `LoweredFunc`, the function defined in CINN IR, +- `Var`, a scalar variable, +- `Expr`, an expression represents any CINN IR node(no specified Statement node), + +## License + +CINN is licensed under the [Apache 2.0 license](LICENSE). + +## Acknowledgement + +CINN learned a lot from the following projects: + +- [Halide](https://github.com/halide/Halide): Referenced the design of most IR nodes, +- [TVM](https://github.com/apache/tvm): We learned many ideas including the semantics of some schedule primitives, TOPI, NNVM, and so on, +- [tiramisu](https://github.com/Tiramisu-Compiler): The isl usage, polyhedral compilation, schedule primitive implementation, and so on, +- [tensorflow/xla](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla): Referenced the semantics of the primitive operations. diff --git a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt index e831bc7114f95c..542ed6c21d0ce4 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt @@ -1,4 +1,4 @@ -# TODO(Aurelius84): new_ir_compiler depends on pd_op_dialect and could +# TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could # not found under CINN_ONLY mode if(NOT CINN_ONLY) set(CINN_DIALECT_BINARY_DIR diff --git a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h index 724aed031165d6..9c6959db093e4b 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h +++ b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h @@ -18,8 +18,8 @@ #include #include #include -#include "paddle/cinn/hlir/framework/new_ir/utils.h" #include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/pir/core/attribute_base.h" #include "paddle/pir/core/operation.h" @@ -51,7 +51,7 @@ struct GroupInfo { private: void Initialize() { op_pattern_kind = hlir::framework::OpPatternKind::kElementWise; - fn_name = hlir::framework::newir::CompatibleInfo::GroupOpsName(ops); + fn_name = hlir::framework::pir::CompatibleInfo::GroupOpsName(ops); } }; @@ -78,7 +78,7 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage { }; struct JITInfoAttributeStorage : public pir::AttributeStorage { - using ParamKey = cinn::hlir::framework::newir::CUDAJITInfo; + using ParamKey = cinn::hlir::framework::pir::CUDAJITInfo; explicit JITInfoAttributeStorage(const ParamKey& key) : data_(key) {} static JITInfoAttributeStorage* Construct(const ParamKey& key) { diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc index 43d7a79f03de48..1899d5f44bee11 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc @@ -20,7 +20,7 @@ const GroupInfo &GroupInfoAttribute::data() const { return storage()->GetAsKey(); } -const cinn::hlir::framework::newir::CUDAJITInfo &CUDAJITInfoAttribute::data() +const cinn::hlir::framework::pir::CUDAJITInfo &CUDAJITInfoAttribute::data() const { return storage()->GetAsKey(); } diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h index 21724e7e3f6c9b..10bd5ebc300a47 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h +++ b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h @@ -44,7 +44,7 @@ class CUDAJITInfoAttribute : public pir::Attribute { return storage() < right.storage(); } - const cinn::hlir::framework::newir::CUDAJITInfo& data() const; + const cinn::hlir::framework::pir::CUDAJITInfo& data() const; }; } // namespace dialect diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 770e78d191e3dc..36f71616e93ca5 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -4,7 +4,19 @@ if(NOT CINN_ONLY) SRCS group_with_group_merge_pass.cc op_with_group_merge_pass.cc + cinn_group_lowering_pass.cc tensor_node.cc DEPS + pd_op_dialect + pir_compiler + cinn_runtime_dialect) + + cinn_cc_library( + pd_to_cinn_pass + SRCS + pd_to_cinn_pass.cc + DEPS + drr + cinn_op_dialect pd_op_dialect) endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc new file mode 100644 index 00000000000000..5906323650e6de --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc @@ -0,0 +1,191 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h" + +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" + +namespace cinn { +namespace dialect { +namespace ir { + +std::vector GetBlockOutsideInput( + const std::vector op_list) { + std::vector vec_res; + std::unordered_set<::pir::Value> block_inner_output; + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_results(); ++i) { + block_inner_output.insert(op_list[k]->result(i)); + } + } + + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_operands(); ++i) { + if (!block_inner_output.count(op_list[k]->operand_source(i))) { + vec_res.push_back(op_list[k]->operand_source(i)); + } + } + } + + return vec_res; +} + +std::vector GetBlockOutsideOutput( + const std::vector op_list) { + std::vector vec_res; + std::unordered_set<::pir::Value> block_inner_output; + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_operands(); ++i) { + block_inner_output.insert(op_list[k]->operand_source(i)); + } + } + + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_results(); ++i) { + if (!block_inner_output.count(op_list[k]->result(i))) { + vec_res.push_back(op_list[k]->result(i)); + } + } + } + + return vec_res; +} + +std::vector GetOpListNotIncludeYield( + const std::vector& op_list) { + std::vector vec_res; + for (size_t i = 0; i < op_list.size(); ++i) { + if (!op_list[i]->isa()) { + vec_res.push_back(op_list[i]); + } + } + + return vec_res; +} + +std::unique_ptr CINNGroupLoweringPass(::pir::Program* program) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + std::string jit_op_name = cinn::dialect::JitKernelOp::name(); + ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); + + auto ir_program = std::make_unique<::pir::Program>(ctx); + std::unordered_map value_map; + std::vector compiler_list; + + auto target = cinn::common::DefaultNVGPUTarget(); + auto scope = cinn::hlir::framework::BuildScope(target, *program); + + for (auto it = program->block()->begin(); it != program->block()->end(); + ++it) { + if ((*it)->isa()) { + // GetOpList and Call cinn CodeGen + auto group_op = (*it)->dyn_cast(); + + // op fusion + auto op_fusion = cinn::dialect::ir::OpFusionPassInternal( + GetOpListNotIncludeYield(group_op.ops())); + + // fusion merge + auto group_list = + cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion); + + PADDLE_ENFORCE_EQ(group_list.size(), + 1u, + phi::errors::Unimplemented( + "Only support one group after group fusion")); + for (auto group : group_list) { + auto ir_compiler = + new cinn::hlir::framework::PIRCompiler(*program, target, scope); + auto group1 = + std::make_shared(group->nodes); + auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group1}); + compiler_list.push_back(ir_compiler); + std::unordered_map op_attrs{ + {cinn::dialect::JitKernelOp::kAttrName, + cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])}, + }; + + // Generate jit kernel op input and output + auto vec_ins = GetBlockOutsideInput(group->nodes); + + std::vector vec_new_ins; + for (size_t i = 0; i < vec_ins.size(); ++i) { + vec_new_ins.push_back(value_map.at(vec_ins[i])); + } + + auto vec_outs = GetBlockOutsideOutput(group->nodes); + + std::vector vec_types; + for (auto& out : vec_outs) { + vec_types.push_back(out.type()); + } + + ::pir::Operation* cinn_op = + ::pir::Operation::Create(vec_new_ins, op_attrs, vec_types, op_info); + + // for (size_t i = 0; i < vec_outs.size(); ++i) { + // value_map[vec_outs[i]] = cinn_op->result(i); + // } + + // auto yield_op = group_op.ops().back()->dyn_cast(); + for (size_t i = 0; i < group_op.num_results(); ++i) { + value_map[group_op.result(i)] = cinn_op->result(i); + } + + ir_program->block()->push_back(cinn_op); + } + + } else { + std::vector vec_ins; + + for (size_t i = 0; i < (*it)->num_operands(); ++i) { + vec_ins.push_back(value_map.at((*it)->operand_source(i))); + } + + std::vector vec_types; + for (size_t i = 0; i < (*it)->num_results(); ++i) { + vec_types.push_back((*it)->result(i).type()); + } + + ::pir::OpInfo info1 = ctx->GetRegisteredOpInfo((*it)->name()); + ::pir::Operation* op = ::pir::Operation::Create( + vec_ins, (*it)->attributes(), vec_types, info1); + + ir_program->block()->push_back(op); + + value_map[(*it)->result(0)] = op->result(0); + } + } + return ir_program; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h new file mode 100644 index 00000000000000..99d113555a39f1 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h @@ -0,0 +1,27 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/program.h" + +namespace cinn { +namespace dialect { +namespace ir { + +std::unique_ptr CINNGroupLoweringPass(::pir::Program* program); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc index e9c165bbcec523..865a137c80be0d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc @@ -1023,9 +1023,7 @@ class FusionPassRegistrar final : public Registrar { // code generation. class GeneralFusionMergePassHelper { public: - explicit GeneralFusionMergePassHelper(const ::pir::Program* graph, - const GroupList& group_list) - : graph_(graph) { + explicit GeneralFusionMergePassHelper(const GroupList& group_list) { fusion_groups_ = group_list; // init input to consumers. InitInputToConsumers(); @@ -2099,7 +2097,6 @@ class GeneralFusionMergePassHelper { } } - const ::pir::Program* graph_; GroupList fusion_groups_; std::unordered_map fusion_groups_index_; std::unordered_set output_nodes_set_; @@ -2108,14 +2105,13 @@ class GeneralFusionMergePassHelper { input_to_consumers_; }; -GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph, - const GroupList& group_list) { +GroupList GeneralFusionMergePassInternal(const GroupList& group_list) { if (group_list.size() <= 1) { VLOG(3) << "Don't do Fusoin Merge Pass...!"; return group_list; } - GeneralFusionMergePassHelper fusion_merge_pass_helper(graph, group_list); + GeneralFusionMergePassHelper fusion_merge_pass_helper(group_list); auto res = fusion_merge_pass_helper(); return res; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc index 3039d81ff83a35..66977667772745 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" + #include #include #include @@ -19,8 +21,6 @@ #include #include -#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" - #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/operation.h" @@ -40,6 +40,8 @@ std::unordered_map OpKindMap = { {"pd_op.full", OpPatternKind::kElementWise}, {"pd_op.relu", OpPatternKind::kElementWise}, {"pd_op.exp", OpPatternKind::kElementWise}, + {"pd_op.sin", OpPatternKind::kElementWise}, + {"pd_op.cos", OpPatternKind::kElementWise}, {"pd_op.sum", OpPatternKind::kReduction}, {"cinn_op.reduce_sum", OpPatternKind::kReduction}, {"cinn_op.reduce_max", OpPatternKind::kReduction}, @@ -143,19 +145,18 @@ using ConditionFunction = // code generation. class OpFusionPassHelper { public: - explicit OpFusionPassHelper(const ::pir::Program& graph) { + explicit OpFusionPassHelper(const std::vector& op_list) { // init fusion relation InitFusionRelation(); // filter node data, create group for each node // auto nodes_inorder = std::get<0>(graph->topological_order()); - for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) { - auto node = *it; - local_ops_.insert(node); + for (auto it = op_list.begin(); it != op_list.end(); ++it) { + local_ops_.insert(*it); } int index = 0; - for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) { + for (auto it = op_list.begin(); it != op_list.end(); ++it) { auto node = *it; if (node) { nodes_.push_back(node); @@ -491,9 +492,9 @@ class OpFusionPassHelper { std::unordered_map fusion_relation_map_; }; -GroupList OpFusionPassInternal(const ::pir::Program& program) { +GroupList OpFusionPassInternal(const std::vector& op_list) { VLOG(3) << "OpFusionPass...!"; - auto op_fusion_helper = OpFusionPassHelper(program); + auto op_fusion_helper = OpFusionPassHelper(op_list); auto res = op_fusion_helper(); for (size_t i = 0; i < res.size(); ++i) { @@ -502,27 +503,11 @@ GroupList OpFusionPassInternal(const ::pir::Program& program) { for (size_t j = 0; j < group->nodes.size(); ++j) { } } - - // for (auto& group : graph->fusion_groups) { - // VLOG(3) << "Group Id : " << group->group_id; - // for (const auto& producer : group->producer_groups()) { - // VLOG(3) << " producer group -> " << producer->group_id; - // } - // for (const auto& consumer : group->consumer_groups()) { - // VLOG(3) << " consumer group -> " << consumer->group_id; - // } - // } VLOG(3) << "OpFusionPass Finish...!"; return res; } -// void BuildNonFusedGroupsPassInternal(framework::Graph* graph) { -// auto op_fusion_helper = OpFusionPassHelper(graph); -// VLOG(3) << "Apply OpFusionPass to generate initial non-fusion groups"; -// graph->fusion_groups = op_fusion_helper(false); -// } - } // namespace ir } // namespace dialect } // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h index c784140c1cf363..d9e07273791fea 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h @@ -24,10 +24,9 @@ namespace ir { using GroupPtr = std::shared_ptr; using GroupList = std::vector; -GroupList OpFusionPassInternal(const ::pir::Program& program); +GroupList OpFusionPassInternal(const std::vector& op_list); -GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph, - const GroupList& group_list); +GroupList GeneralFusionMergePassInternal(const GroupList& group_list); } // namespace ir } // namespace dialect diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc new file mode 100644 index 00000000000000..eefa4b9ae5c8e3 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class SumOpPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source Pattern + pir::drr::SourcePattern patttern = ctx->SourcePattern(); + const auto &full_int_array = + patttern.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", patttern.Attr("axis_info")}, + {"dtype", patttern.Attr("dtype_2")}, + {"place", patttern.Attr("place_2")}}); + + const auto &sum = patttern.Op(paddle::dialect::SumOp::name(), + {{"dtype", patttern.Attr("dtype")}, + {"keepdim", patttern.Attr("keep_dim")}}); + patttern.Tensor("ret") = sum(patttern.Tensor("arg0"), full_int_array()); + + // Result patterns + pir::drr::ResultPattern res = patttern.ResultPattern(); + const auto &cinn_reduce_sum = + res.Op(cinn::dialect::ReduceSumOp::name(), + {{"axis", patttern.Attr("axis_info")}, + {"keep_dim", patttern.Attr("keep_dim")}}); + res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0")); + } +}; + +class MaxOpPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source Pattern + pir::drr::SourcePattern patttern = ctx->SourcePattern(); + const auto &full_int_array = + patttern.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", patttern.Attr("axis_info")}, + {"dtype", patttern.Attr("dtype_2")}, + {"place", patttern.Attr("place_2")}}); + + const auto &pd_max = patttern.Op(paddle::dialect::MaxOp::name(), + {{"keepdim", patttern.Attr("keep_dim")}}); + patttern.Tensor("ret") = pd_max(patttern.Tensor("arg0"), full_int_array()); + + // Result patterns + pir::drr::ResultPattern res = patttern.ResultPattern(); + const auto &cinn_reduce_max = + res.Op(cinn::dialect::ReduceMaxOp::name(), + {{"axis", patttern.Attr("axis_info")}, + {"keep_dim", patttern.Attr("keep_dim")}}); + res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); + } +}; + +PdOpToCinnOpPass::PdOpToCinnOpPass() : pir::Pass("pd_to_cinn_pass", 1) {} + +bool PdOpToCinnOpPass::Initialize(pir::IrContext *context) { + pir::RewritePatternSet ps(context); + ps.Add(SumOpPattern().Build(context)); + ps.Add(MaxOpPattern().Build(context)); + + patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps)); + return true; +} + +void PdOpToCinnOpPass::Run(pir::Operation *op) { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); +} + +bool PdOpToCinnOpPass::CanApplyOn(pir::Operation *op) const { + return op->isa() && op->num_regions() > 0; +} + +void PdOp2CinnOpConverter(::pir::Program *program) { + pir::IrContext *ctx = pir::IrContext::Instance(); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique()); + + pm.Run(program); +} +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h new file mode 100644 index 00000000000000..d6c0bd2013bbc7 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class PdOpToCinnOpPass : public pir::Pass { + public: + PdOpToCinnOpPass(); + + bool Initialize(pir::IrContext *context) override; + + void Run(pir::Operation *op) override; + + bool CanApplyOn(pir::Operation *op) const override; + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +void PdOp2CinnOpConverter(::pir::Program *program); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc index 56f598b55bf525..2d8833a6acefc0 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc @@ -15,7 +15,7 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/enforce.h" @@ -35,7 +35,7 @@ void JitKernelOp::VerifySig() { "Type of attribute: instruction is not right."); } -const hlir::framework::newir::CUDAJITInfo& JitKernelOp::cuda_jit_info() { +const hlir::framework::pir::CUDAJITInfo& JitKernelOp::cuda_jit_info() { return attributes() .at(kAttrName) .dyn_cast() diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h index 0078d0d3b172d4..0ac3d26c262b74 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/cinn/hlir/framework/new_ir/utils.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/pir/core/op_base.h" namespace cinn { @@ -44,7 +44,7 @@ class JitKernelOp : public ::pir::Op { static constexpr char* kAttrName = "jit_info"; static const char* attributes_name[attributes_num]; - const hlir::framework::newir::CUDAJITInfo& cuda_jit_info(); + const hlir::framework::pir::CUDAJITInfo& cuda_jit_info(); void VerifySig(); }; diff --git a/paddle/cinn/hlir/framework/CMakeLists.txt b/paddle/cinn/hlir/framework/CMakeLists.txt index 1aa6817164a1cf..c353eb3810ff89 100755 --- a/paddle/cinn/hlir/framework/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(new_ir) +add_subdirectory(pir) core_gather_headers() gather_srcs( @@ -24,11 +24,10 @@ gather_srcs( visualize_helper.cc compile_error.cc) -# TODO(Aurelius84): new_ir_compiler depends on pd_op_dialect and could +# TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could # not found under CINN_ONLY mode if(NOT CINN_ONLY) - cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi - pd_op_dialect) + cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi pd_op_dialect) endif() if(WITH_CUDA) diff --git a/paddle/cinn/hlir/framework/new_ir/CMakeLists.txt b/paddle/cinn/hlir/framework/new_ir/CMakeLists.txt deleted file mode 100755 index e08baf06dbd13f..00000000000000 --- a/paddle/cinn/hlir/framework/new_ir/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -if(NOT CINN_ONLY) - core_gather_headers() - gather_srcs(cinnapi_src SRCS utils.cc op_lowering_impl.cc) -endif() diff --git a/paddle/cinn/hlir/framework/new_ir/utils.cc b/paddle/cinn/hlir/framework/new_ir/utils.cc deleted file mode 100644 index 86cf0e187cc45a..00000000000000 --- a/paddle/cinn/hlir/framework/new_ir/utils.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/hlir/framework/new_ir/utils.h" - -namespace cinn { -namespace hlir { -namespace framework { -namespace newir { - -const std::unordered_map CompatibleInfo::OP_NAMES = { - {"pd_op.full", "fill_constant"}, {"pd_op.add", "elementwise_add"}}; - -std::string CompatibleInfo::OpName(const ::pir::Operation& op) { - std::string name = op.name(); - if (OP_NAMES.count(name)) { - return OP_NAMES.at(name); - } - auto pos = name.find("."); - if (pos == std::string::npos) { - return name; - } - auto cinn_op_name = name.substr(pos + 1); - VLOG(4) << "GetOpName: " << name << " -> " << cinn_op_name; - return cinn_op_name; -} - -std::string CompatibleInfo::ValueName(const ::pir::Value& value) { - return CompatibleInfo::kNamePrefix + - std::to_string(std::hash<::pir::Value>()(value)); -} - -std::string CompatibleInfo::OpFuncName(const ::pir::Operation& op) { - std::string op_name = OpName(op); - std::string func_name = - cinn::common::Context::Global().NewName("fn_" + op_name); - return func_name; -} - -std::string CompatibleInfo::GroupOpsName( - const std::vector<::pir::Operation*>& ops) { - std::string name = "fn"; - for (auto* op : ops) { - std::string op_name = OpName(*op); - name += "_" + cinn::common::Context::Global().NewName(op_name); - } - return name; -} - -std::vector CompatibleInfo::InputNames(const ::pir::Operation& op, - bool allow_duplicate) { - std::vector names; - std::unordered_set repeat; - for (int i = 0; i < op.num_operands(); ++i) { - auto value = op.operand_source(i); - std::string name = CompatibleInfo::ValueName(value); - if (!allow_duplicate && repeat.count(name)) { - continue; - } - repeat.insert(name); - names.push_back(name); - } - return names; -} - -std::vector CompatibleInfo::OutputNames(::pir::Operation& op) { - std::vector names; - for (int i = 0; i < op.num_results(); ++i) { - auto value = op.result(i); - std::string name = CompatibleInfo::ValueName(value); - names.push_back(std::move(name)); - } - return names; -} - -} // namespace newir -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index ac52aea80de714..8ae0d5869c1a4c 100644 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -22,7 +22,7 @@ #include "paddle/cinn/hlir/framework/op_lowering_impl_base.h" #include "paddle/cinn/lang/packed_func.h" #ifndef CINN_WITH_ONLY -#include "paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" #endif namespace cinn { @@ -65,13 +65,13 @@ inline OpLowerer CreateOpLowerer( } #ifndef CINN_WITH_ONLY -template +template OpLowerer CreateOpLowerer(const Target&); template <> -inline OpLowerer CreateOpLowerer(const Target& target) { - auto* impl_base = new newir::OpLowererImpl(target); - return OpLowerer(impl_base); +inline OpLowerer CreateOpLowerer(const Target& target) { + auto* impl_base = new pir::OpLowererImpl(target); + return OpLowerer(impl_base); } #endif diff --git a/paddle/cinn/hlir/framework/pir/CMakeLists.txt b/paddle/cinn/hlir/framework/pir/CMakeLists.txt new file mode 100755 index 00000000000000..775bed0d835493 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/CMakeLists.txt @@ -0,0 +1,4 @@ +if(NOT CINN_ONLY) + core_gather_headers() + gather_srcs(cinnapi_src SRCS utils.cc op_lowering_impl.cc op_mapper.cc) +endif() diff --git a/paddle/cinn/hlir/framework/new_ir/group.h b/paddle/cinn/hlir/framework/pir/group.h similarity index 94% rename from paddle/cinn/hlir/framework/new_ir/group.h rename to paddle/cinn/hlir/framework/pir/group.h index 1a67a02e58ca9a..cb6c23c4d1e59e 100644 --- a/paddle/cinn/hlir/framework/new_ir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -16,14 +16,14 @@ #include #include -#include "paddle/cinn/hlir/framework/new_ir/utils.h" #include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/pir/core/operation.h" namespace cinn { namespace hlir { namespace framework { -namespace newir { +namespace pir { using framework::OpPatternKind; // TODO(Aurelius84): Need to be replaced with CinnGroupOp @@ -53,7 +53,7 @@ struct Group { } }; -} // namespace newir +} // namespace pir } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc similarity index 96% rename from paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc rename to paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index ac43f808e7303e..19b613aac4a244 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" #include @@ -22,9 +22,8 @@ #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" -#include "paddle/cinn/hlir/framework/new_ir/utils.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/lang/placeholder.h" -#include "paddle/cinn/utils/attribute_util.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/ddim.h" @@ -33,7 +32,7 @@ PD_DECLARE_bool(cinn_use_cuda_vectorize); namespace cinn { namespace hlir { namespace framework { -namespace newir { +namespace pir { using cinn::hlir::op::ExternalApiRegistry; using common::Type; @@ -47,7 +46,7 @@ ir::Tensor GetTensor(const ::pir::Value& value) { auto dtype = type_info.dtype(); std::string input_id = CompatibleInfo::ValueName(value); return lang::CreatePlaceHolder( - in_shape, utils::ConvertIRType(dtype), input_id); + in_shape, CompatibleInfo::ConvertIRType(dtype), input_id); } std::vector CollectInputTensor( @@ -55,9 +54,8 @@ std::vector CollectInputTensor( std::vector* func_args, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { std::vector tensors; - for (auto in_value : op->operands_source()) { + for (auto in_value : CompatibleInfo::RealOperandSources(*op)) { VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value); - // NOTE(Aurelius84): Need always to create placeholder for input tensor. ir::Tensor tensor = details::GetTensor(in_value); if (!tensor_map->count(in_value)) { // record tensor. @@ -82,7 +80,7 @@ void CollectOutputInfo(::pir::Operation* op, auto type_info = out_value.type().dyn_cast(); - out_types->push_back(utils::ConvertIRType(type_info.dtype())); + out_types->push_back(CompatibleInfo::ConvertIRType(type_info.dtype())); auto out_shape = phi::vectorize(type_info.dims()); out_shapes->push_back(std::move(out_shape)); } @@ -91,7 +89,7 @@ void CollectOutputInfo(::pir::Operation* op, NodeAttr CollectAttrs(const ::pir::Operation& op) { NodeAttr node_attrs; VLOG(4) << "op.attributes():" << op.attributes().size(); - auto attrs = utils::ConvertAttributes(op.attributes()); + auto attrs = CompatibleInfo::ConvertAttributes(op); node_attrs.node_name = CompatibleInfo::OpName(op); node_attrs.attr_store = std::move(attrs); @@ -337,7 +335,6 @@ std::vector OpLowererImpl::LowerOps( const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name); auto op_impl = OpStrategy::SelectImpl(strategy[cinn_op]( node_attrs, op_func_arg_tensors, out_types, out_shapes, this->target_)); - // 2.Perform the lower process of Op std::vector funcs = DoOpLower(op_impl, op, tensor_map, &op_func_arg_tensors); @@ -384,14 +381,14 @@ std::vector OpLowererImpl::DoOpLower( for (int idx = 0; idx < pack.size() - 1; ++idx) { Expr expr = pack[idx]; // Insert the output tensor defined by Compute into the tensor_map - if (pack.size() - 1 > op_results.size()) { + if (pack.size() - 1 > op_results.size() && post == "") { // Some op may output multiple temp tensors in their Compute // definition, but only one output in the graph, and we use id + // "_0"/"_1" as key. // FIXME(Aurelius84): It seems that the implementation is relate with // string name. - // (*tensor_map)[op_results[0] + post] = expr.as_tensor_ref(); - // post = "_" + std::to_string(idx); + (*tensor_map)[op_results[idx]] = expr.as_tensor_ref(); + post = "_" + std::to_string(idx); } else { // If the number of output tensors defined by Compute is less equal than // the output node_data on the graph, then there is a one-to-one @@ -455,7 +452,7 @@ ir::Expr OpLowererImpl::DoOpSchedule( return expr_pack[0].operator ir::Expr(); } -} // namespace newir +} // namespace pir } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h similarity index 98% rename from paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h rename to paddle/cinn/hlir/framework/pir/op_lowering_impl.h index 3fa859bbce880b..ead590526dd407 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h @@ -19,9 +19,9 @@ #include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/instruction.h" -#include "paddle/cinn/hlir/framework/new_ir/group.h" #include "paddle/cinn/hlir/framework/op_lowering_impl_base.h" #include "paddle/cinn/hlir/framework/op_strategy.h" +#include "paddle/cinn/hlir/framework/pir/group.h" #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule_util.h" @@ -36,7 +36,7 @@ namespace cinn { namespace hlir { namespace framework { -namespace newir { +namespace pir { using GroupPtr = std::shared_ptr; @@ -157,7 +157,7 @@ class OpLowererImpl : public OpLowererImplBase { Target target_; }; -} // namespace newir +} // namespace pir } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_mapper.cc b/paddle/cinn/hlir/framework/pir/op_mapper.cc new file mode 100644 index 00000000000000..804a6664dc9445 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_mapper.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/op_mapper.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { + +namespace { + +void AppendAttrForReduceOp(const ::pir::Operation& op, + utils::AttributeMap& attrs) { // NOLINT + auto* source_op = + op.operand_source(/*dim_idx=*/1).dyn_cast<::pir::OpResult>().owner(); + CHECK(source_op->isa()); + auto dim_val = + paddle::dialect::GetInt64Vector(source_op->attributes().at("value")); + std::vector dim(dim_val.begin(), dim_val.end()); + attrs["dim"] = dim; +} + +} // namespace + +#define REGISTER_OPERAND_RULE(OP, args...) \ + operand_funcs_[paddle::dialect::OP::name()] = []() -> std::vector { \ + return {args}; \ + }; + +#define REGISTER_ATTR_RULE(OP, func) \ + attr_funcs_[paddle::dialect::OP::name()] = func; + +void OpMapper::RegisterMapRules() { + // max(x, dim) -> reduce_max(x) + REGISTER_OPERAND_RULE(MaxOp, 0); + REGISTER_OPERAND_RULE(SumOp, 0); + REGISTER_OPERAND_RULE(MinOp, 0); + REGISTER_OPERAND_RULE(ProdOp, 0); + REGISTER_ATTR_RULE(MaxOp, AppendAttrForReduceOp); + REGISTER_ATTR_RULE(SumOp, AppendAttrForReduceOp); + REGISTER_ATTR_RULE(MinOp, AppendAttrForReduceOp); + REGISTER_ATTR_RULE(ProdOp, AppendAttrForReduceOp); +} + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_mapper.h b/paddle/cinn/hlir/framework/pir/op_mapper.h new file mode 100644 index 00000000000000..0a0527cf9abf18 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_mapper.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/cinn/utils/type_defs.h" +#include "paddle/pir/core/operation.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { + +enum MapperType { + OPERAND, + ATTRIBUTE, +}; + +class OpMapper { + using OprandIndexsFunction = std::function()>; + using AppendAttrFunction = + std::function; // NOLINT + + public: + static OpMapper& Instance() { + static OpMapper instance; + return instance; + } + + bool has(const ::pir::Operation& op, MapperType type) const { + if (type == MapperType::OPERAND) { + return operand_funcs_.find(op.name()) != operand_funcs_.end(); + } else if (type == MapperType::ATTRIBUTE) { + return attr_funcs_.find(op.name()) != attr_funcs_.end(); + } + return false; + } + + std::vector<::pir::Value> RealOprandSources( + const ::pir::Operation& op) const { + CHECK(has(op, MapperType::OPERAND)) + << "Not register OprandIndexsFunction for " << op.name(); + std::vector<::pir::Value> inputs; + for (auto idx : operand_funcs_.at(op.name())()) { + inputs.push_back(op.operand_source(idx)); + } + return inputs; + } + + void AppendVariantAttrs(const ::pir::Operation& op, + utils::AttributeMap& attrs) const { // NOLINT + CHECK(has(op, MapperType::ATTRIBUTE)) + << "Not register AppendAttrFunction for " << op.name(); + attr_funcs_.at(op.name())(op, attrs); + } + + private: + OpMapper() { RegisterMapRules(); } + void RegisterMapRules(); + + std::unordered_map operand_funcs_; + std::unordered_map attr_funcs_; +}; + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/utils/attribute_util.h b/paddle/cinn/hlir/framework/pir/utils.cc similarity index 51% rename from paddle/cinn/utils/attribute_util.h rename to paddle/cinn/hlir/framework/pir/utils.cc index 474bc09e2c64c2..2f7b05c72fb302 100644 --- a/paddle/cinn/utils/attribute_util.h +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -12,24 +12,102 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once +#include "paddle/cinn/hlir/framework/pir/utils.h" + #include #include -#include "paddle/cinn/common/type.h" -#include "paddle/cinn/utils/type_defs.h" +#include "paddle/cinn/hlir/framework/pir/op_mapper.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/phi/common/data_type.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" namespace cinn { -namespace utils { +namespace hlir { +namespace framework { +namespace pir { + +const std::unordered_map CompatibleInfo::OP_NAMES = { + {"pd_op.full", "fill_constant"}, + {"pd_op.sum", "reduce_sum"}, + {"pd_op.max", "reduce_max"}, + {"pd_op.add", "elementwise_add"}}; + +std::string CompatibleInfo::OpName(const ::pir::Operation& op) { + std::string name = op.name(); + if (OP_NAMES.count(name)) { + return OP_NAMES.at(name); + } + auto pos = name.find("."); + if (pos == std::string::npos) { + return name; + } + auto cinn_op_name = name.substr(pos + 1); + VLOG(4) << "GetOpName: " << name << " -> " << cinn_op_name; + return cinn_op_name; +} + +std::string CompatibleInfo::ValueName(const ::pir::Value& value) { + return CompatibleInfo::kNamePrefix + + std::to_string(std::hash<::pir::Value>()(value)); +} + +std::string CompatibleInfo::OpFuncName(const ::pir::Operation& op) { + std::string op_name = OpName(op); + std::string func_name = + cinn::common::Context::Global().NewName("fn_" + op_name); + return func_name; +} + +std::string CompatibleInfo::GroupOpsName( + const std::vector<::pir::Operation*>& ops) { + std::string name = "fn"; + for (auto* op : ops) { + std::string op_name = OpName(*op); + name += "_" + cinn::common::Context::Global().NewName(op_name); + } + return name; +} -using NewIR_AttributeMap = std::unordered_map; +std::vector CompatibleInfo::InputNames(const ::pir::Operation& op, + bool allow_duplicate) { + std::vector names; + std::unordered_set repeat; + for (int i = 0; i < op.num_operands(); ++i) { + auto value = op.operand_source(i); + std::string name = CompatibleInfo::ValueName(value); + if (!allow_duplicate && repeat.count(name)) { + continue; + } + repeat.insert(name); + names.push_back(name); + } + return names; +} + +std::vector CompatibleInfo::OutputNames(::pir::Operation& op) { + std::vector names; + for (int i = 0; i < op.num_results(); ++i) { + auto value = op.result(i); + std::string name = CompatibleInfo::ValueName(value); + names.push_back(std::move(name)); + } + return names; +} -Attribute ConvertAttribute(const ::pir::Attribute& src_attr) { - Attribute dst_attr; +std::vector<::pir::Value> CompatibleInfo::RealOperandSources( + const ::pir::Operation& op) { + if (OpMapper::Instance().has(op, MapperType::OPERAND)) { + return OpMapper::Instance().RealOprandSources(op); + } else { + return op.operands_source(); + } +} + +utils::Attribute CompatibleInfo::ConvertAttribute( + const ::pir::Attribute& src_attr) { + utils::Attribute dst_attr; if (src_attr.isa<::pir::BoolAttribute>()) { dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data(); } else if (src_attr.isa<::pir::FloatAttribute>()) { @@ -58,8 +136,10 @@ Attribute ConvertAttribute(const ::pir::Attribute& src_attr) { return dst_attr; } -AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) { - AttributeMap dst_attrs; +utils::AttributeMap CompatibleInfo::ConvertAttributes( + const ::pir::Operation& op) { + auto& src_attrs = op.attributes(); + utils::AttributeMap dst_attrs; for (auto& item : src_attrs) { VLOG(4) << "deal with " << item.first; if (item.first == ::pir::kStopGradientAttrName) { @@ -73,6 +153,10 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) { dst_attrs[item.first] = std::move(ConvertAttribute(item.second)); } } + + if (OpMapper::Instance().has(op, MapperType::ATTRIBUTE)) { + OpMapper::Instance().AppendVariantAttrs(op, dst_attrs); + } VLOG(4) << "dst_attrs.size(): " << dst_attrs.size(); return dst_attrs; } @@ -80,7 +164,7 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) { #define CASE_TYPE(src, dst) \ else if (type.isa<::pir::src>()) return common::dst(); -common::Type ConvertIRType(::pir::Type type) { +common::Type CompatibleInfo::ConvertIRType(::pir::Type type) { if (type.isa<::pir::BFloat16Type>()) return common::BF16(); CASE_TYPE(Float16Type, F16) CASE_TYPE(Float32Type, F32) @@ -96,5 +180,7 @@ common::Type ConvertIRType(::pir::Type type) { LOG(FATAL) << "unknown ir::Type " << type; } -} // namespace utils +} // namespace pir +} // namespace framework +} // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/utils.h b/paddle/cinn/hlir/framework/pir/utils.h similarity index 80% rename from paddle/cinn/hlir/framework/new_ir/utils.h rename to paddle/cinn/hlir/framework/pir/utils.h index 755f11fcae2206..0df3666f6be71b 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.h +++ b/paddle/cinn/hlir/framework/pir/utils.h @@ -16,17 +16,20 @@ #include #include #include "paddle/cinn/common/context.h" +#include "paddle/cinn/common/type.h" +#include "paddle/cinn/utils/type_defs.h" #include "paddle/pir/core/operation.h" namespace cinn { namespace hlir { namespace framework { -namespace newir { +namespace pir { struct CUDAJITInfo { void* fn_ptr; std::vector block_dims; std::vector grid_dims; + void* compiler; }; struct CompatibleInfo { @@ -47,9 +50,18 @@ struct CompatibleInfo { bool allow_duplicate = false); static std::vector OutputNames(::pir::Operation& op); // NOLINT + + static std::vector<::pir::Value> RealOperandSources( + const ::pir::Operation& op); + + static utils::Attribute ConvertAttribute(const ::pir::Attribute& src_attr); + + static utils::AttributeMap ConvertAttributes(const ::pir::Operation& op); + + static common::Type ConvertIRType(::pir::Type type); }; -} // namespace newir +} // namespace pir } // namespace framework } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.cc b/paddle/cinn/hlir/framework/pir_compiler.cc similarity index 85% rename from paddle/cinn/hlir/framework/new_ir_compiler.cc rename to paddle/cinn/hlir/framework/pir_compiler.cc index fbc4c58a5ed9a9..df037a08568d58 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.cc +++ b/paddle/cinn/hlir/framework/pir_compiler.cc @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include -#include "paddle/cinn/hlir/framework/new_ir/utils.h" -#include "paddle/cinn/utils/attribute_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/pir/core/builtin_type.h" @@ -26,24 +25,24 @@ namespace framework { // TODO(Aurelius84): Need abstract this logic to implement Proxy for // the co-existance with GraphCompiler. -std::unique_ptr NewIRCompiler::Build() { +std::unique_ptr PIRCompiler::Build() { m_builder_.Clear(); // NOTE(Aurelius84): Currently only support each op for one group - std::vector groups; + std::vector groups; for (auto it = program_.block()->begin(); it != program_.block()->end(); ++it) { std::vector<::pir::Operation*> ops = {*it}; - groups.push_back(std::make_shared(ops)); + groups.push_back(std::make_shared(ops)); } VLOG(4) << "Groups size: " << groups.size(); return std::move(Build(groups)); } -std::vector NewIRCompiler::BuildCUDAJITInfo( - const std::vector& groups) { - std::vector vec_res; +std::vector PIRCompiler::BuildCUDAJITInfo( + const std::vector& groups) { + std::vector vec_res; - auto op_lowerer = CreateOpLowerer(target_); + auto op_lowerer = CreateOpLowerer(target_); std::vector> lowered_funcs; for (int i = 0; i < groups.size(); ++i) { @@ -62,9 +61,11 @@ std::vector NewIRCompiler::BuildCUDAJITInfo( auto fn_ptrs = compiler_->GetFnPtr(); + auto* compilter_ptr = compiler_.release(); for (int idx = 0; idx < groups.size(); ++idx) { - newir::CUDAJITInfo jit_info; + pir::CUDAJITInfo jit_info; jit_info.fn_ptr = fn_ptrs[idx]; + jit_info.compiler = reinterpret_cast(compilter_ptr); lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo( &(jit_info.block_dims)); @@ -77,9 +78,9 @@ std::vector NewIRCompiler::BuildCUDAJITInfo( return vec_res; } -std::unique_ptr NewIRCompiler::Build( - const std::vector& groups) { - auto op_lowerer = CreateOpLowerer(target_); +std::unique_ptr PIRCompiler::Build( + const std::vector& groups) { + auto op_lowerer = CreateOpLowerer(target_); std::vector> lowered_funcs; for (int i = 0; i < groups.size(); ++i) { @@ -110,7 +111,7 @@ std::unique_ptr NewIRCompiler::Build( return std::make_unique(scope_, std::move(instructions)); } -void NewIRCompiler::ProcessFunction( +void PIRCompiler::ProcessFunction( const std::vector& lowered_funcs) { for (auto&& func : lowered_funcs) { for (auto&& arg : func->args) { @@ -135,8 +136,8 @@ void NewIRCompiler::ProcessFunction( } } -std::vector> NewIRCompiler::BuildInstructions( - const std::vector& groups) { +std::vector> PIRCompiler::BuildInstructions( + const std::vector& groups) { std::vector> instructions; for (int idx = 0; idx < groups.size(); ++idx) { auto& fn_name = groups[idx]->fn_name; @@ -168,7 +169,7 @@ std::shared_ptr BuildScope(const Target& target, if (visited.count(value) > 0) return; visited.emplace(value); - std::string name = newir::CompatibleInfo::ValueName(value); + std::string name = pir::CompatibleInfo::ValueName(value); auto type_info = value.type().dyn_cast(); auto* var = scope->Var(name); auto& tensor = absl::get(*var); @@ -178,7 +179,7 @@ std::shared_ptr BuildScope(const Target& target, shape.push_back(Shape::dim_t(type_info.dims()[i])); } tensor->Resize(Shape{shape}); - tensor->set_type(utils::ConvertIRType(type_info.dtype())); + tensor->set_type(pir::CompatibleInfo::ConvertIRType(type_info.dtype())); }; for (auto it = program.block()->begin(); it != program.block()->end(); ++it) { diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.h b/paddle/cinn/hlir/framework/pir_compiler.h similarity index 80% rename from paddle/cinn/hlir/framework/new_ir_compiler.h rename to paddle/cinn/hlir/framework/pir_compiler.h index 44d92ad1386bf0..c567ec2c44eb29 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.h +++ b/paddle/cinn/hlir/framework/pir_compiler.h @@ -28,11 +28,11 @@ namespace framework { // TODO(Aurelius84): Need abstract this logic to implement Proxy for // the co-existance with GraphCompiler. -class NewIRCompiler final { +class PIRCompiler final { public: - NewIRCompiler(const ::pir::Program& prog, - const Target& target, - const std::shared_ptr& scope) + PIRCompiler(const ::pir::Program& prog, + const Target& target, + const std::shared_ptr& scope) : program_(prog), m_builder_("NewIR", target), target_(target), @@ -40,20 +40,20 @@ class NewIRCompiler final { std::unique_ptr Build(); - std::vector BuildCUDAJITInfo( - const std::vector& groups); + std::vector BuildCUDAJITInfo( + const std::vector& groups); - std::unique_ptr Build(const std::vector& groups); + std::unique_ptr Build(const std::vector& groups); private: - CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler); + CINN_DISALLOW_COPY_AND_ASSIGN(PIRCompiler); std::vector GetOpFunc(const ::pir::Operation& op, int idx); void ProcessFunction(const std::vector& lowered_funcs); std::vector> BuildInstructions( - const std::vector& groups); + const std::vector& groups); const ::pir::Program& program_; ir::Module::Builder m_builder_; diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 215f55f2d18839..a2267e1f6cebdd 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -18,7 +18,7 @@ endif() if(WITH_NCCL OR WITH_RCCL) cc_library( process_group_nccl - SRCS process_group_nccl.cc nccl_tools.cc common.cc + SRCS process_group_nccl.cc common.cc DEPS process_group phi place diff --git a/paddle/fluid/distributed/collective/bkcl_tools.cc b/paddle/fluid/distributed/collective/bkcl_tools.cc index ba5afbbf1feb55..7e95eb8b748eb6 100644 --- a/paddle/fluid/distributed/collective/bkcl_tools.cc +++ b/paddle/fluid/distributed/collective/bkcl_tools.cc @@ -14,8 +14,6 @@ #include "paddle/fluid/distributed/collective/bkcl_tools.h" -#include "paddle/fluid/distributed/collective/types.h" - namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/collective/bkcl_tools.h b/paddle/fluid/distributed/collective/bkcl_tools.h index 533498cd8e119c..19d321080d47af 100644 --- a/paddle/fluid/distributed/collective/bkcl_tools.h +++ b/paddle/fluid/distributed/collective/bkcl_tools.h @@ -14,14 +14,15 @@ #pragma once -#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/distributed/types.h" namespace paddle { namespace distributed { using XPUContext = phi::XPUContext; +using phi::distributed::ReduceOp; #define BKCLCHECK(cmd) \ do { \ diff --git a/paddle/fluid/distributed/collective/custom_ccl_tools.cc b/paddle/fluid/distributed/collective/custom_ccl_tools.cc index ccafcf12a6c26f..15e8b680b7805f 100644 --- a/paddle/fluid/distributed/collective/custom_ccl_tools.cc +++ b/paddle/fluid/distributed/collective/custom_ccl_tools.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" -#include "paddle/fluid/distributed/collective/types.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/collective/custom_ccl_tools.h b/paddle/fluid/distributed/collective/custom_ccl_tools.h index 95557079a8252d..4fb336e929065f 100644 --- a/paddle/fluid/distributed/collective/custom_ccl_tools.h +++ b/paddle/fluid/distributed/collective/custom_ccl_tools.h @@ -22,7 +22,6 @@ #include -#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/collective_helper.h" @@ -30,10 +29,13 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/backends/device_guard.h" #include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/core/distributed/types.h" namespace paddle { namespace distributed { +using phi::distributed::ReduceOp; + phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction); } // namespace distributed diff --git a/paddle/fluid/distributed/collective/nccl_tools.h b/paddle/fluid/distributed/collective/nccl_tools.h deleted file mode 100644 index 135aadd2a24145..00000000000000 --- a/paddle/fluid/distributed/collective/nccl_tools.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/distributed/collective/types.h" - -#ifdef PADDLE_WITH_RCCL -#include -#include "paddle/phi/backends/dynload/rccl.h" -#else -#include -#include "paddle/phi/backends/dynload/nccl.h" -#endif - -namespace paddle { -namespace distributed { - -#define NCCL_CHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - printf("Failed, NCCL error %s:%d '%s'\n", \ - __FILE__, \ - __LINE__, \ - phi::dynload::ncclGetErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -ncclRedOp_t ToNCCLRedType(ReduceOp reduction); - -std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); - -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index e643348eeed0de..8767dfa60cf181 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -20,10 +20,10 @@ #include #include -#include "paddle/fluid/distributed/collective/types.h" -#include "paddle/fluid/eager/utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/types.h" +#include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" @@ -32,24 +32,18 @@ constexpr auto kWaitTimeout = std::chrono::milliseconds(0); namespace paddle { namespace distributed { +using phi::distributed::AllreduceOptions; +using phi::distributed::BarrierOptions; +using phi::distributed::BroadcastOptions; +using phi::distributed::CommType; +using phi::distributed::GatherOptions; +using phi::distributed::GetPartialTensor; +using phi::distributed::ReduceOp; +using phi::distributed::ReduceOptions; +using phi::distributed::ReduceScatterOptions; +using phi::distributed::ScatterOptions; constexpr int kIgnoreId = -1; -enum class CommType : std::uint8_t { - BROADCAST = 0, - ALLREDUCE = 1, - ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce - REDUCE = 3, - ALLGATHER = 4, - GATHER = 5, - SCATTER = 6, - REDUCE_SCATTER = 7, - ALLTOALL = 8, - SEND = 9, - RECV = 10, - BARRIER = 11, - UNKNOWN = 100, -}; - class ProcessGroup { public: class Task { @@ -95,6 +89,15 @@ class ProcessGroup { int GetSize() const { return size_; } + int GetGid() const { return gid_; } + + std::string GetGroupMessage() const { + return std::string("rank_in_group: ") + std::to_string(rank_) + + std::string(", nranks: ") + std::to_string(size_) + + std::string(", gid: ") + std::to_string(gid_) + + std::string(", backend: ") + GetBackendName(); + } + virtual std::string GetBackendName() const = 0; virtual phi::DeviceContext* GetDeviceContext( @@ -294,7 +297,7 @@ class ProcessGroup { const phi::DenseTensor& in_tensor UNUSED, const BroadcastOptions& opts UNUSED, bool sync_op UNUSED, - bool use_calc_stream UNUSED) { + bool use_calc_stream) { PADDLE_THROW( phi::errors::Unimplemented("ProcessGroup%s does not support broadcast " "with sync_op and use_calc_stream flag.", @@ -412,68 +415,57 @@ class ProcessGroup { // legacy APIs // TODO(liyurui): This API will be moved later virtual std::shared_ptr AllReduce( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const AllreduceOptions& UNUSED = AllreduceOptions()) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support allreduce", GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const AllreduceOptions& options = AllreduceOptions()) { + return AllReduce(outputs.data(), inputs.front(), options, false); } virtual std::shared_ptr AllReduce( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const AllreduceOptions& UNUSED, - bool) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support allreduce with sync_op flag", - GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const AllreduceOptions& options, + bool sync_op) { + return AllReduce(outputs.data(), inputs.front(), options, sync_op); } // TODO(sunyilun): methods below will be removed later virtual std::shared_ptr Broadcast( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const BroadcastOptions& UNUSED = BroadcastOptions()) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support broadcast", GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const BroadcastOptions& options = BroadcastOptions()) { + return Broadcast(outputs.data(), inputs.front(), options, false); } virtual std::shared_ptr Broadcast( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const BroadcastOptions& UNUSED, - bool) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support broadcast with sync_op flag", - GetBackendName())); + std::vector& inputs, // NOLINT + std::vector& outputs, // NOLINT + const BroadcastOptions& options, + bool sync_op) { + return Broadcast(outputs.data(), inputs.front(), options, sync_op); } virtual std::shared_ptr Send( - std::vector&, int) { // NOLINT - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support send", GetBackendName())); + std::vector& tensors, int dst_rank) { // NOLINT + return Send(tensors.front(), dst_rank, false); } virtual std::shared_ptr Recv( - std::vector&, int) { // NOLINT - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support recv", GetBackendName())); + std::vector& tensors, int src_rank) { // NOLINT + return Recv(&tensors.front(), src_rank, false); } virtual std::shared_ptr AllGather( - std::vector&, // NOLINT - std::vector&) { // NOLINT - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support all_gather", GetBackendName())); + std::vector& in_tensors, // NOLINT + std::vector& out_tensors) { // NOLINT + return AllGather(out_tensors.data(), in_tensors.front(), false); } virtual std::shared_ptr AllGather( - std::vector&, // NOLINT - std::vector&, // NOLINT - bool) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support all_gather with sync_op flag", - GetBackendName())); + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + bool sync_op) { + return AllGather(out_tensors.data(), in_tensors.front(), sync_op); } virtual std::shared_ptr AllToAll( @@ -484,19 +476,17 @@ class ProcessGroup { } virtual std::shared_ptr Reduce( - std::vector&, // NOLINT - std::vector&, // NOLINT - const ReduceOptions& opts UNUSED) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support reduce", GetBackendName())); + std::vector& ins, // NOLINT + std::vector& outs, // NOLINT + const ReduceOptions& opts) { + return Reduce(outs.data(), ins.front(), opts, false); } virtual std::shared_ptr Scatter( - std::vector&, // NOLINT - std::vector&, // NOLINT - const ScatterOptions&) { - PADDLE_THROW(phi::errors::InvalidArgument( - "ProcessGroup%s does not support scatter", GetBackendName())); + std::vector& ins, // NOLINT + std::vector& outs, // NOLINT + const ScatterOptions& opts) { + return Scatter(outs.data(), ins.front(), opts, false); } protected: diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index 4331041c4f043e..81f52bc97f3342 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/distributed/collective/bkcl_tools.h" #include "paddle/fluid/distributed/collective/common.h" -#include "paddle/fluid/distributed/collective/utils.h" +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/phi/api/lib/utils/allocator.h" diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 64dce7b4c6b116..1313d19a2bbfa3 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/enforce.h" @@ -32,6 +31,8 @@ PD_DECLARE_bool(use_stream_safe_cuda_allocator); namespace paddle { namespace distributed { +using phi::distributed::CheckSizeOnEachRank; +using phi::distributed::GetPointerByOffset; static std::mutex g_unfinished_xccl_task_events_mutex; static std::list> g_unfinished_xccl_task_events; diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 13970b2e349a0e..a3fb060376597a 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -22,6 +22,7 @@ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h" +#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/device_manager.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 89f5dcb222e63b..8877224eb7674d 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -15,21 +15,27 @@ #include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/distributed/collective/common.h" -#include "paddle/fluid/distributed/collective/nccl_tools.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" #include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/comm_task_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_task.h" +#include "paddle/phi/core/distributed/nccl_tools.h" +#include "paddle/phi/core/distributed/trace_utils.h" +#include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/utils/data_type.h" -#include "paddle/phi/core/distributed/comm_context_manager.h" - +PHI_DECLARE_bool(benchmark); +PHI_DECLARE_bool(benchmark_nccl); PHI_DECLARE_bool(nccl_blocking_wait); -PD_DECLARE_bool(use_stream_safe_cuda_allocator); +PHI_DECLARE_bool(use_stream_safe_cuda_allocator); +PHI_DECLARE_bool(enable_async_trace); // set this flag to `true` and recompile to enable dynamic checks constexpr bool FLAGS_enable_nccl_dynamic_check = false; @@ -38,6 +44,17 @@ constexpr int64_t kWaitBlockTImeout = 10; namespace paddle { namespace distributed { +using phi::distributed::CheckSizeOnEachRank; +using phi::distributed::GetTraceEndKey; +using phi::distributed::GetTraceStartKey; +using phi::distributed::IsP2POP; +using phi::distributed::NCCLDTypeToString; +using phi::distributed::NCCLRedTypeToString; +using phi::distributed::SerializeNCCLUniqueId; +using phi::distributed::ToNCCLRedType; + +uint64_t ProcessGroupNCCL::s_group_call_counter = 0; + ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place, int rank, CommType comm_type, @@ -60,7 +77,7 @@ void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { // Warning here when use calc stream but also invoke waiting explicitly. if (UseCalcStream()) { - VLOG(3) << "Warning: The communication is on calc stream, wait here is " + VLOG(5) << "Warning: The communication is on calc stream, wait here is " "useless."; return true; } @@ -80,7 +97,7 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { // If we use the work to do barrier, we should block cpu #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); -#else +#else // PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif } @@ -94,20 +111,40 @@ ProcessGroupNCCL::ProcessGroupNCCL( const std::shared_ptr& store, int rank, int size, - int gid) - : ProcessGroupWithStream(rank, size, gid), store_(store) {} + int gid, + int64_t timeout) + : ProcessGroupWithStream(rank, size, gid), + store_(store), + pg_timeout_(timeout) { + LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_; +} void ProcessGroupNCCL::GroupStart() { NCCL_CHECK(phi::dynload::ncclGroupStart()); + ++s_group_call_counter; } -void ProcessGroupNCCL::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); } +void ProcessGroupNCCL::GroupEnd() { + NCCL_CHECK(phi::dynload::ncclGroupEnd()); + --s_group_call_counter; + // NOTE: This is to sync the calc stream and comm stream for debug using + // batch_isend_irecv + if (FLAGS_benchmark || FLAGS_benchmark_nccl) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + } +} phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const Place& place) const { return GetDeviceContext(place, /*use_calc_stream*/ false); } +// NOTE(shenliang03): GetDeviceContext is only used for collective, it can't +// be used for p2p op. phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const Place& place, bool use_calc_stream) const { const std::string& key = GetKeyFromPlace(place); @@ -146,9 +183,21 @@ std::shared_ptr ProcessGroupNCCL::AllGather( // numel > 0 indicates the tensor need to be sliced const phi::DenseTensor& in_tensor_maybe_partial = numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclAllGather] " + << "sendbuff: " << in_tensor_maybe_partial.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor_maybe_partial.numel() + << ", datatype: " + << NCCLDTypeToString( + phi::ToNCCLDataType(in_tensor_maybe_partial.dtype())) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", offset: " << offset + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream); }, in_tensor_maybe_partial, @@ -163,9 +212,21 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( const AllreduceOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclAllReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + comm_context->AllReduce( out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream); }, @@ -191,9 +252,15 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( // simply be covered by static checks. Factors are set to 0 here to skip the // shape check. Its shape check will be done by dynamic checks with // FLAGS_enable_nccl_dynamic_check. - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + phi::distributed::CommStaticCheck::CheckShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + /*out_size_factor*/ 0, + /*in_size_factor*/ 0); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckShape( *out_tensor, @@ -203,13 +270,27 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( size_, comm_context->GetNcclComm()); } - int64_t in_row_size = in_tensor.numel() / in_dim[0], out_row_size = out_tensor->numel() / out_dim[0]; int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; phi::DenseTensor input_partial, output_partial; - comm_context->GroupStart(); + VLOG(3) << "[AllToAll] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", out_size_each_rank: " + << string::join_strings(out_size_each_rank, ',') + << ", in_size_each_rank: " + << string::join_strings(in_size_each_rank, ',') + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + + GroupStart(); for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); @@ -221,7 +302,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( comm_context->Recv(&output_partial, out_numel, i, stream); out_offset += out_numel; } - comm_context->GroupEnd(); + GroupEnd(); }, in_tensor, CommType::ALLTOALL, @@ -241,6 +322,9 @@ std::shared_ptr ProcessGroupNCCL::Barrier( phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); phi::DenseTensor barrier_tensor{allocator.get(), meta}; + VLOG(3) << "[Barrier] " + << "barrier opt: " << opts.device_id; + auto task = AllReduce(&barrier_tensor, barrier_tensor, {}, @@ -257,10 +341,21 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( const BroadcastOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { int root = opts.source_rank + opts.source_root; - auto comm_context = this->GetCommContext(); + + VLOG(3) << "[ncclBroadcast] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << root + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->Broadcast(out_tensor, in_tensor, root, stream); }, in_tensor, @@ -275,9 +370,21 @@ std::shared_ptr ProcessGroupNCCL::Reduce( const ReduceOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", root: " << opts.root_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->Reduce(out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), @@ -296,9 +403,20 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter( const ReduceScatterOptions& opts, bool sync_op, bool use_calc_stream) { - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { + VLOG(3) << "[ncclReduceScatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); comm_context->ReduceScatter( out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream); }, @@ -320,9 +438,8 @@ std::shared_ptr ProcessGroupNCCL::Scatter( /*dst_rank*/ opts.root_rank, /*cur_rank*/ rank_, size_); - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Collective( + [&](phi::distributed::NCCLCommContext* comm_context, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckShape( *out_tensor, @@ -331,18 +448,30 @@ std::shared_ptr ProcessGroupNCCL::Scatter( comm_context->GetNcclComm()); } + VLOG(3) << "[Scatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + int64_t numel = in_tensor.numel() / size_; if (rank_ == opts.root_rank) { int64_t offset = 0; phi::DenseTensor partial_tensor; - comm_context->GroupStart(); + GroupStart(); for (auto i = 0; i < size_; i++) { partial_tensor = GetPartialTensor(in_tensor, offset, numel); comm_context->Send(partial_tensor, numel, i, stream); offset += numel; } comm_context->Recv(out_tensor, numel, opts.root_rank, stream); - comm_context->GroupEnd(); + GroupEnd(); } else { comm_context->Recv(out_tensor, numel, opts.root_rank, stream); } @@ -385,8 +514,8 @@ std::shared_ptr ProcessGroupNCCL::Gather( "root world size [%d] is less than root rank [%d]", size_, opts.root_rank)); - auto gather_func = [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + auto gather_func = [&](phi::distributed::NCCLCommContext* comm_context, + gpuStream_t stream) { // shape check if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckGatherShape( @@ -398,7 +527,17 @@ std::shared_ptr ProcessGroupNCCL::Gather( comm_context->GetNcclComm()); } - comm_context->GroupStart(); + VLOG(3) << "[Gather] " + << "sendbuff: " << in_tensor.data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << GetGroupMessage(); + + GroupStart(); // root receive from all devices if (rank_ == opts.root_rank) { for (auto i = 0; i < size_; i++) { @@ -408,9 +547,9 @@ std::shared_ptr ProcessGroupNCCL::Gather( } // send to root comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream); - comm_context->GroupEnd(); + GroupEnd(); }; - return RunFnInNCCLEnv( + return Collective( gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); } @@ -428,11 +567,25 @@ std::shared_ptr ProcessGroupNCCL::Recv( tensor = &partial_tensor; } - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); - comm_context->Recv(tensor, tensor->numel(), src_rank, stream); + return Point2Point( + [&](phi::distributed::NCCLCommContext* comm_context, + gpuStream_t stream, + int rank_in_group) { + VLOG(3) << "[ncclRecv] " + << "recvbuff: " << tensor->data() + << ", count: " << tensor->numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(tensor->dtype())) + << ", src_in_group: " << src_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + + comm_context->Recv(tensor, tensor->numel(), rank_in_group, stream); }, + src_rank, *tensor, CommType::RECV, sync_op, @@ -450,14 +603,29 @@ std::shared_ptr ProcessGroupNCCL::Send( const phi::DenseTensor& tensor_maybe_partial = numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; - return RunFnInNCCLEnv( - [&](gpuStream_t stream) { - auto comm_context = this->GetCommContext(); + return Point2Point( + [&](phi::distributed::NCCLCommContext* comm_context, + gpuStream_t stream, + int rank_in_group) { + VLOG(3) << "[ncclSend] " + << "sendbuff: " << tensor_maybe_partial.data() + << ", count: " << tensor_maybe_partial.numel() << ", datatype: " + << NCCLDTypeToString( + phi::ToNCCLDataType(tensor_maybe_partial.dtype())) + << ", dst_in_group: " << dst_rank + << ", ncclcomm: " << comm_context->GetNcclComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream + << GetGroupMessage(); + comm_context->Send(tensor_maybe_partial, tensor_maybe_partial.numel(), - dst_rank, + rank_in_group, stream); }, + dst_rank, tensor_maybe_partial, CommType::SEND, sync_op, @@ -474,84 +642,133 @@ std::shared_ptr ProcessGroupNCCL::CreateTask( place, rank, comm_type, is_sync, use_calc_stream); } -void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) { - const std::string key = - "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0"; - if (rank_ == 0) { - std::vector nccl_id_wrapper( - reinterpret_cast(nccl_id), - reinterpret_cast(nccl_id) + NCCL_UNIQUE_ID_BYTES); - store_->set(key, nccl_id_wrapper); +void ProcessGroupNCCL::GetStoreKey(const std::string& place_key, + CommType comm_type, + std::string* store_key) { + bool is_batch_p2p = s_group_call_counter > 0; + bool is_p2p_op = IsP2POP(comm_type, is_batch_p2p); + + if (!is_p2p_op) { + *store_key = "nccl_ids/" + std::to_string(gid_) + "/0"; } else { - const auto& nccl_id_wrapper = store_->get(key); - std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size()); + *store_key = "nccl_ids/" + std::to_string(gid_) + "/" + place_key; } } void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, - const std::string& place_key) { - if (!place_to_comm_ctx_.empty()) { - VLOG(3) << "Warning: Tensors from multiple devices are not supported yet."; + const std::string& place_key, + const std::string& store_key, + CommType comm_type, + int p2p_rank) { + VLOG(3) << "init nccl rank_in_group: " << rank_ << ", nranks: " << size_ + << ", gid: " << gid_ << ", place key: " << place_key + << ", store_key: " << store_key; + + for (size_t i = 0; i < s_group_call_counter; ++i) { + NCCL_CHECK(phi::dynload::ncclGroupEnd()); } - VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ - << ", place: " << place_key; + bool is_batch_p2p = s_group_call_counter > 0; + bool is_p2p_op = IsP2POP(comm_type, is_batch_p2p); + + int num_ranks = is_p2p_op ? 2 : GetSize(); + int rank = is_p2p_op ? p2p_rank : GetRank(); + NCCL_CHECK(phi::dynload::ncclGroupStart()); + + phi::distributed::P2POption p2p_opts({is_p2p_op, p2p_rank, num_ranks, rank}); phi::distributed::CommContextManager::CreateNCCLCommContext( - store_, std::to_string(gid_), rank_, size_); + store_, store_key, rank_, size_, "", &p2p_opts); + + NCCL_CHECK(phi::dynload::ncclGroupEnd()); + + auto nccl_comm_ctx = this->GetCommContext(&store_key); + VLOG(3) << "Get nccl comm: " << nccl_comm_ctx->GetNcclComm() + << " for place_key: " << place_key << " on rank_in_group: " << rank + << " nranks: " << num_ranks << " gid: " << gid_; - auto* calc_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); auto comm_ctx = std::make_unique(place); - auto nccl_comm_ctx = this->GetCommContext(); comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm()); + auto* calc_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + place_to_calc_event_.emplace( place_key, platform::DeviceEvent(place, platform::GenerateDeviceEventFlag())); place_to_calc_ctx_.emplace(place_key, calc_ctx); place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx)); - // TODO(sunyilun): for compatibility, will be removed later - std::vector comm_ctx_wrapper{ - place_to_comm_ctx_[place_key].get()}; - places_to_ctx_.emplace(place_key, comm_ctx_wrapper); + for (size_t i = 0; i < s_group_call_counter; ++i) { + NCCL_CHECK(phi::dynload::ncclGroupStart()); + } } -void ProcessGroupNCCL::SyncCalcStream(const Place& place) { - const std::string& key = GetKeyFromPlace(place); - auto& calc_event = place_to_calc_event_.at(key); - const auto* calc_ctx = place_to_calc_ctx_.at(key); - const auto* comm_ctx = place_to_comm_ctx_.at(key).get(); +void ProcessGroupNCCL::SyncCalcStream(const Place& place, + const std::string& place_key) { + auto& calc_event = place_to_calc_event_.at(place_key); + const auto* calc_ctx = place_to_calc_ctx_.at(place_key); + const auto* comm_ctx = place_to_comm_ctx_.at(place_key).get(); calc_event.Record(calc_ctx); calc_event.Wait(platform::Place2DeviceType(place), comm_ctx); } -std::shared_ptr ProcessGroupNCCL::RunFnInNCCLEnv( - std::function fn, +std::shared_ptr ProcessGroupNCCL::Collective( + std::function fn, const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, bool use_calc_stream) { + comm_seq_++; const auto& place = tensor.place(); const auto& key = GetKeyFromPlace(place); platform::CUDADeviceGuard cuda_guard(place); + std::string store_key; + GetStoreKey(key, comm_type, &store_key); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateNCCLEnvCache(place, key); + CreateNCCLEnvCache(place, key, store_key, comm_type); } if (!use_calc_stream) { - SyncCalcStream(place); + SyncCalcStream(place, key); } auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); const auto* calc_ctx = place_to_calc_ctx_.at(key); const auto& comm_ctx = place_to_comm_ctx_.at(key); + auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - fn(nccl_stream); + + auto nccl_comm_ctx = this->GetCommContext(&store_key); + + if (!FLAGS_enable_async_trace) { + fn(nccl_comm_ctx, nccl_stream); + } else { + auto comm_task = + std::make_shared(place, + rank_, + size_, + gid_, + comm_seq_, + tensor.numel(), + sync_op, + use_calc_stream, + nccl_comm, + nccl_stream, + comm_type, + pg_timeout_); + comm_task->StartRecord(); + fn(nccl_comm_ctx, nccl_stream); + comm_task->EndRecord(); + comm_task->SetStore(store_); + + auto& comm_task_manager = phi::distributed::CommTaskManager::GetInstance(); + comm_task_manager.CommTaskEnqueue(std::move(comm_task)); + } if (!use_calc_stream) { if (FLAGS_use_stream_safe_cuda_allocator) { @@ -564,443 +781,145 @@ std::shared_ptr ProcessGroupNCCL::RunFnInNCCLEnv( task->SetBlockCPUInWait(); task->Wait(); } - return task; -} -// TODO(sunyilun): methods below will be removed later -void SyncDefaultStream(const std::vector& places, - platform::DeviceEvent& nccl_event, // NOLINT - std::vector& dev_ctx) { // NOLINT - for (size_t i = 0; i < places.size(); ++i) { - auto* default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(places[i])); - nccl_event.Record(default_ctx); - nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]); - } -} - -std::shared_ptr ProcessGroupNCCL::CreateTask( - std::vector places, - int rank, - CommType comm_type, - const std::vector& inputs) { - return std::make_shared( - places, rank, comm_type, inputs); -} - -ProcessGroupNCCL::NCCLTask::NCCLTask( - const std::vector& places, - int rank, - CommType CommType, - const std::vector& inputs) - : TaskStream(rank, inputs, CommType), - comm_event_(places[0], platform::GenerateDeviceEventFlag()), - task_place_(places[0]) {} - -// create NCCLManager cache for places_key -void ProcessGroupNCCL::CreateNCCLManagerCache( - const std::string& places_key, const std::vector& places) { - PADDLE_ENFORCE_EQ(places_key.empty(), - false, - phi::errors::PreconditionNotMet( - "Not able to create/get the NCCL Communicator since " - "the GPU place are not known")); - - ncclUniqueId nccl_id; - if (rank_ == 0) { - NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id)); + if (sync_op) { + task->Wait(); } - BroadcastUniqueNCCLID(&nccl_id); - - VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ - << ", place: " << places_key - << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); - std::vector> dev_ctx; - dev_ctx.resize(places.size()); - - std::vector dev_ctx_raw; - dev_ctx_raw.resize(places.size()); - - GroupStart(); - - for (size_t i = 0; i < places.size(); ++i) { - platform::CUDADeviceGuard guard(places[i]); - - dev_ctx[i] = std::make_unique(places[i]); - ncclComm_t nccl_comm; - NCCL_CHECK(phi::dynload::ncclCommInitRank( - &nccl_comm, GetSize(), nccl_id, GetRank())); - dev_ctx[i]->set_nccl_comm(nccl_comm); - dev_ctx_raw[i] = dev_ctx[i].get(); + if (FLAGS_benchmark || FLAGS_benchmark_nccl) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif } - GroupEnd(); - - // TODO(sunyilun): for compatibility, will be removed later - place_to_calc_event_.emplace( - places_key, - platform::DeviceEvent(places[0], platform::GenerateDeviceEventFlag())); - place_to_calc_ctx_.emplace( - places_key, - static_cast( - platform::DeviceContextPool::Instance().Get(places[0]))); - place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0])); - - // These caches will be useful to process sync/wait/communicate - places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw)); + return task; } -template -std::shared_ptr ProcessGroupNCCL::Collective( - std::vector& inputs, - std::vector& outputs, - Fn fn, - CommType op_type) { - const auto places = GetPlaceList(inputs); - const auto key = GetKeyFromPlaces(places); - - { - std::lock_guard lock(mutex_); - if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateNCCLManagerCache(key, places); - } - } - - SyncDefaultStream( - places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); - - auto task = CreateTask(places, rank_, op_type, inputs); +std::shared_ptr ProcessGroupNCCL::Point2Point( + std::function + fn, + int peer, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream) { + const auto& place = tensor.place(); - // construct uninitialize guard for device - platform::CUDADeviceGuard cuda_guard; + int p2p_rank = 0; + int p2p_target_rank = 0; + bool is_batch_p2p = s_group_call_counter > 0; + std::string key = ""; - { - platform::NCCLGroupGuard nccl_guard; - for (size_t i = 0; i < inputs.size(); ++i) { - cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream(); - fn(inputs[i], - outputs[i], - places_to_ctx_.at(key)[i]->nccl_comm(), - nccl_stream); - } + if (is_batch_p2p) { + key = GetKeyFromPlace(place); + p2p_rank = rank_; + p2p_target_rank = peer; + } else { + int low_rank = rank_ < peer ? rank_ : peer; + int high_rank = rank_ < peer ? peer : rank_; + key = std::to_string(low_rank) + "->" + std::to_string(high_rank); + p2p_rank = rank_ < peer ? 0 : 1; + p2p_target_rank = 1 - p2p_rank; } - if (FLAGS_use_stream_safe_cuda_allocator) { - for (size_t i = 0; i < inputs.size(); ++i) { - cuda_guard.SetDevice(places[i]); - memory::RecordStream(inputs[i].Holder(), - places_to_ctx_.at(key)[i]->stream()); - } - } + platform::CUDADeviceGuard cuda_guard(place); + + std::string store_key; + GetStoreKey(key, comm_type, &store_key); - for (size_t i = 0; i < inputs.size(); ++i) { - cuda_guard.SetDevice(places[i]); - task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateNCCLEnvCache(place, key, store_key, comm_type, p2p_rank); } - return task; -} -template -std::shared_ptr ProcessGroupNCCL::PointToPoint( - std::vector& tensors, - Fn fn, - int dst_rank, - CommType op_type) { - const auto places = GetPlaceList(tensors); - const auto key = GetKeyFromPlaces(places); - - { - std::lock_guard lock(mutex_); - if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateNCCLManagerCache(key, places); - } + if (!use_calc_stream) { + SyncCalcStream(place, key); } - SyncDefaultStream( - places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); + auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto& comm_ctx = place_to_comm_ctx_.at(key); - auto task = CreateTask(places, rank_, op_type, tensors); + auto nccl_comm = comm_ctx->nccl_comm(); + auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - // construct uninitialize guard for device - platform::CUDADeviceGuard cuda_guard; + auto comm_task = + std::make_shared(place, + rank_, + size_, + gid_, + comm_seq_, + tensor.numel(), + sync_op, + use_calc_stream, + nccl_comm, + nccl_stream, + comm_type); + + auto nccl_comm_ctx = this->GetCommContext(&store_key); + + if (!FLAGS_enable_async_trace) { + fn(nccl_comm_ctx, nccl_stream, p2p_target_rank); + } else { + comm_task->StartRecord(); + fn(nccl_comm_ctx, nccl_stream, p2p_target_rank); + comm_task->EndRecord(); + comm_task->SetStore(store_); - { - platform::NCCLGroupGuard nccl_guard; - for (size_t i = 0; i < tensors.size(); ++i) { - cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream(); - fn(tensors[i], - places_to_ctx_.at(key)[i]->nccl_comm(), - nccl_stream, - dst_rank); - } + auto& comm_task_manager = phi::distributed::CommTaskManager::GetInstance(); + comm_task_manager.CommTaskEnqueue(std::move(comm_task)); } - if (FLAGS_use_stream_safe_cuda_allocator) { - for (size_t i = 0; i < tensors.size(); ++i) { - cuda_guard.SetDevice(places[i]); - memory::RecordStream(tensors[i].Holder(), - places_to_ctx_.at(key)[i]->stream()); + if (!use_calc_stream) { + if (FLAGS_use_stream_safe_cuda_allocator) { + memory::RecordStream(tensor.Holder(), nccl_stream); } + task->UpdateWaitChain(*comm_ctx); } - for (size_t i = 0; i < tensors.size(); ++i) { - cuda_guard.SetDevice(places[i]); - task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); + if (FLAGS_enable_nccl_dynamic_check) { + task->SetBlockCPUInWait(); + task->Wait(); } - return task; -} - -std::shared_ptr ProcessGroupNCCL::AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - comm_context->AllReduce( - &output, input, ToNCCLRedType(opts.reduce_op), stream); - }, - CommType::ALLREDUCE); -} -std::shared_ptr ProcessGroupNCCL::Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - const auto root = - opts.source_rank * in_tensors.size() + opts.source_root; - auto comm_context = this->GetCommContext(); - comm_context->Broadcast(&output, input, root, stream); - }, - CommType::BROADCAST); -} - -void CheckTensorsInDifferentDevices( - const std::vector& tensors, const size_t num_devices) { - PADDLE_ENFORCE_EQ( - tensors.empty(), - false, - phi::errors::InvalidArgument("Tensor list must be nonempty.")); - PADDLE_ENFORCE_LE( - tensors.size(), - num_devices, - phi::errors::InvalidArgument( - "Tensor list mustn't be larger than the number of available GPUs.")); - - std::set used_devices; - - for (const auto& t : tensors) { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(t.place()), - true, - phi::errors::InvalidArgument("Tensors must be CUDA and dense tensor.")); - - const auto inserted = used_devices.insert(t.place()).second; - PADDLE_ENFORCE_EQ(inserted, - true, - phi::errors::InvalidArgument( - "Tensors must be on distinct GPU devices.")); + if (sync_op) { + task->Wait(); } -} -std::shared_ptr ProcessGroupNCCL::Send( - std::vector& tensors, int dst_rank) { - CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); - - auto task = PointToPoint( - tensors, - [&](phi::DenseTensor& input, - ncclComm_t comm, - const gpuStream_t& stream, - int dst_rank) { - auto comm_context = this->GetCommContext(); - comm_context->Send(input, input.numel(), dst_rank, stream); - }, - dst_rank, - CommType::SEND); - return task; -} + if (!is_batch_p2p && (FLAGS_benchmark || FLAGS_benchmark_nccl)) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + } -std::shared_ptr ProcessGroupNCCL::Recv( - std::vector& tensors, int src_rank) { - CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); - - auto task = PointToPoint( - tensors, - [&](phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream, - int src_rank) { - auto comm_context = this->GetCommContext(); - comm_context->Recv(&output, output.numel(), src_rank, stream); - }, - src_rank, - CommType::RECV); return task; } -std::shared_ptr ProcessGroupNCCL::AllGather( - std::vector& in_tensors, - std::vector& out_tensors) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - phi::errors::InvalidArgument("All outputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - comm_context->AllGather(&output, input, stream); - }, - CommType::ALLGATHER); -} - -std::shared_ptr ProcessGroupNCCL::AllToAll( - std::vector& in_tensors, - std::vector& out_tensors) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - size_t offset = 0; - size_t count = input.numel() / size_; - auto comm_context = this->GetCommContext(); - comm_context->GroupStart(); - for (auto i = 0; i < size_; i++) { - auto input_data = GetPartialTensor(input, offset, count); - comm_context->Send(input_data, count, i, stream); - auto output_data = GetPartialTensor(output, offset, count); - comm_context->Recv(&output_data, count, i, stream); - offset += count; - } - comm_context->GroupEnd(); - }, - CommType::ALLTOALL); -} - -std::shared_ptr ProcessGroupNCCL::Reduce( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - comm_context->Reduce(&output, - input, - ToNCCLRedType(opts.reduce_op), - opts.root_rank, - stream); - }, - CommType::REDUCE); -} - -std::shared_ptr ProcessGroupNCCL::Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - auto comm_context = this->GetCommContext(); - size_t offset = 0; - size_t count = input.numel() / size_; - if (rank_ == opts.root_rank) { - comm_context->GroupStart(); - for (auto i = 0; i < size_; i++) { - auto input_data = reinterpret_cast( - GetPointerByOffset(input.data(), offset, input.dtype())); - comm_context->Send(*input_data, count, i, stream); - offset += count; - } - comm_context->Recv(&output, count, opts.root_rank, stream); - comm_context->GroupEnd(); - } else { - comm_context->Recv(&output, count, opts.root_rank, stream); - } - }, - CommType::SCATTER); -} - std::shared_ptr ProcessGroupNCCL::CreateProcessGroupNCCL( const std::shared_ptr& store, int rank, int size, - int gid) { + int gid, + int64_t timeout) { auto process_group = - std::make_shared(store, rank, size, gid); + std::make_shared(store, rank, size, gid, timeout); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); return process_group; } -phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext() { +phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext( + const std::string* key) { + std::string store_key = std::to_string(this->gid_); + if (key && !key->empty()) { + store_key = *key; + } const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); auto comm_context = static_cast( - comm_context_manager.Get(std::to_string(this->gid_))); + comm_context_manager.Get(store_key)); PADDLE_ENFORCE_NE(comm_context, nullptr, phi::errors::Unavailable("NCCLCommContext is nullptr")); diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index b4f90dea777619..96c907e622b170 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -71,12 +71,14 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { const std::shared_ptr& store, int rank, int size, - int gid); + int gid, + int64_t timeout); ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, - int gid); + int gid, + int64_t timeout = 30 * 60 * 1000); std::string GetBackendName() const override { return "NCCL"; } @@ -170,42 +172,6 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ncclComm_t NCCLComm(const Place& place) const; - // TODO(liyurui): This API will be moved later - std::shared_ptr AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& = AllreduceOptions()) override; - - // TODO(sunyilun): methods below will be removed later - std::shared_ptr Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& = BroadcastOptions()) override; - - std::shared_ptr Send( - std::vector& tensors, int dst_rank) override; - - std::shared_ptr Recv( - std::vector& tensors, int src_rank) override; - - std::shared_ptr AllGather( - std::vector& in_tensors, - std::vector& out_tensors) override; - - std::shared_ptr AllToAll( - std::vector& in_tensors, - std::vector& out_tensors) override; - - std::shared_ptr Reduce( - std::vector& tensors, - std::vector& out_tensors, - const ReduceOptions& opts) override; - - std::shared_ptr Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts) override; - private: std::shared_ptr CreateTask(const Place& place, int rank, @@ -213,44 +179,36 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream); - void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id); + void GetStoreKey(const std::string& place_key, + CommType comm_type, + std::string* store_key); - void CreateNCCLEnvCache(const Place& place, const std::string& place_key); + void CreateNCCLEnvCache(const Place& place, + const std::string& place_key, + const std::string& store_key, + CommType comm_type, + int p2p_rank = 0); - void SyncCalcStream(const Place& place); + void SyncCalcStream(const Place& place, const std::string& place_key); - std::shared_ptr RunFnInNCCLEnv( - std::function fn, + std::shared_ptr Collective( + std::function fn, const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, bool use_calc_stream); - // TODO(sunyilun): methods below will be removed later - std::shared_ptr CreateTask( - std::vector places, - int rank, - CommType op_type, - const std::vector& inputs); - - template - std::shared_ptr Collective( - std::vector& inputs, // NOLINT - std::vector& outputs, // NOLINT - Fn fn, - CommType op_type); - - template - std::shared_ptr PointToPoint( - std::vector& tensors, // NOLINT - Fn fn, - int dst_rank, - CommType op_type); - - void CreateNCCLManagerCache(const std::string& places_key, - const std::vector& places); + std::shared_ptr Point2Point( + std::function + fn, + int peer, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream); - phi::distributed::NCCLCommContext* GetCommContext(); + phi::distributed::NCCLCommContext* GetCommContext( + const std::string* key = nullptr); private: std::shared_ptr store_; @@ -261,9 +219,13 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { std::unordered_map> place_to_comm_ctx_; + uint64_t comm_seq_{0}; + // TODO(sunyilun): attrs below will be removed later std::mutex mutex_; - std::unordered_map> places_to_ctx_; + static uint64_t s_group_call_counter; + // default 30 minutes + int64_t pg_timeout_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/collective/utils.h b/paddle/fluid/distributed/collective/utils.h deleted file mode 100644 index 90149f88bbc4f2..00000000000000 --- a/paddle/fluid/distributed/collective/utils.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace paddle { -namespace distributed { - -inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor& tensor, - int64_t offset, - int64_t numel) { - phi::DenseTensor tensor_flattened; - tensor_flattened.ShareDataWith(tensor); - tensor_flattened.Resize({tensor.numel()}); - return tensor_flattened.Slice(offset, offset + numel); -} - -inline void* GetPointerByOffset(void* raw_pointer, - size_t offset, - phi::DataType type) { - if (type == phi::DataType::FLOAT32) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::FLOAT64) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::FLOAT16) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT32) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT64) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT8) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::UINT8) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::BOOL) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::BFLOAT16) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Datatype %s in NCCL is not supported.", type)); - } - return nullptr; -} - -inline void CheckSizeOnEachRank(const phi::DDim& tensor_dim, - const std::vector& size_on_each_rank, - int world_size) { - int length_size_on_each_rank = size_on_each_rank.size(); - PADDLE_ENFORCE_EQ( - length_size_on_each_rank, - world_size, - phi::errors::InvalidArgument( - "The length of size_on_each_rank must be equal to world_size.")); - - int64_t sum_size_on_each_rank = - std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0); - PADDLE_ENFORCE_EQ( - sum_size_on_each_rank, - tensor_dim[0], - phi::errors::InvalidArgument( - "The sum of size_on_each_rank must be equal to tensor's dim[0].")); -} -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index ab155de79feedd..f948e050387bca 100755 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -75,6 +75,7 @@ cc_library( generated_op autograd_meta hook_utils) + # FIXME(Aurelius84): It seems utils library is depended in cycle, but # CMake only find it twice to deal cycle depend problem. If it is still # not found, ld error will be raised. diff --git a/paddle/fluid/eager/accumulation/accumulation_node.cc b/paddle/fluid/eager/accumulation/accumulation_node.cc index c2c09444aab2f5..c15739385dd433 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.cc +++ b/paddle/fluid/eager/accumulation/accumulation_node.cc @@ -113,6 +113,24 @@ static void CopyOrAddTensor(paddle::Tensor* tensor, &tensor_values); } } + } else if (LIKELY(t.is_dist_tensor())) { + PADDLE_ENFORCE( + tensor->is_dist_tensor(), + paddle::platform::errors::Fatal("A DistTensor can only do gradient " + "merge with another DistTensor.")); + PADDLE_ENFORCE(!t.is_custom_device(), + paddle::platform::errors::Fatal( + "DistTensor doesn't support custom device.")); + auto t_dist = + std::dynamic_pointer_cast(t.impl()); + paddle::Tensor t_values( + std::make_shared(t_dist->value())); + auto tensor_dist = + std::dynamic_pointer_cast( + tensor->impl()); + paddle::Tensor tensor_values( + std::make_shared(tensor_dist->value())); + paddle::imperative::TensorAdd(t_values, &tensor_values); } else { // TODO(jiabin): Support Other TensorBase later // TODO(zhanlve): Replace SelectedRowsAddTensor with diff --git a/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h b/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h index 8302af3169ee01..5ff677b143d605 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h +++ b/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" paddle::Tensor add_n_ad_func(const std::vector& x); @@ -49,6 +50,11 @@ sync_batch_norm__ad_func(const paddle::Tensor& x, std::string data_layout, bool use_global_stats, bool trainable_statistics); + +paddle::Tensor reshard_ad_function( + const paddle::Tensor& tensor, + const phi::distributed::TensorDistAttr dist_attr); + namespace sparse { std::tuple grad_node; + + // Set grad_node before API Call + if (require_any_grad) { + paddle::platform::RecordEvent node_creation_record_event( + "reshard node_creation", + paddle::platform::TracerEventType::Communication, + 1); + + // Node Construction + grad_node = + std::shared_ptr(new ReshardGradNode(1, 1)); // NOLINT + + // Set TensorWrappers for Forward Inputs if needed + grad_node->SetTensorWrapperNoNeedBufferInput(input); + } + + // Forward API Call + // reshard_func(input, api_result, dist_attr); + auto dist_out_ptr = paddle::reshard(input, dist_attr); + auto api_result = paddle::Tensor(dist_out_ptr); + + // Get Outputs + auto& out = api_result; + + // Get Output AutoGradMeta + egr::AutogradMeta* out_autograd_meta = egr::EagerUtils::autograd_meta(&out); + + // Set grad_node after API call + if (require_any_grad) { + egr::EagerUtils::PassStopGradient(false, out_autograd_meta); + + // SetGradOutMeta & SetEdges + grad_node->SetGradOutMeta(input, 0); + // SetOutRank & SetHistory & SetGradInMeta + if (out_autograd_meta) { + egr::EagerUtils::SetOutRankWithSlot(out_autograd_meta, 0); + egr::EagerUtils::SetHistory(out_autograd_meta, grad_node); + } + grad_node->SetGradInMeta(out, 0); + } + + return out; +#else + PADDLE_THROW(phi::errors::Unavailable( + "Reshard is not supported in this version of Paddle. Try to recompile it " + "with WITH_DISTRIBTUE=ON and reinstall this package.")); + return paddle::Tensor(); +#endif +} diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt b/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt index efdcaa70131e68..7072c5568ab062 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt @@ -3,4 +3,5 @@ set(eager_manual_nodes ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc + ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h index 8f63f4fdfeb613..bc6d1d9f1a1b65 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h @@ -396,6 +396,53 @@ class SyncBatchNormGradNode : public egr::GradNodeBase { bool trainable_statistics_; }; +class ReshardGradNode : public egr::GradNodeBase { + public: + ReshardGradNode() : egr::GradNodeBase() { + VLOG(3) << " Construct ReshardGrad Node."; + } + + ReshardGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num) + : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { + VLOG(3) << " Construct ReshardGrad Node, bwd_in_slot_num: " + << bwd_in_slot_num << ", bwd_out_slot_num: " << bwd_out_slot_num; + } + + ~ReshardGradNode() override { VLOG(3) << " Destruct ReshardGrad Node."; } + + virtual paddle::small_vector, + egr::kSlotSmallVectorSize> + operator()(paddle::small_vector, + egr::kSlotSmallVectorSize>& grads, // NOLINT + bool create_graph = false, + bool is_new_grad = false) override; + + void ClearTensorWrappers() override { + input_.clear(); + SetIsTensorWrappersCleared(true); + } + + std::string name() override { return "ReshardGradNode"; } + + std::shared_ptr Copy() const override { + { + auto copied_node = + std::shared_ptr(new ReshardGradNode(*this)); + return copied_node; + } + } + + // SetTensorWrapperX + // Only input's meta is needed. + void SetTensorWrapperNoNeedBufferInput(const paddle::Tensor& input) { + input_ = egr::TensorWrapper(input, true); + } + + private: + // TensorWrappers + egr::TensorWrapper input_; +}; + namespace sparse { class SyncBatchNormGradNode : public egr::GradNodeBase { public: diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc new file mode 100644 index 00000000000000..2df60f60977045 --- /dev/null +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" +#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/imperative/tracer.h" + +paddle::small_vector, + egr::kSlotSmallVectorSize> // NOLINT +ReshardGradNode::operator()( + paddle::small_vector, + egr::kSlotSmallVectorSize>& grads, + bool create_graph, + bool is_new_grad) { +#ifdef PADDLE_WITH_DISTRIBUTE + VLOG(3) << "Running AD API GRAD: " + << "reshard_grad"; + + // Apply Gradient Hooks + auto hooked_grad = ApplyGradientHooks(grads); + + // Collect GradIn Tensors, Attrs and Recovered TensorWrappers + auto input = egr::EagerUtils::RecoverTensorWrapper(&this->input_); + const auto& dist_attr = + std::static_pointer_cast(input.impl()) + ->dist_attr(); + auto& grad_out = hooked_grad[0][0]; + // Prepare Grad function call + + const auto& out_metas = OutputMeta(); + paddle::small_vector, egr::kSlotSmallVectorSize> + returns(1); + + out_metas[0].size() == 0 ? returns[0].resize(1) + : returns[0].resize(out_metas[0].size()); + + auto& grad_input = returns[0][0]; + + VLOG(5) << "Running C++ API: " + << "reshard_func"; + + if (VLOG_IS_ON(3)) { + const char* INPUT_PRINT_TEMPLATE = "{ Input: [%s]} "; + + std::string input_str = ""; + const char* TENSOR_OUT_GRAD_TEMPLATE = " \n( out_grad , [%s]), "; + std::string input_out_grad_str = paddle::string::Sprintf( + TENSOR_OUT_GRAD_TEMPLATE, egr::EagerUtils::TensorStr(grad_out)); + input_str += input_out_grad_str; + const char* TENSOR_X_TEMPLATE = " \n( x , [%s]), "; + std::string input_x_str = paddle::string::Sprintf( + TENSOR_X_TEMPLATE, egr::EagerUtils::TensorStr(input)); + input_str += input_x_str; + VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str); + } + + // Backward call reshard_func function + auto dist_out_ptr = paddle::reshard(grad_out, dist_attr); + grad_input.set_impl(dist_out_ptr); + + VLOG(5) << "Finish C++ API: reshard_func"; + VLOG(6) << "gradnode_ptr = " << this; + + if (VLOG_IS_ON(4)) { + const char* INPUT_PRINT_TEMPLATE = "{ Input: [%s], \n Output: [%s] } "; + std::string input_str = ""; + std::string output_str = ""; + const char* TENSOR_OUT_GRAD_TEMPLATE = " \n( out_grad , [%s]), "; + std::string input_out_grad_str = paddle::string::Sprintf( + TENSOR_OUT_GRAD_TEMPLATE, egr::EagerUtils::TensorStr(grad_out)); + input_str += input_out_grad_str; + const char* TENSOR_X_TEMPLATE = " \n( x , [%s]), "; + std::string input_x_str = paddle::string::Sprintf( + TENSOR_X_TEMPLATE, egr::EagerUtils::TensorStr(input)); + input_str += input_x_str; + const char* TENSOR_X_GRAD_TEMPLATE = " \n ( input_grad , [%s]), "; + std::string output_x_grad_str = paddle::string::Sprintf( + TENSOR_X_GRAD_TEMPLATE, egr::EagerUtils::TensorStr(grad_input)); + output_str += output_x_grad_str; + VLOG(4) << paddle::string::Sprintf( + INPUT_PRINT_TEMPLATE, input_str, output_str); + } + + return returns; +#else + PADDLE_THROW(phi::errors::Unavailable( + "ReshardGrad is not supported in this version of Paddle. Try to " + "recompile it with WITH_DISTRIBTUE=ON and reinstall this package.")); + return paddle::small_vector, + egr::kSlotSmallVectorSize>(1); +#endif +} diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index ab0e3d6a3e3700..54d22f531a3088 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -1059,7 +1059,11 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): or IsVectorTensorType(atype) or (name in self.optional_inputs) ): - set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" + if for_backward is False: + set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" + else: + set_tensor_wrappers = f"{indent}if({name}_optional) grad_node->SetTensorWrapper{name}(*{name}_optional);" + else: need_pre_contiguous_set.add(name) set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name}_tmp);" @@ -1138,7 +1142,10 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): ) if is_optional: - set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" + if for_backward is False: + set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" + else: + set_grad_out_meta = f"{indent}if({name}_optional.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}_optional.get_ptr()), {pos});" else: if ( is_special_forward_api diff --git a/paddle/fluid/eager/nan_inf_utils.cc b/paddle/fluid/eager/nan_inf_utils.cc index 29922e37beb439..a1e62ea6ba519b 100644 --- a/paddle/fluid/eager/nan_inf_utils.cc +++ b/paddle/fluid/eager/nan_inf_utils.cc @@ -19,6 +19,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/selected_rows.h" @@ -90,8 +91,12 @@ void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) { } else if (tensor.is_selected_rows()) { dense_tensor = &( static_cast(tensor.impl().get())->value()); + } else if (tensor.is_dist_tensor()) { + dense_tensor = &( + static_cast(tensor.impl().get()) + ->value()); } else { - VLOG(10) << "Only DenseTensor or SelectedRows need to check, " + VLOG(10) << "Only DenseTensor,SelectedRows,DistTensor need to check, " << tensor_name << " is no need."; return; } diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 83e4424a212514..ec4edbcc74fdec 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -502,7 +502,7 @@ inline void NewIRRunProgramAPI( // Step 2. create new interpretercore auto kernel_forward_program = paddle::dialect::PdOpLowerToKernelPass(forward_program, place); - interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( std::move(kernel_forward_program), place, /*is_grad=*/false, @@ -708,13 +708,12 @@ inline void RunProgramAPI( input_names, params, place); - interpreter_core = - paddle::framework::CreateNewIRInterpreterCoreInfoToCache( - std::move(ir_program), - place, - /*is_grad=*/false, - program_id, - global_inner_scope); + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( + std::move(ir_program), + place, + /*is_grad=*/false, + program_id, + global_inner_scope); } else { interpreter_core = paddle::framework::CreateProgramInterpreterCoreInfoToCache( @@ -865,13 +864,12 @@ inline void RunProgramGradAPI( global_inner_scope, place); - interpreter_core = - paddle::framework::CreateNewIRInterpreterCoreInfoToCache( - std::move(res), - place, - /*is_grad=*/true, - program_id, - global_inner_scope); + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( + std::move(res), + place, + /*is_grad=*/true, + program_id, + global_inner_scope); } else { interpreter_core = paddle::framework::CreateProgramInterpreterCoreInfoToCache( @@ -1041,7 +1039,7 @@ inline void NewIRRunProgramGradAPI( // Step 1. share input_vars & parameters into scope auto kernel_backward_program = paddle::dialect::PdOpLowerToKernelPass(backward_program, place); - interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( std::move(kernel_backward_program), place, /*is_grad=*/true, diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index b83568cfdd69a3..ea7b4de4a6b0f6 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -55,13 +55,13 @@ function(pass_library TARGET DEST) ${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry - ${pass_library_DEPS}) + quantize_helper ${pass_library_DEPS}) else() cc_library( ${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry - ${pass_library_DEPS}) + quantize_helper ${pass_library_DEPS}) endif() # add more DEST here, such as train, dist and collect USE_PASS into a file automatically. @@ -122,69 +122,16 @@ cc_library( SRCS data_type.cc DEPS framework_proto) -cc_test( - data_type_test - SRCS data_type_test.cc - DEPS data_type place tensor) - cc_library( tensor SRCS tensor_util.cc DEPS place memory data_type device_context phi) -cc_test( - tensor_test - SRCS tensor_test.cc - DEPS tensor isfinite_op) -if(WITH_GPU) - nv_test( - tensor_util_test - SRCS tensor_util_test.cc tensor_util_test.cu - DEPS tensor dlpack_tensor isfinite_op) -elseif(WITH_ROCM) - hip_test( - tensor_util_test - SRCS tensor_util_test.cc tensor_util_test.cu - DEPS tensor dlpack_tensor isfinite_op) -else() - cc_test( - tensor_util_test - SRCS tensor_util_test.cc - DEPS tensor dlpack_tensor isfinite_op) -endif() - -cc_test( - copy_same_tensor_test - SRCS copy_same_tensor_test.cc - DEPS tensor) - -cc_test( - eigen_test - SRCS eigen_test.cc - DEPS tensor) - cc_library( lod_tensor SRCS lod_tensor.cc DEPS phi place tensor framework_proto version) -cc_test( - lod_tensor_test - SRCS lod_tensor_test.cc - DEPS phi lod_tensor memory) - -if(WITH_GPU) - nv_test( - lod_tensor_gpu_test - SRCS lod_tensor_test.cu - DEPS lod_tensor) -elseif(WITH_ROCM) - hip_test( - lod_tensor_gpu_test - SRCS lod_tensor_test.cu - DEPS lod_tensor) -endif() - cc_library( garbage_collector SRCS garbage_collector.cc @@ -194,15 +141,6 @@ cc_library( reader SRCS reader.cc DEPS lod_tensor phi) -cc_test( - reader_test - SRCS reader_test.cc - DEPS reader) - -cc_test( - threadpool_test - SRCS threadpool_test.cc - DEPS phi) cc_library( var_type_traits @@ -221,11 +159,6 @@ if(WITH_MKLDNN) add_dependencies(var_type_traits mkldnn) endif() -cc_test( - var_type_traits_test - SRCS var_type_traits_test.cc - DEPS var_type_traits) - set(BRPC_DEPS "") if(WITH_PSCORE) set(BRPC_DEPS ${EXTERNAL_BRPC_DEPS}) @@ -249,39 +182,15 @@ cc_library( device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor scope ${BRPC_DEPS}) -cc_test( - device_worker_test - SRCS device_worker_test.cc - DEPS device_worker) - cc_library( scope_pool SRCS scope_pool.cc DEPS scope) -cc_test( - scope_test - SRCS scope_test.cc - DEPS scope) -cc_test( - variable_test - SRCS variable_test.cc - DEPS tensor var_type_traits) cc_library( data_device_transform SRCS data_device_transform.cc DEPS tensor) -if(WITH_GPU) - nv_test( - data_device_transform_test - SRCS data_device_transform_test.cu - DEPS operator op_registry device_context phi scope) -elseif(WITH_ROCM) - hip_test( - data_device_transform_test - SRCS data_device_transform_test.cu - DEPS operator op_registry device_context phi scope) -endif() if(WITH_GPU) if(WIN32) @@ -299,47 +208,27 @@ if(WITH_GPU) SRCS data_type_transform.cu DEPS tensor) endif() - nv_test( - data_type_transform_test - SRCS data_type_transform_test.cc data_type_transform_test.cu - DEPS data_type_transform) elseif(WITH_ROCM) hip_library( data_type_transform SRCS data_type_transform.cu DEPS tensor) - hip_test( - data_type_transform_test - SRCS data_type_transform_test.cc data_type_transform_test.cu - DEPS data_type_transform) elseif(WITH_XPU) cc_library( data_type_transform SRCS data_type_transform.cc DEPS tensor xpulib) - cc_test( - data_type_transform_test - SRCS data_type_transform_test.cc - DEPS data_type_transform) else() cc_library( data_type_transform SRCS data_type_transform.cc DEPS tensor) - cc_test( - data_type_transform_test - SRCS data_type_transform_test.cc - DEPS data_type_transform) endif() cc_library( data_layout_transform SRCS data_layout_transform.cc DEPS tensor phi) -cc_test( - data_layout_transform_test - SRCS data_layout_transform_test.cc - DEPS data_layout_transform) cc_library( data_transform @@ -357,18 +246,6 @@ cc_library( attribute SRCS attribute.cc DEPS framework_proto enforce) -cc_test( - attribute_test - SRCS attribute_test.cc - DEPS attribute framework_proto proto_desc) -cc_test( - program_desc_test - SRCS program_desc_test.cc - DEPS proto_desc device_context) -cc_test( - op_desc_test - SRCS op_desc_test.cc - DEPS proto_desc) cc_library( op_version_proto SRCS op_version_proto.cc @@ -378,19 +255,11 @@ cc_library( op_version_registry SRCS op_version_registry.cc DEPS op_version_proto framework_proto) -cc_test( - op_version_registry_test - SRCS op_version_registry_test.cc - DEPS op_version_registry) cc_library( op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute ops_extra_info glog auto_parallel_proto) -cc_test( - op_proto_maker_test - SRCS op_proto_maker_test.cc - DEPS op_proto_maker) cc_library( no_need_buffer_vars_inference SRCS no_need_buffer_vars_inference.cc @@ -410,11 +279,6 @@ if(WITH_MKLDNN) add_dependencies(shape_inference mkldnn) endif() -cc_test( - no_need_buffer_vars_inference_test - SRCS no_need_buffer_vars_inference_test.cc - DEPS no_need_buffer_vars_inference layer) - cc_library( transfer_scope_cache SRCS transfer_scope_cache.cc @@ -503,20 +367,7 @@ else() type_info) endif() -cc_test( - operator_test - SRCS operator_test.cc - DEPS operator op_registry device_context) -cc_test( - operator_exception_test - SRCS operator_exception_test.cc - DEPS operator op_registry device_context) - cc_library(version SRCS version.cc) -cc_test( - version_test - SRCS version_test.cc - DEPS version) add_library(proto_desc_base OBJECT var_desc.cc op_desc.cc block_desc.cc program_desc.cc) @@ -556,31 +407,11 @@ cc_library( op_call_stack SRCS op_call_stack.cc DEPS op_proto_maker enforce) -cc_test( - op_call_stack_test - SRCS op_call_stack_test.cc - DEPS op_call_stack) cc_library( program_utils SRCS program_utils.cc DEPS proto_desc) -cc_test( - program_utils_test - SRCS program_utils_test.cc - DEPS proto_desc program_utils) - -if(WITH_GPU) - nv_test( - op_registry_test - SRCS op_registry_test.cc - DEPS op_registry) -elseif(WITH_ROCM) - hip_test( - op_registry_test - SRCS op_registry_test.cc - DEPS op_registry) -endif() if(WITH_PYTHON) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) @@ -1051,102 +882,24 @@ cc_library( SRCS executor_cache.cc DEPS parallel_executor standalone_executor pir_adaptor pd_inplace_pass pd_op_to_kernel_pass pir) -if(WITH_PSCORE) - get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) - if(WITH_HETERPS) - cc_test( - dist_multi_trainer_test - SRCS dist_multi_trainer_test.cc - DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS} - graph_gpu_wrapper) - cc_test( - heter_pipeline_trainer_test - SRCS heter_pipeline_trainer_test.cc - DEPS conditional_block_op - generated_op - heter_listen_and_serv_op - executor - heter_server - gloo_wrapper - phi - ${RPC_DEPS} - graph_gpu_wrapper) - else() - cc_test( - dist_multi_trainer_test - SRCS dist_multi_trainer_test.cc - DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS}) - cc_test( - heter_pipeline_trainer_test - SRCS heter_pipeline_trainer_test.cc - DEPS conditional_block_op - generated_op - heter_listen_and_serv_op - executor - heter_server - gloo_wrapper - phi - ${RPC_DEPS}) - endif() -else() - cc_test( - dist_multi_trainer_test - SRCS dist_multi_trainer_test.cc - DEPS conditional_block_op executor gloo_wrapper) -endif() cc_library( prune SRCS prune.cc DEPS framework_proto auto_parallel_proto proto_desc) -cc_test( - prune_test - SRCS prune_test.cc - DEPS op_info prune recurrent_op device_context) -cc_test( - var_type_inference_test - SRCS var_type_inference_test.cc - DEPS op_registry proto_desc) cc_library( selected_rows_utils SRCS selected_rows_utils.cc DEPS phi device_context) -cc_test( - selected_rows_utils_test - SRCS selected_rows_utils_test.cc - DEPS selected_rows_utils) - -cc_test( - op_kernel_type_test - SRCS op_kernel_type_test.cc - DEPS place device_context framework_proto op_kernel_type) -cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) - -cc_test(tuple_test SRCS tuple_test.cc) - -cc_test(inlined_vector_test SRCS inlined_vector_test.cc) cc_library( dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack) -cc_test( - dlpack_tensor_test - SRCS dlpack_tensor_test.cc - DEPS dlpack_tensor glog) cc_library( op_compatible_info SRCS op_compatible_info.cc DEPS string_helper proto_desc) -cc_test_old( - op_compatible_info_test - SRCS - op_compatible_info_test.cc - DEPS - op_compatible_info - proto_desc - string_helper - glog) cc_library( infershape_utils @@ -1160,10 +913,6 @@ cc_library( phi_utils op_info shape_inference) -cc_test( - infershape_utils_test - SRCS infershape_utils_test.cc - DEPS infershape_utils phi) # Get the current working branch execute_process( @@ -1215,15 +964,3 @@ set(FLUID_FRAMEWORK_MODULES custom_operator) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) - -if(WITH_TESTING AND TEST selected_rows_utils_test) - set_tests_properties(selected_rows_utils_test PROPERTIES TIMEOUT 120) -endif() - -cc_test(scope_guard_test SRCS scope_guard_test.cc) -cc_test( - phi_utils_test - SRCS phi_utils_test.cc - DEPS phi_utils) - -cc_test(convert_utils_test SRCS convert_utils_test.cc) diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index def0d1742ba951..ce085688dd5825 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -48,6 +48,10 @@ class TEST_API BlockDesc { int32_t Parent() const { return desc_->parent_idx(); } + void SetParent(int32_t parent_id) const { + return desc_->set_parent_idx(parent_id); + } + int32_t ForwardBlockID() const { return desc_->forward_block_idx(); } VarDesc *Var(const std::string &name_bytes); diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 50a16d8f686e78..31db91f5517bd8 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -105,6 +105,7 @@ message AMPConfig { optional bool use_fp16_guard = 11 [ default = true ]; optional bool use_optimizer_fp16 = 12 [ default = false ]; // auto parallel effective only + optional bool use_pure_bf16 = 13 [ default = false ]; } message LocalSGDConfig { diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 2e1eb0a58fe5a5..687256c7d9c511 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -324,7 +324,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( return core; } -std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( +std::shared_ptr CreatePirInterpreterCoreInfoToCache( std::unique_ptr<::pir::Program> ir_program, const platform::Place &place, bool is_grad, @@ -543,6 +543,9 @@ std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( if (FLAGS_new_ir_apply_inplace_pass) { ::pir::PassManager pm(::pir::IrContext::Instance(), 3); pm.AddPass(::pir::CreateInplacePass()); + if (VLOG_IS_ON(6)) { + pm.EnableIRPrinting(); + } pm.Run(res.get()); } diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index d30ed6396e65ef..bd73af80e812e5 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -243,7 +243,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( int64_t program_id, framework::Scope* scope); -std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( +std::shared_ptr CreatePirInterpreterCoreInfoToCache( std::unique_ptr<::pir::Program> ir_prog, const platform::Place& place, bool is_grad, diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 92d316fdea0a31..305a11805c9b06 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -59,6 +59,10 @@ cc_library( placement_pass_base SRCS placement_pass_base.cc DEPS pass) +cc_library( + quantize_helper + SRCS quantize_helper.cc + DEPS graph graph_helper) cc_library( coalesce_grad_tensor_pass @@ -237,7 +241,11 @@ if(WITH_XPU) xpu_pass_utils SRCS xpu/pass_utils.cc DEPS pass xpu_quant_utils) - set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) + cc_library( + xpu_graph_pattern_detector + SRCS xpu/xpu_graph_pattern_detector.cc + DEPS graph_pattern_detector) + set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils xpu_graph_pattern_detector) pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) @@ -247,6 +255,8 @@ if(WITH_XPU) # pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_bias_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(xpu_quantize_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(xpu_quantize_squash_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 14f42b129effa5..d29ef0f9ad1fad 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -523,7 +523,6 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { vars_should_not_low_precision.insert(in_var_node->Var()->Name()); } } - // when op_1 only support cpu kernel. if op_2's intput var is op_1's // output var, then op_2 should not run at low precision. if (GetOpOriginalType(op_type) != "feed" && @@ -687,6 +686,16 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } + } else if (GetOpOriginalType(op_desc->Type()) == "quantize_linear" || + GetOpOriginalType(op_desc->Type()) == "dequantize_linear") { + auto vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("ZeroPoint"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } } } @@ -733,6 +742,11 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert( } void AutoMixedPrecisionPass::SetVarPrecision() const { + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL(scope, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the scope " + "should not be null.")); for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) { @@ -749,7 +763,21 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue; if (!VarNodeHasDtype(real_in_var_node)) continue; if (InputVarsNotConvert(op_node, in_var_name)) continue; - + // Judge the real tensor is same to variable, Paddle-Slim weight use + // fp32 variable to save int8 tensor. + if (real_in_var_node->Var()->Persistable()) { + auto* tensor = scope->Var(real_in_var_node->Name()) + ->GetMutable(); + if (framework::TransToProtoVarType(tensor->type()) != + real_in_var_node->Var()->GetDataType()) { + VLOG(3) << "[AutoMixedPrecisionPass] variable " + << real_in_var_node->Name() << "'s proto data type " + << real_in_var_node->Var()->GetDataType() + << " is different from real dense tensor " + << framework::TransToProtoVarType(tensor->type()); + continue; + } + } if (real_in_var_node->Var()->Persistable()) { real_in_var_node->Var()->SetDataType( framework::TransToProtoVarType(low_precision_)); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index 286f7f08cdfc97..916d577d23d606 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -19,6 +19,7 @@ #include #include #include +#include "paddle/fluid/framework/ir/quantize_helper.h" namespace paddle { namespace framework { @@ -94,6 +95,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { scope, platform::errors::InvalidArgument( "Scope in DeleteQuantDequantLinearOpPass should not be null.")); + std::unordered_map> var_quant_scales{}; + // Create pattern patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(), pattern_name); @@ -141,7 +144,11 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op(); any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), input_scale); - + if (!var_quant_scales.count(quantize_linear_op_x->Var()->Name())) { + var_quant_scales.insert( + std::make_pair(quantize_linear_op_x->Var()->Name(), + std::vector({input_scale}))); + } // link x to any_op2 any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), quantize_linear_op_x->Var()->Name()); @@ -161,6 +168,9 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { }; gpd(graph, handler); AddStatis(found_count); + + SaveQuantInfoInTheGraph( + graph, "has_quant_info", "var_quant_scales", var_quant_scales); } } // namespace ir diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc index cf5c9a2c94cf9b..87f2de2a59e0d9 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" #include "glog/logging.h" @@ -35,18 +36,20 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { true, platform::errors::InvalidArgument( "Graph must have kParamScopeAttr attribute.")); - + VLOG(3) << "Handle delete weight dequant linear op pass ..."; auto& scope = graph->Get(kParamScopeAttr); bool is_int8 = false; std::unordered_set nodes2rm; + std::unordered_map> var_quant_scales{}; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); if (op->Type() == "dequantize_linear") { - Node *weight_var_node = nullptr, *calcu_op_node = nullptr, - *while_op_node = nullptr; + Node* weight_var_node = nullptr; + Node* calcu_op_node = nullptr; + Node* while_op_node = nullptr; Node *dequantized_weight_var_node = nullptr, *scale_var_node = nullptr; // 1. Judge whether for dequant weight and find // weight_var_node/scale_var_node @@ -59,9 +62,12 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { scale_var_node = input_node; } } else { - return; + break; } } + if (weight_var_node == nullptr || scale_var_node == nullptr) { + continue; + } // 2. Find next_op_node // For while op: delete its input which is related to dequantized // For calculation op: set weight scale as their attributes @@ -106,7 +112,7 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { } } else { PADDLE_THROW(platform::errors::Unimplemented( - "The dtype of quantization scale must be FP32/16, " + "The dtype of quantization scale must be FP32/FP16, " "but received %d, which is not supported.", weight_scale_tensor->dtype())); } @@ -125,14 +131,34 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { calcu_op_desc->SetAttr("weight_scale", weight_scale[0]); } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Delete Weight Dequant Linear Op Pass is not supported " - "for " - "per-channel quantization")); + std::vector weights_shape = + weight_var_node->Var()->GetShape(); + quant_axis = quant_axis >= 0 + ? quant_axis + : quant_axis + weights_shape.size(); + PADDLE_ENFORCE_EQ( + weight_scale_nums, + weights_shape[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis != -1, it means using per_channel " + "dequantization. In this situation, the number of " + "weight_scale should be equal with " + "weights_shape[quant_axis=%d]=%ld , but received " + "%d.", + quant_axis, + weights_shape[quant_axis], + weight_scale_nums)); + calcu_op_desc->SetAttr("weight_scale", weight_scale); } + if (!var_quant_scales.count(weight_var_node->Var()->Name())) { + var_quant_scales.insert(std::make_pair( + weight_var_node->Var()->Name(), weight_scale)); + } + calcu_op_desc->RenameInput( dequantized_weight_var_node->Var()->Name(), weight_var_node->Var()->Name()); + calcu_op_desc->Flush(); } } } @@ -153,6 +179,8 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { } GraphSafeRemoveNodes(graph, nodes2rm); + SaveQuantInfoInTheGraph( + graph, "has_quant_info", "var_quant_scales", var_quant_scales); graph->Set("enable_int8", new bool(is_int8)); } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 3596f4e0f0e29e..e42334aac05933 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -412,6 +412,20 @@ class Graph { return sub_graphs_.size(); } + std::vector AttrNames() const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->AttrNames(); + } + } + std::vector res; + res.reserve(attrs_.size()); + for (auto &attr : attrs_) { + res.push_back(attr.first); + } + return res; + } + private: // TODO(levi): delete this interface after when we can convert all // blocks into sub_graphs. diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc index 770a3a7a1d117d..d0fb6d58443ae5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -15,7 +15,6 @@ #include #include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h" -#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/quantize_helper.cc b/paddle/fluid/framework/ir/quantize_helper.cc new file mode 100644 index 00000000000000..08f2cc457ef2c2 --- /dev/null +++ b/paddle/fluid/framework/ir/quantize_helper.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/quantize_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SaveQuantInfoInTheGraph( + ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + const std::unordered_map>& info_map) { + const std::string suffix = "_" + key_suffix + "_" + flag; + if (!graph->Has(flag)) { + graph->Set(flag, new bool(true)); + } + for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { + graph->Set(iter->first + suffix, new std::vector(iter->second)); + } +} + +std::unordered_map> GetQuantInfoFromTheGraph( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix) { + std::unordered_map> info_map; + const std::string suffix = "_" + key_suffix + "_" + flag; + if (graph->Has(flag)) { + std::vector attr_names = graph->AttrNames(); + for (auto fake_name : attr_names) { + size_t pos = fake_name.find(suffix); + if (pos != std::string::npos) { + std::string name = fake_name.substr(0, pos); + auto scales_vector = graph->Get>(fake_name); + info_map.insert(std::make_pair(name, scales_vector)); + } + } + } + return info_map; +} + +bool AreScalesPresentForNodes( + std::unordered_map>* var_quant_scales, + std::initializer_list nodes) { + bool present = true; + for (auto node : nodes) { + if (var_quant_scales->count(node->Name()) == 0) { + present = false; + } + } + return present; +} + +float GetScaleValueForNode( + std::unordered_map>* var_quant_scales, + Node* node) { + return var_quant_scales->at(node->Name())[0]; +} + +std::vector GetScaleVecValueForNode( + std::unordered_map>* var_quant_scales, + Node* node) { + return var_quant_scales->at(node->Name()); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/quantize_helper.h b/paddle/fluid/framework/ir/quantize_helper.h new file mode 100644 index 00000000000000..4876cd35a1cf3a --- /dev/null +++ b/paddle/fluid/framework/ir/quantize_helper.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SaveQuantInfoInTheGraph( + ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + const std::unordered_map>& info_map); + +std::unordered_map> GetQuantInfoFromTheGraph( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix); + +bool AreScalesPresentForNodes( + std::unordered_map>* var_quant_scales, + std::initializer_list nodes); + +float GetScaleValueForNode( + std::unordered_map>* var_quant_scales, + Node* node); + +std::vector GetScaleVecValueForNode( + std::unordered_map>* var_quant_scales, + Node* node); + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc index ef8759153b0ccf..1a56e4d6604312 100644 --- a/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc @@ -127,6 +127,7 @@ int CastMixedPrecisionOpFusePass::ApplyCastBeforePass( GraphPatternDetector gpd; patterns::CastBeforePattern pattern( gpd.mutable_pattern(), name_scope_, mixed_precision_op_type); + auto* scope = param_scope(); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -136,7 +137,22 @@ int CastMixedPrecisionOpFusePass::ApplyCastBeforePass( GET_IR_NODE(cast); GET_IR_NODE(cast_out); GET_IR_NODE(mixed_precision_op); - + // Note: conv2d_xpu/fc_xpu not support float32/int8/float16, can not fuse. + if (mixed_precision_op_type == "conv2d_xpu") { + auto filter_name = mixed_precision_op->Op()->Input("filter")[0]; + auto filter_data_type = + scope->FindVar(filter_name)->GetMutable()->dtype(); + if (filter_data_type == phi::DataType::INT8) { + return; + } + } else if (mixed_precision_op_type == "fc_xpu") { + auto w_name = mixed_precision_op->Op()->Input("w")[0]; + auto w_data_type = + scope->FindVar(w_name)->GetMutable()->dtype(); + if (w_data_type == phi::DataType::INT8) { + return; + } + } mixed_precision_op->Op()->RenameInput(cast_out->Name(), cast_in->Name()); IR_NODE_LINK_TO(cast_in, mixed_precision_op); @@ -155,6 +171,7 @@ int CastMixedPrecisionOpFusePass::ApplyCastAfterPass( GraphPatternDetector gpd; patterns::CastAfterPattern pattern( gpd.mutable_pattern(), name_scope_, mixed_precision_op_type); + auto* scope = param_scope(); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -164,7 +181,30 @@ int CastMixedPrecisionOpFusePass::ApplyCastAfterPass( GET_IR_NODE(cast_in); GET_IR_NODE(cast); GET_IR_NODE(cast_out); - + // Note: conv2d_xpu/fc_xpu not support float16/int8/float32, can not fuse. + if (mixed_precision_op_type == "conv2d_xpu") { + auto filter_name = mixed_precision_op->Op()->Input("filter")[0]; + auto filter_data_type = + scope->FindVar(filter_name)->GetMutable()->dtype(); + auto x_name = mixed_precision_op->Op()->Input("x")[0]; + auto* x_node = FindNodeWithName(graph, x_name); + if (filter_data_type == phi::DataType::INT8 && + x_node->Var()->GetDataType() == + proto::VarType::Type::VarType_Type_FP16) { + return; + } + } else if (mixed_precision_op_type == "fc_xpu") { + auto w_name = mixed_precision_op->Op()->Input("w")[0]; + auto w_data_type = + scope->FindVar(w_name)->GetMutable()->dtype(); + auto x_name = mixed_precision_op->Op()->Input("x")[0]; + auto* x_node = FindNodeWithName(graph, x_name); + if (w_data_type == phi::DataType::INT8 && + x_node->Var()->GetDataType() == + proto::VarType::Type::VarType_Type_FP16) { + return; + } + } mixed_precision_op->Op()->RenameOutput(cast_in->Name(), cast_out->Name()); int out_dtype = proto::VarType::Type::VarType_Type_FP32; mixed_precision_op->Op()->SetAttr("out_dtype", out_dtype); diff --git a/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc index 784d5d4ec029f8..51ebb63c563dcf 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_transpose_xpu_fuse_pass.cc @@ -377,8 +377,14 @@ int Conv2dTransposeXPUFusePass::ApplyImpl(ir::Graph* graph, // filter max Node* filter_int16 = nullptr; Node* filter_max = nullptr; - PrepareWeight( - graph, scope, block, conv_filter, &filter_int16, &filter_max, false); + PrepareWeight(graph, + scope, + block, + conv_filter, + &filter_int16, + &filter_max, + false, + std::vector({})); // output && output max std::string conv2d_xpu_out_name; if (!act_type.empty()) { diff --git a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc index 502c275a419d35..89a558c6601f15 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "glog/logging.h" @@ -19,6 +20,7 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -355,6 +357,57 @@ class Conv2dXPUFusePass : public FusePassBase { bool with_branch_x, bool with_branch_y) const; + Node* GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const; + + void CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_conv_bias, + bool with_bn, + bool with_scale, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionBranch( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::string act_type, + std::unordered_map>* var_quant_scales) + const; + + const std::unordered_set support_quant_op_type_{"conv2d", + "conv2d_xpu"}; const std::string name_scope_{"conv2d_xpu_fuse_pass"}; }; @@ -401,6 +454,532 @@ void Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_subgraph_count); } +Node* Conv2dXPUFusePass::GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const { + auto iter = nodes_map.find(pattern_node_name); + PADDLE_ENFORCE_EQ( + iter != nodes_map.end(), + true, + platform::errors::InvalidArgument("nodes_map[%s] not found in nodes_map", + pattern_node_name.c_str())); + auto node_map = iter->second; + auto node_iter = node_map.find(node_name); + PADDLE_ENFORCE_EQ(node_iter != node_map.end(), + true, + platform::errors::InvalidArgument( + "nodes_map[%s][%s] not found in nodes_map", + pattern_node_name.c_str(), + node_name.c_str())); + return node_iter->second; +} + +void Conv2dXPUFusePass::CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_conv_bias, + bool with_bn, + bool with_scale, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + auto* conv_filter = GetNodeFromNodesMap(nodes_map, "conv", "conv_filter"); + PADDLE_ENFORCE_EQ(conv_filter != nullptr, + true, + platform::errors::InvalidArgument( + "conv_filter node ptr can not be null")); + + // transfilter fp16 --> fp32 + auto* filter_t = + scope->FindVar(conv_filter->Name())->GetMutable(); + auto filter_len = filter_t->numel(); + auto filter_dtype = filter_t->dtype(); + if (filter_dtype == phi::DataType::FLOAT16) { + CastToFp32(filter_t, nullptr); + } + + // Get Weight scale in int8 scene + std::vector weight_scale{}; + if (AreScalesPresentForNodes(var_quant_scales, {conv_filter})) { + weight_scale = GetScaleVecValueForNode(var_quant_scales, conv_filter); + } + // Create fusion_bias_node + auto filter_dims = filter_t->dims(); + Node* fusion_bias_node = nullptr; + if (with_conv_bias) { + auto* ew_bias_add_y = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_y"); + PADDLE_ENFORCE_EQ(ew_bias_add_y != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_y node ptr can not be null")); + auto* ew_bias_add_y_t = + scope->FindVar(ew_bias_add_y->Name())->GetMutable(); + auto ew_bias_add_y_dims = ew_bias_add_y_t->dims(); + PADDLE_ENFORCE_EQ(filter_dims[0], + ew_bias_add_y_dims[0], + platform::errors::InvalidArgument( + "the shape[%d] of elewise bias tensor " + "must equal out_channel[%d] of conv", + ew_bias_add_y_dims[0], + filter_dims[0])); + PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node); + } + + if (with_bn) { + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + PADDLE_ENFORCE_EQ( + bn != nullptr, + true, + platform::errors::InvalidArgument("bn node ptr can not be null")); + auto* bn_bias = GetNodeFromNodesMap(nodes_map, "bn", "bn_bias"); + PADDLE_ENFORCE_EQ( + bn_bias != nullptr, + true, + platform::errors::InvalidArgument("bn_bias node ptr can not be null")); + auto* bn_scale = GetNodeFromNodesMap(nodes_map, "bn", "bn_scale"); + PADDLE_ENFORCE_EQ( + bn_scale != nullptr, + true, + platform::errors::InvalidArgument("bn_scale node ptr can not be null")); + auto* bn_var = GetNodeFromNodesMap(nodes_map, "bn", "bn_var"); + PADDLE_ENFORCE_EQ( + bn_var != nullptr, + true, + platform::errors::InvalidArgument("bn_var node ptr can not be null")); + auto* bn_mean = GetNodeFromNodesMap(nodes_map, "bn", "bn_mean"); + PADDLE_ENFORCE_EQ( + bn_mean != nullptr, + true, + platform::errors::InvalidArgument("bn_mean node ptr can not be null")); + + auto bn_bias_t = + scope->Var(bn_bias->Name())->GetMutable(); + PADDLE_ENFORCE_EQ( + filter_dims[0], + bn_bias_t->dims()[0], + platform::errors::InvalidArgument("the shape[%d] of bn bias tensor " + "must equal out_channel[%d] of conv", + bn_bias_t->dims()[0], + filter_dims[0])); + auto bn_scale_t = + scope->Var(bn_scale->Name())->GetMutable(); + auto bn_mean_t = + scope->Var(bn_mean->Name())->GetMutable(); + auto bn_var_t = scope->Var(bn_var->Name())->GetMutable(); + float* bn_scale_ptr = bn_scale_t->data(); + float* bn_bias_ptr = bn_bias_t->data(); + float* bn_mean_ptr = bn_mean_t->data(); + float* bn_var_ptr = bn_var_t->data(); + auto mean_len = bn_mean_t->numel(); + auto filter_stride = filter_len / mean_len; + float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); + if (!with_conv_bias) { // prev node is conv + PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); + } + + auto fusion_bias_t = + scope->Var(fusion_bias_node->Name())->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + // recompute bias and weights + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + } + // recompute the weights + if (op_weights_precision != "int8") { + float* filter_ptr = filter_t->data(); + for (int i = 0; i < mean_len; ++i) { + for (int j = 0; j < filter_stride; j++) { + filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; + } + } + } else { + int8_t* filter_ptr = filter_t->data(); + PADDLE_ENFORCE_EQ( + weight_scale.size(), + mean_len, + platform::errors::InvalidArgument( + "Weight max_scale size must equal batch_norm sacle/mean size.")); + for (int i = 0; i < mean_len; i++) { + weight_scale[i] *= fabs(bn_scale_ptr[i]); + } + for (int i = 0; i < mean_len; i++) { + if (bn_scale_ptr[i] < 0) { + for (int j = 0; j < filter_stride; ++j) { + filter_ptr[i * filter_stride + j] *= -1; + } + } + } + } + // recompute bias + if (!with_conv_bias) { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } else { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] = + bn_bias_ptr[i] + + (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } + } + + // deal with scale op + if (with_scale) { + auto* scale = GetNodeFromNodesMap(nodes_map, "scale", "scale"); + PADDLE_ENFORCE_EQ( + scale != nullptr, + true, + platform::errors::InvalidArgument("scale node ptr can not be null")); + auto bias_len = filter_dims[0]; + float scale_val_ = 1.f; + float bias_val_ = 0.f; + scale_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + bias_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("bias")); + bool bias_after_scale_ = + PADDLE_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); + // recompute bias as scale op + auto fusion_bias_t = + scope->GetVar(fusion_bias_node->Name())->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + for (int i = 0; i < bias_len; ++i) { + if (bias_after_scale_) { + fusion_bias_ptr[i] = fusion_bias_ptr[i] * scale_val_ + bias_val_; + } else { + fusion_bias_ptr[i] = (fusion_bias_ptr[i] + bias_val_) * scale_val_; + } + } + // recompute weight as scale op + if (op_weights_precision != "int8") { + float* filter_ptr = filter_t->data(); + for (int i = 0; i < filter_len; ++i) { + filter_ptr[i] *= scale_val_; + } + } else { + for (size_t i = 0; i < weight_scale.size(); i++) { + weight_scale[i] *= scale_val_; + } + } + } + + (*fusion_nodes_map)["bias"] = fusion_bias_node; + + Node* filter_intx = nullptr; + Node* filter_max = nullptr; + Node* scale_max = nullptr; + if (op_weights_precision != "int8") { + PrepareWeight(graph, + scope, + block, + conv_filter, + &filter_intx, + &filter_max, + false, + weight_scale); + } else { + PrepareWeight(graph, + scope, + block, + conv_filter, + &filter_intx, + &filter_max, + false, + weight_scale); + } + + bool is_per_channel_need_create_scale_max_node = + !weight_scale.empty() && !IsPerTensorQuant(weight_scale); + if (is_per_channel_need_create_scale_max_node) { + phi::DenseTensor ones_weight_max_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + int max_ptr_size = weight_scale.empty() + ? phi::backends::xpu::get_xpu_max_ptr_size(-1) + : weight_scale.size(); + ones_weight_max_tensor.set_type(phi::DataType::FLOAT32); + ones_weight_max_tensor.Resize({max_ptr_size}); + std::vector ones_weight(max_ptr_size, 1.0); + memcpy(cpu_ctx->Alloc(&ones_weight_max_tensor), + ones_weight.data(), + max_ptr_size * sizeof(float)); + + std::string scale_max_name = conv_filter->Name() + "_scale_max"; + VarDesc scale_max_desc(scale_max_name); + scale_max_desc.SetPersistable(true); + scale_max_desc.SetShape(vectorize(ones_weight_max_tensor.dims())); + scale_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + scale_max = graph->CreateVarNode(&scale_max_desc); + auto* block_scale_max_desc = block->Var(scale_max_name); + block_scale_max_desc->SetPersistable(scale_max_desc.Persistable()); + block_scale_max_desc->SetShape(scale_max_desc.GetShape()); + block_scale_max_desc->SetDataType(scale_max_desc.GetDataType()); + Assign(ones_weight_max_tensor, + scope->Var(scale_max_name)->GetMutable()); + } + + (*fusion_nodes_map)["filter"] = filter_intx; + if (is_per_channel_need_create_scale_max_node) { + (*fusion_nodes_map)["filter_max"] = scale_max; + (*fusion_nodes_map)["scale_max"] = filter_max; + } else { + (*fusion_nodes_map)["filter_max"] = filter_max; + (*fusion_nodes_map)["scale_max"] = scale_max; + } +} + +void Conv2dXPUFusePass::CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + auto* input = GetNodeFromNodesMap(nodes_map, "conv", "input"); + PADDLE_ENFORCE_EQ( + input != nullptr, + true, + platform::errors::InvalidArgument("conv input node ptr can not be null")); + // input max + std::string conv_input_max_name = input->Name() + "_input_max"; + Node* conv2d_xpu_input_max = nullptr; + if (op_weights_precision == "int8") { + PADDLE_ENFORCE_EQ(AreScalesPresentForNodes(var_quant_scales, {input}), + true, + platform::errors::InvalidArgument( + "When conv op is running in int8 precision, the " + "scales of input var should be present in!")); + float input_scale = GetScaleValueForNode(var_quant_scales, input); + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc conv_input_max_desc(conv_input_max_name); + conv_input_max_desc.SetPersistable(true); + conv_input_max_desc.SetShape({static_cast(max_ptr_size)}); + conv_input_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + conv2d_xpu_input_max = graph->CreateVarNode(&conv_input_max_desc); + auto input_max_tensor = + scope->Var(conv_input_max_name)->GetMutable(); + input_max_tensor->set_type(phi::DataType::FLOAT32); + input_max_tensor->Resize({max_ptr_size}); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + std::vector input_scales(max_ptr_size, input_scale); + memcpy(cpu_ctx->Alloc(input_max_tensor), + input_scales.data(), + max_ptr_size * sizeof(float)); + } + (*fusion_nodes_map)["x"] = input; + (*fusion_nodes_map)["x_max"] = conv2d_xpu_input_max; +} + +void Conv2dXPUFusePass::CreateFusionBranch( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* ew_branch_add = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add"); + if (ew_branch_add) { + auto* ew_branch_add_in = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add_in"); + PADDLE_ENFORCE_EQ(ew_branch_add_in != nullptr, + true, + platform::errors::InvalidArgument( + "ew_branch_add_in node ptr can not be null")); + (*fusion_nodes_map)["branch"] = ew_branch_add_in; + // ew_branch_add_max + std::string ew_branch_add_max_name = + ew_branch_add_in->Name() + "branch_max"; + Node* ew_branch_add_max = FindNodeWithName(graph, ew_branch_add_max_name); + if (op_weights_precision == "int8" && !ew_branch_add_max) { + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc ew_branch_add_in_max_desc(ew_branch_add_max_name); + ew_branch_add_in_max_desc.SetPersistable(true); + ew_branch_add_in_max_desc.SetShape({static_cast(max_ptr_size)}); + ew_branch_add_in_max_desc.SetDataType( + proto::VarType::Type::VarType_Type_FP32); + ew_branch_add_max = graph->CreateVarNode(&ew_branch_add_in_max_desc); + PADDLE_ENFORCE_EQ( + AreScalesPresentForNodes(var_quant_scales, {ew_branch_add_in}), + true, + platform::errors::InvalidArgument( + "When conv op is running in int8 precision with branch add, the " + "scales of branch var should be present in!")); + float ew_branch_add_scale = + GetScaleValueForNode(var_quant_scales, ew_branch_add_in); + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + auto ew_branch_add_max_tensor = + scope->Var(ew_branch_add_max_name)->GetMutable(); + ew_branch_add_max_tensor->set_type(phi::DataType::FLOAT32); + ew_branch_add_max_tensor->Resize({max_ptr_size}); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + std::vector ew_branch_add_scales(max_ptr_size, + ew_branch_add_scale); + memcpy(cpu_ctx->Alloc(ew_branch_add_max_tensor), + ew_branch_add_scales.data(), + max_ptr_size * sizeof(float)); + } + (*fusion_nodes_map)["branch_max"] = ew_branch_add_max; + } +} + +void Conv2dXPUFusePass::CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::string act_type, + std::unordered_map>* var_quant_scales) + const { + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + // output && output max + std::string conv2d_xpu_out_name; + Node* conv2d_out_var_node = nullptr; + + auto* ew_branch_add = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add"); + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + auto* scale = GetNodeFromNodesMap(nodes_map, "scale", "scale"); + auto* ew_bias_add = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add"); + if (!act_type.empty()) { + auto* act_out = GetNodeFromNodesMap(nodes_map, "act", "act_out"); + PADDLE_ENFORCE_EQ( + act_out != nullptr, + true, + platform::errors::InvalidArgument("act_out node ptr can not be null")); + conv2d_xpu_out_name = act_out->Name(); + conv2d_out_var_node = act_out; + auto* act = GetNodeFromNodesMap(nodes_map, "act", "act"); + PADDLE_ENFORCE_EQ( + act != nullptr, + true, + platform::errors::InvalidArgument("act node ptr can not be null")); + } else if (ew_branch_add) { + auto* ew_branch_add_out = + GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add_out"); + PADDLE_ENFORCE_EQ(ew_branch_add_out != nullptr, + true, + platform::errors::InvalidArgument( + "ew_branch_add_out node ptr can not be null")); + conv2d_xpu_out_name = ew_branch_add_out->Name(); + conv2d_out_var_node = ew_branch_add_out; + PADDLE_ENFORCE_EQ(ew_branch_add != nullptr, + true, + platform::errors::InvalidArgument( + "ew_branch_add node ptr can not be null")); + } else if (scale) { + auto* scale_out = GetNodeFromNodesMap(nodes_map, "scale", "scale_out"); + PADDLE_ENFORCE_EQ(scale_out != nullptr, + true, + platform::errors::InvalidArgument( + "scale_out node ptr can not be null")); + conv2d_xpu_out_name = scale_out->Name(); + conv2d_out_var_node = scale_out; + } else if (bn) { + auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out"); + PADDLE_ENFORCE_EQ( + bn_out != nullptr, + true, + platform::errors::InvalidArgument("bn_out node ptr can not be null")); + conv2d_xpu_out_name = bn_out->Name(); + conv2d_out_var_node = bn_out; + } else if (ew_bias_add) { + auto* ew_bias_add_out = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out"); + PADDLE_ENFORCE_EQ(ew_bias_add_out != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_out node ptr can not be null")); + conv2d_xpu_out_name = ew_bias_add_out->Name(); + conv2d_out_var_node = ew_bias_add_out; + } else { + auto* conv_out = GetNodeFromNodesMap(nodes_map, "conv", "conv_out"); + PADDLE_ENFORCE_EQ( + conv_out != nullptr, + true, + platform::errors::InvalidArgument("conv_out node ptr can not be null")); + conv2d_xpu_out_name = conv_out->Name(); + conv2d_out_var_node = conv_out; + auto* conv = GetNodeFromNodesMap(nodes_map, "conv", "conv"); + PADDLE_ENFORCE_EQ( + conv != nullptr, + true, + platform::errors::InvalidArgument("conv node ptr can not be null")); + } + (*fusion_nodes_map)["out"] = conv2d_out_var_node; + + // Create out max in + if (op_weights_precision == "int8" && + AreScalesPresentForNodes(var_quant_scales, {conv2d_out_var_node})) { + std::string conv_out_max_in_name = conv2d_xpu_out_name + "_max_in"; + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc conv_out_max_in_desc(conv_out_max_in_name); + conv_out_max_in_desc.SetPersistable(true); + conv_out_max_in_desc.SetShape({static_cast(max_ptr_size)}); + conv_out_max_in_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + Node* conv2d_xpu_out_max_in = graph->CreateVarNode(&conv_out_max_in_desc); + auto* block_out_max_in_desc = block->Var(conv_out_max_in_name); + block_out_max_in_desc->SetPersistable(conv_out_max_in_desc.Persistable()); + block_out_max_in_desc->SetShape(conv_out_max_in_desc.GetShape()); + block_out_max_in_desc->SetDataType(conv_out_max_in_desc.GetDataType()); + + float output_scale = + GetScaleValueForNode(var_quant_scales, conv2d_out_var_node); + phi::DenseTensor out_max_in_cpu_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + out_max_in_cpu_tensor.set_type(phi::DataType::FLOAT32); + out_max_in_cpu_tensor.Resize({max_ptr_size}); + std::vector output_scales(max_ptr_size, output_scale); + memcpy(cpu_ctx->Alloc(&out_max_in_cpu_tensor), + output_scales.data(), + max_ptr_size * sizeof(float)); + Assign(out_max_in_cpu_tensor, + scope->Var(conv_out_max_in_name)->GetMutable()); + (*fusion_nodes_map)["out_max_in"] = conv2d_xpu_out_max_in; + } + + // Create out max + std::string conv_out_max_name = conv2d_xpu_out_name + "_max"; + VarDesc conv_out_max_desc(conv_out_max_name); + Node* conv2d_xpu_out_max = graph->CreateVarNode(&conv_out_max_desc); + (*fusion_nodes_map)["out_max"] = conv2d_xpu_out_max; +} + int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, const std::string& conv_type, const std::string& act_type, @@ -419,18 +998,23 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, with_scale, with_branch_x, with_branch_y); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + std::unordered_map> var_quant_scales = + GetQuantInfoFromTheGraph(graph, "has_quant_info", "var_quant_scales"); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle Conv2dXPUFusePass fuse"; - /* declare operator node's name */ + std::map> nodes_map; GET_IR_NODE(conv); GET_IR_NODE(ew_bias_add); GET_IR_NODE(bn); GET_IR_NODE(scale); GET_IR_NODE(ew_branch_add); GET_IR_NODE(act); - /* declare variable node's name*/ + /* Get variable node's name*/ GET_IR_NODE(input); GET_IR_NODE(conv_filter); GET_IR_NODE(conv_out); @@ -449,166 +1033,132 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(ew_branch_add_in); GET_IR_NODE(ew_branch_add_out); GET_IR_NODE(act_out); + + nodes_map.insert({"conv", + {{"conv", conv}, + {"conv_filter", conv_filter}, + {"input", input}, + {"conv_out", conv_out}}}); + nodes_map.insert({"ew_bias_add", + {{"ew_bias_add", ew_bias_add}, + {"ew_bias_add_y", ew_bias_add_y}, + {"ew_bias_add_out", ew_bias_add_out}}}); + nodes_map.insert({"bn", + {{"bn", bn}, + {"bn_bias", bn_bias}, + {"bn_mean", bn_mean}, + {"bn_scale", bn_scale}, + {"bn_var", bn_var}, + {"bn_out", bn_out}, + {"bn_var_out", bn_var_out}, + {"bn_mean_out", bn_mean_out}, + {"bn_saved_var", bn_saved_var}, + {"bn_saved_mean", bn_saved_mean}}}); + nodes_map.insert({"scale", {{"scale", scale}, {"scale_out", scale_out}}}); + nodes_map.insert({"ew_branch_add", + {{"ew_branch_add", ew_branch_add}, + {"ew_branch_add_in", ew_branch_add_in}, + {"ew_branch_add_out", ew_branch_add_out}}}); + nodes_map.insert({"act", {{"act", act}, {"act_out", act_out}}}); + + std::map fusion_nodes_map{{"x", nullptr}, + {"x_max", nullptr}, + {"filter", nullptr}, + {"filter_max", nullptr}, + {"bias", nullptr}, + {"branch", nullptr}, + {"branch_max", nullptr}, + {"scale_max", nullptr}, + {"out_max_in", nullptr}, + {"out", nullptr}, + {"out_max", nullptr}}; + + auto filter_data_type = scope->FindVar(conv_filter->Name()) + ->GetMutable() + ->dtype(); + std::string op_weights_precision = "float32"; + if (filter_data_type == phi::DataType::INT8) { + op_weights_precision = "int8"; + } else if (filter_data_type == phi::DataType::FLOAT16) { + op_weights_precision = "float16"; + } + VLOG(4) << "Conv2d fusion fuse pass is running on " << op_weights_precision + << " precision!"; auto* block = conv->Op()->Block(); - auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL( - scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); - - // recompute bias and weight for conv2d_xpu op - auto* filter_t = - scope->FindVar(conv_filter->Name())->GetMutable(); - // conv_filter fp16 --> fp32 - auto filter_len = filter_t->numel(); - auto filter_dtype = filter_t->dtype(); - int out_dtype = proto::VarType::Type::VarType_Type_FP32; - if (filter_dtype == phi::DataType::FLOAT16) { - out_dtype = proto::VarType::Type::VarType_Type_FP16; - CastToFp32(filter_t, nullptr); - } - - auto filter_dims = filter_t->dims(); - bool has_bias = with_bn || with_conv_bias; - // Create conv_fusion_bias (conv bias) variable - Node* fusion_bias_node = nullptr; - if (has_bias) { - if (with_conv_bias) { - auto* ew_bias_add_y_t = scope->FindVar(ew_bias_add_y->Name()) - ->GetMutable(); - auto ew_bias_add_y_dims = ew_bias_add_y_t->dims(); - PADDLE_ENFORCE_EQ(filter_dims[0], - ew_bias_add_y_dims[0], - platform::errors::InvalidArgument( - "the shape[%d] of elewise bias tensor " - "must equal out_channel[%d] of conv", - ew_bias_add_y_dims[0], - filter_dims[0])); - PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node); - } - if (with_bn) { - auto bn_bias_t = - scope->Var(bn_bias->Name())->GetMutable(); - PADDLE_ENFORCE_EQ(filter_dims[0], - bn_bias_t->dims()[0], - platform::errors::InvalidArgument( - "the shape[%d] of bn bias tensor " - "must equal out_channel[%d] of conv", - bn_bias_t->dims()[0], - filter_dims[0])); - auto bn_scale_t = - scope->Var(bn_scale->Name())->GetMutable(); - auto bn_mean_t = - scope->Var(bn_mean->Name())->GetMutable(); - auto bn_var_t = - scope->Var(bn_var->Name())->GetMutable(); - float* filter_ptr = - filter_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_scale_ptr = - bn_scale_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_bias_ptr = - bn_bias_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_mean_ptr = - bn_mean_t->mutable_data(paddle::platform::CPUPlace()); - float* bn_var_ptr = - bn_var_t->mutable_data(paddle::platform::CPUPlace()); - auto mean_len = bn_mean_t->numel(); - auto filter_stride = filter_len / mean_len; - float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); - if (!with_conv_bias) { // prev node is conv - PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); - } - auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) - ->GetMutable(); - float* fusion_bias_ptr = - fusion_bias_t->mutable_data(paddle::platform::CPUPlace()); - // recompute bias and weights - if (!with_conv_bias) { // prev node is conv - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_stride; j++) { - filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; - } - } - } else { - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - fusion_bias_ptr[i] = - bn_bias_ptr[i] + - (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_stride; j++) { - filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; - } - } - } - } + CreateFusionWeightsAndBias(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + with_conv_bias, + with_bn, + with_scale, + op_weights_precision, + &var_quant_scales); + CreateFusionInputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); + CreateFusionBranch(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); + CreateFusionOutputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + act_type, + &var_quant_scales); + + framework::OpDesc conv2d_xpu_op_desc(block); + conv2d_xpu_op_desc.SetType("conv2d_xpu"); + conv2d_xpu_op_desc.SetInput("x", {fusion_nodes_map["x"]->Name()}); + if (fusion_nodes_map["x_max"]) { + conv2d_xpu_op_desc.SetInput("x_max", {fusion_nodes_map["x_max"]->Name()}); } - // deal with scale op - if (with_scale) { - auto bias_len = filter_dims[0]; - float scale_val_ = 1.f; - float bias_val_ = 0.f; - scale_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); - bias_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("bias")); - bool bias_after_scale_ = - PADDLE_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); - // recompute bias as scale op - auto fusion_bias_t = scope->GetVar(fusion_bias_node->Name()) - ->GetMutable(); - float* fusion_bias_ptr = - fusion_bias_t->mutable_data(paddle::platform::CPUPlace()); - for (int i = 0; i < bias_len; ++i) { - if (bias_after_scale_) { - fusion_bias_ptr[i] = fusion_bias_ptr[i] * scale_val_ + bias_val_; - } else { - fusion_bias_ptr[i] = (fusion_bias_ptr[i] + bias_val_) * scale_val_; - } - } - // recompute weight as scale op - float* filter_ptr = - filter_t->mutable_data(paddle::platform::CPUPlace()); - for (int i = 0; i < filter_len; ++i) { - filter_ptr[i] *= scale_val_; - } + conv2d_xpu_op_desc.SetInput("filter", {fusion_nodes_map["filter"]->Name()}); + conv2d_xpu_op_desc.SetInput("filter_max", + {fusion_nodes_map["filter_max"]->Name()}); + if (fusion_nodes_map["scale_max"]) { + conv2d_xpu_op_desc.SetInput("scale_max", + {fusion_nodes_map["scale_max"]->Name()}); } - // filter max - Node* filter_int16 = nullptr; - Node* filter_max = nullptr; - PrepareWeight( - graph, scope, block, conv_filter, &filter_int16, &filter_max, false); - // output && output max - std::string conv2d_xpu_out_name; - if (!act_type.empty()) { - conv2d_xpu_out_name = act_out->Name(); - } else if (ew_branch_add) { - conv2d_xpu_out_name = ew_branch_add_out->Name(); - } else if (scale) { - conv2d_xpu_out_name = scale_out->Name(); - } else if (bn) { - conv2d_xpu_out_name = bn_out->Name(); - } else if (ew_bias_add) { - conv2d_xpu_out_name = ew_bias_add_out->Name(); - } else { - conv2d_xpu_out_name = conv_out->Name(); + if (fusion_nodes_map["out_max_in"]) { + conv2d_xpu_op_desc.SetInput("out_max_in", + {fusion_nodes_map["out_max_in"]->Name()}); } - std::string conv2d_xpu_out_max_name = conv2d_xpu_out_name + "_max"; - VarDesc conv2d_xpu_out_max_desc(conv2d_xpu_out_max_name); - Node* conv2d_xpu_out_max = graph->CreateVarNode(&conv2d_xpu_out_max_desc); - // Generate conv2d_xpu op - framework::OpDesc conv2d_xpu_op_desc(block); - // set input&output var - conv2d_xpu_op_desc.SetType("conv2d_xpu"); - conv2d_xpu_op_desc.SetInput("x", {input->Name()}); - conv2d_xpu_op_desc.SetInput("filter", {filter_int16->Name()}); - conv2d_xpu_op_desc.SetInput("filter_max", {filter_max->Name()}); - conv2d_xpu_op_desc.SetOutput("out", {conv2d_xpu_out_name}); - conv2d_xpu_op_desc.SetOutput("out_max", {conv2d_xpu_out_max_name}); - // set fusion_bias input node - if (has_bias) { - conv2d_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); + conv2d_xpu_op_desc.SetOutput("out", {fusion_nodes_map["out"]->Name()}); + conv2d_xpu_op_desc.SetOutput("out_max", + {fusion_nodes_map["out_max"]->Name()}); + if (with_conv_bias || with_bn) { + PADDLE_ENFORCE_EQ( + fusion_nodes_map["bias"] != nullptr, + true, + platform::errors::InvalidArgument( + "fusion_nodes_map['bias'] node ptr can not be null")); + conv2d_xpu_op_desc.SetInput("bias", {fusion_nodes_map["bias"]->Name()}); } // set ew_branch_add input node if (ew_branch_add != nullptr) { - conv2d_xpu_op_desc.SetInput("branch", {ew_branch_add_in->Name()}); + PADDLE_ENFORCE_EQ( + fusion_nodes_map["branch"] != nullptr, + true, + platform::errors::InvalidArgument( + "fusion_nodes_map['branch'] node ptr can not be null")); + conv2d_xpu_op_desc.SetInput("branch", + {fusion_nodes_map["branch"]->Name()}); + if (fusion_nodes_map["branch_max"]) { + conv2d_xpu_op_desc.SetInput("branch_max", + {fusion_nodes_map["branch_max"]->Name()}); + } } // set attrs of conv2d_xpu float act_param_ = 0.0f; @@ -646,57 +1196,54 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, "strides", PADDLE_GET_CONST(std::vector, conv->Op()->GetAttr("strides"))); conv2d_xpu_op_desc.SetAttr("paddings", conv_paddings); - conv2d_xpu_op_desc.SetAttr("out_dtype", out_dtype); + // out_dtype is same to input precision + conv2d_xpu_op_desc.SetAttr("out_dtype", + fusion_nodes_map["x"]->Var()->GetDataType()); + // Link node auto* conv2d_xpu = graph->CreateOpNode(&conv2d_xpu_op_desc); - IR_NODE_LINK_TO(input, conv2d_xpu); - IR_NODE_LINK_TO(filter_int16, conv2d_xpu); - IR_NODE_LINK_TO(filter_max, conv2d_xpu); - if (ew_bias_add || bn) { - SAFE_IR_NODE_LINK_TO(fusion_bias_node, conv2d_xpu); - } - if (ew_branch_add_in) { - IR_NODE_LINK_TO(ew_branch_add_in, conv2d_xpu); - } - if (act_out) { - IR_NODE_LINK_TO(conv2d_xpu, act_out); - } else if (ew_branch_add_out) { - IR_NODE_LINK_TO(conv2d_xpu, ew_branch_add_out); - } else if (scale_out) { - IR_NODE_LINK_TO(conv2d_xpu, scale_out); - } else if (bn_out) { - IR_NODE_LINK_TO(conv2d_xpu, bn_out); - } else if (ew_bias_add_out) { - IR_NODE_LINK_TO(conv2d_xpu, ew_bias_add_out); - } else { - IR_NODE_LINK_TO(conv2d_xpu, conv_out); + IR_NODE_LINK_TO(fusion_nodes_map["x"], conv2d_xpu); + if (fusion_nodes_map["x_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["x_max"], conv2d_xpu); } - IR_NODE_LINK_TO(conv2d_xpu, conv2d_xpu_out_max); - // delete useless node - std::unordered_set delete_nodes = {conv}; - if (act != nullptr) { - delete_nodes.insert(act); + IR_NODE_LINK_TO(fusion_nodes_map["filter"], conv2d_xpu); + IR_NODE_LINK_TO(fusion_nodes_map["filter_max"], conv2d_xpu); + if (fusion_nodes_map["scale_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["scale_max"], conv2d_xpu); } - if (ew_branch_add != nullptr) { - delete_nodes.insert(ew_branch_add); + if (fusion_nodes_map["bias"]) { + SAFE_IR_NODE_LINK_TO(fusion_nodes_map["bias"], conv2d_xpu); + } + if (fusion_nodes_map["branch"]) { + IR_NODE_LINK_TO(fusion_nodes_map["branch"], conv2d_xpu); + } + if (fusion_nodes_map["branch_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["branch_max"], conv2d_xpu); + } + if (fusion_nodes_map["out_max_in"]) { + IR_NODE_LINK_TO(fusion_nodes_map["out_max_in"], conv2d_xpu); + } + IR_NODE_LINK_TO(conv2d_xpu, fusion_nodes_map["out"]); + IR_NODE_LINK_TO(conv2d_xpu, fusion_nodes_map["out_max"]); + // delete useless node + std::unordered_set delete_nodes; + if (conv != nullptr) { + delete_nodes.insert(conv); } if (scale != nullptr) { delete_nodes.insert(scale); } if (bn != nullptr) { delete_nodes.insert(bn); - delete_nodes.insert(bn_bias); - delete_nodes.insert(bn_var); - delete_nodes.insert(bn_mean); - delete_nodes.insert(bn_scale); - delete_nodes.insert(bn_var_out); - delete_nodes.insert(bn_mean_out); - delete_nodes.insert(bn_saved_var); - delete_nodes.insert(bn_saved_mean); } if (ew_bias_add != nullptr) { delete_nodes.insert(ew_bias_add); - delete_nodes.insert(ew_bias_add_y); + } + if (ew_branch_add != nullptr) { + delete_nodes.insert(ew_branch_add); + } + if (act != nullptr) { + delete_nodes.insert(act); } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 4c8424b7df08fd..373275706700f1 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -244,9 +245,68 @@ class FcXPUFusePass : public FusePassBase { bool with_bn, const std::string& act_type) const; + void CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + std::string mul_type, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_bias, + bool with_bn, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + void CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const; + + Node* GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const; + const std::string name_scope_{"fc_xpu_fuse_pass"}; }; +Node* FcXPUFusePass::GetNodeFromNodesMap( + const std::map>& nodes_map, + std::string pattern_node_name, + std::string node_name) const { + auto iter = nodes_map.find(pattern_node_name); + PADDLE_ENFORCE_EQ( + iter != nodes_map.end(), + true, + platform::errors::InvalidArgument("nodes_map[%s] not found in nodes_map", + pattern_node_name.c_str())); + auto node_map = iter->second; + auto node_iter = node_map.find(node_name); + PADDLE_ENFORCE_EQ(node_iter != node_map.end(), + true, + platform::errors::InvalidArgument( + "nodes_map[%s][%s] not found in nodes_map", + pattern_node_name.c_str(), + node_name.c_str())); + return node_iter->second; +} + void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); @@ -275,6 +335,368 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_subgraph_count); } +void FcXPUFusePass::CreateFusionWeightsAndBias( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + std::string mul_type, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + bool with_bias, + bool with_bn, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* mul = GetNodeFromNodesMap(nodes_map, "mul", "mul"); + PADDLE_ENFORCE_EQ( + mul != nullptr, + true, + platform::errors::InvalidArgument("mul node ptr can not be null")); + auto* mul_w = GetNodeFromNodesMap(nodes_map, "mul", "mul_w"); + PADDLE_ENFORCE_EQ( + mul_w != nullptr, + true, + platform::errors::InvalidArgument("mul_w node ptr can not be null")); + + // transfilter fp16 --> fp32 + auto* filter_t = + scope->FindVar(mul_w->Name())->GetMutable(); + auto filter_len = filter_t->numel(); + auto filter_dtype = filter_t->dtype(); + if (filter_dtype == phi::DataType::FLOAT16) { + CastToFp32(filter_t, nullptr); + } + + bool transpose_w = false; + if (mul_type == "matmul") { + transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); + } else if (mul_type == "matmul_v2") { + transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); + } + // Get Weight scale in int8 scene + std::vector weight_scale{}; + if (AreScalesPresentForNodes(var_quant_scales, {mul_w})) { + weight_scale = GetScaleVecValueForNode(var_quant_scales, mul_w); + } + // Create fusion_bias_node + Node* fusion_bias_node = nullptr; + if (with_bias) { + auto* ew_bias_add_bias = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_bias"); + PADDLE_ENFORCE_EQ(ew_bias_add_bias != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_bias node ptr can not be null")); + PrepareBias(graph, scope, block, ew_bias_add_bias, &fusion_bias_node); + } + + if (with_bn) { + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + PADDLE_ENFORCE_EQ( + bn != nullptr, + true, + platform::errors::InvalidArgument("bn node ptr can not be null")); + auto* bn_bias = GetNodeFromNodesMap(nodes_map, "bn", "bn_bias"); + PADDLE_ENFORCE_EQ( + bn_bias != nullptr, + true, + platform::errors::InvalidArgument("bn_bias node ptr can not be null")); + auto* bn_scale = GetNodeFromNodesMap(nodes_map, "bn", "bn_scale"); + PADDLE_ENFORCE_EQ( + bn_scale != nullptr, + true, + platform::errors::InvalidArgument("bn_scale node ptr can not be null")); + auto* bn_var = GetNodeFromNodesMap(nodes_map, "bn", "bn_var"); + PADDLE_ENFORCE_EQ( + bn_var != nullptr, + true, + platform::errors::InvalidArgument("bn_var node ptr can not be null")); + auto* bn_mean = GetNodeFromNodesMap(nodes_map, "bn", "bn_mean"); + PADDLE_ENFORCE_EQ( + bn_mean != nullptr, + true, + platform::errors::InvalidArgument("bn_mean node ptr can not be null")); + + auto bn_bias_t = + scope->Var(bn_bias->Name())->GetMutable(); + auto bn_scale_t = + scope->Var(bn_scale->Name())->GetMutable(); + auto bn_mean_t = + scope->Var(bn_mean->Name())->GetMutable(); + auto bn_var_t = scope->Var(bn_var->Name())->GetMutable(); + float* bn_scale_ptr = bn_scale_t->data(); + float* bn_bias_ptr = bn_bias_t->data(); + float* bn_mean_ptr = bn_mean_t->data(); + float* bn_var_ptr = bn_var_t->data(); + auto mean_len = bn_mean_t->numel(); + auto filter_stride = filter_len / mean_len; + float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); + if (!with_bias) { // prev node is conv + PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); + } + + auto fusion_bias_t = + scope->Var(fusion_bias_node->Name())->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + // recompute bias and weights + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + } + // recompute the weights + if (op_weights_precision != "int8") { + float* filter_ptr = filter_t->data(); + for (int i = 0; i < mean_len; ++i) { + for (int j = 0; j < filter_stride; j++) { + filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; + } + } + } else { + int8_t* filter_ptr = filter_t->data(); + PADDLE_ENFORCE_EQ( + weight_scale.size(), + mean_len, + platform::errors::InvalidArgument( + "Weight max_scale size must equal batch_norm sacle/mean size.")); + for (int i = 0; i < mean_len; i++) { + weight_scale[i] *= fabs(bn_scale_ptr[i]); + } + for (int i = 0; i < mean_len; i++) { + if (bn_scale_ptr[i] < 0) { + for (int j = 0; j < filter_stride; ++j) { + filter_ptr[i * filter_stride + j] *= -1; + } + } + } + } + // recompute bias + if (!with_bias) { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } else { + for (int i = 0; i < mean_len; ++i) { + fusion_bias_ptr[i] = + bn_bias_ptr[i] + + (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; + } + } + } + + (*fusion_nodes_map)["bias"] = fusion_bias_node; + + Node* filter_intx = nullptr; + Node* filter_max = nullptr; + Node* scale_max = nullptr; + if (op_weights_precision != "int8") { + PrepareWeight(graph, + scope, + block, + mul_w, + &filter_intx, + &filter_max, + !transpose_w, + weight_scale); + } else { + PrepareWeight(graph, + scope, + block, + mul_w, + &filter_intx, + &filter_max, + !transpose_w, + weight_scale); + } + + bool is_per_channel_need_create_scale_max_node = + !weight_scale.empty() && !IsPerTensorQuant(weight_scale); + if (is_per_channel_need_create_scale_max_node) { + phi::DenseTensor ones_weight_max_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + int max_ptr_size = weight_scale.empty() + ? phi::backends::xpu::get_xpu_max_ptr_size(-1) + : weight_scale.size(); + ones_weight_max_tensor.set_type(phi::DataType::FLOAT32); + ones_weight_max_tensor.Resize({max_ptr_size}); + std::vector ones_weight(max_ptr_size, 1.0); + memcpy(cpu_ctx->Alloc(&ones_weight_max_tensor), + ones_weight.data(), + max_ptr_size * sizeof(float)); + + std::string scale_max_name = mul_w->Name() + "_scale_max"; + VarDesc scale_max_desc(scale_max_name); + scale_max_desc.SetPersistable(true); + scale_max_desc.SetShape(vectorize(ones_weight_max_tensor.dims())); + scale_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + scale_max = graph->CreateVarNode(&scale_max_desc); + auto* block_scale_max_desc = block->Var(scale_max_name); + block_scale_max_desc->SetPersistable(scale_max_desc.Persistable()); + block_scale_max_desc->SetShape(scale_max_desc.GetShape()); + block_scale_max_desc->SetDataType(scale_max_desc.GetDataType()); + Assign(ones_weight_max_tensor, + scope->Var(scale_max_name)->GetMutable()); + } + + (*fusion_nodes_map)["w"] = filter_intx; + if (is_per_channel_need_create_scale_max_node) { + (*fusion_nodes_map)["w_max"] = scale_max; + (*fusion_nodes_map)["scale_max"] = filter_max; + } else { + (*fusion_nodes_map)["w_max"] = filter_max; + (*fusion_nodes_map)["scale_max"] = scale_max; + } +} + +void FcXPUFusePass::CreateFusionOutputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + auto* mul = GetNodeFromNodesMap(nodes_map, "mul", "mul"); + PADDLE_ENFORCE_EQ( + mul != nullptr, + true, + platform::errors::InvalidArgument("mul node ptr can not be null")); + // output && output max + std::string fc_xpu_out_name; + Node* fc_out_var_node = nullptr; + + auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + auto* ew_bias_add = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add"); + auto* act = GetNodeFromNodesMap(nodes_map, "act", "act"); + if (act) { + auto* act_out = GetNodeFromNodesMap(nodes_map, "act", "act_out"); + PADDLE_ENFORCE_EQ( + act_out != nullptr, + true, + platform::errors::InvalidArgument("act_out node ptr can not be null")); + fc_xpu_out_name = act_out->Name(); + fc_out_var_node = act_out; + } else if (bn) { + auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out"); + PADDLE_ENFORCE_EQ( + bn_out != nullptr, + true, + platform::errors::InvalidArgument("bn_out node ptr can not be null")); + fc_xpu_out_name = bn_out->Name(); + fc_out_var_node = bn_out; + } else if (ew_bias_add) { + auto* ew_bias_add_out = + GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out"); + PADDLE_ENFORCE_EQ(ew_bias_add_out != nullptr, + true, + platform::errors::InvalidArgument( + "ew_bias_add_out node ptr can not be null")); + fc_xpu_out_name = ew_bias_add_out->Name(); + fc_out_var_node = ew_bias_add_out; + } else { + auto* mul_out = GetNodeFromNodesMap(nodes_map, "mul", "mul_out"); + PADDLE_ENFORCE_EQ( + mul_out != nullptr, + true, + platform::errors::InvalidArgument("mul_out node ptr can not be null")); + fc_xpu_out_name = mul_out->Name(); + fc_out_var_node = mul_out; + } + (*fusion_nodes_map)["out"] = fc_out_var_node; + + // Create out max in + if (op_weights_precision == "int8" && + AreScalesPresentForNodes(var_quant_scales, {fc_out_var_node})) { + std::string fc_out_max_in_name = fc_xpu_out_name + "_max_in"; + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc fc_out_max_in_desc(fc_out_max_in_name); + fc_out_max_in_desc.SetPersistable(true); + fc_out_max_in_desc.SetShape({static_cast(max_ptr_size)}); + fc_out_max_in_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + Node* fc_xpu_out_max_in = graph->CreateVarNode(&fc_out_max_in_desc); + auto* block_out_max_in_desc = block->Var(fc_out_max_in_name); + block_out_max_in_desc->SetPersistable(fc_out_max_in_desc.Persistable()); + block_out_max_in_desc->SetShape(fc_out_max_in_desc.GetShape()); + block_out_max_in_desc->SetDataType(fc_out_max_in_desc.GetDataType()); + + float output_scale = + GetScaleValueForNode(var_quant_scales, fc_out_var_node); + phi::DenseTensor out_max_in_cpu_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + out_max_in_cpu_tensor.set_type(phi::DataType::FLOAT32); + out_max_in_cpu_tensor.Resize({max_ptr_size}); + std::vector output_scales(max_ptr_size, output_scale); + memcpy(cpu_ctx->Alloc(&out_max_in_cpu_tensor), + output_scales.data(), + max_ptr_size * sizeof(float)); + Assign(out_max_in_cpu_tensor, + scope->Var(fc_out_max_in_name)->GetMutable()); + (*fusion_nodes_map)["out_max_in"] = fc_xpu_out_max_in; + } + + // Create out max + std::string fc_out_max_name = fc_xpu_out_name + "_max"; + VarDesc fc_out_max_desc(fc_out_max_name); + Node* fc_xpu_out_max = graph->CreateVarNode(&fc_out_max_desc); + (*fusion_nodes_map)["out_max"] = fc_xpu_out_max; +} + +void FcXPUFusePass::CreateFusionInputs( + ir::Graph* graph, + Scope* scope, + BlockDesc* block, + const std::map>& nodes_map, + std::map* fusion_nodes_map, + std::string op_weights_precision, + std::unordered_map>* var_quant_scales) + const { + // Get Node + auto* mul = GetNodeFromNodesMap(nodes_map, "mul", "mul"); + PADDLE_ENFORCE_EQ( + mul != nullptr, + true, + platform::errors::InvalidArgument("mul node ptr can not be null")); + auto* mul_x = GetNodeFromNodesMap(nodes_map, "mul", "mul_x"); + PADDLE_ENFORCE_EQ( + mul_x != nullptr, + true, + platform::errors::InvalidArgument("mul_x node ptr can not be null")); + // x max + std::string mul_x_max_name = mul_x->Name() + "_max"; + Node* mul_x_max = nullptr; + if (op_weights_precision == "int8") { + PADDLE_ENFORCE_EQ(AreScalesPresentForNodes(var_quant_scales, {mul_x}), + true, + platform::errors::InvalidArgument( + "When fc op is running in int8 precision, the scales " + "of input var should be present in!")); + float input_scale = GetScaleValueForNode(var_quant_scales, mul_x); + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + VarDesc x_max_desc(mul_x_max_name); + x_max_desc.SetPersistable( + true); // Need depends on ir_params_sync_among_devices_pass copy to xpu + // device + x_max_desc.SetShape({static_cast(max_ptr_size)}); + x_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + mul_x_max = graph->CreateVarNode(&x_max_desc); + auto input_max_tensor = + scope->Var(mul_x_max_name)->GetMutable(); + input_max_tensor->set_type(phi::DataType::FLOAT32); + input_max_tensor->Resize({max_ptr_size}); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + std::vector input_scales(max_ptr_size, input_scale); + memcpy(cpu_ctx->Alloc(input_max_tensor), + input_scales.data(), + max_ptr_size * sizeof(float)); + } + (*fusion_nodes_map)["x"] = mul_x; + (*fusion_nodes_map)["x_max"] = mul_x_max; +} + int FcXPUFusePass::ApplyImpl(ir::Graph* graph, const std::string& mul_type, bool with_bias, @@ -287,7 +709,9 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, with_bias, with_bn, act_type); - + auto* scope = param_scope(); + std::unordered_map> var_quant_scales = + GetQuantInfoFromTheGraph(graph, "has_quant_info", "var_quant_scales"); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { @@ -311,108 +735,96 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(bn_saved_mean); GET_IR_NODE(act); GET_IR_NODE(act_out); - auto* block = mul->Op()->Block(); - auto* scope = param_scope(); - - auto* filter_t = - scope->FindVar(mul_w->Name())->GetMutable(); - // weight fp16 --> fp32 - auto filter_dtype = filter_t->dtype(); - int out_dtype = proto::VarType::Type::VarType_Type_FP32; - if (filter_dtype == phi::DataType::FLOAT16) { - out_dtype = proto::VarType::Type::VarType_Type_FP16; - CastToFp32(filter_t, nullptr); - } - auto filter_dims = filter_t->dims(); - - bool transpose_w = false; - if (mul_type == "matmul") { - transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); - } else if (mul_type == "matmul_v2") { - transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); - } - - bool has_bias = with_bn || with_bias; - Node* fusion_bias_node = nullptr; - if (has_bias) { - if (bias != nullptr) { - PrepareBias(graph, scope, block, bias, &fusion_bias_node); - } - if (bn != nullptr) { - auto bn_bias_t = - scope->Var(bn_bias->Name())->GetMutable(); - auto bn_scale_t = - scope->Var(bn_scale->Name())->GetMutable(); - auto bn_mean_t = - scope->Var(bn_mean->Name())->GetMutable(); - auto bn_var_t = - scope->Var(bn_var->Name())->GetMutable(); - float* mul_w_ptr = filter_t->data(); - float* bn_scale_ptr = bn_scale_t->data(); - float* bn_bias_ptr = bn_bias_t->data(); - float* bn_mean_ptr = bn_mean_t->data(); - float* bn_var_ptr = bn_var_t->data(); - auto mean_len = bn_mean_t->numel(); - auto filter_h = filter_dims[0]; - auto filter_w = filter_dims[1]; - float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); - if (fusion_bias_node == nullptr) { // prev node is conv - PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); - } - auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) - ->GetMutable(); - float* fusion_bias_ptr = fusion_bias_t->data(); - // recompute bias and weights - if (bias == nullptr) { - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_h; j++) { - mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i]; - } - } - } else { - for (int i = 0; i < mean_len; ++i) { - bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - bn_bias_ptr[i] += - (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; - for (int j = 0; j < filter_h; j++) { - mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i]; - } - } - memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float)); - } - } - } + std::map> nodes_map; + nodes_map.insert({"mul", + {{"mul", mul}, + {"mul_x", mul_x}, + {"mul_w", mul_w}, + {"mul_out", mul_out}}}); + nodes_map.insert({"ew_bias_add", + {{"ew_bias_add", add}, + {"ew_bias_add_bias", bias}, + {"ew_bias_add_out", add_out}}}); + nodes_map.insert({"bn", + {{"bn", bn}, + {"bn_bias", bn_bias}, + {"bn_mean", bn_mean}, + {"bn_scale", bn_scale}, + {"bn_var", bn_var}, + {"bn_out", bn_out}, + {"bn_var_out", bn_var_out}, + {"bn_mean_out", bn_mean_out}, + {"bn_saved_var", bn_saved_var}, + {"bn_saved_mean", bn_saved_mean}}}); + nodes_map.insert({"act", {{"act", act}, {"act_out", act_out}}}); - Node* mul_w_int16 = nullptr; - Node* mul_w_max = nullptr; - PrepareWeight( - graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w); - - std::string fc_out_name; - if (act_out) { - fc_out_name = act_out->Name(); - } else if (bn) { - fc_out_name = bn_out->Name(); - } else if (add_out) { - fc_out_name = add_out->Name(); - } else { - fc_out_name = mul_out->Name(); + std::map fusion_nodes_map{{"x", nullptr}, + {"x_max", nullptr}, + {"w", nullptr}, + {"w_max", nullptr}, + {"bias", nullptr}, + {"scale_max", nullptr}, + {"out_max_in", nullptr}, + {"out", nullptr}, + {"out_max", nullptr}}; + auto filter_data_type = + scope->FindVar(mul_w->Name())->GetMutable()->dtype(); + std::string op_weights_precision = "float32"; + if (filter_data_type == phi::DataType::INT8) { + op_weights_precision = "int8"; + } else if (filter_data_type == phi::DataType::FLOAT16) { + op_weights_precision = "float16"; } - std::string fc_out_max_name = fc_out_name + "_max"; - VarDesc fc_out_max_desc(fc_out_max_name); - Node* fc_out_max = graph->CreateVarNode(&fc_out_max_desc); + VLOG(4) << "FC fusion fuse pass is running on " << op_weights_precision + << " precision!"; + auto* block = mul->Op()->Block(); + CreateFusionWeightsAndBias(graph, + scope, + block, + mul_type, + nodes_map, + &fusion_nodes_map, + with_bias, + with_bn, + op_weights_precision, + &var_quant_scales); + CreateFusionInputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); + CreateFusionOutputs(graph, + scope, + block, + nodes_map, + &fusion_nodes_map, + op_weights_precision, + &var_quant_scales); // Generate fc_xpu op framework::OpDesc fc_xpu_op_desc(block); fc_xpu_op_desc.SetType("fc_xpu"); - fc_xpu_op_desc.SetInput("x", {mul_x->Name()}); - fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()}); - fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()}); - if (has_bias) { - fc_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); + fc_xpu_op_desc.SetInput("x", {fusion_nodes_map["x"]->Name()}); + if (fusion_nodes_map["x_max"]) { + fc_xpu_op_desc.SetInput("x_max", {fusion_nodes_map["x_max"]->Name()}); + } + fc_xpu_op_desc.SetInput("w", {fusion_nodes_map["w"]->Name()}); + fc_xpu_op_desc.SetInput("w_max", {fusion_nodes_map["w_max"]->Name()}); + if (fusion_nodes_map["bias"]) { + fc_xpu_op_desc.SetInput("bias", {fusion_nodes_map["bias"]->Name()}); + } + if (fusion_nodes_map["scale_max"]) { + fc_xpu_op_desc.SetInput("scale_max", + {fusion_nodes_map["scale_max"]->Name()}); } + if (fusion_nodes_map["out_max_in"]) { + fc_xpu_op_desc.SetInput("out_max_in", + {fusion_nodes_map["out_max_in"]->Name()}); + } + fc_xpu_op_desc.SetOutput("out", {fusion_nodes_map["out"]->Name()}); + fc_xpu_op_desc.SetOutput("out_max", {fusion_nodes_map["out_max"]->Name()}); fc_xpu_op_desc.SetAttr( "in_num_col_dims", static_cast(mul_x->Var()->GetShape().size() - 1)); @@ -440,48 +852,41 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, "act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("slope"))); } } - fc_xpu_op_desc.SetAttr("out_dtype", out_dtype); - fc_xpu_op_desc.SetOutput("out", {fc_out_name}); - fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name}); + // out_dtype is same to input precision + fc_xpu_op_desc.SetAttr("out_dtype", + fusion_nodes_map["x"]->Var()->GetDataType()); auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc); - IR_NODE_LINK_TO(mul_x, fc_xpu); - IR_NODE_LINK_TO(mul_w_int16, fc_xpu); - IR_NODE_LINK_TO(mul_w_max, fc_xpu); - if (bias || bn) { - SAFE_IR_NODE_LINK_TO(fusion_bias_node, fc_xpu); + IR_NODE_LINK_TO(fusion_nodes_map["x"], fc_xpu); + if (fusion_nodes_map["x_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["x_max"], fc_xpu); } - if (act_out) { - IR_NODE_LINK_TO(fc_xpu, act_out); - } else if (bn_out) { - IR_NODE_LINK_TO(fc_xpu, bn_out); - } else if (add_out) { - IR_NODE_LINK_TO(fc_xpu, add_out); - } else { - IR_NODE_LINK_TO(fc_xpu, mul_out); + IR_NODE_LINK_TO(fusion_nodes_map["w"], fc_xpu); + IR_NODE_LINK_TO(fusion_nodes_map["w_max"], fc_xpu); + if (fusion_nodes_map["scale_max"]) { + IR_NODE_LINK_TO(fusion_nodes_map["scale_max"], fc_xpu); } - IR_NODE_LINK_TO(fc_xpu, fc_out_max); + if (fusion_nodes_map["bias"]) { + IR_NODE_LINK_TO(fusion_nodes_map["bias"], fc_xpu); + } + if (fusion_nodes_map["out_max_in"]) { + IR_NODE_LINK_TO(fusion_nodes_map["out_max_in"], fc_xpu); + } + IR_NODE_LINK_TO(fc_xpu, fusion_nodes_map["out"]); + IR_NODE_LINK_TO(fc_xpu, fusion_nodes_map["out_max"]); // delete useless node std::unordered_set delete_nodes; - if (act != nullptr && add != nullptr) { - delete_nodes = {mul, mul_out, add, add_out, act}; - } else if (act) { - delete_nodes = {mul, mul_out, act}; - } else if (add) { - delete_nodes = {mul, mul_out, add}; - } else { - delete_nodes = {mul}; + if (mul != nullptr) { + delete_nodes.insert(mul); } if (bn != nullptr) { delete_nodes.insert(bn); - delete_nodes.insert(bn_bias); - delete_nodes.insert(bn_var); - delete_nodes.insert(bn_mean); - delete_nodes.insert(bn_scale); - delete_nodes.insert(bn_var_out); - delete_nodes.insert(bn_mean_out); - delete_nodes.insert(bn_saved_var); - delete_nodes.insert(bn_saved_mean); + } + if (add != nullptr) { + delete_nodes.insert(add); + } + if (act != nullptr) { + delete_nodes.insert(act); } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc index 725f4e6a86a495..47bf2b06be9d97 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc @@ -424,11 +424,23 @@ int FusedMultiTransformerXPUPass::FusedMultiTransformerXPUQuant( nullptr, platform::errors::Fatal("w node should not be nullptr")); if (quant_post_dynamic_weight_precision == 0) { - PrepareWeight( - graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + PrepareWeight(graph, + scope, + block, + w_node, + &w_intx, + &w_max, + need_transpose, + std::vector({})); } else { - PrepareWeight( - graph, scope, block, w_node, &w_intx, &w_max, need_transpose); + PrepareWeight(graph, + scope, + block, + w_node, + &w_intx, + &w_max, + need_transpose, + std::vector({})); } w_nodes->push_back(w_node); w_intx_nodes->push_back(w_intx); diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc index 1a9db472bc2cc5..9b552bac36f2d1 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc @@ -67,6 +67,7 @@ struct LinkConv2dPattern : public PatternBase { PATTERN_DECL_NODE(fusion_op); // declare variable node's name PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(filter); PATTERN_DECL_NODE(branch); private: @@ -79,14 +80,19 @@ LinkConv2dPattern::LinkConv2dPattern(PDPattern* pattern, : PatternBase(pattern, name_scope, name_scope), with_branch_(with_branch) { auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op("conv2d_xpu"); + auto* x = pattern->NewNode(x_repr())->assert_is_op_input("conv2d_xpu", "x"); + auto* filter = pattern->NewNode(filter_repr()) + ->assert_is_op_input("conv2d_xpu", "filter") + ->assert_is_persistable_var(); PDNode* branch = nullptr; if (with_branch_) { branch = pattern->NewNode(branch_repr()) ->assert_is_op_input("conv2d_xpu", "branch"); - fusion_op->LinksFrom({branch}); + fusion_op->LinksFrom({x, branch, filter}); + } else { + fusion_op->LinksFrom({x, filter}); } - fusion_op->LinksFrom({x}); } struct LinkFcPattern : public PatternBase { @@ -96,18 +102,30 @@ struct LinkFcPattern : public PatternBase { PATTERN_DECL_NODE(fusion_op); // declare variable node's name PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(w); }; LinkFcPattern::LinkFcPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, name_scope) { auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op("fc_xpu"); - auto* x = pattern->NewNode(x_repr())->assert_is_op_input("fc_xpu", "x"); - fusion_op->LinksFrom({x}); + auto* x = pattern->NewNode(x_repr())->assert_is_op_input("fc_xpu", "x"); + auto* w = pattern->NewNode(w_repr()) + ->assert_is_op_input("fc_xpu", "w") + ->assert_is_persistable_var(); + fusion_op->LinksFrom({x, w}); } } // namespace patterns +bool LinkXPUOpMaxPass::IsQuant(Node* weight_node) const { + auto w_dtype = param_scope() + ->FindVar(weight_node->Name()) + ->GetMutable() + ->dtype(); + return w_dtype == phi::DataType::INT8; +} + void LinkXPUOpMaxPass::LinkAddActMax(ir::Graph* graph) const { GraphPatternDetector gpd; patterns::LinkAddActPattern pattern(gpd.mutable_pattern(), name_scope_); @@ -155,15 +173,18 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const { patterns::LinkConv2dPattern pattern( gpd.mutable_pattern(), name_scope_, with_branch); int found_subgraph_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle LinkConv2dMax"; - /* declare operator node's name */ + /* get operator node's name */ GET_IR_NODE(fusion_op); - /* declare variable node's name*/ + /* get variable node's name*/ GET_IR_NODE(x); + GET_IR_NODE(filter); GET_IR_NODE(branch); + if (IsQuant(filter)) { + return; + } auto* fusion_op_desc = fusion_op->Op(); bool fusion_op_has_branch = fusion_op_desc->HasInput("branch"); if (fusion_op_has_branch) { @@ -177,7 +198,12 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const { auto preop_max_var_name = x_pre_op->Output("out_max"); for (auto max_node : x->inputs[0]->outputs) { if (preop_max_var_name[0] == max_node->Name()) { - fusion_op_desc->SetInput("x_max", {max_node->Name()}); + if (fusion_op_desc->HasInput("x_max")) { + auto x_max_old_name = fusion_op_desc->Input("x_max")[0]; + fusion_op_desc->RenameInput(x_max_old_name, max_node->Name()); + } else { + fusion_op_desc->SetInput("x_max", {max_node->Name()}); + } IR_NODE_LINK_TO(max_node, fusion_op); } } @@ -205,14 +231,16 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const { GraphPatternDetector gpd; patterns::LinkFcPattern pattern(gpd.mutable_pattern(), name_scope_); int found_subgraph_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle LinkFcMax"; - /* declare operator node's name */ + /* get operator node's name */ GET_IR_NODE(fusion_op); - /* declare variable node's name*/ + /* get variable node's name*/ GET_IR_NODE(x); + GET_IR_NODE(w); + + if (IsQuant(w)) return; auto* fusion_op_desc = fusion_op->Op(); auto* x_pre_op = x->inputs[0]->Op(); if (x->inputs.size() > 0 && x->inputs[0]->IsOp() && @@ -220,7 +248,12 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const { auto preop_max_var_name = x_pre_op->Output("out_max"); for (auto max_node : x->inputs[0]->outputs) { if (preop_max_var_name[0] == max_node->Name()) { - fusion_op_desc->SetInput("x_max", {max_node->Name()}); + if (fusion_op_desc->HasInput("x_max")) { + auto x_max_old_name = fusion_op_desc->Input("x_max")[0]; + fusion_op_desc->RenameInput(x_max_old_name, max_node->Name()); + } else { + fusion_op_desc->SetInput("x_max", {max_node->Name()}); + } IR_NODE_LINK_TO(max_node, fusion_op); } } diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h index cad199ce573bb9..a71a2e19cf430d 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.h @@ -102,6 +102,7 @@ Fused subgraph: */ void LinkAddActMax(ir::Graph* graph) const; + bool IsQuant(Node* weight_node) const; const std::string name_scope_{"link_xpu_op_max_pass"}; }; diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index 255c1f5d47a4c3..04439608aaa237 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -561,7 +561,8 @@ void MultiEncoderXPUFusePass::PrepareQKVWeight(Graph* graph, &q_w_fp32_t, &k_w_fp32_t, &v_w_fp32_t}; phi::ConcatKernel(*cpu_ctx, in_tensors, 0, &qkv_w_int16_t); - PrepareWeight(&qkv_w_int16_t, &qkv_w_max_t, false); + ConvertWithQuant( + &qkv_w_int16_t, &qkv_w_max_t, false, std::vector({})); size_t qkv_w_int16_hash = HashTensor(qkv_w_int16_t); size_t qkv_w_max_hash = HashTensor(qkv_w_max_t); std::string qkv_w_int16_name = std::to_string(qkv_w_int16_hash); @@ -813,16 +814,17 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( &qkv_w_int16, &qkv_w_max); -#define PREPARE_QKV_MATMUL_W(idx_) \ - Node* qkv_matmul_##idx_##_w_int16 = nullptr; \ - Node* qkv_matmul_##idx_##_w_max = nullptr; \ - PrepareWeight(graph, \ - scope, \ - block, \ - qkv_matmul_##idx_##_w, \ - &qkv_matmul_##idx_##_w_int16, \ - &qkv_matmul_##idx_##_w_max, \ - true); +#define PREPARE_QKV_MATMUL_W(idx_) \ + Node* qkv_matmul_##idx_##_w_int16 = nullptr; \ + Node* qkv_matmul_##idx_##_w_max = nullptr; \ + PrepareWeight(graph, \ + scope, \ + block, \ + qkv_matmul_##idx_##_w, \ + &qkv_matmul_##idx_##_w_int16, \ + &qkv_matmul_##idx_##_w_max, \ + true, \ + std::vector({})); PREPARE_QKV_MATMUL_W(1); PREPARE_QKV_MATMUL_W(2); PREPARE_QKV_MATMUL_W(3); diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index eeb0e23e19ecde..c6dc2913153990 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -121,102 +121,123 @@ size_t HashTensor(const phi::DenseTensor& in) { template size_t HashTensor(const phi::DenseTensor& in); template size_t HashTensor(const phi::DenseTensor& in); +template size_t HashTensor(const phi::DenseTensor& in); std::string GetPrefixWithoutHash(const std::string& name) { std::size_t found = name.find("_#"); return found == std::string::npos ? name : name.substr(0, found); } -template +template void PrepareWeight(Graph* graph, Scope* scope, BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose) { - auto src_name = src->Name(); - auto* src_tensor = scope->Var(src_name)->GetMutable(); - phi::DenseTensor dst_tensor; - Assign(*src_tensor, &dst_tensor); - phi::DenseTensor dst_max_tensor; - PrepareWeight(&dst_tensor, &dst_max_tensor, transpose); - - size_t dst_hash = HashTensor(dst_tensor); - size_t dst_max_hash = HashTensor(dst_max_tensor); - std::string pre_name = GetPrefixWithoutHash(src_name); - std::string dst_name = pre_name + "_#" + std::to_string(dst_hash); - std::string dst_max_name = pre_name + "_max_#" + std::to_string(dst_max_hash); - *dst = FindNodeWithName(graph, dst_name); - if (*dst == nullptr) { - // Create dst node - // Update dst var_desc in block - VarDesc dst_desc(dst_name); - dst_desc.SetPersistable(true); - dst_desc.SetShape(vectorize(dst_tensor.dims())); - dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype())); - *dst = graph->CreateVarNode(&dst_desc); - auto* block_dst_desc = block->Var(dst_name); - block_dst_desc->SetPersistable(dst_desc.Persistable()); - block_dst_desc->SetShape(dst_desc.GetShape()); - block_dst_desc->SetDataType(dst_desc.GetDataType()); - // Create dst_max node - // Update dst_max var_desc in block - VarDesc dst_max_desc(dst_max_name); - dst_max_desc.SetPersistable(true); - dst_max_desc.SetShape(vectorize(dst_max_tensor.dims())); - dst_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); - *dst_max = graph->CreateVarNode(&dst_max_desc); - auto* block_dst_max_desc = block->Var(dst_max_name); - block_dst_max_desc->SetPersistable(dst_max_desc.Persistable()); - block_dst_max_desc->SetShape(dst_max_desc.GetShape()); - block_dst_max_desc->SetDataType(dst_max_desc.GetDataType()); - + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales) { + auto weight_name = weight->Name(); + auto* weight_tensor = scope->Var(weight_name)->GetMutable(); + phi::DenseTensor dst_weight_tensor; + Assign(*weight_tensor, &dst_weight_tensor); + phi::DenseTensor dst_weight_max_tensor; + ConvertWeightWrapper( + &dst_weight_tensor, &dst_weight_max_tensor, transpose, weight_scales); + size_t dst_weight_hash = HashTensor(dst_weight_tensor); + size_t dst_weight_max_hash = HashTensor(dst_weight_max_tensor); + std::string pre_name = GetPrefixWithoutHash(weight_name); + std::string dst_weight_name = + pre_name + "_#" + std::to_string(dst_weight_hash); + std::string dst_weight_max_name = + pre_name + "_max_#" + std::to_string(dst_weight_max_hash); + *dst_weight = FindNodeWithName(graph, dst_weight_name); + if (*dst_weight == nullptr) { + // Create dst_weight node + // Update dst_weight var_desc in block + VarDesc dst_weight_desc(dst_weight_name); + dst_weight_desc.SetPersistable(true); + dst_weight_desc.SetShape(vectorize(dst_weight_tensor.dims())); + dst_weight_desc.SetDataType( + framework::TransToProtoVarType(dst_weight_tensor.dtype())); + *dst_weight = graph->CreateVarNode(&dst_weight_desc); + auto* block_dst_weight_desc = block->Var(dst_weight_name); + block_dst_weight_desc->SetPersistable(dst_weight_desc.Persistable()); + block_dst_weight_desc->SetShape(dst_weight_desc.GetShape()); + block_dst_weight_desc->SetDataType(dst_weight_desc.GetDataType()); + // Create dst_weight_max node + // Update dst_weight_max var_desc in block + VarDesc dst_weight_max_desc(dst_weight_max_name); + dst_weight_max_desc.SetPersistable(true); + dst_weight_max_desc.SetShape(vectorize(dst_weight_max_tensor.dims())); + dst_weight_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + *dst_weight_max = graph->CreateVarNode(&dst_weight_max_desc); + auto* block_dst_weight_max_desc = block->Var(dst_weight_max_name); + block_dst_weight_max_desc->SetPersistable( + dst_weight_max_desc.Persistable()); + block_dst_weight_max_desc->SetShape(dst_weight_max_desc.GetShape()); + block_dst_weight_max_desc->SetDataType(dst_weight_max_desc.GetDataType()); // Find dst/dst_max variable in scope - auto* dst_var = scope->FindVar(dst_name); - if (dst_var == nullptr) { - // Create dst/dst_max variable/tensor - Assign(dst_tensor, scope->Var(dst_name)->GetMutable()); - Assign(dst_max_tensor, - scope->Var(dst_max_name)->GetMutable()); + auto* dst_weight_var = scope->FindVar(dst_weight_name); + if (dst_weight_var == nullptr) { + // Create dst_weight/dst_weight_max variable/tensor + Assign(dst_weight_tensor, + scope->Var(dst_weight_name)->GetMutable()); + Assign(dst_weight_max_tensor, + scope->Var(dst_weight_max_name)->GetMutable()); } else { // Share the same variable PADDLE_ENFORCE_NOT_NULL( - scope->FindVar(dst_max_name), - platform::errors::Fatal( - "dst_max(%s) variable should not be nullptr if dst(%s) " - "variable is exist. (src_name is %s)", - dst_max_name, - dst_name, - src_name)); + scope->FindVar(dst_weight_max_name), + platform::errors::Fatal("dst_weight_max(%s) variable should not be " + "nullptr if dst_weight(%s) " + "variable is exist. (weight_name is %s)", + dst_weight_max_name, + dst_weight_name, + weight_name)); } } else { - *dst_max = FindNodeWithName(graph, dst_max_name); + *dst_weight_max = FindNodeWithName(graph, dst_weight_max_name); PADDLE_ENFORCE_NOT_NULL( - *dst_max, - platform::errors::Fatal( - "dst_max(%s) variable should not be nullptr if dst(%s) " - "variable is exist. (src_name is %s)", - dst_max_name, - dst_name, - src_name)); + *dst_weight_max, + platform::errors::Fatal("dst_weight_max(%s) variable should not be " + "nullptr if dst_weight(%s) " + "variable is exist. (weight_name is %s)", + dst_weight_max_name, + dst_weight_name, + weight_name)); } } -template void PrepareWeight(Graph* graph, - Scope* scope, - BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose); -template void PrepareWeight(Graph* graph, - Scope* scope, - BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose); +template void PrepareWeight( + Graph* graph, + Scope* scope, + BlockDesc* block, + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); + +template void PrepareWeight( + Graph* graph, + Scope* scope, + BlockDesc* block, + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); + +template void PrepareWeight( + Graph* graph, + Scope* scope, + BlockDesc* block, + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); void PrepareBias( Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) { diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.h b/paddle/fluid/framework/ir/xpu/pass_utils.h index d1e7b218a0b468..668519c8eb4065 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.h +++ b/paddle/fluid/framework/ir/xpu/pass_utils.h @@ -57,18 +57,62 @@ std::vector FindOpNodeByInputName(Graph* graph, template size_t HashTensor(const phi::DenseTensor& in); -template +template ::value, Tcpu>::type* + ptr = nullptr> +void ConvertWeightWrapper(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + ConvertWithQuant(weight, weight_max, transpose, weight_scales); +} + +template ::value, Tcpu>::type* + ptr = nullptr> +void ConvertWeightWrapper(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + ConvertWithoutQuant(weight, weight_max, transpose, weight_scales); +} + +// 1. Quant weight from fp32 to int16/int31/int8 +// 2. Weight data is in-place update. +// 3. Generate weight max tensor +template void PrepareWeight(Graph* graph, Scope* scope, BlockDesc* block, - Node* src, - Node** dst, - Node** dst_max, - bool transpose); + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + bool transpose, + const std::vector& weight_scales); void PrepareBias( Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst); +inline std::string FindOutputNameByVarName(framework::OpDesc* op, + const std::string& searched_name) { + std::string ret; + for (const auto& name : op->OutputNames()) + for (const auto& output_name : op->Output(name)) + if (output_name == searched_name) ret = name; + return ret; +} + +inline std::string FindInputNameByVarName(framework::OpDesc* op, + const std::string& searched_name) { + std::string ret; + for (const auto& name : op->InputNames()) + for (const auto& input_name : op->Input(name)) + if (input_name == searched_name) ret = name; + return ret; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index fcda50051a3627..a137a006e9f708 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -64,9 +64,12 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) { case phi::DataType::FLOAT32: phi::TransposeKernel(*cpu_ctx, *in, axis, out_ptr); break; + case phi::DataType::INT8: + phi::TransposeKernel(*cpu_ctx, *in, axis, out_ptr); + break; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Only support fp16 and fp32, but received dtype is %s.", + "Only support fp16/fp32/int8, but received dtype is %s.", phi::DataTypeToString(in->dtype()))); break; } @@ -258,15 +261,30 @@ void QuantFP32ToIntX(const float* src_ptr, } } -template -void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose) { +template < + typename Tcpu, + typename Txpu, + typename std::enable_if::value, Tcpu>::type* ptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + LOG(FATAL) << "Not support for Tcpu is " + << phi::CppTypeToDataType::Type(); +} + +template < + typename Tcpu, + typename Txpu, + typename std::enable_if::value, Tcpu>::type* ptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { // Convert fp16 to fp32 phi::DenseTensor weight_fp32; CastToFp32(weight, &weight_fp32); - // Transpose if (transpose) { Transpose2D(&weight_fp32); } @@ -286,17 +304,74 @@ void PrepareWeight(phi::DenseTensor* weight, max_ptr_size * sizeof(float)); // Quant - weight->set_type(phi::CppTypeToDataType::Type()); + weight->set_type(phi::CppTypeToDataType::Type()); weight->Resize(weight_fp32.dims()); - QuantFP32ToIntX(weight_data, cpu_ctx->Alloc(weight), max_val, size); + QuantFP32ToIntX( + weight_data, cpu_ctx->Alloc(weight), max_val, size); } -template void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); -template void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); +template +void ConvertWithoutQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales) { + if (transpose) { + Transpose2D(weight); + } + if (std::is_same::value || std::is_same::value) { + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + int max_ptr_size = weight_scales.empty() + ? phi::backends::xpu::get_xpu_max_ptr_size(-1) + : weight_scales.size(); + weight_max->set_type(phi::DataType::FLOAT32); + weight_max->Resize({max_ptr_size}); + if (!weight_scales.empty()) { + memcpy(cpu_ctx->Alloc(weight_max), + weight_scales.data(), + max_ptr_size * sizeof(float)); + } else { + LOG(FATAL) << "weight scales cannot be empty!"; + } + } else { + LOG(FATAL) << "Only support int8<->int8 and int16<->int16 convert."; + } +} + +template void ConvertWithQuant( + phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template void ConvertWithQuant( + phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template void ConvertWithoutQuant( + phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +bool IsPerTensorQuant(const std::vector& weight_max) { + bool per_tensor = true; + PADDLE_ENFORCE_GT( + weight_max.size(), + 0, + platform::errors::InvalidArgument( + "Op's channel size: [%d] should great than zero", weight_max.size())); + auto first = weight_max[0]; + for (size_t i = 1; i < weight_max.size(); ++i) { + if (std::abs(first - weight_max[i]) > 1e-6) { + per_tensor = false; + break; + } + } + return per_tensor; +} } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.h b/paddle/fluid/framework/ir/xpu/quant_utils.h index b417fa03323db8..1a2952c6145424 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.h +++ b/paddle/fluid/framework/ir/xpu/quant_utils.h @@ -27,13 +27,31 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); -// 1. Quant weight from fp32 to int16/int31 -// 2. Weight data is in-place update. -// 3. Generate weight max tensor template -void PrepareWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); +void ConvertWithoutQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template ::value, Tcpu>::type* + ptr = nullptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +template ::value, + Tcpu>::type* ptr = nullptr> +void ConvertWithQuant(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose, + const std::vector& weight_scales); + +bool IsPerTensorQuant(const std::vector& weight_max); } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc index 8383501c30b8f9..8fa4a377175a73 100644 --- a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc @@ -286,9 +286,6 @@ void MapMatmulV2ToMatmulXPUPass::MapMatmulV2ToMatmul(ir::Graph* graph) const { desc.SetAttr("transpose_X", matmul_v2->Op()->GetAttr("trans_x")); desc.SetAttr("transpose_Y", matmul_v2->Op()->GetAttr("trans_y")); desc.SetAttr("alpha", 1.0f); - if (matmul_v2->Op()->HasAttr("use_mkldnn")) { - desc.SetAttr("use_mkldnn", matmul_v2->Op()->GetAttr("use_mkldnn")); - } auto matmul_node = graph->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_x, matmul_node); IR_NODE_LINK_TO(matmul_y, matmul_node); diff --git a/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.cc b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.cc new file mode 100644 index 00000000000000..f1d2752321aadc --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +PDNode *patterns::DequantXPUAny::operator()() { + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize_xpu"); + + auto *dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize_xpu", "y"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + dequant_op->LinksTo({dequant_out}); + next_op->LinksFrom({dequant_out}); + + return dequant_out; +} + +PDNode *patterns::QuantXPUAny::operator()() { + auto *quant_in = pattern->NewNode(quant_in_repr()) + ->AsInput() + ->assert_is_op_input("quantize_xpu", "x"); + auto *quant_op = + pattern->NewNode(quant_op_repr())->assert_is_op("quantize_xpu"); + + auto *quant_out = pattern->NewNode(quant_out_repr()) + ->AsOutput() + ->assert_is_op_output("quantize_xpu", "y"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + quant_op->LinksFrom({quant_in}).LinksTo({quant_out}); + next_op->LinksFrom({quant_out}); + + return quant_out; +} + +PDNode *patterns::DequantQuantXPUAny::operator()() { + auto *dequant_in = pattern->NewNode(dequant_in_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_xpu", "x"); + + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize_xpu"); + + auto *dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize_xpu", "y"); + + auto *quant_op = pattern->NewNode(quant_op_repr()) + ->assert_is_op("quantize_xpu") + ->AsIntermediate(); + + auto *quant_out = pattern->NewNode(quant_out_repr()) + ->AsOutput() + ->assert_is_op_output("quantize_xpu"); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); + quant_op->LinksFrom({dequant_out}).LinksTo({quant_out}); + next_op->LinksFrom({quant_out}); + + return quant_out; +} + +PDNode *patterns::OpDequantXPU::operator()() { + auto any_op = pattern->NewNode(any_op_repr())->assert_is_op(); + auto *dequant_in = pattern->NewNode(dequant_in_repr()) + ->assert_is_op_input("dequantize_xpu", "x"); + auto *dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize_xpu"); + auto dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize_xpu", "y"); + + any_op->LinksTo({dequant_in}); + dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); + return dequant_out; +} + +PDNode *patterns::MultipleQuantizeXPU::operator()() { + auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); + + // find nodes that are inputs to quantize operators + prev_out->assert_more([&](Node *node) { + int counter = static_cast(std::count_if( + node->outputs.begin(), node->outputs.end(), [&](Node const *iter) { + return iter && iter->IsOp() && iter->Op()->Type() == "quantize_xpu"; + })); + return (counter > 1); + }); + + return prev_out; +} + +} // namespace patterns +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h new file mode 100644 index 00000000000000..c849b2a24bb48c --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h @@ -0,0 +1,96 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// Dequantize + anyOP +// This quantize is used for getting number of ops the Dequantize's +// output is an input to. +struct DequantXPUAny : public PatternBase { + DequantXPUAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_xpu_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); + PATTERN_DECL_NODE(next_op); +}; + +// Quantize + anyOP +struct QuantXPUAny : public PatternBase { + QuantXPUAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "quant_xpu_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(quant_in); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); + PATTERN_DECL_NODE(next_op); +}; + +// Dequantize + Quantize + anyOP +// This pattern is used for squashing the dequantize-quantize pairs. +struct DequantQuantXPUAny : public PatternBase { + DequantQuantXPUAny(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dequant_quant_xpu_any") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(dequant_in); + PATTERN_DECL_NODE(dequant_max_in); + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); + PATTERN_DECL_NODE(quant_max_in); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); + PATTERN_DECL_NODE(next_op); +}; + +// Op + Dequant +// named nodes: +// any_op, dequant_in +// dequant_op, dequant_out +struct OpDequantXPU : public PatternBase { + OpDequantXPU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "op_dequant_xpu") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(any_op); + PATTERN_DECL_NODE(dequant_in); + PATTERN_DECL_NODE(dequant_max_in); + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); +}; + +// anyOp + more then one quantize op +// This pattern is used for squashing multiple quantize with the same scale. +struct MultipleQuantizeXPU : public PatternBase { + MultipleQuantizeXPU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multiple_quantize_xpu") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(prev_out); +}; + +} // namespace patterns +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc new file mode 100644 index 00000000000000..761f17a92e299b --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc @@ -0,0 +1,280 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/quantize_helper.h" +#include "paddle/utils/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +static void UnlinkNodes(ir::Node* a, ir::Node* b) { + a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b), + a->outputs.end()); + b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a), + b->inputs.end()); +} + +static void MarkAndLogCannotQuantizeOp(Node* op, + const char* details = nullptr) { + std::stringstream msg_ss; + msg_ss << "Cannot quantize operator " << op->Name() + << " (type: " << op->Op()->Type() << ", id: " << op->id() << ")."; + if (details) msg_ss << " " << details; + VLOG(2) << msg_ss.str().c_str(); +} +void XPUQuantizeOpPass::GetQuantInfo(Graph* graph) const { + var_quant_scales_ = + GetQuantInfoFromTheGraph(graph, "has_quant_info", "var_quant_scales"); +} + +void XPUQuantizeOpPass::QuantizeInput(Graph* g, + Node* op, + Node* input, + std::string input_arg_name) const { + auto inputs = op->Op()->InputNames(); + bool name_found = + std::find(inputs.begin(), inputs.end(), input_arg_name) != inputs.end(); + PADDLE_ENFORCE_EQ(name_found, + true, + platform::errors::InvalidArgument( + "Var(%s) isn't the input of the %s operator.", + input_arg_name, + op->Op()->Type())); + + // Create quantize output variable + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); + auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); + quantize_out_node->Var()->SetDataType( + proto::VarType::Type::VarType_Type_INT8); + + // Create a quantize op node + float scale = GetScaleValueForNode(&var_quant_scales_, input); + OpDesc q_desc; + q_desc.SetType("quantize_xpu"); + q_desc.SetInput("x", std::vector({input->Name()})); + q_desc.SetOutput("y", std::vector({quantize_out_node->Name()})); + q_desc.SetAttr("out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + q_desc.SetAttr("scale", static_cast(scale)); + auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. + + // Update op's input + op->Op()->SetInput(input_arg_name, + std::vector({quantize_out_node->Name()})); + + // Link quantize op + UnlinkNodes(input, op); + IR_NODE_LINK_TO(input, quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_node); + IR_NODE_LINK_TO(quantize_out_node, op); +} + +void XPUQuantizeOpPass::DequantizeOutput(Graph* g, + Node* op, + Node* output, + std::string output_arg_name) const { + auto outputs = op->Op()->OutputNames(); + bool name_found = + std::find(outputs.begin(), outputs.end(), output_arg_name) != + outputs.end(); + PADDLE_ENFORCE_EQ(name_found, + true, + platform::errors::InvalidArgument( + "Var(%s) isn't the output of the %s operator.", + output_arg_name, + op->Op()->Type())); + + // Create dequantize input variable + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); + dequantize_in_node->Var()->SetDataType( + proto::VarType::Type::VarType_Type_INT8); + + float scale = GetScaleValueForNode(&var_quant_scales_, output); + // Create a quantize op node + OpDesc deq_desc; + deq_desc.SetType("dequantize_xpu"); + deq_desc.SetInput("x", + std::vector({dequantize_in_node->Name()})); + deq_desc.SetOutput("y", std::vector({output->Name()})); + deq_desc.SetAttr("out_dtype", static_cast(output->Var()->GetDataType())); + deq_desc.SetAttr("scale", static_cast(scale)); + auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. + + // Update op's input + op->Op()->SetOutput(output_arg_name, + std::vector({dequantize_in_node->Name()})); + + // Link dequantize op + UnlinkNodes(op, output); + IR_NODE_LINK_TO(op, dequantize_in_node); + IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, output); +} + +void XPUQuantizeOpPass::QuantizeConv(ir::Graph* graph) const { + for (auto* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->Type() != "conv2d_xpu") { + continue; + } + Node* w_var_node = nullptr; + Node* x_var_node = nullptr; + Node* out_var_node = nullptr; + Node* branch_var_node = nullptr; + + for (auto* input_node : n->inputs) { + if (!input_node->IsVar()) { + continue; + } + if (input_node->Var()->Name() == op->Input("x")[0]) { + x_var_node = input_node; + } else if (input_node->Var()->Name() == op->Input("filter")[0]) { + w_var_node = input_node; + } else if (op->HasInput("branch") && + input_node->Var()->Name() == op->Input("branch")[0]) { + branch_var_node = input_node; + } + } + + for (auto* output_node : n->outputs) { + if (!output_node->IsVar()) { + continue; + } + if (output_node->Var()->Name() == op->Output("out")[0]) { + out_var_node = output_node; + } + } + if (!AreScalesPresentForNodes(&var_quant_scales_, + {x_var_node, w_var_node})) { + MarkAndLogCannotQuantizeOp(n, "No scale available for the operator"); + return; + } + + QuantizeInput(graph, n, x_var_node, "x"); + auto has_output_scale = + AreScalesPresentForNodes(&var_quant_scales_, {out_var_node}); + bool has_branch = branch_var_node != nullptr; + + // Note: Conv2d fusion requires branch datatype is same as output + // datatype, so we should consider branch/output together. + if (has_branch) { + bool has_branch_scale = + AreScalesPresentForNodes(&var_quant_scales_, {branch_var_node}); + if (has_output_scale && has_branch_scale) { + QuantizeInput(graph, n, branch_var_node, "branch"); + DequantizeOutput(graph, n, out_var_node, "out"); + // Note: out_dtype attr must be set, because if dequantize_output, we + // consider the kernel out_dtype as int8. + n->Op()->SetAttr( + "out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + } else { + n->Op()->SetAttr("out_dtype", x_var_node->Var()->GetDataType()); + } + } else { + if (has_output_scale) { + DequantizeOutput(graph, n, out_var_node, "out"); + // Note: out_dtype attr must be set, because if dequantize_output, we + // consider the kernel out_dtype as int8. + n->Op()->SetAttr( + "out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + } else { + n->Op()->SetAttr("out_dtype", x_var_node->Var()->GetDataType()); + } + } + } + } +} + +void XPUQuantizeOpPass::QuantizeFC(ir::Graph* graph) const { + for (auto* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->Type() != "fc_xpu") { + continue; + } + Node* w_var_node = nullptr; + Node* x_var_node = nullptr; + Node* out_var_node = nullptr; + + for (auto* input_node : n->inputs) { + if (!input_node->IsVar()) { + continue; + } + if (input_node->Var()->Name() == op->Input("x")[0]) { + x_var_node = input_node; + } else if (input_node->Var()->Name() == op->Input("w")[0]) { + w_var_node = input_node; + } + } + + for (auto* output_node : n->outputs) { + if (!output_node->IsVar()) { + continue; + } + if (output_node->Var()->Name() == op->Output("out")[0]) { + out_var_node = output_node; + } + } + if (!AreScalesPresentForNodes(&var_quant_scales_, + {x_var_node, w_var_node})) { + MarkAndLogCannotQuantizeOp(n, "No scale available for the operator"); + return; + } + + QuantizeInput(graph, n, x_var_node, "x"); + + auto has_output_scale = + AreScalesPresentForNodes(&var_quant_scales_, {out_var_node}); + if (has_output_scale) { + DequantizeOutput(graph, n, out_var_node, "out"); + n->Op()->SetAttr( + "out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + } else { + n->Op()->SetAttr("out_dtype", x_var_node->Var()->GetDataType()); + } + } + } +} + +void XPUQuantizeOpPass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "Insert quantize/dequantize op to the graph."; + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(name_scope_, graph); + PADDLE_ENFORCE_NOT_NULL( + param_scope(), + platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + GetQuantInfo(graph); + QuantizeConv(graph); + QuantizeFC(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(xpu_quantize_op_pass, paddle::framework::ir::XPUQuantizeOpPass); diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h new file mode 100644 index 00000000000000..28d0f42e76bde2 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.h @@ -0,0 +1,62 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Quantize all supported operators. + */ +class XPUQuantizeOpPass : public FusePassBase { + public: + virtual ~XPUQuantizeOpPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; + void QuantizeConv(Graph* graph) const; + void QuantizeFC(Graph* graph) const; + + private: + void QuantizeInput(Graph* g, + Node* op, + Node* input, + std::string input_arg_name) const; + + void DequantizeOutput(Graph* g, + Node* op, + Node* output, + std::string output_arg_name) const; + + void GetQuantInfo(Graph* graph) const; + + mutable std::unordered_map> var_quant_scales_; + const std::string name_scope_{"xpu_quantize_op_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc new file mode 100644 index 00000000000000..6161293bf7fb76 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc @@ -0,0 +1,281 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file eint8_outcept in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either eint8_outpress or +// implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h" + +#include +#include + +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/xpu_graph_pattern_detector.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/utils/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +XPUQuantizeSquashPass::XPUQuantizeSquashPass() {} + +void XPUQuantizeSquashPass::FindNodesToKeep( + Graph* graph, + std::unordered_map* nodes_keep_counter) const { + GraphPatternDetector gpd; + patterns::DequantXPUAny deq_any_pattern{gpd.mutable_pattern(), + "dequant_xpu_any"}; + deq_any_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, deq_any_pattern); + + if (nodes_keep_counter->find(dequant_out) == nodes_keep_counter->end()) + (*nodes_keep_counter)[dequant_out] = 1; + else + (*nodes_keep_counter)[dequant_out] += 1; + + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +void XPUQuantizeSquashPass::DequantQuantSquash( + Graph* graph, + std::unordered_map* nodes_keep_counter) const { + GraphPatternDetector gpd; + patterns::DequantQuantXPUAny squash_pattern{gpd.mutable_pattern(), + "dequant_quant_xpu_any"}; + squash_pattern(); + + int found_dequant_quant_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern); + + auto* next_op_desc = next_op->Op(); + float dequant_scale = + PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("scale")); + float quant_scale = + PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("scale")); + + PADDLE_ENFORCE_NE( + nodes_keep_counter->find(dequant_out), + nodes_keep_counter->end(), + platform::errors::NotFound("The dequant output node is not found.")); + + // check if dequantize op should be kept or removed, decrease the counter + bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; + + if (dequant_scale == quant_scale) { + // squash dequantize-quantize to nothing + auto quant_out_var_name = quant_out->Name(); + for (auto input_name : next_op_desc->InputNames()) { + auto& input_names = next_op_desc->MutableInputs()->at(input_name); + std::replace(input_names.begin(), + input_names.end(), + quant_out_var_name, + dequant_in->Name()); + next_op_desc->SetInput(input_name, input_names); + } + if (keep_dequant) + GraphSafeRemoveNodes(graph, {quant_op, quant_out}); + else + GraphSafeRemoveNodes(graph, + {dequant_op, quant_op, dequant_out, quant_out}); + + IR_NODE_LINK_TO(dequant_in, next_op); + + found_dequant_quant_count++; + } + }; + gpd(graph, handler); + AddStatis(found_dequant_quant_count); + PrettyLogDetail("--- squashed %d dequantize-quantize pairs", + found_dequant_quant_count); +} + +void XPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::OpDequantXPU op_dequant_pattern{gpd.mutable_pattern(), + "op_dequant_xpu"}; + op_dequant_pattern(); + + int found_op_dequant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash op-dequant ops pair"; + GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern); + + if (dequant_in->outputs.size() == 1) { + // Find the name of the output linking any_op to dequant_in + std::string output_name = + FindOutputNameByVarName(any_op->Op(), dequant_in->Name()); + + if (output_name.empty()) return; + any_op->Op()->SetAttr("out_dtype", dequant_out->Var()->GetDataType()); + any_op->Op()->SetOutput(output_name, + std::vector({dequant_out->Name()})); + IR_NODE_LINK_TO(any_op, dequant_out); + GraphSafeRemoveNodes(graph, {dequant_in, dequant_op}); + found_op_dequant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_op_dequant_squash_count); + PrettyLogDetail("--- squashed %d dequant with ops", + found_op_dequant_squash_count); +} + +// conv2d_xpu, fc_xpu +void XPUQuantizeSquashPass::QuantOpSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::QuantXPUAny quant_any_pattern{gpd.mutable_pattern(), + "quant_xpu_any"}; + quant_any_pattern(); + + int found_quant_op_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash op-dequant ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, quant_any_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_any_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_any_pattern); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, quant_any_pattern); + + if (quant_out->outputs.size() == 1) { + std::string input_name = + FindInputNameByVarName(next_op->Op(), quant_out->Name()); + + if (input_name.empty()) return; + // Only support quant + conv2d_xpu/fc_xpu fusion + if (!(next_op->Op()->Type() == "conv2d_xpu" || + next_op->Op()->Type() == "fc_xpu")) { + return; + } + next_op->Op()->SetInput(input_name, + std::vector({quant_in->Name()})); + IR_NODE_LINK_TO(quant_in, next_op); + GraphSafeRemoveNodes(graph, {quant_out, quant_op}); + found_quant_op_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_quant_op_squash_count); + PrettyLogDetail("--- squashed %d quantize with ops", + found_quant_op_squash_count); +} + +void XPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::MultipleQuantizeXPU multiple_quantize_pattern{ + gpd.mutable_pattern(), "multiple_quantize_xpu"}; + multiple_quantize_pattern(); + + int found_multiple_quantize_squash_count = 0; + int removed_quantize = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "fuse multiple quantize ops"; + + GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, multiple_quantize_pattern); + + auto* first_quant_op = *(std::find_if( + prev_out->outputs.begin(), prev_out->outputs.end(), [&](Node* node) { + return (node->IsOp() && node->Op()->Type() == "quantize_xpu"); + })); + auto* first_quant_out = first_quant_op->outputs[0]; + float scale = first_quant_op->Op()->GetAttrIfExists("scale"); + + PADDLE_ENFORCE_NE(scale, + 0, + platform::errors::InvalidArgument( + "Quantize scale(%f) should not be equal 0.", scale)); + + for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) { + auto quant_op = prev_out->outputs[iter]; + if (quant_op->IsOp() && quant_op->Op()->Type() == "quantize_xpu" && + quant_op->id() != first_quant_op->id() && + quant_op->Op()->GetAttrIfExists("scale") == scale) { + auto quant_out = quant_op->outputs[0]; + auto last_op = quant_out->outputs[0]; + auto last_op_op = last_op->Op(); + + std::string last_op_input_name = + FindInputNameByVarName(last_op_op, quant_out->Name()); + + PADDLE_ENFORCE_NE( + last_op_input_name.empty(), + true, + platform::errors::NotFound("Operator after quantize operator(%s) " + "should have quantize output as input.", + quant_out->Name())); + + // update the next operator input, + // by replacing quant_out with first_quant_out + auto last_op_names = last_op->Op()->Inputs().at(last_op_input_name); + std::replace(last_op_names.begin(), + last_op_names.end(), + quant_out->Name(), + first_quant_out->Name()); + last_op_op->SetInput(last_op_input_name, + std::vector(last_op_names)); + + IR_NODE_LINK_TO(first_quant_out, last_op); + GraphSafeRemoveNodes(graph, {quant_op, quant_out}); + removed_quantize++; + } + } + found_multiple_quantize_squash_count++; + }; + gpd(graph, handler); + AddStatis(found_multiple_quantize_squash_count); + PrettyLogDetail("--- squashed %d quantize op", removed_quantize); +} + +void XPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, + platform::errors::InvalidArgument( + "The graph in function XPUQuantizeSquashPass::ApplyImpl is null.")); + FusePassBase::Init("xpu_quantize_squash_pass", graph); + + std::unordered_map nodes_keep_counter; + FindNodesToKeep(graph, &nodes_keep_counter); + DequantQuantSquash(graph, &nodes_keep_counter); + OpDequantSquash(graph); + // QuantOpSquash(graph); // If the quant op is fused into conv2d_xpu, the + // performance will become worse. + MultipleQuantizeSquash(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(xpu_quantize_squash_pass, + paddle::framework::ir::XPUQuantizeSquashPass); diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h new file mode 100644 index 00000000000000..d3f37dd42010d0 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Squash dequantize->quantize pair pattern into requantize op + */ + +class XPUQuantizeSquashPass : public FusePassBase { + public: + XPUQuantizeSquashPass(); + virtual ~XPUQuantizeSquashPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + /* + * For each dequantize's output find the number of operators it is an input to + */ + void FindNodesToKeep( + Graph* graph, + std::unordered_map* nodes_keep_counter) const; + + /* + * Squash dequantize-quantize ops pairs into nothing + */ + void DequantQuantSquash( + Graph* graph, + std::unordered_map* nodes_keep_counter) const; + + /* + * Squash dequant if the previous operator support fp32 out + */ + void OpDequantSquash(Graph* graph) const; + + /* + * Squash quantize if several quatize ops have the same scale + */ + void MultipleQuantizeSquash(Graph* graph) const; + + /* + * Squash quantize if is before conv2d_xpu/fc_xpuy + */ + void QuantOpSquash(Graph* graph) const; + + const std::string name_scope_{"squash"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 9f8e9ed80ca46d..02d7ee2962fc3b 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/naive_executor.h" +#include #include #include #include @@ -51,6 +52,26 @@ void NaiveExecutor::Prepare(Scope *scope, CreateOps(program_desc, block_id, with_feed_fetch_ops); } +void NaiveExecutor::PrepareInterpreterCore( + Scope *scope, + const ProgramDesc &program_desc, + const framework::interpreter::ExecutionConfig &execution_config) { + interpreter_core_ = std::make_unique( + place_, program_desc.Block(0), scope, execution_config); +} + +void NaiveExecutor::RunInterpreterCore( + const std::vector &feed_names, bool need_fetch) { + platform::ScopedFlushDenormal flush; +#ifdef PADDLE_WITH_NVTX + platform::CudaNvtxRangePush("model", platform::NvtxRangeColor::Yellow); +#endif + interpreter_core_->Run(feed_names, need_fetch); +#ifdef PADDLE_WITH_NVTX + platform::CudaNvtxRangePop(); +#endif +} + void NaiveExecutor::Run() { #ifdef PADDLE_WITH_DNNL platform::AttachPointerHashToMKLDNNKey(this, place_); @@ -190,6 +211,9 @@ phi::DenseTensor *NaiveExecutor::FindTensor(const std::string &name) { void NaiveExecutor::RegisterOutputHook(const HookFunc &hookfunc) { output_hookfuncs_.push_back(hookfunc); + if (interpreter_core_) { + interpreter_core_->SetOutputHooks(output_hookfuncs_); + } } void NaiveExecutor::RegisterInputHook(const HookFunc &hookfunc) { diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 85f98046285b34..70e8c92ca4fbe7 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -26,6 +26,9 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" +#include "paddle/fluid/framework/new_executor/interpretercore.h" + namespace paddle { namespace framework { @@ -52,6 +55,12 @@ class NaiveExecutor { int block_id, bool with_feed_fetch_ops); + void PrepareInterpreterCore( + Scope* scope, + const ProgramDesc& program_desc, + const framework::interpreter::ExecutionConfig& execution_config = + framework::interpreter::ExecutionConfig{}); + // Create variables before head. // Create parameters if persistable is true, or create the temporary variables // instead. @@ -63,6 +72,9 @@ class NaiveExecutor { // Run all the operators. void Run(); + void RunInterpreterCore(const std::vector& feed_names = {}, + bool need_fetch = false); + // Get an tensor to operating directly, without the need for feed_ops. phi::DenseTensor* FindTensor(const std::string& name); @@ -96,6 +108,8 @@ class NaiveExecutor { std::unordered_map> reuse_cache_; std::vector cluster_buffer_; + + std::unique_ptr interpreter_core_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index 2716846b0e4de2..995de57c747f81 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -6,7 +6,7 @@ add_subdirectory(pir_adaptor) set(STANDALONE_EXECUTOR_SRCS feed_fetch_utils.cc interpretercore.cc new_executor_defs.cc - standalone_executor.cc program_interpreter.cc new_ir_interpreter.cc) + standalone_executor.cc program_interpreter.cc pir_interpreter.cc) set(STANDALONE_EXECUTOR_DEPS interpreter diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index 0c6442cd1f9d3f..e549b243f87ec4 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -17,7 +17,7 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/cinn/hlir/framework/instruction.h" -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/cinn/runtime/cuda/cuda_util.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/framework/paddle2cinn/transform_type.h" @@ -26,7 +26,7 @@ namespace paddle { namespace framework { class CinnJitInstruction::FnPtrImpl { - using CUDAJITInfo = cinn::hlir::framework::newir::CUDAJITInfo; + using CUDAJITInfo = cinn::hlir::framework::pir::CUDAJITInfo; public: explicit FnPtrImpl(const CUDAJITInfo& cuda_jit_info) diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index 2422597ece0d1a..c66e10a822056a 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -16,8 +16,8 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" @@ -93,13 +93,12 @@ CondInstruction::CondInstruction(size_t id, VLOG(6) << "finish process inputs outputs index"; Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); - true_branch_inter_ = - new NewIRInterpreter(place, - {}, - true_branch_block, - true_scope, - value_exec_info->NewChild(true_scope), - {}); + true_branch_inter_ = new PirInterpreter(place, + {}, + true_branch_block, + true_scope, + value_exec_info->NewChild(true_scope), + {}); std::set true_skip_gc_names_set; for (auto value : GetYiedOpInputs(true_branch_block)) { @@ -118,12 +117,12 @@ CondInstruction::CondInstruction(size_t id, Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); false_branch_inter_ = - new NewIRInterpreter(place, - {}, - false_branch_block, - false_scope, - value_exec_info->NewChild(false_scope), - {}); + new PirInterpreter(place, + {}, + false_branch_block, + false_scope, + value_exec_info->NewChild(false_scope), + {}); std::set false_skip_gc_names_set; for (auto value : GetYiedOpInputs(false_branch_block)) { @@ -149,7 +148,7 @@ CondInstruction::~CondInstruction() { } void CondInstruction::CopyBranchOutput( - const std::vector& var_names, const NewIRInterpreter* inter) { + const std::vector& var_names, const PirInterpreter* inter) { for (size_t i = 0; i < var_names.size(); ++i) { auto* inner_var = inter->InnerScope()->GetVar(var_names[i]); diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 469c0ed0ae1ab8..79af374ecdd326 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -24,7 +24,7 @@ namespace paddle { namespace framework { class Scope; class Value; -class NewIRInterpreter; +class PirInterpreter; class ValueExecutionInfo; class CondInstruction : public InstructionBase { @@ -44,7 +44,7 @@ class CondInstruction : public InstructionBase { private: void CopyBranchOutput(const std::vector& var_names, - const NewIRInterpreter* inter); + const PirInterpreter* inter); ::pir::Operation* op_; @@ -54,9 +54,9 @@ class CondInstruction : public InstructionBase { std::vector output_vars_; - NewIRInterpreter* true_branch_inter_; + PirInterpreter* true_branch_inter_; - NewIRInterpreter* false_branch_inter_; + PirInterpreter* false_branch_inter_; std::vector true_branch_outputs_; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index 4066bc7afb3dc6..79f93ef3425a69 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -31,6 +31,7 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/pir/core/block_argument.h" #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" @@ -231,6 +232,9 @@ std::vector GetOutsideOpInputs( inner_outputs.insert(op->result(i)); } } + for (size_t arg_id = 0; arg_id < block->args_size(); ++arg_id) { + inner_outputs.insert(block->argument(arg_id)); + } std::vector outside_op_inputs; for (auto op : (*block)) { diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc index b511ad1f602320..48226391c509d5 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc @@ -16,8 +16,8 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" @@ -76,28 +76,11 @@ WhileInstruction::WhileInstruction(size_t id, } body_block_ = while_op.body_block(); - auto body_block_outputs = GetYiedOpInputs(body_block_); - - Scope* body_scope = &(parent_exe_info->GetScope()->NewScope()); - auto body_exe_info = parent_exe_info->NewChild(body_scope); - for (size_t i = 0; i < body_block_->args_size(); ++i) { - auto var_name = "body_block_arg_" + std::to_string(i); - body_scope->Var(var_name); - body_exe_info->Add(body_block_->argument(i), var_name); - } - body_inter_ = std::unique_ptr(new NewIRInterpreter( - place, {}, body_block_, body_scope, body_exe_info, {})); - - std::set body_skip_gc_names_set; - for (auto value : body_block_outputs) { - body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); - body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); - } - body_inter_->SetSkipGcVars(body_skip_gc_names_set); std::unordered_map> inputs; GetInputIds(op, *parent_exe_info, &inputs); - + auto body_outside_inputs = + GetOutsideOpInputs(body_block_, *parent_exe_info, &inputs); SetInputs(inputs); std::unordered_map> outputs; @@ -116,6 +99,29 @@ WhileInstruction::WhileInstruction(size_t id, } } SetOutputs(outputs); + + Scope* body_scope = &(parent_exe_info->GetScope()->NewScope()); + auto body_exe_info = parent_exe_info->NewChild(body_scope); + for (size_t i = 0; i < body_block_->args_size(); ++i) { + auto var_name = "body_block_arg_" + std::to_string(i); + body_scope->Var(var_name); + body_exe_info->Add(body_block_->argument(i), var_name); + } + body_inter_ = std::unique_ptr(new PirInterpreter( + place, {}, body_block_, body_scope, body_exe_info, {})); + + std::set body_skip_gc_names_set; + auto body_block_outputs = GetYiedOpInputs(body_block_); + for (auto value : body_block_outputs) { + body_outputs_.push_back(body_inter_->GetNameByValue(value)); + body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); + body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); + } + for (auto value : body_outside_inputs) { + body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); + body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); + } + body_inter_->SetSkipGcVars(body_skip_gc_names_set); } void WhileInstruction::CopyInputsToOutputs() { @@ -138,10 +144,10 @@ void WhileInstruction::PassArgsToBodyBlock() { void WhileInstruction::GetValueFromBodyBlock() { cond_var_->GetMutable()->ShareDataWith( body_inter_->local_scope() - ->GetVar(body_skip_gc_names_[0]) + ->GetVar(body_outputs_[0]) ->Get()); for (size_t i = 0; i < outputs_.size(); ++i) { - auto& out_var_name = body_skip_gc_names_[i + 1]; + auto& out_var_name = body_outputs_[i + 1]; auto* out_var = body_inter_->local_scope()->GetVar(out_var_name); outputs_[i]->GetMutable()->ShareDataWith( out_var->Get()); diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.h b/paddle/fluid/framework/new_executor/instruction/while_instruction.h index d486c8206c5026..1c9cfabbde2867 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.h @@ -24,7 +24,7 @@ namespace paddle { namespace framework { class Scope; class Value; -class NewIRInterpreter; +class PirInterpreter; class ValueExecutionInfo; /// The execute semantics of while op ['output' = while_op('cond', 'intput')] @@ -65,7 +65,8 @@ class WhileInstruction : public InstructionBase { std::vector inputs_; std::vector outputs_; - std::unique_ptr body_inter_; + std::unique_ptr body_inter_; + std::vector body_outputs_; std::vector body_skip_gc_names_; ::pir::Block* body_block_; diff --git a/paddle/fluid/framework/new_executor/interpreter/execution_config.h b/paddle/fluid/framework/new_executor/interpreter/execution_config.h index 828678fa59da1f..def76235331f15 100644 --- a/paddle/fluid/framework/new_executor/interpreter/execution_config.h +++ b/paddle/fluid/framework/new_executor/interpreter/execution_config.h @@ -29,6 +29,7 @@ struct ExecutionConfig { bool used_for_cinn{false}; bool used_for_control_flow_op{false}; bool used_for_jit{false}; + bool used_for_inference{false}; size_t device_num_threads{0}; size_t host_num_threads{0}; diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 8c51e310b054c5..34ec3d7ac3b4ed 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -598,6 +598,11 @@ void BuildOpFuncList(const platform::Place& place, for (size_t i = 0; i < ops.size(); ++i) { auto op = ops[i].get(); const std::string& op_type = op->Type(); + if (execution_config.used_for_inference) { + if (op_type == "feed" || op_type == "fetch") { + continue; + } + } VLOG(6) << "Build OpFuncNode from : " << op_type; diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index bebeb142d473f1..2cf615d99b1ba3 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -50,6 +50,7 @@ std::set OpsCanSkipedFakeAllocInStaticBuild = { "create_py_reader", "depend", "fetch_v2", + "print", "send_v2", "nop"}; @@ -170,10 +171,10 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) { std::stringstream ss; ss << "The following OPs are unable to static build:\n"; for (auto& item : invalid_ops) { - ss << item.first << " [in_black_list = " << (item.second >> 6 & 1) - << ", is_operator_base = " << (item.second >> 5 & 1) - << ", is_custom_op = " << (item.second >> 4 & 1) - << ", use_mkldnn = " << (item.second >> 3 & 1) + ss << item.first << " [in_black_list = " << (item.second >> 5 & 1) + << ", is_operator_base = " << (item.second >> 4 & 1) + << ", is_custom_op = " << (item.second >> 3 & 1) + << ", use_mkldnn = " << (item.second >> 2 & 1) << ", sub_block_can_not_static_build = " << (item.second >> 1 & 1) << "]\n"; } diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 8e052d3b2685e0..1a1ee56e17525f 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/framework/new_executor/interpretercore.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/new_executor/program_interpreter.h" #include "paddle/pir/core/program.h" #include "paddle/pir/core/value.h" @@ -54,7 +54,7 @@ InterpreterCore::InterpreterCore( framework::Scope* scope, const ExecutionConfig& execution_config) { VLOG(4) << "InterpreterCore(): " << this << " on " << place; - impl_ = std::make_unique( + impl_ = std::make_unique( place, fetch_var_names, ir_block, scope, execution_config); } diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 3ae75ffd870088..900ac962bc022a 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -424,10 +424,12 @@ void HandleForSpecialOp(pir::Operation* op, // change opreand name to param_name auto orig_name = value_exe_info->GetValue2VarName().at(value); - if (value_exe_info->GetScope()->FindVar(var_name) == nullptr) { - const_cast(value_exe_info->GetScope()) - ->Rename(orig_name, var_name); + if (value_exe_info->GetScope()->FindVar(var_name) != nullptr) { + const_cast(value_exe_info->GetScope())->EraseVars({var_name}); + VLOG(1) << "var " << var_name << " has been removed from scope"; } + const_cast(value_exe_info->GetScope())->Rename(orig_name, var_name); + VLOG(8) << "var " << orig_name << " has been renamed to " << var_name; value_exe_info->Rename(value, var_name, orig_name); } diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index ce0484567b64f0..1968cf758910ba 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -379,6 +379,7 @@ void BuildPhiContext(pir::Operation* op, } // EmplaceBackOutputs + VLOG(8) << "ctx->EmplaceBackOutput: "; for (size_t i = 0; i < op->num_results(); ++i) { pir::Value out_ptr = op->result(i); if (!IsInvalid(out_ptr)) { @@ -399,11 +400,15 @@ void BuildPhiContext(pir::Operation* op, ctx->EmplaceBackOutput(OutType(const_cast( &(inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) ->Get())))); + VLOG(8) << "ctx->EmplaceBackOutput DenseTensor: " + << value_exec_info.GetVarName(out_ptr); } else if (out_ptr.type() .isa()) { ctx->EmplaceBackOutput(OutType(const_cast( &(inner_scope->FindVar(value_exec_info.GetVarName(out_ptr)) ->Get())))); + VLOG(8) << "ctx->EmplaceBackOutput SelectedRows: " + << value_exec_info.GetVarName(out_ptr); } else if (out_ptr.type().isa()) { OutListType outputs; auto& variable_array = @@ -423,6 +428,8 @@ void BuildPhiContext(pir::Operation* op, variable_array[i]->Type())); } } + VLOG(8) << "ctx->EmplaceBackOutput VariableRefArray: " + << value_exec_info.GetVarName(out_ptr); ctx->EmplaceBackOutputs(outputs); } else { PADDLE_THROW( diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc similarity index 90% rename from paddle/fluid/framework/new_executor/new_ir_interpreter.cc rename to paddle/fluid/framework/new_executor/pir_interpreter.cc index e527b9d254b8ce..ee05bf95789980 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" +#include #include #include "paddle/utils/flags.h" @@ -54,6 +55,13 @@ #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/core/builtin_attribute.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif PHI_DECLARE_bool(enable_new_ir_in_executor); PHI_DECLARE_bool(enable_new_ir_in_executor_trace_run); @@ -61,12 +69,11 @@ PHI_DECLARE_bool(enable_new_ir_in_executor_trace_run); namespace paddle { namespace framework { -NewIRInterpreter::NewIRInterpreter( - const platform::Place& place, - const std::vector& fetch_var_names, - const ::pir::Block* ir_block, - framework::Scope* scope, - const ExecutionConfig& execution_config) +PirInterpreter::PirInterpreter(const platform::Place& place, + const std::vector& fetch_var_names, + const ::pir::Block* ir_block, + framework::Scope* scope, + const ExecutionConfig& execution_config) : place_(place), execution_config_(execution_config), var_scope_(scope), @@ -74,7 +81,7 @@ NewIRInterpreter::NewIRInterpreter( ir_block_(ir_block), ir_stream_analyzer_(place), fetch_var_names_(fetch_var_names) { - VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; + VLOG(4) << "PirInterpreter(): " << this << " on " << place_; static_build_ = FLAGS_new_executor_static_build && !FLAGS_new_executor_use_cuda_graph && @@ -118,11 +125,12 @@ NewIRInterpreter::NewIRInterpreter( value_exe_info_ = std::make_shared(InnerScope()); std::stringstream ss; - ss << this; + ss << this + << std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); BuildScope(*ir_block_, ss.str(), value_exe_info_.get()); } -NewIRInterpreter::NewIRInterpreter( +PirInterpreter::PirInterpreter( const platform::Place& place, const std::vector& fetch_var_names, const ::pir::Block* ir_block, @@ -136,7 +144,7 @@ NewIRInterpreter::NewIRInterpreter( ir_block_(ir_block), ir_stream_analyzer_(place), fetch_var_names_(fetch_var_names) { - VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; + VLOG(4) << "PirInterpreter(): " << this << " on " << place_; static_build_ = FLAGS_new_executor_static_build && !FLAGS_new_executor_use_cuda_graph && @@ -184,11 +192,11 @@ NewIRInterpreter::NewIRInterpreter( BuildScope(*ir_block_, ss.str(), value_exe_info_.get()); } -NewIRInterpreter::~NewIRInterpreter() { +PirInterpreter::~PirInterpreter() { // cancle gc's thread gc_.reset(nullptr); async_work_queue_.reset(); - VLOG(4) << "~NewIRInterpreter(): " << this << " on " << place_; + VLOG(4) << "~PirInterpreter(): " << this << " on " << place_; #ifdef PADDLE_WITH_DNNL // Clear mkl-dnn cache, @@ -197,13 +205,12 @@ NewIRInterpreter::~NewIRInterpreter() { #endif } -void NewIRInterpreter::SetCopyProgram(std::shared_ptr prog) { +void PirInterpreter::SetCopyProgram(std::shared_ptr prog) { PADDLE_THROW(platform::errors::Unimplemented( - "SetCopyProgram is not implemented in NewIRInterpreter.")); + "SetCopyProgram is not implemented in PirInterpreter.")); } -void NewIRInterpreter::SetSkipGcVars( - const std::set& skip_gc_vars) { +void PirInterpreter::SetSkipGcVars(const std::set& skip_gc_vars) { PADDLE_ENFORCE_EQ( execution_config_.skip_gc_vars.empty(), true, @@ -214,7 +221,7 @@ void NewIRInterpreter::SetSkipGcVars( execution_config_.skip_gc_vars = skip_gc_vars; } -void NewIRInterpreter::SetJitInputVars( +void PirInterpreter::SetJitInputVars( const std::set& jit_input_vars) { PADDLE_ENFORCE_EQ( execution_config_.jit_input_vars.empty(), @@ -226,15 +233,15 @@ void NewIRInterpreter::SetJitInputVars( execution_config_.jit_input_vars = jit_input_vars; } -const std::set& NewIRInterpreter::JitInputVars() const { +const std::set& PirInterpreter::JitInputVars() const { return execution_config_.jit_input_vars; } -const VariableScope* NewIRInterpreter::GetVariableScope() const { +const VariableScope* PirInterpreter::GetVariableScope() const { return &var_scope_; } -void NewIRInterpreter::reset_scope(Scope* new_scope) { +void PirInterpreter::reset_scope(Scope* new_scope) { var_scope_.SetScope(new_scope); scope_ = new_scope; for (size_t i = 0; i < value_exe_info_->GetVarList().size(); i++) { @@ -244,7 +251,7 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { } // The index should be assured valid, cause the InterpreterCore may not be // fully built, but was still cached and used. For example, see unit test - // `test_assert.py`, it may exit before `NewIRInterpreter::Convert`, + // `test_assert.py`, it may exit before `PirInterpreter::Convert`, // but still was cached and used by later tests. for (size_t i = 0; i < std::min(refs_.size(), value_exe_info_->GetVarList().size()); @@ -253,16 +260,16 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { } } -const Scope* NewIRInterpreter::local_scope() const { return local_scope_; } +const Scope* PirInterpreter::local_scope() const { return local_scope_; } -void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { - async_work_queue_ = reinterpret_cast(src)->GetWorkQueue(); +void PirInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { + async_work_queue_ = reinterpret_cast(src)->GetWorkQueue(); VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src << ") to InterpreterCore(" << this << ")"; } -void NewIRInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { - const NewIRInterpreter& impl = dynamic_cast(src); +void PirInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { + const PirInterpreter& impl = dynamic_cast(src); if (is_shared_results_build_ || !impl.IsSharedResultsBuild()) { return; } @@ -277,25 +284,25 @@ void NewIRInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { } const interpreter::NewIrDependencyBuilder& -NewIRInterpreter::GetNewIrDependencyBuilder() const { +PirInterpreter::GetNewIrDependencyBuilder() const { return ir_dependency_builder_; } -std::shared_ptr> NewIRInterpreter::GetDependencyCount() +std::shared_ptr> PirInterpreter::GetDependencyCount() const { return dependecy_count_; } -const interpreter::NewIrStreamAnalyzer& -NewIRInterpreter::GetNewIrStreamAnalyzer() const { +const interpreter::NewIrStreamAnalyzer& PirInterpreter::GetNewIrStreamAnalyzer() + const { return ir_stream_analyzer_; } -bool NewIRInterpreter::IsSharedResultsBuild() const { +bool PirInterpreter::IsSharedResultsBuild() const { return is_shared_results_build_; } -std::shared_ptr NewIRInterpreter::GetWorkQueue() { +std::shared_ptr PirInterpreter::GetWorkQueue() { if (async_work_queue_ == nullptr) { async_work_queue_ = std::make_shared( execution_config_.host_num_threads, @@ -305,7 +312,7 @@ std::shared_ptr NewIRInterpreter::GetWorkQueue() { return async_work_queue_; } -void NewIRInterpreter::PrepareForCUDAGraphCapture() { +void PirInterpreter::PrepareForCUDAGraphCapture() { if (!FLAGS_new_executor_use_cuda_graph) return; #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_EQ( @@ -330,7 +337,7 @@ void NewIRInterpreter::PrepareForCUDAGraphCapture() { #endif } -void NewIRInterpreter::CheckCUDAGraphBeforeRun( +void PirInterpreter::CheckCUDAGraphBeforeRun( const std::vector& feed_names) { #ifdef PADDLE_WITH_CUDA if (platform::IsCUDAGraphCapturing()) { @@ -354,7 +361,7 @@ void NewIRInterpreter::CheckCUDAGraphBeforeRun( #endif } -void NewIRInterpreter::ClearLoDTensorArrayInLocalScope() { +void PirInterpreter::ClearLoDTensorArrayInLocalScope() { auto vars = local_scope_->LocalVars(); for (auto var : vars) { if (var->IsType()) { @@ -364,7 +371,7 @@ void NewIRInterpreter::ClearLoDTensorArrayInLocalScope() { } } -std::string NewIRInterpreter::GetDepsString() const { +std::string PirInterpreter::GetDepsString() const { std::stringstream ss; auto downstream_map = ir_dependency_builder_.OpDownstreamMap(); ss << "Note: when static_dep is 1, it is ok that the dynamic_dep will not " @@ -383,17 +390,17 @@ std::string NewIRInterpreter::GetDepsString() const { return ss.str(); } -bool NewIRInterpreter::HasLocalScope() const { return local_scope_ != nullptr; } +bool PirInterpreter::HasLocalScope() const { return local_scope_ != nullptr; } -Scope* NewIRInterpreter::InnerScope() const { +Scope* PirInterpreter::InnerScope() const { return local_scope_ != nullptr ? local_scope_ : scope_; } -std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const { +std::string PirInterpreter::GetNameByValue(::pir::Value value) const { return value_exe_info_->GetVarName(value); } -void NewIRInterpreter::UpdateSyncOpNum() { +void PirInterpreter::UpdateSyncOpNum() { int64_t sync_op_num = 0; for (auto& ins : vec_instruction_base_) { if (ins->KernelType() == OpFuncType::kCpuSync || @@ -405,7 +412,7 @@ void NewIRInterpreter::UpdateSyncOpNum() { VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_; } -void NewIRInterpreter::UpdateNcclOpNum() { +void PirInterpreter::UpdateNcclOpNum() { static std::set nccl_op_set = { "pd_op.c_softmax_with_cross_entropy", "pd_op.c_allgather", @@ -496,7 +503,7 @@ void NewIRInterpreter::UpdateNcclOpNum() { // ->(sync_run)-> OP(B) OP(O) ->(direct_run)-> OP(C) ->(direct_run)-> OP(D) If B // is run before C, B may always block to wait for A to finish executing, but in // fact, C can be executed first during this time. -void NewIRInterpreter::AnalyseExecuteOrderForTrace( +void PirInterpreter::AnalyseExecuteOrderForTrace( std::map> op_downstream_map, InstructionSchedulingPriorityLess compare) { VLOG(4) << "Analyze the execution order of Trace scheduling mode."; @@ -556,7 +563,7 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace( /// For new ir /// /// ======================== /// -void NewIRInterpreter::BuildInstruction() { +void PirInterpreter::BuildInstruction() { VLOG(6) << "Build Instructions for new ir ... "; vec_instruction_base_.clear(); size_t op_idx = 0; @@ -613,7 +620,7 @@ void NewIRInterpreter::BuildInstruction() { } } -std::string NewIRInterpreter::DebugValueInfo() { +std::string PirInterpreter::DebugValueInfo() { std::stringstream os; os << "value info of interpretercore " << this << "\n" << "value -> var_name -> id -> variable*" @@ -641,7 +648,7 @@ std::string NewIRInterpreter::DebugValueInfo() { return os.str(); } -void NewIRInterpreter::BuildInstructionDependences() { +void PirInterpreter::BuildInstructionDependences() { // analysis the dependences between instructions, add next_instr_list to each // instr, and set the dependecy_count_ size_t instr_num = vec_instruction_base_.size(); @@ -697,7 +704,7 @@ void NewIRInterpreter::BuildInstructionDependences() { } } -void NewIRInterpreter::RecordMemcpyD2H(InstructionBase* instr_node) { +void PirInterpreter::RecordMemcpyD2H(InstructionBase* instr_node) { // NOTE(zhiqiu): hot fix for jit input var if (instr_node->Name() == "pd_op.memcpy_d2h") { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); @@ -712,7 +719,7 @@ void NewIRInterpreter::RecordMemcpyD2H(InstructionBase* instr_node) { } } -void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { +void PirInterpreter::RecordStreamForGC(InstructionBase* instr) { #if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) PADDLE_THROW(platform::errors::Unimplemented( "RecordStreamForGC is only implemented when compiled with GPU.")); @@ -730,6 +737,29 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { gpuStream_t stream = reinterpret_cast(instr->DeviceContext()).stream(); +// TODO(lizhiyu): Only analyse the 'send_v2' for GPT pp strategy right now. +// To support all the operators for communicating in the future. +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + if (instr->Name() == "pd_op.send_v2") { + ::pir::Operation* op = instr->Operation(); + if (op->HasAttribute("use_calc_stream") && + op->attribute<::pir::BoolAttribute>("use_calc_stream").data() == + false) { + int ring_id = op->attribute<::pir::Int32Attribute>("ring_id").data(); + if (FLAGS_dynamic_static_unified_comm) { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + stream = static_cast( + comm_context_manager.Get(std::to_string(ring_id))) + ->GetStream(); + } else { + stream = platform::NCCLCommContext::Instance() + .Get(ring_id, instr->DeviceContext().GetPlace()) + ->stream(); + } + } + } +#endif auto TensorRecordStream = [&stream](phi::DenseTensor& tensor) { auto allocation = tensor.Holder(); if (allocation == nullptr) { @@ -827,7 +857,7 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { #endif } -void NewIRInterpreter::CheckGC(InstructionBase* instr) { +void PirInterpreter::CheckGC(InstructionBase* instr) { platform::RecordEvent record( "CheckGC", platform::TracerEventType::UserDefined, 10); @@ -855,8 +885,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) { } } -void NewIRInterpreter::CalculateLastLiveOps() { - VLOG(4) << "NewIRInterpreter(): " << this << " start CalculateLastLiveOps"; +void PirInterpreter::CalculateLastLiveOps() { + VLOG(4) << "PirInterpreter(): " << this << " start CalculateLastLiveOps"; // calculate last_live_ops_ for (size_t op_idx = 0; op_idx < vec_instruction_base_.size(); ++op_idx) { InstructionBase* instr = vec_instruction_base_[op_idx].get(); @@ -967,7 +997,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { VLOG(4) << "done CalculateLastLiveOps"; } -void NewIRInterpreter::ConstructEventForJitInput() { +void PirInterpreter::ConstructEventForJitInput() { for (size_t i = 0; i < dependecy_count_->size(); ++i) { if ((*dependecy_count_)[i] == 0) { InstructionBase* inst = vec_instruction_base_[i].get(); @@ -991,7 +1021,7 @@ void NewIRInterpreter::ConstructEventForJitInput() { } } -paddle::framework::FetchList NewIRInterpreter::Run( +paddle::framework::FetchList PirInterpreter::Run( const std::vector& feed_names, const std::vector& feed_tensors) { auto FeedInput = [&] { @@ -1098,8 +1128,8 @@ paddle::framework::FetchList NewIRInterpreter::Run( } } -FetchList NewIRInterpreter::Run(const std::vector& feed_names, - bool need_fetch) { +FetchList PirInterpreter::Run(const std::vector& feed_names, + bool need_fetch) { SetDeviceId(place_); CheckCUDAGraphBeforeRun(feed_names); @@ -1188,7 +1218,7 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, } } -void NewIRInterpreter::TraceRunImpl() { +void PirInterpreter::TraceRunImpl() { // lazy initialization of gc, do not create gc is the program only run once if (!gc_) { gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_base_); @@ -1201,7 +1231,7 @@ void NewIRInterpreter::TraceRunImpl() { VLOG(4) << "Done TraceRunInstructionList"; } -void NewIRInterpreter::MultiThreadRunImpl() { +void PirInterpreter::MultiThreadRunImpl() { // lazy initialization of gc, do not create gc is the program only run once if (!gc_) { gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_base_); @@ -1215,7 +1245,7 @@ void NewIRInterpreter::MultiThreadRunImpl() { VLOG(4) << "Done MultiThreadRunInstructionList"; } -void NewIRInterpreter::TraceRunInstructionList( +void PirInterpreter::TraceRunInstructionList( const std::vector>& vec_instr) { unfinished_op_number_ = vec_instr.size(); if (unfinished_op_number_ == 0) { @@ -1259,7 +1289,7 @@ void NewIRInterpreter::TraceRunInstructionList( VLOG(4) << "Done TraceRunInstructionList"; } -void NewIRInterpreter::MultiThreadRunInstructionList( +void PirInterpreter::MultiThreadRunInstructionList( const std::vector>& vec_instr) { unfinished_op_number_ = vec_instr.size(); if (unfinished_op_number_ == 0) { @@ -1340,7 +1370,7 @@ void NewIRInterpreter::MultiThreadRunInstructionList( } } -void NewIRInterpreter::RunInstructionBaseAsync(size_t instr_id) { +void PirInterpreter::RunInstructionBaseAsync(size_t instr_id) { // NOTE(Ruibiao): Due to the uncertain order in multi-threading asynchronous // scheduling, the priority order involved cross-thread scheduling is not // guaranteed. Only Ops scheduled by the same AddTask call have the guarantee @@ -1374,8 +1404,8 @@ void NewIRInterpreter::RunInstructionBaseAsync(size_t instr_id) { } } -void NewIRInterpreter::RunNextInstructions(InstructionBase* instr, - SchedulingQueue* reserved_next_ops) { +void PirInterpreter::RunNextInstructions(InstructionBase* instr, + SchedulingQueue* reserved_next_ops) { platform::RecordEvent record( "RunNextInstructions", platform::TracerEventType::UserDefined, 10); @@ -1400,7 +1430,7 @@ void NewIRInterpreter::RunNextInstructions(InstructionBase* instr, } } -void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { +void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { platform::RecordEvent instruction_event( instr_node->Name(), platform::TracerEventType::Operator, 1); @@ -1467,7 +1497,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { } } -void NewIRInterpreter::PreAnalysis() { +void PirInterpreter::PreAnalysis() { BuildInstructionDependences(); VLOG(4) << "Done BuildInstructionDependences"; @@ -1493,14 +1523,14 @@ void NewIRInterpreter::PreAnalysis() { VLOG(4) << "Done UpdateNcclOpNum"; } -void NewIRInterpreter::Build( +void PirInterpreter::Build( const std::vector& feed_names, std::vector* op_func_nodes) { PADDLE_THROW(platform::errors::Unimplemented( - "Build is not implemented in NewIRInterpreter.")); + "Build is not implemented in PirInterpreter.")); } -::pir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { +::pir::Value PirInterpreter::GetValueByName(const std::string& var_name) { for (auto kv : value_exe_info_->GetValue2VarName()) { if (kv.second == var_name) { return kv.first; @@ -1509,7 +1539,7 @@ ::pir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { return nullptr; } -void NewIRInterpreter::SolvePersisableVarNames() { +void PirInterpreter::SolvePersisableVarNames() { VLOG(6) << "SolvePersisableVarNames"; for (auto kv : value_exe_info_->GetValue2VarName()) { ::pir::Value value = kv.first; diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/pir_interpreter.h similarity index 90% rename from paddle/fluid/framework/new_executor/new_ir_interpreter.h rename to paddle/fluid/framework/new_executor/pir_interpreter.h index 3a128791cdfce6..80052308e87432 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/pir_interpreter.h @@ -25,7 +25,7 @@ class Block; namespace paddle { namespace framework { class ValueExecutionInfo; -class NewIRInterpreter : public InterpreterBaseImpl { +class PirInterpreter : public InterpreterBaseImpl { using ExecutionConfig = interpreter::ExecutionConfig; using InstructionSchedulingPriorityLess = std::function; using SchedulingQueue = @@ -34,20 +34,20 @@ class NewIRInterpreter : public InterpreterBaseImpl { InstructionSchedulingPriorityLess>; public: - NewIRInterpreter(const platform::Place& place, - const std::vector& fetch_var_names, - const ::pir::Block* ir_block, - Scope* scope, - const ExecutionConfig& execution_config = ExecutionConfig()); - - NewIRInterpreter(const platform::Place& place, - const std::vector& fetch_var_names, - const ::pir::Block* ir_block, - Scope* scope, - std::shared_ptr value_exe_info, - const ExecutionConfig& execution_config = ExecutionConfig()); - - ~NewIRInterpreter(); + PirInterpreter(const platform::Place& place, + const std::vector& fetch_var_names, + const ::pir::Block* ir_block, + Scope* scope, + const ExecutionConfig& execution_config = ExecutionConfig()); + + PirInterpreter(const platform::Place& place, + const std::vector& fetch_var_names, + const ::pir::Block* ir_block, + Scope* scope, + std::shared_ptr value_exe_info, + const ExecutionConfig& execution_config = ExecutionConfig()); + + ~PirInterpreter(); paddle::framework::FetchList Run( const std::vector& feed_names, diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 294bc28c3ff2dd..2df562e7bef18c 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -110,8 +110,9 @@ void ProgramInterpreter::RunImpl() { interpreter::ResetAtomicGuard guard(&deps_, &refs_); - if ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && - (sync_op_num_ == 0)) { + if (execution_config_.used_for_inference || + ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && + (sync_op_num_ == 0))) { VLOG(4) << "Tracing Instruction List"; TraceInstructionList(vec_instruction_); } else { @@ -857,6 +858,10 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) { : var_scope_.GetMutableScope(); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); + if (op->Type() == "while") { + op->SetOutputHooks(hookfuncs_); + } + auto op_with_kernel = dynamic_cast(op); { // If it is OperatorBase, InferShape do nothing. @@ -1446,7 +1451,7 @@ bool ProgramInterpreter::HasLocalScope() const { // miss. When a model is all KQueueAsync type OPs, all OPs will be distributed // to the DeviceThread for execution, and the multithreading scheduling will not // have any benefits. Therefore, in the dynamic to static, when the number of -// KQueueAsync Ops is 0, we choose Trace mode. +// KQueueSync Ops is 0, we choose Trace mode. void ProgramInterpreter::TraceInstructionList( const std::vector& vec_instr) { unfinished_op_number_ = vec_instr.size(); diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index f06bee2c884e31..f0d2bc14b3b5b7 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -55,7 +55,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, const std::string& job_type = job->Type(); std::shared_ptr program = nullptr; std::shared_ptr<::pir::Program> ir_program = nullptr; - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_new_ir_in_executor) { ir_program = plan_.IrProgram(job_type); } else { program = std::make_shared(*(plan_.Program(job_type))); @@ -80,10 +80,6 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, // TODO(phlrain) we only support cpu for now if (FLAGS_enable_new_ir_in_executor) { std::shared_ptr<::pir::Program> base_program = ir_program; - if (!FLAGS_enable_pir_api) { - VLOG(6) << "begin to translate" << std::endl; - base_program = paddle::TranslateLegacyProgramToProgram(*program); - } auto block = base_program->block(); for (auto it = block->begin(); it != block->end(); ++it) { if ((*it)->isa()) { diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index c3d4c3329016ad..cc9c2c2e6f5f5f 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -577,6 +577,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(apply_optim_); CP_MEMBER(skip_load_params_); + CP_MEMBER(use_new_executor_); + if (use_gpu_) { PADDLE_ENFORCE_EQ(use_xpu_, false, diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a098bc524f2555..8b6e3317a8a96b 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -98,6 +98,10 @@ #include "paddle/phi/backends/xpu/xpu_info.h" #endif +#ifdef PADDLE_WITH_NVTX +#include "paddle/fluid/platform/device/gpu/cuda/cuda_profiler.h" +#endif + namespace paddle { namespace { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -334,15 +338,13 @@ bool AnalysisPredictor::Init( const std::shared_ptr &parent_scope, const std::shared_ptr &program) { VLOG(3) << "Predictor::init()"; +#ifdef PADDLE_WITH_NVTX if (config_.with_profile_) { LOG(WARNING) << "Profiler is activated, which might affect the performance"; - auto tracking_device = config_.use_gpu() ? platform::ProfilerState::kAll - : platform::ProfilerState::kCPU; - platform::EnableProfiler(tracking_device); - } else { - VLOG(2) << "Profiler is deactivated, and no profiling report will be " - "generated."; + platform::CudaProfilerStart(); + platform::NvprofEnableRecordEvent(); } +#endif if (!status_is_cloned_) { root_predictor_id_ = predictor_id_; @@ -702,6 +704,20 @@ bool AnalysisPredictor::PrepareExecutor() { executor_->Prepare( sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_); + if (config_.new_executor_enabled()) { + framework::interpreter::ExecutionConfig execution_config; + execution_config.create_local_scope = false; + execution_config.used_for_inference = true; + auto input_names = GetInputNames(); + execution_config.skip_gc_vars.insert(input_names.begin(), + input_names.end()); + auto output_names = GetOutputNames(); + execution_config.skip_gc_vars.insert(output_names.begin(), + output_names.end()); + executor_->PrepareInterpreterCore( + sub_scope_, *inference_program_, execution_config); + } + if (config_.enable_memory_optim_) { auto *pass_res_info = inference::analysis::PassResultInfoForRuntime::Instance(); @@ -1082,8 +1098,6 @@ bool AnalysisPredictor::Run(const std::vector &inputs, if (config_.use_mkldnn_) MkldnnPreSet(inputs); #endif VLOG(3) << "Predictor::predict"; - inference::Timer timer; - timer.tic(); // set feed variable framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get(); PADDLE_ENFORCE_NOT_NULL( @@ -1107,9 +1121,13 @@ bool AnalysisPredictor::Run(const std::vector &inputs, HookCollectShapeRangeInfo(); } - // Run the inference program - // if share variables, we need not create variables - executor_->Run(); + if (config_.new_executor_enabled()) { + executor_->RunInterpreterCore(); + } else { + // Run the inference program + // if share variables, we need not create variables + executor_->Run(); + } // get fetch variable if (!GetFetch(output_data, scope)) { @@ -1117,8 +1135,6 @@ bool AnalysisPredictor::Run(const std::vector &inputs, return false; } - VLOG(3) << "predict cost: " << timer.toc() << "ms"; - // All the containers in the scope will be hold in inference, but the // operators assume that the container will be reset after each batch. // Here is a bugfix, collect all the container variables, and reset then to a @@ -1178,9 +1194,13 @@ bool AnalysisPredictor::Run(const std::vector &inputs, HookCollectShapeRangeInfo(); } - // Run the inference program - // if share variables, we need not create variables - executor_->Run(); + if (config_.new_executor_enabled()) { + executor_->RunInterpreterCore(); + } else { + // Run the inference program + // if share variables, we need not create variables + executor_->Run(); + } inference::DisplayMemoryInfo(place_, "after run"); @@ -2094,11 +2114,7 @@ bool AnalysisPredictor::ZeroCopyRun() { #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) if (config_.dist_config().use_dist_model()) { VLOG(3) << "ZeroCopyRun will use the fleet executor."; - inference::Timer timer; - timer.tic(); fleet_exe_->Run(config_.dist_config().carrier_id()); - VLOG(3) << "Fleet executor inf runs once use: " - << std::to_string(timer.toc()) << "ms"; return true; } #endif @@ -2155,7 +2171,11 @@ bool AnalysisPredictor::ZeroCopyRun() { } #endif - executor_->Run(); + if (config_.new_executor_enabled()) { + executor_->RunInterpreterCore(); + } else { + executor_->Run(); + } inference::DisplayMemoryInfo(place_, "after run"); #ifdef PADDLE_WITH_XPU @@ -2607,10 +2627,12 @@ AnalysisPredictor::~AnalysisPredictor() { // NOLINT SaveTrtCalibToDisk(); } #endif +#ifdef PADDLE_WITH_NVTX if (config_.with_profile_) { - platform::DisableProfiler(platform::EventSortingKey::kTotal, - "./profile.log"); + platform::NvprofDisableRecordEvent(); + platform::CudaProfilerStop(); } +#endif if (sub_scope_) { if (framework::global_transfer_scope_key().find(sub_scope_) != framework::global_transfer_scope_key().end()) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index beecfc9743b104..fb5be9125cb3c2 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -564,15 +564,12 @@ class AnalysisPredictor : public PaddlePredictor { std::shared_ptr scope_; framework::Scope *sub_scope_{nullptr}; std::shared_ptr inference_program_; - framework::OpCompatibleMap op_compatible_map_; std::vector feeds_; std::map feed_names_; // Sorted according to the idx. std::map idx2feeds_; std::vector fetches_; std::map idx2fetches_; - std::once_flag register_input_hook_flag_; - std::once_flag register_output_hook_flag_; phi::DataType model_precision_{phi::DataType::FLOAT32}; @@ -592,16 +589,14 @@ class AnalysisPredictor : public PaddlePredictor { details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; // A mutex help to make Clone thread safe. std::mutex clone_mutex_; + static int clone_num_; - // For memory optimization. - const size_t max_shape_collect_count_{1000}; - int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true. - std::vector>> batch_var_shapes_; int predictor_id_; int root_predictor_id_{-1}; private: - std::vector hookfuncs_; + std::once_flag register_input_hook_flag_; + std::once_flag register_output_hook_flag_; std::vector output_hookfuncs_; std::vector input_hookfuncs_; // Some status here that help to determine the status inside the predictor. @@ -609,7 +604,6 @@ class AnalysisPredictor : public PaddlePredictor { std::map>> shape_info_; std::map>> shape_tensor_value_; - static int clone_num_; bool private_context_{false}; void *predictor_stream_{nullptr}; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.h b/paddle/fluid/inference/api/mkldnn_quantizer.h index a44da8085f35b9..17fe7fff3aa21a 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.h +++ b/paddle/fluid/inference/api/mkldnn_quantizer.h @@ -21,7 +21,6 @@ #include #include -#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/api_impl.h" diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index ccefb05896d3f3..94215dddc6ccea 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -880,6 +880,10 @@ struct PD_INFER_DECL AnalysisConfig { /// int tensorrt_optimization_level() { return trt_optimization_level_; } + void EnableNewExecutor(bool x = true) { use_new_executor_ = x; } + + bool new_executor_enabled() const { return use_new_executor_; } + void EnableDlnne( int min_subgraph_size = 3, int max_batch_size = 1, @@ -1305,6 +1309,8 @@ struct PD_INFER_DECL AnalysisConfig { bool use_feed_fetch_ops_{true}; bool ir_debug_{false}; + bool use_new_executor_{false}; + bool specify_input_name_{false}; int cpu_math_library_num_threads_{1}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c7f3f87a4d192d..25c2e0988c4199 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -507,6 +507,8 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ + "delete_quant_dequant_linear_op_pass", + "delete_weight_dequant_linear_op_pass", "delete_assign_op_pass", "delete_dropout_op_pass", "delete_concat_op_pass", @@ -559,9 +561,11 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fast_where_xpu_fuse_pass", "elementwise_mul_add_fuse_pass", "link_xpu_op_max_pass", - "delete_isolated_node_pass", // "auto_mixed_precision_pass", "cast_mixed_precision_op_fuse_pass", + "xpu_quantize_op_pass", + "xpu_quantize_squash_pass", + "delete_isolated_node_pass", "inplace_op_var_pass", }); use_xpu_ = true; diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index 1033dc65f2dcc6..b3b0cd35fb300b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -15,10 +15,10 @@ #include #include "glog/logging.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/flags.h" diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 11c2743117586b..fe65cb2327255e 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -60,40 +60,56 @@ const std::unordered_set ProgramTranslator::unsupported_ops = { static std::vector GetCondOpIds(const BlockDesc& src_block, uint64_t first_id) { - std::vector op_list = {first_id}; - if (((first_id + 1) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 1))->Type() == "logical_not")) { - op_list.emplace_back(first_id + 1); + uint64_t temp_id = first_id; + // add conditional_block + std::vector op_list = {temp_id}; + temp_id++; + // add logical_not + if ((temp_id < src_block.OpSize()) && + (src_block.Op(static_cast(temp_id))->Type() == "logical_not")) { + op_list.emplace_back(temp_id); + temp_id++; } - if (((first_id + 2) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 2))->Type() == + // add conditional_block + if ((temp_id < src_block.OpSize()) && + (src_block.Op(static_cast(temp_id))->Type() == "conditional_block")) { - op_list.emplace_back(first_id + 2); + op_list.emplace_back(temp_id); + temp_id++; } - if (((first_id + 3) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 3))->Type() == "cast")) { - op_list.emplace_back(first_id + 3); + // add cast + if ((temp_id < src_block.OpSize()) && + (src_block.Op(static_cast(temp_id))->Type() == "cast")) { + op_list.emplace_back(temp_id); + temp_id++; } // Note(zhangbo): Some output variables are input, without select_input op. - std::vector output_names = - src_block.Op(static_cast(first_id))->Output("Out"); - std::vector input_names = - src_block.Op(static_cast(first_id))->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - size_t output_size = diffs.size(); - for (size_t i = 0; i < output_size; i++) { - if (((first_id + 4 + i) < src_block.OpSize()) && - (src_block.Op(static_cast(first_id + 4 + i))->Type() == - "select_input")) { - op_list.emplace_back(first_id + 4 + i); + std::vector init_op_list; + while (temp_id < src_block.OpSize()) { + if ((src_block.Op(static_cast(temp_id))->Type() == "fill_constant") || + (src_block.Op(static_cast(temp_id))->Type() == "assign_value")) { + init_op_list.emplace_back(temp_id); + temp_id++; + } else { + break; + } + } + std::vector select_input_op_list; + while (temp_id < src_block.OpSize()) { + if (src_block.Op(static_cast(temp_id))->Type() == "select_input") { + select_input_op_list.emplace_back(temp_id); + temp_id++; + } else { + break; } } + + if (select_input_op_list.size() > 0) { + op_list.insert(op_list.end(), init_op_list.begin(), init_op_list.end()); + } + op_list.insert( + op_list.end(), select_input_op_list.begin(), select_input_op_list.end()); + return op_list; } @@ -114,66 +130,86 @@ const std::string& ConditionBlockCombination::CondVarName() const { return op_list_[0]->Input("Cond")[0]; } -size_t ConditionBlockCombination::OutputSize() const { - std::vector output_names = op_list_[0]->Output("Out"); - std::vector input_names = op_list_[0]->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - return diffs.size(); -} - -std::vector<::paddle::framework::VarDesc*> +std::vector> ConditionBlockCombination::OutputVars() const { - std::vector<::paddle::framework::VarDesc*> outputs; - if (this->OutputSize() > 0) { - for (size_t i = 4; i < op_list_.size(); i++) { - outputs.emplace_back(op_list_[i]->Block()->FindVarRecursive( - op_list_[i]->Output("Out")[0])); + std::vector<::paddle::framework::VarDesc*> if_outputs; + std::vector<::paddle::framework::VarDesc*> true_block_outputs; + std::vector<::paddle::framework::VarDesc*> false_block_outputs; + for (::paddle::framework::OpDesc* op : op_list_) { + if (op->Type() == "select_input") { + if_outputs.emplace_back( + op->Block()->FindVarRecursive(op->Output("Out")[0])); + true_block_outputs.emplace_back( + op->Block()->FindVarRecursive(op->Input("X")[1])); + false_block_outputs.emplace_back( + op->Block()->FindVarRecursive(op->Input("X")[0])); } } - return outputs; + return {if_outputs, true_block_outputs, false_block_outputs}; +} + +size_t ConditionBlockCombination::MainOutputSize() const { + return OutputVars()[0].size(); } std::vector ConditionBlockCombination::TrueBlockOutputVarNames() const { - std::vector output_names = op_list_[0]->Output("Out"); - std::vector input_names = op_list_[0]->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - return diffs; + std::vector output_names; + for (::paddle::framework::OpDesc* op : op_list_) { + if (op->Type() == "select_input") { + output_names.emplace_back(op->Input("X")[1]); + } + } + return output_names; } -std::vector ConditionBlockCombination::FalseBlockOutputVarNames() - const { - if (op_list_.size() > 1) { - std::vector output_names = op_list_[2]->Output("Out"); - std::vector input_names = op_list_[2]->Input("Input"); - std::vector diffs(output_names.size()); - auto iter = std::set_difference(output_names.begin(), - output_names.end(), - input_names.begin(), - input_names.end(), - diffs.begin()); - diffs.resize(iter - diffs.begin()); - return diffs; - } - return {""}; +std::vector<::paddle::framework::OpDesc*> +ConditionBlockCombination::TrueBlockInitOps() const { + std::vector<::paddle::framework::OpDesc*> init_ops; + std::vector output_names = TrueBlockOutputVarNames(); + for (::paddle::framework::OpDesc* op : op_list_) { + if ((op->Type() == "fill_constant") || (op->Type() == "assign_value")) { + auto out_name = op->Output("Out")[0]; + if (std::find(output_names.begin(), output_names.end(), out_name) != + output_names.end()) { + init_ops.emplace_back(op); + } + } + } + return init_ops; } int ConditionBlockCombination::TrueBlockId() const { return op_list_[0]->GetBlockAttrId("sub_block"); } +std::vector ConditionBlockCombination::FalseBlockOutputVarNames() + const { + std::vector output_names; + for (::paddle::framework::OpDesc* op : op_list_) { + if (op->Type() == "select_input") { + output_names.emplace_back(op->Input("X")[0]); + } + } + return output_names; +} + +std::vector<::paddle::framework::OpDesc*> +ConditionBlockCombination::FalseBlockInitOps() const { + std::vector<::paddle::framework::OpDesc*> init_ops; + std::vector output_names = FalseBlockOutputVarNames(); + for (::paddle::framework::OpDesc* op : op_list_) { + if ((op->Type() == "fill_constant") || (op->Type() == "assign_value")) { + auto out_name = op->Output("Out")[0]; + if (std::find(output_names.begin(), output_names.end(), out_name) != + output_names.end()) { + init_ops.emplace_back(op); + } + } + } + return init_ops; +} + int ConditionBlockCombination::FalseBlockId() const { if (op_list_.size() > 1) { return op_list_[2]->GetBlockAttrId("sub_block"); @@ -210,10 +246,9 @@ bool ConditionBlockCombination::Verify( return false; } } else { - if (op_list[id]->Type() != "select_input") { - return false; - } - if (op_list[id]->Input("Mask")[0] != op_list[3]->Output("Out")[0]) { + if ((op_list[id]->Type() != "select_input") && + (op_list[id]->Type() != "fill_constant") && + (op_list[id]->Type() != "assign_value")) { return false; } } @@ -304,9 +339,10 @@ void ProgramTranslator::TranslateBlock( uint64_t start_id, uint64_t end_id, TranslationContext* translation_ctx, - pir::Block* dest_block, + pir::Block* dst_block, bool for_cond_block, - std::vector skip_cond_assign) { + const std::vector& cond_sub_block_outputs, + const std::vector<::paddle::framework::OpDesc*>& cond_init_ops) { VLOG(8) << "=============>start to translate a block"; PADDLE_ENFORCE( (src_block.OpSize() >= end_id) && (start_id <= end_id), @@ -318,7 +354,7 @@ void ProgramTranslator::TranslateBlock( src_block.OpSize())); std::unordered_map translate_completed; - std::vector assign_inputs; + std::map assign_output_2_input; for (uint64_t op_id = start_id; op_id < end_id; op_id++) { if (translate_completed.count(op_id) && translate_completed.at(op_id)) { continue; @@ -333,49 +369,59 @@ void ProgramTranslator::TranslateBlock( "Not support translated %s op", op->Type())); if (op->Type() == "conditional_block") { - std::vector cond_op_list = {op}; std::vector cond_op_ids = GetCondOpIds(src_block, op_id); ConditionBlockCombination cond_op_combination(src_block, cond_op_ids); pir::Operation* if_op = TranslateCondIfOperation( - cond_op_combination, translation_ctx, dest_block); + cond_op_combination, translation_ctx, dst_block); for (auto cond_id : cond_op_ids) { translate_completed[cond_id] = true; } VLOG(10) << "[op translated][conditional_block]" << if_op; } else if (op->Type() == "while") { - TranslateWhileOperation(op, translation_ctx, dest_block); + TranslateWhileOperation(op, translation_ctx, dst_block); } else { if (for_cond_block && op->Type() == "assign" && - std::count(skip_cond_assign.begin(), - skip_cond_assign.end(), + std::count(cond_sub_block_outputs.begin(), + cond_sub_block_outputs.end(), op->Output("Out")[0])) { - assign_inputs.push_back(op->Input("X")[0]); + assign_output_2_input[op->Output("Out")[0]] = op->Input("X")[0]; translate_completed[op_id] = true; } else { - TranslateGeneralOperation(op, translation_ctx, dest_block); + TranslateGeneralOperation(op, translation_ctx, dst_block); translate_completed[op_id] = true; } } } + // NOTE(zhangbo): If conditional_block operator has output, the cf.yeild // operator needs to be inserted if (for_cond_block) { + // insert init ops + for (::paddle::framework::OpDesc* init_op : cond_init_ops) { + TranslateGeneralOperation(init_op, translation_ctx, dst_block); + } + // insert yeild op std::vector yeild_inputs; - for (size_t id = 0; id < assign_inputs.size(); id++) { - yeild_inputs.emplace_back((*translation_ctx)[assign_inputs[id]].value); + for (auto output_name : cond_sub_block_outputs) { + if (assign_output_2_input.count(output_name) != 0) { + yeild_inputs.emplace_back( + (*translation_ctx)[assign_output_2_input[output_name]].value); + } else { + yeild_inputs.emplace_back((*translation_ctx)[output_name].value); + } } pir::AttributeMap attribute_map; auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); pir::Operation* yeild_op = pir::Operation::Create(yeild_inputs, attribute_map, {}, yeild_info); - dest_block->push_back(yeild_op); + dst_block->push_back(yeild_op); } } pir::Operation* ProgramTranslator::TranslateCondIfOperation( const ConditionBlockCombination& cond_ops, TranslationContext* translation_ctx, - pir::Block* dest_block) { + pir::Block* dst_block) { auto& type_translator = TypeTranslator::instance(); auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); std::vector op_inputs = { @@ -386,7 +432,7 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( std::vector op_output_types; std::vector<::paddle::framework::VarDesc*> output_vardescs = - cond_ops.OutputVars(); + cond_ops.OutputVars()[0]; for (auto var_desc : output_vardescs) { IR_ENFORCE(var_desc != nullptr, "[control flow] Output should not be null"); pir::Type translated_var_type = @@ -403,7 +449,7 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( VariableDefiningInfo(operation->result(i))); } - dest_block->push_back(operation); + dst_block->push_back(operation); VLOG(4) << "[general op][conditional_block] IfOp creation end."; if (cond_ops.TrueBlockId() != -1) { @@ -420,7 +466,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( true_block_context, true_region.front(), true, - cond_ops.TrueBlockOutputVarNames()); + cond_ops.TrueBlockOutputVarNames(), + cond_ops.TrueBlockInitOps()); } VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; @@ -436,7 +483,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( false_block_context, false_region.front(), true, - cond_ops.FalseBlockOutputVarNames()); + cond_ops.FalseBlockOutputVarNames(), + cond_ops.FalseBlockInitOps()); } VLOG(4) << "[general op][conditional_block] IfOp false block translate end."; @@ -448,7 +496,7 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( void ProgramTranslator::TranslateWhileOperation( const OpDesc* op, TranslationContext* translation_ctx, - pir::Block* dest_block) { + pir::Block* dst_block) { VLOG(8) << "=============>Start to translate while op:" << op; auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); int index = static_cast(sub_block.OpSize()) - 1; @@ -488,7 +536,7 @@ void ProgramTranslator::TranslateWhileOperation( } pir::Operation* while_op = pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1); - dest_block->push_back(while_op); + dst_block->push_back(while_op); while_op->region(0).push_back(body_block); TranslateBlock(sub_block, 0, index + 1, translation_ctx, body_block); @@ -518,7 +566,7 @@ void ProgramTranslator::TranslateWhileOperation( void ProgramTranslator::TranslateGeneralOperation( const OpDesc* src_op, TranslationContext* translation_ctx, - pir::Block* dest_block) { + pir::Block* dst_block) { auto& op_translator = OpTranslator::instance(); OpTranslateFn& fn = op_translator[src_op->Type()]; if (src_op->Type() == "shadow_output") { @@ -526,7 +574,7 @@ void ProgramTranslator::TranslateGeneralOperation( return; } } - pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dest_block); + pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dst_block); VLOG(10) << "[op translated][general]" << operation << "end"; } diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 97c7ae1ec86879..fb7df9204d37ad 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -50,14 +50,25 @@ class ConditionBlockCombination { public: ConditionBlockCombination(const ::paddle::framework::BlockDesc& src_block, const std::vector& op_ids); + const std::string& CondVarName() const; - int TrueBlockId() const; - int FalseBlockId() const; - size_t OutputSize() const; - std::vector<::paddle::framework::VarDesc*> OutputVars() const; + + std::vector> OutputVars() const; + + size_t MainOutputSize() const; + std::vector TrueBlockOutputVarNames() const; + + std::vector<::paddle::framework::OpDesc*> TrueBlockInitOps() const; + + int TrueBlockId() const; + std::vector FalseBlockOutputVarNames() const; + std::vector<::paddle::framework::OpDesc*> FalseBlockInitOps() const; + + int FalseBlockId() const; + private: bool Verify(const std::vector<::paddle::framework::OpDesc*>& op_list); @@ -127,16 +138,18 @@ class ProgramTranslator { static const std::unordered_set unsupported_ops; - void TranslateBlock(const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - TranslationContext* translation_ctx, - pir::Block* dest_block, - bool for_cond_block = false, - std::vector skip_cond_assign = {}); + void TranslateBlock( + const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + TranslationContext* translation_ctx, + pir::Block* dst_block, + bool for_cond_block = false, + const std::vector& cond_sub_block_outputs = {}, + const std::vector<::paddle::framework::OpDesc*>& cond_init_ops = {}); void TranslateGeneralOperation(const OpDesc* src_op, TranslationContext* translation_ctx, - pir::Block* dest_block); + pir::Block* dst_block); void GetParameterForSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block); void SetStopGradientAttributeForAllValue(const BlockDesc& block); @@ -146,10 +159,10 @@ class ProgramTranslator { pir::Operation* TranslateCondIfOperation( const ConditionBlockCombination& cond_ops, TranslationContext* translation_ctx, - pir::Block* dest_block); + pir::Block* dst_block); void TranslateWhileOperation(const OpDesc* op, TranslationContext* translation_ctx, - pir::Block* dest_block); + pir::Block* dst_block); }; } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/translate.cc b/paddle/fluid/ir_adaptor/translator/translate.cc index 0f98e557743fcb..7a7081fe1acbf2 100644 --- a/paddle/fluid/ir_adaptor/translator/translate.cc +++ b/paddle/fluid/ir_adaptor/translator/translate.cc @@ -34,8 +34,9 @@ std::unique_ptr TranslateLegacyProgramToProgram( auto program = std::make_unique(ctx); translator::ProgramTranslator program_translator(&legacy_program, program.get()); + VLOG(6) << "begin to translate"; program_translator.Translate(); - + VLOG(6) << "translate done"; return program; } diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 1d45cee7154095..270e0debbdb1b6 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -35,8 +35,6 @@ namespace operators { void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNorm"); - OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "BatchNorm"); - OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "BatchNorm"); OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "BatchNorm"); OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance", "BatchNorm"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm"); @@ -118,48 +116,54 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ? x_dims[1] : x_dims[x_dims.size() - 1]); - auto scale_dim = ctx->GetInputDim("Scale"); - auto bias_dim = ctx->GetInputDim("Bias"); + if (ctx->HasInput("Scale")) { + auto scale_dim = ctx->GetInputDim("Scale"); + PADDLE_ENFORCE_EQ( + scale_dim.size(), + 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of scale must equal to 1." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale_dim, + scale_dim.size())); + } - PADDLE_ENFORCE_EQ( - scale_dim.size(), - 1UL, - platform::errors::InvalidArgument( - "ShapeError: the dimension of scale must equal to 1." - "But received: the shape of scale is [%s], the dimension " - "of scale is [%d]", - scale_dim, - scale_dim.size())); - PADDLE_ENFORCE_EQ(bias_dim.size(), - 1UL, - platform::errors::InvalidArgument( - "ShapeError: the dimension of bias must equal to 1." - "But received: the shape of bias is [%s],the dimension " - "of bias is [%d]", - bias_dim, - bias_dim.size())); + if (ctx->HasInput("Bias")) { + auto bias_dim = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dim.size(), + 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of bias must equal to 1." + "But received: the shape of bias is [%s],the dimension " + "of bias is [%d]", + bias_dim, + bias_dim.size())); + } bool check = true; - if ((!ctx->IsRuntime()) && - (phi::product(scale_dim) <= 0 || phi::product(bias_dim) <= 0)) { + if (!ctx->HasInput("Scale") || !ctx->HasInput("Bias") || + ((!ctx->IsRuntime()) && (phi::product(ctx->GetInputDim("Scale")) <= 0 || + phi::product(ctx->GetInputDim("Bias")) <= 0))) { check = false; } if (check) { - PADDLE_ENFORCE_EQ(scale_dim[0], + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C, platform::errors::InvalidArgument( "ShapeError: the shape of scale must equal to [%d]" "But received: the shape of scale is [%d]", C, - scale_dim[0])); - PADDLE_ENFORCE_EQ(bias_dim[0], + ctx->GetInputDim("Scale")[0])); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C, platform::errors::InvalidArgument( "ShapeError: the shape of bias must equal to [%d]" "But received: the shape of bias is [%d]", C, - bias_dim[0])); + ctx->GetInputDim("Bias")[0])); } ctx->SetOutputDim("Y", x_dims); ctx->ShareLoD("X", "Y"); @@ -185,16 +189,20 @@ phi::KernelKey BatchNormOp::GetExpectedKernelType( if (input_data_type == framework::proto::VarType::FP64) { bn_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ( - bn_param_type, - framework::TransToProtoVarType( - ctx.Input("Scale")->dtype()), - platform::errors::InvalidArgument("Scale input should be of float type")); - PADDLE_ENFORCE_EQ( - bn_param_type, - framework::TransToProtoVarType( - ctx.Input("Bias")->dtype()), - platform::errors::InvalidArgument("Bias input should be of float type")); + if (ctx.HasInput("Scale")) { + PADDLE_ENFORCE_EQ(bn_param_type, + framework::TransToProtoVarType( + ctx.Input("Scale")->dtype()), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + } + if (ctx.HasInput("Bias")) { + PADDLE_ENFORCE_EQ(bn_param_type, + framework::TransToProtoVarType( + ctx.Input("Bias")->dtype()), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + } PADDLE_ENFORCE_EQ( bn_param_type, framework::TransToProtoVarType( @@ -205,7 +213,6 @@ phi::KernelKey BatchNormOp::GetExpectedKernelType( ctx.Input("Variance")->dtype()), platform::errors::InvalidArgument( "Variance input should be of float type")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); } @@ -257,10 +264,12 @@ void BatchNormOpMaker::Make() { AddInput("X", "The input tensor"); AddInput("Scale", "Scale is a 1-dimensional tensor of size C " - "that is applied to the output"); + "that is applied to the output") + .AsDispensable(); AddInput("Bias", "Bias is a 1-dimensional tensor of size C " - "that is applied to the output"); + "that is applied to the output") + .AsDispensable(); AddInput("Mean", "The global mean (for training) or " "estimated mean (for testing)"); diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index bdb774c62bdc4c..11b51602d4d75a 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/alltoall_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" @@ -27,6 +27,8 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { +using phi::distributed::GetPartialTensor; + template class AllToAllOpCUDAKernel : public framework::OpKernel { public: @@ -103,9 +105,9 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { if (comm_ctx) { comm_ctx->GroupStart(); for (auto i = 0; i < nranks; ++i) { - auto send_buf = distributed::GetPartialTensor(*x, offset, send_numel); + auto send_buf = GetPartialTensor(*x, offset, send_numel); comm_ctx->Send(send_buf, send_numel, i, stream); - auto recv_buf = distributed::GetPartialTensor(*out, offset, send_numel); + auto recv_buf = GetPartialTensor(*out, offset, send_numel); comm_ctx->Recv(&recv_buf, send_numel, i, stream); offset += send_numel; } diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 06be523a50b27c..bd105c35886cb0 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_allgather_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index 37616be1128f7f..d13179cbae48b1 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_concat_op.h" #include -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/fluid/operators/math/concat_and_split.h" diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 50988fd3814831..737784d96c0ee2 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index 9da5d6ad1d840f..cd1cf0c0176363 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_reducescatter_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index ea5a0dda1fd973..7f4b4f6734de0c 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_scatter_op.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index e296a4d218f1f9..d95c194452174e 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_gather_op.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" @@ -279,7 +279,7 @@ struct GlobalGatherProcessGroupFunctor { out->mutable_data(out_dims, place); for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + distributed::ProcessGroupNCCL::GroupStart(); for (auto j = 0; j < nranks; ++j) { int idx = i + j * n_expert; if (cpu_global_count_data[idx]) { @@ -299,7 +299,7 @@ struct GlobalGatherProcessGroupFunctor { /*sync_op*/ true); } } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + distributed::ProcessGroupNCCL::GroupEnd(); } #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 45d91dc724108f..d8cd6d4be5f54a 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_scatter_op.h" #include "paddle/phi/core/distributed/comm_context_manager.h" -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/convert_utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" @@ -286,7 +286,7 @@ struct GlobalScatterProcessGroupFunctor { out->mutable_data(out_dims, place); for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + distributed::ProcessGroupNCCL::GroupStart(); for (auto j = 0; j < nranks; ++j) { int idx = i + j * n_expert; if (cpu_local_count_data[idx]) { @@ -306,7 +306,7 @@ struct GlobalScatterProcessGroupFunctor { recv_ptr += cpu_global_count_data[idx]; } } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + distributed::ProcessGroupNCCL::GroupEnd(); } #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc index cf353c12ffa491..b0cdabce48503a 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc @@ -23,7 +23,6 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 2a6aea1c7a13af..c8844058696e14 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -23,7 +23,6 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 67089a18c8e4fc..39858b3ed37a26 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -23,8 +23,6 @@ limitations under the License. */ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" -#include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 42c41effb80ed2..b8bb9a123fba3d 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -13,7 +13,6 @@ register_operators( yolo_box_head_op yolo_box_post_op fusion_group_op - fusion_gru_op fusion_lstm_op fused_bn_add_activation_op fused_attention_op @@ -27,8 +26,6 @@ register_operators( fused_gate_attention_op resnet_basic_block_op) -# fusion_gru_op does not have CUDA kernel -op_library(fusion_gru_op) op_library(fusion_lstm_op) if(WITH_AVX AND AVX512F_FOUND diff --git a/paddle/fluid/operators/fused/fused_attention_utils.h b/paddle/fluid/operators/fused/fused_attention_utils.h index c059a194d0ea53..7d17041133bcd7 100644 --- a/paddle/fluid/operators/fused/fused_attention_utils.h +++ b/paddle/fluid/operators/fused/fused_attention_utils.h @@ -23,7 +23,6 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/errors.h" diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index ba12bdc8b9d7f2..40717402846db5 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -24,7 +24,6 @@ limitations under the License. */ #include -#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fused/attention_layer_norm.h" diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc deleted file mode 100644 index 541233949b5d22..00000000000000 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ /dev/null @@ -1,565 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_gru_op.h" - -#include // for memcpy -#include -#include - -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/fc_functor.h" -#include "paddle/phi/kernels/funcs/jit/kernels.h" -#include "paddle/phi/kernels/funcs/sequence2batch.h" - -namespace paddle { -namespace operators { - -void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru"); - auto x_dims = ctx->GetInputDim("X"); - auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) - ? phi::flatten_to_2d(x_dims, 1) - : x_dims; - PADDLE_ENFORCE_EQ( - x_mat_dims.size(), - 2, - platform::errors::InvalidArgument("The size of input X dims should be 2, " - "or 3 with second dimension equal to " - "1, but now Input X dim is:[%s] ", - x_dims)); - - auto wx_dims = ctx->GetInputDim("WeightX"); - PADDLE_ENFORCE_EQ(wx_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(WeightX) should be 2, but received " - "WeightX dim size is:%d, WeightX dim is:[%s] ", - wx_dims.size(), - wx_dims)); - PADDLE_ENFORCE_EQ( - wx_dims[0], - x_mat_dims[1], - platform::errors::InvalidArgument( - "The first dimension of flattened WeightX" - "should equal to last dimension of flattened input X, but " - "received fattened WeightX dimension is:%d, flattened X dimension " - "is:%d", - wx_dims[0], - x_mat_dims[1])); - - int frame_size = static_cast(wx_dims[1] / 3); - auto wh_dims = ctx->GetInputDim("WeightH"); - - PADDLE_ENFORCE_EQ(wh_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(WeightH) should be 2, but received " - "WeightH dim size is:%d, WeightH dim is:[%s]", - wh_dims.size(), - wh_dims)); - PADDLE_ENFORCE_EQ(wh_dims[0], - frame_size, - platform::errors::InvalidArgument( - "The first dimension of WeightH " - "should equal to frame_size, but received WeightH's " - "first dimension is: " - "%d, frame size is:%d", - wh_dims[0], - frame_size)); - PADDLE_ENFORCE_EQ(wh_dims[1], - 3 * frame_size, - platform::errors::InvalidArgument( - "The second dimension of Input(WeightH) " - "should equal to 3 * frame_size, but received WeightH " - "is:%d, frame size is:%d", - wh_dims[1], - frame_size)); - - if (ctx->HasInput("H0")) { - auto h0_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE_EQ(h0_dims[1], - frame_size, - platform::errors::InvalidArgument( - "The width of H0 must be equal to frame_size, but " - "receiced the width of H0 is:%d, frame size is:%d", - h0_dims[1], - frame_size)); - } - if (ctx->HasInput("Bias")) { - auto b_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(b_dims.size(), - 2, - platform::errors::InvalidArgument( - "The rank of Input(Bias) should be 2, but received " - "Bias rank is:%d, Bias dim is:[%s]", - b_dims.size(), - b_dims)); - PADDLE_ENFORCE_EQ(b_dims[0], - 1, - platform::errors::InvalidArgument( - "The first dimension of Input(Bias) should be 1, but " - "received Bias first dim is:%d, Bias dim is:[%s]", - b_dims[0], - b_dims)); - PADDLE_ENFORCE_EQ(b_dims[1], - frame_size * 3, - platform::errors::InvalidArgument( - "The shape of Bias must be [1, frame_size * 3], but " - "received bias dim is:[%s], frame size is:%d", - b_dims, - frame_size)); - } - framework::DDim out_dims({x_mat_dims[0], frame_size}); - ctx->SetOutputDim("Hidden", out_dims); - ctx->ShareLoD("X", "Hidden"); - int xx_width = 0; - if (ctx->Attrs().Get("use_seq")) { - xx_width = static_cast(wx_dims[1]); - } else { - xx_width = static_cast(x_mat_dims[1] > wx_dims[1] ? wx_dims[1] - : x_mat_dims[1]); - OP_INOUT_CHECK( - ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", "fusion_gru"); - OP_INOUT_CHECK( - ctx->HasOutput("BatchedInput"), "Output", "BatchedInput", "fusion_gru"); - OP_INOUT_CHECK( - ctx->HasOutput("BatchedOut"), "Output", "BatchedOut", "fusion_gru"); - ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchedOut", out_dims); - } - ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width}); - ctx->ShareLoD("X", "XX"); -} - -phi::KernelKey FusionGRUOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(data_type, ctx.GetPlace()); -} - -void FusionGRUOpMaker::Make() { - AddInput( - "X", - "(phi::DenseTensor) the input is a LodTensor, which support " - "variable-time length input sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T X M), where T is the " - "total time steps in this mini-batch, M is the dim size of x."); - AddInput( - "H0", - "(phi::DenseTensor, optional) The initial hidden state is an optional " - "input. This is a tensor with shape (N x D), where N is the " - "batch size, D is the hidden size.") - .AsDispensable(); - AddInput("WeightX", - "(phi::DenseTensor) The FC weight with shape (M x 3D)," - "where M is the dim size of x, D is the hidden size. "); - AddInput( - "WeightH", - "(phi::DenseTensor) (D x 3D) Same as GRUOp, where D is the hidden size. " - "This weight is not exactly D x 3D as: {W_update, W_reset, W_state}" - "Acutally they are D x 2D and D x D two part weights." - "{W_update, W_reset; W_state}" - "{D x (D + D); D x D}"); - AddInput("Bias", - "(phi::DenseTensor, optional) (1 x 3D)." - "Almost same as GRUOp." - "Note: if have FC bias it should be added on this bias.") - .AsDispensable(); - AddOutput("ReorderedH0", - "(phi::DenseTensor) (N x D), which N is the min-batch size.") - .AsIntermediate(); - AddOutput("XX", - "(phi::DenseTensor) the result after X * WeightX (size is T x 3D)" - " or batched_X (size is T x M), this will be automatically chosen," - " where T is the total time steps in this mini-batch," - " D is the hidden size, M is the dim size of x input.") - .AsIntermediate(); - AddOutput("BatchedInput", - "(phi::DenseTensor) This is the batched result of input X" - "or the batched result after fc, shape (T x 3D)") - .AsIntermediate(); - AddOutput("BatchedOut", "(phi::DenseTensor) (T X D) save batched hidden.") - .AsIntermediate(); - AddOutput("Hidden", "(phi::DenseTensor) (T x D) Same as GRUOp"); - AddAttr("activation", - "(string, default tanh) " - "The activation type used for output candidate {h}_t.") - .SetDefault("tanh"); - AddAttr( - "gate_activation", - "(string, default sigmoid) " - "The activation type used in update gate and reset gate.") - .SetDefault("sigmoid"); - AddAttr("is_reverse", - "(bool, default: False) " - "whether to compute reversed GRU.") - .SetDefault(false); - AddAttr("use_seq", - "(bool, default: True) " - "whether to use seq mode to compute GRU.") - .SetDefault(true); - AddAttr("origin_mode", - "bool" - "use origin mode in article https://arxiv.org/abs/1412.3555") - .SetDefault(false); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr( - "mkldnn_data_type", - "(string, default \"float32\"). Data type of mkldnn kernel") - .SetDefault("float32") - .InEnum({"float32", "int8", "bfloat16"}); - AddAttr("Scale_data", - "Scale to be used for int8 input/output data." - "Only used with MKL-DNN INT8.") - .SetDefault(1.0f); - AddAttr("Shift_data", - "Shift to be used for int8 input/output data." - "Only used with MKL-DNN INT8.") - .SetDefault(0.0f); - AddAttr>("Scale_weights", - "Scale_weights to be used for int8 weights data." - "Only used with MKL-DNN INT8.") - .SetDefault({1.0f}); - AddAttr("force_fp32_output", - "(bool, default false) Force INT8 kernel output FP32, only " - "used in MKL-DNN INT8") - .SetDefault(false); - AddComment(R"DOC( -The Fusion complete GRU Operator. -This operator fuse the fully-connected operator into GRU, -more details can refer to GRU op. -)DOC"); -} - -template -class FusionGRUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - if (ctx.Attr("use_seq")) { - SeqCompute(ctx); - } else { - BatchCompute(ctx); - } - } - -#define INIT_BASE_DEFINES \ - auto* x = ctx.Input("X"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* xx = ctx.Output("XX"); \ - auto x_lod = x->lod(); \ - auto x_dims = x->dims(); /* T x M*/ \ - auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \ - ? phi::flatten_to_2d(x_dims, 1) \ - : x_dims; \ - auto wh_dims = wh->dims(); /* D x 3D*/ \ - const int total_T = x_mat_dims[0]; \ - const int D3 = wh_dims[1] - -#define INIT_OTHER_DEFINES \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* bias = ctx.Input("Bias"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - const int M = x_mat_dims[1]; \ - const int D = wh_dims[0]; \ - const int D2 = D * 2; \ - const phi::jit::gru_attr_t attr( \ - D, \ - phi::jit::to_kerneltype(ctx.Attr("gate_activation")), \ - phi::jit::to_kerneltype(ctx.Attr("activation"))); \ - phi::jit::gru_t one_step; \ - auto ComputeH1 = phi::jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeHtPart1 = phi::jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - auto ComputeHtPart2 = phi::jit::KernelFuncs, \ - platform::CPUPlace>::Cache() \ - .At(attr); \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - auto place = ctx.GetPlace(); \ - T* xx_data = xx->mutable_data(place) - - void SeqCompute(const framework::ExecutionContext& ctx) const { - INIT_BASE_DEFINES; - INIT_OTHER_DEFINES; - const int N = static_cast(x_lod[0].size() - 1); - const T* h0_data = h0 ? h0->data() : nullptr; - const T* wh_state_data = wh_data + D * D2; - T* hidden_out_data = hidden_out->mutable_data(place); - - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - phi::funcs::FCFunctor fc; - fc(dev_ctx, - total_T, - D3, - M, - x_data, - wx_data, - xx_data, - bias ? bias->data() : nullptr); - - int xx_offset = D3; - int gate_offset = D; - if (is_reverse) { - const int offset = (total_T - 1) * D; - xx_data = xx_data + offset * 3; - hidden_out_data = hidden_out_data + offset; - xx_offset = -D3; - gate_offset = -D; - } - auto move_step = [&]() { - xx_data = xx_data + xx_offset; - hidden_out_data = hidden_out_data + gate_offset; - }; - for (int i = 0; i < N; ++i) { - int bid = is_reverse ? N - 1 - i : i; - int seq_len = static_cast(x_lod[0][bid + 1] - x_lod[0][bid]); - const T* prev_hidden_data = nullptr; - int tstart = 0; - if (h0_data) { - prev_hidden_data = h0_data + bid * D; - } else { - one_step.gates = xx_data; - one_step.ht = hidden_out_data; - ComputeH1(&one_step, &attr); - prev_hidden_data = hidden_out_data; - tstart = 1; - move_step(); - } - for (int step = tstart; step < seq_len; ++step) { - // gemm prev * (Wu + Wr) - blas.GEMM(CblasNoTrans, - CblasNoTrans, - 1, - D2, - D, - static_cast(1), - prev_hidden_data, - D, - wh_data, - D2, - static_cast(1), - xx_data, - D3); - one_step.gates = xx_data; - one_step.ht_1 = prev_hidden_data; - one_step.ht = hidden_out_data; - ComputeHtPart1(&one_step, &attr); - // gemm rt * Ws - blas.GEMM(CblasNoTrans, - CblasNoTrans, - 1, - D, - D, - static_cast(1), - hidden_out_data, - D, - wh_state_data, - D, - static_cast(1), - xx_data + D2, - D3); - ComputeHtPart2(&one_step, &attr); - // save prev - prev_hidden_data = hidden_out_data; - move_step(); - } - } - } - - void BatchCompute(const framework::ExecutionContext& ctx) const { - INIT_BASE_DEFINES; - if (x_lod[0].size() == 2) { - xx->Resize({total_T, D3}); - SeqCompute(ctx); - return; - } - INIT_OTHER_DEFINES; - auto* reordered_h0 = ctx.Output("ReorderedH0"); - auto* batched_input = ctx.Output("BatchedInput"); - auto* batched_out = ctx.Output("BatchedOut"); - T* batched_input_data = batched_input->mutable_data(place); - T* batched_out_data = batched_out->mutable_data(place); - hidden_out->mutable_data(place); - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - phi::funcs::LoDTensor2BatchFunctor to_batch; - - phi::funcs::FCFunctor fc; - if (M > D3) { - fc(dev_ctx, - total_T, - D3, - M, - x_data, - wx_data, - xx_data, - bias ? bias->data() : nullptr); - to_batch(dev_ctx, *xx, batched_input, true, is_reverse); - } else { - to_batch(dev_ctx, *x, xx, true, is_reverse); - batched_input->set_lod(xx->lod()); - fc(dev_ctx, - total_T, - D3, - M, - xx_data, - wx_data, - batched_input_data, - bias ? bias->data() : nullptr); - } - - auto batched_lod = batched_input->lod(); - const auto& seq_order = batched_lod[2]; - const int max_bs = static_cast(seq_order.size()); - reordered_h0->Resize({max_bs, D}); - - int tstart = 0; - T* prev_hidden_data = nullptr; - if (h0) { - // reorder h0 - T* reordered_h0_data = reordered_h0->mutable_data(place); - const T* h0_data = h0->data(); - prev_hidden_data = reordered_h0_data; - size_t sz = sizeof(T) * D; - for (int i = 0; i < max_bs; ++i) { - std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); - reordered_h0_data += D; - } - } else { - // compute without h0 - T* cur_in_data = batched_input_data; - T* cur_out_data = batched_out_data; - // W: {W_update, W_reset; W_state} - for (int i = 0; i < max_bs; ++i) { - one_step.gates = cur_in_data; - one_step.ht = cur_out_data; - ComputeH1(&one_step, &attr); - // add offset - cur_in_data += D3; - cur_out_data += D; - } - tstart = 1; - prev_hidden_data = batched_out_data; - } - // Then start from next - const T* wh_state_data = wh_data + D * D2; - const auto& batch_starts = batched_lod[0]; - const int max_seq_len = static_cast(batch_starts.size() - 1); - batched_input_data = batched_input_data + tstart * max_bs * D3; - batched_out_data = batched_out_data + tstart * max_bs * D; - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = - static_cast(batch_starts[step + 1] - batch_starts[step]); - // gemm prev * (Wu + Wr) - blas.GEMM(CblasNoTrans, - CblasNoTrans, - cur_bs, - D2, - D, - static_cast(1), - prev_hidden_data, - D, - wh_data, - D2, - static_cast(1), - batched_input_data, - D3); - - T* cur_batched_data = batched_input_data; - T* cur_out_data = batched_out_data; - T* cur_prev_hidden_data = prev_hidden_data; - for (int i = 0; i < cur_bs; ++i) { - one_step.gates = cur_batched_data; - one_step.ht_1 = cur_prev_hidden_data; - one_step.ht = cur_out_data; - ComputeHtPart1(&one_step, &attr); - - cur_batched_data += D3; - cur_prev_hidden_data += D; - cur_out_data += D; - } - - cur_batched_data = batched_input_data; - cur_out_data = batched_out_data; - blas.GEMM(CblasNoTrans, - CblasNoTrans, - cur_bs, - D, - D, - static_cast(1), - cur_out_data, - D, - wh_state_data, - D, - static_cast(1), - cur_batched_data + D2, - D3); - - cur_prev_hidden_data = prev_hidden_data; - for (int i = 0; i < cur_bs; ++i) { - one_step.gates = cur_batched_data; - one_step.ht_1 = cur_prev_hidden_data; - one_step.ht = cur_out_data; - ComputeHtPart2(&one_step, &attr); - cur_batched_data += D3; - cur_prev_hidden_data += D; - cur_out_data += D; - } - prev_hidden_data = batched_out_data; - batched_out_data = cur_out_data; - batched_input_data = cur_batched_data; - } - - phi::funcs::Batch2LoDTensorFunctor to_seq; - batched_out->set_lod(batched_lod); - to_seq(dev_ctx, *batched_out, hidden_out); - } -#undef INIT_OTHER_DEFINES -#undef INIT_BASE_DEFINES -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - fusion_gru, CPU, ALL_LAYOUT, ops::FusionGRUKernel, float, double) {} - -/* ========================== register checkpoint ===========================*/ -REGISTER_OP_VERSION(fusion_gru) - .AddCheckpoint( - R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "Scale_weights", - "The added attribute 'Scale_weights' is not yet " - "registered.", - std::vector{1.0f})); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.h b/paddle/fluid/operators/fused/fusion_gru_op.h deleted file mode 100644 index e811df655099d8..00000000000000 --- a/paddle/fluid/operators/fused/fusion_gru_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusionGRUOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc deleted file mode 100644 index de70a5b6b5cf59..00000000000000 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ /dev/null @@ -1,290 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h" - -#include // for min, max -#include - -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/fc_functor.h" - -namespace paddle { -namespace operators { - -void FusionSeqConvEltAddReluOp::InferShape( - framework::InferShapeContext* ctx) const { - OP_INOUT_CHECK( - ctx->HasInput("X"), "Input", "X", "fusion_seqconv_eltadd_relu"); - OP_INOUT_CHECK( - ctx->HasInput("Filter"), "Input", "Filter", "fusion_seqconv_eltadd_relu"); - OP_INOUT_CHECK( - ctx->HasInput("Bias"), "Input", "Bias", "fusion_seqconv_eltadd_relu"); - - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "fusion_seqconv_eltadd_relu"); - OP_INOUT_CHECK(ctx->HasOutput("ColMat"), - "Output", - "ColMat", - "fusion_seqconv_eltadd_relu"); - - auto x_dims = ctx->GetInputDim("X"); - auto w_dims = ctx->GetInputDim("Filter"); - int context_length = ctx->Attrs().Get("contextLength"); - PADDLE_ENFORCE_EQ(ctx->Attrs().Get("contextStride"), - 1, - platform::errors::InvalidArgument( - "Currently, FusionSeqConvEltAddReluOp only supports " - "contextStride=1, but received value is: %d.", - ctx->Attrs().Get("contextStride"))); - - PADDLE_ENFORCE_EQ( - x_dims.size(), - 2, - platform::errors::InvalidArgument( - "Input(X) should be 2-D tensor, but reveiced value is: %d.", - x_dims.size())); - - PADDLE_ENFORCE_EQ( - w_dims.size(), - 2, - platform::errors::InvalidArgument( - "Filter should be 2-D tensor, but reveiced value is: %d.", - w_dims.size())); - - PADDLE_ENFORCE_EQ(w_dims[0], - context_length * x_dims[1], - platform::errors::InvalidArgument( - "Filter's height should be equal to context_length * " - "input_hidden_size, but received Filter height is: %d," - "context_length is: %d, input_hidden_size is: %d.", - w_dims[0], - context_length, - x_dims[1])); - - PADDLE_ENFORCE_GT( - context_length + ctx->Attrs().Get("contextStart"), - 0, - platform::errors::InvalidArgument( - "contextStart size should be smaller than contextLength, " - "but received context_length is: %d, contextStart is: " - "%d.", - context_length, - ctx->Attrs().Get("contextStart"))); - - ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]}); - ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]}); - ctx->ShareLoD("X", "Out"); -} - -phi::KernelKey FusionSeqConvEltAddReluOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -void FusionSeqConvEltAddReluOpMaker::Make() { - AddInput( - "X", - "(phi::DenseTensor) the input is a LodTensor, which support " - "variable-time length input sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T X M), where T is the " - "total time steps in this mini-batch, M is the dim size of x."); - // PaddingData only support false yet, should be ensured at pass. - AddInput( - "Filter", - "(phi::DenseTensor) same as the input(Filter) of sequence conv op is an " - "learnable parameter." - "This is a tensor with shape (K, N), where K is the " - "context_length * dim size of x, N is the output feature size."); - AddInput( - "Bias", - "(phi::DenseTensor) the learnable weights. shape (1, N), where N is the " - "output feature size"); - AddOutput( - "Out", - "(phi::DenseTensor) the output(Out) is a LodTensor, which support " - "variable-time length output sequence. The underlying tensor in " - "this phi::DenseTensor is a matrix with shape (T, N), where, T is the " - "total time steps in this mini-batch, N is the output feature size."); - AddOutput("ColMat", - "(phi::DenseTensor) (T, K), where T is where T is the " - "total time steps in this mini-batch, K is height of Filter") - .AsIntermediate(); - AddAttr("contextLength", - "(int) the contextLength of FusionSeqConvEltAddReluOp is the " - "height of the convolution kernel.") - .GreaterThan(0); - AddAttr("contextStart", - "(int, default:0) the contextStart of FusionSeqConvEltAddReluOp " - "represents the beginning of the convolution of the number of " - "rows of sequence, which can be negative. The negative number " - "means to pad contextStart time-steps of zeros or learnable " - "parameters at the beginning of each instance. The positive " - "number means to skip contextStart time-steps of each " - "instance.") - .SetDefault(0); - AddAttr( - "contextStride", - "(int, default:1) the contextStride of FusionSeqConvEltAddReluOp " - "represents the stride length of convolution kernel. " - "Currently, FusionSeqConvEltAddReluOp only supports" - "contextStride=1.") - .SetDefault(1) - .GreaterThan(0); - AddComment(R"DOC( -Fusion Sequence Conv and ElementwiseAdd Operator. -)DOC"); -} - -template -class FusionSeqConvEltAddReluKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* w = ctx.Input("Filter"); - auto* b = ctx.Input("Bias"); - auto* y = ctx.Output("Out"); - auto* col = ctx.Output("ColMat"); - - auto x_lod = x->lod(); - auto x_dims = phi::vectorize(x->dims()); - auto w_dims = phi::vectorize(w->dims()); - PADDLE_ENFORCE_EQ( - b->numel(), - w_dims[1], - platform::errors::InvalidArgument( - "bias size should be equal to weights feature size, but received " - "bias size is: %d, weights feature size is: %d.", - b->numel(), - w_dims[1])); - PADDLE_ENFORCE_EQ( - x_lod.size(), - 1UL, - platform::errors::InvalidArgument( - "Only support one level sequence now, but received value is: %d.", - x_lod.size())); - - const T* x_data = x->data(); - const T* w_data = w->data(); - const T* b_data = b->data(); - T* y_data = y->mutable_data(ctx.GetPlace()); - T* col_data = col->mutable_data(ctx.GetPlace()); - - int context_start = ctx.Attr("contextStart"); - int context_length = ctx.Attr("contextLength"); - int up_pad = std::max(0, -context_start); - int down_pad = std::max(0, context_start + context_length - 1); - // im2col - int src_mat_w = static_cast(x_dims[1]); - int src_mat_w_sz = src_mat_w * sizeof(T); - int col_mat_w = static_cast(w_dims[0]); - int col_mat_w_sz = col_mat_w * sizeof(T); - for (int i = 0; i < static_cast(x_lod[0].size()) - 1; ++i) { - int st = static_cast(x_lod[0][i]); - int ed = static_cast(x_lod[0][i + 1]); - const T* src_data = x_data + st * src_mat_w; - T* dst_data = col_data + st * col_mat_w; - int seq_len = ed - st; - if (seq_len > up_pad + down_pad) { - // zero all up_pad and fill data - std::memset(dst_data, 0, up_pad * col_mat_w_sz); - dst_data = dst_data + up_pad * src_mat_w; - int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz; - for (int j = 0; j < up_pad; ++j) { - // blas.VCOPY? - std::memcpy(dst_data, src_data, copy_size); - dst_data += (col_mat_w - src_mat_w); - copy_size += src_mat_w_sz; - } - // fill data - if (context_start > 0) { - src_data += context_start * src_mat_w; - } - for (int j = 0; j < seq_len - up_pad - down_pad; ++j) { - std::memcpy(dst_data, src_data, copy_size); - dst_data += col_mat_w; - src_data += src_mat_w; - } - // zero all down_pad and fill data - std::memset(dst_data, 0, down_pad * col_mat_w_sz); - copy_size -= src_mat_w_sz; - for (int j = 0; j < down_pad; ++j) { - if (copy_size < 0) { - copy_size = 0; - } - std::memcpy(dst_data, src_data, copy_size); - dst_data += col_mat_w; - src_data += src_mat_w; - copy_size -= src_mat_w_sz; - } - } else { - std::memset(dst_data, 0, seq_len * col_mat_w_sz); - dst_data = dst_data + up_pad * src_mat_w; - int zero_sz = up_pad * src_mat_w_sz; - int cur_src_sz = seq_len * src_mat_w_sz; - for (int j = 0; j < std::min(up_pad, seq_len); ++j) { - int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); - std::memcpy(dst_data, src_data, copy_size); - dst_data += (col_mat_w - src_mat_w); - zero_sz -= src_mat_w_sz; - } - // from bottom - dst_data = col_data + ed * col_mat_w; - src_data = x_data + st * src_mat_w; - if (context_start > 0) { - src_data += context_start * src_mat_w; - } - zero_sz = down_pad * src_mat_w_sz; - for (int j = 1; j <= std::min(down_pad, seq_len); ++j) { - int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); - if (copy_size < 0) { - copy_size = 0; - } - std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T), - src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w, - copy_size); - dst_data -= col_mat_w; - zero_sz -= src_mat_w_sz; - } - } - } - auto& dev_ctx = ctx.template device_context(); - phi::funcs::FCFunctor fc; - fc(dev_ctx, - x_dims[0], - w_dims[1], - w_dims[0], - col_data, - w_data, - y_data, - b_data, - true); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_seqconv_eltadd_relu, - ops::FusionSeqConvEltAddReluOp, - ops::FusionSeqConvEltAddReluOpMaker); - -PD_REGISTER_STRUCT_KERNEL(fusion_seqconv_eltadd_relu, - CPU, - ALL_LAYOUT, - ops::FusionSeqConvEltAddReluKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h deleted file mode 100644 index 42e0c57b1133aa..00000000000000 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionSeqConvEltAddReluOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc deleted file mode 100644 index 03b5971b1482ae..00000000000000 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h" - -#include - -#include "paddle/phi/backends/cpu/cpu_info.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/cpu_vec.h" -#include "paddle/phi/kernels/funcs/fc_functor.h" - -namespace paddle { -namespace operators { - -void FusionSeqExpandConcatFCOp::InferShape( - framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), - 1UL, - platform::errors::InvalidArgument( - "Inputs(X) of FusionSeqExpandConcatFCOp should larger " - "than 1, but received value is: %d.", - ctx->Inputs("X").size())); - OP_INOUT_CHECK(ctx->HasInput("FCWeight"), - "Input", - "FCWeight", - "fusion_seqexpand_concat_fc"); - OP_INOUT_CHECK( - ctx->HasOutput("Out"), "Output", "Out", "fusion_seqexpand_concat_fc"); - OP_INOUT_CHECK( - ctx->HasOutput("FCOut"), "Output", "FCOut", "fusion_seqexpand_concat_fc"); - - auto ins_dims = ctx->GetInputsDim("X"); - auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D - PADDLE_ENFORCE_EQ( - w_dims.size(), - 2, - platform::errors::InvalidArgument( - "Input(FCWeight)'s rank must be 2, but received value is: %d.", - w_dims.size())); - const int D = static_cast(w_dims[1]); - int sum = static_cast(ins_dims[0][1]); - for (size_t i = 1; i < ins_dims.size(); ++i) { - sum += static_cast(ins_dims[i][1]); - } - PADDLE_ENFORCE_EQ( - sum, - w_dims[0], - platform::errors::InvalidArgument("FC height should be sum of all inputs " - "width, but received FC height is: %d, " - "sum of all inputs width is: %d.", - w_dims[0], - sum)); - if (ctx->HasInput("FCBias")) { - auto b_dims = ctx->GetInputDim("FCBias"); - PADDLE_ENFORCE_EQ( - b_dims.size() == 1 || b_dims.size() == 2, - true, - platform::errors::InvalidArgument( - "FCBias dim should be 1 or 2, but received value is: %d.", - b_dims.size())); - if (b_dims.size() == 1) { - PADDLE_ENFORCE_EQ(b_dims[0], - D, - platform::errors::InvalidArgument( - "FCBias shapes must be %d when FCBias dim = 1, but " - "received value is: %d.", - D, - b_dims[0])); - } else { - PADDLE_ENFORCE_EQ(b_dims[0], - 1, - platform::errors::InvalidArgument( - "FCBias shapes must be 1x%d, when FCBias dim = 2, " - "but received dim[0] is: %d.", - D, - b_dims[0])); - PADDLE_ENFORCE_EQ(b_dims[1], - D, - platform::errors::InvalidArgument( - "FCBias shapes must be 1x%d, when FCBias dim = 2, " - "but received dim[1] is: %d.", - D, - b_dims[1])); - } - } - - ctx->SetOutputDim("Out", {ins_dims[0][0], D}); - // fcout should be reshape when run since can not get lod in infershape - // explicit share the ref lod - ctx->ShareLoD("X", "Out", 0); -} - -phi::KernelKey FusionSeqExpandConcatFCOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); -} - -void FusionSeqExpandConcatFCOpMaker::Make() { - AddInput("X", - "(phi::DenseTensor) input LodDTensors, the first one must be have " - "ref lod " - "for sequence expand, and the rest input should have same lod.") - .AsDuplicable(); - AddInput("FCWeight", "(phi::DenseTensor) the weights of fc."); - AddInput("FCBias", "(phi::DenseTensor, optional) the bias of fc.") - .AsDispensable(); - AddOutput("Out", "(phi::DenseTensor) Output LodTensor."); - AddOutput( - "FCOut", - "(phi::DenseTensor) the intermediate tensor to keep the result of fc." - "Shape is (N x D), where N is the batch size, D is the output dim of fc") - .AsIntermediate(); - AddAttr("fc_activation", - "(string, default: identity)" - "The activation for the result of fc." - "`identity` by default.") - .SetDefault("identity") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddComment(R"DOC( -Fusion Sequence expand + concat + fc Operator. - -All below conditions should be meet: - -The ref_level of seq_expand should be 0. - -The ref lod of seq_expand level is the first input of concat. - -The other inputs should have same lod and same batch size of ref lod. - -The seq len of other inputs should be 1. - -The concat axis should be 1. - -)DOC"); -} - -template -class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto* w = ctx.Input("FCWeight"); - auto* b = ctx.Input("FCBias"); - auto* out = ctx.Output("Out"); - auto* fc_out = ctx.Output("FCOut"); - - auto* ref_in = ins[0]; - auto ref_lod = ref_in->lod(); - auto in1_lod = ins[1]->lod(); - auto ref_dims = ref_in->dims(); // T x M0 - auto in1_dims = ins[1]->dims(); // N x M1 - auto w_dims = w->dims(); - const int N = static_cast(ref_lod[0].size() - 1); - const int total_T = static_cast(ref_dims[0]); - const int M0 = static_cast(ref_dims[1]); - const int M1 = static_cast(in1_dims[1]); - const int D = static_cast(w_dims[1]); - - // some check and fcout should be reshape here - // since infershape can not get lod info - PADDLE_ENFORCE_EQ( - ref_lod.size(), - 1UL, - platform::errors::InvalidArgument( - "Only support input lod size is 1, but received value is: %d.", - ref_lod.size())); - PADDLE_ENFORCE_EQ( - in1_lod.size(), - 1UL, - platform::errors::InvalidArgument( - "Only support input lod size is 1, but received value is: %d.", - in1_lod.size())); - PADDLE_ENFORCE_EQ(static_cast(in1_lod[0].size() - 1), - N, - platform::errors::InvalidArgument( - "Batch size of all inputs should be equal to %d, but " - "received value is: %d.", - N, - static_cast(in1_lod[0].size() - 1))); - PADDLE_ENFORCE_EQ( - static_cast(in1_lod[0][N]), - N, - platform::errors::InvalidArgument("Seq_length of other inputs should " - "be %d, but received value is: %d.", - N, - static_cast(in1_lod[0][N]))); - PADDLE_ENFORCE_EQ( - in1_dims[0], - N, - platform::errors::InvalidArgument( - "input height should be batch size: %d, but received value is %d.", - N, - in1_dims[0])); - for (size_t i = 2; i < ins.size(); ++i) { - PADDLE_ENFORCE_EQ(ins[i]->dims()[0], - N, - platform::errors::InvalidArgument( - "All other inputs height should be equal to %d, " - "but received value is: %d.", - N, - ins[i]->dims()[0])); - PADDLE_ENFORCE_EQ(ins[i]->lod(), - in1_lod, - platform::errors::InvalidArgument( - "All other inputs should have same lod: %d, but " - "received value is: %d.", - in1_lod, - ins[i]->lod())); - } - fc_out->Resize({N, D}); - - std::function fc_act; - auto& fc_act_str = ctx.Attr("fc_activation"); - if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { - phi::funcs::VecActivations act_functor; - fc_act = act_functor(fc_act_str); - } else { - phi::funcs::VecActivations act_functor; - fc_act = act_functor(fc_act_str); - } - - const T* ref_in_data = ref_in->data(); - const T* in1_data = ins[1]->data(); - const T* w_data = w->data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - phi::funcs::FCFunctor fc; - fc(dev_ctx, - total_T, - D, - M0, - ref_in_data, - w_data, - out_data, - b ? b->data() : NULL); - w_data = w_data + M0 * D; - // first write on - blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); - w_data = w_data + M1 * D; - for (size_t i = 2; i < ins.size(); ++i) { - // add on - const T* in_data = ins[i]->data(); - const int K = static_cast(ins[i]->dims()[1]); - blas.GEMM(CblasNoTrans, - CblasNoTrans, - N, - D, - K, - static_cast(1), - in_data, - K, - w_data, - D, - static_cast(1), - fc_out_data, - D); - w_data = w_data + K * D; - } - T* cur_out_data = out_data; - for (int i = 0; i < N; ++i) { - int seq_len = static_cast(ref_lod[0][i + 1] - ref_lod[0][i]); - T* src = fc_out_data + i * D; - for (int step = 0; step < seq_len; ++step) { - blas.VADD(D, cur_out_data, src, cur_out_data); - cur_out_data = cur_out_data + D; - } - } - fc_act(total_T * D, out_data, out_data); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_seqexpand_concat_fc, - ops::FusionSeqExpandConcatFCOp, - ops::FusionSeqExpandConcatFCOpMaker); - -PD_REGISTER_STRUCT_KERNEL(fusion_seqexpand_concat_fc, - CPU, - ALL_LAYOUT, - ops::FusionSeqExpandConcatFCOpKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h deleted file mode 100644 index 7438b6c7174873..00000000000000 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - -class FusionSeqExpandConcatFCOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc deleted file mode 100644 index 5ec5e8081bb6f1..00000000000000 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ /dev/null @@ -1,387 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/operators/fused/fusion_gru_op.h" -#include "paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h" -#include "paddle/phi/backends/onednn/onednn_reuse.h" -#include "paddle/phi/core/expect.h" - -namespace paddle { -namespace operators { - -using phi::OneDNNContext; -using phi::funcs::OneDNNGetDataType; -using phi::funcs::OneDNNMemDesc; -using phi::funcs::RNNReorderType; -using OneDNNMemoryFormat = dnnl::memory::format_tag; - -template -class GRUMKLDNNHandler : public RNNMKLDNNHandler { - public: - GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, - const OneDNNContext& dev_ctx, - const dnnl::engine onednn_engine, - platform::Place cpu_place UNUSED, - const phi::DenseTensor* input, - const phi::DenseTensor* weight_h, - const phi::DenseTensor* h0, - const bool is_reverse, - const int64_t N, - const int64_t Ti, - const int64_t IC, - const int64_t OC, - const std::string& unique_name UNUSED) - : RNNMKLDNNHandler( - ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - input, - weight_h, - h0, - is_reverse, - N, - Ti, - IC, - OC, - 3, - ctx.InputName("X") + ctx.InputName("WeightH")) { - const bool is_INT8 = std::is_same::value; - - if (unlikely(!this->isCached())) { - // oneDNN kernel has hardcoded activation functions - PADDLE_ENFORCE_EQ( - ctx.Attr("gate_activation"), - "sigmoid", - platform::errors::Unimplemented( - "oneDNN fusion_gru supports only sigmoid as a gate activation.")); - PADDLE_ENFORCE_EQ( - ctx.Attr("activation"), - "tanh", - platform::errors::Unimplemented( - "oneDNN fusion_gru supports only tanh as an activation.")); - - // Weights for int8 kernel are of a type s8 - const auto weights_dt = - is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType(); - - // oneDNN RNN dimensions - const int64_t D = 1; // Directions - const int64_t L = 1; // Layers (PP supports only 1 stacked layer) - const int64_t G = 3; // Number of Gates, 3 for GRU - - // Create memory descriptors - auto input_md = OneDNNMemDesc( - {Ti, N, IC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); - auto weight_x_md = - OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any); - auto weight_h_md = - OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any); - auto bias_md = OneDNNMemDesc( - {L, D, G, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldgo); - auto hidden_md = OneDNNMemDesc( - {Ti, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); - auto h0_md = OneDNNMemDesc( - {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc); - - // Create GRU oneDNN primitive - const auto direction = - is_reverse ? dnnl::rnn_direction::unidirectional_right2left - : dnnl::rnn_direction::unidirectional_left2right; - - this->AcquireForwardPrimitiveDescriptor( - this->attr_, - dnnl::prop_kind::forward_inference, - direction, - input_md, - h0_md, - weight_x_md, - weight_h_md, - bias_md, - hidden_md, - dnnl::memory::desc()); - } - } - - template - std::shared_ptr AcquireWeightXMemory( - const phi::DenseTensor* weight_x, const bool origin_mode) { - const std::string wx_key = this->memory_key_ + "@weight_x"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); - - if (!memory_p) { - auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldigo); - auto user_memory = dnnl::memory(user_md, this->engine_); - - auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); - memcpy(weight_x_data, - weight_x->data(), - sizeof(U) * this->IC * this->G * this->OC); - - if (origin_mode == false) { - for (int64_t i = 0; i < this->IC; ++i) { - for (int64_t j = 0; j < this->OC; ++j) { - U minus_one(-1.0f); - weight_x_data[j] = minus_one * weight_x_data[j]; - } - weight_x_data += 3 * this->OC; - } - } - - memory_p = std::make_shared( - this->fwd_pd_->weights_layer_desc(), this->engine_); - - auto& astream = OneDNNContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, this->attr_) - .execute(astream, user_memory, *memory_p); - - this->dev_ctx_.SetBlob(wx_key, memory_p); - } - return memory_p; - } - - template - std::shared_ptr AcquireWeightHMemory( - const phi::DenseTensor* weight_h, const bool origin_mode) { - const std::string wh_key = this->memory_key_ + "@weight_h"; - auto memory_p = - std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); - - if (!memory_p) { - auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC}, - OneDNNGetDataType(), - OneDNNMemoryFormat::ldigo); - auto user_memory = dnnl::memory(user_md, this->engine_); - - // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to - // oneDNN format [OC, 3OC] - auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); - auto* user_weight_h_data = weight_h->data(); - - auto src1_iter = user_weight_h_data; - auto src2_iter = user_weight_h_data + 2 * this->OC * this->OC; - - for (int64_t c = 0; c < this->OC; ++c) { - memcpy(weight_h_data, src1_iter, 2 * this->OC * sizeof(U)); - memcpy(weight_h_data + 2 * this->OC, src2_iter, this->OC * sizeof(U)); - - src1_iter += 2 * this->OC; - src2_iter += this->OC; - weight_h_data += 3 * this->OC; - } - - weight_h_data = reinterpret_cast(user_memory.get_data_handle()); - - if (origin_mode == false) { - for (int64_t i = 0; i < this->OC; ++i) { - for (int64_t j = 0; j < this->OC; ++j) { - U minus_one(-1.0f); - weight_h_data[j] = minus_one * weight_h_data[j]; - } - weight_h_data += 3 * this->OC; - } - } - - memory_p = std::make_shared( - this->fwd_pd_->weights_iter_desc(), this->engine_); - - auto& astream = OneDNNContext::tls().get_stream(); - dnnl::reorder(user_memory, *memory_p, this->attr_) - .execute(astream, user_memory, *memory_p); - - this->dev_ctx_.SetBlob(wh_key, memory_p); - } - return memory_p; - } - - std::shared_ptr AcquireBiasMemory(const phi::DenseTensor* bias, - const bool origin_mode) { - const std::string bias_key = this->memory_key_ + "@bias"; - auto memory_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(bias_key)); - - if (!memory_p) { - memory_p = std::make_shared(this->fwd_pd_->bias_desc(), - this->engine_); - auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); - if (bias) { - const float* user_bias_data = - bias->data(); // Bias in oneDNN is always float - memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); - } else { - // oneDNN always need bias memory, if it's not provided in PP, let - // oneDNN allocate memory and set it to 0 - memset(bias_data, 0, sizeof(float) * this->G * this->OC); - } - - if (origin_mode == false && bias) { - for (int64_t i = 0; i < this->OC; ++i) { - bias_data[i] *= -1; - } - } - this->dev_ctx_.SetBlob(bias_key, memory_p); - } - return memory_p; - } -}; - -template -class FusionGRUMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const bool is_bf16 = std::is_same::value; - const bool force_fp32_output = ctx.Attr("force_fp32_output"); - - // BF16 does not support force output - if (!is_bf16 && force_fp32_output) { // NOLINT - RunKernel(ctx); - } else { - RunKernel(ctx); - } - } - - template - void RunKernel(const framework::ExecutionContext& ctx) const { - auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - // Get Tensors - const auto* input = ctx.Input("X"); - const auto* h0 = ctx.Input("H0"); - const auto* weight_x = ctx.Input("WeightX"); - const auto* weight_h = ctx.Input("WeightH"); - const auto* bias = ctx.Input("Bias"); - auto* hidden = ctx.Output("Hidden"); - auto x_dims = input->dims(); - auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) - ? phi::flatten_to_2d(x_dims, 1) - : x_dims; - // Get attributes - const bool is_reverse = ctx.Attr("is_reverse"); - const bool origin_mode = ctx.Attr("origin_mode"); - - // Get tensor dimensions - const auto x_mat_dims_vec = phi::vectorize(x_mat_dims); - const auto weight_h_dims = phi::vectorize(weight_h->dims()); - const auto& input_lod = input->lod()[0]; - - // Calculate RNN dimensions - const int64_t N = input_lod.size() - 1; // Number of sentences (batches) - const int64_t Ti = // Max length of the sentence in a batch - [&input_lod]() { - size_t res = 0; - for (size_t i = 0; i < (input_lod.size() - 1); ++i) { - res = std::max(res, input_lod[i + 1] - input_lod[i]); - } - return res; - }(); - const int64_t IC = x_mat_dims_vec[1]; // Input channels - const int64_t OC = weight_h_dims[0]; // Output channels - - GRUMKLDNNHandler handler( - ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - input, - weight_h, - h0, - is_reverse, - N, - Ti, - IC, - OC, - ctx.InputName("X") + ctx.InputName("WeightH")); - - auto input_memory_p = - handler.AcquireInputMemoryWithReorder(input, is_reverse); - - std::shared_ptr h0_memory_p, weight_h_memory_p, - weight_x_memory_p; - - if (framework::TransToProtoVarType(weight_h->dtype()) == - paddle::framework::proto::VarType_Type_FP32) { - h0_memory_p = handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory(weight_x, origin_mode); - weight_h_memory_p = - handler.template AcquireWeightHMemory(weight_h, origin_mode); - } else if (framework::TransToProtoVarType(weight_h->dtype()) == - paddle::framework::proto::VarType_Type_BF16) { - h0_memory_p = - handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory( - weight_x, origin_mode); - weight_h_memory_p = - handler.template AcquireWeightHMemory( - weight_h, origin_mode); - } else { - h0_memory_p = handler.template AcquireH0Memory(h0); - weight_x_memory_p = - handler.template AcquireWeightXMemory(weight_x, origin_mode); - weight_h_memory_p = - handler.template AcquireWeightHMemory(weight_h, origin_mode); - } - - auto bias_memory_p = handler.AcquireBiasMemory(bias, origin_mode); - auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); - - std::unordered_map gru_args = { - {DNNL_ARG_SRC_LAYER, *input_memory_p}, - {DNNL_ARG_SRC_ITER, *h0_memory_p}, - {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, - {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, - {DNNL_ARG_BIAS, *bias_memory_p}, - {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; - - auto gru_forward_p = handler.AcquireForwardPrimitive(); - - auto& astream = OneDNNContext::tls().get_stream(); - gru_forward_p->execute(astream, gru_args); - astream.wait(); - - auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); - auto* hidden_data = - phi::funcs::to_void_cast(hidden->mutable_data(ctx.GetPlace())); - if (handler.is_NTC()) { - handler.reorderRNNdata(hidden_onednn_data, - hidden_data, - input_lod, - is_reverse, - RNNReorderType::NTC_PP); - } else { - handler.reorderRNNdata(hidden_onednn_data, - hidden_data, - input_lod, - is_reverse, - RNNReorderType::TNC_PP); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(fusion_gru, - MKLDNN, - phi::CPUPlace, - ops::FusionGRUMKLDNNKernel, - ops::FusionGRUMKLDNNKernel, - ops::FusionGRUMKLDNNKernel); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index a672f5ac99aa8f..6b0a36fc564721 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -22,6 +22,7 @@ #include "paddle/phi/core/cuda_stream.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" @@ -29,8 +30,6 @@ #include "paddle/phi/kernels/funcs/tensor_to_string.h" #include "paddle/utils/optional.h" -#include "paddle/fluid/distributed/collective/utils.h" - #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/flags.h" @@ -2404,9 +2403,9 @@ void DistributedFusedLambKernel( if (num_devices > 1) { // ncclAllGather if (local_comm_ctx) { - auto send_buf = paddle::distributed::GetPartialTensor( + auto send_buf = distributed::GetPartialTensor( *fp32_param_out, fp32_offset, fp32_numel_each_device); - auto recv_buf = paddle::distributed::GetPartialTensor( + auto recv_buf = distributed::GetPartialTensor( *fp32_param_out, 0, fp32_numel_each_device); local_comm_ctx->AllGather(&recv_buf, send_buf, stream); } else { @@ -2442,9 +2441,9 @@ void DistributedFusedLambKernel( if (num_devices > 1) { // ncclAllGather if (local_comm_ctx) { - auto send_buf = paddle::distributed::GetPartialTensor( + auto send_buf = distributed::GetPartialTensor( *fp16_param_out, fp16_offset, fp16_numel_each_device); - auto recv_buf = paddle::distributed::GetPartialTensor( + auto recv_buf = distributed::GetPartialTensor( *fp16_param_out, 0, fp16_numel_each_device); local_comm_ctx->AllGather(&recv_buf, send_buf, stream); } else { diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 4777c5ab4971d7..86cae425e9d14e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -358,12 +358,10 @@ def GenBuildOutputs( """ CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name}; if ({name}_.dyn_cast().owner()->isa()) {{ - {name} = std::move(phi::IntArray({name}_.dyn_cast().owner() + {name} = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + {name}_.dyn_cast().owner() ->dyn_cast() - .attribute("value") - .dyn_cast() - .data() - .GetData())); + .attribute("value")))); }} else if ({name}_.type().isa()) {{ size_t {name}_size = {name}_.type().dyn_cast().size(); {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); @@ -378,12 +376,10 @@ def GenBuildOutputs( CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector {name}; if ({name}_.dyn_cast().owner()->isa()) {{ - {name} = {name}_.dyn_cast().owner() + {name} = paddle::dialect::GetInt64Vector( + {name}_.dyn_cast().owner() ->dyn_cast() - .attribute("value") - .dyn_cast() - .data() - .GetData(); + .attribute("value")); }} else if ({name}_.type().isa()) {{ size_t {name}_size = {name}_.type().dyn_cast().size(); {name} = std::vector({name}_size, -1); @@ -696,41 +692,36 @@ def gen_build_func_str( ) GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); """ GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); """ GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name}; for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); }} """ GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().GetData(); """ GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - PADDLE_ENFORCE( + IR_ENFORCE( attributes.find("{attribute_name}") != attributes.end(), - phi::errors::NotFound( - "'{attribute_name}' Attribute is expected for {op_name}. ")); + "'{attribute_name}' Attribute is expected for {op_name}. "); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().to<{attr_type}>(); """ diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py index c760d7fb85b84e..50cfa79cdd443c 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -20,13 +20,13 @@ CPP_FILE_TEMPLATE = """ #include "paddle/fluid/pir/drr/ir_operation_factory.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +{op_header} #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" namespace pir {{ namespace drr {{ -void OperationFactory::RegisterGeneratedOpCreator() {{ +void OperationFactory::Register{dialect}GeneratedOpCreator() {{ {body} }} @@ -41,7 +41,7 @@ [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) {{ - return rewriter.Build( + return rewriter.Build<{namespace}::{op_class_name}>( {params_code}); }}); """ @@ -63,6 +63,12 @@ }}); """ +Dialect2NameSpaceMap = {"pd_op": "paddle::dialect", "cinn_op": "cinn::dialect"} +Dialect2OpHeaderMap = { + "pd_op": "#include \"paddle/fluid/pir/dialect/operator/ir/pd_op.h\"", + "cinn_op": "#include \"paddle/cinn/hlir/dialect/operator/ir/cinn_op.h\"", +} + class OpCreatorCodeGen: def __init__(self, op_yaml_files, op_compat_yaml_file, dialect_name): @@ -107,6 +113,7 @@ def gen_cpp_file_code(self, cpp_file_path): if len(op_info_item.mutable_attribute_name_list) == 0: body_code += NORMAL_FUNCTION_TEMPLATE.format( op_name=ir_op_name, + namespace=Dialect2NameSpaceMap[self.dialect_name], op_class_name=(to_pascal_case(phi_op_name) + "Op"), params_code=", ".join(params_no_mutable_attr), ) @@ -139,7 +146,13 @@ def gen_cpp_file_code(self, cpp_file_path): ) with open(cpp_file_path, 'w') as f: - f.write(CPP_FILE_TEMPLATE.format(body=body_code)) + f.write( + CPP_FILE_TEMPLATE.format( + dialect=to_pascal_case(self.dialect_name), + op_header=Dialect2OpHeaderMap[self.dialect_name], + body=body_code, + ) + ) def ParseArguments(): diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 8983ffa38b5629..185a874615d393 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1038,8 +1038,6 @@ def OpGenerator( and op_info.op_phi_name[0] not in vjp_interface_black_list ): op_interfaces += ["paddle::dialect::VjpInterface"] - if op_info.op_phi_name[0] in decomp_interface_declare_gen_op_list: - op_interfaces += ["paddle::dialect::DecompInterface"] exclusive_interface_str = gen_exclusive_interface_str( op_info, op_info_items ) @@ -1056,6 +1054,9 @@ def OpGenerator( # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: + if op_name in decomp_interface_declare_gen_op_list: + op_interfaces += ["paddle::dialect::DecompInterface"] + exclusive_interface_str += "\n static std::vector> Decomp(pir::Operation* op);" if op_name in PD_MANUAL_OP_LIST: continue if op_kernel_map is None: diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 299d4197b79475..6d7c5224e3803e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list - # generator interfaces from vjp_interface_black_list import vjp_interface_black_list @@ -23,65 +21,44 @@ fn(infer_meta); }} """ +CHECK_INPUT_TEMPLATE = """ + PADDLE_ENFORCE_EQ( + inputs_.size(), + {inputs_size}, + platform::errors::InvalidArgument("{op_name} op's inputs size should be {inputs_size}, but now is %d.", inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + {outputs_size}, + platform::errors::InvalidArgument("{op_name} op's outputs size should be {outputs_size}, but now is %d.", outputs.size())); +""" OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ - {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" + {input_type} {input_name}(std::make_shared({vjp_param_name}[{input_idx}][0]));""" OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """ - pir::CombineOp combine_op_obj_{input_name} = - op_obj.{input_name}().dyn_cast().owner()->dyn_cast(); std::vector {input_name}; - for (size_t idx = 0; idx < combine_op_obj_{input_name}.inputs().size(); idx++) {{ + for (size_t idx = 0; idx < {vjp_param_name}[{input_idx}].size(); idx++) {{ {input_name}.emplace_back( - std::make_shared(combine_op_obj_{input_name}.inputs()[idx])); + std::make_shared({vjp_param_name}[{input_idx}][idx])); }}""" OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE = """ paddle::optional {input_name}; - if (!IsEmptyValue(op_obj.{input_name}())){{ - {input_name} = paddle::make_optional(Tensor(std::make_shared(op_obj.{input_name}()))); + if (!IsEmptyValue({vjp_param_name}[{input_idx}][0])){{ + {input_name} = paddle::make_optional(Tensor(std::make_shared({vjp_param_name}[{input_idx}][0]))); }}""" OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE = """ paddle::optional> {input_name}; - if (!IsEmptyValue(op_obj.{input_name}())){{ - pir::CombineOp combine_op_obj = - op_obj.{input_name}().dyn_cast().owner()->dyn_cast(); - std::vector optional_{input_name}; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{ + std::vector optional_{input_name}; + if (!IsEmptyValue({vjp_param_name}[{input_idx}][0])){{ + for (size_t idx = 0; idx < {vjp_param_name}[{input_idx}].size(); idx++) {{ optional_{input_name}.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); + std::make_shared({vjp_param_name}[{input_idx}][idx])); }} {input_name} = paddle::make_optional>(optional_{input_name}); }}""" -OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ - Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" - -OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ - std::vector {output_grad_name}; - for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{ - {output_grad_name}.emplace_back( - std::make_shared(out_grads[{index}][idx])); - }}""" - -OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE = """ - paddle::optional {output_grad_name}; - if (!IsEmptyValue(out_grads[{idx1}][{idx2}])){{ - {output_grad_name} = paddle::make_optional(Tensor(std::make_shared(out_grads[{idx1}][{idx2}]))); - }}""" - -OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE = """ - paddle::optional> {output_grad_name}; - std::vector optional_{output_grad_name}; - if (!IsEmptyValue(out_grads[{index}])){{ - for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{ - optional_{output_grad_name}.emplace_back( - std::make_shared(out_grads[{index}][idx])); - }} - {output_grad_name} = paddle::make_optional>(optional_{output_grad_name}); - }}""" - OP_VJP_ATTRIBUTE_TEMPLATE = """ {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();""" @@ -111,12 +88,10 @@ }""" OP_VJP_DEFINE_TEMPLATE = """ -std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ - {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); (void)op_obj; - +std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients){{ +{check_param} VLOG(6) << "Prepare inputs of {op_grad_name}"; -{forward_input_output_code} -{forward_output_grad_code} +{backward_input_code} VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; {attribute_code} @@ -144,74 +119,66 @@ def gen_op_vjp_str( op_grad_info, ): bw_input_list = op_grad_info.input_name_list - forward_input_output_code = '' - forward_output_grad_code = '' + fwd_input_and_mutable_attr_name_list = ( + op_info.input_name_list + op_info.mutable_attribute_name_list + ) + + backward_input_code = '' build_args_str = '' grad_idx = -1 for idx in range(len(bw_input_list)): - build_args_str += bw_input_list[idx] + ", " + bw_input_name = bw_input_list[idx] + build_args_str += bw_input_name + ", " input_type = input_types_map[op_grad_info.input_type_list[idx]] - if ( - bw_input_list[idx] in op_info.input_name_list - or bw_input_list[idx] in op_info.output_name_list - ): - if op_grad_info.input_optional_list[idx] == 'true': - if input_type == 'Tensor': - forward_input_output_code += ( - OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], - ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], - ) + + vjp_param_name = '' + index_0 = -1 + if bw_input_name in fwd_input_and_mutable_attr_name_list: + vjp_param_name = 'inputs_' + index_0 = fwd_input_and_mutable_attr_name_list.index(bw_input_name) + elif bw_input_name in op_info.output_name_list: + vjp_param_name = 'outputs' + index_0 = op_info.output_name_list.index(bw_input_name) + else: + vjp_param_name = 'out_grads' + grad_idx += 1 + index_0 = grad_idx + if op_grad_info.input_optional_list[idx] == 'true': + if input_type == 'Tensor': + backward_input_code += ( + OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( + vjp_param_name=vjp_param_name, + input_name=bw_input_name, + input_idx=index_0, ) + ) else: - if input_type == 'Tensor': - forward_input_output_code += ( - OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( - input_type=input_type, - input_name=bw_input_list[idx], - ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], - ) + backward_input_code += ( + OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format( + vjp_param_name=vjp_param_name, + input_name=bw_input_name, + input_idx=index_0, ) + ) else: - grad_idx += 1 - if op_grad_info.input_optional_list[idx] == 'true': - if input_type == 'Tensor': - forward_input_output_code += ( - OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE.format( - output_grad_name=bw_input_list[idx], - idx1=grad_idx, - idx2=0, - ) - ) - else: - forward_input_output_code += OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE.format( - output_grad_name=bw_input_list[idx], index=grad_idx + if input_type == 'Tensor': + backward_input_code += ( + OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + vjp_param_name=vjp_param_name, + input_type=input_type, + input_name=bw_input_name, + input_idx=index_0, ) + ) else: - if input_type == 'Tensor': - forward_output_grad_code += ( - OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( - output_grad_name=bw_input_list[idx], - idx1=grad_idx, - idx2=0, - ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE.format( - output_grad_name=bw_input_list[idx], index=grad_idx - ) + backward_input_code += ( + OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format( + vjp_param_name=vjp_param_name, + input_name=bw_input_name, + input_idx=index_0, ) + ) + op_attribute_list = op_grad_info.attribute_name_list attribute_code = '' build_attr_str = '' @@ -221,8 +188,12 @@ def gen_op_vjp_str( if op_attribute_list[idx] in op_info.mutable_attribute_name_list: attribute_code += ( OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + vjp_param_name='inputs_', input_type="Tensor", input_name=op_attribute_list[idx], + input_idx=fwd_input_and_mutable_attr_name_list.index( + op_attribute_list[idx] + ), ) ) build_args_str += op_attribute_list[idx] + ", " @@ -272,14 +243,19 @@ def gen_op_vjp_str( inputs_list=build_args_str, ) stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE + check_param = CHECK_INPUT_TEMPLATE.format( + op_name=op_phi_name_format, + inputs_size=len(fwd_input_and_mutable_attr_name_list), + outputs_size=len(op_info.output_name_list), + out_grads_size=grad_idx + 1, + ) str = OP_VJP_DEFINE_TEMPLATE.format( + check_param=check_param, op_class_name=op_class_name, op_grad_name=op_grad_name, op_phi_name=op_phi_name, - res_size=len(op_info.input_name_list), - forward_input_output_code=forward_input_output_code, - forward_output_grad_code=forward_output_grad_code, + backward_input_code=backward_input_code, attribute_code=attribute_code, call_vjp_code=call_vjp_code, stop_gradient_input_grad_code=stop_gradient_input_grad_code, @@ -317,7 +293,5 @@ def gen_exclusive_interface_str(op_info, op_info_items): " static void InferMeta( phi::InferMetaContext *infer_meta );" ) if op_info.op_phi_name[0] not in vjp_interface_black_list: - exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" - if op_info.op_phi_name[0] in decomp_interface_declare_gen_op_list: - exclusive_interface_str += "\n static std::vector> Decomp(pir::Operation* op);" + exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py index 3a2515f278915a..f42a73347d13ad 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py @@ -19,8 +19,8 @@ VLOG(4) << "Verifying inputs:"; {{ auto input_size = num_operands(); - PADDLE_ENFORCE_EQ(input_size, {inputs_size}u, - phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check} + IR_ENFORCE(input_size == {inputs_size}u, + "The size %d of inputs must be equal to {inputs_size}.", input_size);{inputs_type_check} }} VLOG(4) << "Verifying attributes:"; {{{attributes_check} @@ -28,8 +28,8 @@ VLOG(4) << "Verifying outputs:"; {{ auto output_size = num_results(); - PADDLE_ENFORCE_EQ(output_size, {outputs_size}u, - phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check} + IR_ENFORCE(output_size == {outputs_size}u, + "The size %d of outputs must be equal to {outputs_size}.", output_size);{outputs_type_check} }} VLOG(4) << "End Verifying for: {op_name}."; }} @@ -40,83 +40,83 @@ """ INPUT_TYPE_CHECK_TEMPLATE = """ - PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));""" + IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), + "Type validation failed for the {index}th input.");""" INPUT_VECTORTYPE_CHECK_TEMPLATE = """ if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th input."); }} }} else {{ - PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), + "Type validation failed for the {index}th input."); }}""" INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ if (auto val = (*this)->operand({index})) {{ - PADDLE_ENFORCE(val.type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(val.type().isa<{standard}>(), + "Type validation failed for the {index}th input."); }}""" INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto val = (*this)->operand({index})) {{ if (auto vec_type = val.type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th input."); }} }} else {{ - PADDLE_ENFORCE(val.type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + IR_ENFORCE(val.type().isa<{standard}>(), + "Type validation failed for the {index}th input."); }} }}""" ATTRIBUTE_CHECK_TEMPLATE = """ - PADDLE_ENFORCE(attributes.count("{attribute_name}")>0, - phi::errors::PreconditionNotMet("{attribute_name} does not exist.")); - PADDLE_ENFORCE(attributes.at("{attribute_name}").isa<{standard}>(), - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not {standard}.")); + IR_ENFORCE(attributes.count("{attribute_name}")>0, + "{attribute_name} does not exist."); + IR_ENFORCE(attributes.at("{attribute_name}").isa<{standard}>(), + "Type of attribute: {attribute_name} is not {standard}."); """ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """ - PADDLE_ENFORCE(attributes.count("{attribute_name}")>0, - phi::errors::PreconditionNotMet("{attribute_name} does not exist.")); - PADDLE_ENFORCE(attributes.at("{attribute_name}").isa(), - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not pir::ArrayAttribute.")); + IR_ENFORCE(attributes.count("{attribute_name}")>0, + "{attribute_name} does not exist."); + IR_ENFORCE(attributes.at("{attribute_name}").isa(), + "Type of attribute: {attribute_name} is not pir::ArrayAttribute."); for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); + IR_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), + "Type of attribute: {attribute_name} is not right."); }}""" OUTPUT_TYPE_CHECK_TEMPLATE = """ - PADDLE_ENFORCE((*this)->result({index}).type().isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));""" + IR_ENFORCE((*this)->result({index}).type().isa<{standard}>(), + "Type validation failed for the {index}th output.");""" OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """ auto output_{index}_type = (*this)->result({index}).type(); if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th output."); }} }} else {{ - PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(output_{index}_type.isa<{standard}>(), + "Type validation failed for the {index}th output."); }}""" OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ if (auto output_{index}_type = (*this)->result({index}).type()) {{ - PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(output_{index}_type.isa<{standard}>(), + "Type validation failed for the {index}th output."); }}""" OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto output_{index}_type = (*this)->result({index}).type()) {{ if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ - PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(vec_type[i].isa<{standard}>(), + "Type validation failed for the {index}th output."); }} }} else {{ - PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + IR_ENFORCE(output_{index}_type.isa<{standard}>(), + "Type validation failed for the {index}th output."); }} }}""" diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index e2d17e7f118023..66dc7bbbdf323f 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -68,7 +68,23 @@ OPS_API_TEMPLATE = """ {{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},""" -NEED_GEN_STATIC_ONLY_APIS = ['fetch'] +NEED_GEN_STATIC_ONLY_APIS = [ + 'fetch', + 'fused_embedding_eltwise_layernorm', + 'fused_fc_elementwise_layernorm', + 'fused_multi_transformer_xpu', + 'fused_scale_bias_relu_conv_bnstats', + 'fusion_transpose_flatten_concat', + 'generate_sequence_xpu', + 'layer_norm_act_xpu', + 'multi_encoder_xpu', + 'multihead_matmul', + 'squeeze_excitation_block', + 'yolo_box_xpu', + 'fusion_gru', + 'fusion_seqconv_eltadd_relu', + 'fusion_seqexpand_concat_fc', +] NO_NEED_GEN_STATIC_ONLY_APIS = [ 'add_n_', @@ -88,6 +104,7 @@ 'fused_bn_add_activation_', 'fused_feedforward', 'fused_scale_bias_relu_conv_bnstats', + 'memcpy', 'print', 'recv_v2', 'rnn_', diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index c63e0c4e418338..a7841e4d6d8afb 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -24,12 +24,6 @@ vjp_interface_black_list = [ - 'frobenius_norm', - 'write_to_array', - 'fused_attention', - 'fused_feedforward', - 'set_value', - 'set_value_with_tensor', 'silu_grad', 'fused_dropout_add', 'fused_rotary_position_embedding', diff --git a/paddle/fluid/pir/dialect/operator/interface/infermeta.h b/paddle/fluid/pir/dialect/operator/interface/infermeta.h index 958d2df369ed9b..fe0f50a456008a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infermeta.h +++ b/paddle/fluid/pir/dialect/operator/interface/infermeta.h @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #pragma once #include "paddle/phi/core/infermeta_utils.h" diff --git a/paddle/fluid/pir/dialect/operator/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc index 8a4049ff09544b..24c12050cd0524 100644 --- a/paddle/fluid/pir/dialect/operator/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -20,6 +20,8 @@ namespace paddle { namespace dialect { std::vector> VjpInterface::Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { std::vector> out_grads_value; @@ -30,7 +32,7 @@ std::vector> VjpInterface::Vjp( } out_grads_value.emplace_back(std::move(grad_value)); } - return impl_->vjp_(op, out_grads_value, stop_gradients); + return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients); } } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/interface/vjp.h b/paddle/fluid/pir/dialect/operator/interface/vjp.h index 4f2292c7b6c02f..44d1731359beb5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/vjp.h +++ b/paddle/fluid/pir/dialect/operator/interface/vjp.h @@ -22,11 +22,15 @@ class VjpInterface : public pir::OpInterfaceBase { struct Concept { explicit Concept(std::vector> (*vjp)( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients)) : vjp_(vjp) {} std::vector> (*vjp_)( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients); }; @@ -35,9 +39,11 @@ class VjpInterface : public pir::OpInterfaceBase { struct Model : public Concept { static std::vector> Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { - return ConcreteOp::Vjp(op, out_grads, stop_gradients); + return ConcreteOp::Vjp(op, inputs, outputs, out_grads, stop_gradients); } Model() : Concept(Vjp) {} @@ -49,13 +55,17 @@ class VjpInterface : public pir::OpInterfaceBase { std::vector> Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { - return impl_->vjp_(op, out_grads, stop_gradients); + return impl_->vjp_(op, inputs, outputs, out_grads, stop_gradients); } std::vector> Vjp( pir::Operation* op, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients); diff --git a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt index 6c07f558e61abc..26343fa7249682 100644 --- a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt @@ -98,7 +98,7 @@ set(api_source_file_tmp ${api_source_file}.tmp) add_custom_command( OUTPUT ${api_header_file} ${api_source_file} COMMAND - ${PYTHON_EXECUTABLE} ${api_gen_file} --op_yaml_files ${api_gen_yaml_files} + ${PYTHON_EXECUTABLE} ${api_gen_file} --op_yaml_files ${op_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} --api_def_h_file ${api_header_file_tmp} --api_def_cc_file ${api_source_file_tmp} @@ -129,10 +129,9 @@ set(python_c_source_file_tmp ${python_c_source_file}.tmp) add_custom_command( OUTPUT ${python_c_header_file} ${python_c_source_file} COMMAND - ${PYTHON_EXECUTABLE} ${python_c_gen_file} --op_yaml_files - ${api_gen_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file} - --namespaces "paddle,pybind" --python_c_def_h_file - ${python_c_header_file_tmp} --python_c_def_cc_file + ${PYTHON_EXECUTABLE} ${python_c_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces "paddle,pybind" + --python_c_def_h_file ${python_c_header_file_tmp} --python_c_def_cc_file ${python_c_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${python_c_header_file_tmp} ${python_c_header_file} @@ -160,9 +159,9 @@ set(ops_api_source_file_tmp ${ops_api_source_file}.tmp) add_custom_command( OUTPUT ${ops_api_source_file} COMMAND - ${PYTHON_EXECUTABLE} ${ops_api_gen_file} --op_yaml_files - ${api_gen_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file} - --namespaces "paddle,pybind" --ops_api_file ${ops_api_source_file_tmp} + ${PYTHON_EXECUTABLE} ${ops_api_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces "paddle,pybind" + --ops_api_file ${ops_api_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ops_api_source_file_tmp} ${ops_api_source_file} COMMENT "copy_if_different ${ops_api_source_file}" diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 317ce64feea084..83ced4c1458fe1 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -51,6 +51,8 @@ class AddNOp : public pir::Op> Vjp( pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); }; diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc index 80c13ac89def13..2d94be8c8bcb3e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc @@ -29,22 +29,28 @@ using IntArray = paddle::experimental::IntArray; std::vector> AddNOp::Vjp( pir::Operation* op, + const std::vector>& inputs_, + const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { - AddNOp op_obj = op->dyn_cast(); - VLOG(6) << "Prepare inputs of add_n_grad"; + PADDLE_ENFORCE_EQ( + inputs_.size(), + 1u, + platform::errors::InvalidArgument( + "addn op's inputs size should be 1 but now is %d", inputs_.size())); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1u, + platform::errors::InvalidArgument( + "addn op's outputs size should be 1 but now is %d", outputs.size())); PADDLE_ENFORCE( - op_obj.inputs() != nullptr, - paddle::platform::errors::Fatal("addn op's inputs can't be null")); - pir::CombineOp combine_op_obj = op_obj.inputs() - .dyn_cast() - .owner() - ->dyn_cast(); + inputs_[0].size() != 0, + paddle::platform::errors::Fatal("addn op's inputs[0] can't be null")); std::vector inputs; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) { + for (size_t idx = 0; idx < inputs_[0].size(); idx++) { inputs.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); + std::make_shared(inputs_[0][idx])); } Tensor out_grad(std::make_shared(out_grads[0][0])); diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc index f10db043d1523d..3134214cf9029b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc @@ -35,6 +35,8 @@ phi::Scalar ScalarAttribute::data() { return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().data()); } else if (isa()) { diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h index 6b9edf98cb56a3..0b0973a5205c85 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h @@ -48,6 +48,7 @@ class ScalarAttribute : public pir::Attribute { (val.type_id() == pir::FloatAttribute::type_id()) || (val.type_id() == pir::DoubleAttribute::type_id()) || (val.type_id() == pir::Int32Attribute::type_id()) || + (val.type_id() == pir::IndexAttribute::type_id()) || (val.type_id() == pir::Int64Attribute::type_id()) || (val.type_id() == pir::StrAttribute::type_id()); } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 899863d58aba12..649b6886f54179 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -81,6 +81,16 @@ args : (Tensor[] x) output : Tensor(out) +- op : memcpy + args : (Tensor x, int dst_place_type) + output : Tensor(out) + infer_meta: + func: UnchangedInferMeta + param: [x] + kernel: + func : memcpy + param: [x, dst_place_type] + - op : print args : (Tensor in, int first_n, str message, int summarize, bool print_tensor_name = true, bool print_tensor_type = true, bool print_tensor_shape = true, bool print_tensor_layout = true, bool print_tensor_lod = true, str print_phase = "BOTH", bool is_forward = true) output : Tensor(out) @@ -130,7 +140,7 @@ param : [x, ring_id, dynamic_shape, peer, use_calc_stream] - op : set_value - args : (Tensor x, int64_t[] starts, int64_t[] ends, int64_t[] steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) + args : (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) output : Tensor(out) infer_meta: func: SetValueInferMeta @@ -142,7 +152,7 @@ backward: set_value_grad - op : set_value_with_tensor - args : (Tensor x, Tensor values, int64_t[] starts, int64_t[] ends, int64_t[] steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) + args : (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) output : Tensor(out) infer_meta: func: SetValueInferMeta diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 95e3d99bd573b2..81213383e3fcff 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -19,7 +19,7 @@ optional: linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln1_out, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance, dropout2_out, ln1_scale_grad, ln1_bias_grad, ln2_scale_grad, ln2_bias_grad, linear2_bias_grad - backward_op : set_value_grad - args : (Tensor out_grad, Tensor values, int64_t[] starts, int64_t[] ends, int64_t[] steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) + args : (Tensor out_grad, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) output : Tensor(x_grad), Tensor(values_grad) infer_meta: func: SetValueGradInferMeta diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 0aa2eaf143f7e9..c6bee0270d82eb 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -203,5 +203,24 @@ bool IsEmptyValue(const pir::Value& value) { return !value.impl() || !value.type(); } +std::vector GetInt64Vector(const pir::Attribute& attr) { + PADDLE_ENFORCE_EQ(attr.isa(), + true, + phi::errors::PreconditionNotMet( + "attribute MUST be a pir::ArrayAttribute")); + auto attr_vec = attr.dyn_cast().AsVector(); + + std::vector vec_int64; + for (auto vec_element : attr_vec) { + PADDLE_ENFORCE_EQ( + vec_element.isa(), + true, + phi::errors::PreconditionNotMet("element MUST be a Int64Attribute")); + vec_int64.push_back(vec_element.dyn_cast().data()); + } + + return vec_int64; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 1c228e7e850834..e35d7fa74cc649 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -133,5 +133,7 @@ bool IsLegacyOp(const std::string& name); bool IsEmptyValue(const pir::Value& value); +std::vector GetInt64Vector(const pir::Attribute& attr); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/drr/CMakeLists.txt b/paddle/fluid/pir/drr/CMakeLists.txt index c1b524dda69a6a..1d90762ed22066 100644 --- a/paddle/fluid/pir/drr/CMakeLists.txt +++ b/paddle/fluid/pir/drr/CMakeLists.txt @@ -23,6 +23,12 @@ set(fused_op_backward_yaml_file ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml ) +set(cinn_op_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/generated/ops.parsed.yaml) + +set(cinn_op_yaml_source_file + ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir/ops.yaml) + set(parsed_op_dir ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated) @@ -39,6 +45,12 @@ set(op_creator_file_tmp ${op_creator_file}.tmp) set(dialect_name pd_op) +set(cinn_op_creator_file + ${PADDLE_BINARY_DIR}/paddle/fluid/pir/drr/cinn_op_factory_generated.cc) +set(cinn_op_creator_file_tmp ${cinn_op_creator_file}.tmp) + +set(cinn_dialect_name cinn_op) + add_custom_command( OUTPUT ${op_creator_file} COMMAND @@ -59,7 +71,27 @@ add_custom_command( pd_op_dialect_op VERBATIM) +if(WITH_CINN AND NOT CINN_ONLY) + add_custom_command( + OUTPUT ${cinn_op_creator_file} + COMMAND + ${PYTHON_EXECUTABLE} ${op_creator_gen_file} --op_yaml_files + ${cinn_op_yaml_file} --op_compat_yaml_file ${op_compat_yaml_file} + --dialect_name ${cinn_dialect_name} --op_creator_file + ${cinn_op_creator_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${cinn_op_creator_file_tmp} + ${cinn_op_creator_file} + COMMENT "copy_if_different ${cinn_op_creator_file}" + DEPENDS ${op_creator_gen_file} ${op_compat_yaml_file} + ${cinn_op_yaml_source_file} pd_op_dialect_op cinn_op_dialect + VERBATIM) + set(CINN_SOURCE_FILE ${cinn_op_creator_file}) + + set(CINN_DEPS cinn_op_dialect) + +endif() + cc_library( drr - SRCS ${DRR_SRCS} ${op_creator_file} - DEPS pd_op_dialect pir) + SRCS ${DRR_SRCS} ${op_creator_file} ${CINN_SOURCE_FILE} + DEPS pd_op_dialect ${CINN_DEPS} pir) diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/api/drr_pattern_base.h index d5f19ff3e6e9be..1a84c42800373b 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_base.h +++ b/paddle/fluid/pir/drr/api/drr_pattern_base.h @@ -28,12 +28,13 @@ class DrrPatternBase { // Define the Drr Pattern. virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0; - std::unique_ptr> Build( + std::unique_ptr Build( pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { DrrPatternContext drr_context; this->operator()(&drr_context); - return std::make_unique>( - drr_context, ir_context, benefit); + std::string pattern_name = pir::get_type_name(); + return std::make_unique( + pattern_name, drr_context, ir_context, benefit); } }; diff --git a/paddle/fluid/pir/drr/attr_type_uilts.h b/paddle/fluid/pir/drr/attr_type_uilts.h index fb989fe063b771..28b26ba26a2a12 100644 --- a/paddle/fluid/pir/drr/attr_type_uilts.h +++ b/paddle/fluid/pir/drr/attr_type_uilts.h @@ -43,6 +43,8 @@ PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, paddle::dialect::IntArrayAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray, + paddle::dialect::IntArrayAttribute); template struct IrAttrbuteCreator { diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc new file mode 100644 index 00000000000000..6304220fc72ffc --- /dev/null +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc @@ -0,0 +1,524 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" + +namespace pir { +namespace drr { + +bool DrrRewritePattern::MatchAndRewrite( + pir::Operation* op, + PatternRewriter& rewriter) const { // NOLINT + std::shared_ptr src_match_ctx = + std::make_shared(); + if (PatternGraphMatch(op, src_match_ctx.get())) { + VLOG(4) << "DRR pattern (" << pattern_name_ << ") is matched in program."; + PatternGraphRewrite(*src_match_ctx, rewriter); + return true; + } + return false; +} + +bool DrrRewritePattern::PatternGraphMatch( + pir::Operation* op, MatchContextImpl* source_pattern_match_ctx) const { + VLOG(6) << "PatternGraphMatch Start: op(" << op->name() << ")"; + const OpCall* anchor = source_pattern_graph_->AnchorNode(); + std::unordered_map> + bind_map = + FindCandidateIrOutputOp(op, anchor, *(source_pattern_graph_.get())); + if (bind_map.empty()) { + return false; + } + std::vector drr_output_sequence; + std::vector ir_output_sequence; + std::unordered_map output_op_map; + for (auto pair : bind_map) { + drr_output_sequence.push_back(pair.first); + } + // using dfs to obtain the arrangement of all candidate ir ops + auto permute = [&](auto&& permute, size_t index) -> bool { + if (index == drr_output_sequence.size()) { + // avoiding duplicate binding of ir op + std::unordered_set ir_output_set; + for (Operation* op : ir_output_sequence) { + auto pr = ir_output_set.insert(op); + if (pr.second == false) { + return false; + } + } + // new match_ctx + std::shared_ptr match_ctx = + std::make_shared(); + std::transform(drr_output_sequence.begin(), + drr_output_sequence.end(), + ir_output_sequence.begin(), + std::inserter(output_op_map, output_op_map.end()), + [](const OpCall* drr_op, Operation* ir_op) { + return std::make_pair(drr_op, ir_op); + }); + if (MatchFromOutputToInput( + output_op_map, *(source_pattern_graph_.get()), match_ctx)) { + *source_pattern_match_ctx = *match_ctx; + return true; + } + return false; + } + for (auto* ir_op : bind_map[drr_output_sequence[index]]) { + ir_output_sequence.push_back(ir_op); + if (permute(permute, index + 1)) { + return true; + } + ir_output_sequence.pop_back(); + } + return false; + }; + + return permute(permute, 0); +} + +std::unordered_map> +DrrRewritePattern::FindCandidateIrOutputOp( + pir::Operation* op, + const OpCall* anchor, + const SourcePatternGraph& source_pattern_graph) const { + // get source pattern output op + std::unordered_set drr_output_op_set = + source_pattern_graph.OutputNodes(); + std::unordered_map> + output_op_bind_map{{anchor, {op}}}; + if (drr_output_op_set.size() == 1) { + return output_op_bind_map; + } + std::unordered_set drr_visited_ops{anchor}; + DfsVisitor( + anchor, op, drr_output_op_set, &drr_visited_ops, &output_op_bind_map); + if (output_op_bind_map.size() != drr_output_op_set.size()) { + return {}; + } + return output_op_bind_map; +} + +void DrrRewritePattern::DfsVisitor( + const OpCall* drr_op, + pir::Operation* ir_op, + const std::unordered_set& drr_output_op_set, + std::unordered_set* drr_visited_ops, + std::unordered_map>* + output_op_bind_map) const { + VLOG(6) << "DfsVisitor Start: drr op(" << drr_op->name() << ")" + << "ir op(" << ir_op->name() << ")"; + if (drr_op->name() != ir_op->name()) { + return; + } + // check op input's size + const auto& drr_op_input_tensors = drr_op->inputs(); + auto ir_op_input_value_size = ir_op->num_operands(); + if (drr_op_input_tensors.size() != ir_op_input_value_size) { + return; + } + // check op output's size + const auto& drr_op_output_tensors = drr_op->outputs(); + auto ir_op_output_value_size = ir_op->num_results(); + if (drr_op_output_tensors.size() != ir_op_output_value_size) { + return; + } + // check producer op + for (size_t i = 0; i < drr_op_input_tensors.size(); ++i) { + // case 1: drr_op_input_tensor is the input tensor of source pattern + if (drr_op_input_tensors[i]->producer() == nullptr) { + // dfs source pattern input tensor other child op + auto ir_input_tensor = ir_op->operand(i).source(); + for (auto drr_bro_op : drr_op_input_tensors[i]->consumers()) { + if (drr_visited_ops->count(drr_bro_op)) { + continue; + } + for (auto it = ir_input_tensor.use_begin(); + it != ir_input_tensor.use_end(); + ++it) { + auto* ir_bro_op = it.owner(); + if (drr_bro_op->name() == ir_bro_op->name()) { + drr_visited_ops->insert(drr_bro_op); + DfsVisitor(drr_bro_op, + ir_bro_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_bro_op); + } + } + } + continue; + } + // case 2: have producer op + const auto& drr_producer_op = drr_op_input_tensors[i]->producer(); + if (drr_visited_ops->count(drr_producer_op)) { + continue; + } + auto ir_operand_value = ir_op->operand(i).source(); + if (drr_op_input_tensors[i]->consumers().size() != + ir_operand_value.use_count()) { + return; + } + auto* ir_producer_op = ir_operand_value.dyn_cast().owner(); + drr_visited_ops->insert(drr_producer_op); + DfsVisitor(drr_producer_op, + ir_producer_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_producer_op); + } + if (drr_output_op_set.count(drr_op)) { + (*output_op_bind_map)[drr_op].insert(ir_op); + return; + } + // check child ops + for (size_t i = 0; i < drr_op_output_tensors.size(); ++i) { + const auto& drr_child_ops = drr_op_output_tensors[i]->consumers(); + auto ir_output_value = ir_op->result(i); + if (drr_child_ops.size() != ir_output_value.use_count()) { + return; + } + for (auto* drr_child_op : drr_child_ops) { + for (auto it = ir_output_value.use_begin(); + it != ir_output_value.use_end(); + ++it) { + auto* ir_child_op = it.owner(); + if (drr_child_op->name() == ir_child_op->name()) { + if (drr_visited_ops->count(drr_child_op)) { + continue; + } + drr_visited_ops->insert(drr_child_op); + DfsVisitor(drr_child_op, + ir_child_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_child_op); + } + } + } + } // check child ops + return; +} + +bool DrrRewritePattern::MatchFromOutputToInput( + std::unordered_map output_op_map, + const SourcePatternGraph& source_pattern_graph, + const std::shared_ptr& source_pattern_match_ctx) const { + VLOG(6) << "MatchFromOutputToInput Start"; + std::unordered_set drr_visited; + std::unordered_set ir_visited; + std::queue drr_q; + std::queue ir_q; + bool matched = true; + size_t step = 0; + for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) { + VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @" + << it->second << ") in source_pattern_graph "; + drr_q.push(it->first); + drr_visited.insert(it->first); + ir_q.push(it->second); + ir_visited.insert(it->second); + } + while (!drr_q.empty()) { + if (!matched) break; + auto* drr_node = drr_q.front(); + auto* ir_node = ir_q.front(); + drr_q.pop(); + ir_q.pop(); + if (drr_node->name() != ir_node->name()) { + matched = false; + break; + } + const auto& drr_input_tensors = drr_node->inputs(); + auto ir_input_value_size = ir_node->num_operands(); + if (drr_input_tensors.size() != ir_input_value_size) { + matched = false; + break; + } + if (drr_node->outputs().size() != ir_node->num_results()) { + matched = false; + break; + } + source_pattern_match_ctx->BindIrOperation( + drr_node, std::make_shared(ir_node)); + // binding input_tensor of current_op + for (size_t i = 0; i < drr_input_tensors.size(); ++i) { + source_pattern_match_ctx->BindIrValue( + drr_input_tensors[i]->name(), + std::make_shared(ir_node->operand(i).source())); + auto* drr_producer_op = drr_input_tensors[i]->producer(); + if (drr_producer_op == nullptr) { + continue; + } + auto* ir_producer_op = + ir_node->operand(i).source().dyn_cast().owner(); + if (drr_input_tensors[i]->consumers().size() != + ir_node->operand(i).source().use_count()) { + matched = false; + break; + } + // bfs producer_op of current_op + if (!drr_visited.count(drr_producer_op)) { + drr_q.push(drr_producer_op); + ir_q.push(ir_producer_op); + drr_visited.insert(drr_producer_op); + ir_visited.insert(ir_producer_op); + } + } + // binding output tensor of current_op + auto drr_op_output_tensor = drr_node->outputs(); + for (size_t j = 0; j < drr_op_output_tensor.size(); j++) { + source_pattern_match_ctx->BindIrValue( + drr_op_output_tensor[j]->name(), + std::make_shared(ir_node->result(j))); + } + ++step; + } + + if (matched) { + IR_ENFORCE(step == source_pattern_graph.CountOfOpCalls()); + } else { + return matched; + } + + MatchContext match_context{source_pattern_match_ctx}; + for (const auto& constraint : constraints_) { + matched = constraint(match_context); + if (!matched) break; + } + + return matched; +} + +void DrrRewritePattern::PatternGraphRewrite( + const MatchContextImpl& source_pattern_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + VLOG(6) << "Create Operations in result_pattern_graph"; + MatchContextImpl res_match_ctx = CreateOperations(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); + VLOG(6) << "Process Assign Tensor"; + RebindIrTensorForAssignTensor(*result_pattern_graph_, &res_match_ctx); + VLOG(6) << "Replace Output Values in source_pattern_graph by Output Values " + "in result_pattern_graph"; + ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter); + VLOG(6) << "Delete Operations in source_pattern_graph"; + DeleteSourcePatternOp(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); +} + +MatchContextImpl DrrRewritePattern::CreateOperations( + const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + MatchContextImpl res_match_ctx; + // add input tensors info for res_match_ctx + for (const auto& in_tensor : result_pattern_graph.input_tensors()) { + IR_ENFORCE(result_pattern_graph.id2owend_tensor().count(in_tensor), + "Drr input tensor [%s] must exists in result pattern graph.", + in_tensor); + if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { + res_match_ctx.BindIrValue( + in_tensor, + std::make_shared(src_match_ctx.GetIrValue(in_tensor))); + } + } + + if (result_pattern_graph.CountOfOpCalls() == 1) { + CreateOperation(*result_pattern_graph.owned_op_call()[0], + src_match_ctx, + rewriter, + &res_match_ctx); + return res_match_ctx; + } + + std::vector> temp_program; + std::unordered_map op_2_temp_program_index; + for (Operation* op : *rewriter.block()) { + op_2_temp_program_index[op] = temp_program.size(); + temp_program.push_back({op}); + } + + // topo order visit result_pattern_graph + GraphTopo graph_topo_visit(&result_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { + // set insert point + size_t max_input_op_index = 0; + Operation* max_index_op = nullptr; + for (const Tensor* input : op_call.inputs()) { + if (input->is_none()) { + continue; + } + Value ir_val = res_match_ctx.GetIrValue(input->name()).get(); + if (ir_val) { + Operation* ir_input_op = ir_val.dyn_cast().owner(); + if (max_input_op_index < op_2_temp_program_index[ir_input_op]) { + max_input_op_index = op_2_temp_program_index[ir_input_op]; + max_index_op = ir_input_op; + } else if (max_input_op_index == op_2_temp_program_index[ir_input_op]) { + const auto& ops_vec = temp_program[max_input_op_index]; + for (auto it = ops_vec.rbegin(); it != ops_vec.rend(); it++) { + if (*it == max_index_op) { + break; + } else if (*it == ir_input_op) { + max_index_op = ir_input_op; + break; + } else { + // do nothing + } + } + } else { + // do nothing + } + } + } + if (max_input_op_index == 0UL) { + VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; + Operation* source_patter_first_op = + src_match_ctx.Operation(source_pattern_graph.owned_op_call()[0].get()) + .get(); + max_input_op_index = op_2_temp_program_index[source_patter_first_op]; + rewriter.SetInsertionPoint(source_patter_first_op); + } else { + rewriter.SetInsertionPointAfter(max_index_op); + } + + Operation* new_op = + CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); + op_2_temp_program_index[new_op] = max_input_op_index + 1; + temp_program[max_input_op_index + 1].push_back(new_op); + }); + + return res_match_ctx; +} + +void DrrRewritePattern::RebindIrTensorForAssignTensor( + const ResultPatternGraph& result_pattern_graph, + MatchContextImpl* res_match_ctx) const { + const auto& tensor_assign_map = result_pattern_graph.tensor_assign_map(); + for (const auto& kv : tensor_assign_map) { + const auto& src_tensor_name = kv.first; + const auto& dst_tensor_name = kv.second; + res_match_ctx->BindIrValue( + src_tensor_name, + std::make_shared(res_match_ctx->GetIrValue(dst_tensor_name))); + } +} + +void DrrRewritePattern::ReplaceOutputTensor( + const MatchContextImpl& src_match_ctx, + const MatchContextImpl& res_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + for (const auto& output_name : result_pattern_graph_->output_tensors()) { + if (source_pattern_graph_->id2owend_tensor().count(output_name)) { + const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); + const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); + rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + } else { + LOG(WARNING) << "The output tensor (" << output_name + << ") in the result_pattern_graph is not the tensor" + " in source_pattern_graph."; + } + } +} + +void DrrRewritePattern::DeleteSourcePatternOp( + const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + std::vector topo_order_ops; + GraphTopo graph_topo_visit(&source_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder( + [&topo_order_ops](const OpCall& op_call) { + topo_order_ops.push_back(&op_call); + }); + + // Filter the operations which are replaced by result pattern + // 1. Filter operations by forward walk + std::unordered_set forward_visited_tensor_set( + result_pattern_graph.input_tensors()); + std::unordered_set forward_deleted_ops; + std::for_each(topo_order_ops.begin(), + topo_order_ops.end(), + [&forward_deleted_ops, + &forward_visited_tensor_set](const OpCall* op_call) { + if (op_call->inputs().empty()) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + } + for (const auto* input : op_call->inputs()) { + if (forward_visited_tensor_set.count(input->name())) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + break; + } + } + }); + // 2. Filter operations by backward walk and merge the forward result + std::unordered_set backward_visited_tensor_set( + result_pattern_graph.output_tensors()); + std::vector deleted_ops; + std::unordered_set deleted_ops_set; + std::for_each(topo_order_ops.rbegin(), + topo_order_ops.rend(), + [&deleted_ops, + &deleted_ops_set, + &backward_visited_tensor_set, + &forward_deleted_ops](const OpCall* op_call) { + bool all_comsumer_deleted = true; + bool from_backward_visited_tensor = false; + for (const auto* output : op_call->outputs()) { + if (backward_visited_tensor_set.count(output->name())) { + from_backward_visited_tensor = true; + } else if (output->consumers().empty()) { + continue; + } else { + all_comsumer_deleted = false; + } + } + if (all_comsumer_deleted && from_backward_visited_tensor && + forward_deleted_ops.count(op_call)) { + deleted_ops_set.insert(op_call); + deleted_ops.push_back(op_call); + for (const auto* input : op_call->inputs()) { + backward_visited_tensor_set.insert(input->name()); + } + } + }); + + // Delete Operation with topo order from output tensors. + for (const auto* op_call : deleted_ops) { + IR_ENFORCE(src_match_ctx.operation_map().count(op_call), + "Drr OpCall [%s] must exists in match context.", + op_call->name()); + auto* op = src_match_ctx.operation_map().at(op_call)->get(); + VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@" << op + << ") in source_pattern_graph "; + rewriter.EraseOp(op); + } +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/drr_rewrite_pattern.h index c17feb0eaad052..2c51dcf339b472 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.h +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.h @@ -34,10 +34,10 @@ namespace pir { namespace drr { -template class DrrRewritePattern : public pir::RewritePattern { public: - explicit DrrRewritePattern(const DrrPatternContext& drr_context, + explicit DrrRewritePattern(const std::string& pattern_name, + const DrrPatternContext& drr_context, pir::IrContext* context, pir::PatternBenefit benefit = 1) : pir::RewritePattern( @@ -45,6 +45,7 @@ class DrrRewritePattern : public pir::RewritePattern { benefit, context, {}), + pattern_name_(pattern_name), source_pattern_graph_(drr_context.source_pattern_graph()), constraints_(drr_context.constraints()), result_pattern_graph_(drr_context.result_pattern_graph()) { @@ -54,97 +55,16 @@ class DrrRewritePattern : public pir::RewritePattern { } bool MatchAndRewrite(pir::Operation* op, - PatternRewriter& rewriter) const override { // NOLINT - std::shared_ptr src_match_ctx = - std::make_shared(); - if (PatternGraphMatch(op, src_match_ctx.get())) { - VLOG(4) << "DRR pattern (" << pir::get_type_name() - << ") is matched in program."; - PatternGraphRewrite(*src_match_ctx, rewriter); - return true; - } - return false; - } + PatternRewriter& rewriter) const override; // // NOLINT private: bool PatternGraphMatch(pir::Operation* op, - MatchContextImpl* source_pattern_match_ctx) const { - VLOG(6) << "PatternGraphMatch Start: op(" << op->name() << ")"; - const OpCall* anchor = source_pattern_graph_->AnchorNode(); - std::unordered_map> - bind_map = - FindCandidateIrOutputOp(op, anchor, *(source_pattern_graph_.get())); - if (bind_map.empty()) { - return false; - } - std::vector drr_output_sequence; - std::vector ir_output_sequence; - std::unordered_map output_op_map; - for (auto pair : bind_map) { - drr_output_sequence.push_back(pair.first); - } - // using dfs to obtain the arrangement of all candidate ir ops - auto permute = [&](auto&& permute, size_t index) -> bool { - if (index == drr_output_sequence.size()) { - // avoiding duplicate binding of ir op - std::unordered_set ir_output_set; - for (Operation* op : ir_output_sequence) { - auto pr = ir_output_set.insert(op); - if (pr.second == false) { - return false; - } - } - // new match_ctx - std::shared_ptr match_ctx = - std::make_shared(); - std::transform(drr_output_sequence.begin(), - drr_output_sequence.end(), - ir_output_sequence.begin(), - std::inserter(output_op_map, output_op_map.end()), - [](const OpCall* drr_op, Operation* ir_op) { - return std::make_pair(drr_op, ir_op); - }); - if (MatchFromOutputToInput( - output_op_map, *(source_pattern_graph_.get()), match_ctx)) { - *source_pattern_match_ctx = *match_ctx; - return true; - } - return false; - } - for (auto* ir_op : bind_map[drr_output_sequence[index]]) { - ir_output_sequence.push_back(ir_op); - if (permute(permute, index + 1)) { - return true; - } - ir_output_sequence.pop_back(); - } - return false; - }; - - return permute(permute, 0); - } + MatchContextImpl* source_pattern_match_ctx) const; std::unordered_map> - FindCandidateIrOutputOp( - pir::Operation* op, - const OpCall* anchor, - const SourcePatternGraph& source_pattern_graph) const { - // get source pattern output op - std::unordered_set drr_output_op_set = - source_pattern_graph.OutputNodes(); - std::unordered_map> - output_op_bind_map{{anchor, {op}}}; - if (drr_output_op_set.size() == 1) { - return output_op_bind_map; - } - std::unordered_set drr_visited_ops{anchor}; - DfsVisitor( - anchor, op, drr_output_op_set, &drr_visited_ops, &output_op_bind_map); - if (output_op_bind_map.size() != drr_output_op_set.size()) { - return {}; - } - return output_op_bind_map; - } + FindCandidateIrOutputOp(pir::Operation* op, + const OpCall* anchor, + const SourcePatternGraph& source_pattern_graph) const; void DfsVisitor( const OpCall* drr_op, @@ -152,413 +72,38 @@ class DrrRewritePattern : public pir::RewritePattern { const std::unordered_set& drr_output_op_set, std::unordered_set* drr_visited_ops, std::unordered_map>* - output_op_bind_map) const { - VLOG(6) << "DfsVisitor Start: drr op(" << drr_op->name() << ")" - << "ir op(" << ir_op->name() << ")"; - if (drr_op->name() != ir_op->name()) { - return; - } - // check op input's size - const auto& drr_op_input_tensors = drr_op->inputs(); - auto ir_op_input_value_size = ir_op->num_operands(); - if (drr_op_input_tensors.size() != ir_op_input_value_size) { - return; - } - // check op output's size - const auto& drr_op_output_tensors = drr_op->outputs(); - auto ir_op_output_value_size = ir_op->num_results(); - if (drr_op_output_tensors.size() != ir_op_output_value_size) { - return; - } - // check producer op - for (size_t i = 0; i < drr_op_input_tensors.size(); ++i) { - // case 1: drr_op_input_tensor is the input tensor of source pattern - if (drr_op_input_tensors[i]->producer() == nullptr) { - // dfs source pattern input tensor other child op - auto ir_input_tensor = ir_op->operand(i).source(); - for (auto drr_bro_op : drr_op_input_tensors[i]->consumers()) { - if (drr_visited_ops->count(drr_bro_op)) { - continue; - } - for (auto it = ir_input_tensor.use_begin(); - it != ir_input_tensor.use_end(); - ++it) { - auto* ir_bro_op = it.owner(); - if (drr_bro_op->name() == ir_bro_op->name()) { - drr_visited_ops->insert(drr_bro_op); - DfsVisitor(drr_bro_op, - ir_bro_op, - drr_output_op_set, - drr_visited_ops, - output_op_bind_map); - drr_visited_ops->erase(drr_bro_op); - } - } - } - continue; - } - // case 2: have producer op - const auto& drr_producer_op = drr_op_input_tensors[i]->producer(); - if (drr_visited_ops->count(drr_producer_op)) { - continue; - } - auto ir_operand_value = ir_op->operand(i).source(); - if (drr_op_input_tensors[i]->consumers().size() != - ir_operand_value.use_count()) { - return; - } - auto* ir_producer_op = ir_operand_value.dyn_cast().owner(); - drr_visited_ops->insert(drr_producer_op); - DfsVisitor(drr_producer_op, - ir_producer_op, - drr_output_op_set, - drr_visited_ops, - output_op_bind_map); - drr_visited_ops->erase(drr_producer_op); - } - if (drr_output_op_set.count(drr_op)) { - (*output_op_bind_map)[drr_op].insert(ir_op); - return; - } - // check child ops - for (size_t i = 0; i < drr_op_output_tensors.size(); ++i) { - const auto& drr_child_ops = drr_op_output_tensors[i]->consumers(); - auto ir_output_value = ir_op->result(i); - if (drr_child_ops.size() != ir_output_value.use_count()) { - return; - } - for (auto* drr_child_op : drr_child_ops) { - for (auto it = ir_output_value.use_begin(); - it != ir_output_value.use_end(); - ++it) { - auto* ir_child_op = it.owner(); - if (drr_child_op->name() == ir_child_op->name()) { - if (drr_visited_ops->count(drr_child_op)) { - continue; - } - drr_visited_ops->insert(drr_child_op); - DfsVisitor(drr_child_op, - ir_child_op, - drr_output_op_set, - drr_visited_ops, - output_op_bind_map); - drr_visited_ops->erase(drr_child_op); - } - } - } - } // check child ops - return; - } + output_op_bind_map) const; bool MatchFromOutputToInput( std::unordered_map output_op_map, const SourcePatternGraph& source_pattern_graph, - const std::shared_ptr& source_pattern_match_ctx) const { - VLOG(6) << "MatchFromOutputToInput Start"; - std::unordered_set drr_visited; - std::unordered_set ir_visited; - std::queue drr_q; - std::queue ir_q; - bool matched = true; - size_t step = 0; - for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) { - VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @" - << it->second << ") in source_pattern_graph "; - drr_q.push(it->first); - drr_visited.insert(it->first); - ir_q.push(it->second); - ir_visited.insert(it->second); - } - while (!drr_q.empty()) { - if (!matched) break; - auto* drr_node = drr_q.front(); - auto* ir_node = ir_q.front(); - drr_q.pop(); - ir_q.pop(); - if (drr_node->name() != ir_node->name()) { - matched = false; - break; - } - const auto& drr_input_tensors = drr_node->inputs(); - auto ir_input_value_size = ir_node->num_operands(); - if (drr_input_tensors.size() != ir_input_value_size) { - matched = false; - break; - } - if (drr_node->outputs().size() != ir_node->num_results()) { - matched = false; - break; - } - source_pattern_match_ctx->BindIrOperation( - drr_node, std::make_shared(ir_node)); - // binding input_tensor of current_op - for (size_t i = 0; i < drr_input_tensors.size(); ++i) { - source_pattern_match_ctx->BindIrValue( - drr_input_tensors[i]->name(), - std::make_shared(ir_node->operand(i).source())); - auto* drr_producer_op = drr_input_tensors[i]->producer(); - if (drr_producer_op == nullptr) { - continue; - } - auto* ir_producer_op = - ir_node->operand(i).source().dyn_cast().owner(); - if (drr_input_tensors[i]->consumers().size() != - ir_node->operand(i).source().use_count()) { - matched = false; - break; - } - // bfs producer_op of current_op - if (!drr_visited.count(drr_producer_op)) { - drr_q.push(drr_producer_op); - ir_q.push(ir_producer_op); - drr_visited.insert(drr_producer_op); - ir_visited.insert(ir_producer_op); - } - } - // binding output tensor of current_op - auto drr_op_output_tensor = drr_node->outputs(); - for (size_t j = 0; j < drr_op_output_tensor.size(); j++) { - source_pattern_match_ctx->BindIrValue( - drr_op_output_tensor[j]->name(), - std::make_shared(ir_node->result(j))); - } - ++step; - } - - if (matched) { - IR_ENFORCE(step == source_pattern_graph.CountOfOpCalls()); - } else { - return matched; - } - - MatchContext match_context{source_pattern_match_ctx}; - for (const auto& constraint : constraints_) { - matched = constraint(match_context); - if (!matched) break; - } - - return matched; - } + const std::shared_ptr& source_pattern_match_ctx) const; void PatternGraphRewrite(const MatchContextImpl& source_pattern_match_ctx, - pir::PatternRewriter& rewriter) const { // NOLINT - VLOG(6) << "Create Operations in result_pattern_graph"; - MatchContextImpl res_match_ctx = CreateOperations(*source_pattern_graph_, - *result_pattern_graph_, - source_pattern_match_ctx, - rewriter); - VLOG(6) << "Process Assign Tensor"; - RebindIrTensorForAssignTensor(*result_pattern_graph_, &res_match_ctx); - VLOG(6) << "Replace Output Values in source_pattern_graph by Output Values " - "in result_pattern_graph"; - ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter); - VLOG(6) << "Delete Operations in source_pattern_graph"; - DeleteSourcePatternOp(*source_pattern_graph_, - *result_pattern_graph_, - source_pattern_match_ctx, - rewriter); - } + pir::PatternRewriter& rewriter) const; // NOLINT private: MatchContextImpl CreateOperations( const SourcePatternGraph& source_pattern_graph, const ResultPatternGraph& result_pattern_graph, const MatchContextImpl& src_match_ctx, - pir::PatternRewriter& rewriter) const { // NOLINT - MatchContextImpl res_match_ctx; - // add input tensors info for res_match_ctx - for (const auto& in_tensor : result_pattern_graph.input_tensors()) { - IR_ENFORCE(result_pattern_graph.id2owend_tensor().count(in_tensor), - "Drr input tensor [%s] must exists in result pattern graph.", - in_tensor); - if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { - res_match_ctx.BindIrValue( - in_tensor, - std::make_shared(src_match_ctx.GetIrValue(in_tensor))); - } - } - - if (result_pattern_graph.CountOfOpCalls() == 1) { - CreateOperation(*result_pattern_graph.owned_op_call()[0], - src_match_ctx, - rewriter, - &res_match_ctx); - return res_match_ctx; - } - - std::vector> temp_program; - std::unordered_map op_2_temp_program_index; - for (Operation* op : *rewriter.block()) { - op_2_temp_program_index[op] = temp_program.size(); - temp_program.push_back({op}); - } - - // topo order visit result_pattern_graph - GraphTopo graph_topo_visit(&result_pattern_graph); - graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { - // set insert point - size_t max_input_op_index = 0; - Operation* max_index_op = nullptr; - for (const Tensor* input : op_call.inputs()) { - if (input->is_none()) { - continue; - } - Value ir_val = res_match_ctx.GetIrValue(input->name()).get(); - if (ir_val) { - Operation* ir_input_op = ir_val.dyn_cast().owner(); - if (max_input_op_index < op_2_temp_program_index[ir_input_op]) { - max_input_op_index = op_2_temp_program_index[ir_input_op]; - max_index_op = ir_input_op; - } else if (max_input_op_index == - op_2_temp_program_index[ir_input_op]) { - const auto& ops_vec = temp_program[max_input_op_index]; - for (auto it = ops_vec.rbegin(); it != ops_vec.rend(); it++) { - if (*it == max_index_op) { - break; - } else if (*it == ir_input_op) { - max_index_op = ir_input_op; - break; - } else { - // do nothing - } - } - } else { - // do nothing - } - } - } - if (max_input_op_index == 0UL) { - VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; - Operation* source_patter_first_op = - src_match_ctx - .Operation(source_pattern_graph.owned_op_call()[0].get()) - .get(); - max_input_op_index = op_2_temp_program_index[source_patter_first_op]; - rewriter.SetInsertionPoint(source_patter_first_op); - } else { - rewriter.SetInsertionPointAfter(max_index_op); - } - - Operation* new_op = - CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); - op_2_temp_program_index[new_op] = max_input_op_index + 1; - temp_program[max_input_op_index + 1].push_back(new_op); - }); - - return res_match_ctx; - } + pir::PatternRewriter& rewriter) const; // NOLINT void RebindIrTensorForAssignTensor( const ResultPatternGraph& result_pattern_graph, - MatchContextImpl* res_match_ctx) const { - const auto& tensor_assign_map = result_pattern_graph.tensor_assign_map(); - for (const auto& kv : tensor_assign_map) { - const auto& src_tensor_name = kv.first; - const auto& dst_tensor_name = kv.second; - res_match_ctx->BindIrValue( - src_tensor_name, - std::make_shared( - res_match_ctx->GetIrValue(dst_tensor_name))); - } - } + MatchContextImpl* res_match_ctx) const; void ReplaceOutputTensor(const MatchContextImpl& src_match_ctx, const MatchContextImpl& res_match_ctx, - pir::PatternRewriter& rewriter) const { // NOLINT - for (const auto& output_name : result_pattern_graph_->output_tensors()) { - if (source_pattern_graph_->id2owend_tensor().count(output_name)) { - const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); - const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); - rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); - } else { - LOG(WARNING) << "The output tensor (" << output_name - << ") in the result_pattern_graph is not the tensor" - " in source_pattern_graph."; - } - } - } + pir::PatternRewriter& rewriter) const; // NOLINT void DeleteSourcePatternOp(const SourcePatternGraph& source_pattern_graph, const ResultPatternGraph& result_pattern_graph, const MatchContextImpl& src_match_ctx, - pir::PatternRewriter& rewriter) const { // NOLINT - std::vector topo_order_ops; - GraphTopo graph_topo_visit(&source_pattern_graph); - graph_topo_visit.WalkGraphNodesTopoOrder( - [&topo_order_ops](const OpCall& op_call) { - topo_order_ops.push_back(&op_call); - }); - - // Filter the operations which are replaced by result pattern - // 1. Filter operations by forward walk - std::unordered_set forward_visited_tensor_set( - result_pattern_graph.input_tensors()); - std::unordered_set forward_deleted_ops; - std::for_each(topo_order_ops.begin(), - topo_order_ops.end(), - [&forward_deleted_ops, - &forward_visited_tensor_set](const OpCall* op_call) { - if (op_call->inputs().empty()) { - forward_deleted_ops.insert(op_call); - for (const auto* output : op_call->outputs()) { - forward_visited_tensor_set.insert(output->name()); - } - } - for (const auto* input : op_call->inputs()) { - if (forward_visited_tensor_set.count(input->name())) { - forward_deleted_ops.insert(op_call); - for (const auto* output : op_call->outputs()) { - forward_visited_tensor_set.insert(output->name()); - } - break; - } - } - }); - // 2. Filter operations by backward walk and merge the forward result - std::unordered_set backward_visited_tensor_set( - result_pattern_graph.output_tensors()); - std::vector deleted_ops; - std::unordered_set deleted_ops_set; - std::for_each(topo_order_ops.rbegin(), - topo_order_ops.rend(), - [&deleted_ops, - &deleted_ops_set, - &backward_visited_tensor_set, - &forward_deleted_ops](const OpCall* op_call) { - bool all_comsumer_deleted = true; - bool from_backward_visited_tensor = false; - for (const auto* output : op_call->outputs()) { - if (backward_visited_tensor_set.count(output->name())) { - from_backward_visited_tensor = true; - } else if (output->consumers().empty()) { - continue; - } else { - all_comsumer_deleted = false; - } - } - if (all_comsumer_deleted && from_backward_visited_tensor && - forward_deleted_ops.count(op_call)) { - deleted_ops_set.insert(op_call); - deleted_ops.push_back(op_call); - for (const auto* input : op_call->inputs()) { - backward_visited_tensor_set.insert(input->name()); - } - } - }); - - // Delete Operation with topo order from output tensors. - for (const auto* op_call : deleted_ops) { - IR_ENFORCE(src_match_ctx.operation_map().count(op_call), - "Drr OpCall [%s] must exists in match context.", - op_call->name()); - auto* op = src_match_ctx.operation_map().at(op_call)->get(); - VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@" << op - << ") in source_pattern_graph "; - rewriter.EraseOp(op); - } - } + pir::PatternRewriter& rewriter) const; // NOLINT private: + const std::string pattern_name_; const std::shared_ptr source_pattern_graph_; const std::vector constraints_; const std::shared_ptr result_pattern_graph_; diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index 5355a8977e8c53..135b274b010937 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -83,6 +83,9 @@ static pir::Attribute CreateIrAttribute(const std::any& obj) { } else if (obj.type() == typeid(std::vector)) { return IrAttrbuteCreator>()( std::any_cast>(obj)); + } else if (obj.type() == typeid(phi::IntArray)) { + return IrAttrbuteCreator()( + std::any_cast(obj)); } else { PADDLE_THROW( phi::errors::Unimplemented("Type error. CreateIrAttribute for type(%s) " diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h index b38b5cd6a12b32..ed472be8408102 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.h +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -54,13 +54,18 @@ class OperationFactory { private: OperationFactory() { - RegisterGeneratedOpCreator(); + RegisterPdOpGeneratedOpCreator(); +#ifdef PADDLE_WITH_CINN + RegisterCinnOpGeneratedOpCreator(); +#endif RegisterManualOpCreator(); } void RegisterManualOpCreator(); - void RegisterGeneratedOpCreator(); - + void RegisterPdOpGeneratedOpCreator(); +#ifdef PADDLE_WITH_CINN + void RegisterCinnOpGeneratedOpCreator(); +#endif std::unordered_map op_creator_map; }; diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc index 2d0fef35cb454c..f15183fd1af036 100644 --- a/paddle/fluid/pir/transforms/build_cinn_pass.cc +++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc @@ -32,7 +32,7 @@ #include "paddle/pir/pass/pass_registry.h" #include "paddle/cinn/frontend/op_mapper_registry.h" -#include "paddle/cinn/hlir/framework/new_ir/utils.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/utils/flags.h" PD_DECLARE_string(allow_cinn_ops); @@ -43,7 +43,7 @@ using GroupOpsVec = std::vector; // The delim(`;`) that is used to split the FLAGS_allow_cinn_ops // & FLAGS_deny_cinn_ops. constexpr char kDelim[] = ";"; -using CompatibleInfo = cinn::hlir::framework::newir::CompatibleInfo; +using CompatibleInfo = cinn::hlir::framework::pir::CompatibleInfo; // OpTransInfo contains informations used to detect subgraphs // supported by the CINN compiler. diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc new file mode 100644 index 00000000000000..6025a3f7d1c3a9 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class FusedDropoutAddPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), + {{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + + dropout({&pat.Tensor("x"), &pat.Tensor("seed_tensor")}, + {&pat.Tensor("dropout_out"), &pat.Tensor("mask")}); + pat.Tensor("add_out") = add(pat.Tensor("dropout_out"), pat.Tensor("y")); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &fused_dropout_add = + res.Op(paddle::dialect::FusedDropoutAddOp::name(), + {{{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}}); + fused_dropout_add( + {&res.Tensor("x"), &res.Tensor("y"), &res.Tensor("seed_tensor")}, + {&res.Tensor("add_out"), &res.Tensor("mask")}); + } +}; + +class FusedDropoutGradAddGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), + {{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &dropout_grad = pat.Op(paddle::dialect::DropoutGradOp::name(), + {{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}}); + + dropout({&pat.Tensor("x"), &pat.Tensor("seed_tensor")}, + {&pat.Tensor("dropout_out"), &pat.Tensor("mask")}); + pat.Tensor("add_out") = add(pat.Tensor("dropout_out"), pat.Tensor("y")); + add_grad({&pat.Tensor("dropout_out"), + &pat.Tensor("y"), + &pat.Tensor("add_out_grad")}, + {&pat.Tensor("dropout_out_grad"), &pat.Tensor("y_grad")}); + dropout_grad({&pat.Tensor("mask"), &pat.Tensor("dropout_out_grad")}, + {&pat.Tensor("x_grad")}); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &fused_dropout_add = + res.Op(paddle::dialect::FusedDropoutAddOp::name(), + {{{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"seed", pat.Attr("seed")}, + {"fix_seed", pat.Attr("fix_seed")}}}); + + const auto &fused_dropout_add_grad = + res.Op(paddle::dialect::FusedDropoutAddGradOp::name(), + {{{"p", pat.Attr("p")}, + {"is_test", pat.Attr("is_test")}, + {"mode", pat.Attr("mod")}, + {"fix_seed", pat.Attr("fix_seed")}}}); + + fused_dropout_add( + {&res.Tensor("x"), &res.Tensor("y"), &res.Tensor("seed_tensor")}, + {&res.Tensor("add_out"), &res.Tensor("mask")}); + fused_dropout_add_grad({&res.Tensor("mask"), &res.Tensor("add_out_grad")}, + {&res.Tensor("x_grad"), &res.Tensor("y_grad")}); + } +}; + +class FusedDropoutAddPass : public pir::Pass { + public: + FusedDropoutAddPass() : pir::Pass("fused_dropout_add_pass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + + ps.Add(FusedDropoutAddPattern().Build(context)); + ps.Add(FusedDropoutGradAddGradPattern().Build(context)); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateFusedDropoutAddPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(fused_dropout_add_pass, FusedDropoutAddPass); diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h new file mode 100644 index 00000000000000..3d78e6fe7b3b29 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateFusedDropoutAddPass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 0823867b444888..71fe6b6476302a 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" - +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/api/drr_pattern_base.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" @@ -25,10 +25,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { public: void operator()(pir::drr::DrrPatternContext *ctx) const override { pir::drr::SourcePattern pat = ctx->SourcePattern(); - const auto &matmul = pat.Op("pd_op.matmul", + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); - const auto &add = pat.Op("pd_op.add"); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); @@ -43,10 +43,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { return "none"; }); - const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue", - {{{"trans_x", pat.Attr("trans_x")}, - {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + const auto &fused_gemm_epilogue = + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); fused_gemm_epilogue( {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out")}); @@ -58,14 +59,14 @@ class FusedLinearGradPattern public: void operator()(pir::drr::DrrPatternContext *ctx) const override { pir::drr::SourcePattern pat = ctx->SourcePattern(); - const auto &matmul = pat.Op("pd_op.matmul", + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); - const auto &matmul_grad = pat.Op("pd_op.matmul_grad", + const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); - const auto &add = pat.Op("pd_op.add"); - const auto &add_grad = pat.Op("pd_op.add_grad"); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); @@ -84,12 +85,13 @@ class FusedLinearGradPattern res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { return "none"; }); - const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue", - {{{"trans_x", pat.Attr("trans_x")}, - {"trans_y", pat.Attr("trans_y")}, - {"activation", act_attr}}}); + const auto &fused_gemm_epilogue = + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); const auto &fused_gemm_epilogue_grad = - res.Op("pd_op.fused_gemm_epilogue_grad", + res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x")}, {"trans_y", pat.Attr("trans_y")}, {"activation_grad", act_attr}}}); @@ -112,19 +114,20 @@ class FusedLinearGeluGradPattern void operator()(pir::drr::DrrPatternContext *ctx) const override { pir::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fused_gemm_epilogue = - pat.Op("pd_op.fused_gemm_epilogue", + pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, {"trans_y", pat.Attr("trans_y1")}, {"activation", pat.Attr("act1")}}}); const auto &fused_gemm_epilogue_grad1 = - pat.Op("pd_op.fused_gemm_epilogue_grad", + pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x2")}, {"trans_y", pat.Attr("trans_y2")}, {"activation_grad", pat.Attr("act2")}}}); fused_gemm_epilogue( {&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")}, {&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")}); - pat.Tensor("out") = pat.Op("pd_op.gelu")(pat.Tensor("fuse_out")); + pat.Tensor("out") = + pat.Op(paddle::dialect::GeluOp::name())(pat.Tensor("fuse_out")); fused_gemm_epilogue_grad1({&pat.Tensor("x1"), &pat.Tensor("w1"), @@ -133,8 +136,8 @@ class FusedLinearGeluGradPattern {&pat.Tensor("x1_grad"), &pat.Tensor("w1_grad"), &pat.Tensor("bias1_grad")}); - pat.Tensor("gelu_dx") = pat.Op("pd_op.gelu_grad")(pat.Tensor("fuse_out"), - pat.Tensor("x1_grad")); + pat.Tensor("gelu_dx") = pat.Op(paddle::dialect::GeluGradOp::name())( + pat.Tensor("fuse_out"), pat.Tensor("x1_grad")); pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { return match_ctx.Attr("act1") == "none" && @@ -147,7 +150,7 @@ class FusedLinearGeluGradPattern return "gelu"; }); const auto &fused_gemm_epilogue_new = - res.Op("pd_op.fused_gemm_epilogue", + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, {"trans_y", pat.Attr("trans_y1")}, {"activation", act_attr}}}); @@ -156,7 +159,7 @@ class FusedLinearGeluGradPattern return "gelu_grad"; }); const auto &fused_gemm_epilogue_grad_new = - res.Op("pd_op.fused_gemm_epilogue_grad", + res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x2")}, {"trans_y", pat.Attr("trans_y2")}, {"activation_grad", act_grad_attr}}}); @@ -179,17 +182,17 @@ class FusedLinearReluGradPattern void operator()(pir::drr::DrrPatternContext *ctx) const override { pir::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fused_gemm_epilogue = - pat.Op("pd_op.fused_gemm_epilogue", + pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, {"trans_y", pat.Attr("trans_y1")}, {"activation", pat.Attr("act1")}}}); const auto &fused_gemm_epilogue_grad = - pat.Op("pd_op.fused_gemm_epilogue_grad", + pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x2")}, {"trans_y", pat.Attr("trans_y2")}, {"activation_grad", pat.Attr("act2")}}}); const auto &fused_gemm_epilogue_grad1 = - pat.Op("pd_op.fused_gemm_epilogue_grad", + pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x3")}, {"trans_y", pat.Attr("trans_y3")}, {"activation_grad", pat.Attr("act3")}}}); @@ -226,7 +229,7 @@ class FusedLinearReluGradPattern return "relu"; }); const auto &fused_gemm_epilogue_new = - res.Op("pd_op.fused_gemm_epilogue", + res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, {"trans_y", pat.Attr("trans_y1")}, {"activation", act_attr}}}); @@ -235,7 +238,7 @@ class FusedLinearReluGradPattern return "relu_grad"; }); const auto &fused_gemm_epilogue_grad1_new = - res.Op("pd_op.fused_gemm_epilogue_grad", + res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(), {{{"trans_x", pat.Attr("trans_x2")}, {"trans_y", pat.Attr("trans_y2")}, {"activation_grad", act_grad_attr}}}); diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc new file mode 100644 index 00000000000000..ccc8a58c24aef5 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -0,0 +1,370 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +namespace { + +// add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add +class FusedMatmulAddGradAddPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + add_grad( + {&pat.Tensor("out"), &pat.Tensor("bias"), &pat.Tensor("addout_grad")}, + {&pat.Tensor("out_grad"), &pat.Tensor("dbias")}); + matmul_grad( + {&pat.Tensor("x"), &pat.Tensor("weight"), &pat.Tensor("out_grad")}, + {&pat.Tensor("x_grad"), &pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + const auto &x_trans = match_ctx.Attr("trans_x"); + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + match_ctx.Tensor("out").Shape() == + match_ctx.Tensor("addout_grad").Shape() && + x_trans == false); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &dx_matmul_trans_y_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + const auto &y_trans = match_ctx.Attr("trans_y"); + return !y_trans; + }); + + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &matmul = res.Op( + paddle::dialect::MatmulOp::name(), + {{"transpose_x", false_attr}, {"transpose_y", dx_matmul_trans_y_attr}}); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + + matmul({&res.Tensor("addout_grad"), &res.Tensor("weight")}, + {&res.Tensor("x_grad")}); + fused_linear_param_grad_add({&res.Tensor("x"), + &res.Tensor("addout_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias")}); + } +}; + +// matmul_grad + add_ -> matmul + fused_liner_param_gard_add +class FusedMatmulGradAddPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + matmul_grad( + {&pat.Tensor("x"), &pat.Tensor("weight"), &pat.Tensor("out_grad")}, + {&pat.Tensor("x_grad"), &pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + const auto &x_trans = match_ctx.Attr("trans_x"); + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + x_trans == false); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &dx_matmul_trans_y_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + const auto &y_trans = match_ctx.Attr("trans_y"); + return !y_trans; + }); + + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &matmul = res.Op( + paddle::dialect::MatmulOp::name(), + {{"transpose_x", false_attr}, {"transpose_y", dx_matmul_trans_y_attr}}); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + + matmul({&res.Tensor("out_grad"), &res.Tensor("weight")}, + {&res.Tensor("x_grad")}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("out_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); + } +}; + +// matmul + 0 = add_(0,1) -> fused_liner_param_gard_add +class FusedMatmulAddaPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + matmul({&pat.Tensor("x"), &pat.Tensor("out_grad")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("out_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); + } +}; + +// matmul + 1 = add_(1,0) -> fused_liner_param_gard_add +class FusedMatmulAddbPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + + matmul({&pat.Tensor("x"), &pat.Tensor("out_grad")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("add_out") = + add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &false_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", false_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("out_grad"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); + } +}; + +// add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add +class FusedMatmulAddGradAddaPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + add_grad({&pat.Tensor("out"), &pat.Tensor("bias"), &pat.Tensor("dadd_out")}, + {&pat.Tensor("dout"), &pat.Tensor("dbias")}); + matmul({&pat.Tensor("x"), &pat.Tensor("dout")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("dweight_out") = + add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + match_ctx.Tensor("out").Shape() == + match_ctx.Tensor("dadd_out").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("dadd_out"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); + } +}; + +// add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add +class FusedMatmulAddGradAddbPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name()); + const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add_ = pat.Op(paddle::dialect::Add_Op::name()); + add_grad({&pat.Tensor("out"), &pat.Tensor("bias"), &pat.Tensor("dadd_out")}, + {&pat.Tensor("dout"), &pat.Tensor("dbias")}); + matmul({&pat.Tensor("x"), &pat.Tensor("dout")}, + {&pat.Tensor("weight_grad")}); + pat.Tensor("dweight_out") = + add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("weight_grad").Shape() == + match_ctx.Tensor("dweight").Shape() && + match_ctx.Tensor("out").Shape() == + match_ctx.Tensor("dadd_out").Shape()); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &muti_precision_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return !(match_ctx.Tensor("dweight").Dtype() == + match_ctx.Tensor("weight_grad").Dtype()); + }); + const auto &true_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &fused_linear_param_grad_add = res.Op( + paddle::dialect::FusedLinearParamGradAddOp::name(), + {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); + fused_linear_param_grad_add( + {&res.Tensor("x"), + &res.Tensor("dadd_out"), + &res.Tensor("dweight"), + &res.NoneTensor()}, + {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); + } +}; + +class FusedLinearParamGradAddPass : public pir::Pass { + public: + FusedLinearParamGradAddPass() + : pir::Pass("fused_linear_param_grad_add_pass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(FusedMatmulAddGradAddPattern().Build(context)); + ps.Add(FusedMatmulGradAddPattern().Build(context)); + ps.Add(FusedMatmulAddaPattern().Build(context)); + ps.Add(FusedMatmulAddbPattern().Build(context)); + ps.Add(FusedMatmulAddGradAddaPattern().Build(context)); + ps.Add(FusedMatmulAddGradAddbPattern().Build(context)); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateFusedLinearParamGradAddPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(fused_linear_param_grad_add_pass, FusedLinearParamGradAddPass); diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h new file mode 100644 index 00000000000000..f4b17e8993a187 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateFusedLinearParamGradAddPass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index 760a78c1952ab1..a15bbebeef5d02 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include +#include +#include + #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" @@ -21,11 +24,16 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/phi/core/flags.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" +PHI_DECLARE_string(ir_inplace_kernel_blacklist); + namespace details { // NOTE(zhangbo): Which kind of value can be deleted? // (1) Value's type needs to be AllocatedDenseTensorType or @@ -53,7 +61,44 @@ static bool CanBeDeleted(pir::Value value) { static bool CanDoInplace(const std::unordered_set& eager_dels, pir::Value input, pir::Value output) { - if (input.type() != output.type()) { + if (!input.type() || !output.type()) { + return false; + } + + if (input.type().isa() && + output.type().isa()) { + auto input_alloc_tensor_type = + input.type().dyn_cast(); + auto output_alloc_tensor_type = + output.type().dyn_cast(); + + if (input_alloc_tensor_type.dtype() != output_alloc_tensor_type.dtype()) { + VLOG(9) << " -- input's dtype != output's dtype, can't do inplace"; + return false; + } + + int64_t in_numel = 1; + int64_t out_numel = 1; + for (int i = 0; i < input_alloc_tensor_type.dims().size(); i++) { + if (input_alloc_tensor_type.dims()[i] == -1) { + VLOG(9) << " -- input's shape has -1, can't do inplace"; + return false; + } + in_numel *= input_alloc_tensor_type.dims()[i]; + } + + for (int i = 0; i < output_alloc_tensor_type.dims().size(); i++) { + if (output_alloc_tensor_type.dims()[i] == -1) { + VLOG(9) << " -- output's shape has -1, can't do inplace"; + return false; + } + out_numel *= output_alloc_tensor_type.dims()[i]; + } + if (in_numel != out_numel) { + VLOG(9) << " -- input's numel != output's numel, can't do inplace"; + return false; + } + } else if (input.type() != output.type()) { VLOG(9) << " -- input's type != output's type, can't do inplace"; return false; } @@ -140,21 +185,19 @@ static void GetEagerDelValueOfOp( for (size_t i = 0; i < op->num_operands(); ++i) { auto input = op->operand_source(i); - if (skip_dels.count(input) > 0 || !input || !CanBeDeleted(input) || - IsNoNeedBuffer(op, input)) { + if (skip_dels.count(input) > 0 || !input || !CanBeDeleted(input)) { VLOG(6) << "The " << i << "-th input value of the Operation(" << upper_op_name << ") can not be deleted."; VLOG(8) << " -- skip dels: " << skip_dels.count(input); VLOG(8) << " -- value is null: " << !input; VLOG(8) << " -- can be deleted: " << !CanBeDeleted(input); - VLOG(8) << " -- is no_need_buffer: " << IsNoNeedBuffer(op, input); continue; } (*del_value_2_op)[input] = op; } - for (size_t i = 0; i < op->num_results(); ++i) { - pir::Value output = op->result(i); + for (auto& result : op->results()) { + pir::Value output = result; if (output && CanBeDeleted(output)) { (*del_value_2_op)[output] = op; } @@ -206,8 +249,8 @@ static std::unordered_map GetInplaceOps( VLOG(6) << op->name() << "is not a kernel_dialect op, inplace only support " "kernel_dialect operators"; - for (size_t i = 0; i < op->num_results(); ++i) { - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + visited_values.insert(result); } continue; } @@ -226,8 +269,8 @@ static std::unordered_map GetInplaceOps( .dyn_cast() .data() .backend() == phi::Backend::CPU)) { - for (size_t i = 0; i < op->num_results(); ++i) { - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + visited_values.insert(result); } continue; } @@ -238,9 +281,9 @@ static std::unordered_map GetInplaceOps( for (size_t i = 0; i < op->num_operands(); ++i) { reused_input_values.insert(op->operand_source(i)); } - for (size_t i = 0; i < op->num_results(); ++i) { - reused_output_values.insert(op->result(i)); - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + reused_output_values.insert(result); + visited_values.insert(result); } continue; } @@ -248,7 +291,16 @@ static std::unordered_map GetInplaceOps( pir::OpInfo upper_inplace_op_info = pir::IrContext::Instance()->GetRegisteredOpInfo(upper_op_name + "_"); - if (eager_dels.count(op) == 0 || (!upper_inplace_op_info)) { + std::regex reg(","); + std::unordered_set elems{ + std::sregex_token_iterator(FLAGS_ir_inplace_kernel_blacklist.begin(), + FLAGS_ir_inplace_kernel_blacklist.end(), + reg, + -1), + std::sregex_token_iterator()}; + elems.erase(""); + + if (elems.count(upper_op_name)) { VLOG(6) << upper_op_name << "'s value can't delete or doesn't have inplace op, so that " "can't do inplace."; @@ -257,6 +309,19 @@ static std::unordered_map GetInplaceOps( } continue; } + if (eager_dels.count(op) == 0 || (!upper_inplace_op_info) || + upper_op_name == "pd_op.transpose") { + // NOTE(wanghuancoder): pd_op.transpose is not an + // inplace op, only strided transpose support + // inplace in dygraph + VLOG(6) << upper_op_name + << "'s value can't delete or doesn't have inplace op, so that " + "can't do inplace."; + for (auto& result : op->results()) { + visited_values.insert(result); + } + continue; + } auto upper_inplace_op_interface = upper_inplace_op_info @@ -310,8 +375,13 @@ static std::unordered_map GetInplaceOps( << " will change to inplace version op: " << upper_op_name + "_"; } - for (size_t i = 0; i < op->num_results(); ++i) { - visited_values.insert(op->result(i)); + for (auto& result : op->results()) { + visited_values.insert(result); + } + } + if (!FLAGS_ir_inplace_kernel_blacklist.empty()) { + for (auto i : inplace_ops) { + std::cout << i.second << std::endl; } } return inplace_ops; diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 3ac3db56cfd41d..deffc55035bf06 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -74,7 +74,6 @@ const std::unordered_set SpecialLowerOps = { "pd_op.if", "pd_op.while", "cf.yield", - "cf.cond_yield", "cinn_runtime.jit_kernel"}; bool NeedFallBackCpu(const pir::Operation* op, @@ -679,10 +678,34 @@ phi::KernelKey GetKernelKey( phi::KernelKey res(kernel_backend, kernel_layout, kernel_data_type); + // kernel backend infered incorrectly from memcpy op operands, + // case that place from (not GPU) to GPU. + // We handle this special case by following code to fix up the problem. + // This could be further improved if we had another method. + if (!platform::is_gpu_place(place)) { + if (op->isa()) { + VLOG(6) << "MemcpyOp need a special handle"; + int dst_place_type = op->attribute("dst_place_type") + .dyn_cast() + .data(); + if (dst_place_type == 1) { + res.set_backend(phi::Backend::GPU); + } + } + } + if (op->isa()) { res.set_dtype(phi::DataType::FLOAT32); VLOG(8) << "LoadCombineOp's kernel data type must be FLOAT32"; } + + if (op->isa() || + op->isa()) { + res.set_dtype(phi::DataType::FLOAT32); + VLOG(8) << "CSyncCommStream_Op/CSyncCommStreamOp's kernel data type must " + "be FLOAT32"; + } + if (NeedFallBackCpu((op), kernel_fn_str, res)) { res.set_backend(phi::Backend::CPU); VLOG(8) << "kernel backend must be on CPU when need fallback"; @@ -815,6 +838,15 @@ void HandleForWhileOp( ctx, map_op_pair, map_value_pair); + + (*map_op_pair)[op_item] = new_while_op; + + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = new_while_op->result(i); + } + } } pir::Value GetNewInput( diff --git a/paddle/fluid/platform/device/gpu/nccl_helper.h b/paddle/fluid/platform/device/gpu/nccl_helper.h index 6afcd2eb7cd972..8afcfc9f2b7005 100644 --- a/paddle/fluid/platform/device/gpu/nccl_helper.h +++ b/paddle/fluid/platform/device/gpu/nccl_helper.h @@ -32,6 +32,8 @@ #ifdef PADDLE_WITH_RCCL #include "paddle/fluid/platform/dynload/rccl.h" #endif +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/platform/device_event_base.h b/paddle/fluid/platform/device_event_base.h index 03fd7d4bb13f05..828b54c44a2dd3 100644 --- a/paddle/fluid/platform/device_event_base.h +++ b/paddle/fluid/platform/device_event_base.h @@ -18,6 +18,7 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#include "paddle/utils/test_macros.h" namespace paddle { namespace platform { @@ -213,7 +214,7 @@ struct EventCreateFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventCreateFunctionRegisterer \ __reg_event_create_##device_type##__(func); \ - int TouchDeviceEventCreate##device_type() { \ + TEST_API int TouchDeviceEventCreate##device_type() { \ __reg_event_create_##device_type##__.Touch(); \ return 0; \ } @@ -233,7 +234,7 @@ struct EventRecordFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventRecordFunctionRegisterer \ __reg_event_record_##device_type##__(func); \ - int TouchDeviceEventRecord##device_type() { \ + TEST_API int TouchDeviceEventRecord##device_type() { \ __reg_event_record_##device_type##__.Touch(); \ return 0; \ } @@ -253,7 +254,7 @@ struct EventQueryFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventQueryFunctionRegisterer \ __reg_event_query_##device_type##__(func); \ - int TouchDeviceEventQuery##device_type() { \ + TEST_API int TouchDeviceEventQuery##device_type() { \ __reg_event_query_##device_type##__.Touch(); \ return 0; \ } @@ -273,7 +274,7 @@ struct EventFinishFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventFinishFunctionRegisterer \ __reg_event_finish_##device_type##__(func); \ - int TouchDeviceEventFinish##device_type() { \ + TEST_API int TouchDeviceEventFinish##device_type() { \ __reg_event_finish_##device_type##__.Touch(); \ return 0; \ } @@ -293,7 +294,7 @@ struct EventSetFinishedFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventSetFinishedFunctionRegisterer \ __reg_event_finished_setter_##device_type##__(func); \ - int TouchDeviceEventSetFinished##device_type() { \ + TEST_API int TouchDeviceEventSetFinished##device_type() { \ __reg_event_finished_setter_##device_type##__.Touch(); \ return 0; \ } @@ -315,7 +316,7 @@ struct EventWaitFunctionRegisterer : public framework::Registrar { static ::paddle::platform::EventWaitFunctionRegisterer \ __reg_event_wait_##waiter_type##event_type##__(func); \ - int TouchDeviceEventWait##waiter_type##event_type() { \ + TEST_API int TouchDeviceEventWait##waiter_type##event_type() { \ __reg_event_wait_##waiter_type##event_type##__.Touch(); \ return 0; \ } @@ -335,7 +336,7 @@ struct EventResetFunctionRegisterer : public framework::Registrar { "REGISTER_EVENT_RESET_FUNCTION must be called in global namespace"); \ static ::paddle::platform::EventResetFunctionRegisterer \ __reg_event_resetter_##device_type##__(func); \ - int TouchDeviceEventReset##device_type() { \ + TEST_API int TouchDeviceEventReset##device_type() { \ __reg_event_resetter_##device_type##__.Touch(); \ return 0; \ } diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index c2052719dd56c3..d9516c9f4de4e8 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -31,6 +31,7 @@ namespace dynload { __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ __macro(ncclCommCuDevice); \ @@ -42,6 +43,7 @@ namespace dynload { __macro(ncclGroupEnd); \ __macro(ncclReduce); \ __macro(ncclReduceScatter); \ + __macro(ncclCommGetAsyncError); \ __macro(ncclGetErrorString); NCCL_RAND_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_NCCL_WRAP) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 64c431b3d237fe..6e12d6fa464cc7 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1212,8 +1212,8 @@ void scatter_grad(const Tensor& index, template void batch_norm_grad(const Tensor& x, - const Tensor& scale, - const Tensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean_out, const paddle::optional& variance_out, const Tensor& saved_mean, @@ -1306,14 +1306,20 @@ void batch_norm_grad(const Tensor& x, if (x_grad) { if (use_global_stats) { - auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; + auto nhwc_x_grad = rsqrt_var * nhwc_out_grad; + if (scale) { + nhwc_x_grad = scale.get() * nhwc_x_grad; + } auto nchw_x_grad = transpose(nhwc_x_grad, nhwc_to_nchw_dim); if (need_cast) { nchw_x_grad = cast(nchw_x_grad, x.dtype()); } set_output(nchw_x_grad, x_grad); } else { - auto part1 = scale * rsqrt_var; + auto part1 = rsqrt_var; + if (scale) { + part1 = scale.get() * part1; + } auto mean_temp1 = nhwc_out_grad_sum / nhw; auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; auto part2 = @@ -1343,14 +1349,19 @@ void batch_norm_grad(const Tensor& x, auto nhwc_sum_dout_mul_diff = sum( out_grad_data * (x_data - mean_data), reduce_axis, dtype, false); if (use_global_stats) { - auto x_grad_data = scale * rsqrt_var * out_grad_data; + auto x_grad_data = rsqrt_var * out_grad_data; + if (scale) { + x_grad_data = scale.get() * x_grad_data; + } if (need_cast) { x_grad_data = cast(x_grad_data, x.dtype()); } set_output(x_grad_data, x_grad); } else { - auto part1 = scale * rsqrt_var; - + auto part1 = rsqrt_var; + if (scale) { + part1 = scale.get() * part1; + } auto mean_temp1 = out_grad_data_sum / nhw; auto mean_temp2 = nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index 02e6c58f97af63..19f73bcf0eda3d 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -12,6 +12,7 @@ #include "paddle/pir/core/operation.h" #include "paddle/phi/core/flags.h" #include "paddle/utils/optional.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" PHI_DECLARE_string(tensor_operants_mode); @@ -57,7 +58,7 @@ if({{i.name}}_define_op->name() != "pd_op.full_int_array"){ "We don't support dynamic tensors attribute {{i.name}} for {{api_name}} composite " "for now. ")); } -auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_cast().data(); +auto {{i.name}} = phi::IntArray(paddle::dialect::GetInt64Vector({{i.name}}_define_op->attribute("value"))); {% endif %} {% endif %} {% endfor %} diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 2dfeb89bef5c42..cc5cb3f326bf65 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -42,6 +42,7 @@ set(PYBIND_DEPS pd_op_dialect program_translator pd_inplace_pass + fusion_passes pir new_profiler jit_layer diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 09d76e33d69c1e..785e80a3abeaab 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -35,6 +35,7 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/dense_tensor.h" @@ -571,33 +572,7 @@ void BindAutoParallel(py::module *m) { "reshard", [](py::handle py_tensor, const TensorDistAttr &dist_attr) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - auto dev_ctx = phi::DeviceContextPool::Instance().Get(tensor.place()); - std::shared_ptr dist_out_ptr = nullptr; - if (phi::distributed::DistTensor::classof(tensor.impl().get())) { - auto tensor_in = tensor.impl(); - if (tensor_in) { - phi::distributed::DistTensor *dist_tensor = - static_cast(tensor_in.get()); - if (dist_tensor->dist_attr() != dist_attr) { - VLOG(6) << "reshard func, reshard tensor from " - << dist_tensor->dist_attr() << " to " << dist_attr; - auto *func = phi::distributed::ChooseProperReshardFunction( - *dist_tensor, dist_attr); - dist_out_ptr = func->Eval(dev_ctx, *dist_tensor, dist_attr); - } else { - dist_out_ptr = - std::static_pointer_cast( - tensor_in); - } - } - return paddle::Tensor(dist_out_ptr); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "The input tensor of shard function should be " - "``phi::distributed::DistTensor``. " - "However it's %s", - typeid(tensor.impl().get()).name())); - } + return reshard_ad_function(tensor, dist_attr); }, py::return_value_policy::reference); diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 82408b5236936e..64ff801d464f4c 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -38,6 +38,9 @@ namespace paddle { namespace pybind { void BindCommContextManager(py::module *m) { + auto P2POption = py::class_(*m, "P2POption") + .def(py::init<>()); + auto CommContextManager = py::class_>( @@ -49,6 +52,12 @@ void BindCommContextManager(py::module *m) { .def_static( "create_nccl_comm_context", &phi::distributed::CommContextManager::CreateNCCLCommContext, + py::arg("store"), + py::arg("unique_comm_key"), + py::arg("rank"), + py::arg("size"), + py::arg("hash_key") = "", + py::arg("p2p_opt") = nullptr, py::call_guard()) #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 259aa1f5dac493..5c492815f108d0 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -23,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/reducer.h" -#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/imperative/layer.h" @@ -31,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/process_group_utils.h" #include "paddle/phi/api/all.h" +#include "paddle/phi/core/distributed/types.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group_nccl.h" @@ -265,8 +265,8 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto in_dense = *p_in_tensor; - auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); auto task = self.AllGather(out_dense, in_dense, sync_op); + auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); task->UpdateWaitChain(*dev_ctx); return task; @@ -320,8 +320,6 @@ void BindDistributed(py::module *m) { auto in_dense = *p_in_tensor; // in_tensor_list should not be empty - auto *dev_ctx = - self.GetDeviceContext(in_tensor_list.back().place()); int world_size = self.GetSize(); auto task = self.AllToAll(out_dense, @@ -329,6 +327,8 @@ void BindDistributed(py::module *m) { GetDefaultSplitSizes(*out_dense, world_size), GetDefaultSplitSizes(in_dense, world_size), sync_op); + auto *dev_ctx = + self.GetDeviceContext(in_tensor_list.back().place()); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); task->UpdateWaitChain(*dev_ctx); return task; @@ -542,11 +542,11 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto in_dense = *p_in_tensor; - auto *dev_ctx = - self.GetDeviceContext(in_tensor.place(), use_calc_stream); distributed::GatherOptions gather_opts{dst}; auto task = self.Gather( out_dense, in_dense, gather_opts, sync_op, use_calc_stream); + auto *dev_ctx = + self.GetDeviceContext(in_tensor.place(), use_calc_stream); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); if (!use_calc_stream && dev_ctx->GetPlace() != platform::CPUPlace()) { @@ -582,8 +582,7 @@ void BindDistributed(py::module *m) { opts.reduce_op = op; auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.AllReduce(tensors, tensors, opts); + return self.AllReduce(dense.get(), *dense, opts, false); }, py::arg("tensor"), py::arg("op") = distributed::ReduceOp::SUM, @@ -599,8 +598,7 @@ void BindDistributed(py::module *m) { opts.source_rank = source_rank; auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Broadcast(tensors, tensors, opts); + return self.Broadcast(dense.get(), *dense, opts, false); }, py::arg("tensor"), py::arg("source_rank"), @@ -614,8 +612,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Send(tensors, dst); + return self.Send(*dense, dst, false); }, py::arg("tensor"), py::arg("dst"), @@ -629,8 +626,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Recv(tensors, src); + return self.Recv(dense.get(), src, false); }, py::arg("tensor"), py::arg("src"), @@ -647,9 +643,7 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto out_dense = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllGather(in_tensors, out_tensors); + return self.AllGather(out_dense.get(), *in_dense, false); }, py::arg("in"), py::arg("out"), @@ -695,9 +689,14 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto out_dense = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllToAll(in_tensors, out_tensors); + + int world_size = self.GetSize(); + return self.AllToAll( + out_dense.get(), + *in_dense, + GetDefaultSplitSizes(*out_dense, world_size), + GetDefaultSplitSizes(*in_dense, world_size), + false); }, py::arg("in"), py::arg("out"), @@ -741,8 +740,7 @@ void BindDistributed(py::module *m) { opts.root_rank = dst; auto dense = std::dynamic_pointer_cast( in_tensor.impl()); - std::vector tensors = {*dense}; - return self.Reduce(tensors, tensors, opts); + return self.Reduce(dense.get(), *dense, opts, false); }, py::arg("tensor"), py::arg("dst"), @@ -763,9 +761,7 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto out_dense = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.Scatter(in_tensors, out_tensors, opts); + return self.Scatter(out_dense.get(), *in_dense, opts, false); }, py::arg("in"), py::arg("out"), @@ -788,12 +784,11 @@ void BindDistributed(py::module *m) { auto p_in_tensor = std::dynamic_pointer_cast( in_tensor.impl()); auto in_dense = *p_in_tensor; - - auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true); auto task = self.AllGather(out_dense, in_dense, /*sync_op*/ true, /*use_calc_stream*/ true); + auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); return task; }, @@ -900,8 +895,6 @@ void BindDistributed(py::module *m) { auto in_dense = *p_in_tensor; // in_tensor_list should not be empty - auto *dev_ctx = self.GetDeviceContext( - in_tensor_list.back().place(), /*use_calc_stream*/ true); int world_size = self.GetSize(); auto task = self.AllToAll(out_dense, @@ -910,6 +903,8 @@ void BindDistributed(py::module *m) { GetDefaultSplitSizes(in_dense, world_size), /*sync_op*/ true, /*use_calc_stream*/ true); + auto *dev_ctx = self.GetDeviceContext( + in_tensor_list.back().place(), /*use_calc_stream*/ true); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); return task; }, @@ -1239,6 +1234,7 @@ void BindDistributed(py::module *m) { py::arg("rank"), py::arg("world_size"), py::arg("group_id") = 0, + py::arg("timeout") = 30 * 60 * 1000, py::call_guard()) .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart) .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd); diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index a30f01084a060f..fb5fd57e26255d 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -220,20 +220,46 @@ void InitTensorWithNumpyValue(TensorObject* self, "EmptyTensorInitializer is " "forbidden. Please check your code and make sure you new a " "eager tensor before init it with NumPy.")); + phi::DenseTensor* impl_ptr = static_cast(self->tensor.impl().get()); - if (platform::is_cpu_place(place)) { SetTensorFromPyArray(impl_ptr, array, place, zero_copy); } else if (platform::is_xpu_place(place)) { +#if defined(PADDLE_WITH_XPU) + phi::backends::xpu::SetXPUDeviceId(place.device); + VLOG(4) << "CurrentDeviceId: " + << phi::backends::xpu::GetXPUCurrentDeviceId() << " from " + << static_cast(place.device); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with XPU if use XPUPlace.")); +#endif SetTensorFromPyArray(impl_ptr, array, place, zero_copy); } else if (platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::backends::gpu::SetDeviceId(place.device); + VLOG(4) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() + << " from " << static_cast(place.device); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU if use CUDAPlace.")); +#endif SetTensorFromPyArray( impl_ptr, array, place, zero_copy); } else if (platform::is_cuda_pinned_place(place)) { SetTensorFromPyArray( impl_ptr, array, place, zero_copy); } else if (platform::is_custom_place(place)) { +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + phi::DeviceManager::SetDevice(place); + VLOG(4) << "CurrentDeviceId: " + << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " + << static_cast(place.device); +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace.")); +#endif SetTensorFromPyArray( impl_ptr, array, place, zero_copy); } else { @@ -455,7 +481,7 @@ std::string ParseName(std::unordered_map kws_map, } } else { if (flag_kwargs) { - if ((kws_map["name"] == nullptr) || (kws_map["name"] == Py_None)) { + if ((kws_map["name"] == NULL) || (kws_map["name"] == Py_None)) { act_name = egr::Controller::Instance().GenerateUniqueName(unique_name_prefix); } else { diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 85afc274623ea5..8552f1e7208b8c 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -28,14 +28,26 @@ static PyObject *eager_api_linear(PyObject *self, auto x = GetTensorFromArgs("linear", "X", args, 0, false); auto weight = GetTensorFromArgs("linear", "weight", args, 1, false); auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true); + tstate = PyEval_SaveThread(); + if (bias.initialized()) { + const phi::distributed::ProcessMesh *mesh = nullptr; + if (InputsContainDistTensor(&mesh, x, weight, bias)) { + ConvertAllInputsToDistTensor(mesh, x, weight, bias); + } + auto mm_out = matmul_ad_func(x, weight, false, false); auto out = add_ad_func(mm_out, bias); PyEval_RestoreThread(tstate); tstate = nullptr; return ToPyObject(out); } else { + const phi::distributed::ProcessMesh *mesh = nullptr; + if (InputsContainDistTensor(&mesh, x, weight)) { + ConvertAllInputsToDistTensor(mesh, x, weight); + } + auto mm_out = matmul_ad_func(x, weight, false, false); PyEval_RestoreThread(tstate); tstate = nullptr; diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 199d05d2c98007..5c9f6b6e8a9452 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -893,10 +893,17 @@ static PyObject* tensor_clear_gradient(TensorObject* self, selected_rows->mutable_rows()->clear(); selected_rows->mutable_value()->clear(); } - } else if (grad->is_dense_tensor()) { + } else if (grad->is_dense_tensor() || grad->is_dist_tensor()) { if (grad->initialized()) { + phi::DenseTensor* grad_t = nullptr; + if (grad->is_dense_tensor()) { + grad_t = static_cast(grad->impl().get()); + } else { + grad_t = + static_cast(grad->impl().get()) + ->unsafe_mutable_value(); + } if (set_to_zero) { - auto* grad_t = static_cast(grad->impl().get()); auto* dev_ctx = platform::DeviceContextPool::Instance().Get(grad_t->place()); phi::funcs::set_constant(*dev_ctx, grad_t, 0.0); @@ -908,9 +915,7 @@ static PyObject* tensor_clear_gradient(TensorObject* self, } else { VLOG(4) << "Gradient of " << self->tensor.name() << " is initialized, will be released."; - auto dense_tensor = - std::dynamic_pointer_cast(grad->impl()); - dense_tensor->MoveMemoryHolder(); + grad_t->MoveMemoryHolder(); } } } diff --git a/paddle/fluid/pybind/eval_frame.c b/paddle/fluid/pybind/eval_frame.c index 5b4f216be24dc7..6a647ae50818f1 100644 --- a/paddle/fluid/pybind/eval_frame.c +++ b/paddle/fluid/pybind/eval_frame.c @@ -458,6 +458,7 @@ inline static PyObject *eval_custom_code_py311_plus(PyThreadState *tstate, // Create a new function object from code object. Refer to MAKE_FUNCTION. PyFunctionObject *func = (PyFunctionObject *)PyFunction_New((PyObject *)code, frame->f_globals); + Py_INCREF(func); #if PY_VERSION_HEX < 0x030c0000 Py_XINCREF(frame->f_func->func_closure); func->func_closure = frame->f_func->func_closure; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 019b5098feb75f..39c22f9301457c 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -855,6 +855,9 @@ void BindAnalysisConfig(py::module *m) { .def("enable_memory_optim", &AnalysisConfig::EnableMemoryOptim, py::arg("x") = true) + .def("enable_new_executor", + &AnalysisConfig::EnableNewExecutor, + py::arg("x") = true) .def("enable_profile", &AnalysisConfig::EnableProfile) .def("disable_glog_info", &AnalysisConfig::DisableGlogInfo) .def("glog_info_disabled", &AnalysisConfig::glog_info_disabled) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 25f6936ab1c386..5fb99070324917 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -36,6 +36,8 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" +#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/block.h" @@ -66,6 +68,10 @@ using pir::Value; using pybind11::return_value_policy; USE_PASS(dead_code_elimination_pass); +USE_PASS(attention_fuse_pass); +USE_PASS(fused_gemm_epilogue_pass); +USE_PASS(fused_dropout_add_pass); +USE_PASS(fused_linear_param_grad_add_pass); USE_PASS(inplace_pass); PHI_DECLARE_bool(print_ir); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 0308d06d9305e1..680064ee615e1b 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -191,6 +191,7 @@ void BindBlockDesc(pybind11::module *m) { std::string name = byte_name; return self.HasVarRecursive(name); }) + .def("set_parent_idx", &pd::BlockDesc::SetParent) .def( "find_var", [](pd::BlockDesc &self, pybind11::bytes byte_name) { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index dcae0104f35598..7d676ef6e189cd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -695,6 +695,8 @@ void BindVjp(pybind11::module *m) { m->def( "call_vjp", [](pir::Operation &fwd_op, + const std::vector> &inputs, + const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients) { py::list res; @@ -704,8 +706,8 @@ void BindVjp(pybind11::module *m) { vjp_interface, phi::errors::InvalidArgument( "The vjp function is not registered in %s op ", fwd_op.name())); - std::vector> vjp_res = - vjp_interface.Vjp(&fwd_op, out_grads, stop_gradients); + std::vector> vjp_res = vjp_interface.Vjp( + &fwd_op, inputs, outputs, out_grads, stop_gradients); PADDLE_ENFORCE_EQ( stop_gradients.size(), vjp_res.size(), diff --git a/paddle/phi/api/include/tensor_utils.h b/paddle/phi/api/include/tensor_utils.h index 56ed9ae12feb45..83d02a7f716d00 100644 --- a/paddle/phi/api/include/tensor_utils.h +++ b/paddle/phi/api/include/tensor_utils.h @@ -17,6 +17,10 @@ limitations under the License. */ #include #include "paddle/phi/api/include/tensor.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#endif namespace paddle { @@ -50,4 +54,22 @@ PADDLE_API Tensor from_blob(void* data, const phi::Place& place = phi::Place(), const Deleter& deleter = nullptr); +#ifdef PADDLE_WITH_DISTRIBUTE +/** + * @brief Reshard a DistTensor by given DistAttr. + * + * @note Input of `Reshard` should be a `paddle::Tensor` whose impl is + * shared_ptr of DistTensor. According to the given DistAttr, input will be + * reshard to wanted distributed state. And it will return shared_ptr of a new + * DistTensor as outptut. + * + * @param input The input tensor to be resharded. + * @param dist_attr The dist_attr to be resharded. + * @return Shared_ptr of a new DistTensor + */ +// TODO(GhostScreaming): All APIs should call this unified function later. +PADDLE_API std::shared_ptr reshard( + const paddle::Tensor& input, + const phi::distributed::TensorDistAttr& dist_attr); +#endif } // namespace paddle diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 71257dc588dac1..3f62c52eaed1c3 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -539,7 +539,6 @@ phi::distributed::DistMetaTensor MakeDistMetaTensor( phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { if (out) { - // TODO(chenweihang): now all dist case are nullptr if (out->impl() == nullptr) { auto dist_t = std::make_shared(phi::DDim(), dist_attr); @@ -617,6 +616,7 @@ void SetReplicatedDistAttrForOutput( phi::distributed::DistTensor* out, const phi::distributed::ProcessMesh& process_mesh) { if (out) { + // For inplace output, we also need to set replicated dist attr auto dist_attr = phi::distributed::TensorDistAttr(phi::vectorize(out->dims())); dist_attr.set_process_mesh(process_mesh); diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 8ba76b64f5f7af..1a8d92c2d90406 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -661,6 +661,19 @@ ReshardApiInputToReplicatedKernelInput( return nullptr; } +paddle::optional> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::TensorDistAttr& dist_attr) { + if (tensor) { + VLOG(6) << "Optional ApiIn to Replicated KernelIn."; + return paddle::make_optional>( + ReshardApiInputToReplicatedKernelInput(dev_ctx, *tensor, dist_attr)); + } + return paddle::none; +} + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { if (out_tensor->dist_attr().is_partial()) { @@ -705,18 +718,6 @@ void ReshardKernelOutputToApiOutput( } } -std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { - return PrepareDataForDistTensor( - std::static_pointer_cast(input.impl()), - target_args_def, - transform_flag, - is_stride_kernel); -} - std::shared_ptr PrepareDataForDistTensor( const std::shared_ptr& input, const phi::TensorArgDef& target_args_def, @@ -752,6 +753,35 @@ std::shared_ptr PrepareDataForDistTensor( return nullptr; } +paddle::optional> +PrepareDataForDistTensor( + const paddle::optional>& + input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { + if (input) { + VLOG(6) << "PrepareDataForDistTensor for optional return transformed dist " + "tensor"; + return paddle::make_optional>( + PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)); + } + return paddle::none; +} + +std::shared_ptr PrepareDataForDistTensor( + const Tensor& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { + return PrepareDataForDistTensor( + std::static_pointer_cast(input.impl()), + target_args_def, + transform_flag, + is_stride_kernel); +} + std::vector> PrepareDataForDistTensor(const std::vector& input, const phi::TensorArgDef& target_args_def, diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 25c0e4137aa7f6..712f568479d2e8 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -186,6 +186,12 @@ ReshardApiInputToReplicatedKernelInput( const Tensor& tensor, const phi::distributed::TensorDistAttr& dist_attr); +paddle::optional> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::TensorDistAttr& dist_attr); + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); @@ -195,13 +201,21 @@ void ReshardKernelOutputToApiOutput( Tensor* dst_tensor); std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, + const std::shared_ptr& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel); + +paddle::optional> +PrepareDataForDistTensor( + const paddle::optional>& + input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); std::shared_ptr PrepareDataForDistTensor( - const std::shared_ptr& input, + const Tensor& input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 74ee1e380dcc4a..cbcf38e376ba37 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -27,7 +27,11 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" // clang-format off - +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/api/lib/data_transform.h" +#endif namespace paddle { namespace experimental { // declare cast api @@ -87,9 +91,7 @@ void Tensor::copy_(const Tensor &src, VLOG(8) << "Src is empty, skip copy"; return; } - // Prepare copy kernel key and outputs - auto kernel_key_set = ParseKernelKeyByInputArgs(src); - KernelType kernel_type = ParseKernelTypeByInputArgs(src); + VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name(); if (initialized()) { PADDLE_ENFORCE_EQ(dtype(), @@ -114,6 +116,12 @@ void Tensor::copy_(const Tensor &src, "Copy cannot be performed!", target_place, place())); + } + + // Prepare copy kernel key and outputs + auto kernel_key_set = ParseKernelKeyByInputArgs(src); + KernelType kernel_type = ParseKernelTypeByInputArgs(src); + if (initialized()) { kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place())); } else { @@ -129,6 +137,42 @@ void Tensor::copy_(const Tensor &src, place.GetType() == target_place.GetType() ? target_place : place); if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { +#ifdef PADDLE_WITH_DISTRIBUTE + bool run_auto_parallel = AllInputsAreDistTensor(src); + bool rank_is_in_current_mesh = false; + if (run_auto_parallel) { + auto mesh = std::static_pointer_cast( + src.impl())->dist_attr().process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + auto meta_dist_input_x = MakeDistMetaTensor(*src.impl()); + + auto dist_out = SetKernelDistOutput(this, meta_dist_input_x.dist_attr()); + auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) { + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, + 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + } + + phi::MetaTensor meta_dist_out(dist_out); + phi::UnchangedInferMeta(MakeMetaTensor(*(src.impl_)), &meta_dist_out); + + if (rank_is_in_current_mesh) { + auto dist_input_x = static_cast( + src.impl().get());; + + auto input_x = &dist_input_x->value(); + + phi::MetaTensor meta_dense_out(dense_out); + phi::UnchangedInferMeta(MakeMetaTensor(*input_x), &meta_dense_out); + + phi::Copy(*dev_ctx, *input_x, target_place, blocking, dense_out); + } + return; + } +#endif SetKernelOutput(this); phi::MetaTensor meta_out(impl_.get()); phi::UnchangedInferMeta( diff --git a/paddle/phi/api/lib/tensor_utils.cc b/paddle/phi/api/lib/tensor_utils.cc index b8d25e4f22b100..1adb7b638c7894 100644 --- a/paddle/phi/api/lib/tensor_utils.cc +++ b/paddle/phi/api/lib/tensor_utils.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/api/include/tensor_utils.h" +#include "glog/logging.h" #include "paddle/phi/api/lib/api_registry.h" #include "paddle/phi/core/dense_tensor.h" @@ -105,4 +106,38 @@ PADDLE_API Tensor from_blob(void* data, return Tensor(std::make_shared(alloc, meta)); } +#ifdef PADDLE_WITH_DISTRIBUTE +PD_REGISTER_API(reshard) + +PADDLE_API std::shared_ptr reshard( + const paddle::Tensor& input, + const phi::distributed::TensorDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ(input.is_dist_tensor(), + true, + phi::errors::InvalidArgument( + "The input tensor of ReshardFunction should be " + "``phi::distributed::DistTensor``. " + "However it's %s", + typeid(input.impl().get()).name())); + auto dev_ctx = phi::distributed::GetDistTensorDeviceContext( + std::static_pointer_cast(input.impl())); + auto input_tensor_impl = input.impl(); + std::shared_ptr dist_out_ptr = nullptr; + if (input_tensor_impl) { + phi::distributed::DistTensor* dist_tensor = + static_cast(input_tensor_impl.get()); + if (dist_tensor->dist_attr() != dist_attr) { + VLOG(6) << "reshard func, reshard tensor from " + << dist_tensor->dist_attr() << " to " << dist_attr; + auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, + dist_attr); + dist_out_ptr = func->Eval(dev_ctx, *dist_tensor, dist_attr); + } else { + dist_out_ptr = std::static_pointer_cast( + input_tensor_impl); + } + } + return dist_out_ptr; +} +#endif } // namespace paddle diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index b54307861b3674..45186294ce979f 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -62,14 +62,23 @@ optional : bias, x_max - op : conv2d_xpu - args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, int act_type, float act_param, DataType out_dtype) + args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, Tensor scale_max, Tensor out_max_in, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, int act_type, float act_param, DataType out_dtype) output : Tensor(out), Tensor(out_max) infer_meta : func : Conv2dXPUInferMeta kernel : func : conv2d_xpu data_type : x - optional : bias, branch, branch_max ,x_max + optional : bias, branch, branch_max ,x_max, scale_max, out_max_in + +- op : dequantize_xpu + args : (Tensor x, DataType out_dtype, float scale = 1.0f) + output : Tensor(y) + infer_meta : + func : DeQuantizeXPUInferMeta + kernel : + func : dequantize_xpu + data_type: x - op : embedding_with_eltwise_add_xpu args : (Tensor[] ids, Tensor[] tables, Tensor mask, int64_t padding_idx) @@ -101,14 +110,14 @@ data_type : x - op : fc_xpu - args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype) + args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, Tensor scale_max, Tensor out_max_in, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype) output : Tensor(out), Tensor(out_max) infer_meta : func : FcXPUInferMeta kernel : func : fc_xpu data_type : x - optional : bias, x_max + optional : bias, x_max, scale_max, out_max_in - op : fused_bias_act args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0) @@ -207,6 +216,38 @@ func : fused_scale_bias_relu_conv_bnstats data_type : x +- op : fusion_gru + args : (Tensor x, Tensor h0, Tensor weight_x, Tensor weight_h, Tensor bias, str activation = "tanh", str gate_activation = "sigmoid", bool is_reverse = false, bool use_seq = true, bool origin_mode = false, bool use_mkldnn = false, str mkldnn_data_type = "float32", float scale_data = 1.0f, float shift_data = 0.0f, float[] scale_weights = {1.0f}, bool force_fp32_output = false) + output : Tensor(reordered_h0), Tensor(xx), Tensor(batched_input), Tensor(batched_out), Tensor(hidden) + infer_meta : + func : FusionGRUInferMeta + kernel : + func : fusion_gru + data_type : x + optional : h0, bias + intermediate : reordered_h0, xx, batched_input, batched_out + +- op : fusion_seqconv_eltadd_relu + args : (Tensor x, Tensor filter, Tensor bias, int context_length, int context_start = 0, int context_stride = 1) + output : Tensor(out), Tensor(col_mat) + infer_meta : + func : FusionSeqConvEltAddReluInferMeta + kernel : + func : fusion_seqconv_eltadd_relu + data_type : x + intermediate : col_mat + +- op : fusion_seqexpand_concat_fc + args : (Tensor[] x, Tensor fc_weight, Tensor fc_bias, str fc_activation="identity") + output : Tensor(out), Tensor(fc_out) + infer_meta : + func : FusionSeqExpandConcatFCInferMeta + kernel : + func : fusion_seqexpand_concat_fc + data_type : x + optional : fc_bias + intermediate : fc_out + - op : fusion_transpose_flatten_concat args : (Tensor[] x, int[] trans_axis, int flatten_axis, int concat_axis) output : Tensor(out) @@ -254,6 +295,15 @@ data_type : input optional : bias_qk +- op : quantize_xpu + args : (Tensor x, DataType out_dtype, float scale = 1.0f) + output : Tensor(y) + infer_meta : + func : QuantizeXPUInferMeta + kernel : + func : quantize_xpu + data_type : x + - op : squeeze_excitation_block args : (Tensor x, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] act_type, float[] act_param, int[] filter_dims) output : Tensor(out) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 5e7cff92131712..4d04604aea3e0d 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -764,7 +764,21 @@ def gene_optional_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});""" +{code_indent} // inplace vector of tensors should also be transferred to CPU when kernel has fallen back +{code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name}; +{code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name}_vec; +{code_indent} if (kernel_result.has_fallback_cpu) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_vec){{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::optional>({PREFIX_TENSOR_NAME}{input_name}_vec->size()); +{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}_vec->size(); ++i) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}->at(i) = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); +{code_indent} }} +{code_indent} }} +{code_indent} }} +{code_indent} else {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name}); +{code_indent} }}""" ) else: input_name_tensor_map[input_name].append( @@ -773,7 +787,7 @@ def gene_optional_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); {code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name}; {code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_vec){{ {code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::optional>({PREFIX_TENSOR_NAME}{input_name}_vec->size()); @@ -802,7 +816,19 @@ def gene_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});""" +{code_indent} // inplace vector of tensors should also be transferred to CPU when kernel has fallen back +{code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name}; +{code_indent} std::unique_ptr> {PREFIX_TENSOR_NAME}{input_name}_vec; +{code_indent} if (kernel_result.has_fallback_cpu) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} {PREFIX_TENSOR_NAME}{input_name}.resize({PREFIX_TENSOR_NAME}{input_name}_vec->size()); +{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); +{code_indent} }} +{code_indent} }} +{code_indent} else {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name}); +{code_indent} }}""" ) else: input_name_tensor_map[input_name].append( @@ -811,7 +837,7 @@ def gene_vec_dense_input( input_tensor_code = ( input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, GetKernelInputArgDef(kernel.InputAt({kernel_param.index(input_name)}), actual_kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); {code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size()); {code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{ {code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); @@ -1243,7 +1269,9 @@ def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False): {code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type); {code_indent} }} {code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel; -{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); +{code_indent} // add actual_kernel_backend to select actual kernel backend after a potential falling-back to CPU +{code_indent} Backend actual_kernel_backend = kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend; +{code_indent} auto* dev_ctx = GetDeviceContextByBackend(actual_kernel_backend); {input_tensors} {output_create} {pre_save_stride} diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index 7a71555c1156fd..6d7bfbb232eadb 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -274,7 +274,10 @@ def gene_output( output_create = ( output_create + f""" -{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});""" +{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code}); +{code_indent} if (kernel_result.has_fallback_cpu) {{ +{code_indent} TransDataBackend(kernel_out_{i}, actual_kernel_backend, kernel_out_{i}); +{code_indent} }}""" ) else: @@ -406,6 +409,9 @@ def declare_extension_api(): return """ namespace paddle { PD_DECLARE_API(from_blob); +#ifdef PADDLE_WITH_DISTRIBUTE +PD_DECLARE_API(reshard); +#endif } // namespace paddle """ diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 2bf886ab7fa5ef..d6c90584cb540f 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -44,19 +44,20 @@ # TODO(chenweihang): add view support later MAIN_DIST_BRANCH_TEMPLATE = """ // Auto Parallel condition - if (use_dist_branch) {{ + if (run_auto_parallel) {{ // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){} // 2. Create API Output & Prepare Dist and Dense Output{} // 3. Infer DistTensor's Global Shape{}\n - if (!computation_clip_for_pp){{ + if (rank_is_in_current_mesh){{ // 4. Select Kernel{} // 5. Reshard Input{}\n // 6. PrepareData (DataTransform & Prepare Dense Input){} // 7. Infer Local DenseTensor Meta{} // 8. DenseTensor Kernel Call{} // 9. Reshard Partial Output to Replicated (Temporary){}\n - }} - // 10. Return + }}\n + // 10. Set Output Dist Attr For Default Impl{}\n + // 11. Return {} }} """ @@ -65,21 +66,23 @@ # 1. Non computation rank clip GET_MESH_TEMPLATE = """ auto mesh = std::static_pointer_cast({}impl())->dist_attr().process_mesh(); - computation_clip_for_pp = !phi::distributed::IsCurRankInMesh(mesh);""" + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh);""" # Auto Parallel condition AUTO_PARALLEL_COND_TEMPLATE = """ - bool use_dist_branch = AllInputsAreDistTensor({input_args}); - bool computation_clip_for_pp = false; - if (use_dist_branch) {{{mesh} + bool run_auto_parallel = AllInputsAreDistTensor({input_args}); + bool rank_is_in_current_mesh = true; + if (run_auto_parallel) {{{mesh} }} - if (!computation_clip_for_pp) {{{kernel_code} + if (rank_is_in_current_mesh) {{{kernel_code} }} """ # 1. InferSPMD SINGLE_DIST_META_IN_TEMPLATE = """ - auto meta_dist_input_{} = MakeDistMetaTensor(*{}.impl());""" + auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());""" +OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE = """ + auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*(*{name}).impl()) : phi::distributed::DistMetaTensor();""" INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::{}({}); """ @@ -100,7 +103,7 @@ SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """ auto dist_out = SetKernelDistOutput(&api_output); auto dense_out = dist_out->unsafe_mutable_value(); - if (computation_clip_for_pp) {{ + if (!rank_is_in_current_mesh) {{ *dense_out = phi::DenseTensor( std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), phi::DenseTensorMeta()); @@ -108,8 +111,8 @@ """ MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """ auto dist_out_{idx} = SetKernelDistOutput({out}); - auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value(); - if (computation_clip_for_pp) {{ + auto dense_out_{idx} = dist_out_{idx} ? dist_out_{idx}->unsafe_mutable_value() : nullptr; + if (!rank_is_in_current_mesh) {{ *dense_out_{idx} = phi::DenseTensor( std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), phi::DenseTensorMeta()); @@ -118,7 +121,7 @@ SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput(&api_output, spmd_info.second[0]); auto dense_out = dist_out->unsafe_mutable_value(); - if (computation_clip_for_pp) {{ + if (!rank_is_in_current_mesh) {{ *dense_out = phi::DenseTensor( std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), phi::DenseTensorMeta()); @@ -127,7 +130,7 @@ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out_{idx} = SetKernelDistOutput({out}, spmd_info.second[{idx}]); auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value(); - if (computation_clip_for_pp) {{ + if (!rank_is_in_current_mesh) {{ *dense_out_{idx} = phi::DenseTensor( std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), phi::DenseTensorMeta()); @@ -138,7 +141,7 @@ std::vector dense_out(dist_out.size()); for (size_t i = 0; i < dist_out.size(); ++i) {{ dense_out[i] = const_cast(&dist_out[i]->value()); - if (computation_clip_for_pp) {{ + if (!rank_is_in_current_mesh) {{ *dense_out[i] = phi::DenseTensor( std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), phi::DenseTensorMeta()); @@ -150,7 +153,7 @@ std::vector dense_out_{out_name}(dist_out_{out_name}.size()); for (size_t i = 0; i < dist_out_{out_name}.size(); ++i) {{ dense_out_{out_name}[i] = const_cast(&dist_out_{out_name}[i]->value()); - if (computation_clip_for_pp) {{ + if (!rank_is_in_current_mesh) {{ *dense_out_{out_name}[i] = phi::DenseTensor( std::make_shared(nullptr, 0, phi::distributed::GetDefaultPlace()), phi::DenseTensorMeta()); @@ -214,15 +217,6 @@ INFER_GLOBAL_SHAPE_TEMPLATE = """ phi::{}({}{}); """ -# Dist Branch will not generated in the API that doesn't have input tensor. -SET_SINGLE_OUT_REPLICATED_DIST_ATTR = """ - SetReplicatedDistAttrForOutput({}, spmd_info.first[0].process_mesh());""" -SET_VECTOR_OUT_REPLICATED_DIST_ATTR = """ - auto current_process_mesh = spmd_info.first[0].process_mesh(); - for (size_t i = 0; i < dist_out.size(); ++i) {{ - SetReplicatedDistAttrForOutput(dist_out[i], current_process_mesh); - }} -""" # 4. Select Kernel KERNEL_SELECTION_TEMPLATE = """ @@ -265,6 +259,10 @@ }} """ OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE = """ + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional((*dist_input_{name})->value()) : paddle::none; +""" +OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; """ @@ -354,8 +352,21 @@ RESHARD_P2R_MULTI_SINGLE_OUTPUT_TEMPLATE = """ ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out_{});""" UNSUPPORTED_RESHARD_OUTPUT_COMMENT_TEMPLATE = """ - // API `{}` does not need to support ReshardOutput now + // API `{}` does not need to support ReshardOutput now.""" + +# 10. Set Output DistAttr for Default impl +# Dist Branch will not generated in the API that doesn't have input tensor. +CURRENT_PROCESS_MESH_TEMPLATE = """ + auto current_process_mesh = spmd_info.first[0].process_mesh();""" +SET_SINGLE_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ + SetReplicatedDistAttrForOutput({}, current_process_mesh);""" +SET_VECTOR_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ + for (size_t i = 0; i < {name}.size(); ++i) {{ + SetReplicatedDistAttrForOutput({name}[i], current_process_mesh); + }} """ +NONEED_TO_SET_DIST_ATTR_COMMENT_TEMPLATE = """ + // API `{}` does not need to set DistAttr for output.""" # BaseAPI members: # inputs: @@ -681,7 +692,7 @@ def generate_specialized_infer_spmd_code(self) -> str: if param in input_names: if self.inputs['input_info'][param] == "const Tensor&": input_decl_code += SINGLE_DIST_META_IN_TEMPLATE.format( - param, param + name=param ) input_args_code += "meta_dist_input_" + param + ", " else: @@ -722,14 +733,20 @@ def generate_general_infer_spmd_code(self) -> str: if param in input_names: if self.inputs['input_info'][param] == "const Tensor&": input_decl_code += SINGLE_DIST_META_IN_TEMPLATE.format( - param, param + name=param ) input_args_code += "meta_dist_input_" + param + ", " elif ( self.inputs['input_info'][param] - == "const std::vector&" - or self.inputs['input_info'][param] == "const paddle::optional&" + ): + input_decl_code += ( + OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE.format(name=param) + ) + input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] + == "const std::vector&" or self.inputs['input_info'][param] == "const paddle::optional>&" ): @@ -937,17 +954,12 @@ def generate_infer_global_shape_code(self) -> str: # 3. get meta tensor output args output_decl_code = "" output_args_code = "" - set_out_dist_attr_code = "" for i, out_name in enumerate(self.dist_output_args): if self.outputs['types'][i] == 'std::vector': output_decl_code += VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE.format( name=out_name ) output_args_code += f"{out_name}_meta_ptr_vec, " - if self.generate_general_infer_spmd is True: - set_out_dist_attr_code += ( - SET_VECTOR_OUT_REPLICATED_DIST_ATTR - ) else: output_decl_code += SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE.format( out_name, out_name @@ -958,10 +970,6 @@ def generate_infer_global_shape_code(self) -> str: output_args_code += ( f"{out_name} ? &meta_{out_name} : nullptr, " ) - if self.generate_general_infer_spmd is True: - set_out_dist_attr_code += ( - SET_SINGLE_OUT_REPLICATED_DIST_ATTR.format(out_name) - ) output_args_code = output_args_code[:-2] return ( @@ -970,7 +978,6 @@ def generate_infer_global_shape_code(self) -> str: + INFER_GLOBAL_SHAPE_TEMPLATE.format( infer_meta_func_code, input_args_code, output_args_code ) - + set_out_dist_attr_code ) def generate_kernel_selection_code(self) -> str: @@ -991,7 +998,11 @@ def generate_reshard_input_code(self) -> str: for i, param in enumerate(kernel_params): if param in input_names: - if self.inputs['input_info'][param] == "const Tensor&": + if ( + self.inputs['input_info'][param] == "const Tensor&" + or self.inputs['input_info'][param] + == "const paddle::optional&" + ): if self.generate_general_infer_spmd is True: input_reshard_code += ( SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( @@ -1077,11 +1088,20 @@ def generate_optional_single_dense_input( if kernel_param is None: kernel_param = input_names + attr_names - input_tensor_code += OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE.format( - name=input_name, - index=kernel_param.index(input_name), - trans_flag=trans_flag, - ) + if self.generate_infer_spmd is True: + input_tensor_code += OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE.format( + name=input_name, + index=kernel_param.index(input_name), + trans_flag=trans_flag, + ) + else: + input_tensor_code += ( + OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( + name=input_name, + index=kernel_param.index(input_name), + trans_flag=trans_flag, + ) + ) return input_tensor_code @@ -1348,6 +1368,29 @@ def generate_reshard_partial_out_to_replicated_code(self) -> str: return reshard_p2r_code + def generate_output_dist_attr_setting(self) -> str: + set_out_dist_attr_code = "" + if self.generate_general_infer_spmd is True: + set_out_dist_attr_code += CURRENT_PROCESS_MESH_TEMPLATE + for i, out_name in enumerate(self.dist_output_args): + if self.outputs['types'][i] == 'std::vector': + set_out_dist_attr_code += ( + SET_VECTOR_OUT_REPLICATED_DIST_ATTR_TEMPLATE.format( + name=out_name + ) + ) + else: + set_out_dist_attr_code += ( + SET_SINGLE_OUT_REPLICATED_DIST_ATTR_TEMPLATE.format( + out_name + ) + ) + else: + set_out_dist_attr_code = ( + NONEED_TO_SET_DIST_ATTR_COMMENT_TEMPLATE.format(self.api) + ) + return set_out_dist_attr_code + def generate_return_code(self) -> str: return self.gene_return_code() @@ -1365,6 +1408,7 @@ def generate_auto_paralel_branch(self) -> str: self.generate_infer_meta_code(), self.generate_kernel_call_code(), self.generate_reshard_partial_out_to_replicated_code(), + self.generate_output_dist_attr_setting(), self.generate_return_code(), ) diff --git a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py index 9368d6908b33cc..759beed6eb7e39 100644 --- a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py @@ -24,19 +24,20 @@ MAIN_DIST_BRANCH_TEMPLATE = """ // Auto Parallel condition - if (use_dist_branch) {{ + if (run_auto_parallel) {{ // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){} // 2. Create Temporary Output & Prepare Dist and Dense Output{} // 3. Infer DistTensor's Global Shape{}\n - if (!computation_clip_for_pp){{ - // 4. Select Kernel{} - // 5. Reshard Input{}\n - // 6. PrepareData (DataTransform & Prepare Dense Input){} - // 7. Infer Local DenseTensor Meta{} - // 8. DenseTensor Kernel Call{} - // 9. Reshard Partial Output to Replicated (Temporary){}\n + // 4. Set Output Dist Attr For Default Impl{}\n + if (rank_is_in_current_mesh){{ + // 5. Select Kernel{} + // 6. Reshard Input{}\n + // 7. PrepareData (DataTransform & Prepare Dense Input){} + // 8. Infer Local DenseTensor Meta{} + // 9. DenseTensor Kernel Call{} + // 10. Reshard Partial Output to Replicated (Temporary){}\n }} - // 10. Return + // 11. Return {} }} """ @@ -265,6 +266,7 @@ def generate_auto_paralel_branch(self) -> str: self.generate_infer_spmd_code(), self.generate_output_creation_code(), self.generate_infer_global_shape_code(), + self.generate_output_dist_attr_setting(), self.generate_kernel_selection_code(), self.generate_reshard_input_code(), self.generate_prepare_data_code(), diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 47eda81f5d0ca0..7453c7ec49a908 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -86,7 +86,7 @@ kernel : func : batch_norm_double_grad data_type : x - optional : out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad + optional : scale, out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad inplace : (grad_out -> grad_out_grad) - backward_op : batch_norm_grad @@ -99,7 +99,7 @@ kernel : func : batch_norm_grad data_type : out_grad - optional : mean_out, variance_out, reserve_space + optional : scale, bias, mean_out, variance_out, reserve_space composite: batch_norm_grad(x, scale, bias, mean_out, variance_out, saved_mean, saved_variance, reserve_space, out_grad, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics) backward : batch_norm_double_grad @@ -246,8 +246,8 @@ invoke : zeros_like(out_grad) - backward_op : frobenius_norm_grad - forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all) + forward : frobenius_norm(Tensor x, IntArray axis, bool keep_dim, bool reduce_all) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis, bool keep_dim, bool reduce_all) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -380,6 +380,7 @@ infer_meta : func : UnchangedInferMeta param: [x] + spmd_rule : ReductionGradInferSpmd kernel : func : mean_grad backward : mean_double_grad @@ -702,6 +703,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ReductionGradInferSpmd kernel : func : sum_grad composite : sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 01acb338c987bd..6b1206f617bcaa 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -122,7 +122,7 @@ data_type : x view : (mean -> mean_out), (variance -> variance_out) backward : batch_norm_grad - optional : reserve_space + optional : scale, bias, reserve_space - op : c_allgather args : (Tensor x, int ring_id, int nranks, bool use_calc_stream) @@ -214,7 +214,7 @@ inplace : (x -> out) - op : c_sync_comm_stream - args : (Tensor x) + args : (Tensor x, int ring_id) output : Tensor(out) infer_meta : func : UnchangedInferMeta @@ -452,10 +452,10 @@ inplace: (x -> out) - op : frobenius_norm - args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) + args : (Tensor x, IntArray axis, bool keep_dim, bool reduce_all) output : Tensor(out) infer_meta : - func : ReduceInferMetaBase + func : ReduceIntArrayAxisInferMetaBase kernel : func : frobenius_norm backward : frobenius_norm_grad @@ -726,6 +726,7 @@ output : Tensor(out) infer_meta : func : ReduceIntArrayAxisInferMeta + spmd_rule : ReductionMeanInferSpmdDynamic kernel : func : mean backward : mean_grad @@ -1015,6 +1016,7 @@ output : Tensor(out) infer_meta : func : SumInferMeta + spmd_rule : ReductionSumInferSpmdDynamic kernel : func : sum data_type : x diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f7d3878e44847f..93eefa553f7bca 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -987,8 +987,8 @@ extra : attrs : [bool use_mkldnn = false, bool use_cudnn = false] -- op : exponential_ - backward : exponential__grad +- op : exponential_ (exponential) + backward : exponential__grad (exponential_grad) inputs : x : X outputs : @@ -1263,8 +1263,6 @@ scale : Scale outputs : out : Out - attrs : - epsilon : epsilon - op : fused_fc_elementwise_layernorm inputs : @@ -1278,11 +1276,6 @@ out : Out mean : Mean variance : Variance - attrs : - x_num_col_dims : x_num_col_dims - activation_type : activation_type - epsilon : epsilon - begin_norm_axis : begin_norm_axis - op : fused_feedforward backward: fused_feedforward_grad @@ -1332,15 +1325,51 @@ extra : attrs : [str data_format = "AnyLayout"] -- op : fusion_transpose_flatten_concat +- op : fusion_gru inputs : x : X + h0 : H0 + weight_x : WeightX + weight_h : WeightH + bias : Bias + outputs : + reordered_h0 : ReorderedH0 + xx : XX + batched_input : BatchedInput + batched_out : BatchedOut + hidden : Hidden + attrs : + scale_data : Scale_data + shift_data : Shift_data + scale_weights : Scale_weights + +- op : fusion_seqconv_eltadd_relu + inputs : + x : X + filter : Filter + bias : Bias outputs : out : Out + col_mat : ColMat attrs : - trans_axis : trans_axis - flatten_axis : flatten_axis - concat_axis : concat_axis + context_length : contextLength + context_start : contextStart + context_stride : contextStride + +- op : fusion_seqexpand_concat_fc + inputs : + x : X + fc_weight : FCWeight + fc_bias : FCBias + outputs : + out : Out + fc_out : FCOut + +- op : fusion_transpose_flatten_concat + inputs : + x : X + outputs : + out : Out - op : gather backward : gather_grad @@ -3275,6 +3304,12 @@ attrs: pivot : pivots +- op: memcpy + inputs: + x: X + outputs: + out: Out + - op: memcpy_d2h inputs : x : X diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index aaf6c4e1445ef4..b3c6d31c710ec1 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1012,10 +1012,10 @@ backward : frame_grad - op : full_int_array - args : (IntArray value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) + args : (int64_t[] value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) output: Tensor(out) infer_meta : - func : CreateIntArrayInferMeta + func : CreateVecShapeInferMeta param : [value, dtype] kernel : func : full_int_array diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 9f8def740385b4..2dc2657491068e 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -256,7 +256,7 @@ args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) output : Tensor(out) infer_meta : - func : ReduceInferMetaBase + func : ReduceIntArrayAxisInferMetaBase kernel : func : frobenius_norm param : [x, axis, keepdim, reduce_all] diff --git a/paddle/phi/backends/dynload/nccl.h b/paddle/phi/backends/dynload/nccl.h index 6c73c562caa697..91b6f5dcd58dc5 100644 --- a/paddle/phi/backends/dynload/nccl.h +++ b/paddle/phi/backends/dynload/nccl.h @@ -44,6 +44,7 @@ extern void* nccl_dso_handle; __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ __macro(ncclCommCuDevice); \ @@ -55,6 +56,7 @@ extern void* nccl_dso_handle; __macro(ncclGroupEnd); \ __macro(ncclReduce); \ __macro(ncclReduceScatter); \ + __macro(ncclCommGetAsyncError); \ __macro(ncclGetErrorString); NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) diff --git a/paddle/phi/backends/dynload/rccl.h b/paddle/phi/backends/dynload/rccl.h index 9232d387d2d19d..e1018a3f253fa5 100644 --- a/paddle/phi/backends/dynload/rccl.h +++ b/paddle/phi/backends/dynload/rccl.h @@ -44,6 +44,7 @@ extern void* rccl_dso_handle; __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ __macro(ncclCommCuDevice); \ @@ -55,6 +56,7 @@ extern void* rccl_dso_handle; __macro(ncclGroupEnd); \ __macro(ncclReduce); \ __macro(ncclReduceScatter); \ + __macro(ncclCommGetAsyncError); \ __macro(ncclGetErrorString); RCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_RCCL_WRAP) diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 6252bbc54c9334..a9246e8a75303d 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1519,8 +1519,11 @@ class PoolingOneDNNHandler } if (adaptive) { - ComputeAdaptivePoolParameters( - src_tz, &copied_kernel_size, &copied_strides); + ComputeAdaptivePoolParameters(src_tz, + onednn_paddings[0], + onednn_paddings[1], + &copied_kernel_size, + &copied_strides); } bool is_test = dev_ctx.HasDnnAttr("is_test") @@ -1612,8 +1615,11 @@ class PoolingOneDNNHandler } if (adaptive) { - ComputeAdaptivePoolParameters( - diff_src_tz, &copied_kernel_size, &copied_strides); + ComputeAdaptivePoolParameters(src_tz, + onednn_paddings[0], + onednn_paddings[1], + &copied_kernel_size, + &copied_strides); } memory::dims dilation = {0, 0}; @@ -1672,23 +1678,45 @@ class PoolingOneDNNHandler return mem_p; } - static void ComputeAdaptivePoolParameters(const std::vector& src_tz, - std::vector* kernel_size, - std::vector* strides) { + static void ComputeAdaptivePoolParameters( + const std::vector& src_tz, + const std::vector& padding_l, + const std::vector& padding_r, + std::vector* kernel_size, + std::vector* strides) { // https://github.com/oneapi-src/oneDNN/tree/bkocot/adaptive-pooling/rfcs/20200818-adaptive-pooling auto IH = static_cast(src_tz[src_tz.size() - 2]); auto IW = static_cast(src_tz[src_tz.size() - 1]); auto OH = static_cast(kernel_size->at(0)); auto OW = static_cast(kernel_size->at(1)); - strides->at(0) = - static_cast(floor((IH * 2.0) / OH) - floor(IH / OH)); - strides->at(1) = - static_cast(floor((IW * 2.0) / OW) - floor(IW / OW)); - kernel_size->at(0) = - static_cast(ceil((IH * 2.0) / OH) - floor(IH / OH)); - kernel_size->at(1) = - static_cast(ceil((IW * 2.0) / OW) - floor(IW / OW)); + /* + The previous calculation formula is given by OneDNN rfc, but in some odd + cases(mod(I/O)>=O/2) there will be problems with the calculation results. + Now change the formula to the general calculation formula of + AdaptivePool when in mod(I/O)>=O/2 case: + stride=floor(input_size/output_size) + kernel_size=input_size-(output_size-1)*stride + */ + int mod_H = IH - floor(IH / OH) * OH; + int mod_W = IW - floor(IW / OW) * OW; + if (2 * mod_H < OH && 2 * mod_W < OW) { + strides->at(0) = + static_cast(floor((IH * 2.0) / OH) - floor(IH / OH)); + strides->at(1) = + static_cast(floor((IW * 2.0) / OW) - floor(IW / OW)); + kernel_size->at(0) = + static_cast(ceil((IH * 2.0) / OH) - floor(IH / OH)); + kernel_size->at(1) = + static_cast(ceil((IW * 2.0) / OW) - floor(IW / OW)); + } else { + strides->at(0) = static_cast(floor(IH / OH)); + strides->at(1) = static_cast(floor(IW / OW)); + kernel_size->at(0) = static_cast( + IH + padding_l[0] + padding_r[0] - floor((OH - 1) * strides->at(0))); + kernel_size->at(1) = static_cast( + IW + padding_l[1] + padding_r[1] - floor((OW - 1) * strides->at(1))); + } } private: diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 74a8cf0bc1150e..c5d6c36ad26ef5 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -176,7 +176,9 @@ XPUOpMap& get_kl2_ops() { {"conv1d_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"conv2d_xpu", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT8})}, {"conv3d_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"conv3d", @@ -210,6 +212,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"dequantize_xpu", + XPUKernelSet({phi::DataType::INT16, phi::DataType::INT8})}, {"diag_v2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -317,7 +321,9 @@ XPUOpMap& get_kl2_ops() { {"fast_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fc_xpu", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT8})}, {"fill", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, @@ -617,6 +623,8 @@ XPUOpMap& get_kl2_ops() { {"prelu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})}, + {"quantize_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"range", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64, @@ -777,14 +785,19 @@ XPUOpMap& get_kl2_ops() { {"split", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"split_with_num", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sqrt", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -797,6 +810,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze2", XPUKernelSet({phi::DataType::FLOAT64, @@ -806,6 +821,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze", XPUKernelSet({phi::DataType::FLOAT64, @@ -814,6 +830,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze_grad", XPUKernelSet({phi::DataType::FLOAT64, @@ -822,6 +840,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"stack", XPUKernelSet({phi::DataType::FLOAT32, @@ -894,24 +914,28 @@ XPUOpMap& get_kl2_ops() { {"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, @@ -935,7 +959,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze2", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -944,7 +969,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -952,7 +978,9 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -960,8 +988,9 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, + phi::DataType::BFLOAT16})}, {"unstack", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 29a85493958949..858f7189cae6db 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -759,14 +759,21 @@ XPUOpMap& get_kl3_ops() { {"split", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"split_with_num", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sqrt", + XPUKernelSet({ + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, + })}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -779,6 +786,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze2", XPUKernelSet({phi::DataType::FLOAT64, @@ -788,6 +797,7 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze", XPUKernelSet({phi::DataType::FLOAT64, @@ -796,6 +806,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze_grad", XPUKernelSet({phi::DataType::FLOAT64, @@ -804,6 +816,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"stack", XPUKernelSet({phi::DataType::FLOAT32, @@ -876,24 +890,28 @@ XPUOpMap& get_kl3_ops() { {"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, @@ -917,7 +935,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze2", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -926,7 +945,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -934,7 +954,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -942,8 +964,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, + phi::DataType::BFLOAT16})}, {"unstack", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index 12c59059c7c322..8e58ab4bf840e6 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -5,7 +5,9 @@ add_subdirectory(auto_parallel) set(DISTRIBUTED_COMMON_SRCS comm_context_manager.cc) if(WITH_NCCL OR WITH_RCCL) - list(APPEND DISTRIBUTED_COMMON_SRCS nccl_comm_context.cc) + list(APPEND DISTRIBUTED_COMMON_SRCS comm_task_manager.cc) + list(APPEND DISTRIBUTED_COMMON_SRCS nccl_comm_context.cc nccl_comm_task.cc + nccl_tools.cc) endif() if(WITH_GLOO) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h index efbf38d28f9f0a..30757c5a1cdaa6 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h @@ -22,6 +22,8 @@ namespace distributed { class DistMetaTensor : public MetaTensor { public: + DistMetaTensor() : MetaTensor() {} + // supporting implicit construction is easier to use DistMetaTensor(TensorBase* tensor) // NOLINT : MetaTensor(tensor) {} diff --git a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc index 9d5d8f43f76708..960807000974e7 100644 --- a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc @@ -73,15 +73,14 @@ int64_t FindFirstDiffShardAxis(const TensorDistAttr& in_dist_attr, bool SameNdMeshReshardFunction::IsSuitable( const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; - - flag &= (in.dist_attr().process_mesh() == out_dist_attr.process_mesh()); - flag &= (out_dist_attr.process_mesh().ndim() > 1); + RESHARD_SHORTCUT_IF_FALSE(in.dist_attr().process_mesh() == + out_dist_attr.process_mesh()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.process_mesh().ndim() > 1); // check the input and output dims_mapping is not equal - flag &= in.dist_attr() != out_dist_attr; + RESHARD_SHORTCUT_IF_FALSE(in.dist_attr() != out_dist_attr); - return flag; + return true; } void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, @@ -121,7 +120,8 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, // 1.3 Calculate the input one dim dist attr TensorDistAttr in_one_dim_dist_attr(vectorize(in.dims())); in_one_dim_dist_attr.set_process_mesh(sub_mesh); - in_one_dim_dist_attr.set_partial_status(std::vector{0}); + in_one_dim_dist_attr.set_partial_status(std::vector{0}, + kv.second); // 1.4 Calculate the output one dim dist attr TensorDistAttr out_one_dim_dist_attr(vectorize(in.dims())); diff --git a/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.cc index f9aaa6f8adf7f4..53ea82569eb5e8 100644 --- a/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.cc @@ -20,25 +20,25 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/kernels/all_reduce_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { namespace distributed { bool PToRReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; - - flag &= in.dist_attr().is_partial(); - flag &= out_dist_attr.is_replicated(); + RESHARD_SHORTCUT_IF_FALSE(in.dist_attr().is_partial()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); const auto& in_process_mesh = in.dist_attr().process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void PToRReshardFunction::Eval(DeviceContext* dev_ctx, @@ -50,9 +50,18 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); const auto& in_partial_status = in_dist_attr.partial_status(); + auto in_reduce_type = in_partial_status.at(0); + bool reduce_mean = false; auto dtype = in.dtype(); - int64_t reduce_type = static_cast(in_partial_status.at(0)); + if (in_reduce_type == ReduceType::kRedAvg) { + in_reduce_type = ReduceType::kRedSum; + reduce_mean = true; + } + int64_t reduce_type = static_cast(in_reduce_type); + VLOG(3) << "Transfer from partial to replicated status with reduce type " + << reduce_type; + RESHARD_FUNCTOR_WITH_COMM(dev_ctx, AllReduce, dtype, @@ -61,6 +70,24 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, reduce_type, GetMutableTensor(out)); + if (reduce_mean) { + VLOG(3) << "Do reduce mean after all reduce sum"; + DenseTensor tensor_of_num_process; + IntArray shape({1}); + RESHARD_FUNCTOR(dev_ctx, + Full, + in.dtype(), + shape, + static_cast(in_process_ids.size()), + &tensor_of_num_process); + RESHARD_FUNCTOR(dev_ctx, + Divide, + dtype, + out->value(), + tensor_of_num_process, + GetMutableTensor(out)); + } + SetDistProps(out, in.dims(), out_dist_attr); } diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc index 77569c1ecfbac0..0af1c2b625a844 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc @@ -27,20 +27,19 @@ namespace distributed { bool RToPReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= in_dist_attr.is_replicated(); - flag &= out_dist_attr.is_partial(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_partial()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc index bc6cb393a15b86..3adf488efca4e4 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -25,20 +25,19 @@ namespace distributed { bool RToSReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= in_dist_attr.is_replicated(); - flag &= out_dist_attr.is_shard(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index ce52f0a203680f..97c6d59cc27552 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -15,10 +15,13 @@ #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "glog/logging.h" +#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/store/store_utils.h" +#include "paddle/phi/core/enforce.h" namespace phi { namespace distributed { @@ -142,6 +145,24 @@ bool IsCurRankInMesh(const ProcessMesh& process_mesh) { process_ids.end()); } +// Only Input is DistTensor and current device id isn't in DistTensor's mesh +// will return true. +bool NeedComputationClipForPP( + const std::shared_ptr& tensor_impl) { + PADDLE_ENFORCE_EQ( + phi::distributed::DistTensor::classof(tensor_impl.get()), + true, + phi::errors::InvalidArgument( + "The input tensor of NeedComputationClipForPP should be " + "``phi::distributed::DistTensor``. " + "However it's %s", + typeid(tensor_impl.get()).name())); + return !IsCurRankInMesh( + std::static_pointer_cast(tensor_impl) + ->dist_attr() + .process_mesh()); +} + Place GetDefaultPlace() { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (phi::backends::gpu::GetGPUDeviceCount() >= 0) { @@ -151,5 +172,13 @@ Place GetDefaultPlace() { return paddle::CPUPlace(); } +phi::DeviceContext* GetDistTensorDeviceContext( + const std::shared_ptr& input) { + // TODO(GhostScreaming): pipeline parallel may create an undefined middle grad + // tensor. In such case, we need to get default place. + auto place = input && input->defined() ? input->place() : GetDefaultPlace(); + return phi::DeviceContextPool::Instance().Get(place); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h index a40b62c182f318..f4a4cd68ce5e19 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -22,6 +22,9 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/visit_type.h" namespace phi { @@ -32,8 +35,14 @@ class ProcessMesh; bool IsCurRankInMesh(const ProcessMesh& process_mesh); +bool NeedComputationClipForPP( + const std::shared_ptr& tensor_impl); + Place GetDefaultPlace(); +phi::DeviceContext* GetDistTensorDeviceContext( + const std::shared_ptr& input); + int64_t GetLocalRankInParticipate(const std::vector& process_ids, int64_t global_rank = -1); @@ -145,5 +154,12 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, } while (0) #endif +#define RESHARD_SHORTCUT_IF_FALSE(expr) \ + do { \ + if (!(expr)) { \ + return false; \ + } \ + } while (0) + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc index 361c83d64a007f..a5f8ce455871d7 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc @@ -28,30 +28,28 @@ namespace distributed { bool SToRReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - // TODO(GhostScreaming): Fix problems of using uninitialized DistTensor's - // local_dims const auto& in_dims_mapping = in_dist_attr.dims_mapping(); + const auto& in_dims_mapping = in_dist_attr.dims_mapping(); - flag &= in_dist_attr.is_shard(); - flag &= out_dist_attr.is_replicated(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - // TODO(GhostScreaming): Fix problems of using uninitialized DistTensor's - // local_dims Ensure the tensor is balanced split, or we need send/recv rather - // than all_gather int split_axis = - // GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; int64_t - // num_of_process = in_process_mesh.size(); flag &= - // (in.local_dims()[static_cast(split_axis)] * num_of_process == - // in.dims()[static_cast(split_axis)]); + // Ensure the tensor is balanced split, or we need send/recv rather than + // all_gather + int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; + int64_t num_of_process = in_process_mesh.size(); + RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast(split_axis)] * + num_of_process == + in.dims()[static_cast(split_axis)]); - return flag; + return true; } void SToRReshardFunction::Eval(DeviceContext* dev_ctx, diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc index 3aafe1dc7fbeea..d90903220d8725 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc @@ -28,20 +28,19 @@ namespace distributed { bool SToSReshardFunction::IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= in_dist_attr.is_shard(); - flag &= out_dist_attr.is_shard(); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh.ndim() == 1); - flag &= (out_process_mesh.ndim() == 1); - flag &= (in_process_mesh == out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); - return flag; + return true; } void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, diff --git a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc index e7aed9ae788b04..5360368469f562 100644 --- a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc @@ -46,18 +46,20 @@ std::vector GetUnionProcessIds(std::vector in_process_ids, bool SameStatusReshardFunction::IsSuitable( const DistTensor& in, const TensorDistAttr& out_dist_attr) { - bool flag = true; const auto& in_dist_attr = in.dist_attr(); - flag &= (in_dist_attr.dims_mapping() == out_dist_attr.dims_mapping()); - flag &= (in_dist_attr.partial_dims() == out_dist_attr.partial_dims()); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.dims_mapping() == + out_dist_attr.dims_mapping()); + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.partial_dims() == + out_dist_attr.partial_dims()); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - flag &= (in_process_mesh != out_process_mesh); - flag &= (in_process_mesh.shape() == out_process_mesh.shape()); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() == + out_process_mesh.shape()); - return flag; + return true; } void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 338ee4b4bad177..2a5b336f34e256 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -33,6 +33,7 @@ #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/distributed/nccl_tools.h" #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/core/distributed/xccl_comm_context.h" @@ -56,18 +57,19 @@ void CommContextManager::CreateNCCLCommContext( const std::string& unique_comm_key, int rank, int size, - const std::string& hash_key) { + const std::string& hash_key, + const P2POption* p2p_opt) { auto& comm_context_manager = CommContextManager::GetInstance(); if (comm_context_manager.Has(unique_comm_key)) { return; } ncclUniqueId nccl_id; - if (rank == 0) { + if (rank == 0 || (p2p_opt && p2p_opt->is_p2p_op && p2p_opt->p2p_rank == 0)) { PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id)); } std::string unique_key = "NCCLCommContext/" + unique_comm_key + hash_key; - if (rank == 0) { + if (rank == 0 || (p2p_opt && p2p_opt->is_p2p_op && p2p_opt->p2p_rank == 0)) { std::vector nccl_id_wrapper( reinterpret_cast(&nccl_id), reinterpret_cast(&nccl_id) + NCCL_UNIQUE_ID_BYTES); @@ -77,6 +79,14 @@ void CommContextManager::CreateNCCLCommContext( std::memcpy(&nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size()); } + if (p2p_opt) { + rank = p2p_opt->rank; + size = p2p_opt->num_ranks; + } + VLOG(3) << "init NCCLCommContext rank: " << rank << ", size: " << size + << ", unique_comm_key: " << unique_comm_key + << ", unique_key: " << unique_key + << ", nccl_id: " << SerializeNCCLUniqueId(nccl_id); auto nccl_comm_context = std::make_unique(rank, size, nccl_id); if (CommContextManager::device_id != -1) { diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 69e58a96e18e1a..2229786db38551 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -30,6 +30,13 @@ namespace phi { namespace distributed { +struct P2POption { + bool is_p2p_op; + int p2p_rank; + int num_ranks; + int rank; +}; + class Store; class CommContextManager { @@ -62,7 +69,8 @@ class CommContextManager { const std::string& unique_comm_key, int rank, int size, - const std::string& hash_key = ""); + const std::string& hash_key = "", + const P2POption* opt = nullptr); #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/phi/core/distributed/comm_task.h b/paddle/phi/core/distributed/comm_task.h new file mode 100644 index 00000000000000..3673c7a9e21aab --- /dev/null +++ b/paddle/phi/core/distributed/comm_task.h @@ -0,0 +1,158 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/macros.h" + +#if defined(PADDLE_WITH_RCCL) +#include "paddle/phi/backends/dynload/rccl.h" +#endif +#if defined(PADDLE_WITH_NCCL) +#include "paddle/phi/backends/dynload/nccl.h" +#endif + +namespace phi { +namespace distributed { + +class Store; +class CommTask { + public: + CommTask(const std::string& backend = "", + const phi::Place& place = phi::Place(), + int rank = -1, + int size = 0, + int gid = 0, + uint64_t seq = 0, + int64_t numel = 0, + ncclComm_t nccl_comm = nullptr, + gpuStream_t nccl_stream = nullptr, + CommType comm_type = CommType::UNKNOWN) + : backend_(backend), + place_(place), + rank_(rank), + size_(size), + gid_(gid), + seq_(seq), + numel_(numel), + nccl_comm_(nccl_comm), + nccl_stream_(nccl_stream), + comm_type_(comm_type) { + const char* global_rank = std::getenv("PADDLE_TRAINER_ID"); + PADDLE_ENFORCE_NOT_NULL( + global_rank, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + global_rank_ = std::atoi(global_rank); + } + virtual ~CommTask() = default; + + std::string UniqueKey() { + return "op:" + CommTypeToString(comm_type_) + + ",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_); + } + + std::string GetBackend() { return backend_; } + phi::Place GetPlace() { return place_; } + int GetGlobalRank() { return global_rank_; } + int GetRank() { return rank_; } + int GetSize() { return size_; } + int GetGid() { return gid_; } + int64_t GetNumel() { return numel_; } + uint64_t GetSeq() { return seq_; } + CommType GetCommType() { return comm_type_; } + bool GetTraceUpdated() { return start_trace_updated_; } + void SetTraceUpdated() { start_trace_updated_ = true; } + std::chrono::time_point GetStartTime() { + return start_time_; + } + std::shared_ptr GetStore() { return store_; } + void SetStore(std::shared_ptr store) { store_ = store; } + + ncclComm_t nccl_comm() { return nccl_comm_; } + gpuStream_t nccl_stream() { return nccl_stream_; } + + virtual std::string GetTraceMsg() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return ""; + } + virtual void StartRecord() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + virtual void EndRecord() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + + virtual std::string GetCommErrors() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return ""; + } + virtual bool IsStarted() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } + virtual bool IsTimeout() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } + virtual bool IsCompleted() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return false; + } + virtual void AbortComm() { + PADDLE_THROW( + phi::errors::Unimplemented("%s is not implemented.", __func__)); + return; + } + + protected: + std::string backend_; + phi::Place place_; + int global_rank_; + int rank_; + int size_; + int gid_; + uint64_t seq_{0}; + int64_t numel_; + ncclComm_t nccl_comm_; + gpuStream_t nccl_stream_; + CommType comm_type_; + bool start_trace_updated_{false}; + + bool completed_ = false; + bool aborted_{false}; + std::chrono::time_point start_time_; + std::shared_ptr store_; + + private: + DISABLE_COPY_AND_ASSIGN(CommTask); +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_task_manager.cc b/paddle/phi/core/distributed/comm_task_manager.cc new file mode 100644 index 00000000000000..37083119b59f59 --- /dev/null +++ b/paddle/phi/core/distributed/comm_task_manager.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_GLOO) +#include + +#include "paddle/phi/core/distributed/gloo_comm_context.h" +#include "paddle/phi/core/distributed/gloo_utils.h" +#include "paddle/phi/core/distributed/store/gloo_store.h" +#endif + +#include "paddle/phi/core/distributed/comm_context_manager.h" + +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/enforce.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/comm_task_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/distributed/trace_utils.h" +#endif + +namespace phi { +namespace distributed { + +std::thread CommTaskManager::comm_task_loop_thread_; +const int64_t CommTaskManager::loop_thread_sleep_millis = 10000; + +std::atomic CommTaskManager::terminated_; +std::mutex CommTaskManager::comm_task_list_mutex_; +std::condition_variable CommTaskManager::comm_task_list_cv_; +std::list> CommTaskManager::comm_task_list_; +std::unordered_map> + CommTaskManager::init_comm_task_map_; +std::unordered_map> + CommTaskManager::start_comm_task_map_; + +CommTaskManager::CommTaskManager() { + terminated_.store(false); + comm_task_loop_thread_ = std::thread(&CommTaskManager::CommTaskLoop, this); + LOG(INFO) << "CommTaskManager init success"; +} +CommTaskManager::~CommTaskManager() { + terminated_.store(true); + + if (comm_task_loop_thread_.joinable()) { + comm_task_loop_thread_.join(); + comm_task_list_cv_.notify_one(); + } + LOG(INFO) << "CommTaskManager destruct success."; +} + +void CommTaskManager::CommTaskEnqueue(std::shared_ptr comm_task) { + if (!terminated_.load()) { + std::lock_guard lock(comm_task_list_mutex_); + comm_task_list_.emplace_back(std::move(comm_task)); + } +} + +void CommTaskManager::CommTaskLoop() { + bool done = false; + while (!terminated_.load() || !done) { + std::unique_lock lock(comm_task_list_mutex_); + comm_task_list_cv_.wait_for( + lock, + std::chrono::milliseconds(loop_thread_sleep_millis), + [&]() -> bool { return terminated_.load(); }); + for (auto iter = comm_task_list_.begin(); iter != comm_task_list_.end();) { + auto task = *iter; + if (task->IsTimeout()) { + if (!task->IsStarted()) { + LOG(ERROR) << "Find timeout init but not start task: " + << task->GetTraceMsg() << ",comm:" << task->nccl_comm() + << ",stream:" << task->nccl_stream(); + std::string task_key = task->UniqueKey(); + init_comm_task_map_[task_key] = task; + } else if (!task->IsCompleted()) { + LOG(ERROR) << "Find timeout start but not finish task: " + << task->GetTraceMsg() << ",comm:" << task->nccl_comm() + << ",stream:" << task->nccl_stream(); + std::string task_key = task->UniqueKey(); + start_comm_task_map_[task_key] = task; + } + iter = comm_task_list_.erase(iter); + } else { + ++iter; + } + } + + for (auto iter = init_comm_task_map_.begin(); + iter != init_comm_task_map_.end();) { + auto task = iter->second; + if (task->IsStarted()) { + std::string task_key = task->UniqueKey(); + start_comm_task_map_[task_key] = task; + iter = init_comm_task_map_.erase(iter); + LOG(INFO) << "Start timeout task: " << task->GetTraceMsg(); + } else { + ++iter; + } + } + + for (auto iter = start_comm_task_map_.begin(); + iter != start_comm_task_map_.end();) { + auto task = iter->second; + if (task->IsCompleted()) { + iter = start_comm_task_map_.erase(iter); + LOG(INFO) << "Finish timeout task: " << task->GetTraceMsg(); + } else { + ++iter; + } + } + + if (comm_task_list_.empty() && init_comm_task_map_.empty() && + start_comm_task_map_.empty()) { + done = true; + } + } +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_task_manager.h b/paddle/phi/core/distributed/comm_task_manager.h new file mode 100644 index 00000000000000..58be0026dd0721 --- /dev/null +++ b/paddle/phi/core/distributed/comm_task_manager.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/distributed/comm_task.h" +#include "paddle/phi/core/macros.h" + +namespace phi { +namespace distributed { + +enum ErrorHandlingMode { NoHandling = 0, TearDown = 1 }; + +class Store; + +class CommTaskManager { + public: + CommTaskManager(); + ~CommTaskManager(); + + public: + static CommTaskManager& GetInstance() { + static CommTaskManager instance; + return instance; + } + + void CommTaskEnqueue(std::shared_ptr comm_task); + + private: + void CommTaskLoop(); + + static std::thread comm_task_loop_thread_; + static const int64_t loop_thread_sleep_millis; + + static std::atomic terminated_; + + static std::mutex comm_task_list_mutex_; + static std::condition_variable comm_task_list_cv_; + static std::list> comm_task_list_; + // not start task + static std::unordered_map> + init_comm_task_map_; + // start but not finish task + static std::unordered_map> + start_comm_task_map_; + std::shared_ptr store_; + bool store_error_{false}; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index bd49f0cff17086..d1d92c98fb0fd6 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -19,6 +19,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" #include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/nccl_tools.h" #include "paddle/phi/core/distributed/utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/data_type.h" @@ -31,9 +32,9 @@ constexpr bool FLAGS_enable_nccl_dynamic_check = false; NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id) : CommContext(rank, size) { - PADDLE_ENFORCE_GPU_SUCCESS( + NCCL_CHECK( phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&nccl_version_)); + NCCL_CHECK(phi::dynload::ncclGetVersion(&nccl_version_)); } int NCCLCommContext::GetNcclVersion() { return nccl_version_; } @@ -76,14 +77,13 @@ void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, if (FLAGS_enable_nccl_dynamic_check) { NCCLDynamicCheck::CheckShape(*out_tensor, root, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclBroadcast(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - root, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclBroadcast(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + root, + nccl_comm_, + stream)); } void NCCLCommContext::AllGather(phi::DenseTensor* out_tensor, @@ -100,13 +100,12 @@ void NCCLCommContext::AllGather(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllGather(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclAllGather(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + nccl_comm_, + stream)); } void NCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -123,14 +122,13 @@ void NCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclReduceScatter(in_tensor.data(), - out_tensor->data(), - out_tensor->numel(), - ToNCCLDataType(in_tensor.type()), - reduce_type, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclReduceScatter(in_tensor.data(), + out_tensor->data(), + out_tensor->numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + nccl_comm_, + stream)); } void NCCLCommContext::Send(const phi::DenseTensor& in_tensor, @@ -143,13 +141,12 @@ void NCCLCommContext::Send(const phi::DenseTensor& in_tensor, NCCLDynamicCheck::CheckShape(in_tensor, rank_, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclSend(in_tensor.data(), - count, - ToNCCLDataType(in_tensor.dtype()), - peer, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclSend(in_tensor.data(), + count, + ToNCCLDataType(in_tensor.dtype()), + peer, + nccl_comm_, + stream)); VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims()) << " to " << peer; } @@ -163,13 +160,12 @@ void NCCLCommContext::Recv(phi::DenseTensor* out_tensor, NCCLDynamicCheck::CheckShape(*out_tensor, peer, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclRecv(out_tensor->data(), - count, - ToNCCLDataType(out_tensor->dtype()), - peer, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclRecv(out_tensor->data(), + count, + ToNCCLDataType(out_tensor->dtype()), + peer, + nccl_comm_, + stream)); VLOG(3) << "rank " << GetRank() << " recv " << phi::product(out_tensor->dims()) << " from " << peer; } @@ -189,14 +185,13 @@ void NCCLCommContext::AllReduce(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllReduce(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - reduce_type, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclAllReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + nccl_comm_, + stream)); } void NCCLCommContext::Reduce(phi::DenseTensor* out_tensor, @@ -215,15 +210,14 @@ void NCCLCommContext::Reduce(phi::DenseTensor* out_tensor, rank_, nccl_comm_); } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclReduce(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - reduce_type, - root, - nccl_comm_, - stream)); + NCCL_CHECK(phi::dynload::ncclReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + root, + nccl_comm_, + stream)); } void NCCLCommContext::GroupStart() { diff --git a/paddle/phi/core/distributed/nccl_comm_task.cc b/paddle/phi/core/distributed/nccl_comm_task.cc new file mode 100644 index 00000000000000..f82f39c1954a3d --- /dev/null +++ b/paddle/phi/core/distributed/nccl_comm_task.cc @@ -0,0 +1,219 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/nccl_comm_task.h" + +#include "gflags/gflags.h" +#include "glog/logging.h" + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/distributed/nccl_tools.h" +#include "paddle/phi/core/distributed/trace_utils.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { +namespace distributed { + +NCCLCommTask::NCCLCommTask(const phi::Place& place, + int rank, + int size, + int gid, + uint64_t seq, + int64_t numel, + bool sync_op, + bool use_calc_stream, + ncclComm_t nccl_comm, + gpuStream_t stream, + CommType comm_type, + int64_t timeout) + : CommTask("NCCL", + place, + rank, + size, + gid, + seq, + numel, + nccl_comm, + stream, + comm_type), + sync_op_(sync_op), + use_calc_stream_(use_calc_stream) { + start_trace_updated_ = false; + start_event_created_ = false; + end_event_created_ = false; + start_time_ = std::chrono::steady_clock::now(); + timeout_ = std::chrono::milliseconds(timeout); +} + +void NCCLCommTask::StartRecord() { + backends::gpu::GPUDeviceGuard guard(place_.device); + if (!start_event_created_) { +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventCreateWithFlags(&nccl_start_event_, cuda_event_flags_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventCreateWithFlags(&nccl_start_event_, hip_event_flags_)); +#endif + start_event_created_ = true; + } +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventRecord(nccl_start_event_, nccl_stream_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventRecord(nccl_start_event_, nccl_stream_)); +#endif +} +void NCCLCommTask::EndRecord() { + backends::gpu::GPUDeviceGuard guard(place_.device); + if (!end_event_created_) { +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventCreateWithFlags(&nccl_end_event_, cuda_event_flags_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventCreateWithFlags(&nccl_end_event_, hip_event_flags_)); +#endif + end_event_created_ = true; + } +#ifdef PADDLE_WITH_CUDA + CUDA_CHECK(cudaEventRecord(nccl_end_event_, nccl_stream_)); +#else // PADDLE_WITH_HIP + HIP_CHECK(hipEventRecord(nccl_end_event_, nccl_stream_)); +#endif +} + +bool NCCLCommTask::CudaEventQuery(gpuEvent_t event) { +#ifdef PADDLE_WITH_CUDA + cudaError_t ret = cudaEventQuery(event); + if (ret == cudaSuccess) { + return true; + } else if (ret != cudaErrorNotReady) { + CUDA_CHECK(ret); + } else { + // ignore and clear the error if not ready + CUDA_CHECK(cudaGetLastError()); + } +#else // PADDLE_WITH_HIP + hipError_t ret = hipEventQuery(event); + if (ret == hipSuccess) { + return true; + } else if (ret != hipErrorNotReady) { + HIP_CHECK(ret); + } else { + // ignore and clear the error if not ready + HIP_CHECK(hipGetLastError()); + } +#endif + return false; +} + +std::string GetNCCLErrorDetail(ncclResult_t result) { + std::string detail; + std::string last_error; +#ifdef ENABLE_NCCL_GET_LAST_ERROR + last_error = + ", Last error: " + std::string(phi::dynload::ncclGetLastError(NULL)); +#endif + switch (result) { + case ncclUnhandledCudaError: + detail = "ncclUnhandledCudaError: Call to CUDA function failed."; + break; + case ncclSystemError: + detail = + "ncclSystemError: System call (e.g. socket, malloc) or external " + "library call failed or device error. "; +#ifndef NCCL_REMOTE_ERROR + // Before ncclRemoteError was created, unexpected remote disconnect was + // categorized as ncclSystemError + detail += "It can be also caused by unexpected exit of a remote peer."; +#endif + break; + case ncclInternalError: + detail = "ncclInternalError: Internal check failed."; + break; + case ncclInvalidArgument: + detail = "ncclInvalidArgument: Invalid value for an argument."; + break; + case ncclInvalidUsage: + detail = + "ncclInvalidUsage: This usually reflects invalid usage of NCCL " + "library."; + break; +#ifdef NCCL_REMOTE_ERROR + case ncclRemoteError: + detail = + "ncclRemoteError: A call failed possibly due to a network error or a " + "remote process exiting prematurely."; + break; +#endif + default: + detail = "Unknown NCCL error!"; + } + return detail + last_error; +} + +std::string NCCLCommTask::GetCommErrors() { + std::unique_lock lock(mutex_); + if (!comm_error_.empty()) { + return comm_error_; + } + + ncclResult_t nccl_async_error; + NCCL_CHECK( + phi::dynload::ncclCommGetAsyncError(nccl_comm_, &nccl_async_error)); + if (nccl_async_error != ncclSuccess) { + comm_error_ = + "\n\t Find nccl comm error: " + GetNCCLErrorDetail(nccl_async_error); + } + return comm_error_; +} + +bool NCCLCommTask::IsStarted() { return CudaEventQuery(nccl_start_event_); } + +bool NCCLCommTask::IsCompleted() { return CudaEventQuery(nccl_end_event_); } + +bool NCCLCommTask::IsTimeout() { + auto current_timepoint = std::chrono::steady_clock::now(); + return std::chrono::duration_cast( + current_timepoint - start_time_) >= timeout_; +} + +void NCCLCommTask::AbortComm() { + std::unique_lock lock(mutex_); + if (aborted_) { + return; + } + NCCL_CHECK(phi::dynload::ncclCommAbort(nccl_comm_)); + + aborted_ = true; + nccl_comm_ = nullptr; + return; +} + +std::string NCCLCommTask::GetTraceMsg() { + auto current_timepoint = std::chrono::steady_clock::now(); + auto time_elapsed = std::chrono::duration_cast( + current_timepoint - start_time_); + return "op:" + CommTypeToString(comm_type_) + ",gid:" + std::to_string(gid_) + + ",seq:" + std::to_string(seq_) + + ",started:" + std::to_string(IsStarted()) + + ",completed:" + std::to_string(IsCompleted()) + + ",global_rank:" + std::to_string(global_rank_) + + ",local_rank:" + std::to_string(rank_) + + ",size:" + std::to_string(size_) + ",numel:" + std::to_string(numel_) + + ",sync_op:" + std::to_string(sync_op_) + + ",use_calc_stream:" + std::to_string(use_calc_stream_) + + ",timeout:" + std::to_string(timeout_.count()) + + ",is_timeout:" + std::to_string(IsTimeout()) + + ",time_elapsed:" + std::to_string(time_elapsed.count()); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_task.h b/paddle/phi/core/distributed/nccl_comm_task.h new file mode 100644 index 00000000000000..9fe71670c2f88b --- /dev/null +++ b/paddle/phi/core/distributed/nccl_comm_task.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/distributed/comm_task.h" +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/macros.h" + +#if defined(PADDLE_WITH_RCCL) +#include "paddle/phi/backends/dynload/rccl.h" +#else +#include "paddle/phi/backends/dynload/nccl.h" +#endif + +namespace phi { +class DenseTensor; +namespace distributed { + +static int64_t DefaultTimeout = 30 * 60 * 1000; + +class NCCLCommTask : public CommTask { + public: + NCCLCommTask(const phi::Place& place = phi::Place(), + int rank = -1, + int size = 0, + int gid = 0, + uint64_t seq = 0, + int64_t numel = 0, + bool sync_op = true, + bool use_calc_stream = false, + ncclComm_t = nullptr, + gpuStream_t = nullptr, + CommType comm_type = CommType::UNKNOWN, + int64_t timeout = DefaultTimeout); + ~NCCLCommTask() = default; + + // check whether the nccl kernel started + bool IsStarted() override; + bool IsTimeout() override; + bool IsCompleted() override; + + std::string GetTraceMsg() override; + std::string GetCommErrors() override; + void AbortComm() override; + + void StartRecord(); + void EndRecord(); + + bool CudaEventQuery(gpuEvent_t event); + + protected: + std::mutex mutex_; + std::chrono::milliseconds timeout_; + +#ifdef PADDLE_WITH_CUDA + unsigned int cuda_event_flags_ = cudaEventDisableTiming; +#else // PADDLE_WITH_HIP + unsigned int hip_event_flags_ = hipEventDisableTiming; +#endif + + bool sync_op_; + bool use_calc_stream_; + + bool start_event_created_; + bool end_event_created_; + gpuEvent_t nccl_start_event_; + gpuEvent_t nccl_end_event_; + + std::string comm_error_; + + private: + DISABLE_COPY_AND_ASSIGN(NCCLCommTask); +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/fluid/distributed/collective/nccl_tools.cc b/paddle/phi/core/distributed/nccl_tools.cc similarity index 51% rename from paddle/fluid/distributed/collective/nccl_tools.cc rename to paddle/phi/core/distributed/nccl_tools.cc index 940c8d47ccb882..e419cfca905fa5 100644 --- a/paddle/fluid/distributed/collective/nccl_tools.cc +++ b/paddle/phi/core/distributed/nccl_tools.cc @@ -12,14 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/distributed/collective/nccl_tools.h" +#include "paddle/phi/core/distributed/nccl_tools.h" #include #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" -namespace paddle { +#if NCCL_VERSION_CODE >= 21300 +#define ENABLE_NCCL_GET_LAST_ERROR +#define NCCL_REMOTE_ERROR +#endif + +namespace phi { namespace distributed { ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { @@ -47,5 +52,43 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) { return oss.str(); } +std::string NCCLDTypeToString(ncclDataType_t dtype) { +#define PD_NCCL_DTYPE_TO_STR(__nccl_dtype, __str_dtype) \ + if (dtype == __nccl_dtype) return __str_dtype; + PD_NCCL_DTYPE_TO_STR(ncclFloat, "float32"); + PD_NCCL_DTYPE_TO_STR(ncclFloat32, "float32"); + PD_NCCL_DTYPE_TO_STR(ncclHalf, "float16"); + PD_NCCL_DTYPE_TO_STR(ncclFloat16, "float16"); +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + PD_NCCL_DTYPE_TO_STR(ncclBfloat16, "bfloat16"); +#endif + PD_NCCL_DTYPE_TO_STR(ncclDouble, "float64"); + PD_NCCL_DTYPE_TO_STR(ncclFloat64, "float64"); + + PD_NCCL_DTYPE_TO_STR(ncclInt8, "int8"); + PD_NCCL_DTYPE_TO_STR(ncclChar, "int8"); + PD_NCCL_DTYPE_TO_STR(ncclUint8, "uint8"); + PD_NCCL_DTYPE_TO_STR(ncclInt32, "int32"); + PD_NCCL_DTYPE_TO_STR(ncclInt, "int32"); + PD_NCCL_DTYPE_TO_STR(ncclUint32, "uint32"); + PD_NCCL_DTYPE_TO_STR(ncclInt64, "int64"); + PD_NCCL_DTYPE_TO_STR(ncclUint64, "uint64"); + +#undef PD_NCCL_DTYPE_TO_STR + PADDLE_THROW(phi::errors::InvalidArgument( + "This datatype %d in nccl is not supported.", static_cast(dtype))); +} + +std::string NCCLRedTypeToString(ncclRedOp_t op) { + if (op == ncclSum) return "SUM"; + if (op == ncclProd) return "PROD"; + if (op == ncclMin) return "MIN"; + if (op == ncclMax) return "MAX"; +#if NCCL_VERSION_CODE >= 21000 + if (op == ncclAvg) return "AVG"; +#endif + return "UDF_" + std::to_string(op); +} + } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_tools.h b/paddle/phi/core/distributed/nccl_tools.h new file mode 100644 index 00000000000000..0ab380a4177838 --- /dev/null +++ b/paddle/phi/core/distributed/nccl_tools.h @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/distributed/types.h" + +#ifdef PADDLE_WITH_RCCL +#include +#include "paddle/phi/backends/dynload/rccl.h" +#else +#include +#include "paddle/phi/backends/dynload/nccl.h" +#endif + +namespace phi { +namespace distributed { + +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + PADDLE_THROW( \ + phi::errors::External("Failed, NCCL error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + phi::dynload::ncclGetErrorString(r))); \ + } \ + } while (0) + +#ifdef PADDLE_WITH_NCCL +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t r = expr; \ + if (r != cudaSuccess) { \ + PADDLE_THROW(phi::errors::External("Failed, cuda error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(r))); \ + } \ + } while (0) +#else // PADDLE_WITH_RCCL +#define HIP_CHECK(expr) \ + do { \ + hipError_t r = expr; \ + if (r != hipSuccess) { \ + PADDLE_THROW(phi::errors::External("Failed, hip error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + hipGetErrorString(r))); \ + } \ + } while (0) +#endif + +ncclRedOp_t ToNCCLRedType(ReduceOp reduction); + +std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); + +std::string NCCLDTypeToString(ncclDataType_t dtype); + +std::string NCCLRedTypeToString(ncclRedOp_t op); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/store/store.cc b/paddle/phi/core/distributed/store/store.cc index 7e7db8895b99f1..5987b694b4e51e 100644 --- a/paddle/phi/core/distributed/store/store.cc +++ b/paddle/phi/core/distributed/store/store.cc @@ -28,6 +28,11 @@ std::vector Store::get(const std::string& key) { errors::InvalidArgument("Implement the get method in the subclass.")); } +bool Store::check(const std::string& key) { + PADDLE_THROW( + errors::InvalidArgument("Implement the get method in the subclass.")); +} + void Store::wait(const std::string& key) { PADDLE_THROW( errors::InvalidArgument("Implement the wait method in the subclass.")); diff --git a/paddle/phi/core/distributed/store/store.h b/paddle/phi/core/distributed/store/store.h index fa509586eefdf2..4ecd4cb8b5d995 100644 --- a/paddle/phi/core/distributed/store/store.h +++ b/paddle/phi/core/distributed/store/store.h @@ -29,6 +29,7 @@ class Store { virtual int64_t add(const std::string& key, int64_t value); virtual std::vector get(const std::string& key); + virtual bool check(const std::string& key); virtual void wait(const std::string& key); virtual void set(const std::string& key, const std::vector& value); diff --git a/paddle/phi/core/distributed/store/tcp_store.cc b/paddle/phi/core/distributed/store/tcp_store.cc index 6fbe2aa6761e2c..46af21fa943562 100644 --- a/paddle/phi/core/distributed/store/tcp_store.cc +++ b/paddle/phi/core/distributed/store/tcp_store.cc @@ -110,6 +110,19 @@ void MasterDaemon::_do_get(SocketType socket) { tcputils::send_vector(socket, value); } +void MasterDaemon::_do_check(SocketType socket) { + std::string key = tcputils::receive_string(socket); + VLOG(4) << "MasterDaemon::_do_check key(" << key << ") " + << GetSockName(socket); + + auto iter = _store.find(key); + if (iter != _store.end()) { + tcputils::send_value(socket, ReplyType::READY); + } else { + tcputils::send_value(socket, ReplyType::NOT_READY); + } +} + #ifndef _WIN32 void MasterDaemon::InitControlFd() { PADDLE_ENFORCE_NE( @@ -190,6 +203,9 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { case Command::GET: _do_get(fds[i].fd); break; + case Command::CHECK: + _do_check(fds[i].fd); + break; case Command::SET: _do_set(fds[i].fd); break; @@ -420,6 +436,17 @@ std::vector TCPStore::get(const std::string& key) { return _client->receive_vector(); } +bool TCPStore::check(const std::string& key) { + _client->send_command_for_key(Command::CHECK, _key_prefix + key); + VLOG(3) << "TCPStore check."; + auto response = _client->receive_value(); + if (response == ReplyType::READY) { + return true; + } else { + return false; + } +} + void TCPStore::wait(const std::string& key) { ReplyType reply; // NOLINT VLOG(7) << "TCPStore wait."; diff --git a/paddle/phi/core/distributed/store/tcp_store.h b/paddle/phi/core/distributed/store/tcp_store.h index 0f17bc9b58bd45..4cc3a1933bd5d1 100644 --- a/paddle/phi/core/distributed/store/tcp_store.h +++ b/paddle/phi/core/distributed/store/tcp_store.h @@ -37,8 +37,8 @@ namespace phi { namespace distributed { -enum class ReplyType { WAITING, STOP_WAIT }; -enum class Command { ADD, GET, SET, WAIT, STOP }; +enum class ReplyType { WAITING, STOP_WAIT, READY, NOT_READY }; +enum class Command { ADD, GET, CHECK, SET, WAIT, STOP }; namespace detail { @@ -59,6 +59,7 @@ class MasterDaemon { void _do_add(SocketType socket); void _do_wait(SocketType socket); void _do_get(SocketType socket); + void _do_check(SocketType socket); void _do_set(SocketType socket); void _notify_waiting_sockets(const std::string&); SocketType _listen_socket; @@ -130,6 +131,7 @@ class TCPStore : public Store { int64_t add(const std::string& key, int64_t value) override; std::vector get(const std::string& key) override; + bool check(const std::string& key) override; void wait(const std::string& key) override; void set(const std::string& key, const std::vector& value) override; diff --git a/paddle/phi/core/distributed/trace_utils.h b/paddle/phi/core/distributed/trace_utils.h new file mode 100644 index 00000000000000..7a34055a987bce --- /dev/null +++ b/paddle/phi/core/distributed/trace_utils.h @@ -0,0 +1,187 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/utils/string/split.h" + +namespace phi { +namespace distributed { + +enum TraceEventType { + TraceEventStart, + TraceEventEnd, +}; + +using TraceMap = + std::map>>; + +inline std::string GetTraceStartKey(const std::string& backend, + int rank, + int gid) { + return backend + "_" + std::to_string(rank) + "_" + std::to_string(gid) + + "_trace_start"; +} + +inline std::string GetTraceEndKey(const std::string& backend, + int rank, + int gid) { + return backend + "_" + std::to_string(rank) + "_" + std::to_string(gid) + + "_trace_end"; +} + +inline std::string GetExceptionMsgFromExceptionPtr( + const std::exception_ptr& exception_ptr) { + if (exception_ptr == nullptr) { + return "No exception found"; + } + try { + std::rethrow_exception(exception_ptr); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown exception type"; + } +} + +inline bool UpdateTraceMsg(std::shared_ptr store, + const std::string& key, + uint64_t seq, + const std::string& comm_type) { + std::vector value(comm_type.size() + sizeof(seq) + 1); + memcpy(value.data(), &seq, sizeof(seq)); + memcpy(value.data() + sizeof(seq), comm_type.data(), comm_type.size()); + try { + store->set(key, value); + return true; + } catch (...) { + LOG(ERROR) << "Store is down while updating trace msg, with seq: " << seq + << ", key " << key; + return false; + } +} + +inline bool ParseTraceValue(std::shared_ptr store, + const std::string& key, + uint64_t* seq, + std::string* comm_type) { + try { + std::vector value = store->get(key); + memcpy(seq, value.data(), sizeof(*seq)); + std::string type_value( + reinterpret_cast(value.data() + sizeof(*seq))); + *comm_type = type_value; + return true; + } catch (...) { + LOG(ERROR) << "Store is down while parsing trace value, with key: " << key; + return false; + } +} + +inline std::string RanksToString(const std::vector& ranks) { + std::string result; + for (int rank : ranks) { + if (result.empty()) { + result += std::to_string(rank); + } else { + result += ", " + std::to_string(rank); + } + } + return result; +} + +inline std::string AnalyzeTraceMsg(const TraceMap& trace_map, int gid) { + uint64_t lag_seq = trace_map.begin()->first; + std::vector start_ranks; + std::vector end_ranks; + for (auto& p : trace_map.begin()->second) { + if (p.second.second == TraceEventStart) { + start_ranks.emplace_back(p.first); + } else { + end_ranks.emplace_back(p.first); + } + } + + std::string result = "\n\t The ranks that has desync problem are: "; + if (start_ranks.size()) { + result += "[" + RanksToString(start_ranks) + + "] joined but do not finish collective seq: " + + std::to_string(lag_seq) + " in group_id: " + std::to_string(gid); + } + if (end_ranks.size()) { + result += ", ranks [" + RanksToString(end_ranks) + + "] finished collective seq: " + std::to_string(lag_seq) + + ", but didnt join seq: " + std::to_string(lag_seq + 1) + + " in group_id: " + std::to_string(gid); + } + return result; +} + +inline std::string GenerateTraceMsg(std::shared_ptr store, + const std::string& backend, + int curr_rank, + int group_id, + int world_size) { + std::string result; + TraceMap trace_map; + + uint64_t curr_seq; + std::string curr_comm_type; + + for (int rank = 0; rank < world_size; ++rank) { + uint64_t seq_start = 0; + { + std::string trace_start_key = GetTraceStartKey(backend, rank, group_id); + if (!store->check(trace_start_key)) { + continue; + } + + std::string comm_type; + if (!ParseTraceValue(store, trace_start_key, &seq_start, &comm_type)) { + return result; + } + trace_map[seq_start].emplace(rank, + std::make_pair(comm_type, TraceEventStart)); + if (rank == curr_rank) { + curr_seq = seq_start; + curr_comm_type = std::move(comm_type); + } + } + { + std::string trace_end_key = GetTraceEndKey(backend, rank, group_id); + if (!store->check(trace_end_key)) { + continue; + } + + uint64_t seq = 0; + std::string comm_type; + if (!ParseTraceValue(store, trace_end_key, &seq, &comm_type)) { + return result; + } + if (seq == seq_start) { + trace_map[seq][rank].second = TraceEventEnd; + } + } + } + result += "\n\t Problem summary: rank: " + std::to_string(curr_rank) + + " timeout at collective: " + curr_comm_type + + ", group_id: " + std::to_string(group_id) + + ", seq: " + std::to_string(curr_seq); + result += AnalyzeTraceMsg(trace_map, group_id); + return result; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/fluid/distributed/collective/types.h b/paddle/phi/core/distributed/types.h similarity index 97% rename from paddle/fluid/distributed/collective/types.h rename to paddle/phi/core/distributed/types.h index bd20f2705f22fb..3d4d074efd735f 100644 --- a/paddle/fluid/distributed/collective/types.h +++ b/paddle/phi/core/distributed/types.h @@ -20,7 +20,7 @@ #include "paddle/phi/common/place.h" -namespace paddle { +namespace phi { namespace distributed { // TODO(shenliang03): To support AVG for reduce @@ -58,4 +58,4 @@ struct ReduceScatterOptions { }; } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/core/distributed/utils.h b/paddle/phi/core/distributed/utils.h index f635b7d99fa610..40b28bb2a3e6f5 100644 --- a/paddle/phi/core/distributed/utils.h +++ b/paddle/phi/core/distributed/utils.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ // limitations under the License. #pragma once - #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -28,13 +27,119 @@ inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor& tensor, return tensor_flattened.Slice(offset, offset + numel); } -#define NCCL_CHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +inline void* GetPointerByOffset(void* raw_pointer, + size_t offset, + phi::DataType type) { + if (type == phi::DataType::FLOAT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::FLOAT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::FLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::INT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::INT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::INT8) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::UINT8) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::BOOL) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == phi::DataType::BFLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Datatype %s in NCCL is not supported.", type)); + } + return nullptr; +} + +inline void CheckSizeOnEachRank(const phi::DDim& tensor_dim, + const std::vector& size_on_each_rank, + int world_size) { + int length_size_on_each_rank = size_on_each_rank.size(); + PADDLE_ENFORCE_EQ( + length_size_on_each_rank, + world_size, + phi::errors::InvalidArgument( + "The length of size_on_each_rank must be equal to world_size.")); + + int64_t sum_size_on_each_rank = + std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0); + PADDLE_ENFORCE_EQ( + sum_size_on_each_rank, + tensor_dim[0], + phi::errors::InvalidArgument( + "The sum of size_on_each_rank must be equal to tensor's dim[0].")); +} + +enum class CommType : std::uint8_t { + BROADCAST = 0, + ALLREDUCE = 1, + ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce + REDUCE = 3, + ALLGATHER = 4, + GATHER = 5, + SCATTER = 6, + REDUCE_SCATTER = 7, + ALLTOALL = 8, + SEND = 9, + RECV = 10, + BARRIER = 11, + UNKNOWN = 100, +}; + +inline bool IsP2POP(CommType comm_type, bool is_batch_p2p = false) { + if (is_batch_p2p) { + return false; + } else { + return comm_type == CommType::SEND || comm_type == CommType::RECV; + } +} + +inline std::string CommTypeToString(CommType CommType) { + switch (CommType) { + case CommType::BROADCAST: + return "Broadcast"; + case CommType::ALLREDUCE: + return "AllReduce"; + case CommType::ALLREDUCE_SPARSE: + return "AllReduce_Sparse"; + case CommType::REDUCE: + return "Reduce"; + case CommType::ALLGATHER: + return "AllGather"; + case CommType::GATHER: + return "Gather"; + case CommType::SCATTER: + return "Scatter"; + case CommType::REDUCE_SCATTER: + return "ReduceScatter"; + case CommType::ALLTOALL: + return "AllToAll"; + case CommType::SEND: + return "Send"; + case CommType::RECV: + return "Recv"; + case CommType::BARRIER: + return "Barrier"; + case CommType::UNKNOWN: + return "Unknown"; + default: + return "Unknown"; + } + return "Unknown"; +} } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index c7a0a81c7fb4f4..d91b5de90a584d 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1134,10 +1134,17 @@ PHI_DEFINE_EXPORTED_bool(gpugraph_debug_gpu_memory, * Example: * Note: nccl blocking wait. */ + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PHI_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); #endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PHI_DEFINE_EXPORTED_bool(benchmark_nccl, + false, + "enable nccl debug mode to synchronize nccl comm"); +#endif + /** * Autotune related FLAG * Name: FLAGS_use_autotune @@ -1312,6 +1319,11 @@ PHI_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass, "Whether to apply inplace pass on lowering " "::pir::Program to Kernel Dialect"); +PHI_DEFINE_EXPORTED_string( + ir_inplace_kernel_blacklist, + "", + "It controls the ir inplace kernel subset do not use."); + PHI_DEFINE_EXPORTED_bool(enable_record_memory, false, "Enable memory recorder"); PHI_DEFINE_EXPORTED_bool( @@ -1350,3 +1362,18 @@ PHI_DEFINE_EXPORTED_bool(dynamic_static_unified_comm, "Whether to use new communication library in auto " "parallel and static mode."); #endif // FLAGS_dynamic_static_unified_comm + +/** + * ProcessGroupNCCL related FLAG + * Name: enable_async_trace + * Since Version: + * Value Range: bool, default=false + * Example: + * Note: enable nccl async trace. + */ + +PHI_DEFINE_EXPORTED_bool(enable_async_trace, + false, + "enable collective async trace"); + +PHI_DEFINE_EXPORTED_int32(async_trace_count, 5, "collective async trace count"); diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 0aca25103f80a7..e7062879573c54 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -228,6 +228,8 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const MetaTensor& bias, const MetaTensor& branch, const MetaTensor& branch_max, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -378,6 +380,8 @@ void FcXPUInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& w_max, const MetaTensor& bias, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -1914,8 +1918,8 @@ void FusedEmbeddingEltWiseLayerNormInferMeta( auto dim_output = phi::make_ddim({batch, seq_len, hidden}); out->set_dims(dim_output); - // out->share_lod(ids); - // context->ShareLoD("Ids", /*->*/ "Out"); + out->share_lod(*ids[0]); + out->set_dtype((*embs[0]).dtype()); } void FusionTransposeFlattenConcatInferMeta( @@ -1977,6 +1981,7 @@ void FusionTransposeFlattenConcatInferMeta( out_dims[concat_axis] = -1; } out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype((*x[0]).dtype()); } void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, @@ -2158,13 +2163,304 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, } out->set_dims(y_dims); + out->set_dtype(x.dtype()); if (mean) { + mean->set_dtype(x.dtype()); mean->set_dims({dim_0}); } if (variance) { variance->set_dims({dim_0}); + variance->set_dtype(x.dtype()); } out->share_lod(x); } +void FusionGRUInferMeta(const MetaTensor& x, + const MetaTensor& h0, + const MetaTensor& weight_x, + const MetaTensor& weight_h, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + MetaTensor* reordered_h0, + MetaTensor* xx, + MetaTensor* batched_input, + MetaTensor* batched_out, + MetaTensor* hidden) { + std::string mkldnn_data_type_list[] = {"float32", "int8", "bfloat16"}; + PADDLE_ENFORCE_EQ( + std::find(std::begin(mkldnn_data_type_list), + std::end(mkldnn_data_type_list), + mkldnn_data_type) != std::end(mkldnn_data_type_list), + true, + phi::errors::InvalidArgument("The mkldnn_data_type shoule be [float32, " + "int8, bfloat16], but found %s.", + mkldnn_data_type.c_str())); + + DDim x_dims = x.dims(); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? phi::flatten_to_2d(x_dims, 1) + : x_dims; + PADDLE_ENFORCE_EQ( + x_mat_dims.size(), + 2, + phi::errors::InvalidArgument("The size of input X dims should be 2, " + "or 3 with second dimension equal to " + "1, but now Input X dim is:[%s] ", + x_dims)); + + auto wx_dims = weight_x.dims(); + PADDLE_ENFORCE_EQ(wx_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(WeightX) should be 2, but received " + "WeightX dim size is:%d, WeightX dim is:[%s] ", + wx_dims.size(), + wx_dims)); + PADDLE_ENFORCE_EQ( + wx_dims[0], + x_mat_dims[1], + phi::errors::InvalidArgument( + "The first dimension of flattened WeightX" + "should equal to last dimension of flattened input X, but " + "received fattened WeightX dimension is:%d, flattened X dimension " + "is:%d", + wx_dims[0], + x_mat_dims[1])); + + int frame_size = static_cast(wx_dims[1] / 3); + auto wh_dims = weight_h.dims(); + + PADDLE_ENFORCE_EQ(wh_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(WeightH) should be 2, but received " + "WeightH dim size is:%d, WeightH dim is:[%s]", + wh_dims.size(), + wh_dims)); + PADDLE_ENFORCE_EQ(wh_dims[0], + frame_size, + phi::errors::InvalidArgument( + "The first dimension of WeightH " + "should equal to frame_size, but received WeightH's " + "first dimension is: " + "%d, frame size is:%d", + wh_dims[0], + frame_size)); + PADDLE_ENFORCE_EQ(wh_dims[1], + 3 * frame_size, + phi::errors::InvalidArgument( + "The second dimension of Input(WeightH) " + "should equal to 3 * frame_size, but received WeightH " + "is:%d, frame size is:%d", + wh_dims[1], + frame_size)); + + if (h0) { + auto h0_dims = h0.dims(); + PADDLE_ENFORCE_EQ(h0_dims[1], + frame_size, + phi::errors::InvalidArgument( + "The width of H0 must be equal to frame_size, but " + "receiced the width of H0 is:%d, frame size is:%d", + h0_dims[1], + frame_size)); + reordered_h0->set_dtype(x.dtype()); + } + if (bias) { + auto b_dims = bias.dims(); + PADDLE_ENFORCE_EQ(b_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(Bias) should be 2, but received " + "Bias rank is:%d, Bias dim is:[%s]", + b_dims.size(), + b_dims)); + PADDLE_ENFORCE_EQ(b_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dimension of Input(Bias) should be 1, but " + "received Bias first dim is:%d, Bias dim is:[%s]", + b_dims[0], + b_dims)); + PADDLE_ENFORCE_EQ(b_dims[1], + frame_size * 3, + phi::errors::InvalidArgument( + "The shape of Bias must be [1, frame_size * 3], but " + "received bias dim is:[%s], frame size is:%d", + b_dims, + frame_size)); + } + DDim out_dims({x_mat_dims[0], frame_size}); + hidden->set_dims(out_dims); + hidden->share_lod(x); + hidden->set_dtype(x.dtype()); + int xx_width = 0; + if (use_seq) { + xx_width = static_cast(wx_dims[1]); + } else { + xx_width = static_cast(x_mat_dims[1] > wx_dims[1] ? wx_dims[1] + : x_mat_dims[1]); + batched_input->set_dims({x_mat_dims[0], wx_dims[1]}); + batched_input->set_dtype(x.dtype()); + batched_out->set_dims(out_dims); + batched_out->set_dtype(x.dtype()); + } + xx->set_dims({x_mat_dims[0], xx_width}); + xx->set_dtype(x.dtype()); + xx->share_lod(x); +} + +void FusionSeqConvEltAddReluInferMeta(const MetaTensor& x, + const MetaTensor& filter, + const MetaTensor& bias, + const int context_length, + const int context_start, + const int context_stride, + MetaTensor* out, + MetaTensor* col_mat) { + auto x_dims = x.dims(); + auto w_dims = filter.dims(); + PADDLE_ENFORCE_GT( + context_length, + 0, + phi::errors::InvalidArgument("context_length should be greater than 0, " + "but received context_length is: %d", + context_length)); + PADDLE_ENFORCE_EQ(context_stride, + 1, + phi::errors::InvalidArgument( + "Currently, FusionSeqConvEltAddReluOp only supports " + "contextStride=1, but received value is: %d.", + context_stride)); + + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(X) should be 2-D tensor, but reveiced value is: %d.", + x_dims.size())); + + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + phi::errors::InvalidArgument( + "Filter should be 2-D tensor, but reveiced value is: %d.", + w_dims.size())); + + PADDLE_ENFORCE_EQ(w_dims[0], + context_length * x_dims[1], + phi::errors::InvalidArgument( + "Filter's height should be equal to context_length * " + "input_hidden_size, but received Filter height is: %d," + "context_length is: %d, input_hidden_size is: %d.", + w_dims[0], + context_length, + x_dims[1])); + + PADDLE_ENFORCE_GT( + context_length + context_start, + 0, + phi::errors::InvalidArgument( + "contextStart size should be smaller than contextLength, " + "but received context_length is: %d, contextStart is: " + "%d.", + context_length, + context_start)); + out->set_dims({x_dims[0], w_dims[1]}); + col_mat->set_dims({x_dims[0], w_dims[0]}); + out->share_lod(x); + col_mat->set_dtype(x.dtype()); + out->set_dtype(x.dtype()); +} + +void FusionSeqExpandConcatFCInferMeta(const std::vector& x, + const MetaTensor& fc_weight, + const MetaTensor& fc_bias, + const std::string& fc_activation, + MetaTensor* out, + MetaTensor* fc_out) { + PADDLE_ENFORCE_GT(x.size(), + 1UL, + phi::errors::InvalidArgument( + "Inputs(X) of FusionSeqExpandConcatFCOp should larger " + "than 1, but received value is: %d.", + x.size())); + + std::vector ins_dims; + ins_dims.reserve(x.size()); + std::transform(x.begin(), + x.end(), + std::back_inserter(ins_dims), + [](const MetaTensor* var) { return var->dims(); }); + + auto w_dims = fc_weight.dims(); // (M0+M1+M2+..) x D + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(FCWeight)'s rank must be 2, but received value is: %d.", + w_dims.size())); + const int D = static_cast(w_dims[1]); + int sum = static_cast(ins_dims[0][1]); + for (size_t i = 1; i < ins_dims.size(); ++i) { + sum += static_cast(ins_dims[i][1]); + } + PADDLE_ENFORCE_EQ( + sum, + w_dims[0], + phi::errors::InvalidArgument("FC height should be sum of all inputs " + "width, but received FC height is: %d, " + "sum of all inputs width is: %d.", + w_dims[0], + sum)); + if (fc_bias) { + auto b_dims = fc_bias.dims(); + PADDLE_ENFORCE_EQ( + b_dims.size() == 1 || b_dims.size() == 2, + true, + phi::errors::InvalidArgument( + "FCBias dim should be 1 or 2, but received value is: %d.", + b_dims.size())); + if (b_dims.size() == 1) { + PADDLE_ENFORCE_EQ(b_dims[0], + D, + phi::errors::InvalidArgument( + "FCBias shapes must be %d when FCBias dim = 1, but " + "received value is: %d.", + D, + b_dims[0])); + } else { + PADDLE_ENFORCE_EQ(b_dims[0], + 1, + phi::errors::InvalidArgument( + "FCBias shapes must be 1x%d, when FCBias dim = 2, " + "but received dim[0] is: %d.", + D, + b_dims[0])); + PADDLE_ENFORCE_EQ(b_dims[1], + D, + phi::errors::InvalidArgument( + "FCBias shapes must be 1x%d, when FCBias dim = 2, " + "but received dim[1] is: %d.", + D, + b_dims[1])); + } + } + fc_out->set_dtype((*x[0]).dtype()); + out->set_dims({ins_dims[0][0], D}); + out->set_dtype((*x[0]).dtype()); + // fcout should be reshape when run since can not get lod in infershape + // explicit share the ref lod + out->share_lod(*x[0]); +} } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index c022a4257e4dc3..b6b9c64314ca83 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -62,6 +62,8 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const MetaTensor& bias, const MetaTensor& branch, const MetaTensor& branch_max, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -86,6 +88,8 @@ void FcXPUInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& w_max, const MetaTensor& bias, + const MetaTensor& scale_max, + const MetaTensor& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -515,4 +519,41 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, MetaTensor* variance, MetaConfig config = MetaConfig()); +void FusionGRUInferMeta(const MetaTensor& x, + const MetaTensor& h0, + const MetaTensor& weight_x, + const MetaTensor& weight_h, + const MetaTensor& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + MetaTensor* reordered_h0, + MetaTensor* xx, + MetaTensor* batched_input, + MetaTensor* batched_out, + MetaTensor* hidden); + +void FusionSeqConvEltAddReluInferMeta(const MetaTensor& x, + const MetaTensor& filter, + const MetaTensor& bias, + const int context_length, + const int context_start, + const int context_stride, + MetaTensor* out, + MetaTensor* col_mat); + +void FusionSeqExpandConcatFCInferMeta(const std::vector& x, + const MetaTensor& fc_weight, + const MetaTensor& fc_bias, + const std::string& fc_activation, + MetaTensor* out, + MetaTensor* fc_out); } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 0cd5534a9c44ab..cece7dd8807933 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -617,48 +617,52 @@ void BatchNormInferMeta(const MetaTensor& x, (data_layout == DataLayout::kNCHW) ? x_dims[1] : x_dims[x_dims.size() - 1]); - auto scale_dim = scale.dims(); - auto bias_dim = bias.dims(); + if (scale) { + PADDLE_ENFORCE_EQ( + scale.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "ShapeError: the dimension of scale must equal to 1." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale.dims().size(), + scale.dims().size())); + } - PADDLE_ENFORCE_EQ( - scale_dim.size(), - 1UL, - phi::errors::InvalidArgument( - "ShapeError: the dimension of scale must equal to 1." - "But received: the shape of scale is [%s], the dimension " - "of scale is [%d]", - scale_dim, - scale_dim.size())); - PADDLE_ENFORCE_EQ(bias_dim.size(), - 1UL, - phi::errors::InvalidArgument( - "ShapeError: the dimension of bias must equal to 1." - "But received: the shape of bias is [%s],the dimension " - "of bias is [%d]", - bias_dim, - bias_dim.size())); + if (bias) { + PADDLE_ENFORCE_EQ( + bias.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "ShapeError: the dimension of bias must equal to 1." + "But received: the shape of bias is [%s],the dimension " + "of bias is [%d]", + bias.dims(), + bias.dims().size())); + } bool check = true; - if ((!config.is_runtime) && - (phi::product(scale_dim) <= 0 || phi::product(bias_dim) <= 0)) { + if (!scale || !bias || + ((!config.is_runtime) && + (phi::product(scale.dims()) <= 0 || phi::product(bias.dims()) <= 0))) { check = false; } if (check) { - PADDLE_ENFORCE_EQ(scale_dim[0], + PADDLE_ENFORCE_EQ(scale.dims()[0], C, phi::errors::InvalidArgument( "ShapeError: the shape of scale must equal to [%d]" "But received: the shape of scale is [%d]", C, - scale_dim[0])); - PADDLE_ENFORCE_EQ(bias_dim[0], + scale.dims()[0])); + PADDLE_ENFORCE_EQ(bias.dims()[0], C, phi::errors::InvalidArgument( "ShapeError: the shape of bias must equal to [%d]" "But received: the shape of bias is [%d]", C, - bias_dim[0])); + bias.dims()[0])); } y->set_dims(x_dims); mean_out->set_dims({C}); diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 0e3ac3fb5ca2c8..fa791c2f80a3eb 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -41,13 +41,11 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); } -void CreateIntArrayInferMeta(const IntArray& data, +void CreateVecShapeInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out) { - CreateInferMetaBase({static_cast(data.GetData().size())}, - dtype, - DataLayout::NCHW, - out); + CreateInferMetaBase( + {static_cast(shape.size())}, dtype, DataLayout::NCHW, out); } void CreateInferMetaBase(const std::vector& shape, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 2f9c9a69a13f1e..1a765eff7c1111 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -35,7 +35,7 @@ void AssignValueInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out); -void CreateIntArrayInferMeta(const IntArray& data, +void CreateVecShapeInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out); diff --git a/paddle/phi/infermeta/spmd_rules/reduction.cc b/paddle/phi/infermeta/spmd_rules/reduction.cc index 24c90a1792341d..24fc64484f418e 100644 --- a/paddle/phi/infermeta/spmd_rules/reduction.cc +++ b/paddle/phi/infermeta/spmd_rules/reduction.cc @@ -29,8 +29,15 @@ using phi::distributed::auto_parallel::str_join; ////////////////// Utils Functions ////////////////// std::string GetOutputNotation(int input_ndim, const std::string& input_axes, - std::vector reduce_dims, + std::vector reduce_dims, bool keep_dim) { + // if input_axes is empty means reduce all + if (reduce_dims.empty()) { + for (int i = 0; i < input_ndim; ++i) { + reduce_dims.emplace_back(i); + } + } + // convert the negative dim value to normal dim value for (auto& reduce_dim : reduce_dims) { if (reduce_dim < 0) { @@ -40,7 +47,7 @@ std::string GetOutputNotation(int input_ndim, std::string output_axes = ""; for (int i = 0; i < input_ndim; i++) { - std::vector::iterator iter = + std::vector::iterator iter = std::find(reduce_dims.begin(), reduce_dims.end(), i); if (iter != reduce_dims.end()) { // if i is reduce dim, the corresponding input axis @@ -58,9 +65,10 @@ std::string GetOutputNotation(int input_ndim, return output_axes; } -SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, - const std::vector& axis, - bool keep_dim) { +SpmdInfo ReductionInferSpmdBase(const DistMetaTensor& x, + const std::vector& axis, + bool keep_dim, + int reduce_type) { // Step0: Verify input args based on reduction logic auto x_shape = phi::vectorize(x.dims()); int x_ndim = x_shape.size(); @@ -102,8 +110,8 @@ SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, // Step3.1 Output Partial std::vector partial_on_dims = ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); - out_dist_attr.set_partial_status( - partial_on_dims /*, handle reduce_type in future */); + out_dist_attr.set_partial_status(partial_on_dims, + static_cast(reduce_type)); // Step3.2 handle input tensor partial (TODO) // If the op is a linear op, i.e. `linearity` is true, it supports @@ -116,14 +124,37 @@ SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, VLOG(4) << "Input0 shape: [" << str_join(x_shape) << "] " << "dims_mapping: [" << str_join(x_dims_mapping) << "]"; VLOG(4) << "Output dims_mapping: [" + str_join(out_dims_mapping) + "] " - << "partial_on_dims: [" + str_join(partial_on_dims) + "]\n\n"; + << "partial_on_dims: [" + str_join(partial_on_dims) + << " with reduce_type " << reduce_type << "]\n\n"; return {{x_dist_attr_src}, {out_dist_attr}}; } +SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, + const std::vector& axis, + bool keep_dim) { + return ReductionInferSpmdBase( + x, axis, keep_dim, static_cast(ReduceType::kRedSum)); +} + +SpmdInfo ReductionMeanInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + bool keep_dim) { + return ReductionInferSpmdBase( + x, axis.GetData(), keep_dim, static_cast(ReduceType::kRedAvg)); +} + +SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim) { + return ReductionInferSpmdBase( + x, axis.GetData(), keep_dim, static_cast(ReduceType::kRedSum)); +} + SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, - const std::vector& axis, + const std::vector& axis, bool keep_dim) { // Step0: Verify input args based on reduction logic auto x_shape = phi::vectorize(x.dims()); @@ -174,5 +205,44 @@ SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr_dst}, {out_dist_attr_src}}; } +SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const IntArray& axis, + bool keep_dim, + bool reduce_all) { + TensorDistAttr x_dist_attr = out_grad.dist_attr(); + TensorDistAttr x_grad_dist_attr = out_grad.dist_attr(); + + std::vector x_dim = phi::vectorize(x.dims()); + std::vector out_grad_dim = phi::vectorize(out_grad.dims()); + + if (x_dim.size() != out_grad_dim.size()) { + auto dims_mapping = x_dist_attr.dims_mapping(); + auto axis_value = axis.GetData(); + + for (size_t i = 0; i < axis_value.size(); ++i) { + if (axis_value[i] < 0) { + axis_value[i] += x_dim.size(); + } + } + std::sort(axis_value.begin(), axis_value.end()); + + // if the input_axes is empty means to reduce all + if (axis_value.empty()) { + for (size_t i = 0; i < x_dim.size(); ++i) { + axis_value.emplace_back(i); + } + } + + for (const auto& axis : axis_value) { + dims_mapping.insert(dims_mapping.begin() + axis, -1); + } + x_dist_attr.set_dims_mapping(dims_mapping); + x_grad_dist_attr.set_dims_mapping(dims_mapping); + } + + return {{x_dist_attr, out_grad.dist_attr()}, {x_grad_dist_attr}}; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/reduction.h b/paddle/phi/infermeta/spmd_rules/reduction.h index ed9341ddc6904b..e010abbb1f60c7 100644 --- a/paddle/phi/infermeta/spmd_rules/reduction.h +++ b/paddle/phi/infermeta/spmd_rules/reduction.h @@ -16,6 +16,7 @@ limitations under the License. */ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/type_defs.h" @@ -23,13 +24,32 @@ namespace phi { namespace distributed { SpmdInfo ReductionInferSpmd(const DistMetaTensor& x, - const std::vector& axis, + const std::vector& axis, bool keep_dim); +// This infer spmd function only use in dynamic mode for it uses +// IntArray as parameter. The IntArray may contain vector of tensor +// which is not support in static mode. So we separate these two and +// use dynamic infer_spmd invoke static infer_spmd function. +SpmdInfo ReductionMeanInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + bool keep_dim); + +SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim); + SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, - const std::vector& axis, + const std::vector& axis, bool keep_dim); +SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const IntArray& axis, + bool keep_dim, + bool reduce_all); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/replicated.cc b/paddle/phi/infermeta/spmd_rules/replicated.cc index fce5a1d89b263e..b2b3b019be0391 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.cc +++ b/paddle/phi/infermeta/spmd_rules/replicated.cc @@ -54,7 +54,11 @@ SpmdInfo ReplicatedInferSpmd(const std::vector& ins, // Step3: Merge and get Inputs' Batch Axis New Dims Mapping. std::vector dst_input_dist_attrs; for (int64_t i = 0; i < ninputs; i++) { + // `ndim == -1` means input is nullptr int ndim = ins[i]->dims().size(); + if (ndim == -1) { + continue; + } TensorDistAttr dist_attr_dst = CopyTensorDistAttrForOutput(ins[i]->dist_attr()); std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); @@ -64,6 +68,9 @@ SpmdInfo ReplicatedInferSpmd(const std::vector& ins, VLOG(4) << "ReplicatedSpmd InferForward:"; for (int64_t i = 0; i < ninputs; i++) { + if (ins[i]->dims().size() == -1) { + continue; + } VLOG(4) << "Input" << std::to_string(i) << " shape: [" << str_join(phi::vectorize(ins[i]->dims())) << "] " << "src_dims_mapping: [" diff --git a/paddle/phi/infermeta/spmd_rules/reshape.cc b/paddle/phi/infermeta/spmd_rules/reshape.cc index 4c95b846c87d03..42e946c7321610 100644 --- a/paddle/phi/infermeta/spmd_rules/reshape.cc +++ b/paddle/phi/infermeta/spmd_rules/reshape.cc @@ -50,7 +50,7 @@ std::vector InferTargetShape(const std::vector& shape, PADDLE_ENFORCE_EQ( product, len, - phi::errors::InvalidArgument("The total size are not matched")); + phi::errors::InvalidArgument("The total size are not matched.")); return std::vector(shape); } else { std::vector new_shape(shape); @@ -59,7 +59,7 @@ std::vector InferTargetShape(const std::vector& shape, PADDLE_ENFORCE_EQ(len % infer_size, 0, phi::errors::InvalidArgument( - "The total is not diviable by infer_size")); + "The total is not diviable by infer_size.")); new_shape[infer_idx] = infer_size; return new_shape; } @@ -143,8 +143,11 @@ std::vector MakeReshapeDimTrans( SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, const std::vector& shape) { // Step0: Verify input args based on reshape logic - auto src_shape = phi::vectorize(x.dims()); - int x_ndim = src_shape.size(); + VLOG(2) << "Debug Info for reshape"; + VLOG(2) << "shape: " << str_join(shape); + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + int out_ndim = shape.size(); auto x_dist_attr_src = x.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -154,20 +157,31 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, "dims_mapping size [%d] are not matched.", x_ndim, x_dims_mapping.size())); + VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(x_shape) << "]"; + VLOG(4) << "Out shape: [" << str_join(shape) << "]"; // Step1: Build the transformation from // the original shape to the target shape + // handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...]. + // This is used in inference but reshape allows only one '-1' in the + // target shape, so set the shape to a special value '256' + for (int i = 0; i < x_ndim; i++) { + if (x_shape[i] == -1) { + x_shape[i] = 256; + } + } + // handle the '0' values in target shape, '0' indicates // that the target shape is equal to the source shape std::vector tgt_shape(shape); - for (int64_t i = 0, n = static_cast(tgt_shape.size()); i < n; i++) { + for (int64_t i = 0; i < out_ndim; i++) { if (tgt_shape[i] == 0) { - tgt_shape[i] = src_shape[i]; + tgt_shape[i] = x_shape[i]; } } - std::vector trans = MakeReshapeDimTrans(src_shape, tgt_shape); + std::vector trans = MakeReshapeDimTrans(x_shape, tgt_shape); // Step2: Infer the dims mapping of input (if reshard is // needed) and output from the dimension transformation. @@ -181,17 +195,14 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, TensorDistAttr out_dist_attr(x_dist_attr_src); out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(src_shape) - << "] Out shape: [" << str_join(tgt_shape) << "]"; VLOG(4) << "Transformation from input to output:"; for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { DimTrans* t = trans[i]; VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); } VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) - << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) - << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) - << "]\n\n"; + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; CleanUp(); @@ -201,9 +212,12 @@ SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x, SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, const std::vector& shape) { + VLOG(2) << "Debug Info for reshape_reverse"; + VLOG(2) << "shape: " << str_join(shape); // Step0: Verify input args based on reshape logic auto x_shape = phi::vectorize(x.dims()); auto out_shape = phi::vectorize(out.dims()); + int x_ndim = x_shape.size(); int out_ndim = out_shape.size(); auto out_dist_attr_src = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); @@ -214,14 +228,39 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, "dims_mapping size [%d] are not matched.", out_ndim, out_dims_mapping.size())); + VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "], X shape: [" << str_join(x_shape) << "]"; // Step1: Build the transformation from the output shape // to original shape. This function infers the dims mapping // from output to input, we first get the transformation // from output to input so that we can infer the dims mapping // with the map from output axes to input axes. - // Shapes in InferSpmdReverse don't contain -1 or 0, so they will - // not be modified and we can directly use them. + + // handle the case of dynamic shape, like [-1, -1, ...] --> [0, 0, ...]. + // This is used in inference but reshape allows only one '-1' in the + // target shape, so set the shape to a special value '256' + for (int i = 0; i < x_ndim; i++) { + if (x_shape[i] == -1) { + x_shape[i] = 256; + } + } + + // handle the '0' values in target shape, '0' indicates + // that the target shape is equal to the source shape + std::vector tgt_shape(shape); + for (int64_t i = 0; i < out_ndim; i++) { + if (shape[i] == 0) { + out_shape[i] = x_shape[i]; + } + } + + // The out_shape may contain '-1', which will cause error + // when inferring the transformation from out_shape to + // x_shape, so infer the '-1' value before inferrng DimTrans + int64_t nelm = std::accumulate( + x_shape.begin(), x_shape.end(), 1, std::multiplies()); + out_shape = InferTargetShape(out_shape, nelm); std::vector trans = MakeReshapeDimTrans(out_shape, x_shape); // Step2: Infer the dims mapping of input with @@ -236,8 +275,6 @@ SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x, TensorDistAttr x_dist_attr(x.dist_attr()); x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape) - << "] X shape: [" << str_join(x_shape) << "]"; VLOG(4) << "Transformation from output to input:"; for (int64_t i = 0, n = trans.size(); i < n; i++) { DimTrans* t = trans[i]; diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 10ad71f520cf1d..b3f0a87fefba09 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/reduction.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" #include "paddle/phi/infermeta/spmd_rules/reshape.h" +#include "paddle/phi/infermeta/spmd_rules/slice.h" #include "paddle/phi/infermeta/spmd_rules/softmax.h" #include "paddle/phi/infermeta/spmd_rules/split.h" #include "paddle/phi/infermeta/spmd_rules/squeeze.h" @@ -522,6 +523,11 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd), PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse)); +// slice rule +PD_REGISTER_SPMD_RULE(slice, + PD_INFER_SPMD(phi::distributed::SliceInferSpmd), + PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); + // transpose rule PD_REGISTER_SPMD_RULE( transpose, diff --git a/paddle/phi/infermeta/spmd_rules/slice.cc b/paddle/phi/infermeta/spmd_rules/slice.cc new file mode 100644 index 00000000000000..d73fdfe8629efa --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/slice.cc @@ -0,0 +1,176 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/slice.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo SliceInferSpmd(const DistMetaTensor& input, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis) { + auto input_shape = phi::vectorize(input.dims()); + int input_ndim = input_shape.size(); + auto input_dist_attr_src = input.dist_attr(); + std::vector input_dims_mapping = input_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + input_ndim, + input_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Input's rank [%d] and Input's " + "dims_mapping size [%d] are not matched.", + input_ndim, + input_dims_mapping.size())); + + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::string input_axes = alphabet.substr(0, input_ndim); + std::string special_axes = alphabet.substr(input_ndim); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_axes[axis] = special_axes[i]; + } + + std::string out_axes(input_axes); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + out_axes[axis] = '1'; + } + + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{input_axes, input_dims_mapping}}); + + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + + TensorDistAttr out_dist_attr = + CopyTensorDistAttrForOutput(input_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + TensorDistAttr input_dist_attr_dst(input_dist_attr_src); + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_dims_mapping[axis] = -1; + } + input_dist_attr_dst.set_dims_mapping(input_dims_mapping); + + VLOG(4) << "SliceInferSpmd:"; + VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; + VLOG(4) << "Input shape: [" << str_join(input_shape) << "] " + << "src_dims_mapping: [" + << str_join(input_dist_attr_src.dims_mapping()) << "] " + << "dst_dims_mapping: [" << str_join(input_dims_mapping) << "]"; + VLOG(4) << "Output" + << " dims_mapping: [" << str_join(out_dims_mapping) << "]"; + VLOG(4) << std::endl; + + return {{input_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input, + const DistMetaTensor& output, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis) { + auto output_shape = phi::vectorize(output.dims()); + int out_ndim = output_shape.size(); + auto out_dist_attr = output.dist_attr(); + int out_dims_mapping_size = out_dist_attr.dims_mapping().size(); + auto input_shape = phi::vectorize(input.dims()); + int input_ndim = input_shape.size(); + auto input_dist_attr = input.dist_attr(); + std::vector input_dims_mapping = input_dist_attr.dims_mapping(); + + PADDLE_ENFORCE_EQ( + input_ndim, + out_ndim, + phi::errors::InvalidArgument("The Tensor Input's rank [%d] is not equal " + "to the Tensor Output's rank [%d]", + input_ndim, + out_ndim)); + + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping_size, + phi::errors::InvalidArgument("The Tensor Output's rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping_size)); + + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::string input_axes = alphabet.substr(0, input_ndim); + std::string special_axes = alphabet.substr(input_ndim); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_axes[axis] = special_axes[i]; + } + + std::string out_axes(input_axes); + + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + out_axes[axis] = special_axes[i]; + } + + std::vector>> axes_sharding_info; + std::vector out_dims_mapping = output.dist_attr().dims_mapping(); + axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping)); + + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + input_dims_mapping = GetDimsMappingForAxes(input_axes, axis_to_dim_map, true); + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + input_dims_mapping[axis] = -1; + } + input_dist_attr.set_dims_mapping(input_dims_mapping); + out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); + for (int i = 0; i < static_cast(axes.size()); i++) { + int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + out_dims_mapping[axis] = -1; + } + out_dist_attr.set_dims_mapping(out_dims_mapping); + + VLOG(4) << "SliceInferSpmdReverse:"; + VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; + VLOG(4) << "Output" + << " shape: [" << str_join(phi::vectorize(output.dims())) << "] " + << "src_dims_mapping: [" + << str_join(output.dist_attr().dims_mapping()) << "] " + << "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping()) + << "]"; + VLOG(4) << "Input shape: [" << str_join(input_shape) << "] " + << "dims_mapping: [" << str_join(input_dims_mapping) << "]\n\n"; + + return {{input_dist_attr}, {out_dist_attr}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/slice.h b/paddle/phi/infermeta/spmd_rules/slice.h new file mode 100644 index 00000000000000..5a49ad9e0c48d6 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/slice.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo SliceInferSpmd(const DistMetaTensor& input, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis); + +SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input, + const DistMetaTensor& output, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 31bfba2a0d4338..dc6141f3ec0ce2 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -135,7 +135,7 @@ TensorDistAttr CopyTensorDistAttrForOutput( TensorDistAttr new_dist_attr = TensorDistAttr(); new_dist_attr.set_process_mesh(src_dist_attr.process_mesh()); new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); - new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); + // new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); // new_dist_attr.set_annotated(false); TODO unset field is false by default. new_dist_attr.clean_partial_status(); // in partial-stage I, partial is allow // to propagate diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 243f0b232395e4..8873a617ef303f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -672,6 +672,15 @@ void DecodeJpegInferMeta(const MetaTensor& x, } } +void DeQuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y) { + auto x_dims = x.dims(); + y->set_dims(x_dims); + y->set_dtype(out_dtype); +} + void DiagEmbedInferMeta( const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) { auto x_dims = x.dims(); @@ -3263,6 +3272,7 @@ void ReduceInferMeta(const MetaTensor& x, if (axis.empty()) { reduce_all = true; } + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out); } @@ -3768,6 +3778,15 @@ void FillSplitOutDims(const MetaTensor& x, } } +void QuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y) { + auto x_dims = x.dims(); + y->set_dims(x_dims); + y->set_dtype(out_dtype); +} + void SplitInferMeta(const MetaTensor& x, const IntArray& sections, const Scalar& axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d79b53a71097e4..8a28d454e42f79 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -145,6 +145,11 @@ void DecodeJpegInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* out); +void DeQuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y); + void DiagEmbedInferMeta( const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out); @@ -453,6 +458,11 @@ void QrInferMeta(const MetaTensor& x, MetaTensor* q, MetaTensor* r); +void QuantizeXPUInferMeta(const MetaTensor& x, + DataType out_dtype, + float scale, + MetaTensor* y); + void WeightQuantizeInferMeta(const MetaTensor& x, const std::string& algo, MetaTensor* out, diff --git a/paddle/phi/kernels/batch_norm_grad_kernel.h b/paddle/phi/kernels/batch_norm_grad_kernel.h index ec4753604283fe..fc3d2f3d9886ac 100644 --- a/paddle/phi/kernels/batch_norm_grad_kernel.h +++ b/paddle/phi/kernels/batch_norm_grad_kernel.h @@ -23,8 +23,8 @@ namespace phi { template void BatchNormGradFunctor(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -45,8 +45,8 @@ void BatchNormGradFunctor(const Context& dev_ctx, template void BatchNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -67,7 +67,7 @@ template void BatchNormDoubleGradKernel( const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, + const paddle::optional& scale, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, diff --git a/paddle/phi/kernels/batch_norm_kernel.h b/paddle/phi/kernels/batch_norm_kernel.h index edae79941f535e..b81f9b03700960 100644 --- a/paddle/phi/kernels/batch_norm_kernel.h +++ b/paddle/phi/kernels/batch_norm_kernel.h @@ -25,8 +25,8 @@ void BatchNormKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mean, const DenseTensor& variance, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, bool is_test, float momentum, float epsilon, @@ -57,8 +57,8 @@ void BatchNormInferKernel(const Context& dev_ctx, template void phi::BatchNormGradFunctor( \ const ::phi::backend##Context& dev_ctx, \ const DenseTensor& x, \ - const DenseTensor& scale, \ - const DenseTensor& bias, \ + const paddle::optional& scale, \ + const paddle::optional& bias, \ const paddle::optional& mean, \ const paddle::optional& variance, \ const DenseTensor& saved_mean, \ diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index 32d06c354a1c20..7dc8f39da05132 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -38,8 +38,8 @@ using ConstEigenVectorArrayMap = template void BatchNormGradFunctor(const Context& ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -139,8 +139,6 @@ void BatchNormGradFunctor(const Context& ctx, inv_var_data = saved_variance.data(); } - ConstEigenVectorArrayMap scale_arr(scale.data(), C); - ConstEigenVectorArrayMap bias_arr(bias.data(), C); ConstEigenVectorArrayMap mean_arr(mean_data, C); ConstEigenVectorArrayMap inv_var_arr(inv_var_data, C); @@ -167,6 +165,20 @@ void BatchNormGradFunctor(const Context& ctx, phi::Copy(ctx, *d_y, ctx.GetPlace(), false, d_x); return; } + auto* Scale = scale.get_ptr(); + auto* Bias = bias.get_ptr(); + Eigen::Array scale_arr(C); + Eigen::Array bias_arr(C); + if (Scale) { + scale_arr = ConstEigenVectorArrayMap(Scale->data(), C); + } else { + scale_arr.setOnes(); + } + if (Bias) { + bias_arr = ConstEigenVectorArrayMap(Bias->data(), C); + } else { + bias_arr.setZero(); + } int scale_coefff = use_global_stats ? 1 : N * sample_size; const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff; @@ -295,8 +307,8 @@ void BatchNormGradFunctor(const Context& ctx, template void BatchNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -338,7 +350,7 @@ template void BatchNormDoubleGradKernel( const Context& ctx, const DenseTensor& x, - const DenseTensor& scale, + const paddle::optional& scale, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -357,7 +369,7 @@ void BatchNormDoubleGradKernel( DenseTensor* scale_grad, DenseTensor* y_grad_grad) { const auto* X = &x; - const auto* Scale = &scale; + const auto* Scale = scale.get_ptr(); const auto* dY = &y_grad; const auto* Saved_mean = &saved_mean; const auto* Saved_variance = &saved_variance; diff --git a/paddle/phi/kernels/cpu/batch_norm_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_kernel.cc index 4db0e2f3f53781..e6acb16a89185a 100644 --- a/paddle/phi/kernels/cpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_kernel.cc @@ -37,8 +37,8 @@ void BatchNormKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& mean, const DenseTensor& variance, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, bool is_test, float momentum, float epsilon, @@ -167,11 +167,27 @@ void BatchNormKernel(const Context& ctx, // ((x - est_mean) * (inv_var) * scale + bias // formula transform ====> // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) - ConstEigenVectorArrayMap scale_arr(scale.data(), C); - ConstEigenVectorArrayMap bias_arr(bias.data(), C); - Eigen::Array new_scale = inv_std * scale_arr; - Eigen::Array new_bias = - bias_arr - mean_arr * inv_std * scale_arr; + auto* Scale = scale.get_ptr(); + auto* Bias = bias.get_ptr(); + Eigen::Array new_scale(C); + Eigen::Array new_bias(C); + if (Scale && Bias) { + ConstEigenVectorArrayMap scale_arr(Scale->data(), C); + ConstEigenVectorArrayMap bias_arr(Bias->data(), C); + new_scale = inv_std * scale_arr; + new_bias = bias_arr - mean_arr * inv_std * scale_arr; + } else if (Scale) { + ConstEigenVectorArrayMap scale_arr(Scale->data(), C); + new_scale = inv_std * scale_arr; + new_bias = -(mean_arr * inv_std * scale_arr); + } else if (Bias) { + ConstEigenVectorArrayMap bias_arr(Bias->data(), C); + new_scale = inv_std; + new_bias = bias_arr - mean_arr * inv_std; + } else { + new_scale = inv_std; + new_bias = -(mean_arr * inv_std); + } switch (data_layout) { case DataLayout::kNCHW: { diff --git a/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc index a0e2611f92cfcf..1a9a737866153e 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_grad_kernel.cc @@ -45,6 +45,9 @@ PD_REGISTER_KERNEL(divide_grad, phi::DivideGradKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, phi::dtype::complex, diff --git a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc index 20aae406136a2b..a5fc4552bfbf28 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc @@ -59,6 +59,9 @@ PD_REGISTER_KERNEL(divide, phi::DivideKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, complex64, diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index 5e37d3dfa262ff..bb2533490cfc29 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -88,13 +88,14 @@ void FullLikeKernel(const Context& dev_ctx, template void FullIntArrayKernel(const Context& dev_ctx, - const IntArray& val, + const std::vector& shape, DataType dtype UNUSED, DenseTensor* out) { - out->Resize(phi::make_ddim({static_cast(val.GetData().size())})); + out->Resize(phi::make_ddim({static_cast(shape.size())})); T* out_data = dev_ctx.template Alloc(out); - for (size_t i = 0; i < val.GetData().size(); ++i) { - out_data[i] = static_cast(val.GetData()[i]); + for (size_t i = 0; i < shape.size(); ++i) { + int64_t val = shape[i]; + out_data[i] = static_cast(val); } } diff --git a/paddle/phi/kernels/elementwise_divide_kernel.h b/paddle/phi/kernels/elementwise_divide_kernel.h index c5c9993826b541..8a78435950c0fb 100644 --- a/paddle/phi/kernels/elementwise_divide_kernel.h +++ b/paddle/phi/kernels/elementwise_divide_kernel.h @@ -25,14 +25,24 @@ void DivideKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out); +template +void Divide(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* dense_out) { + MetaTensor meta_out(dense_out); + ElementwiseInferMeta(x, y, &meta_out); + if (x.initialized()) { + DivideKernel(dev_ctx, x, y, dense_out); + } +} + template DenseTensor Divide(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y) { DenseTensor dense_out; - MetaTensor meta_out(&dense_out); - ElementwiseInferMeta(x, y, &meta_out); - DivideKernel(dev_ctx, x, y, &dense_out); + Divide(dev_ctx, x, y, &dense_out); return dense_out; } diff --git a/paddle/phi/kernels/frobenius_norm_grad_kernel.h b/paddle/phi/kernels/frobenius_norm_grad_kernel.h index 65db8dd9e0a108..78494c4423f7e5 100644 --- a/paddle/phi/kernels/frobenius_norm_grad_kernel.h +++ b/paddle/phi/kernels/frobenius_norm_grad_kernel.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -25,7 +26,7 @@ void FrobeniusNormGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* dx); diff --git a/paddle/phi/kernels/frobenius_norm_kernel.h b/paddle/phi/kernels/frobenius_norm_kernel.h index 30122cb416094d..45ddb6123b85da 100644 --- a/paddle/phi/kernels/frobenius_norm_kernel.h +++ b/paddle/phi/kernels/frobenius_norm_kernel.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -23,7 +24,7 @@ namespace phi { template void FrobeniusNormKernel(const Context& ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* out); diff --git a/paddle/phi/kernels/full_kernel.h b/paddle/phi/kernels/full_kernel.h index cef58433e9e04f..b10e02658fe754 100644 --- a/paddle/phi/kernels/full_kernel.h +++ b/paddle/phi/kernels/full_kernel.h @@ -92,7 +92,7 @@ DenseTensor FullLike(const Context& dev_ctx, template void FullIntArrayKernel(const Context& dev_ctx, - const IntArray& val, + const std::vector& shape, DataType dtype, DenseTensor* out); diff --git a/paddle/phi/kernels/fusion/cpu/fusion_gru_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_gru_kernel.cc new file mode 100644 index 00000000000000..3b140091fc69c4 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_gru_kernel.cc @@ -0,0 +1,439 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // for memcpy +#include +#include + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" +#include "paddle/phi/kernels/funcs/sequence2batch.h" + +namespace phi { +namespace fusion { + +#define INIT_BASE_DEFINES \ + auto x_lod = x.lod(); \ + auto x_dims = x.dims(); /* T x M*/ \ + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \ + ? phi::flatten_to_2d(x_dims, 1) \ + : x_dims; \ + auto wh_dims = weight_h.dims(); /* D x 3D*/ \ + const int total_T = x_mat_dims[0]; \ + const int D3 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + const int M = x_mat_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const phi::jit::gru_attr_t attr(D, \ + phi::jit::to_kerneltype(gate_activation), \ + phi::jit::to_kerneltype(activation)); \ + phi::jit::gru_t one_step; \ + auto ComputeH1 = \ + phi::jit::KernelFuncs, phi::CPUPlace>::Cache() \ + .At(attr); \ + auto ComputeHtPart1 = phi::jit::KernelFuncs, \ + phi::CPUPlace>::Cache() \ + .At(attr); \ + auto ComputeHtPart2 = phi::jit::KernelFuncs, \ + phi::CPUPlace>::Cache() \ + .At(attr); \ + const T* x_data = x.data(); \ + const T* wx_data = weight_x.data(); \ + const T* wh_data = weight_h.data(); \ + T* xx_data = dev_ctx.template Alloc(xx) + +template +void SeqCompute(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; + const int N = static_cast(x_lod[0].size() - 1); + const T* h0_data = h0 ? h0->data() : nullptr; + const T* wh_state_data = wh_data + D * D2; + T* hidden_out_data = dev_ctx.template Alloc(hidden); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + phi::funcs::FCFunctor fc; + fc(dev_ctx, + total_T, + D3, + M, + x_data, + wx_data, + xx_data, + bias ? bias->data() : nullptr); + + int xx_offset = D3; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 3; + hidden_out_data = hidden_out_data + offset; + xx_offset = -D3; + gate_offset = -D; + } + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + }; + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = static_cast(x_lod[0][bid + 1] - x_lod[0][bid]); + const T* prev_hidden_data = nullptr; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + bid * D; + } else { + one_step.gates = xx_data; + one_step.ht = hidden_out_data; + ComputeH1(&one_step, &attr); + prev_hidden_data = hidden_out_data; + tstart = 1; + move_step(); + } + for (int step = tstart; step < seq_len; ++step) { + // gemm prev * (Wu + Wr) + blas.GEMM(CblasNoTrans, + CblasNoTrans, + 1, + D2, + D, + static_cast(1), + prev_hidden_data, + D, + wh_data, + D2, + static_cast(1), + xx_data, + D3); + one_step.gates = xx_data; + one_step.ht_1 = prev_hidden_data; + one_step.ht = hidden_out_data; + ComputeHtPart1(&one_step, &attr); + // gemm rt * Ws + blas.GEMM(CblasNoTrans, + CblasNoTrans, + 1, + D, + D, + static_cast(1), + hidden_out_data, + D, + wh_state_data, + D, + static_cast(1), + xx_data + D2, + D3); + ComputeHtPart2(&one_step, &attr); + // save prev + prev_hidden_data = hidden_out_data; + move_step(); + } + } +} + +template +void BatchCompute(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + INIT_BASE_DEFINES; + if (x_lod[0].size() == 2) { + xx->Resize({total_T, D3}); + SeqCompute(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + return; + } + INIT_OTHER_DEFINES; + T* batched_input_data = dev_ctx.template Alloc(batched_input); + T* batched_out_data = dev_ctx.template Alloc(batched_out); + dev_ctx.template Alloc(hidden); + auto blas = phi::funcs::GetBlas(dev_ctx); + phi::funcs::LoDTensor2BatchFunctor to_batch; + + phi::funcs::FCFunctor fc; + if (M > D3) { + fc(dev_ctx, + total_T, + D3, + M, + x_data, + wx_data, + xx_data, + bias ? bias->data() : nullptr); + to_batch(dev_ctx, *xx, batched_input, true, is_reverse); + } else { + to_batch(dev_ctx, x, xx, true, is_reverse); + batched_input->set_lod(xx->lod()); + fc(dev_ctx, + total_T, + D3, + M, + xx_data, + wx_data, + batched_input_data, + bias ? bias->data() : nullptr); + } + + auto batched_lod = batched_input->lod(); + const auto& seq_order = batched_lod[2]; + const int max_bs = static_cast(seq_order.size()); + reordered_h0->Resize({max_bs, D}); + + int tstart = 0; + T* prev_hidden_data = nullptr; + if (h0) { + // reorder h0 + T* reordered_h0_data = dev_ctx.template Alloc(reordered_h0); + const T* h0_data = h0->data(); + prev_hidden_data = reordered_h0_data; + size_t sz = sizeof(T) * D; + for (int i = 0; i < max_bs; ++i) { + std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); + reordered_h0_data += D; + } + } else { + // compute without h0 + T* cur_in_data = batched_input_data; + T* cur_out_data = batched_out_data; + // W: {W_update, W_reset; W_state} + for (int i = 0; i < max_bs; ++i) { + one_step.gates = cur_in_data; + one_step.ht = cur_out_data; + ComputeH1(&one_step, &attr); + // add offset + cur_in_data += D3; + cur_out_data += D; + } + tstart = 1; + prev_hidden_data = batched_out_data; + } + // Then start from next + const T* wh_state_data = wh_data + D * D2; + const auto& batch_starts = batched_lod[0]; + const int max_seq_len = static_cast(batch_starts.size() - 1); + batched_input_data = batched_input_data + tstart * max_bs * D3; + batched_out_data = batched_out_data + tstart * max_bs * D; + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = + static_cast(batch_starts[step + 1] - batch_starts[step]); + // gemm prev * (Wu + Wr) + blas.GEMM(CblasNoTrans, + CblasNoTrans, + cur_bs, + D2, + D, + static_cast(1), + prev_hidden_data, + D, + wh_data, + D2, + static_cast(1), + batched_input_data, + D3); + + T* cur_batched_data = batched_input_data; + T* cur_out_data = batched_out_data; + T* cur_prev_hidden_data = prev_hidden_data; + for (int i = 0; i < cur_bs; ++i) { + one_step.gates = cur_batched_data; + one_step.ht_1 = cur_prev_hidden_data; + one_step.ht = cur_out_data; + ComputeHtPart1(&one_step, &attr); + + cur_batched_data += D3; + cur_prev_hidden_data += D; + cur_out_data += D; + } + + cur_batched_data = batched_input_data; + cur_out_data = batched_out_data; + blas.GEMM(CblasNoTrans, + CblasNoTrans, + cur_bs, + D, + D, + static_cast(1), + cur_out_data, + D, + wh_state_data, + D, + static_cast(1), + cur_batched_data + D2, + D3); + + cur_prev_hidden_data = prev_hidden_data; + for (int i = 0; i < cur_bs; ++i) { + one_step.gates = cur_batched_data; + one_step.ht_1 = cur_prev_hidden_data; + one_step.ht = cur_out_data; + ComputeHtPart2(&one_step, &attr); + cur_batched_data += D3; + cur_prev_hidden_data += D; + cur_out_data += D; + } + prev_hidden_data = batched_out_data; + batched_out_data = cur_out_data; + batched_input_data = cur_batched_data; + } + + phi::funcs::Batch2LoDTensorFunctor to_seq; + batched_out->set_lod(batched_lod); + to_seq(dev_ctx, *batched_out, hidden); +} + +template +void FusionGRUKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + if (use_seq) { + SeqCompute(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } else { + BatchCompute(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL( + fusion_gru, CPU, ALL_LAYOUT, phi::fusion::FusionGRUKernel, float, double) {} diff --git a/paddle/phi/kernels/fusion/cpu/fusion_seqconv_eltadd_relu_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_seqconv_eltadd_relu_kernel.cc new file mode 100644 index 00000000000000..fbe2ea8d12bc27 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_seqconv_eltadd_relu_kernel.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // for min, max +#include + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" + +namespace phi { +namespace fusion { + +template +void FusionSeqConvEltAddReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& filter, + const DenseTensor& bias, + const int context_length, + const int context_start, + const int context_stride, + DenseTensor* out, + DenseTensor* col_mat) { + auto x_lod = x.lod(); + auto x_dims = phi::vectorize(x.dims()); + auto w_dims = phi::vectorize(filter.dims()); + PADDLE_ENFORCE_EQ( + bias.numel(), + w_dims[1], + phi::errors::InvalidArgument( + "bias size should be equal to weights feature size, but received " + "bias size is: %d, weights feature size is: %d.", + bias.numel(), + w_dims[1])); + PADDLE_ENFORCE_EQ( + x_lod.size(), + 1UL, + phi::errors::InvalidArgument( + "Only support one level sequence now, but received value is: %d.", + x_lod.size())); + + const T* x_data = x.data(); + const T* w_data = filter.data(); + const T* b_data = bias.data(); + T* y_data = dev_ctx.template Alloc(out); + T* col_data = dev_ctx.template Alloc(col_mat); + + int up_pad = std::max(0, -context_start); + int down_pad = std::max(0, context_start + context_length - 1); + // im2col + int src_mat_w = static_cast(x_dims[1]); + int src_mat_w_sz = src_mat_w * sizeof(T); + int col_mat_w = static_cast(w_dims[0]); + int col_mat_w_sz = col_mat_w * sizeof(T); + for (int i = 0; i < static_cast(x_lod[0].size()) - 1; ++i) { + int st = static_cast(x_lod[0][i]); + int ed = static_cast(x_lod[0][i + 1]); + const T* src_data = x_data + st * src_mat_w; + T* dst_data = col_data + st * col_mat_w; + int seq_len = ed - st; + if (seq_len > up_pad + down_pad) { + // zero all up_pad and fill data + std::memset(dst_data, 0, up_pad * col_mat_w_sz); + dst_data = dst_data + up_pad * src_mat_w; + int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz; + for (int j = 0; j < up_pad; ++j) { + // blas.VCOPY? + std::memcpy(dst_data, src_data, copy_size); + dst_data += (col_mat_w - src_mat_w); + copy_size += src_mat_w_sz; + } + // fill data + if (context_start > 0) { + src_data += context_start * src_mat_w; + } + for (int j = 0; j < seq_len - up_pad - down_pad; ++j) { + std::memcpy(dst_data, src_data, copy_size); + dst_data += col_mat_w; + src_data += src_mat_w; + } + // zero all down_pad and fill data + std::memset(dst_data, 0, down_pad * col_mat_w_sz); + copy_size -= src_mat_w_sz; + for (int j = 0; j < down_pad; ++j) { + if (copy_size < 0) { + copy_size = 0; + } + std::memcpy(dst_data, src_data, copy_size); + dst_data += col_mat_w; + src_data += src_mat_w; + copy_size -= src_mat_w_sz; + } + } else { + std::memset(dst_data, 0, seq_len * col_mat_w_sz); + dst_data = dst_data + up_pad * src_mat_w; + int zero_sz = up_pad * src_mat_w_sz; + int cur_src_sz = seq_len * src_mat_w_sz; + for (int j = 0; j < std::min(up_pad, seq_len); ++j) { + int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); + std::memcpy(dst_data, src_data, copy_size); + dst_data += (col_mat_w - src_mat_w); + zero_sz -= src_mat_w_sz; + } + // from bottom + dst_data = col_data + ed * col_mat_w; + src_data = x_data + st * src_mat_w; + if (context_start > 0) { + src_data += context_start * src_mat_w; + } + zero_sz = down_pad * src_mat_w_sz; + for (int j = 1; j <= std::min(down_pad, seq_len); ++j) { + int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz); + if (copy_size < 0) { + copy_size = 0; + } + std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T), + src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w, + copy_size); + dst_data -= col_mat_w; + zero_sz -= src_mat_w_sz; + } + } + } + phi::funcs::FCFunctor fc; + fc(dev_ctx, + x_dims[0], + w_dims[1], + w_dims[0], + col_data, + w_data, + y_data, + b_data, + true); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_seqconv_eltadd_relu, + CPU, + ALL_LAYOUT, + phi::fusion::FusionSeqConvEltAddReluKernel, + float, + double) {} diff --git a/paddle/phi/kernels/fusion/cpu/fusion_seqexpand_concat_fc_kernel.cc b/paddle/phi/kernels/fusion/cpu/fusion_seqexpand_concat_fc_kernel.cc new file mode 100644 index 00000000000000..d5eb7894455f1d --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fusion_seqexpand_concat_fc_kernel.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/backends/cpu/cpu_info.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/cpu_vec.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" + +namespace phi { +namespace fusion { +template +void FusionSeqExpandConcatFCKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& fc_weight, + const paddle::optional& fc_bias, + const std::string& fc_activation, + DenseTensor* out, + DenseTensor* fc_out) { + auto* ref_in = x[0]; + auto ref_lod = ref_in->lod(); + auto in1_lod = x[1]->lod(); + auto ref_dims = ref_in->dims(); // T x M0 + auto in1_dims = x[1]->dims(); // N x M1 + auto w_dims = fc_weight.dims(); + const int N = static_cast(ref_lod[0].size() - 1); + const int total_T = static_cast(ref_dims[0]); + const int M0 = static_cast(ref_dims[1]); + const int M1 = static_cast(in1_dims[1]); + const int D = static_cast(w_dims[1]); + + // some check and fcout should be reshape here + // since infershape can not get lod info + PADDLE_ENFORCE_EQ( + ref_lod.size(), + 1UL, + phi::errors::InvalidArgument( + "Only support input lod size is 1, but received value is: %d.", + ref_lod.size())); + PADDLE_ENFORCE_EQ( + in1_lod.size(), + 1UL, + phi::errors::InvalidArgument( + "Only support input lod size is 1, but received value is: %d.", + in1_lod.size())); + PADDLE_ENFORCE_EQ(static_cast(in1_lod[0].size() - 1), + N, + phi::errors::InvalidArgument( + "Batch size of all inputs should be equal to %d, but " + "received value is: %d.", + N, + static_cast(in1_lod[0].size() - 1))); + PADDLE_ENFORCE_EQ( + static_cast(in1_lod[0][N]), + N, + phi::errors::InvalidArgument("Seq_length of other inputs should " + "be %d, but received value is: %d.", + N, + static_cast(in1_lod[0][N]))); + PADDLE_ENFORCE_EQ( + in1_dims[0], + N, + phi::errors::InvalidArgument( + "input height should be batch size: %d, but received value is %d.", + N, + in1_dims[0])); + for (size_t i = 2; i < x.size(); ++i) { + PADDLE_ENFORCE_EQ(x[i]->dims()[0], + N, + phi::errors::InvalidArgument( + "All other inputs height should be equal to %d, " + "but received value is: %d.", + N, + x[i]->dims()[0])); + PADDLE_ENFORCE_EQ(x[i]->lod(), + in1_lod, + phi::errors::InvalidArgument( + "All other inputs should have same lod: %d, but " + "received value is: %d.", + in1_lod, + x[i]->lod())); + } + fc_out->Resize({N, D}); + + std::function fc_act; + if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { + phi::funcs::VecActivations act_functor; + fc_act = act_functor(fc_activation); + } else { + phi::funcs::VecActivations act_functor; + fc_act = act_functor(fc_activation); + } + + const T* ref_in_data = ref_in->data(); + const T* in1_data = x[1]->data(); + const T* w_data = fc_weight.data(); + T* out_data = dev_ctx.template Alloc(out); + T* fc_out_data = dev_ctx.template Alloc(fc_out); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + phi::funcs::FCFunctor fc; + fc(dev_ctx, + total_T, + D, + M0, + ref_in_data, + w_data, + out_data, + fc_bias ? fc_bias->data() : NULL); + w_data = w_data + M0 * D; + // first write on + blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); + w_data = w_data + M1 * D; + for (size_t i = 2; i < x.size(); ++i) { + // add on + const T* in_data = x[i]->data(); + const int K = static_cast(x[i]->dims()[1]); + blas.GEMM(CblasNoTrans, + CblasNoTrans, + N, + D, + K, + static_cast(1), + in_data, + K, + w_data, + D, + static_cast(1), + fc_out_data, + D); + w_data = w_data + K * D; + } + T* cur_out_data = out_data; + for (int i = 0; i < N; ++i) { + int seq_len = static_cast(ref_lod[0][i + 1] - ref_lod[0][i]); + T* src = fc_out_data + i * D; + for (int step = 0; step < seq_len; ++step) { + blas.VADD(D, cur_out_data, src, cur_out_data); + cur_out_data = cur_out_data + D; + } + } + fc_act(total_T * D, out_data, out_data); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_seqexpand_concat_fc, + CPU, + ALL_LAYOUT, + phi::fusion::FusionSeqExpandConcatFCKernel, + float, + double) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu index 3b9618db02db05..894903fb0fab83 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu @@ -59,11 +59,7 @@ void FusedBatchNormAddActGradKernel(const Context &dev_ctx, DenseTensor *z_grad, DenseTensor *scale_grad, DenseTensor *bias_grad) { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_bn_add_activation operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif +#if defined(PADDLE_WITH_CUDA) and CUDNN_VERSION >= 7401 bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; PADDLE_ENFORCE_EQ(is_gpu_place, true, @@ -208,6 +204,11 @@ void FusedBatchNormAddActGradKernel(const Context &dev_ctx, phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_bn_add_activation operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu index 7b5b4119cf9705..52152476e4aca1 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_kernel.cu @@ -59,11 +59,7 @@ void FusedBatchNormAddActKernel(const Context &dev_ctx, DenseTensor *saved_mean, DenseTensor *saved_variance, DenseTensor *reserve_space) { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_bn_add_activation operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif +#if defined(PADDLE_WITH_CUDA) and CUDNN_VERSION >= 7401 bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; PADDLE_ENFORCE_EQ(is_gpu_place, true, @@ -210,6 +206,11 @@ void FusedBatchNormAddActKernel(const Context &dev_ctx, phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_bn_add_activation operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu b/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu index 954fbd67b96abc..b71f814fd4c985 100644 --- a/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fusion_transpose_flatten_concat_kernel.cu @@ -37,6 +37,7 @@ void TransposeFlattenConcatFusionKernel( const int flatten_axis, const int concat_axis, DenseTensor* out) { +#if defined(PADDLE_WITH_CUDA) dev_ctx.template Alloc(out, out->numel() * sizeof(T)); auto odims = out->dims(); @@ -114,6 +115,10 @@ void TransposeFlattenConcatFusionKernel( phi::dynload::cudnnDestroyTensorDescriptor(in_desc)); PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnDestroyTensorDescriptor(out_desc)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fusion_transpose_flatten_concat operator is not supported on HIP.")); +#endif } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc b/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc new file mode 100644 index 00000000000000..e3fa939aad7537 --- /dev/null +++ b/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc @@ -0,0 +1,638 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/expect.h" +#include "paddle/phi/core/utils/data_type.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +using phi::OneDNNContext; +using phi::funcs::CreateKey; +using phi::funcs::OneDNNGetDataType; +using phi::funcs::OneDNNMemDesc; +using phi::funcs::RNNReorderType; +using OneDNNMemoryFormat = dnnl::memory::format_tag; + +template +class GRUOneDNNHandler + : public phi::funcs::OneDNNHandlerT { + public: + GRUOneDNNHandler(const OneDNNContext& dev_ctx, + const dnnl::engine onednn_engine, + phi::Place cpu_place UNUSED, + const phi::DenseTensor* input, + const phi::DenseTensor* weight_h, + const phi::DenseTensor* h0, + const bool is_reverse, + const float scale_data, + const float shift_data, + const std::string& gate_activation, + const std::string& activation, + const std::vector& scale_weights, + const int64_t N, + const int64_t Ti, + const int64_t IC, + const int64_t OC) + : phi::funcs::OneDNNHandlerT( + dev_ctx, + dev_ctx.GetEngine(), + cpu_place, + CreateKey(dev_ctx, + dev_ctx.GetInputsName("X")[0] + + dev_ctx.GetInputsName("WeightH")[0], + OneDNNGetDataType(), + Ti)), + N(N), + Ti(Ti), + IC(IC), + OC(OC), + G(3) { + std::string unique_name = + dev_ctx.GetInputsName("X")[0] + dev_ctx.GetInputsName("WeightH")[0]; + // Create memory key without Ti because weights, bias and h0 memories + // do not depend on Ti size but primitive and input/output memory do + memory_key_ = phi::funcs::ExtendKeyWithThreadInfoIfNeeded( + dev_ctx, CreateKey(dev_ctx, unique_name, OneDNNGetDataType())); + // Is it int8 kernel + const bool is_INT8 = std::is_same::value; + if (is_INT8) { + const int weights_scale_mask = + 0 + + (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo` + + + (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo` + + attr_.set_rnn_data_qparams(scale_data, shift_data); + attr_.set_rnn_weights_qparams(weights_scale_mask, scale_weights); + } + + if (unlikely(!this->isCached())) { + // oneDNN kernel has hardcoded activation functions + PADDLE_ENFORCE_EQ( + gate_activation, + "sigmoid", + phi::errors::Unimplemented( + "oneDNN fusion_gru supports only sigmoid as a gate activation.")); + PADDLE_ENFORCE_EQ( + activation, + "tanh", + phi::errors::Unimplemented( + "oneDNN fusion_gru supports only tanh as an activation.")); + + // Weights for int8 kernel are of a type s8 + const auto weights_dt = + is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType(); + + // oneDNN RNN dimensions + const int64_t D = 1; // Directions + const int64_t L = 1; // Layers (PP supports only 1 stacked layer) + const int64_t G = 3; // Number of Gates, 3 for GRU + + // Create memory descriptors + auto input_md = OneDNNMemDesc( + {Ti, N, IC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); + auto weight_x_md = + OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any); + auto weight_h_md = + OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any); + auto bias_md = OneDNNMemDesc( + {L, D, G, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldgo); + auto hidden_md = OneDNNMemDesc( + {Ti, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ntc); + auto h0_md = OneDNNMemDesc( + {L, D, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc); + + // Create GRU oneDNN primitive + const auto direction = + is_reverse ? dnnl::rnn_direction::unidirectional_right2left + : dnnl::rnn_direction::unidirectional_left2right; + + this->AcquireForwardPrimitiveDescriptor( + this->attr_, + dnnl::prop_kind::forward_inference, + direction, + input_md, + h0_md, + weight_x_md, + weight_h_md, + bias_md, + hidden_md, + dnnl::memory::desc()); + } + } + + bool is_NTC() { return this->is_NTC(this->fwd_pd_->dst_desc()); } + + bool is_NTC(const dnnl::memory::desc& md) { + auto ntc_md = dnnl::memory::desc( + md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::ntc); + return md == ntc_md; + } + + void reorderRNNdata(void* input_data, + void* output_data, + std::vector lod, + const bool is_reverse, + RNNReorderType reorder_type) { + switch (reorder_type) { + // Reorder input memory [WORDS, C] + LoD -> [N, T, C] + case RNNReorderType::PP_NTC: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * IC; + const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; + memcpy(output_data_iter + n * Ti * IC + offset, + input_data_iter, + sizeof(T) * num_elements); + input_data_iter += num_elements; + } + } break; + // Reorder input memory [WORDS, C] + LoD -> [T, N, C] + case RNNReorderType::PP_TNC: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]); + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter + (t + offset) * N * IC + n * IC, + input_data_iter, + sizeof(T) * IC); + input_data_iter += IC; + } + } + } break; + // Reorder output values to PP format [N, T, C] -> [WORDS, C] + case RNNReorderType::NTC_PP: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * OC; + const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; + memcpy(output_data_iter, + input_data_iter + n * Ti * OC + offset, + sizeof(T_out) * num_elements); + output_data_iter += num_elements; + } + } break; + // Reorder output values to PP format [T, N, C] -> [WORDS, C] + case RNNReorderType::TNC_PP: { + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); + for (int n = 0; n < N; ++n) { + const auto num_elements = lod[n + 1] - lod[n]; + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter, + input_data_iter + (t + offset) * N * OC + n * OC, + sizeof(T_out) * OC); + output_data_iter += OC; + } + } + } break; + } + } + + std::shared_ptr AcquireInputMemoryWithReorder( + const phi::DenseTensor* input, const bool is_reverse) { + const auto name = this->key_ + "@input_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->src_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + + const auto& input_lod = input->lod()[0]; + auto* x_data = phi::funcs::to_void_cast(input->data()); + + auto* x_onednn_data = memory_p->get_data_handle(); + memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); + + if (is_NTC(this->fwd_pd_->src_desc())) { + reorderRNNdata( + x_data, x_onednn_data, input_lod, is_reverse, RNNReorderType::PP_NTC); + } else { + reorderRNNdata( + x_data, x_onednn_data, input_lod, is_reverse, RNNReorderType::PP_TNC); + } + return memory_p; + } + + std::shared_ptr AcquireOutputMemory() { + const auto name = this->key_ + "@output_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->dst_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + return memory_p; + } + + // H0 is for now persistable + template + std::shared_ptr AcquireH0Memory(const phi::DenseTensor* h0) { + const std::string h0_key = memory_key_ + "@h0"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(h0_key)); + + if (!memory_p) { + auto user_h0_memory = dnnl::memory(); + if (h0) { + user_h0_memory = dnnl::memory( + {{1, 1, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc}, + this->engine_, + phi::funcs::to_void_cast(h0->data())); + } else { + user_h0_memory = dnnl::memory( + {{1, 1, N, OC}, OneDNNGetDataType(), OneDNNMemoryFormat::ldnc}, + this->engine_); + memset(user_h0_memory.get_data_handle(), 0, sizeof(U) * N * OC); + } + memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), + this->engine_); + + auto& astream = phi::OneDNNContext::tls().get_stream(); + dnnl::reorder(user_h0_memory, *memory_p, attr_) + .execute(astream, user_h0_memory, *memory_p); + + this->dev_ctx_.SetBlob(h0_key, memory_p); + } + return memory_p; + } + + template + std::shared_ptr AcquireWeightXMemory( + const phi::DenseTensor* weight_x, const bool origin_mode) { + const std::string wx_key = this->memory_key_ + "@weight_x"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); + + if (!memory_p) { + auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, + weight_x->data(), + sizeof(U) * this->IC * this->G * this->OC); + + if (origin_mode == false) { + for (int64_t i = 0; i < this->IC; ++i) { + for (int64_t j = 0; j < this->OC; ++j) { + U minus_one(-1.0f); + weight_x_data[j] = minus_one * weight_x_data[j]; + } + weight_x_data += 3 * this->OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_); + + auto& astream = OneDNNContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wx_key, memory_p); + } + return memory_p; + } + + template + std::shared_ptr AcquireWeightHMemory( + const phi::DenseTensor* weight_h, const bool origin_mode) { + const std::string wh_key = this->memory_key_ + "@weight_h"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); + + if (!memory_p) { + auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC}, + OneDNNGetDataType(), + OneDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to + // oneDNN format [OC, 3OC] + auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + auto* user_weight_h_data = weight_h->data(); + + auto src1_iter = user_weight_h_data; + auto src2_iter = user_weight_h_data + 2 * this->OC * this->OC; + + for (int64_t c = 0; c < this->OC; ++c) { + memcpy(weight_h_data, src1_iter, 2 * this->OC * sizeof(U)); + memcpy(weight_h_data + 2 * this->OC, src2_iter, this->OC * sizeof(U)); + + src1_iter += 2 * this->OC; + src2_iter += this->OC; + weight_h_data += 3 * this->OC; + } + + weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + + if (origin_mode == false) { + for (int64_t i = 0; i < this->OC; ++i) { + for (int64_t j = 0; j < this->OC; ++j) { + U minus_one(-1.0f); + weight_h_data[j] = minus_one * weight_h_data[j]; + } + weight_h_data += 3 * this->OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_iter_desc(), this->engine_); + + auto& astream = OneDNNContext::tls().get_stream(); + dnnl::reorder(user_memory, *memory_p, this->attr_) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wh_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireBiasMemory(const phi::DenseTensor* bias, + const bool origin_mode) { + const std::string bias_key = this->memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->bias_desc(), + this->engine_); + auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); + if (bias) { + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + memcpy(bias_data, user_bias_data, sizeof(float) * this->G * this->OC); + } else { + // oneDNN always need bias memory, if it's not provided in PP, let + // oneDNN allocate memory and set it to 0 + memset(bias_data, 0, sizeof(float) * this->G * this->OC); + } + + if (origin_mode == false && bias) { + for (int64_t i = 0; i < this->OC; ++i) { + bias_data[i] *= -1; + } + } + this->dev_ctx_.SetBlob(bias_key, memory_p); + } + return memory_p; + } + + protected: + // RNN dimensions + // N - Batch Size + // Ti - Max sentence length + // IC - Input Channels + // OC - Output Channels + // G - Number of gates + const int64_t N, Ti, IC, OC, G; + + // Memory size of weights, bias and h0 does not depend + // on Ti size, thus we need another key to cache them + std::string memory_key_; + dnnl::primitive_attr attr_; +}; + +template +void RunKernel(const phi::OneDNNContext& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto x_dims = x.dims(); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? phi::flatten_to_2d(x_dims, 1) + : x_dims; + + // Get tensor dimensions + const auto x_mat_dims_vec = phi::vectorize(x_mat_dims); + const auto weight_h_dims = phi::vectorize(weight_h.dims()); + const auto& input_lod = x.lod()[0]; + + // Calculate RNN dimensions + const int64_t N = input_lod.size() - 1; // Number of sentences (batches) + const int64_t Ti = // Max length of the sentence in a batch + [&input_lod]() { + size_t res = 0; + for (size_t i = 0; i < (input_lod.size() - 1); ++i) { + res = std::max(res, input_lod[i + 1] - input_lod[i]); + } + return res; + }(); + const int64_t IC = x_mat_dims_vec[1]; // Input channels + const int64_t OC = weight_h_dims[0]; // Output channels + + GRUOneDNNHandler handler(dev_ctx, + onednn_engine, + dev_ctx.GetPlace(), + &x, + &weight_h, + h0.get_ptr(), + is_reverse, + scale_data, + shift_data, + gate_activation, + activation, + scale_weights, + N, + Ti, + IC, + OC); + auto input_memory_p = handler.AcquireInputMemoryWithReorder(&x, is_reverse); + + std::shared_ptr h0_memory_p, weight_h_memory_p, + weight_x_memory_p; + + if (phi::TransToProtoVarType(weight_h.dtype()) == phi::ProtoDataType::FP32) { + h0_memory_p = handler.template AcquireH0Memory(h0.get_ptr()); + weight_x_memory_p = + handler.template AcquireWeightXMemory(&weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory(&weight_h, origin_mode); + } else if (phi::TransToProtoVarType(weight_h.dtype()) == + phi::ProtoDataType::BF16) { + h0_memory_p = + handler.template AcquireH0Memory(h0.get_ptr()); + weight_x_memory_p = + handler.template AcquireWeightXMemory( + &weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory( + &weight_h, origin_mode); + } else { + h0_memory_p = handler.template AcquireH0Memory(h0.get_ptr()); + weight_x_memory_p = + handler.template AcquireWeightXMemory(&weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory(&weight_h, origin_mode); + } + + auto bias_memory_p = handler.AcquireBiasMemory(bias.get_ptr(), origin_mode); + auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); + + std::unordered_map gru_args = { + {DNNL_ARG_SRC_LAYER, *input_memory_p}, + {DNNL_ARG_SRC_ITER, *h0_memory_p}, + {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, + {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, + {DNNL_ARG_BIAS, *bias_memory_p}, + {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; + + auto gru_forward_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + gru_forward_p->execute(astream, gru_args); + astream.wait(); + + auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); + auto* hidden_tmp_data = dev_ctx.template Alloc(hidden); + auto* hidden_data = phi::funcs::to_void_cast(hidden_tmp_data); + if (handler.is_NTC()) { + handler.reorderRNNdata(hidden_onednn_data, + hidden_data, + input_lod, + is_reverse, + RNNReorderType::NTC_PP); + } else { + handler.reorderRNNdata(hidden_onednn_data, + hidden_data, + input_lod, + is_reverse, + RNNReorderType::TNC_PP); + } +} + +template +void FusionGRUKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& h0, + const DenseTensor& weight_x, + const DenseTensor& weight_h, + const paddle::optional& bias, + const std::string& activation, + const std::string& gate_activation, + const bool is_reverse, + const bool use_seq, + const bool origin_mode, + const bool use_mkldnn, + const std::string& mkldnn_data_type, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + DenseTensor* reordered_h0, + DenseTensor* xx, + DenseTensor* batched_input, + DenseTensor* batched_out, + DenseTensor* hidden) { + const bool is_bf16 = std::is_same::value; + // BF16 does not support force output + if (!is_bf16 && force_fp32_output) { // NOLINT + RunKernel(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } else { + RunKernel(dev_ctx, + x, + h0, + weight_x, + weight_h, + bias, + activation, + gate_activation, + is_reverse, + use_seq, + origin_mode, + use_mkldnn, + mkldnn_data_type, + scale_data, + shift_data, + scale_weights, + force_fp32_output, + reordered_h0, + xx, + batched_input, + batched_out, + hidden); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fusion_gru, + OneDNN, + ONEDNN, + phi::fusion::FusionGRUKernel, + float, + phi::dtype::bfloat16, + uint8_t) {} diff --git a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc index 43caa13698b48f..6ba3d84b5eb0b8 100644 --- a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "glog/logging.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" namespace phi { namespace fusion { @@ -32,6 +35,8 @@ void Conv2dXPUKernelImpl(const Context& ctx, const paddle::optional& bias, const paddle::optional& branch, const paddle::optional& branch_max, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -66,14 +71,19 @@ void Conv2dXPUKernelImpl(const Context& ctx, int out_c = static_cast(filter_dims[0]); int win_h = static_cast(filter_dims[2]); int win_w = static_cast(filter_dims[3]); - auto* input_data = reinterpret_cast(x.data()); const float* input_max_data = x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data(); auto* filter_data = reinterpret_cast(filter.data()); auto* filter_max_data = filter_max.data(); + auto* scale_max_data = scale_max.get_ptr() == nullptr + ? nullptr + : scale_max.get_ptr()->data(); const XPUTypeOut* branch_data = nullptr; + const float* branch_max_data = branch_max.get_ptr() == nullptr + ? nullptr + : branch_max.get_ptr()->data(); auto* branch_tensor = branch.get_ptr(); xpu::ctx_guard RAII_GUARD(ctx.x_context()); if (branch_tensor != nullptr) { @@ -92,14 +102,15 @@ void Conv2dXPUKernelImpl(const Context& ctx, branch_data = branch_data_temp; } } - const float* branch_max_data = branch_max.get_ptr() == nullptr - ? nullptr - : branch_max.get_ptr()->data(); + const float* bias_data = bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); auto* out_data = reinterpret_cast(ctx.template Alloc(out)); auto* out_max_data = ctx.template Alloc(out_max); + out_max_data = out_max_in.get_ptr() != nullptr + ? const_cast(out_max_in.get_ptr()->data()) + : out_max_data; xpu::Activation_t act(static_cast(act_type)); if (act_type == xpu::Activation_t::LEAKY_RELU) { act.leaky_alpha = act_param; @@ -131,7 +142,7 @@ void Conv2dXPUKernelImpl(const Context& ctx, /* const TY* branch */ branch_data, /* const baidu::xpu::api::Activation_t& act */ act, /* const float* branch_maxptr */ branch_max_data, - /* const float* scale */ nullptr); + /* const float* scale */ scale_max_data); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu"); } @@ -145,6 +156,8 @@ void Conv2dXPUKernelImpl(const Context& ctx, bias, \ branch, \ branch_max, \ + scale_max, \ + out_max_in, \ paddings, \ dilations, \ strides, \ @@ -164,6 +177,8 @@ void Conv2dXPUKernel(const Context& ctx, const paddle::optional& bias, const paddle::optional& branch, const paddle::optional& branch_max, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -174,14 +189,118 @@ void Conv2dXPUKernel(const Context& ctx, DataType out_dtype, DenseTensor* out, DenseTensor* out_max) { - if (out_dtype == DataType::FLOAT32) { - CONV2D_XPU_KERNEL_IMPL(T, int16_t, float, int16_t); - } else if (out_dtype == DataType::FLOAT16) { - CONV2D_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t); - } else { - PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.", - DataTypeToString(out_dtype))); + // Dont use template T param + VLOG(4) << "Conv kernel type: " << x.dtype() << " ," << filter.dtype() << " ," + << out_dtype; + if (x.dtype() == DataType::FLOAT32) { + // float32/float16 kernel + if (filter.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(float, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL(float, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else if (filter.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(float, int8_t, float, int8_t); + } else if (out_dtype == DataType::INT8) { + CONV2D_XPU_KERNEL_IMPL(float, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + return; } + + if (x.dtype() == DataType::FLOAT16) { + // float16 kernel + if (filter.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(phi::dtype::float16, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL( + phi::dtype::float16, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else if (filter.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL( + phi::dtype::float16, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + CONV2D_XPU_KERNEL_IMPL(phi::dtype::float16, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + if (x.dtype() == DataType::INT8) { + if (filter.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(int8_t, int8_t, float, int8_t); + } else if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL(int8_t, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + CONV2D_XPU_KERNEL_IMPL(int8_t, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(filter.dtype()), + DataTypeToString(out_dtype))); } } // namespace fusion @@ -192,4 +311,5 @@ PD_REGISTER_KERNEL(conv2d_xpu, ALL_LAYOUT, phi::fusion::Conv2dXPUKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int8_t) {} diff --git a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc index 6a6721194e9a84..d6153eff096cb5 100644 --- a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "glog/logging.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" @@ -29,6 +30,8 @@ void FcXPUKernelImpl(const Context& ctx, const DenseTensor& w, const DenseTensor& w_max, const paddle::optional& bias, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -53,7 +56,13 @@ void FcXPUKernelImpl(const Context& ctx, bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + auto* scale_max_data = scale_max.get_ptr() == nullptr + ? nullptr + : scale_max.get_ptr()->data(); auto* out_max_data = ctx.template Alloc(out_max); + out_max_data = out_max_in.get_ptr() != nullptr + ? const_cast(out_max_in.get_ptr()->data()) + : out_max_data; xpu::Activation_t act(static_cast(act_type)); if (act_type == xpu::Activation_t::LEAKY_RELU) { act.leaky_alpha = act_alpha; @@ -80,7 +89,9 @@ void FcXPUKernelImpl(const Context& ctx, alpha, // alpha beta, // beta bias_data, // bias - act); + act, // act + scale_max_data); // scale + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu"); } @@ -92,6 +103,8 @@ void FcXPUKernelImpl(const Context& ctx, w, \ w_max, \ bias, \ + scale_max, \ + out_max_in, \ in_num_col_dims, \ transpose_x, \ alpha, \ @@ -108,6 +121,8 @@ void FcXPUKernel(const Context& ctx, const DenseTensor& w, const DenseTensor& w_max, const paddle::optional& bias, + const paddle::optional& scale_max, + const paddle::optional& out_max_in, int in_num_col_dims, bool transpose_x, float alpha, @@ -117,14 +132,119 @@ void FcXPUKernel(const Context& ctx, DataType out_dtype, DenseTensor* out, DenseTensor* out_max) { - if (out_dtype == DataType::FLOAT32) { - FC_XPU_KERNEL_IMPL(T, int16_t, float, int16_t); - } else if (out_dtype == DataType::FLOAT16) { - FC_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t); - } else { - PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.", - DataTypeToString(out_dtype))); + // Dont use template T param + VLOG(4) << "Fc kernel type: " << x.dtype() << " ," << w.dtype() << " ," + << out_dtype; + if (x.dtype() == DataType::FLOAT32) { + // float32/float16 kernel + if (w.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(float, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(float, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else if (w.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(float, int8_t, float, int8_t); + } else if (out_dtype == DataType::INT8) { + FC_XPU_KERNEL_IMPL(float, int8_t, int8_t, int8_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(float, int8_t, dtype::float16, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + if (x.dtype() == DataType::FLOAT16) { + // float16 kernel + if (w.dtype() == DataType::INT16) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(phi::dtype::float16, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL( + phi::dtype::float16, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else if (w.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(phi::dtype::float16, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + FC_XPU_KERNEL_IMPL(phi::dtype::float16, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + return; } + + if (x.dtype() == DataType::INT8) { + if (w.dtype() == DataType::INT8) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(int8_t, int8_t, float, int8_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(int8_t, int8_t, dtype::float16, int8_t); + } else if (out_dtype == DataType::INT8) { + FC_XPU_KERNEL_IMPL(int8_t, int8_t, int8_t, int8_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is " + "%s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); + } + return; + } + + PADDLE_THROW(phi::errors::Unimplemented( + "Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.", + DataTypeToString(x.dtype()), + DataTypeToString(w.dtype()), + DataTypeToString(out_dtype))); } } // namespace fusion @@ -135,4 +255,5 @@ PD_REGISTER_KERNEL(fc_xpu, ALL_LAYOUT, phi::fusion::FcXPUKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int8_t) {} diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 78c3723ceedcbd..c3c353859728b7 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -22,6 +22,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/norm_utils.cu.h" @@ -487,8 +488,8 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( template void BatchNormGradFunctor(const Context &ctx, const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -549,23 +550,41 @@ void BatchNormGradFunctor(const Context &ctx, ctx.template Alloc>(d_bias); } + auto *Scale = scale.get_ptr(); + auto *Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(ctx, {C}, static_cast(0)); + } + PADDLE_ENFORCE_EQ( - scale.dims().size(), + new_scale.dims().size(), 1UL, phi::errors::InvalidArgument( "The size of scale's dimensions must equal to 1. But received: " "the size of scale's dimensions is [%d], the dimensions of scale " "is [%s].", - scale.dims().size(), - scale.dims())); + new_scale.dims().size(), + new_scale.dims())); PADDLE_ENFORCE_EQ( - scale.dims()[0], + new_scale.dims()[0], C, phi::errors::InvalidArgument( "The first dimension of scale must equal to Channels[%d]. But " "received: the first dimension of scale is [%d]", C, - scale.dims()[0])); + new_scale.dims()[0])); auto dtype = phi::backends::gpu::CudnnDataType::type; #ifdef PADDLE_WITH_HIP @@ -713,8 +732,8 @@ void BatchNormGradFunctor(const Context &ctx, if (is_inplace) { inplace_functor(compute_format, transformed_x.data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), saved_mean_data, saved_var_data, epsilon, @@ -735,7 +754,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -750,7 +769,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -880,7 +899,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), dscale, dbias, mean_ptr, @@ -897,7 +916,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -912,7 +931,7 @@ void BatchNormGradFunctor(const Context &ctx, <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), + new_scale.template data>(), saved_mean_data, saved_var_data, C, @@ -969,7 +988,8 @@ void BatchNormGradFunctor(const Context &ctx, /*dxDesc=*/data_desc_, /*dxData=*/ctx.template Alloc(&transformed_d_x), /*dBnScaleBiasDesc=*/bn_param_desc_, - /*bnScaleData=*/scale.template data>(), + /*bnScaleData=*/ + new_scale.template data>(), /*bnBiasData=*/nullptr, /*dBnScaleData=*/ ctx.template Alloc>(d_scale), @@ -1000,7 +1020,7 @@ void BatchNormGradFunctor(const Context &ctx, data_desc_, ctx.template Alloc(&transformed_d_x), bn_param_desc_, - scale.template data>(), + new_scale.template data>(), ctx.template Alloc>(d_scale), ctx.template Alloc>(d_bias), epsilon, @@ -1023,7 +1043,7 @@ void BatchNormGradFunctor(const Context &ctx, BNBackwardData <<>>( d_y->data(), - scale.data>(), + new_scale.data>(), saved_mean_data, x.data(), saved_var_data, @@ -1051,7 +1071,7 @@ void BatchNormGradFunctor(const Context &ctx, BNBackwardData <<>>( d_y->data(), - scale.data>(), + new_scale.data>(), saved_mean_data, x.data(), saved_var_data, @@ -1080,7 +1100,7 @@ void BatchNormGradFunctor(const Context &ctx, BNBackwardData <<>>( d_y->data(), - scale.data>(), + new_scale.data>(), saved_mean_data, x.data(), saved_var_data, @@ -1134,8 +1154,8 @@ void BatchNormGradFunctor(const Context &ctx, auto px = x; inplace_functor(data_layout, ctx.template Alloc(&px), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), running_mean_data, running_var_data, epsilon, @@ -1152,14 +1172,15 @@ void BatchNormGradFunctor(const Context &ctx, if (data_layout == DataLayout::kNHWC) { if (d_x) { KeBNBackwardData - <<>>(d_y->data(), - scale.data>(), - running_var_data, - epsilon, - C, - H * W, - num, - d_x->data()); + <<>>( + d_y->data(), + new_scale.data>(), + running_var_data, + epsilon, + C, + H * W, + num, + d_x->data()); } if (d_scale && d_bias) { KeBNBackwardScaleBias @@ -1178,14 +1199,15 @@ void BatchNormGradFunctor(const Context &ctx, } else { if (d_x) { KeBNBackwardData - <<>>(d_y->data(), - scale.data>(), - running_var_data, - epsilon, - C, - H * W, - num, - d_x->data()); + <<>>( + d_y->data(), + new_scale.data>(), + running_var_data, + epsilon, + C, + H * W, + num, + d_x->data()); } if (d_scale && d_bias) { KeBNBackwardScaleBias @@ -1205,14 +1227,15 @@ void BatchNormGradFunctor(const Context &ctx, } else { if (d_x) { KeBNBackwardData - <<>>(d_y->data(), - scale.data>(), - running_var_data, - epsilon, - C, - H * W, - num, - d_x->data()); + <<>>( + d_y->data(), + new_scale.data>(), + running_var_data, + epsilon, + C, + H * W, + num, + d_x->data()); } if (d_scale && d_bias) { dim3 block; @@ -1262,8 +1285,8 @@ void BatchNormGradFunctor(const Context &ctx, template void BatchNormGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -1305,7 +1328,7 @@ template void BatchNormDoubleGradKernel( const Context &ctx, const DenseTensor &x, - const DenseTensor &scale, + const paddle::optional &scale, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -1338,10 +1361,20 @@ void BatchNormDoubleGradKernel( running_mean = mean.get_ptr(); running_variance = variance.get_ptr(); } + const auto &x_dims = x.dims(); + int N, C, H, W, D; + phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + auto *Scale = scale.get_ptr(); + phi::DenseTensor new_scale; + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(ctx, {C}, static_cast(1)); + } phi::funcs::NormDoubleGradFunctor(ctx, data_layout, &x, - &scale, + &new_scale, &y_grad, &saved_mean, &saved_variance, diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 3b73935699babb..20aa02a5f24856 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -29,6 +29,7 @@ namespace cub = hipcub; #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/norm_utils.cu.h" @@ -515,8 +516,8 @@ void BatchNormKernel(const Context &ctx, const DenseTensor &x, const DenseTensor &mean, const DenseTensor &variance, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, bool is_test, float momentum, float epsilon_f, @@ -551,6 +552,24 @@ void BatchNormKernel(const Context &ctx, auto dtype = phi::backends::gpu::CudnnDataType::type; + auto *Scale = scale.get_ptr(); + auto *Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(ctx, {C}, static_cast(0)); + } + #ifdef PADDLE_WITH_HIP auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; @@ -722,8 +741,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -735,8 +754,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -779,8 +798,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -803,8 +822,8 @@ void BatchNormKernel(const Context &ctx, est_mean->template data>(), // est_var->template data>(), inv_var_ptr, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -816,8 +835,8 @@ void BatchNormKernel(const Context &ctx, transformed_x.template data(), est_mean->template data>(), est_var->template data>(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -838,8 +857,8 @@ void BatchNormKernel(const Context &ctx, data_desc_, ctx.template Alloc(&transformed_y), bn_param_desc_, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), est_mean->template data>(), est_var->template data>(), epsilon)); @@ -884,8 +903,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -900,8 +919,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1002,8 +1021,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DCompStat <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1021,8 +1040,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DWriteRes<<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1063,8 +1082,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DChannelLastCompStat <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1083,8 +1102,8 @@ void BatchNormKernel(const Context &ctx, BNForwardTraining2DChannelLastWriteRes <<>>( transformed_x.template data(), - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), C, N, H * W * D, @@ -1155,8 +1174,8 @@ void BatchNormKernel(const Context &ctx, data_desc_, transformed_y.template data(), bn_param_desc_, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), this_factor, ctx.template Alloc>(mean_out), ctx.template Alloc>(variance_out), @@ -1180,8 +1199,8 @@ void BatchNormKernel(const Context &ctx, data_desc_, ctx.template Alloc(&transformed_y), bn_param_desc_, - scale.template data>(), - bias.template data>(), + new_scale.template data>(), + new_bias.template data>(), this_factor, ctx.template Alloc>(mean_out), ctx.template Alloc>(variance_out), diff --git a/paddle/phi/kernels/gpu/contiguous_kernel.cu b/paddle/phi/kernels/gpu/contiguous_kernel.cu index 357e104afb01c8..49b253effd9451 100644 --- a/paddle/phi/kernels/gpu/contiguous_kernel.cu +++ b/paddle/phi/kernels/gpu/contiguous_kernel.cu @@ -31,12 +31,12 @@ __global__ void ContiguousCaseZeroFunc( blockDim.z * blockDim.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; + int64_t coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; #pragma unroll for (int dim = N - 1; dim >= 0; --dim) { diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 1f1453a0c64088..783d94e8e7bb27 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -324,6 +324,9 @@ PD_REGISTER_KERNEL(divide_grad, phi::dtype::float16, phi::dtype::bfloat16, double, + int8_t, + uint8_t, + int16_t, int, int64_t, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu index f2be0f073a87de..5bb59357bc976a 100644 --- a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu @@ -24,14 +24,14 @@ namespace phi { template void FrobeniusNormKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { - reduce_all = recompute_reduce_all(x, dims, reduce_all); + reduce_all = recompute_reduce_all(x, dims.GetData(), reduce_all); auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); SqrtKernel(dev_ctx, *out, out); } diff --git a/paddle/phi/kernels/gpu/reduce_kernel.cu b/paddle/phi/kernels/gpu/reduce_kernel.cu index 969a3dd1d9ca58..d9714d37febd9b 100644 --- a/paddle/phi/kernels/gpu/reduce_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_kernel.cu @@ -370,6 +370,8 @@ PD_REGISTER_KERNEL(sum_grad, double, phi::dtype::float16, phi::dtype::bfloat16, + int8_t, + uint8_t, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h index 385ea68e6e7075..7954441f30c2b3 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h @@ -25,13 +25,13 @@ void FrobeniusNormGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* dx) { - reduce_all = recompute_reduce_all(x, axis, reduce_all); + reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all); ReduceGradKernel( - ctx, x, out, dout, axis, keep_dim, reduce_all, dx); + ctx, x, out, dout, axis.GetData(), keep_dim, reduce_all, dx); } } // namespace phi diff --git a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h index 7dbc3ab3af7ba6..eab028a1caccfc 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h @@ -23,13 +23,13 @@ namespace phi { template void FrobeniusNormKernel(const Context& ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DenseTensor* out) { - reduce_all = recompute_reduce_all(x, axis, reduce_all); + reduce_all = recompute_reduce_all(x, axis.GetData(), reduce_all); Reduce( - ctx, x, reduce_all, axis, keep_dim, x.dtype(), out); + ctx, x, reduce_all, axis.GetData(), keep_dim, x.dtype(), out); } } // namespace phi diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index 584e026241bde3..d40f1bd7a7062b 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -307,6 +307,9 @@ PD_REGISTER_KERNEL(divide, phi::DivideKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, float16, diff --git a/paddle/phi/kernels/kps/reduce_kernel.cu b/paddle/phi/kernels/kps/reduce_kernel.cu index 1bc00cf11cbdb2..506bd36e828bc5 100644 --- a/paddle/phi/kernels/kps/reduce_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_kernel.cu @@ -369,6 +369,8 @@ PD_REGISTER_KERNEL(sum_raw, double, float16, bfloat16, + int8_t, + uint8_t, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu b/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu index f07164bc16885b..394d525b15f0f0 100644 --- a/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/legacy/kps/elementwise_kernel.cu @@ -77,6 +77,9 @@ PD_REGISTER_KERNEL(divide_raw, phi::DivideRawKernel, float, double, + int8_t, + uint8_t, + int16_t, int, int64_t, float16, diff --git a/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc index e3e0fef11e9133..e97d5c5f96cb52 100644 --- a/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_grad_kernel.cc @@ -21,8 +21,8 @@ template void phi::BatchNormGradFunctor( \ const ::phi::backend##Context& dev_ctx, \ const DenseTensor& x, \ - const DenseTensor& scale, \ - const DenseTensor& bias, \ + const paddle::optional& scale, \ + const paddle::optional& bias, \ const paddle::optional& mean, \ const paddle::optional& variance, \ const DenseTensor& saved_mean, \ @@ -45,8 +45,8 @@ namespace phi { template void BatchNormGradFunctor(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, @@ -63,8 +63,10 @@ void BatchNormGradFunctor(const Context& dev_ctx, DenseTensor* x_grad, DenseTensor* scale_grad, DenseTensor* bias_grad) { + auto Scale = scale.get_ptr(); + auto Bias = bias.get_ptr(); funcs::BatchNormOneDNNHandler handler( - dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad); + dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, Scale, &y_grad); T* diff_scale_data = dev_ctx.template Alloc(scale_grad); T* diff_shift_data = dev_ctx.template Alloc(bias_grad); @@ -73,7 +75,7 @@ void BatchNormGradFunctor(const Context& dev_ctx, auto mean_memory = handler.AcquireMeanMemory(&saved_mean); auto variance_memory = handler.AcquireVarianceMemory(&saved_variance); auto diff_dst_memory = handler.AcquireDiffDstMemory(&y_grad); - auto scaleshift_mems = handler.AcquireScaleShiftMemory(&scale, &bias); + auto scaleshift_mems = handler.AcquireScaleShiftMemory(Scale, Bias); auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad); auto diff_scaleshift_mems = handler.AcquireDiffScaleShiftMemory(diff_scale_data, diff_shift_data); @@ -100,8 +102,8 @@ void BatchNormGradFunctor(const Context& dev_ctx, template void BatchNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, const paddle::optional& mean, const paddle::optional& variance, const DenseTensor& saved_mean, diff --git a/paddle/phi/kernels/onednn/batch_norm_kernel.cc b/paddle/phi/kernels/onednn/batch_norm_kernel.cc index 61172c074e26a0..56bb933359f96f 100644 --- a/paddle/phi/kernels/onednn/batch_norm_kernel.cc +++ b/paddle/phi/kernels/onednn/batch_norm_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "glog/logging.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -28,8 +29,8 @@ void BatchNormKernel(const Context &dev_ctx, const DenseTensor &x, const DenseTensor &mean, const DenseTensor &variance, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, bool is_test, float momentum, float epsilon, @@ -49,6 +50,8 @@ void BatchNormKernel(const Context &dev_ctx, ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("fuse_with_relu")) : false; + auto Scale = scale.get_ptr(); + auto Bias = bias.get_ptr(); funcs::BatchNormOneDNNHandler handler(dev_ctx.GetEngine(), dev_ctx.GetPlace(), &x, @@ -58,7 +61,7 @@ void BatchNormKernel(const Context &dev_ctx, test_mode); auto src_memory = handler.AcquireSrcMemory(&x); - auto scaleshift_mems = handler.AcquireScaleShiftMemory(&scale, &bias); + auto scaleshift_mems = handler.AcquireScaleShiftMemory(Scale, Bias); auto dst_memory = handler.AcquireDstMemory(y); auto batch_norm_p = handler.AcquireForwardPrimitive(); @@ -87,7 +90,7 @@ void BatchNormKernel(const Context &dev_ctx, astream.wait(); if (!global_stats) { - const unsigned int C = phi::vectorize(scale.dims())[0]; + const unsigned int C = phi::vectorize(Scale->dims())[0]; // mkldnn only compute stats for current batch // so we need compute momentum stats via Eigen lib diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index 6f2dc34673f670..10495a286df362 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -66,6 +66,8 @@ PD_REGISTER_KERNEL(sum, int16_t, int, int64_t, + uint8_t, + int8_t, complex64, complex128) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); diff --git a/paddle/phi/kernels/squeeze_grad_kernel.cc b/paddle/phi/kernels/squeeze_grad_kernel.cc index 473acf9d7a1d15..a8a788e817472b 100644 --- a/paddle/phi/kernels/squeeze_grad_kernel.cc +++ b/paddle/phi/kernels/squeeze_grad_kernel.cc @@ -76,6 +76,7 @@ PD_REGISTER_KERNEL(squeeze_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/squeeze_kernel.cc b/paddle/phi/kernels/squeeze_kernel.cc index d495b040921b59..a8d24423fcb45c 100644 --- a/paddle/phi/kernels/squeeze_kernel.cc +++ b/paddle/phi/kernels/squeeze_kernel.cc @@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(squeeze_infer, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, @@ -129,6 +130,7 @@ PD_REGISTER_KERNEL(squeeze, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/unsqueeze_grad_kernel.cc b/paddle/phi/kernels/unsqueeze_grad_kernel.cc index 3c119db2c73d6e..d26753ece47cdc 100644 --- a/paddle/phi/kernels/unsqueeze_grad_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_grad_kernel.cc @@ -77,6 +77,7 @@ PD_REGISTER_KERNEL(unsqueeze_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/unsqueeze_kernel.cc b/paddle/phi/kernels/unsqueeze_kernel.cc index c08c31da4ef0ce..3e1c8f8cc15e12 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_kernel.cc @@ -27,7 +27,7 @@ void UnsqueezeInferKernel(const Context& dev_ctx, DenseTensor* out) { auto x_dims = x.dims(); auto out_dims = out->dims(); - if (axes.FromTensor()) { + if (axes.FromTensor() && out->dims()[0] == -1) { out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims); } out->Resize(out_dims); @@ -124,6 +124,7 @@ PD_REGISTER_KERNEL(unsqueeze_infer, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, @@ -137,6 +138,7 @@ PD_REGISTER_KERNEL(unsqueeze, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index efac9b30ae2eb2..c9b1136793e5eb 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -580,8 +580,13 @@ PD_REGISTER_KERNEL(leaky_relu, phi::LeakyReluKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(sqrt, + XPU, + ALL_LAYOUT, + phi::SqrtKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc index 09e62bbfd4bde1..863bc2759b39a3 100644 --- a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/batch_norm_grad_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" @@ -72,8 +73,8 @@ static int CalculateInvVar(xpu::Context *ctx, template void BatchNormGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &scale, - const DenseTensor &bias, + const paddle::optional &scale, + const paddle::optional &bias, const paddle::optional &mean, const paddle::optional &variance, const DenseTensor &saved_mean, @@ -133,9 +134,27 @@ void BatchNormGradKernel(const Context &dev_ctx, W = W * D; + auto *Scale = scale.get_ptr(); + auto *Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(dev_ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(dev_ctx, {C}, static_cast(0)); + } + const auto *x_data = reinterpret_cast(x.data()); const auto *d_y_data = reinterpret_cast(y_grad.data()); - const auto *scale_data = scale.data(); + const auto *scale_data = new_scale.data(); // init output XPUType *x_grad_data = nullptr; @@ -151,22 +170,22 @@ void BatchNormGradKernel(const Context &dev_ctx, } PADDLE_ENFORCE_EQ( - scale.dims().size(), + new_scale.dims().size(), 1UL, phi::errors::InvalidArgument( "The size of scale's dimensions must equal to 1. But received: " "the size of scale's dimensions is [%d], the dimensions of scale " "is [%s].", - scale.dims().size(), - scale.dims())); + new_scale.dims().size(), + new_scale.dims())); PADDLE_ENFORCE_EQ( - scale.dims()[0], + new_scale.dims()[0], C, phi::errors::InvalidArgument( "The first dimension of scale must equal to Channels[%d]. But " "received: the first dimension of scale is [%d]", C, - scale.dims()[0])); + new_scale.dims()[0])); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); @@ -203,8 +222,8 @@ void BatchNormGradKernel(const Context &dev_ctx, : saved_mean.data(); r = CalculateInvBNY(dev_ctx.x_context(), x_fp32_data, - scale.data(), - bias.data(), + new_scale.data(), + new_bias.data(), mean_data, inv_std_data, N, diff --git a/paddle/phi/kernels/xpu/batch_norm_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_kernel.cc index e2f2d28182b67d..2abb1686daed98 100644 --- a/paddle/phi/kernels/xpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" @@ -25,8 +26,8 @@ void BatchNormKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mean, const DenseTensor& variance, - const DenseTensor& scale, - const DenseTensor& bias, + const paddle::optional& scale, + const paddle::optional& bias, bool is_test, float momentum, float epsilon, @@ -69,9 +70,27 @@ void BatchNormKernel(const Context& dev_ctx, W = W * D; + auto* Scale = scale.get_ptr(); + auto* Bias = bias.get_ptr(); + + phi::DenseTensor new_scale; + phi::DenseTensor new_bias; + + if (Scale) { + new_scale = scale.get(); + } else { + new_scale = phi::Full(dev_ctx, {C}, static_cast(1)); + } + + if (Bias) { + new_bias = bias.get(); + } else { + new_bias = phi::Full(dev_ctx, {C}, static_cast(0)); + } + const auto* x_data = reinterpret_cast(x.data()); - const auto* scale_data = scale.data(); - const auto* bias_data = bias.data(); + const auto* scale_data = new_scale.data(); + const auto* bias_data = new_bias.data(); // alloc memory auto* y_data = reinterpret_cast(dev_ctx.template Alloc(y)); diff --git a/paddle/phi/kernels/xpu/dequantization_kernel.cc b/paddle/phi/kernels/xpu/dequantization_kernel.cc new file mode 100644 index 00000000000000..9dc9868e75fd96 --- /dev/null +++ b/paddle/phi/kernels/xpu/dequantization_kernel.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void DeQuantizeKernelImpl(const Context& ctx, + const DenseTensor& x, + float scale, + DenseTensor* y) { + using XPUInX = typename XPUTypeTrait::Type; + using XPUOutY = typename XPUTypeTrait::Type; + + auto* y_data = ctx.template Alloc(y); + const auto* x_data = x.data(); + int64_t len = x.numel(); + int max_ptr_size = ctx.x_context()->max_ptr_size(); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + auto max_data = RAII_GUARD.alloc_l3_or_gm(max_ptr_size); + int r = xpu::constant(ctx.x_context(), max_data, max_ptr_size, scale); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::dequantization( + ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + len, + max_data); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dequantization"); +} + +template +void DeQuantizeKernel(const Context& ctx, + const DenseTensor& x, + DataType out_dtype, + float scale, + DenseTensor* y) { + switch (out_dtype) { + case DataType::FLOAT32: + DeQuantizeKernelImpl(ctx, x, scale, y); + break; + case DataType::FLOAT16: + DeQuantizeKernelImpl(ctx, x, scale, y); + break; + default: + PADDLE_THROW(phi::errors::Unavailable( + "Not supported dequantize data type from %d -> %d ", + x.dtype(), + out_dtype)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + dequantize_xpu, XPU, ALL_LAYOUT, phi::DeQuantizeKernel, int16_t, int8_t) {} diff --git a/paddle/phi/kernels/xpu/quantization_kernel.cc b/paddle/phi/kernels/xpu/quantization_kernel.cc new file mode 100644 index 00000000000000..32b28b034e2dab --- /dev/null +++ b/paddle/phi/kernels/xpu/quantization_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void QuantizeKernelImpl(const Context& ctx, + const DenseTensor& x, + float scale, + DenseTensor* y) { + using XPUInX = typename XPUTypeTrait::Type; + using XPUOutY = typename XPUTypeTrait::Type; + + auto* y_data = ctx.template Alloc(y); + const auto* x_data = x.data(); + int64_t len = x.numel(); + int max_ptr_size = ctx.x_context()->max_ptr_size(); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + auto max_data = RAII_GUARD.alloc_l3_or_gm(max_ptr_size); + int r = xpu::constant(ctx.x_context(), max_data, max_ptr_size, scale); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::quantization( + ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + len, + max_data); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "quantization"); +} + +template +void QuantizeKernel(const Context& ctx, + const DenseTensor& x, + DataType out_dtype, + float scale, + DenseTensor* y) { + switch (out_dtype) { + case DataType::INT16: + QuantizeKernelImpl(ctx, x, scale, y); + break; + case DataType::INT8: + QuantizeKernelImpl(ctx, x, scale, y); + break; + default: + PADDLE_THROW(phi::errors::Unavailable( + "Not supported quantize data type from %d -> %d ", + x.dtype(), + out_dtype)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(quantize_xpu, + XPU, + ALL_LAYOUT, + phi::QuantizeKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/split_kernel.cc b/paddle/phi/kernels/xpu/split_kernel.cc index 11a20f6f179466..e3aeb7ffdfbe32 100644 --- a/paddle/phi/kernels/xpu/split_kernel.cc +++ b/paddle/phi/kernels/xpu/split_kernel.cc @@ -74,7 +74,8 @@ PD_REGISTER_KERNEL(split, float, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(split_with_num, XPU, ALL_LAYOUT, @@ -82,4 +83,5 @@ PD_REGISTER_KERNEL(split_with_num, float, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc index 043d2c8e3df5ad..71b2187bddce10 100644 --- a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc @@ -65,6 +65,7 @@ PD_REGISTER_KERNEL(transpose_grad, phi::TransposeGradKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t, int, bool) {} diff --git a/paddle/phi/kernels/xpu/transpose_kernel.cc b/paddle/phi/kernels/xpu/transpose_kernel.cc index 398a2281dcea8b..dd985ddc7ebc58 100644 --- a/paddle/phi/kernels/xpu/transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_kernel.cc @@ -60,6 +60,7 @@ PD_REGISTER_KERNEL(transpose, phi::TransposeKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t, int, bool) {} diff --git a/paddle/pir/core/builder.cc b/paddle/pir/core/builder.cc index 6a1608c84ab857..2484e02f5156e8 100644 --- a/paddle/pir/core/builder.cc +++ b/paddle/pir/core/builder.cc @@ -73,6 +73,9 @@ DoubleAttribute Builder::double_attr(double value) { Int32Attribute Builder::int32_attr(int32_t value) { return Int32Attribute::get(context_, value); } +IndexAttribute Builder::index_attr(int64_t value) { + return IndexAttribute::get(context_, value); +} Int64Attribute Builder::int64_attr(int64_t value) { return Int64Attribute::get(context_, value); } diff --git a/paddle/pir/core/builder.h b/paddle/pir/core/builder.h index 72c8494cf89067..ae1887230c6661 100644 --- a/paddle/pir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -39,6 +39,7 @@ class BoolAttribute; class FloatAttribute; class DoubleAttribute; class Int32Attribute; +class IndexAttribute; class Int64Attribute; class ArrayAttribute; class PointerAttribute; @@ -131,6 +132,7 @@ class Builder { IR_API FloatAttribute float_attr(float value); IR_API DoubleAttribute double_attr(double value); IR_API Int32Attribute int32_attr(int32_t value); + IR_API IndexAttribute index_attr(int64_t value); IR_API Int64Attribute int64_attr(int64_t value); IR_API ArrayAttribute array_attr(const std::vector &value); IR_API PointerAttribute pointer_attr(void *value); diff --git a/paddle/pir/core/builtin_attribute.cc b/paddle/pir/core/builtin_attribute.cc index e14a424c32c8e5..0958e247984140 100644 --- a/paddle/pir/core/builtin_attribute.cc +++ b/paddle/pir/core/builtin_attribute.cc @@ -24,6 +24,8 @@ double DoubleAttribute::data() const { return storage()->data(); } int32_t Int32Attribute::data() const { return storage()->data(); } +int64_t IndexAttribute::data() const { return storage()->data(); } + int64_t Int64Attribute::data() const { return storage()->data(); } void* PointerAttribute::data() const { return storage()->data(); } @@ -86,6 +88,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::BoolAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::FloatAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::DoubleAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int32Attribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::IndexAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int64Attribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::PointerAttribute) diff --git a/paddle/pir/core/builtin_attribute.h b/paddle/pir/core/builtin_attribute.h index 7d3f86144915cb..b09bff8750c402 100644 --- a/paddle/pir/core/builtin_attribute.h +++ b/paddle/pir/core/builtin_attribute.h @@ -55,6 +55,15 @@ class IR_API Int32Attribute : public Attribute { int32_t data() const; }; +class IR_API IndexAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(IndexAttribute, IndexAttributeStorage); + + int64_t data() const; +}; + class IR_API Int64Attribute : public Attribute { public: using Attribute::Attribute; @@ -123,6 +132,7 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::FloatAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::DoubleAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int32Attribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int64Attribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::IndexAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::PointerAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::TypeAttribute) diff --git a/paddle/pir/core/builtin_attribute_storage.h b/paddle/pir/core/builtin_attribute_storage.h index fd9dd6eb871283..2ab13326d3ebc6 100644 --- a/paddle/pir/core/builtin_attribute_storage.h +++ b/paddle/pir/core/builtin_attribute_storage.h @@ -52,6 +52,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t); +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(IndexAttributeStorage, int64_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(TypeAttributeStorage, Type); diff --git a/paddle/pir/core/builtin_dialect.cc b/paddle/pir/core/builtin_dialect.cc index 60575da6d9472c..0fef066ec47271 100644 --- a/paddle/pir/core/builtin_dialect.cc +++ b/paddle/pir/core/builtin_dialect.cc @@ -46,6 +46,7 @@ void BuiltinDialect::initialize() { DoubleAttribute, PointerAttribute, Int32Attribute, + IndexAttribute, Int64Attribute, ArrayAttribute, TypeAttribute>(); diff --git a/paddle/pir/core/builtin_op.h b/paddle/pir/core/builtin_op.h index 19ca96b0526928..64649f29175e6f 100644 --- a/paddle/pir/core/builtin_op.h +++ b/paddle/pir/core/builtin_op.h @@ -204,7 +204,7 @@ class IR_API ConstantOp : public Op { Type output_type); void VerifySig() const; - + OpResult out() { return result(0); } Attribute value() const; }; diff --git a/paddle/pir/core/builtin_type_interfaces.h b/paddle/pir/core/builtin_type_interfaces.h index f1df893f89e3ff..40ad58313a0d3d 100644 --- a/paddle/pir/core/builtin_type_interfaces.h +++ b/paddle/pir/core/builtin_type_interfaces.h @@ -40,27 +40,17 @@ class ShapedTypeInterface : public TypeInterfaceBase { template struct Model : public Concept { - static inline DataType getElementType(Type type) { + static inline DataType GetElementType(Type type) { return pir::cast(type).dtype(); } - static inline DDim getShape(Type type) { + static inline DDim GetShape(Type type) { return pir::cast(type).dims(); } - Model() : Concept(getElementType, getShape) {} + Model() : Concept(GetElementType, GetShape) {} }; - /// Constructor - ShapedTypeInterface(std::nullptr_t) // NOLINT - : TypeInterfaceBase(Type()), impl_(nullptr) {} - - explicit ShapedTypeInterface(Type type = Type()) - : TypeInterfaceBase(type), - impl_(type - ? type.abstract_type().GetInterfaceImpl() - : nullptr) {} - ShapedTypeInterface(Type type, Concept *impl) : TypeInterfaceBase(type), impl_(impl) {} diff --git a/paddle/pir/core/infer_type_op_interface.cc b/paddle/pir/core/infer_type_op_interface.cc new file mode 100644 index 00000000000000..b238daca2045ff --- /dev/null +++ b/paddle/pir/core/infer_type_op_interface.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/core/infer_type_op_interface.h" + +namespace pir { + +bool InferShapedTypeOpInterface::ReifyReturnTypeShapes( + Builder& builder, + std::vector operands, + std::vector& reified_return_shapes) { + return impl_->reify_return_type_shapes( + builder, operands, reified_return_shapes); +} +} // namespace pir + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::InferShapedTypeOpInterface) diff --git a/paddle/pir/core/infer_type_op_interface.h b/paddle/pir/core/infer_type_op_interface.h new file mode 100644 index 00000000000000..6acef20c023404 --- /dev/null +++ b/paddle/pir/core/infer_type_op_interface.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/op_base.h" + +// Type inference is currently modelled executionally for operation creation +// using the `InferMetaInterface`. While `InferShapedTypeOpInterface` is used to +// implement the shape and element type inference. The return type can often be +// deduced from the deduced return shape and elemental type (queryable from +// `InferShapedTypeOpInterface`) and so type inference for tensor types can be +// implemented with `InferShapedTypeOpInterface`. + +namespace pir { + +class InferShapedTypeOpInterface + : public pir::OpInterfaceBase { + public: + /// Defined these methods with the interface. + struct Concept { + explicit Concept(bool (*reify_return_type_shapes)( + Builder& builder, // NOLINT + std::vector operands, // NOLINT + std::vector& reified_return_shapes)) // NOLINT + : reify_return_type_shapes(reify_return_type_shapes) {} + bool (*reify_return_type_shapes)( + Builder& builder, + std::vector operands, + std::vector& reified_return_shapes); // NOLINT + }; + + template + struct Model : public Concept { + static inline bool ReifyReturnTypeShapes( + Builder& builder, // NOLINT + std::vector operands, // NOLINT + std::vector& reified_return_shapes) { // NOLINT + return ConcreteOp::ReifyReturnTypeShapes( + builder, operands, reified_return_shapes); + } + + Model() : Concept(ReifyReturnTypeShapes) {} + }; + + /// Constructor + InferShapedTypeOpInterface(Operation* op, Concept* impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + bool ReifyReturnTypeShapes( + Builder& builder, // NOLINT + std::vector operands, // NOLINT + std::vector& reified_return_shapes); // NOLINT + + private: + Concept* impl_; +}; + +} // namespace pir + +IR_DECLARE_EXPLICIT_TYPE_ID(pir::InferShapedTypeOpInterface) diff --git a/paddle/pir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc index 81cb3b4bcf2244..d5d4819256bf64 100644 --- a/paddle/pir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -110,6 +110,8 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { os << "(Int32)" << i.data(); } else if (auto i = attr.dyn_cast()) { os << "(Int64)" << i.data(); + } else if (auto i = attr.dyn_cast()) { + os << "(Index)" << i.data(); } else if (auto p = attr.dyn_cast()) { os << "(Pointer)" << p.data(); } else if (auto arr = attr.dyn_cast()) { diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index f0710ff5ec6297..217a34a6315369 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -141,6 +141,7 @@ class Op : public OpBase { using InterfaceList = typename Filter>::Type; + // TODO(zhangbopd): Use classof static ConcreteOp dyn_cast(Operation *op) { if (op && op->info().id() == TypeId::get()) { return ConcreteOp(op); diff --git a/paddle/pir/core/op_result.h b/paddle/pir/core/op_result.h index 8860473fe33395..c6639eff442750 100644 --- a/paddle/pir/core/op_result.h +++ b/paddle/pir/core/op_result.h @@ -30,6 +30,7 @@ class IR_API OpResult : public Value { public: OpResult(std::nullptr_t ptr = nullptr) : Value(ptr){}; // NOLINT Operation *owner() const; + // Return the result index of this op result. uint32_t index() const; bool operator==(const OpResult &other) const; diff --git a/paddle/pir/core/op_trait.cc b/paddle/pir/core/op_trait.cc index ccea4e3f06d9b9..94d800e2944f2a 100644 --- a/paddle/pir/core/op_trait.cc +++ b/paddle/pir/core/op_trait.cc @@ -16,9 +16,9 @@ #include "paddle/pir/core/enforce.h" #include "paddle/pir/core/type_util.h" -namespace pir::detail { +namespace { -void VerifySameOperandsShapeTrait(Operation *op) { +void VerifySameOperandsShapeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsShapeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -39,7 +39,7 @@ void VerifySameOperandsShapeTrait(Operation *op) { op->name()); } -void VerifySameOperandsAndResultShapeTrait(Operation *op) { +void VerifySameOperandsAndResultShapeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsAndResultShapeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -73,7 +73,7 @@ void VerifySameOperandsAndResultShapeTrait(Operation *op) { op->name()); } -void VerifySameOperandsElementTypeTrait(Operation *op) { +void VerifySameOperandsElementTypeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsElementTypeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -91,7 +91,7 @@ void VerifySameOperandsElementTypeTrait(Operation *op) { } } -void VerifySameOperandsAndResultElementTypeTrait(Operation *op) { +void VerifySameOperandsAndResultElementTypeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsAndResultElementTypeTrait for : " << op->name(); @@ -126,7 +126,7 @@ void VerifySameOperandsAndResultElementTypeTrait(Operation *op) { } } -void VerifySameOperandsAndResultTypeTrait(Operation *op) { +void VerifySameOperandsAndResultTypeTrait(pir::Operation *op) { VLOG(4) << "Verify SameOperandsAndResultTypeTrait for : " << op->name(); IR_ENFORCE(op->num_operands() > 0, @@ -169,7 +169,7 @@ void VerifySameOperandsAndResultTypeTrait(Operation *op) { } } -void VerifySameTypeOperandsTrait(Operation *op) { +void VerifySameTypeOperandsTrait(pir::Operation *op) { VLOG(4) << "Verify SameTypeOperandsTrait for : " << op->name(); // For zero or only one operand. @@ -186,7 +186,40 @@ void VerifySameTypeOperandsTrait(Operation *op) { } } -} // namespace pir::detail +void VerifyOneResultTrait(pir::Operation *op) { + IR_ENFORCE(op->num_results() == 1, + "Op %s with OneResultTrait requires 1 result, but got %u results.", + op->name(), + op->num_results()); +} +} // namespace + +namespace pir { +void SameOperandsShapeTrait::Verify(Operation *op) { + return VerifySameOperandsShapeTrait(op); +} + +void SameOperandsAndResultShapeTrait::Verify(Operation *op) { + return VerifySameOperandsAndResultShapeTrait(op); +} + +void SameOperandsElementTypeTrait::Verify(Operation *op) { + return VerifySameOperandsElementTypeTrait(op); +} + +void SameOperandsAndResultElementTypeTrait::Verify(Operation *op) { + return VerifySameOperandsAndResultElementTypeTrait(op); +} + +void SameOperandsAndResultTypeTrait::Verify(Operation *op) { + return VerifySameOperandsAndResultTypeTrait(op); +} +void SameTypeOperandsTrait::Verify(Operation *op) { + return VerifySameTypeOperandsTrait(op); +} + +void OneResultTrait::Verify(Operation *op) { return VerifyOneResultTrait(op); } +} // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsShapeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultShapeTrait) @@ -194,3 +227,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::OneResultTrait) diff --git a/paddle/pir/core/op_trait.h b/paddle/pir/core/op_trait.h index 760799fd16165d..b55352c1647657 100644 --- a/paddle/pir/core/op_trait.h +++ b/paddle/pir/core/op_trait.h @@ -18,15 +18,6 @@ namespace pir { -namespace detail { -void VerifySameOperandsShapeTrait(Operation *op); -void VerifySameOperandsAndResultShapeTrait(Operation *op); -void VerifySameOperandsElementTypeTrait(Operation *op); -void VerifySameOperandsAndResultElementTypeTrait(Operation *op); -void VerifySameOperandsAndResultTypeTrait(Operation *op); -void VerifySameTypeOperandsTrait(Operation *op); -} // namespace detail - /// /// \brief Provides verification for ops that are known to have the /// same operand shape. @@ -35,9 +26,7 @@ class SameOperandsShapeTrait : public pir::OpTraitBase { public: explicit SameOperandsShapeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsShapeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -49,9 +38,7 @@ class SameOperandsAndResultShapeTrait public: explicit SameOperandsAndResultShapeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsAndResultShapeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -63,9 +50,7 @@ class SameOperandsElementTypeTrait public: explicit SameOperandsElementTypeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsElementTypeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -77,9 +62,7 @@ class SameOperandsAndResultElementTypeTrait public: explicit SameOperandsAndResultElementTypeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsAndResultElementTypeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -93,9 +76,7 @@ class SameOperandsAndResultTypeTrait explicit SameOperandsAndResultTypeTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameOperandsAndResultTypeTrait(op); - } + static void Verify(Operation *op); }; /// @@ -106,9 +87,26 @@ class SameTypeOperandsTrait : public pir::OpTraitBase { public: explicit SameTypeOperandsTrait(pir::Operation *op) : pir::OpTraitBase(op) {} - static void Verify(Operation *op) { - return detail::VerifySameTypeOperandsTrait(op); + static void Verify(Operation *op); +}; + +/// +/// \brief This trait provides return value APIs for ops that are known to have +/// a single result returned by GetType(). +/// +class OneResultTrait : public OpTraitBase { + public: + // Replace all uses of 'this' value with the new value, updating anything + // in the IR that uses 'this' to use the other value instead. + void ReplaceAllUsesWith(Value new_value) { + this->operation()->result(0).ReplaceAllUsesWith(new_value); + } + + // Replace all uses of 'this' value with the result of 'op'. + void ReplaceAllUsesWith(Operation *op) { + this->operation()->ReplaceAllUsesWith(op->result(0)); } + static void Verify(Operation *op); }; } // namespace pir @@ -119,3 +117,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait) IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait) IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait) IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(pir::OneResultTrait) diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 0dedeafc9ae710..1a6666fcc2a9b3 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -123,7 +123,12 @@ Operation *Operation::Create(const std::vector &inputs, // 0. Verify if (op_info) { - op_info.VerifySig(op); + try { + op_info.VerifySig(op); + } catch (const pir::IrNotMetException &e) { + op->Destroy(); + throw e; + } } return op; } @@ -292,6 +297,21 @@ std::string Operation::name() const { auto p_name = info_.name(); return p_name ? p_name : ""; } + +void Operation::Erase() { + if (auto *parent = GetParent()) + parent->erase(*this); + else + Destroy(); +} + +bool Operation::use_empty() { + auto res = results(); + return std::all_of(res.begin(), res.end(), [](OpResult result) { + return result.use_empty(); + }); +} + void Operation::ReplaceAllUsesWith(const std::vector &values) { IR_ENFORCE(num_results_ == values.size(), "the num of result should be the same."); diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index d45fc368d28046..57523921e911b8 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -125,6 +125,16 @@ class IR_API alignas(8) Operation final { pir::OpInfo info() const { return info_; } std::string name() const; + /// + /// \brief Remove this operation from its parent block and delete it. + /// + void Erase(); + + /// + /// \brief Returns true if this operation has no uses. + /// + bool use_empty(); + template T dyn_cast() { return CastUtil::call(this); diff --git a/paddle/pir/core/type.cc b/paddle/pir/core/type.cc index 91933019fb8359..a200a07325bc0f 100644 --- a/paddle/pir/core/type.cc +++ b/paddle/pir/core/type.cc @@ -31,4 +31,6 @@ bool Type::IsIntOrIndex() const { isa() || isa() || isa(); } +bool Type::IsIndex() const { return isa(); } + } // namespace pir diff --git a/paddle/pir/core/type.h b/paddle/pir/core/type.h index c1b2f155e8d5a4..b48da12c12b31d 100644 --- a/paddle/pir/core/type.h +++ b/paddle/pir/core/type.h @@ -120,6 +120,7 @@ class IR_API Type { /// type. /// bool IsIntOrIndex() const; + bool IsIndex() const; protected: const Storage *storage_{nullptr}; diff --git a/paddle/pir/core/type_base.cc b/paddle/pir/core/type_base.cc index aec0d93d9fa69c..3676d4099be814 100644 --- a/paddle/pir/core/type_base.cc +++ b/paddle/pir/core/type_base.cc @@ -30,7 +30,7 @@ void *AbstractType::GetInterfaceImpl(TypeId interface_id) const { VLOG(6) << "Find no interface!"; return nullptr; } - // TODO(zhangbo63): Add LookUp method like: + // TODO(zhangbopd): Add LookUp method like: // return ir::detail::LookUp( // interface_id, num_interfaces_, num_traits_, this); } diff --git a/paddle/pir/core/value.h b/paddle/pir/core/value.h index 96787b973b81ad..1ec07f2eb7a929 100644 --- a/paddle/pir/core/value.h +++ b/paddle/pir/core/value.h @@ -93,7 +93,6 @@ class IR_API Value { protected: detail::ValueImpl *impl_{nullptr}; }; - } // namespace pir namespace std { diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index 4367670156efcd..0353a7610d2b38 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -15,20 +15,24 @@ #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" -namespace pir { -namespace dialect { +namespace pir::shape { ShapeDialect::ShapeDialect(IrContext *context) : Dialect(name(), context, TypeId::get()) { initialize(); } void ShapeDialect::initialize() { - RegisterOps(); + TensorDimOp, + ShapeOfOp, + FromElementsOp, + ExtractOp, + ConstantOp, + IndexCastOp>(); } void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const { @@ -39,7 +43,6 @@ void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const { } } -} // namespace dialect -} // namespace pir +} // namespace pir::shape -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::ShapeDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ShapeDialect) diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.h b/paddle/pir/dialect/shape/ir/shape_dialect.h index b8fe39bd8d500f..4be71aa0127ce7 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.h +++ b/paddle/pir/dialect/shape/ir/shape_dialect.h @@ -16,8 +16,7 @@ #include "paddle/pir/core/dialect.h" -namespace pir { -namespace dialect { +namespace pir::shape { /// /// \brief Shape Dialect: /// @@ -32,7 +31,6 @@ class IR_API ShapeDialect : public Dialect { void initialize(); }; -} // namespace dialect -} // namespace pir +} // namespace pir::shape -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::ShapeDialect) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ShapeDialect) diff --git a/paddle/pir/dialect/shape/ir/shape_op.cc b/paddle/pir/dialect/shape/ir/shape_op.cc index 885f50d080143e..d7acec75c08971 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.cc +++ b/paddle/pir/dialect/shape/ir/shape_op.cc @@ -18,9 +18,9 @@ #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/enforce.h" -namespace pir::dialect { +namespace pir::shape { -const char *SymbolicDim::attributes_name[attributes_num] = { +const char *SymbolicDimOp::attributes_name[attributes_num] = { "known_negative_one", // value = -1 "known_non_negative", // value >= 0 "known_non_size_one", // value != 1 @@ -28,14 +28,14 @@ const char *SymbolicDim::attributes_name[attributes_num] = { "sym_name", "value"}; // NOLINT -void SymbolicDim::Build(Builder &builder, - OperationArgument &argument, - const std::string &sym_name, - int64_t value, - bool known_non_negative, - bool known_negative_one, - bool known_non_size_one, - bool known_non_size_zero) { +void SymbolicDimOp::Build(Builder &builder, + OperationArgument &argument, + const std::string &sym_name, + int64_t value, + bool known_non_negative, + bool known_negative_one, + bool known_non_size_one, + bool known_non_size_zero) { IrContext *ctx = IrContext::Instance(); auto attr_sym_name = StrAttribute::get(ctx, sym_name); auto attr_value = Int64Attribute::get(ctx, value); @@ -52,57 +52,66 @@ void SymbolicDim::Build(Builder &builder, argument.AddAttribute("known_non_size_zero", attr_known_non_size_zero); } -const std::string SymbolicDim::GetSymName() { +const std::string SymbolicDimOp::GetSymName() { return attribute("sym_name").AsString(); } -int64_t SymbolicDim::GetDimSize() { + +int64_t SymbolicDimOp::GetDimSize() { return attribute("value").data(); } -bool SymbolicDim::GetKnownNonNegative() { + +bool SymbolicDimOp::GetKnownNonNegative() { return attribute("known_non_negative").data(); } -bool SymbolicDim::GetKnownNegativeOne() { + +bool SymbolicDimOp::GetKnownNegativeOne() { return attribute("known_negative_one").data(); } -bool SymbolicDim::GetKnownNonSizeOne() { + +bool SymbolicDimOp::GetKnownNonSizeOne() { return attribute("known_non_size_one").data(); } -bool SymbolicDim::GetKnownNonSizeZero() { + +bool SymbolicDimOp::GetKnownNonSizeZero() { return attribute("known_non_size_zero").data(); } -void SymbolicDim::SetSymName(const std::string &attr_value) { +void SymbolicDimOp::SetSymName(const std::string &attr_value) { operation()->set_attribute( "sym_name", StrAttribute::get(IrContext::Instance(), attr_value)); } -void SymbolicDim::SetDimSize(int64_t attr_value) { + +void SymbolicDimOp::SetDimSize(int64_t attr_value) { operation()->set_attribute( "value", Int64Attribute::get(IrContext::Instance(), attr_value)); } -void SymbolicDim::UpdateKnownNonNegative(bool flag) { +void SymbolicDimOp::UpdateKnownNonNegative(bool flag) { operation()->set_attribute("known_non_negative", BoolAttribute::get(IrContext::Instance(), flag)); } -void SymbolicDim::UpdateKnownNegativeOne(bool flag) { + +void SymbolicDimOp::UpdateKnownNegativeOne(bool flag) { operation()->set_attribute("known_negative_one", BoolAttribute::get(IrContext::Instance(), flag)); } -void SymbolicDim::UpdateKnownNonSizeOne(bool flag) { + +void SymbolicDimOp::UpdateKnownNonSizeOne(bool flag) { operation()->set_attribute("known_non_size_one", BoolAttribute::get(IrContext::Instance(), flag)); } -void SymbolicDim::UpdateKnownNonSizeZero(bool flag) { + +void SymbolicDimOp::UpdateKnownNonSizeZero(bool flag) { operation()->set_attribute("known_non_size_zero", BoolAttribute::get(IrContext::Instance(), flag)); } -bool SymbolicDim::IsDynamic() { +bool SymbolicDimOp::IsDynamic() { return GetDimSize() == ShapedTypeInterface::kDynamic; } -bool SymbolicDim::Merge(SymbolicDim other) { - VLOG(4) << "Try to merge two SymbolicDim ops."; +bool SymbolicDimOp::Merge(SymbolicDimOp other) { + VLOG(4) << "Try to merge two SymbolicDimOp."; if (!IsDynamic() && !other.IsDynamic() && GetDimSize() != other.GetDimSize()) return false; @@ -145,11 +154,11 @@ void DimOp::Build(Builder &builder, argument.output_types.emplace_back(IndexType::get(IrContext::Instance())); } -const std::string DimOp::getName() { +const std::string DimOp::GetName() { return attribute("name").AsString(); } -void DimOp::setName(std::string attrName) { +void DimOp::SetName(std::string attrName) { operation()->set_attribute( "name", StrAttribute::get(IrContext::Instance(), attrName)); } @@ -192,6 +201,7 @@ std::vector TieProductEqualOp::lhs() { } return res; } + std::vector TieProductEqualOp::rhs() { int64_t lhs_len = attribute("lhs_len").data(); int64_t rhs_len = attribute("rhs_len").data(); @@ -203,13 +213,14 @@ std::vector TieProductEqualOp::rhs() { } const char *TieShapeOp::attributes_name[attributes_num] = { - SymbolicDim::GetSymbolicDimAttrName().c_str()}; // NOLINT + SymbolicDimOp::GetSymbolicDimAttrName().c_str()}; // NOLINT void TieShapeOp::Build(Builder &builder, OperationArgument &argument, Value input) { argument.AddInput(input); } + void TieShapeOp::Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT Value input, @@ -218,8 +229,6 @@ void TieShapeOp::Build(Builder &builder, // NOLINT argument.AddInputs(dims); } -Value TieShapeOp::value() { return operand_source(0); } - std::vector TieShapeOp::dims() { std::vector res; for (uint32_t i = 1; i < num_operands(); i++) { @@ -261,23 +270,82 @@ void TensorDimOp::Build(Builder &builder, OperationArgument &argument, Value source, int64_t index) { - OpResult indexValue = + OpResult index_value = builder .Build(Int64Attribute::get(IrContext::Instance(), index), IndexType::get(IrContext::Instance())) ->result(0); - argument.AddInputs({source, indexValue}); + argument.AddInputs({source, index_value}); argument.output_types.emplace_back(IndexType::get(IrContext::Instance())); } -Value TensorDimOp::source() { return operand_source(0); } +std::optional TensorDimOp::GetConstantIndex() { + auto op = index().dyn_cast().owner(); + int64_t index = + op->dyn_cast().value().dyn_cast().data(); + return index; +} + +void ShapeOfOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value input) { + argument.AddInput(input); +} + +void FromElementsOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::vector &elements) { + argument.AddInputs(elements); +} -Value TensorDimOp::index() { return operand_source(1); } -} // namespace pir::dialect +std::vector FromElementsOp::elements() { + std::vector elements; + for (uint32_t idx = 0; idx < num_operands(); idx++) { + elements.push_back(operand_source(static_cast(idx))); + } + return elements; +} + +void ExtractOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value tensor, + std::vector indices) { + argument.AddInput(tensor); + argument.AddInputs(indices); +} + +std::vector ExtractOp::indices() { + std::vector indices; + for (uint32_t idx = 1; idx < num_operands(); idx++) { + indices.push_back(operand_source(static_cast(idx))); + } + return indices; +} + +void ConstantIndexOp::Build(Builder &builder, + OperationArgument &argument, + int64_t value) { + ConstantOp::Build( + builder, argument, builder.index_attr(value), builder.index_type()); +} + +void IndexCastOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Type out, + Value in) { + argument.AddInput(in); +} -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::DimOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp) +} // namespace pir::shape + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::SymbolicDimOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::DimOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::TieProductEqualOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::TieShapeOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::FuncOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::TensorDimOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ShapeOfOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::FromElementsOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ExtractOp); +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ConstantIndexOp); +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::IndexCastOp); diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index c838624d2566df..31e35f376c55fc 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -14,14 +14,16 @@ #pragma once +#include #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/op_trait.h" -namespace pir::dialect { +namespace pir::shape { -class IR_API SymbolicDim : public Op { +class IR_API SymbolicDimOp : public Op { public: using Op::Op; static const char *name() { return "shape.symbolic_dim"; } @@ -61,11 +63,11 @@ class IR_API SymbolicDim : public Op { // Sets `known_non_size_zero` to the value of `flag` void UpdateKnownNonSizeZero(bool flag); - // Returns true if this SymbolicDim is not known at compile-time. + // Returns true if this SymbolicDimOp is not known at compile-time. bool IsDynamic(); - // Try to merge two SymbolicDim ops. - bool Merge(SymbolicDim other); + // Try to merge two SymbolicDimOp. + bool Merge(SymbolicDimOp other); static const std::string GetSymbolicDimAttrName() { return "kSymbolicDimAttr"; @@ -86,8 +88,8 @@ class IR_API DimOp : public Op { OperationArgument &argument, // NOLINT const std::string &name); - const std::string getName(); - void setName(std::string attrValue); + const std::string GetName(); + void SetName(std::string attrValue); OpResult out() { return result(0); } void VerifySig() {} }; @@ -130,7 +132,7 @@ class IR_API TieShapeOp : public Op { OperationArgument &argument, // NOLINT Value input, const std::vector &dims); - Value value(); + Value input() { return operand_source(0); } std::vector dims(); void VerifySig() {} }; @@ -150,7 +152,7 @@ class IR_API FuncOp : public Op { void VerifySig() {} }; -class IR_API TensorDimOp : public Op { +class IR_API TensorDimOp : public Op { public: using Op::Op; static const char *name() { return "shape.tensor_dim"; } @@ -166,17 +168,106 @@ class IR_API TensorDimOp : public Op { OperationArgument &argument, // NOLINT Value source, int64_t index); - Value index(); - Value source(); + + Value source() { return operand_source(0); } + Value index() { return operand_source(1); } + OpResult out() { return result(0); } + void VerifySig() {} + std::optional GetConstantIndex(); +}; + +class IR_API ShapeOfOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.shape_of"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value input); + + Value input() { return operand_source(0); } + OpResult out() { return result(0); } + void VerifySig() {} +}; + +class IR_API FromElementsOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.from_elements"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::vector &elements); + + std::vector elements(); + OpResult out() { return result(0); } + void VerifySig() {} +}; + +class IR_API ExtractOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.extract"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value tensor, + std::vector indices); + + Value tensor() { return operand_source(0); } + std::vector indices(); OpResult out() { return result(0); } void VerifySig() {} }; -} // namespace pir::dialect +// Specialization of `constant` op that returns an integer of index type. +class IR_API ConstantIndexOp : public ConstantOp { + public: + using ConstantOp::ConstantOp; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + int64_t value); +}; + +// Cast between index and integer types. +class IR_API IndexCastOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.index_cast"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Type out, + Value in); + + Value in() { return operand_source(0); } + OpResult out() { return result(0); } + void VerifySig() {} +}; -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::DimOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp); +} // namespace pir::shape + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::SymbolicDimOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::DimOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::TieProductEqualOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::TieShapeOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::FuncOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::TensorDimOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ShapeOfOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::FromElementsOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ExtractOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ConstantIndexOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::IndexCastOp); diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization.cc b/paddle/pir/dialect/shape/transforms/shape_optimization.cc index 54f43c74cb4154..df21e6112a7a3d 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization.cc @@ -13,41 +13,41 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/dialect/shape/ir/shape_op.h" - #include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/infer_type_op_interface.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace pir { namespace { -using PassPipelineRunner = - std::function; bool InsertTieShapeOnValue(pir::Value value, pir::Builder& builder) { // NOLINT - auto ty = value.type().dyn_cast(); + auto type = value.type().dyn_cast(); - if (!ty || ty.dims().size() == 0) return true; - std::vector dimSizes; - for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) { - auto dimOp = builder.Build(value, dim); - dimSizes.push_back(dimOp.out()); + if (!type || type.dims().size() == 0) return true; + std::vector dim_sizes; + for (int64_t dim = 0, rank = type.dims().size(); dim < rank; ++dim) { + auto dim_op = builder.Build(value, dim); + dim_sizes.push_back(dim_op.out()); } - builder.Build(value, dimSizes); + builder.Build(value, dim_sizes); return true; } +// Forward declaration bool InsertTieShapeOnRegion(pir::Region* region); bool InsertTieShapeOnOperation(pir::Operation* op, pir::Builder& builder) { // NOLINT - // TODO(zhangbo63): skip more specialized Ops. - if (op->isa() || op->isa()) - return true; + // TODO(zhangbopd): skip more specialized Ops. + if (op->isa() || op->isa()) return true; for (size_t i = 0; i < op->num_regions(); ++i) { if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; @@ -63,7 +63,7 @@ bool InsertTieShapeOnOperation(pir::Operation* op, bool InsertTieShapeOnBlock(pir::Block* block) { pir::Builder builder = pir::Builder(pir::IrContext::Instance(), block, block->begin()); - // TODO(liujinnan): mapping block arguments + // TODO(zhangbopd): mapping block arguments std::vector op_list; for (pir::Operation* op : *block) op_list.push_back(op); @@ -74,18 +74,108 @@ bool InsertTieShapeOnBlock(pir::Block* block) { } bool InsertTieShapeOnRegion(pir::Region* region) { - for (pir::Block* block : *region) { + for (Block* block : *region) { if (!InsertTieShapeOnBlock(block)) return false; } return true; } +struct ExpandShapeOfOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(shape::ShapeOfOp op, + PatternRewriter& rewriter) const override { + // TODO(zhangbopd): Uncomment + // auto type = op.out().type().dyn_cast(); + + // if (!type || !type.dyn_cast().HasStaticShape() || + // !type.dyn_cast().GetElementType().IsIndex()) + // return false; + + // std::vector dim_sizes; + // for (int dim = 0, rank = + // type.dyn_cast().GetShape()[0]; + // dim < rank; + // ++dim) { + // dim_sizes.push_back( + // rewriter.Build(op.input(), dim).out()); + // } + // rewriter.ReplaceOpWithNewOp(op, dim_sizes); + return true; + } +}; + +// Fold dim of an operation that implements the InferShapedTypeOpInterface +template +struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(OpTy dim_op, PatternRewriter& rewriter) const override { + OpResult dim_value = dim_op.source().template dyn_cast(); + if (!dim_value) return false; + + auto shaped_type_op = + dim_value.owner()->dyn_cast(); + + if (!shaped_type_op) return false; + // TODO(zhangbopd): Uncomment + // std::optional dim_index = dim_op.GetConstantIndex(); + // if (!dim_index) return false; + + // std::vector reified_result_shapes; + // if (!shaped_type_op.ReifyReturnTypeShapes( + // rewriter, shaped_type_op->operands(), reified_result_shapes)) + // return false; + + // if (reified_result_shapes.size() != shaped_type_op->num_results()) + // return false; + + // Value result_shape = reified_result_shapes[dim_value.index()]; + // auto result_shape_type = result_shape.type().dyn_cast(); + // auto shaped_type = result_shape_type.dyn_cast(); + // if (!result_shape_type || !shaped_type.GetElementType().IsIntOrIndex()) + // return false; + + // // TODO(zhangbopd): BuildOrFold required. + // std::vector indices; + // indices.push_back(rewriter.Build(*dim_index).out()); + // Value new_value = + // rewriter.Build(result_shape, indices).out(); + + // if (!new_value.type().isa()) + // new_value = + // rewriter.Build(rewriter.index_type(), + // new_value) + // .out(); + + // rewriter.ReplaceOp(dim_op, {new_value}); + return true; + } +}; + bool MaterializeShapeComputation(pir::ModuleOp m) { if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; - // TODO(liujinnan): add rewitter pattern for reifyInferShape. + // TODO(zhangbopd): add rewitter pattern for reifyInferShape. + RewritePatternSet patterns(m.ir_context()); + + patterns.Add>( + patterns.ir_context()); + + IR_ENFORCE(ApplyPatternsGreedily(m, std::move(patterns)), + "fail to materialize shape computation\n"); return true; } +using PassPipelineRunner = + std::function; + +// Returns true if the type is possible to be a shape tensor type. +// Shape tensor type : +// - rank-1 static-shaped tensor type +// - element type of the tensor is int or index +// - number of elements of the tensor < 32, supposing that the +// higiest possible rank is smaller than 32. bool IsCandidateShapeTensorType(Type type) { auto tensor_type = type.dyn_cast(); auto shaped_type = tensor_type.dyn_cast(); @@ -119,21 +209,16 @@ class ShapeComputationIRAnalysis { ModuleOp m_; SymbolicDimMgr& mgr_; - std::unordered_map value_to_sym_dim_; + std::unordered_map value_to_sym_dim_; // shape tensor is the 1D ranked tensor with int/index dtype. - std::unordered_map> shape_tensor_to_sym_dims_; + std::unordered_map> + shape_tensor_to_sym_dims_; - std::unordered_map> dense_tensor_to_sym_dims_; + std::unordered_map> + dense_tensor_to_sym_dims_; }; -// Returns true if the type is possible to be a shape tensor type. -// Shape tensor type : -// - rank-1 static-shaped tensor type -// - element type of the tensor is int or index -// - number of elements of the tensor < 32, supposing that the -// higiest possible rank is smaller than 32. - ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m, SymbolicDimMgr& mgr) : m_(m), mgr_(mgr) {} @@ -163,7 +248,7 @@ bool ShapeComputationIRAnalysis::RunOnRegion(Region* region, func fn) { } bool ShapeComputationIRAnalysis::RunOnBlock(Block* block, func fn) { - // TODO(liujinnan): mapping block arguments + // TODO(zhangbopd): mapping block arguments std::vector op_list; for (Operation* op : *block) op_list.push_back(op); @@ -181,37 +266,37 @@ bool ShapeComputationIRAnalysis::RunOnOperation(Operation* op, func fn) { } bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { - if (op->isa()) return true; - if (op->isa()) { + if (op->isa()) return true; + if (op->isa()) { Value value = op->operand_source(0); - std::vector symbols; - if (op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) { + std::vector symbols; + if (op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) { auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()) + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()) .AsVector(); for (Attribute attr : attrs) { - auto sym = mgr_.symbolTable().Lookup( + auto sym = mgr_.symbolTable().Lookup( attr.dyn_cast().AsString()); - assert(sym); - SymbolicDim root = mgr_.GetRootSymbolicDim(sym); + IR_ENFORCE(sym); + SymbolicDimOp root = mgr_.GetRootSymbolicDim(sym); symbols.push_back(root); } } else { symbols = mgr_.CreateSymbolicDimsForRankedValue(value); std::vector attrs; - for (SymbolicDim sym : symbols) { + for (SymbolicDimOp sym : symbols) { Attribute rootSymbol = StrAttribute::get(m_->ir_context(), sym.GetSymName()); attrs.push_back(rootSymbol); } - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), ArrayAttribute::get(m_->ir_context(), attrs)); } dense_tensor_to_sym_dims_[value] = std::move(symbols); return true; } - for (size_t i = 0; i < op->num_results(); ++i) { - if (!BuildShapeOnValue(op->result(i))) return false; + for (auto& result : op->results()) { + if (!BuildShapeOnValue(result)) return false; } return true; } @@ -219,11 +304,11 @@ bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) { Type type = value.type(); if (type.IsIntOrIndex()) { - SymbolicDim sym = mgr_.NewSymbolicDim(); + SymbolicDimOp sym = mgr_.NewSymbolicDim(); value_to_sym_dim_[value] = sym; } else if (IsCandidateShapeTensorType(type)) { auto shaped_type = type.dyn_cast(); - std::vector symbols; + std::vector symbols; for (size_t i = 0, d = shaped_type.GetShape()[0]; i < d; ++i) symbols.push_back(mgr_.NewSymbolicDim()); shape_tensor_to_sym_dims_[value] = std::move(symbols); @@ -237,7 +322,7 @@ bool ShapeComputationIRAnalysis::ApplyOpConstraint(Operation* op) { IR_ENFORCE(ApplyTieShapeOpConstraint(op), "Fail to apply constraint for tie_shape op"); - // TODO(zhangbo63): add more constraints + // TODO(zhangbopd): add more constraints return true; } @@ -247,7 +332,7 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { Type type = op->result(0).type(); if (!type.IsIntOrIndex()) return true; - if (auto dim_op = op->dyn_cast()) { + if (auto dim_op = op->dyn_cast()) { int64_t dim_index = dim_op.index() .dyn_cast() .owner() @@ -267,12 +352,12 @@ bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { return false; } } - // TODO(zhangbo63): add support for reifyInferShape. (e.g. mul/add) + // TODO(zhangbopd): add support for reifyInferShape. (e.g. mul/add) return true; } bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { - if (auto tie_shape = op->dyn_cast()) { + if (auto tie_shape = op->dyn_cast()) { auto& value = dense_tensor_to_sym_dims_[op->operand_source(0)]; for (size_t idx = 0; idx < tie_shape.dims().size(); ++idx) { if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[tie_shape.dims()[idx]], @@ -285,7 +370,7 @@ bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { } bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { - // TODO(liujinnan): Do some Canonicalizer. + // TODO(zhangbopd): Do some Canonicalizer. pir::SymbolicDimMgr mgr(m); IR_ENFORCE(mgr.Load(), "SymbolicDimMgr Load failed in OptimizeShapeComputation."); diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc index 07f7cf4129a4d9..6954858bc8956b 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc @@ -50,22 +50,22 @@ bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { for (auto op : *(m.block())) { - if (op->isa()) { + if (op->isa()) { symbol_table_ = SymbolTable(op); return; } } Builder builder = Builder(m_.ir_context(), m_.block(), m_.block()->begin()); - dialect::FuncOp func = builder.Build(); + shape::FuncOp func = builder.Build(); symbol_table_ = SymbolTable(func); } bool SymbolicDimMgr::Load() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); for (auto op : *(func_op.block())) { symbol_table_.insert(op); - if (SymbolicDim sym_dim_op = op->dyn_cast()) { + if (SymbolicDimOp sym_dim_op = op->dyn_cast()) { symbol_dim_union_set_[sym_dim_op] = sym_dim_op; symbol_name_set_.insert(sym_dim_op.GetSymName()); } @@ -74,10 +74,10 @@ bool SymbolicDimMgr::Load() { } bool SymbolicDimMgr::LoadShapeConstraintGraph() { - // TODO(liujinnan): add more constraint function. currently, only support + // TODO(zhangbopd): add more constraint function. currently, only support // tie_product_equal. auto constraint_vec = - symbol_table_.Lookup("tie_product_equal"); + symbol_table_.Lookup("tie_product_equal"); if (!constraint_vec.size()) return true; @@ -88,8 +88,8 @@ bool SymbolicDimMgr::LoadShapeConstraintGraph() { if (auto constOp = defining_op->dyn_cast()) { product.factor *= constOp.value().dyn_cast().data(); continue; - } else if (auto dimOp = defining_op->dyn_cast()) { - auto sym = symbol_table_.Lookup(dimOp.getName()); + } else if (auto dim_op = defining_op->dyn_cast()) { + auto sym = symbol_table_.Lookup(dim_op.GetName()); if (!sym) return false; product.symbols.push_back(sym); continue; @@ -139,17 +139,17 @@ bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, SymbolicDimProduct SymbolicDimMgr::SimplifySymbolicDimProduct( const SymbolicDimProduct& x) { - std::vector copied; + std::vector copied; copied.reserve(x.symbols.size()); - for (SymbolicDim op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); + for (SymbolicDimOp op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); std::sort( - copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) { + copied.begin(), copied.end(), [&](SymbolicDimOp lhs, SymbolicDimOp rhs) { return CompareSymbolicDimNames(lhs.GetSymName(), rhs.GetSymName()); }); SymbolicDimProduct new_x; new_x.factor = x.factor; - for (SymbolicDim op : copied) { + for (SymbolicDimOp op : copied) { if (!op.IsDynamic()) { new_x.factor *= op.GetDimSize(); } else { @@ -186,13 +186,13 @@ SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, new_lhs.factor = lhs.factor / gcd_factor; new_rhs.factor = rhs.factor / gcd_factor; - std::unordered_map lhs_symbol_map; - std::unordered_map rhs_symbol_map; + std::unordered_map lhs_symbol_map; + std::unordered_map rhs_symbol_map; - for (SymbolicDim op : lhs.symbols) ++lhs_symbol_map[op]; - for (SymbolicDim op : rhs.symbols) ++rhs_symbol_map[op]; + for (SymbolicDimOp op : lhs.symbols) ++lhs_symbol_map[op]; + for (SymbolicDimOp op : rhs.symbols) ++rhs_symbol_map[op]; - for (SymbolicDim op : lhs.symbols) { + for (SymbolicDimOp op : lhs.symbols) { auto it = rhs_symbol_map.find(op); if (it != rhs_symbol_map.end() && op.GetKnownNonSizeZero()) { if (--it->second == 0) rhs_symbol_map.erase(it); @@ -201,7 +201,7 @@ SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, new_lhs.symbols.push_back(op); } - for (SymbolicDim op : rhs.symbols) { + for (SymbolicDimOp op : rhs.symbols) { auto it = lhs_symbol_map.find(op); if (it != lhs_symbol_map.end() && op.GetKnownNonSizeZero()) { if (--it->second == 0) lhs_symbol_map.erase(it); @@ -224,24 +224,24 @@ const std::string SymbolicDimMgr::GetNextName() { return name; } -SymbolicDim SymbolicDimMgr::NewSymbolicDim(const std::string& name) { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); +SymbolicDimOp SymbolicDimMgr::NewSymbolicDim(const std::string& name) { + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); Builder builder = Builder(m_.ir_context(), func_op.block()); // default settting dim != 0 - dialect::SymbolicDim symbol = - builder.Build(name.empty() ? GetNextName() : name, - ShapedTypeInterface::kDynamic, - false, - false, - false, - true); + SymbolicDimOp symbol = + builder.Build(name.empty() ? GetNextName() : name, + ShapedTypeInterface::kDynamic, + false, + false, + false, + true); symbol_dim_union_set_[symbol] = symbol; symbol_table_.insert(symbol); return symbol; } -SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { +SymbolicDimOp SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { auto it = constant_symbolic_dim_map_.find(val); if (it == constant_symbolic_dim_map_.end()) { auto name = "C" + std::to_string(val); @@ -257,9 +257,9 @@ SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { return GetRootSymbolicDim(it->second); } -std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( +std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( Value value) { - std::vector symbols; + std::vector symbols; auto dims = value.type().dyn_cast().dims(); for (int idx = 0; idx < dims.size(); ++idx) { symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic @@ -269,26 +269,26 @@ std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( return symbols; } -SymbolicDim SymbolicDimMgr::GetRootSymbolicDim(SymbolicDim symbol) { - SymbolicDim current = symbol; - std::vector path; +SymbolicDimOp SymbolicDimMgr::GetRootSymbolicDim(SymbolicDimOp symbol) { + SymbolicDimOp current = symbol; + std::vector path; while (symbol_dim_union_set_[current] != current) { path.push_back(current); current = symbol_dim_union_set_[current]; } - for (SymbolicDim sym : path) symbol_dim_union_set_[sym] = current; + for (SymbolicDimOp sym : path) symbol_dim_union_set_[sym] = current; return current; } -bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhs_root = GetRootSymbolicDim(lhs); - SymbolicDim rhs_root = GetRootSymbolicDim(rhs); +bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs) { + SymbolicDimOp lhs_root = GetRootSymbolicDim(lhs); + SymbolicDimOp rhs_root = GetRootSymbolicDim(rhs); return lhs_root == rhs_root; } -bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhs_root = GetRootSymbolicDim(lhs); - SymbolicDim rhs_root = GetRootSymbolicDim(rhs); +bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs) { + SymbolicDimOp lhs_root = GetRootSymbolicDim(lhs); + SymbolicDimOp rhs_root = GetRootSymbolicDim(rhs); if (lhs_root != rhs_root) { if (CompareSymbolicDimNames(lhs_root.GetSymName(), rhs_root.GetSymName())) { @@ -315,10 +315,10 @@ SymbolicDimProduct* SymbolicDimMgr::SymbolicDimProductDivide( SymbolicDimProduct* result = new SymbolicDimProduct(); result->factor = new_lhs.factor / new_rhs.factor; - std::unordered_map sym_proc_map; - for (SymbolicDim sym : new_rhs.symbols) ++sym_proc_map[sym]; + std::unordered_map sym_proc_map; + for (SymbolicDimOp sym : new_rhs.symbols) ++sym_proc_map[sym]; - for (SymbolicDim sym : new_lhs.symbols) { + for (SymbolicDimOp sym : new_lhs.symbols) { auto it = sym_proc_map.find(sym); if (it == sym_proc_map.end()) { result->symbols.push_back(sym); @@ -457,13 +457,13 @@ bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, } bool SymbolicDimMgr::Save() { - using Name2SymbolFn = std::function; + using Name2SymbolFn = std::function; auto update_attrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { std::vector new_attrs; for (Attribute attr : attrs.AsVector()) { auto sym = fn(attr.dyn_cast().AsString()); - assert(sym); - SymbolicDim root = GetRootSymbolicDim(sym); + IR_ENFORCE(sym); + SymbolicDimOp root = GetRootSymbolicDim(sym); Attribute root_symbol = StrAttribute::get(m_->ir_context(), root.GetSymName()); new_attrs.push_back(root_symbol); @@ -471,41 +471,41 @@ bool SymbolicDimMgr::Save() { return ArrayAttribute::get(m_->ir_context(), new_attrs); }; - // TODO(liujinnan): update attributes attached in DenseTensorType + // TODO(zhangbopd): update attributes attached in DenseTensorType for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + if (!op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()); auto symbolic_shape_attr = update_attrs(attrs, [&](const std::string& name) { - return symbol_table_.Lookup(name); + return symbol_table_.Lookup(name); }); - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), symbolic_shape_attr); } if (!UpdateProductEqualityMap()) { return false; } - std::unordered_set used_symbolic_ops; + std::unordered_set used_symbolic_ops; std::vector used_symbol_names; - // TODO(liujinnan): collect uses in value. + // TODO(zhangbopd): collect uses in value. auto collect_used_symbols = [&](ArrayAttribute attrs) { for (Attribute attr : attrs.AsVector()) { - auto sym = symbol_table_.Lookup( + auto sym = symbol_table_.Lookup( attr.dyn_cast().AsString()); - assert(sym); + IR_ENFORCE(sym); if (used_symbolic_ops.insert(sym).second) used_symbol_names.push_back(sym.GetSymName()); } }; for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + if (!op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()); collect_used_symbols(attrs); } - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); for (auto& p : symbol_dim_union_set_) { if (!used_symbolic_ops.count(p.first)) { func_op.block()->erase(*(p.first.operation())); @@ -514,10 +514,11 @@ bool SymbolicDimMgr::Save() { std::vector candidates; for (auto& outter : product_equality_map_) { - if (std::any_of( - outter.first.symbols.begin(), - outter.first.symbols.end(), - [&](SymbolicDim sym) { return used_symbolic_ops.count(sym) == 0; })) + if (std::any_of(outter.first.symbols.begin(), + outter.first.symbols.end(), + [&](SymbolicDimOp sym) { + return used_symbolic_ops.count(sym) == 0; + })) candidates.push_back(outter.first); } @@ -527,7 +528,7 @@ bool SymbolicDimMgr::Save() { for (auto& inner : outter.second) { if (std::any_of(inner.first.symbols.begin(), inner.first.symbols.end(), - [&](SymbolicDim sym) { + [&](SymbolicDimOp sym) { return used_symbolic_ops.count(sym) == 0; })) candidates.push_back(outter.first); @@ -550,35 +551,35 @@ bool SymbolicDimMgr::Save() { } } - std::unordered_map name_to_symbol; - for (SymbolicDim op : used_symbolic_ops) { + std::unordered_map name_to_symbol; + for (SymbolicDimOp op : used_symbolic_ops) { auto name = op.GetSymName(); op.SetSymName(name_mapping[name]); name_to_symbol[name] = op; } for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + if (!op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + op->attribute(SymbolicDimOp::GetSymbolicDimAttrName()); auto symbolic_shape_attr = update_attrs( attrs, [&](const std::string& name) { return name_to_symbol[name]; }); - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), symbolic_shape_attr); } - // TODO(liujinnan): update attributes attached to values. + // TODO(zhangbopd): update attributes attached to values. return SaveShapeConstraintGraph(); } bool SymbolicDimMgr::SaveShapeConstraintGraph() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); + auto func_op = symbol_table_.getOp()->dyn_cast(); + IR_ENFORCE(func_op); auto op_it = func_op.block()->rbegin(); while (op_it != func_op.block()->rend()) { - if (((*op_it)->isa()) || - ((*op_it)->isa())) + if (((*op_it)->isa()) || + ((*op_it)->isa())) op_it++; else op_it = decltype(op_it)(func_op.block()->erase(*(*op_it))); @@ -597,8 +598,8 @@ bool SymbolicDimMgr::SaveShapeConstraintGraph() { Int32Type::get(m_->ir_context())) ->result(0)); } - for (SymbolicDim sym : prod.symbols) { - values.push_back(builder.Build(sym.GetSymName()).out()); + for (SymbolicDimOp sym : prod.symbols) { + values.push_back(builder.Build(sym.GetSymName()).out()); } return values; }; @@ -613,7 +614,7 @@ bool SymbolicDimMgr::SaveShapeConstraintGraph() { if (!product_equality_map_[x][y]) continue; auto lhs_operands = build_operands(x); auto rhs_operands = build_operands(y); - builder.Build(lhs_operands, rhs_operands); + builder.Build(lhs_operands, rhs_operands); } } return true; diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h index 5541e8a8ee2f19..9bce0732441247 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h @@ -17,13 +17,13 @@ #include "paddle/pir/dialect/shape/utils/symbol_table.h" namespace pir { -using dialect::SymbolicDim; +using shape::SymbolicDimOp; // Represents a product of symbolic and concrete factors. // Used to prove product equalities symbolically. struct SymbolicDimProduct { // List all symbolic factors that can not be aggregated. - std::vector symbols; + std::vector symbols; // Product of all const factors. int64_t factor = 1; @@ -43,7 +43,7 @@ inline bool operator!=(const SymbolicDimProduct& lhs, } struct SymDimHasher { - size_t operator()(const dialect::SymbolicDim& symbol) const noexcept { + size_t operator()(const SymbolicDimOp& symbol) const noexcept { return std::hash{}(symbol.operation()); } }; @@ -64,29 +64,29 @@ class SymbolicDimMgr { public: explicit SymbolicDimMgr(ModuleOp m); - // Loads pre-defined SymbolicDim ops from the module this mgr runs on. + // Loads pre-defined SymbolicDimOp ops from the module this mgr runs on. bool Load(); // Create a new symbolicDim instance owned by this mgr. - SymbolicDim NewSymbolicDim(const std::string& name = {}); + SymbolicDimOp NewSymbolicDim(const std::string& name = {}); // Create a symbolicDim with static dim size == `val`. - SymbolicDim NewConstantSymbolicDim(int64_t val); + SymbolicDimOp NewConstantSymbolicDim(int64_t val); // Create a symbolicDim with given value. - std::vector CreateSymbolicDimsForRankedValue(Value value); + std::vector CreateSymbolicDimsForRankedValue(Value value); // All symbolic-equal dims form a group. - // Returns the root SymbolicDim of the symbolic-equal symbolic dim group which - // this SymbolicDim belongs to. - SymbolicDim GetRootSymbolicDim(SymbolicDim symbol); + // Returns the root SymbolicDimOp of the symbolic-equal symbolic dim group + // which this SymbolicDimOp belongs to. + SymbolicDimOp GetRootSymbolicDim(SymbolicDimOp symbol); // Returns true if lhs and rhs are known to be equal. - bool IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + bool IsSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs); // Marks lhs and rhs have same size and try to merge lhs & rhs static known // info. Returns false if failed to merge lhs & rhs. - bool MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + bool MapSymbolicDimEqual(SymbolicDimOp lhs, SymbolicDimOp rhs); // Returns the simplified version of SymbolicDimProduct. // This will try to fold some symbolicDim ops with const values. @@ -139,10 +139,10 @@ class SymbolicDimMgr { std::unordered_set symbol_name_set_; - std::unordered_map + std::unordered_map symbol_dim_union_set_; - std::unordered_map constant_symbolic_dim_map_; + std::unordered_map constant_symbolic_dim_map_; // product_equality_map_[A][B] == true : Product[A] == Product[B] using SymbolicDimProductMap = std::unordered_map< diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index d746831835ed89..79f270afdba504 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -50,16 +50,16 @@ ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) : m_(m), mgr_(m) { mgr_.Load(); for (auto op : *(m_.block())) { - auto tie_shape_op = op->dyn_cast(); + auto tie_shape_op = op->dyn_cast(); if (!tie_shape_op) continue; - Value result = tie_shape_op.value(); + Value result = tie_shape_op.input(); auto& symbols = value_to_sym_dims_[result]; auto attrs = tie_shape_op - .attribute(SymbolicDim::GetSymbolicDimAttrName()) + .attribute(SymbolicDimOp::GetSymbolicDimAttrName()) .AsVector(); for (const auto& attr : attrs) { - auto sym_op = mgr_.symbolTable().Lookup( + auto sym_op = mgr_.symbolTable().Lookup( attr.dyn_cast().AsString()); if (!sym_op) continue; symbols.push_back(sym_op); @@ -90,8 +90,8 @@ bool ShapeConstraintIRAnalysis::IsShapeEqual(Value lhs, Value rhs) { lhs_it->second.size() != rhs_it->second.size()) return false; - std::vector lhs_syms; - std::vector rhs_syms; + std::vector lhs_syms; + std::vector rhs_syms; for (auto sym : lhs_it->second) { lhs_syms.push_back(mgr_.GetRootSymbolicDim(sym)); } diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 0842313962d36b..9ac479548465d4 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -76,7 +76,7 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis { SymbolicDimMgr mgr_; // Map a ranked memref value to an array of symbolicDims, each represents one // dimension size of the memref value. - std::unordered_map> + std::unordered_map> value_to_sym_dims_; }; diff --git a/paddle/pir/dialect/shape/utils/symbol_table.cc b/paddle/pir/dialect/shape/utils/symbol_table.cc index c4ed0807b0b43b..74c60f3f6b1631 100644 --- a/paddle/pir/dialect/shape/utils/symbol_table.cc +++ b/paddle/pir/dialect/shape/utils/symbol_table.cc @@ -18,13 +18,13 @@ namespace pir { const std::string SymbolTable::insert(Operation* symbol) { std::string name; - if (symbol->isa()) { - name = symbol->dyn_cast().GetSymName(); + if (symbol->isa()) { + name = symbol->dyn_cast().GetSymName(); symbol_table_map_.insert({name, symbol}); } - // TODO(liujinnan): add more constraint_func name branch. - if (symbol->isa()) { + // TODO(zhangbopd): add more constraint_func name branch. + if (symbol->isa()) { name = "tie_product_equal"; symbol_func_map_[name].emplace_back(symbol); } diff --git a/paddle/pir/dialect/shape/utils/symbol_table.h b/paddle/pir/dialect/shape/utils/symbol_table.h index f85ba2cfb8099f..2c71a142c78d14 100644 --- a/paddle/pir/dialect/shape/utils/symbol_table.h +++ b/paddle/pir/dialect/shape/utils/symbol_table.h @@ -28,22 +28,22 @@ namespace pir { -using dialect::SymbolicDim; +using shape::SymbolicDimOp; class SymbolTable { public: explicit SymbolTable(Operation* symbol_table_op) : symbol_table_op_(symbol_table_op) {} SymbolTable() = default; template - typename std::enable_if::value, - SymbolicDim>::type + typename std::enable_if::value, + SymbolicDimOp>::type Lookup(const std::string& name) const { auto it = symbol_table_map_.find(name); - return it != symbol_table_map_.end() ? it->second->dyn_cast() - : SymbolicDim(nullptr); + return it != symbol_table_map_.end() ? it->second->dyn_cast() + : SymbolicDimOp(nullptr); } template - typename std::enable_if::value, + typename std::enable_if::value, std::vector>::type Lookup(const std::string& name) const { std::vector res; diff --git a/paddle/pir/pattern_rewrite/pattern_match.cc b/paddle/pir/pattern_rewrite/pattern_match.cc index 028d0779dbf94f..7b775ba4985813 100644 --- a/paddle/pir/pattern_rewrite/pattern_match.cc +++ b/paddle/pir/pattern_rewrite/pattern_match.cc @@ -116,46 +116,60 @@ void RewriterBase::ReplaceOpWithIf( void RewriterBase::ReplaceOp(Operation* op, const std::vector& new_values) { + // Notify that the rewriter subclass we're about to replace this root. NotifyRootReplaced(op, new_values); + IR_ENFORCE(op->num_results() == new_values.size(), "incorrect # of replacement values"); op->ReplaceAllUsesWith(new_values); + NotifyOperationRemoved(op); - op->GetParent()->erase(*op); + op->Erase(); } void RewriterBase::EraseOp(Operation* op) { - // TODO(wilber): Operation support use_empty. - // IR_ENFORCE(op->use_empty(), "expected 'op' to have no uses"); + IR_ENFORCE(op->use_empty(), "expected 'op' to have no uses"); NotifyOperationRemoved(op); - op->GetParent()->erase(*op); + op->Erase(); } -/// Find uses of `from` and replace it with `to` +// Find uses of `from` and replace it with `to`. void RewriterBase::ReplaceAllUsesWith(Value from, Value to) { - // TODO(wilber): Substitue a low level impl. - from.ReplaceAllUsesWith(to); + for (auto it = from.use_begin(); it != from.use_end();) + UpdateRootInplace(it.owner(), [&]() { (it++)->set_source(to); }); } -// TODO(wilber): iterator maybe should support modify inplace. +// Find uses of `from` and replace them with `to` if the `functor` returns true. void RewriterBase::ReplaceUseIf(Value from, Value to, std::function functor) { - // for (auto it = from.begin(); it != from.end(); ++it) { - // // // TODO: need a lvalue. - // if (functor(*it)) { - // UpdateRootInplace(it.owner(), [&](){it.get().set(to)}); - // } + // Use post-increment operator for iterator since set_source() will change + // `it`. + // TODO(zhangbopd): Uncomment + // for (auto it = from.use_begin(); it != from.use_end();) { + // if (functor(*it)) + // UpdateRootInplace(it.owner(), [&]() { (it++)->set_source(to); }); // } } +// Replace theuses of op with uses of new_op. +// 'op' and 'new_op' are known to have the same number of results void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op) { IR_ENFORCE(op->num_results() == new_op->num_results(), "replacement op doesn't match results of original op"); - // TODO(wilber): Op support results method. - // if (op->num_results() == 1) return ReplaceOp(op, - // new_op->result(0)); return ReplaceOp(op, new_op->GetResults()); + // TODO(zhangbopd): Uncomment + // if (op->num_results() == 1) { + // std::vector new_values; + // new_values.push_back(new_op->result(0)); + // return ReplaceOp(op, new_values); + // } + + // std::vector new_values; + // for (auto res : new_op->results()) { + // new_values.push_back(res); + // } + // return ReplaceOp(op, new_values); } } // namespace pir diff --git a/paddle/pir/pattern_rewrite/pattern_match.h b/paddle/pir/pattern_rewrite/pattern_match.h index 9e7553f4217cae..c1415606c3b24d 100644 --- a/paddle/pir/pattern_rewrite/pattern_match.h +++ b/paddle/pir/pattern_rewrite/pattern_match.h @@ -272,9 +272,16 @@ class RewriterBase : public Builder { virtual void ReplaceOp(Operation* op, const std::vector& new_values); - // template - // OpTy ReplaceOpWithNewOp(Operation *op, Args &&...args); + // Replaces the result op with a new op. + // The result values of the two ops must be the same types. + template + OpTy ReplaceOpWithNewOp(Operation* op, Args&&... args) { + auto new_op = Build(std::forward(args)...); + ReplaceOpWithResultsOfAnotherOp(op, new_op.operation()); + return new_op; + } + // This method erases an operation that is known to have no uses. virtual void EraseOp(Operation* op); IR_API void ReplaceAllUsesWith(Value from, Value to); @@ -327,6 +334,7 @@ class RewritePatternSet { public: explicit RewritePatternSet(IrContext* context) : context_(context) {} + // Construct a RewritePatternSet with the given patterns. RewritePatternSet(IrContext* context, std::unique_ptr pattern) : context_(context) { native_patterns_.emplace_back(std::move(pattern)); @@ -344,7 +352,7 @@ class RewritePatternSet { typename... ConstructorArgs, typename = std::enable_if_t> RewritePatternSet& Add(ConstructorArg&& arg, ConstructorArgs&&... args) { - std::initializer_list{ + (void)std::initializer_list{ (AddImpl({}, std::forward(arg), std::forward(args)...), @@ -359,7 +367,7 @@ class RewritePatternSet { RewritePatternSet& AddWithLabel(const std::vector& debug_labels, ConstructorArg&& arg, ConstructorArgs&&... args) { - std::initializer_list{ + (void)std::initializer_list{ (AddImpl(debug_labels, std::forward(arg), std::forward(args)...), diff --git a/pyproject.toml b/pyproject.toml index 3e8da7d18ed6fd..86b8ee2c804036 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,14 +103,3 @@ ignore = [ "test/dygraph_to_static/test_loop.py" = ["C416", "F821"] # Ignore unnecessary lambda in dy2st unittest test_lambda "test/dygraph_to_static/test_lambda.py" = ["PLC3002"] - -# B017 -"test/auto_parallel/spmd_rules/test_reshape_rule.py" = ["B017"] -"test/dygraph_to_static/test_assert.py" = ["B017"] -"test/legacy_test/test_cuda_max_memory_allocated.py" = ["B017"] -"test/legacy_test/test_cuda_max_memory_reserved.py" = ["B017"] -"test/legacy_test/test_cuda_memory_allocated.py" = ["B017"] -"test/legacy_test/test_cuda_memory_reserved.py" = ["B017"] -"test/legacy_test/test_eigvals_op.py" = ["B017"] -"test/legacy_test/test_tensordot.py" = ["B017"] -"test/legacy_test/test_top_k_v2_op.py" = ["B017"] diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 11a2d07d2096dd..842151d83b3325 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -71,7 +71,6 @@ Tensor.__qualname__ = 'Tensor' import paddle.distributed.fleet # noqa: F401 - from paddle import ( # noqa: F401 distributed, sysconfig, @@ -113,6 +112,7 @@ create_parameter, to_tensor, diag, + diag_embed, diagflat, eye, linspace, @@ -568,6 +568,7 @@ 'subtract', 'diag', 'diagflat', + 'diag_embed', 'isnan', 'scatter_nd_add', 'unstack', diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index ad5a7cc02aef9e..0ab057509e7ed0 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -105,12 +105,12 @@ def prepare_grad_outputs(grad_outputs, outputs, state): if output.shape != grad.shape: raise ValueError( "The shape of grad_output[%d] %s should be the same as the shape of output[%d] %s" - % (i, str(output.shape), i, str(grad.shape)) + % (i, str(grad.shape), i, str(output.shape)) ) if output.dtype != grad.dtype: raise ValueError( "The dtype of grad_output[%d] %s should be the same as the dtype of output[%d] %s" - % (i, str(output.dtype), i, str(grad.dtype)) + % (i, str(grad.dtype), i, str(output.dtype)) ) feedop = grad.get_defining_op() update_bwdop_structure( @@ -328,7 +328,7 @@ def append_backward_ops( if op has grad_op, prepare its grad_op's inputs by value_to_valuegrad, eg: value_to_valuegrad[v3] = [[v3_g]]; - v2_g = call_vjp(op3, [v3_g], [v2_stopgradient]) + v2_g = call_vjp(op3, [[v2]], [[v3]],[[v3_g]], [[v2_stopgradient]]) special pattern 1: @@ -339,7 +339,7 @@ def append_backward_ops( v1 is inside python api, we don't describe it in backward process(state) so v1_grad is inside vjp, we don't describe it in backward process(state) - [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [v3_g], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient) + [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [[v11, v12]], [[v3]],[[v3_g]], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient]) op_vjp is: @@ -358,10 +358,12 @@ def append_backward_ops( else continue to next op. ''' - def make_output_grad(op): + def make_output_with_output_grad(op): zero_flag = [False] * op.num_results() + outputs = [] output_grads = [] for i, value in enumerate(op.results()): + new_value = [value] if ( value in state.value_to_valuegrad and len(state.value_to_valuegrad[value]) > 1 @@ -396,12 +398,15 @@ def make_output_grad(op): # pattern case: # this fwd_op's output is vectorType, it will split to # Type by builtin.split op, so need get from split op's ouput - split_zero_flag, split_output_grad = make_output_grad( - value.first_use().owner() - ) + ( + split_zero_flag, + split_outputs, + split_output_grad, + ) = make_output_with_output_grad(value.first_use().owner()) zero_flag[i] = all(split_zero_flag) grad_values = [value[0] for value in split_output_grad] state.value_to_valuegrad[value] = [grad_values] + new_value = [info[0] for info in split_outputs] else: # first case: # this fwd_op's output didn't used by other fwd_op, @@ -424,35 +429,45 @@ def make_output_grad(op): state.value_to_valuegrad[value] = [[grad_value]] + outputs.append(new_value) output_grads.append(state.value_to_valuegrad[value][0]) - return zero_flag, output_grads + return zero_flag, outputs, output_grads - def make_input_stopgradient(op): + def make_input_with_input_stopgradient(op): + inputs = [] input_grad_stopgradients = [] if op.name() == "builtin.combine": grad_semantic_info = [True for _ in range(op.num_operands())] else: grad_semantic_info = op.get_input_grad_semantics() + for input, grad_semantic in zip( op.operands_source(), grad_semantic_info ): if not grad_semantic: + inputs.append([input]) continue if ( input.get_defining_op() is not None and input.get_defining_op().name() == "builtin.combine" ): - stop_gradient = make_input_stopgradient(input.get_defining_op()) + ( + combine_inputs, + combine_stop_gradient, + ) = make_input_with_input_stopgradient(input.get_defining_op()) + inputs.append([info[0] for info in combine_inputs]) input_grad_stopgradients.append( - [info[0] for info in stop_gradient] + [info[0] for info in combine_stop_gradient] ) else: + inputs.append([input]) if input.get_defining_op() is None or input in no_grad_set: input_grad_stopgradients.append([True]) else: input_grad_stopgradients.append([False]) - return input_grad_stopgradients + + return inputs, input_grad_stopgradients def update_input_grad_map(op, input_grads): i = 0 @@ -494,7 +509,7 @@ def update_input_grad_map(op, input_grads): for op in clear_effective_forward_ops: if paddle.framework.core.has_vjp(op): # prepare output_grad - zero_flag, output_grads = make_output_grad(op) + zero_flag, outputs, output_grads = make_output_with_output_grad(op) # all(zero_flag) support this op has no contribution for grad # should be delete (prune sub_graph) @@ -502,12 +517,15 @@ def update_input_grad_map(op, input_grads): continue # prepare input_grad stop_gradient info. - input_grad_stopgradients = make_input_stopgradient(op) + ( + inputs, + input_grad_stopgradients, + ) = make_input_with_input_stopgradient(op) # create grad_op before_ops_num = len(block.ops) input_grads = paddle.framework.core.call_vjp( - op, output_grads, input_grad_stopgradients + op, inputs, outputs, output_grads, input_grad_stopgradients ) after_ops_num = len(block.ops) diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index b01c7a70e44066..1f5b414ebb559e 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -869,7 +869,7 @@ def cuda(self, device_id=None, blocking=True): if self.place._equals(res_place): return self else: - res = self._copy_to(res_place, True) + res = self._copy_to(res_place, blocking) res.stop_gradient = self.stop_gradient res.persistable = self.persistable return res diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 037657ee0ad94c..01b038f818ed4f 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -21,7 +21,7 @@ import numpy as np -from ..pir import OpResult +from ..pir import OpResult, translate_to_new_ir from . import compiler, core, framework, get_flags, set_flags, unique_name from .data_feeder import convert_dtype from .framework import ( @@ -1006,7 +1006,14 @@ def _get_program_and_executor(self, cached_data): ) else: default_job = core.Job("default") - type_to_program = {"default": new_program.desc} + if get_flags("FLAGS_enable_new_ir_in_executor")[ + 'FLAGS_enable_new_ir_in_executor' + ]: + type_to_program = { + "default": translate_to_new_ir(new_program.desc) + } + else: + type_to_program = {"default": new_program.desc} plan = core.Plan([default_job], type_to_program) new_exe = _StandaloneExecutor(place, plan, scope) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index e4ab9d0f886626..9e41f5f732760e 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -157,3 +157,11 @@ def set_field_default_config(category, field, default_value): set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True) set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32) set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True) + +######################################### +# model parallel configuration +######################################### +MP_OPTIMIZATION = "mp_optimization" +set_field_default_config( + MP_OPTIMIZATION, "allreduce_matmul_grad_overlapping", False +) diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 6f0a1db1a3bff9..ba9dca8b334d49 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -355,14 +355,18 @@ def _apply_post_optimization( ) params_grads = self._pass_context.get_attr("params_grads") - mp_async_allreduce_in_backward = os.getenv( - "FLAGS_mp_async_allreduce_in_backward" - ) in [1, "1", True, "True"] - if mp_async_allreduce_in_backward: - column_parallel_linear_backward_overlapping_pass = new_pass( - "column_parallel_linear_backward_overlapping", {} - ) - column_parallel_linear_backward_overlapping_pass.apply( + if self._strategy.mp_optimization.allreduce_matmul_grad_overlapping: + if int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) != 1: + self._logger.warning( + "You set mp_optimization.allreduce_matmul_grad_overlapping=True, but you did not set environment " + "variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance " + "loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance." + ) + + allreduce_matmul_grad_overlapping_pass = new_pass( + "allreduce_matmul_grad_overlapping", {} + ) + allreduce_matmul_grad_overlapping_pass.apply( [main_program], [startup_program], self._pass_context ) @@ -432,8 +436,22 @@ def _apply_post_optimization( and self._strategy.pipeline.enable and use_new_executor() ): + enable_send_recv_overlap = ( + self._strategy.pipeline.enable_send_recv_overlap + ) + if ( + enable_send_recv_overlap + and int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) != 1 + ): + self._logger.warning( + "You set pipeline.enable_send_recv_overlap=True, but you did not set environment " + "variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance " + "loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance." + ) + main_program._pipeline_opt = {} main_program._pipeline_opt["standalone_opt"] = { + "enable_send_recv_overlap": enable_send_recv_overlap, "schedule_mode": self._strategy.pipeline.schedule_mode, "num_micro_batches": self._strategy.pipeline.accumulate_steps, "pp_degree": len(self._dist_context.process_meshes), diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 1df4663b4fed5c..958d7dc565304d 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -136,6 +136,12 @@ def __init__(self, config_dict=None): super().__init__(category, config_dict) +class MPOptimizationConfig(BaseConfig): + def __init__(self, config_dict=None): + category = constants.MP_OPTIMIZATION + super().__init__(category, config_dict) + + class Strategy(BaseConfig): """ The `Strategy` object is used to configure the parallelization and optimization behaviors. @@ -214,3 +220,6 @@ def __init__(self, config=None): config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) self.dp_optimization = DPOptimizationConfig(config_dict) + + config_dict = self._config_dict.get(constants.MP_OPTIMIZATION, None) + self.mp_optimization = MPOptimizationConfig(config_dict) diff --git a/python/paddle/distributed/auto_tuner/prune.py b/python/paddle/distributed/auto_tuner/prune.py index abae3f606fee15..976089f9d05f2b 100644 --- a/python/paddle/distributed/auto_tuner/prune.py +++ b/python/paddle/distributed/auto_tuner/prune.py @@ -85,10 +85,6 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None): if mp_degree not in mp_degree_candidates: return True - # prune default candidates - if mp_degree > 8: - return True - return False diff --git a/python/paddle/distributed/auto_tuner/recorder.py b/python/paddle/distributed/auto_tuner/recorder.py index 71c1b08ff3ecdf..11517da529f4fe 100644 --- a/python/paddle/distributed/auto_tuner/recorder.py +++ b/python/paddle/distributed/auto_tuner/recorder.py @@ -70,9 +70,8 @@ def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]: if first_few >= 5: break return (best_cfg, False) - if ( - isinstance(self.history[0]["max_mem_usage"], str) - or self.history[0]["time"] == -1 + if isinstance(self.history[0]["max_mem_usage"], str) or ( + "time" in self.history[0] and self.history[0]["time"] == -1 ): return (self.history[0], True) return (self.history[0], False) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index a2bac699bb5421..66d82ae1a29149 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -153,8 +153,9 @@ def _new_process_group_impl( if backend == "gloo": pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) elif backend == "nccl": - pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id) - + pg = core.ProcessGroupNCCL.create( + store, rank, world_size, group_id, genv.pg_timeout + ) elif backend == "xccl": pg = core.ProcessGroupCustom.create( store, genv.device_type, rank, world_size, group_id @@ -240,12 +241,6 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): # TODO: The method below is a new method for group management, will replace the previous # three in the future. _add_new_group(group) - - # TODO(shenliang03): This is a temporary solution to solve the problem of - # hang caused by tcp - paddle.distributed.barrier(group=group) - if paddle.distributed.get_world_size() > 1: - paddle.distributed.barrier() return group if not backend: diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 4750c6bca66fc6..a8e702ef66e76c 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -930,6 +930,8 @@ def amp_configs(self): use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. + use_pure_bf16(bool): Whether to use the pure bf16 training. Default False. + use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. Default True. Only takes effect when `use_pure_fp16` is turned on. diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index b6130b55bf6737..bced953eff1397 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -189,16 +189,27 @@ def __init__(self, topology): self._sep_parallel_id = self._get_sep_parallel_id() self.stage_id = self._get_pipe_parallel_id() - assert self._check_vaild_topo(), ( - "Here is an unreasonable topogy setting. world_size: {}, but" - "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}, sep_num: {}".format( - self.nranks, - self._mp_degree, - self._sharding_degree, - self._pp_degree, - self._dp_degree, - self._sep_degree, - ) + assert ( + self._check_vaild_topo() + ), "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}, sep_num: {}".format( + self.nranks, + self._mp_degree, + self._sharding_degree, + self._pp_degree, + self._dp_degree, + ) + + # create comm group for pipe parallel + self._pp_group, self._pp_comm_group = self._set_comm_group("pipe") + # NOTE(shenliang03): In pipeline parallel, we use batch_isend_irecv. + # if batch_isend_irecv is the first collective operation, all ranks of + # the pipeline group must participate in this call. In order to avoid + # this situation, we perform a collective communication in advance and + # create a communicator. + paddle.distributed.all_reduce( + paddle.zeros([1], dtype="int32"), + op=paddle.distributed.ReduceOp.SUM, + group=self._pp_comm_group, ) # create comm group for data parallel @@ -207,9 +218,6 @@ def __init__(self, topology): # create comm group for model parallel self._mp_group, self._mp_comm_group = self._set_comm_group("model") - # create comm group for pipe parallel - self._pp_group, self._pp_comm_group = self._set_comm_group("pipe") - # create comm group for sharding parallel self._sharding_group, self._sharding_comm_group = self._set_comm_group( "sharding" @@ -240,6 +248,11 @@ def __init__(self, topology): ["pipe", "model"] ) + ( + self.sharding_check_group, + self.sharding_check_comm_group, + ) = self._set_check_group("sharding") + # create p2p group self.is_first_stage = self.stage_id == 0 self.is_last_stage = self.stage_id == (self._pp_degree - 1) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index f7fc29b8d27ab7..626781a4725490 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -85,19 +85,27 @@ def distributed_model(model): if paddle.distributed.get_world_size() <= 1: return model - amp_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp: - amp_enable = True - amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1" - if amp_level.upper() == "O2": + level = ( + "O2" + if strategy.amp_configs['use_pure_fp16'] + or strategy.amp_configs['use_pure_bf16'] + else "O1" + ) + + if level == "O2": model = paddle.amp.decorate( models=model, optimizers=None, level="O2", master_weight=None, save_dtype=None, + dtype="float16" + if strategy.amp_configs['use_pure_fp16'] + else "bfloat16", ) + init_loss_scaling = strategy.amp_configs['init_loss_scaling'] incr_ratio = strategy.amp_configs['incr_ratio'] decr_ratio = strategy.amp_configs['decr_ratio'] diff --git a/python/paddle/distributed/fleet/scaler.py b/python/paddle/distributed/fleet/scaler.py index 463674c9587413..e284563614745e 100755 --- a/python/paddle/distributed/fleet/scaler.py +++ b/python/paddle/distributed/fleet/scaler.py @@ -31,6 +31,7 @@ def unscale_method(self, optimizer): return param_grads = [] + param_grads_bf16 = [] param_grads_fp16 = [] param_grads_fp32 = [] if getattr(optimizer, '_param_groups', None) and isinstance( @@ -53,6 +54,10 @@ def unscale_method(self, optimizer): paddle.float16, ]: param_grads_fp16.append(tgt_grad) + elif tgt_grad.dtype in [ + paddle.bfloat16, + ]: + param_grads_bf16.append(tgt_grad) else: param_grads_fp32.append(tgt_grad) else: @@ -90,10 +95,15 @@ def unscale_method(self, optimizer): paddle.float16, ]: param_grads_fp16.append(tgt_grad) + elif tgt_grad.dtype in [ + paddle.bfloat16, + ]: + param_grads_bf16.append(tgt_grad) else: param_grads_fp32.append(tgt_grad) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) + temp_found_inf_bf16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) self._found_inf = self._temp_found_inf_value_false if len(param_grads_fp16): @@ -106,6 +116,16 @@ def unscale_method(self, optimizer): self._found_inf = _C_ops.bitwise_or( self._found_inf, temp_found_inf_fp16 ) + if len(param_grads_bf16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_bf16, + self._scale, + param_grads_bf16, + temp_found_inf_bf16, + ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_bf16 + ) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 3a751b5d0c3c89..2d7c44e77b6623 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle # noqa: F401 from paddle.distributed import fleet -from paddle.utils import deprecated # noqa: F401 from . import ( # noqa: F401 hybrid_parallel_util, diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index d625887b8167f0..27e294907304b5 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -197,8 +197,9 @@ def __init__(self, ctx): host, port = self.endpoint.split(':') if ctx.is_auto_tuner_mode(): - self.etcd_client = ETCDClient(host=host, port=port) - self.client = etcd3.client(host=host, port=port) + self.client = ETCDClient(host=host, port=port) + else: + self.client = etcd3.client(host=host, port=port) def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): ''' @@ -256,22 +257,13 @@ def register_heartbeat(self, job_id, pod_id, ttl=10): self.job_prefix = f'/paddle/{job_id}' self.heartbeat_prefix = f'{self.job_prefix}/heartbeat' - if self.ctx.is_auto_tuner_mode(): - self.etcd_client.delete_prefix(self.job_prefix) - lease = self.etcd_client.lease(ttl) - else: - self.client.delete_prefix(self.job_prefix) - lease = self.client.lease(ttl) + self.client.delete_prefix(self.job_prefix) + lease = self.client.lease(ttl) # self.client.delete_prefix(self.job_prefix) beat_path = f"{self.heartbeat_prefix}/{pod_id}" - if self.ctx.is_auto_tuner_mode(): - self.etcd_client.put( - beat_path, pod_id.encode('latin-1'), lease=lease - ) - else: - self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) + self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) def _beat_watch(event): self.ctx.status.restart() diff --git a/python/paddle/distributed/launch/controllers/watcher.py b/python/paddle/distributed/launch/controllers/watcher.py index 25855572620f85..fd5571c39d4434 100644 --- a/python/paddle/distributed/launch/controllers/watcher.py +++ b/python/paddle/distributed/launch/controllers/watcher.py @@ -23,7 +23,7 @@ class Watcher: def __init__(self, ctx): self.ctx = ctx - self.interval = 30 + self.interval = 5 self.gpu_util = [] diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index 1fc2e6713e1b63..e24984e6f1479c 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -656,6 +656,7 @@ def launch(): elif "OK" not in status: timeout_flag = False + has_error = False if err & (1 << 0): ctx.logger.warning( f"Read metric failed for parameters: {log_dir}" @@ -665,6 +666,7 @@ def launch(): cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM" + has_error = True if err & (1 << 1): ctx.logger.warning(f"Out of memory for parameters: {log_dir}") @@ -673,6 +675,7 @@ def launch(): cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg["max_mem_usage"] = "OOM" + has_error = True # not err & (1 << 1): do not record memory usage when out of memory if err & (1 << 2) and not err & (1 << 1): @@ -684,18 +687,20 @@ def launch(): ) cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM" - if not err and timeout_flag: + if not has_error and timeout_flag: # for pruner use cur_cfg['time'] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM" - if not err and not timeout_flag: + if not has_error and not timeout_flag: cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM" # record history + if tuner_cfg['metric_cfg']['name'] not in cur_cfg: + cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg['job_id'] = job_id recorder.add_cfg(**cur_cfg) recorder.store_history(history_file_path) @@ -794,6 +799,8 @@ def launch(): ctx.logger.info(f"AutoTuner ends in {end_time-start_time}s.") logger.info(f"AutoTuner ends in {end_time-start_time}s.") # launch best cfg + if not tuner_cfg.get("run_best", True): + sys.exit() new_args = gen_new_args(raw_args, best_cfg, tuner_cfg, run_best=True) ctx.run_best = True ctx.args.training_script_args = new_args diff --git a/python/paddle/distributed/launch/utils/etcd_client.py b/python/paddle/distributed/launch/utils/etcd_client.py index e4bbf8e1409a4d..a96c7a034fdb18 100644 --- a/python/paddle/distributed/launch/utils/etcd_client.py +++ b/python/paddle/distributed/launch/utils/etcd_client.py @@ -140,3 +140,41 @@ def lease(self, ttl, lease_id=None): if times >= self.retry_times: raise ValueError(f"Lease failed after {self.retry_times} times.") + + def add_watch_prefix_callback(self, key_prefix, callback, **kwargs): + times = 0 + while times < self.retry_times: + try: + return self.client.add_watch_prefix_callback( + key_prefix, callback, **kwargs + ) + break + except Exception as e: + times += 1 + logging.info( + f"Add watch prefix callback failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Add watch prefix callback failed after {self.retry_times} times." + ) + + def cancel_watch(self, watch_id): + times = 0 + while times < self.retry_times: + try: + return self.client.cancel_watch(watch_id) + break + except Exception as e: + times += 1 + logging.info( + f"Cancel watch failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Cancel watch failed after {self.retry_times} times." + ) diff --git a/python/paddle/distributed/launch/utils/nvsmi.py b/python/paddle/distributed/launch/utils/nvsmi.py index 0c51456bf1204f..232ccce2209cce 100644 --- a/python/paddle/distributed/launch/utils/nvsmi.py +++ b/python/paddle/distributed/launch/utils/nvsmi.py @@ -133,7 +133,7 @@ def get_gpu_util(index=None): if index is None or isinstance(index, list) else str(index).split(",") ) - if paddle.device.is_compiled_with_cuda(): + if paddle.device.is_compiled_with_rocm(): return query_rocm_smi(q, index=index, dtype=d) return query_smi(q, index=index, dtype=d) diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 8890ab0bd179ae..34400b1c2f7e37 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -181,7 +181,6 @@ def sync_params_buffers( paddle.distributed.broadcast( coalesced_var, src=src_rank, group=comm_group, sync_op=True ) - for coalesced_var, origin_vars, var_shapes in coalesced_vars: var_len = [np.prod(v_shape) for v_shape in var_shapes] paddle.base.framework._dygraph_tracer().trace_op( @@ -685,6 +684,7 @@ def __init__(self): self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", "")) + self._pg_timeout = int(os.getenv("PADDLE_PG_TIMEOUT", "1800000")) # imperative only support one gpu or xpu if self._device_type != "": @@ -849,6 +849,24 @@ def nrings(self): """ return self._nrings + @property + def pg_timeout(self): + """ + timeout of process group. + + Its value is equal to the value of the environment variable ``PADDLE_PG_TIMEOUT`` . The default value is 30 minutes. + + Examples: + .. code-block:: python + + # execute this command in terminal: export PADDLE_PG_TIMEOUT=1800000 + import paddle.distributed as dist + + env = dist.ParallelEnv() + # the pg_timeout of process group 1800000 + """ + return self._pg_timeout + # [aliases] Compatible with old method names local_rank = rank nranks = world_size @@ -1098,7 +1116,6 @@ def init_parallel_env(): # TODO(mine): support XPU and other backends. if backend in ["nccl", 'xccl', 'bkcl']: core.CommContextManager.set_device_id(parallel_env.device_id) - paddle.distributed.barrier(group=group) return group node_num = {i.split(":")[0] for i in parallel_env.trainer_endpoints} diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index e2f54d47a4e08c..8c1f4ab6e5350f 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -24,7 +24,7 @@ from .auto_parallel_grad_clip import * # noqa: F403 from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403 from .auto_parallel_pipeline import * # noqa: F403 -from .column_parallel_linear_backward_overlapping import * # noqa: F403 +from .allreduce_matmul_grad_overlapping import * # noqa: F403 from .cpp_pass import * # noqa: F403 from .fuse_all_reduce import * # noqa: F403 from .pipeline_scheduler_pass import * # noqa: F403 diff --git a/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py similarity index 98% rename from python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py rename to python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index aa5dbd7d267e1c..c6457b612ff81e 100644 --- a/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -27,8 +27,8 @@ # dY = matmul(X^T, dOut) # # Then the c_allreduce_sum can overlap with the compute of dY. -@register_pass("column_parallel_linear_backward_overlapping") -class ColumnParallelLinearBackwardOverlappingPass(PassBase): +@register_pass("allreduce_matmul_grad_overlapping") +class AllreduceMatmulGradOverlappingPass(PassBase): def __init__(self): super().__init__() self.set_attr("allreduce_stream", None) diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index d5eb98d7422de9..9b2042ce9ea854 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -35,6 +35,8 @@ core.VarDesc.VarType.FETCH_LIST, ] +logger = get_logger(logging.INFO) + # NOTE: Here stream is just a presentation with different name, # it is up to executor to create the exact streams given the name. @@ -264,7 +266,7 @@ def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): required_vars = type_to_required_vars[job_type] micro_batch_id = job.micro_batch_id() skip_gc_vars = required_vars & suffixed_required_vars[micro_batch_id] - get_logger(logging.INFO).info( + logger.debug( f"Skip gc vars for {job_type}-({micro_batch_id}): {skip_gc_vars}" ) diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index ba17a7a50a8ff0..7aed3fb36d3bc7 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -379,10 +379,10 @@ def _partial_programs(self, program): ) for i in range(len(types)): - logger.info( + logger.debug( f"type = {types[i]}, sub_programs = {sub_programs[i]}\n" ) - logger.info(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}") + logger.debug(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}") return types, sub_programs diff --git a/python/paddle/incubate/multiprocessing/__init__.py b/python/paddle/incubate/multiprocessing/__init__.py index 42c7bd7bcf75eb..2498a04014d954 100644 --- a/python/paddle/incubate/multiprocessing/__init__.py +++ b/python/paddle/incubate/multiprocessing/__init__.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing # noqa: F401 - from .reductions import init_reductions __all__ = [] diff --git a/python/paddle/incubate/nn/functional/fused_dropout_add.py b/python/paddle/incubate/nn/functional/fused_dropout_add.py index d191f1682fddac..127cc91d548119 100644 --- a/python/paddle/incubate/nn/functional/fused_dropout_add.py +++ b/python/paddle/incubate/nn/functional/fused_dropout_add.py @@ -16,7 +16,7 @@ from paddle import _C_ops from paddle.base import core from paddle.common_ops_import import default_main_program -from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.framework import LayerHelper, in_dynamic_or_pir_mode def fused_dropout_add( @@ -84,7 +84,7 @@ def fused_dropout_add( "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" ) seed = None - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if default_main_program().random_seed != 0: seed = default_main_program().random_seed out, seed_offset = _C_ops.fused_dropout_add( diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm.py b/python/paddle/incubate/nn/functional/fused_rms_norm.py index 99f9c4e72e77d0..9a95d99b178a72 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm.py @@ -15,7 +15,7 @@ import paddle from paddle import _C_ops -from paddle.framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode def fused_rms_norm( @@ -64,7 +64,7 @@ def fused_rms_norm( >>> epsilon = 1e-6 >>> paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) """ - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): return _C_ops.rms_norm( x, bias, @@ -78,7 +78,21 @@ def fused_rms_norm( quant_max_bound, quant_min_bound, ) - + if in_pir_mode(): + out, residual_out = _C_ops.rms_norm( + x, + bias, + residual, + norm_weight, + norm_bias, + epsilon, + begin_norm_axis, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + ) + return (out, residual_out) if residual is not None else out helper = LayerHelper('rms_norm', **locals()) out = None if quant_scale <= 0: diff --git a/python/paddle/jit/api.py b/python/paddle/jit/api.py index 71a70fdfbe2c3f..1a93121bf4b587 100644 --- a/python/paddle/jit/api.py +++ b/python/paddle/jit/api.py @@ -15,18 +15,18 @@ # Temporary disable isort to avoid circular import # This can be removed after the circular import is resolved -# isort: skip_file from __future__ import annotations +import inspect import os import pickle -import warnings import sys -from collections import OrderedDict -import inspect import threading -from typing import Any import types +import warnings +from collections import OrderedDict +from contextlib import contextmanager +from typing import Any import paddle from paddle.base import core, dygraph @@ -40,43 +40,52 @@ program_desc_tracing_guard, switch_to_static_graph, ) -from .dy2static import logging_utils -from .dy2static.convert_call_func import ( - ConversionOptions, - add_ignore_module, -) -from .dy2static.program_translator import ( - ProgramTranslator, - StaticFunction, - ASTStaticFunction, - SymbolicStaticFunction, - unwrap_decorators, -) -from paddle.jit.translated_layer import ( - TranslatedLayer, - INFER_MODEL_SUFFIX, - INFER_PARAMS_SUFFIX, - INFER_PARAMS_INFO_SUFFIX, - INFER_PROPERTY_SUFFIX, -) -from paddle.nn import Layer from paddle.base.executor import Executor, scope_guard from paddle.base.framework import ( Block, + EagerParamBase, + Parameter, Program, Variable, - Parameter, - EagerParamBase, -) -from paddle.base.framework import ( _current_expected_place, _dygraph_guard, _dygraph_tracer, + dygraph_only, ) -from paddle.base.framework import dygraph_only from paddle.base.wrapped_decorator import wrap_decorator -from paddle.static.io import save_inference_model from paddle.framework import in_dynamic_mode +from paddle.nn import Layer +from paddle.static.io import save_inference_model +from paddle.utils.environments import ( + BooleanEnvironmentVariable, + EnvironmentVariableGuard, +) + +from .dy2static import logging_utils +from .dy2static.convert_call_func import ConversionOptions, add_ignore_module +from .dy2static.program_translator import ( + ASTStaticFunction, + ProgramTranslator, + StaticFunction, + SymbolicStaticFunction, + convert_to_static, + unwrap_decorators, +) +from .translated_layer import ( + INFER_MODEL_SUFFIX, + INFER_PARAMS_INFO_SUFFIX, + INFER_PARAMS_SUFFIX, + INFER_PROPERTY_SUFFIX, + TranslatedLayer, +) + +ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True) + + +@contextmanager +def sot_mode_guard(value: bool): + with EnvironmentVariableGuard(ENV_ENABLE_SOT, value): + yield def create_program_from_desc(program_desc): @@ -166,7 +175,7 @@ def __impl__(*args, **kwargs): "We will just return dygraph output." ) return dygraph_func(*args, **kwargs) - static_func = program_translator.get_func(dygraph_func) + static_func = convert_to_static(dygraph_func) return static_func(*args, **kwargs) return __impl__ @@ -298,11 +307,8 @@ def decorated(python_func): nonlocal full_graph if full_graph is None: - flag = os.environ.get("ENABLE_FALL_BACK", None) - if flag == "True" or flag is None: - full_graph = False - else: # False - full_graph = True + flag = ENV_ENABLE_SOT.get() + full_graph = not flag if sys.version_info >= (3, 12) and not full_graph: warnings.warn( diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index f890e1eb7d0233..49844749826ac4 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -214,7 +214,9 @@ def __init__( ) # program_id -> list(scope) - self._scope_cache = {} + self._pir_scope_cache = {} + self._legacy_scope_cache = {} + self._scope_cache = self._legacy_scope_cache self._hooker = None self._backend = kwargs.get('backend', None) self._grad_var_names = {} @@ -267,6 +269,12 @@ def set_hooker(self, hooker): self._hooker = hooker def _get_scope(self, program_id=None, use_scope_cache=False): + if get_flags('FLAGS_enable_new_ir_in_executor')[ + 'FLAGS_enable_new_ir_in_executor' + ]: + self._scope_cache = self._pir_scope_cache + else: + self._scope_cache = self._legacy_scope_cache if use_scope_cache: if program_id not in self._scope_cache: scope = core.Scope() @@ -1157,4 +1165,9 @@ def add_build_strategy_for( builded_program = paddle.static.Program() for var in program.block(0).vars.values(): builded_program.block(0)._clone_variable(var, False) + + # set back the parent_idx of blocks + for origin, current in zip(program.blocks, builded_program.blocks): + current.desc.set_parent_idx(origin.desc.parent) + return builded_program diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 20528aebf9c35f..95a9c2c9fdc91e 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -15,7 +15,6 @@ import collections import inspect import os -import textwrap import threading import warnings import weakref @@ -56,7 +55,6 @@ ALREADY_D2S, NO_SHAPE_VAR_TYPE, ast_to_func, - ast_to_source_code, backend_guard, func_to_source_code, input_specs_compatible, @@ -1762,37 +1760,6 @@ def __init__(self): self.enable_to_static = True def enable(self, enable_to_static): - """ - Enable or disable the converting from imperative to static graph by - ProgramTranslator globally. - - Args: - enable_to_static (bool): True or False to enable or disable converting to static. - - Returns: - None. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - - >>> x = paddle.ones([1, 2]) - >>> x_v = prog_trans.get_output(func, x) - >>> print(x_v) - Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0., 0.]]) - """ check_type( enable_to_static, "enable_to_static", @@ -1801,274 +1768,6 @@ def enable(self, enable_to_static): ) self.enable_to_static = enable_to_static - def get_output(self, dygraph_func, *args, **kwargs): - """ - Returns the output dygraph Tensor for dygraph function. The dygraph - function will be translated into static graph function so the under - beneath numerical result will be calculated by static graph mode. - - Args: - dygraph_func (callable): the dygraph function. - *args (tuple): the input argument of dygraph_func. - **kwargs (dict): the input argument of dygraph_func. - - Returns: - Tensor or tuple of Tensors: the dygraph Tensor containing digital result. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - - >>> x = paddle.ones([1, 2]) - >>> x_v = prog_trans.get_output(func, x) - >>> print(x_v) - Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0., 0.]]) - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_output" - - if not self.enable_to_static: - # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message) - # will show up **only once**. - logging_utils.warn( - "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. " - "We will just return dygraph output. " - "Please call ProgramTranslator.enable(True) if you would like to get static output." - ) - return dygraph_func(*args, **kwargs) - try: - function_spec = FunctionSpec(dygraph_func) - cache_key = CacheKey.from_func_and_args( - function_spec, - args, - kwargs, - getattr(dygraph_func, '__self__', None), - ) - _, partial_program_layer = self._program_cache[cache_key] - - if args and isinstance(args[0], layers.Layer): - # Synchronize self.training attribute. - partial_program_layer.training = args[0].training - args = args[1:] - try: - return partial_program_layer(args) - except BaseException as e: - # NOTE: - # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before; - # 2. If e raised in runtime, e should be attached to ERROR_DATA here. - if not hasattr(e, error.ERROR_DATA): - # runtime error - error.attach_error_data(e, in_runtime=True) - raise - except BaseException as e: - error_data = getattr(e, error.ERROR_DATA, None) - if error_data: - error_data.raise_new_exception() - else: - logging_utils.warn( - "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'" - " if you can't handle this {} yourself.".format(type(e)) - ) - raise e - - def get_func(self, dygraph_func): - """ - Returns a callable function which converts imperative dygraph APIs of - the input dygraph_func into declarative net-building APIs, which means - it doesn't return immediate digital result as get_output does. - Users should handle Program and Executor by themselves. - - Args: - dygraph_func (callable): the dygraph function. - - Returns: - callable: converting imperative dygraph APIs into declarative - net-building APIs. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - >>> static_func = prog_trans.get_func(func) - >>> print(callable(static_func)) - True - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_func" - - if not self.enable_to_static: - logging_utils.warn( - "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will " - "just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output." - ) - return dygraph_func - - static_func = convert_to_static(dygraph_func) - return static_func - - def get_program(self, dygraph_func, *args, **kwargs): - """ - Returns the translated static program and input/output Tensors from - dygraph function. The users can use the program to run by executor. - - Args: - dygraph_func (callable): the dygraph function. - *args (tuple): the input argument of dygraph_func. - **kwargs (dict): the input argument of dygraph_func. - - Returns: - tuple of (main_program, startup_program, inputs, outputs) whose - types are (Program, Program, list of Tensors, list of Tensors). - main_program: the converted main program. - startup_program: the converted startup program. - inputs: list of input Tensors which need to be fed. - outputs: list of output Tensors which users can fetch. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - >>> x = paddle.ones([1, 2]) - >>> main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x) - >>> print([i.name for i in inputs]) - >>> # [u'generated_tensor_0'] the feed input Tensor name representing x - >>> print([o.name for o in outputs]) - >>> # [u'_generated_var_4'] the fetch output Tensor name representing x_v - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_program" - - if not self.enable_to_static: - logging_utils.warn( - "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False." - "We will just return dygraph output. " - "Please call ProgramTranslator.enable(True) if you would like to get static output." - ) - return dygraph_func(*args, **kwargs) - - function_spec = FunctionSpec(dygraph_func) - cache_key = CacheKey.from_func_and_args( - function_spec, args, kwargs, getattr(dygraph_func, '__self__', None) - ) - concrete_program, partial_program_layer = self._program_cache[cache_key] - - # Note: concrete_program hold all input/output infos include non-Variable - input_vars = [ - var - for var in concrete_program.inputs - if isinstance(var, framework.Variable) - ] - output_vars = [ - var - for var in concrete_program.outputs - if isinstance(var, framework.Variable) - ] - - return ( - concrete_program.main_program, - concrete_program.startup_program, - input_vars, - output_vars, - ) - - def get_code(self, dygraph_func): - """ - Returns the translated static function string code from dygraph function. - - Args: - dygraph_func (callable): the dygraph function. - - Returns: - str: the string code of translated static function. - - Examples: - .. code-block:: python - - >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest') - >>> import paddle - >>> def func(x): - ... if paddle.mean(x) > 0: - ... x_v = x - 1 - ... else: - ... x_v = x + 1 - ... return x_v - ... - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - - >>> code = prog_trans.get_code(func) - >>> print(type(code)) - - """ - assert callable( - dygraph_func - ), "Input dygraph_func is not a callable in ProgramTranslator.get_code" - # Gets AST from dygraph function - - unwrap_func = unwrap(dygraph_func) - raw_code = inspect.getsource(unwrap_func) - code = textwrap.dedent(raw_code) - root = gast.parse(code) - - # Transform AST - dygraph_to_static = DygraphToStaticAst() - root = dygraph_to_static.get_static_ast(root) - - # Get source_code - source_code = ast_to_source_code(root) - return source_code - - def get_program_cache(self): - """ - Returns the ProgramCache instance. This method is used by PaddlePaddle - developers to manage program cache in ProgramTranslator. Normal users - don't have to call this method. - - Returns: - ProgramCache: ProgramCache instance of ProgramTranslator. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> prog_trans = paddle.jit.dy2static.program_translator.ProgramTranslator() - >>> prog_cache = prog_trans.get_program_cache() - """ - return self._program_cache - def enable_to_static(enable_to_static_bool): """ diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 1eab7edc738bfd..fd5eba66c76842 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -178,21 +178,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ) -def create_undefined_variable_local(): - helper = LayerHelper('create_undefined_variable', **locals()) - var = helper.create_variable( - name=unique_name.generate("undefined_var"), - shape=[1], - dtype="float64", - type=core.VarDesc.VarType.LOD_TENSOR, - stop_gradient=False, - is_data=True, - need_check_feed=False, - ) - paddle.assign(RETURN_NO_VALUE_MAGIC_NUM, var) - return var - - def create_undefined_variable(): var = data_layer_not_check( unique_name.generate("undefined_var"), [1], "float64" diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 0859ecfec46b9a..1fd89009200a4b 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -29,6 +29,7 @@ from ...symbolic.statement_ir import Symbol from ...symbolic.symbolic_context import SymbolicTraceContext from ...utils import ( + ENV_SHOW_TRACKERS, NameGenerator, OrderedSet, inner_error_default_handler, @@ -37,7 +38,6 @@ log, log_do, map_if, - show_trackers, tmp_name_guard, ) from .guard import Guard, StringifyExpression, make_guard @@ -341,7 +341,7 @@ def start_compile(self, *ret_vars: VariableBase): self.restore_side_effects(self.side_effects.proxy_variables) self.pycode_gen.gen_enable_eval_frame() - tracker_output_path = show_trackers() + tracker_output_path = ENV_SHOW_TRACKERS.get() if tracker_output_path: from .tracker_viewer import view_tracker diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 6d9ec8829497a5..d9947579dc7d4a 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -30,6 +30,7 @@ from ...profiler import EventGuard, event_register from ...psdb import NO_BREAKGRAPH_CODES from ...utils import ( + ENV_MIN_GRAPH_SIZE, BreakGraphError, FallbackError, InnerError, @@ -37,7 +38,6 @@ SotUndefinedVar, log, log_do, - min_graph_size, ) from ..custom_code import CustomCode from ..instruction_utils import ( @@ -1701,7 +1701,7 @@ def transform(self): # stopped by RETURN_VALUE and has sir len is enough => disable_eval_frame simulate_complete = bool(self.stop_state == "Return") if simulate_complete: - if self._graph.sir_ctx.TOS.graph_size() < min_graph_size(): + if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): raise FallbackError( "Fallback after simulate for reasons.", disable_eval_frame=True, diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index a1f26ea622772b..d27d5e9e5ab971 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -13,6 +13,16 @@ # limitations under the License. from .code_status import CodeStatus # noqa: F401 +from .envs import ( # noqa: F401 + ENV_CLEAN_CODE, + ENV_COST_MODEL, + ENV_MIN_GRAPH_SIZE, + ENV_SHOW_TRACKERS, + ENV_SOT_LOG_LEVEL, + ENV_STRICT_MODE, + cost_model_guard, + strict_mode_guard, +) from .exceptions import ( # noqa: F401 BreakGraphError, FallbackError, @@ -35,7 +45,6 @@ SotUndefinedVar, StepInfoManager, StepState, - cost_model, count_if, current_tmp_name_records, execute_time, @@ -55,8 +64,6 @@ map_if, map_if_extend, meta_str, - min_graph_size, no_eval_frame, - show_trackers, tmp_name_guard, ) diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py new file mode 100644 index 00000000000000..303e3af2a20f31 --- /dev/null +++ b/python/paddle/jit/sot/utils/envs.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager + +from paddle.utils.environments import ( + BooleanEnvironmentVariable, + EnvironmentVariableGuard, + IntegerEnvironmentVariable, + StringEnvironmentVariable, +) + +ENV_COST_MODEL = BooleanEnvironmentVariable("COST_MODEL", False) +ENV_MIN_GRAPH_SIZE = IntegerEnvironmentVariable("MIN_GRAPH_SIZE", 10) +ENV_SOT_LOG_LEVEL = IntegerEnvironmentVariable("SOT_LOG_LEVEL", 0) +ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) +ENV_SHOW_TRACKERS = StringEnvironmentVariable("SHOW_TRACKERS", "") +ENV_CLEAN_CODE = BooleanEnvironmentVariable("CLEAN_CODE", False) + + +@contextmanager +def cost_model_guard(value: bool): + with EnvironmentVariableGuard(ENV_COST_MODEL, value): + yield + + +@contextmanager +def strict_mode_guard(value: bool): + with EnvironmentVariableGuard(ENV_STRICT_MODE, value): + yield diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 912ae7dec2692c..c27ff8e33123b2 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -16,7 +16,6 @@ import builtins import inspect -import os import time import types import weakref @@ -32,6 +31,12 @@ from paddle.framework import Program from paddle.utils import flatten, map_structure +from .envs import ( + ENV_CLEAN_CODE, + ENV_COST_MODEL, + ENV_SOT_LOG_LEVEL, + ENV_STRICT_MODE, +) from .paddle_api_config import ( break_graph_set, paddle_api_list, @@ -41,14 +46,6 @@ T = TypeVar("T") -def cost_model(): - return os.environ.get("COST_MODEL", "False") == "True" - - -def min_graph_size(): - return int(os.environ.get("MIN_GRAPH_SIZE", 10)) - - class Singleton(Generic[T]): def __init__(self, cls: type[T]): self._cls = cls @@ -119,13 +116,13 @@ def next(self): def log(level, *args): - cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + cur_level = ENV_SOT_LOG_LEVEL.get() if level <= cur_level: print(*args, end="") def log_do(level, fn): - cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + cur_level = ENV_SOT_LOG_LEVEL.get() if level <= cur_level: fn() @@ -287,15 +284,11 @@ def meta_str(shape, dtype, stop_gradient): def is_strict_mode(): - return os.environ.get("STRICT_MODE", "0") == "1" - - -def show_trackers() -> str | None: - return os.environ.get("SHOW_TRACKERS", None) + return ENV_STRICT_MODE.get() def is_clean_code() -> bool: - return os.environ.get('CLEAN_CODE', "False") == "True" + return ENV_CLEAN_CODE.get() def list_find_index_by_id(li: list[Any], item: Any) -> int: @@ -623,7 +616,9 @@ class StepInfo: def __init__(self): self.step_count = -1 self.state = ( - StepState.COLLECT_INFO if cost_model() else StepState.RUN_SOT + StepState.COLLECT_INFO + if ENV_COST_MODEL.get() + else StepState.RUN_SOT ) self.dyn_time_costs = [] self.avg_dyn_time = 0 diff --git a/python/paddle/metric/metrics.py b/python/paddle/metric/metrics.py index 2760b448a70276..9506e4db895429 100644 --- a/python/paddle/metric/metrics.py +++ b/python/paddle/metric/metrics.py @@ -17,10 +17,10 @@ import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from ..base.data_feeder import check_variable_and_dtype -from ..base.framework import _create_tensor +from ..base.framework import _create_tensor, in_pir_mode from ..base.layer_helper import LayerHelper from ..framework import in_dynamic_mode @@ -807,6 +807,10 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): ) return _acc + elif in_pir_mode(): + topk_out, topk_indices = paddle.topk(input, k=k) + _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label) + return _acc helper = LayerHelper("accuracy", **locals()) check_variable_and_dtype( diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 1ef27639abd132..db4c9adf3327ce 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -164,48 +164,6 @@ from . import initializer # noqa: F401 from . import quant # noqa: F401 -# TODO: remove 'diag_embed', 'remove_weight_norm', 'weight_norm' months later. -from paddle.utils import deprecated - - -@deprecated( - since="2.0.0", - update_to="paddle.nn.functional.diag_embed", - level=1, - reason="diag_embed in paddle.nn will be removed in future", -) -def diag_embed(*args): - ''' - alias name of paddle.nn.functional.diag_embed - ''' - return functional.diag_embed(*args) - - -@deprecated( - since="2.0.0", - update_to="paddle.nn.utils.remove_weight_norm", - level=1, - reason="remove_weight_norm in paddle.nn will be removed in future", -) -def remove_weight_norm(*args): - ''' - alias name of paddle.nn.utils.remove_weight_norm - ''' - return utils.remove_weight_norm(*args) - - -@deprecated( - since="2.0.0", - update_to="paddle.nn.utils.weight_norm", - level=1, - reason="weight_norm in paddle.nn will be removed in future", -) -def weight_norm(*args): - ''' - alias name of paddle.nn.utils.weight_norm - ''' - return utils.weight_norm(*args) - __all__ = [ 'BatchNorm', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 608587becd9522..453627b4cf0494 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -183,7 +183,6 @@ 'log_softmax', 'glu', 'gumbel_softmax', - 'diag_embed', 'sequence_mask', 'dropout', 'dropout2d', diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index c74748793a4e9d..fa2b447860903b 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -484,7 +484,7 @@ def leaky_relu(x, negative_slope=0.01, name=None): [-0.02000000, 0. , 1. ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.leaky_relu(x, negative_slope) else: check_variable_and_dtype( diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 138146f376aeeb..c1e1d8f8137dbb 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -463,7 +463,7 @@ def conv1d( squeeze_aixs = -3 if channel_last else -2 x = unsqueeze(x, axis=[squeeze_aixs]) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if l_type == 'conv2d': out = _C_ops.conv2d( x, diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 757c9059efdd69..f52a1334331208 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -14,141 +14,26 @@ # TODO: define the extention functions -import numpy as np -from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode +from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode, tensor +from paddle.utils import deprecated -from ...base.data_feeder import ( - check_dtype, - check_type, - check_variable_and_dtype, -) +from ...base.data_feeder import check_type, check_variable_and_dtype from ...base.layer_helper import LayerHelper from ...common_ops_import import Variable from ...framework import convert_np_dtype_to_dtype_, core -from ...tensor.creation import assign __all__ = [] +@deprecated( + since="2.5.2", + update_to="paddle.diag_embed", + level=1, + reason="diag_embed in paddle.nn.functional will be removed in future", +) def diag_embed(input, offset=0, dim1=-2, dim2=-1): - """ - Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) - are filled by ``input``. By default, a 2D plane formed by the last two dimensions - of the returned tensor will be selected. - - The argument ``offset`` determines which diagonal is generated: - - - If offset = 0, it is the main diagonal. - - If offset > 0, it is above the main diagonal. - - If offset < 0, it is below the main diagonal. - - Args: - input(Tensor|numpy.ndarray): The input tensor. Must be at least 1-dimensional. The input data type should be float32, float64, int32, int64. - offset(int, optional): Which diagonal to consider. Default: 0 (main diagonal). - dim1(int, optional): The first dimension with respect to which to take diagonal. Default: -2. - dim2(int, optional): The second dimension with respect to which to take diagonal. Default: -1. - - Returns: - Tensor, the output data type is the same as input data type. - - Examples: - .. code-block:: python - - >>> import paddle - >>> import paddle.nn.functional as F - - >>> diag_embed_input = paddle.arange(6) - - >>> diag_embed_output1 = F.diag_embed(diag_embed_input) - >>> print(diag_embed_output1) - Tensor(shape=[6, 6], dtype=int64, place=Place(cpu), stop_gradient=True, - [[0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0], - [0, 0, 2, 0, 0, 0], - [0, 0, 0, 3, 0, 0], - [0, 0, 0, 0, 4, 0], - [0, 0, 0, 0, 0, 5]]) - - >>> diag_embed_output2 = F.diag_embed(diag_embed_input, offset=-1, dim1=0,dim2=1 ) - >>> print(diag_embed_output2) - Tensor(shape=[7, 7], dtype=int64, place=Place(cpu), stop_gradient=True, - [[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0], - [0, 0, 2, 0, 0, 0, 0], - [0, 0, 0, 3, 0, 0, 0], - [0, 0, 0, 0, 4, 0, 0], - [0, 0, 0, 0, 0, 5, 0]]) - - >>> diag_embed_input_2dim = paddle.reshape(diag_embed_input,[2,3]) - >>> print(diag_embed_input_2dim) - Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, - [[0, 1, 2], - [3, 4, 5]]) - >>> diag_embed_output3 = F.diag_embed(diag_embed_input_2dim,offset= 0, dim1=0, dim2=2 ) - >>> print(diag_embed_output3) - Tensor(shape=[3, 2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, - [[[0, 0, 0], - [3, 0, 0]], - [[0, 1, 0], - [0, 4, 0]], - [[0, 0, 2], - [0, 0, 5]]]) - """ - if not isinstance(input, Variable): - input = assign(input) - - if in_dynamic_mode(): - return _C_ops.diag_embed(input, offset, dim1, dim2) - - inputs = {'Input': [input]} - attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} - - def __check_input(input, offset, dim1, dim2): - check_dtype( - input.dtype, - 'Input', - ['int32', 'int64', 'float16', 'float32', 'float64'], - 'diag_embed', - ) - - input_shape = list(input.shape) - assert len(input_shape) >= 1, ( - "Input must be at least 1-dimensional, " - "But received Input's dimensional: %s.\n" % len(input_shape) - ) - - assert np.abs(dim1) <= len(input_shape), ( - "Dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" - % (-(len(input_shape) + 1), len(input_shape), dim1) - ) - - assert np.abs(dim2) <= len(input_shape), ( - "Dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" - % (-(len(input_shape) + 1), len(input_shape), dim2) - ) - - dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 + 1 - dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 + 1 - assert dim1_ != dim2_, ( - "dim1 and dim2 cannot be the same dimension." - "But received dim1 = %d, dim2 = %d\n" % (dim1, dim2) - ) - - __check_input(input, offset, dim1, dim2) - helper = LayerHelper("diag_embed", **locals()) - - out = helper.create_variable_for_type_inference(dtype=input.dtype) - - helper.append_op( - type='diag_embed', - inputs={'Input': [input]}, - attrs={'offset': offset, 'dim1': dim1, 'dim2': dim2}, - outputs={'Out': [out]}, - ) - out.stop_gradient = True - return out + return tensor.diag_embed(input, offset, dim1, dim2) def sequence_mask(x, maxlen=None, dtype='int64', name=None): diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 11b85df5d1377a..98da4e717feb3f 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -45,6 +45,15 @@ def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True): g_enable_mem_efficient = original_enable_mem_efficient +# special for XPU device +def get_triangle_upper_mask(x): + mask = paddle.full_like(x, -1e4) + mask.stop_gradient = True + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + def _math_attention( query, key, @@ -65,11 +74,19 @@ def _math_attention( product = paddle.matmul( x=query * (head_dim**-0.5), y=key, transpose_y=True ) - weights = ( - paddle.incubate.softmax_mask_fuse_upper_triangle(product) - if causal - else F.softmax(product) - ) + + if not causal: + weights = F.softmax(product) + else: + # special for XPU device + place = paddle.get_device() + if "xpu" in place: + # softmax_mask_fuse_upper_triangle is not supported on XPU, use plain implementation + mask = get_triangle_upper_mask(product) + product = product + mask + weights = F.softmax(product) + else: + weights = paddle.incubate.softmax_mask_fuse_upper_triangle(product) if dropout_rate > 0.0: weights = F.dropout( weights, dropout_rate, training=training, mode="upscale_in_train" @@ -183,10 +200,22 @@ def flash_attention( >>> import paddle - >>> paddle.seed(1) + >>> paddle.seed(2023) >>> q = paddle.rand((1, 128, 2, 16)) >>> output = paddle.nn.functional.flash_attention.flash_attention(q, q, q, 0.9, False, False) + >>> print(output) + (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[[0.34992966, 0.34456208, 0.45826620, ..., 0.39883569, + 0.42132431, 0.39157745], + [0.76687670, 0.65837246, 0.69117945, ..., 0.82817286, + 0.76690865, 0.71485823]], + ..., + [[0.71662450, 0.57275224, 0.57053083, ..., 0.48108247, + 0.53336465, 0.54540104], + [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250, + 0.40526697, 0.60541755]]]]), None) + """ head_dim = query.shape[3] sdp_func_name = _select_sdp(head_dim) @@ -340,11 +369,12 @@ def flash_attn_unpadded( .. code-block:: python >>> import paddle - >>> paddle.seed(1) - >>> q = paddle.rand((1, 128, 2, 16)) + >>> paddle.seed(2023) + >>> q = paddle.rand((2, 128, 8, 16), dtype='float16') + >>> cu = paddle.arange(0, 384, 128, dtype='int32') + >>> qq = paddle.reshape(q, [256, 8, 16]) + >>> output = paddle.nn.functional.flash_attention.flash_attn_unpadded(qq, qq, qq, cu, cu, 128, 128, 0.25, 0.0, False, False) - >>> output = paddle.nn.functional.flash_attention.flash_attn_unpadded(q, q, q, 0.9, False, False) - >>> print(output) """ if in_dynamic_mode(): ( @@ -461,7 +491,7 @@ def scaled_dot_product_attention( Examples: .. code-block:: python - >>> # doctest: +SKIP() + >>> # doctest: +SKIP('bfloat need V100 compile') >>> import paddle >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e74e67d83f88eb..f9bb57d616a5df 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -713,7 +713,7 @@ def binary_cross_entropy_with_logits( logit, label, weight=None, reduction='mean', pos_weight=None, name=None ): r""" - Combine the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer. + Combine the sigmoid layer and the :ref:`api_paddle_nn_BCELoss` layer. This measures the element-wise probability error in classification tasks in which each class is independent. @@ -1337,13 +1337,13 @@ def l1_loss(input, label, reduction='mean', name=None): check_variable_and_dtype( input, 'input', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64', 'int32', 'int64', 'float16'], 'l1_loss', ) check_variable_and_dtype( label, 'label', - ['float32', 'float64', 'int32', 'int64'], + ['float32', 'float64', 'int32', 'int64', 'float16'], 'l1_loss', ) diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index ed77c07ffdb457..8218a8c67ca563 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -116,8 +116,8 @@ def batch_norm( x, running_mean, running_var, - weight, - bias, + weight=None, + bias=None, training=False, momentum=0.9, epsilon=1e-05, @@ -134,8 +134,8 @@ def batch_norm( x(Tesnor): input value. It's data type should be float32, float64. running_mean(Tensor): running mean. running_var(Tensor): running variance. - weight(Tensor): The weight tensor of batch_norm, can not be None. - bias(Tensor): The bias tensor of batch_norm can not be None. + weight(Tensor, optional): The weight tensor of batch_norm. Default: None. + bias(Tensor, optional): The bias tensor of batch_norm. Default: None. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Default False. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. @@ -229,12 +229,14 @@ def batch_norm( inputs = { "X": [x], - "Scale": [weight], - "Bias": [bias], "Mean": [running_mean], "Variance": [running_var], } + if weight: + inputs['Scale'] = [weight] + if bias: + inputs['Bias'] = [bias] helper = LayerHelper('batch_norm', **locals()) from paddle.base.data_feeder import convert_dtype diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 6e53daa02cddb8..3e78f585881b7c 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -189,7 +189,7 @@ def avg_pool1d( ): """ This API implements average pooling 1d operation, - See more details in :ref:`api_nn_pooling_AvgPool1d` . + See more details in :ref:`api_paddle_nn_AvgPool1d` . Args: x (Tensor): The input tensor of pooling operator which is a 3-D tensor with @@ -312,7 +312,7 @@ def avg_pool2d( ): """ This API implements average pooling 2d operation. - See more details in :ref:`api_nn_pooling_AvgPool2d` . + See more details in :ref:`api_paddle_nn_AvgPool2d` . Args: x (Tensor): The input tensor of pooling operator which is a 4-D tensor with @@ -445,7 +445,7 @@ def avg_pool3d( ): """ This API implements average pooling 3d operation. - See more details in :ref:`api_nn_pooling_AvgPool3d` . + See more details in :ref:`api_paddle_nn_AvgPool3d` . Args: x (Tensor): The input tensor of pooling operator, which is a 5-D tensor with @@ -572,7 +572,7 @@ def max_pool1d( ): """ This API implements max pooling 1d opereation. - See more details in :ref:`api_nn_pooling_MaxPool1d` . + See more details in :ref:`api_paddle_nn_MaxPool1d` . Args: x (Tensor): The input tensor of pooling operator which is a 3-D tensor with @@ -1184,7 +1184,7 @@ def max_pool2d( ): """ This API implements max pooling 2d operation. - See more details in :ref:`api_nn_pooling_MaxPool2d` . + See more details in :ref:`api_paddle_nn_MaxPool2d` . Args: x (Tensor): The input tensor of pooling operator which is a 4-D tensor with @@ -1484,7 +1484,7 @@ def adaptive_avg_pool1d(x, output_size, name=None): Adaptive average pooling 1d operation on :attr:`x` according to :attr:`output_size`. Notes: - See more details in :ref:`api_nn_pooling_AdaptiveAvgPool1d` . + See more details in :ref:`api_paddle_nn_AdaptiveAvgPool1d` . Args: x (Tensor): The input Tensor of pooling, which is a 3-D tensor with shape :math:`[N, C, L]`, where :math:`N` is batch size, :math:`C` is the number of channels and :math:`L` is the length of the feature. The data type is float32 or float64. @@ -1825,7 +1825,7 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None): def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): """ This API implements adaptive max pooling 1d operation. - See more details in :ref:`api_nn_pooling_AdaptiveMaxPool1d` . + See more details in :ref:`api_paddle_nn_AdaptiveMaxPool1d` . Args: x (Tensor): The input tensor of pooling operator, which is a 3-D tensor @@ -1921,7 +1921,7 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): """ This operation applies a 2D adaptive max pooling on input tensor. - See more details in :ref:`api_nn_pooling_AdaptiveMaxPool2d` . + See more details in :ref:`api_paddle_nn_AdaptiveMaxPool2d` . Args: x (Tensor): The input tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type can be float16, float32, float64, int32 or int64. @@ -2007,7 +2007,7 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): """ This operation applies a 3D adaptive max pooling on input tensor. - See more details in :ref:`api_nn_pooling_AdaptiveMaxPool3d` . + See more details in :ref:`api_paddle_nn_AdaptiveMaxPool3d` . Args: x (Tensor): The input tensor of adaptive max pool3d operator, which is a 5-D tensor. The data type can be float32, float64. diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index a690daab0ef211..2f40653d193bcd 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -727,46 +727,25 @@ def __init__( param_shape = [num_features] # create parameter - if weight_attr is False: - self.weight = self.create_parameter( - attr=None, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(1.0), - ) - self.weight.stop_gradient = True - else: + if weight_attr is not False: self.weight = self.create_parameter( attr=self._weight_attr, shape=param_shape, dtype=self._dtype, default_initializer=Constant(1.0), ) - self.weight.stop_gradient = ( - self._weight_attr is not None - and self._weight_attr.learning_rate == 0.0 - ) - if bias_attr is False: - self.bias = self.create_parameter( - attr=None, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(0.0), - is_bias=True, - ) - self.bias.stop_gradient = True else: + self.weight = None + if bias_attr is not False: self.bias = self.create_parameter( attr=self._bias_attr, shape=param_shape, dtype=self._dtype, is_bias=True, ) - self.bias.stop_gradient = ( - self._bias_attr is not None - and self._bias_attr.learning_rate == 0.0 - ) + else: + self.bias = None moving_mean_name = None moving_variance_name = None @@ -992,10 +971,6 @@ def __init__( self._act = act self._use_mkldnn = _global_flags()["FLAGS_use_mkldnn"] - assert ( - bias_attr is not False - ), "bias_attr should not be False in batch_norm." - if dtype == "float16": self._dtype = "float32" else: @@ -1004,25 +979,24 @@ def __init__( param_shape = [num_channels] # create parameter - self.weight = self.create_parameter( - attr=self._param_attr, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(1.0), - ) - self.weight.stop_gradient = ( - use_global_stats and self._param_attr.learning_rate == 0.0 - ) - - self.bias = self.create_parameter( - attr=self._bias_attr, - shape=param_shape, - dtype=self._dtype, - is_bias=True, - ) - self.bias.stop_gradient = ( - use_global_stats and self._param_attr.learning_rate == 0.0 - ) + if param_attr is not False: + self.weight = self.create_parameter( + attr=self._param_attr, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0), + ) + else: + self.weight = None + if bias_attr is not False: + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=param_shape, + dtype=self._dtype, + is_bias=True, + ) + else: + self.bias = None self._mean = self.create_parameter( attr=ParamAttr( @@ -1610,6 +1584,24 @@ def __init__( None, name, ) + param_shape = [num_features] + if weight_attr is False: + self.weight = self.create_parameter( + attr=None, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0), + ) + self.weight.stop_gradient = True + if bias_attr is False: + self.bias = self.create_parameter( + attr=None, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(0.0), + is_bias=True, + ) + self.bias.stop_gradient = True def _check_data_format(self): if self._data_format in ['NCHW', 'NCDHW', 'NC', 'NCL']: diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index ad7ecbc4a1cd24..6f0acfaedbbbb1 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -368,6 +368,28 @@ def __impl__(self, other_var): '__rtruediv__', _binary_creator_('__rtruediv__', paddle.tensor.divide, True, None), ), + ( + '__pow__', + _binary_creator_('__pow__', paddle.tensor.pow, False, None), + ), + ( + '__rpow__', + _binary_creator_('__rpow__', paddle.tensor.pow, True, None), + ), + ( + '__floordiv__', + _binary_creator_( + '__floordiv__', paddle.tensor.floor_divide, False, None + ), + ), + ( + '__mod__', + _binary_creator_('__mod__', paddle.tensor.remainder, False, None), + ), + ( + '__matmul__', + _binary_creator_('__matmul__', paddle.tensor.matmul, False, None), + ), ] global _already_patch_opresult diff --git a/python/paddle/quantization/config.py b/python/paddle/quantization/config.py index 28feb8c6b087ff..bafc24488f0898 100644 --- a/python/paddle/quantization/config.py +++ b/python/paddle/quantization/config.py @@ -127,7 +127,7 @@ def add_layer_config( >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) >>> q_config = QuantConfig(activation=None, weight=None) >>> q_config.add_layer_config([model.fc], activation=quanter, weight=quanter) - >>> # doctest: +SKIP + >>> # doctest: +SKIP('random memory address') >>> print(q_config) Global config: None @@ -176,7 +176,7 @@ def add_name_config( >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) >>> q_config = QuantConfig(activation=None, weight=None) >>> q_config.add_name_config([model.fc.full_name()], activation=quanter, weight=quanter) - >>> # doctest: +SKIP + >>> # doctest: +SKIP('random memory address') >>> print(q_config) Global config: None @@ -226,7 +226,7 @@ def add_type_config( >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) >>> q_config = QuantConfig(activation=None, weight=None) >>> q_config.add_type_config([Linear], activation=quanter, weight=quanter) - >>> # doctest: +SKIP + >>> # doctest: +SKIP('random memory address') >>> print(q_config) Global config: None diff --git a/python/paddle/quantization/factory.py b/python/paddle/quantization/factory.py index b0ef9062201864..eb8916460975c8 100644 --- a/python/paddle/quantization/factory.py +++ b/python/paddle/quantization/factory.py @@ -83,7 +83,7 @@ def quanter(class_name): Examples: .. code-block:: python - >>> # doctest: +SKIP + >>> # doctest: +SKIP('need 2 file to run example') >>> # Given codes in ./customized_quanter.py >>> from paddle.quantization import quanter >>> from paddle.quantization import BaseQuanter diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 0d64440f2426ff..45f6a7c4bdb7fc 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -1224,9 +1224,6 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): with true_cond_block.block(): origin_true_output = true_fn() if origin_true_output is not None: - origin_true_output = map_structure( - create_undefined_var_in_subblock, origin_true_output - ) true_output = map_structure( copy_to_parent_func, origin_true_output ) @@ -1243,9 +1240,6 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): with false_cond_block.block(): origin_false_output = false_fn() if origin_false_output is not None: - origin_false_output = map_structure( - create_undefined_var_in_subblock, origin_false_output - ) false_output = map_structure( copy_to_parent_func, origin_false_output ) @@ -1363,18 +1357,6 @@ def merge_every_var_list(false_vars, true_vars, name): return merged_output -def create_undefined_var_in_subblock(var): - # to make sure the undefined var created in subblock. - from paddle.jit.dy2static.utils import ( - UndefinedVar, - create_undefined_variable_local, - ) - - if isinstance(var, UndefinedVar): - var = create_undefined_variable_local() - return var - - def copy_var_to_parent_block(var, layer_helper): if not isinstance(var, Variable): return var diff --git a/python/paddle/static/nn/metric.py b/python/paddle/static/nn/metric.py index f9941c47447231..672bc80ece9267 100644 --- a/python/paddle/static/nn/metric.py +++ b/python/paddle/static/nn/metric.py @@ -17,9 +17,14 @@ import numpy as np import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.base.data_feeder import check_variable_and_dtype -from paddle.base.framework import Variable, _create_tensor, in_dygraph_mode +from paddle.base.framework import ( + Variable, + _create_tensor, + in_dygraph_mode, + in_pir_mode, +) from paddle.base.layer_helper import LayerHelper from paddle.nn.initializer import ConstantInitializer @@ -88,6 +93,10 @@ def accuracy(input, label, k=1, correct=None, total=None): topk_out, topk_indices, label, correct, total ) return _acc + elif in_pir_mode(): + topk_out, topk_indices = paddle.topk(input, k=k, sorted=False) + _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label) + return _acc helper = LayerHelper("accuracy", **locals()) check_variable_and_dtype( diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ce4cfc8ee883ba..c8bfe99f91e6b3 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -24,6 +24,7 @@ from .creation import to_tensor # noqa: F401 from .creation import diag # noqa: F401 from .creation import diagflat # noqa: F401 +from .creation import diag_embed # noqa: F401 from .creation import eye # noqa: F401 from .creation import linspace # noqa: F401 from .creation import fill_constant # noqa: F401 @@ -694,6 +695,7 @@ 'i1e', 'polygamma', 'polygamma_', + 'diag_embed', 'atan2', 'diagflat', 'multinomial', diff --git a/python/paddle/tensor/attribute.py b/python/paddle/tensor/attribute.py index f3dcaf06cd9bf0..8bc7cff200b344 100644 --- a/python/paddle/tensor/attribute.py +++ b/python/paddle/tensor/attribute.py @@ -20,11 +20,7 @@ from paddle import _C_ops from ..base.data_feeder import check_type, check_variable_and_dtype -from ..base.framework import ( - in_dygraph_mode, - in_dynamic_or_pir_mode, - in_pir_mode, -) +from ..base.framework import in_dynamic_or_pir_mode, in_pir_mode from ..common_ops_import import Variable from ..framework import LayerHelper, core from .creation import _complex_to_real_dtype, assign @@ -300,7 +296,7 @@ def real(x, name=None): [[1., 2., 3.], [4., 5., 6.]]) """ - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.real(x) else: check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'real') @@ -348,7 +344,7 @@ def imag(x, name=None): [[6., 5., 4.], [3., 2., 1.]]) """ - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.imag(x) else: check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'imag') diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index e4f0ea824e3a41..5e99753516b0bb 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1674,6 +1674,125 @@ def meshgrid(*args, **kwargs): return out +def diag_embed(input, offset=0, dim1=-2, dim2=-1): + """ + Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) + are filled by ``input``. By default, a 2D plane formed by the last two dimensions + of the returned tensor will be selected. + + The argument ``offset`` determines which diagonal is generated: + + - If offset = 0, it is the main diagonal. + - If offset > 0, it is above the main diagonal. + - If offset < 0, it is below the main diagonal. + + Args: + input(Tensor|numpy.ndarray): The input tensor. Must be at least 1-dimensional. The input data type should be float32, float64, int32, int64. + offset(int, optional): Which diagonal to consider. Default: 0 (main diagonal). + dim1(int, optional): The first dimension with respect to which to take diagonal. Default: -2. + dim2(int, optional): The second dimension with respect to which to take diagonal. Default: -1. + + Returns: + Tensor, the output data type is the same as input data type. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> diag_embed_input = paddle.arange(6) + + >>> diag_embed_output1 = paddle.diag_embed(diag_embed_input) + >>> print(diag_embed_output1) + Tensor(shape=[6, 6], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 2, 0, 0, 0], + [0, 0, 0, 3, 0, 0], + [0, 0, 0, 0, 4, 0], + [0, 0, 0, 0, 0, 5]]) + + >>> diag_embed_output2 = paddle.diag_embed(diag_embed_input, offset=-1, dim1=0,dim2=1 ) + >>> print(diag_embed_output2) + Tensor(shape=[7, 7], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 2, 0, 0, 0, 0], + [0, 0, 0, 3, 0, 0, 0], + [0, 0, 0, 0, 4, 0, 0], + [0, 0, 0, 0, 0, 5, 0]]) + + >>> diag_embed_input_2dim = paddle.reshape(diag_embed_input,[2,3]) + >>> print(diag_embed_input_2dim) + Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 1, 2], + [3, 4, 5]]) + >>> diag_embed_output3 = paddle.diag_embed(diag_embed_input_2dim,offset= 0, dim1=0, dim2=2 ) + >>> print(diag_embed_output3) + Tensor(shape=[3, 2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[0, 0, 0], + [3, 0, 0]], + [[0, 1, 0], + [0, 4, 0]], + [[0, 0, 2], + [0, 0, 5]]]) + """ + if not isinstance(input, Variable): + input = assign(input) + + if in_dynamic_mode(): + return _C_ops.diag_embed(input, offset, dim1, dim2) + + inputs = {'Input': [input]} + attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} + + def __check_input(input, offset, dim1, dim2): + check_dtype( + input.dtype, + 'Input', + ['int32', 'int64', 'float16', 'float32', 'float64'], + 'diag_embed', + ) + + input_shape = list(input.shape) + assert len(input_shape) >= 1, ( + "Input must be at least 1-dimensional, " + "But received Input's dimensional: %s.\n" % len(input_shape) + ) + + assert np.abs(dim1) <= len(input_shape), ( + "Dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" + % (-(len(input_shape) + 1), len(input_shape), dim1) + ) + + assert np.abs(dim2) <= len(input_shape), ( + "Dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" + % (-(len(input_shape) + 1), len(input_shape), dim2) + ) + + dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 + 1 + dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 + 1 + assert dim1_ != dim2_, ( + "dim1 and dim2 cannot be the same dimension." + "But received dim1 = %d, dim2 = %d\n" % (dim1, dim2) + ) + + __check_input(input, offset, dim1, dim2) + helper = LayerHelper("diag_embed", **locals()) + + out = helper.create_variable_for_type_inference(dtype=input.dtype) + + helper.append_op( + type='diag_embed', + inputs={'Input': [input]}, + attrs={'offset': offset, 'dim1': dim1, 'dim2': dim2}, + outputs={'Out': [out]}, + ) + out.stop_gradient = True + return out + + def diagflat(x, offset=0, name=None): """ If ``x`` is a vector (1-D tensor), a 2-D square tensor with the elements of ``x`` as the diagonal is returned. diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 71016a2208c154..f16cc3ad8d3a05 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -382,7 +382,7 @@ def frobenius_norm(input, dim=None, keepdim=False, name=None): "The dim of frobenius norm op should be None or two elements list!" ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if dim is None: return _C_ops.frobenius_norm(input, [], keepdim, True) return _C_ops.frobenius_norm(input, dim, keepdim, False) @@ -1375,7 +1375,7 @@ def t(input, name=None): "length of Input(input) is %s. Perhaps you can use paddle." "tensor.transpose() instead." % len(input.shape) ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if len(input.shape) <= 1: return input # 2-D tensor @@ -1539,7 +1539,7 @@ def cholesky(x, upper=False, name=None): [1.06467664, 0.17859250, 0. ], [1.30602181, 0.08326444, 0.22790681]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cholesky(x, upper) else: check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky') @@ -1958,14 +1958,12 @@ def slogdet(x, name=None): >>> import paddle >>> paddle.seed(2023) - >>> x = paddle.randn([3,3,3]) + >>> x = paddle.randn([3, 3, 3]) >>> A = paddle.linalg.slogdet(x) >>> print(A) - >>> # doctest: +SKIP Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, [[-1. , 1. , 1. ], [ 0.25681755, -0.25061053, -0.10809582]]) - >>> # doctest: -SKIP """ if in_dynamic_mode(): @@ -2801,10 +2799,12 @@ def eigh(x, UPLO='L', name=None): property. For more information, please refer to :ref:`api_guide_Name`. Returns: - - out_value(Tensor): A Tensor with shape [*, N] and data type of float32 and float64. - The eigenvalues of eigh op. - - out_vector(Tensor): A Tensor with shape [*, N, N] and data type of float32,float64, - complex64 and complex128. The eigenvectors of eigh op. + 2-element tuple containing + + - out_value(Tensor): A Tensor with shape :math:`[*, N]` and data type of float32 and float64. + The eigenvalues of eigh op. + - out_vector(Tensor): A Tensor with shape :math:`[*, N, N]` and data type of float32, float64, + complex64 and complex128. The eigenvectors of eigh op. Examples: .. code-block:: python @@ -3283,7 +3283,7 @@ def cholesky_solve(x, y, upper=False, name=None): [-7. ], [ 9.50000000]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cholesky_solve(x, y, upper) else: helper = LayerHelper("cholesky_solve", **locals()) diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 9b50993b891667..865b2a62d4ca21 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -139,7 +139,7 @@ def logical_and(x, y, out=None, name=None): [True , False, True , False]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logical_and(x, y) return _logical_op( @@ -413,7 +413,7 @@ def equal_all(x, y, name=None): Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True, False) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.equal_all(x, y) else: helper = LayerHelper("equal_all", **locals()) @@ -1213,7 +1213,7 @@ def bitwise_or(x, y, out=None, name=None): Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [-1, -1, -3]) """ - if in_dynamic_mode() and out is None: + if in_dynamic_or_pir_mode() and out is None: return _C_ops.bitwise_or(x, y) return _bitwise_op( @@ -1272,7 +1272,7 @@ def bitwise_xor(x, y, out=None, name=None): Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [-1, -3, -4]) """ - if in_dynamic_mode() and out is None: + if in_dynamic_or_pir_mode() and out is None: return _C_ops.bitwise_xor(x, y) return _bitwise_op( op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True @@ -1328,7 +1328,7 @@ def bitwise_not(x, out=None, name=None): Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [ 4, 0, -2]) """ - if in_dynamic_mode() and out is None: + if in_dynamic_or_pir_mode() and out is None: return _C_ops.bitwise_not(x) return _bitwise_op( diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ae61880c997bed..aa790e1fd69960 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1413,7 +1413,7 @@ def flip(x, axis, name=None): if isinstance(axis, int): axis = [axis] - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.flip(x, axis) else: helper = LayerHelper("flip", **locals()) @@ -2866,7 +2866,7 @@ def gather(x, index, axis=None, name=None): if axis is None: axis = 0 - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.gather(x, index, axis) else: check_variable_and_dtype( @@ -3430,7 +3430,7 @@ def expand_as(x, y, name=None): [[1, 2, 3], [1, 2, 3]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.expand_as(x, None, y.shape) else: check_variable_and_dtype( @@ -4461,7 +4461,7 @@ def as_complex(x, name=None): [[1j , (2+3j) , (4+5j) ], [(6+7j) , (8+9j) , (10+11j)]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.as_complex(x) else: check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'as_complex') @@ -4512,7 +4512,7 @@ def as_real(x, name=None): [8. , 9. ], [10., 11.]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.as_real(x) else: check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'as_real') @@ -4572,6 +4572,8 @@ def repeat_interleave(x, repeats, axis=None, name=None): [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]) """ + if isinstance(repeats, Variable) and not repeats.shape: + repeats = paddle.reshape(repeats, [1]) if axis is None: x = paddle.flatten(x) axis = 0 @@ -4688,7 +4690,7 @@ def moveaxis(x, source, destination, name=None): for i in range(len(src_dims)): perm[dst_dims[i]] = src_dims[i] - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.transpose(x, perm) return out else: diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 4312a6aa55c3ca..c5ed7ff655c5cd 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -20,7 +20,7 @@ import paddle from paddle import _C_ops, _legacy_C_ops -from paddle.common_ops_import import VarDesc, dygraph_only, dygraph_utils +from paddle.common_ops_import import VarDesc, dygraph_utils from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..base.data_feeder import ( @@ -130,7 +130,7 @@ def _get_reduce_axis(axis, x): def _get_reduce_axis_with_tensor(axis, x): - if isinstance(axis, Variable): + if isinstance(axis, (Variable, paddle.pir.OpResult)): if axis.shape[0] == len(x.shape): reduce_all = True else: @@ -941,7 +941,7 @@ def floor_divide(x, y, name=None): [2, 0, 2, 2]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.floor_divide(x, y) else: return _elementwise_op(LayerHelper('elementwise_floordiv', **locals())) @@ -1105,13 +1105,10 @@ def multiply_(x, y, name=None): return _C_ops.multiply_(x, y) -@dygraph_only -def _elementwise_op_with_axis_in_dygraph( - x, y, axis=-1, name=None, op_type="Undifined" -): +def _elementwise_op_with_axis(x, y, axis=-1, name=None, op_type="Undifined"): assert ( - in_dynamic_mode() - ), "You can only call `_elementwise_op_with_axis_in_dygraph` function within in_dynamic_mode" + in_dynamic_or_pir_mode() + ), "You can only call `_elementwise_op_with_axis` function within in_dynamic_or_pir_mode" assert op_type in ["add", "subtract", "multiply", "divide"], ( "op_name input error! _elementwise_op_with_axis is an inner function to replace elementwise_add/sub/mul/div. Input op_name=%s, Expect op_name=[add|subtract|multiply|divide]\n" % op_type @@ -1132,8 +1129,8 @@ def _elementwise_op_with_axis_in_dygraph( def _add_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape - if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "add") + if in_dynamic_or_pir_mode(): + return _elementwise_op_with_axis(x, y, axis, name, "add") else: op_type = 'elementwise_add' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1142,9 +1139,7 @@ def _add_with_axis(x, y, axis=-1, name=None): def _subtract_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph( - x, y, axis, name, "subtract" - ) + return _elementwise_op_with_axis(x, y, axis, name, "subtract") else: op_type = 'elementwise_sub' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1153,9 +1148,7 @@ def _subtract_with_axis(x, y, axis=-1, name=None): def _multiply_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph( - x, y, axis, name, "multiply" - ) + return _elementwise_op_with_axis(x, y, axis, name, "multiply") else: op_type = 'elementwise_mul' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1164,7 +1157,7 @@ def _multiply_with_axis(x, y, axis=-1, name=None): def _divide_with_axis(x, y, axis=-1, name=None): # opt performance, only dynamic mode needs reshape if in_dynamic_mode(): - return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "divide") + return _elementwise_op_with_axis(x, y, axis, name, "divide") else: op_type = 'elementwise_div' return _elementwise_op(LayerHelper(op_type, **locals())) @@ -3429,7 +3422,7 @@ def log10(x, name=None): Tensor(shape=[1], dtype=float64, place=Place(cpu), stop_gradient=True, [1.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.log10(x) else: check_variable_and_dtype( @@ -3670,7 +3663,7 @@ def __check_input(x, offset, axis1, axis2): "But received axis1 = %d, axis2 = %d\n" % (axis1, axis2) ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.trace(x, offset, axis1, axis2) else: __check_input(x, offset, axis1, axis2) @@ -3925,7 +3918,7 @@ def cumsum(x, axis=None, dtype=None, name=None): if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): x = cast(x, dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if axis is None: axis = -1 return _C_ops.cumsum(x, axis, flatten, False, False) @@ -4437,7 +4430,7 @@ def isnan(x, name=None): Tensor(shape=[7], dtype=bool, place=Place(cpu), stop_gradient=True, [False, False, False, False, False, True , True ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.isnan(x) else: helper = LayerHelper("isnan_v2", **locals()) @@ -4538,7 +4531,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): x = cast(x, dtype) reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.prod(x, axis, keepdim, reduce_all) else: helper = LayerHelper('reduce_prod', **locals()) @@ -4681,7 +4674,7 @@ def increment(x, value=1.0, name=None): [1.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.increment_(x, value) else: check_variable_and_dtype( @@ -5209,7 +5202,7 @@ def logit(x, eps=None, name=None): """ if eps is None: eps = 0.0 - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logit(x, eps) else: check_variable_and_dtype( diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 8fd2473231f931..51f09119ef2e48 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -459,7 +459,7 @@ def nonzero(x, as_tuple=False): shape = x.shape rank = len(shape) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): outs = _C_ops.nonzero(x) else: check_variable_and_dtype( @@ -1005,7 +1005,7 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None): """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if axis is None: axis = -1 out, indices = _C_ops.topk(x, k, axis, largest, sorted) diff --git a/python/paddle/utils/environments.py b/python/paddle/utils/environments.py new file mode 100644 index 00000000000000..1d20d4fb79a420 --- /dev/null +++ b/python/paddle/utils/environments.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class EnvironmentVariable(Generic[T]): + name: str + default: T + + def __init__(self, name: str, default: T): + self.name = name + self.default = default + + def get(self) -> T: + raise NotImplementedError() + + def set(self, value: T) -> None: + raise NotImplementedError() + + def delete(self) -> None: + del os.environ[self.name] + + +class StringEnvironmentVariable(EnvironmentVariable[str]): + def __init__(self, name: str, default: str): + super().__init__(name, default) + assert isinstance(default, str), "default must be a string" + + def get(self) -> str: + return os.getenv(self.name, self.default) + + def set(self, value: str) -> None: + os.environ[self.name] = value + + +class BooleanEnvironmentVariable(EnvironmentVariable[bool]): + BOOLEAN_IS_SET = ("y", "yes", "t", "true", "on", "1") + + def __init__(self, name: str, default: bool): + super().__init__(name, default) + assert isinstance(default, bool), "default must be a boolean" + + def get(self) -> bool: + default = str(self.default).lower() + env_str = os.getenv(self.name, default).lower() + return env_str in BooleanEnvironmentVariable.BOOLEAN_IS_SET + + def set(self, value: bool) -> None: + os.environ[self.name] = str(value).lower() + + +class IntegerEnvironmentVariable(EnvironmentVariable[int]): + def __init__(self, name: str, default: int): + super().__init__(name, default) + assert isinstance(default, int), "default must be an integer" + + def get(self) -> int: + try: + return int(os.getenv(self.name, str(self.default))) + except ValueError: + return self.default + + def set(self, value: int) -> None: + os.environ[self.name] = str(value) + + +class EnvironmentVariableGuard(Generic[T]): + variable: EnvironmentVariable[T] + original_value: T + + def __init__(self, variable: EnvironmentVariable[T], value: T): + self.variable = variable + self.original_value = variable.get() + self.variable.set(value) + + def __enter__(self) -> EnvironmentVariableGuard: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.variable.set(self.original_value) diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 8700ab2e070744..a5513adafb6cc9 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -10,6 +10,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) test_auto_parallel_relaunch) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + py_test_modules(test_mp_allreduce_matmul_grad_overlapping MODULES + test_mp_allreduce_matmul_grad_overlapping) + set_tests_properties(test_mp_allreduce_matmul_grad_overlapping + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner) set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) diff --git a/test/auto_parallel/mp_allreduce_matmul_grad_overlapping_unittest.py b/test/auto_parallel/mp_allreduce_matmul_grad_overlapping_unittest.py new file mode 100644 index 00000000000000..2945dd1b311518 --- /dev/null +++ b/test/auto_parallel/mp_allreduce_matmul_grad_overlapping_unittest.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed.fleet import auto + +paddle.enable_static() + + +def reset_prog(): + paddle.base.framework.switch_main_program(paddle.static.Program()) + paddle.base.framework.switch_startup_program(paddle.static.Program()) + + +class TestMPAllreduceMatmulGradOverlapping(unittest.TestCase): + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2023) + np.random.seed(2023) + random.seed(2023) + place = paddle.base.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_mp_engine(self, allreduce_matmul_grad_overlapping): + reset_prog() + + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + strategy.mp_optimization.allreduce_matmul_grad_overlapping = ( + allreduce_matmul_grad_overlapping + ) + + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def run_mp(self, allreduce_matmul_grad_overlapping): + mp_engine = self.get_mp_engine(allreduce_matmul_grad_overlapping) + history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + return np.array(history.history["loss"]) + + def check_results(self, ref_losses, check_losses, rtol=None, atol=None): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=rtol or self.rtol, + atol=atol or self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) + + def test_mp_allreduce_matmul_grad_overlapping(self): + losses_with_allreduce_matmul_grad_overlapping = self.run_mp(True) + losses_without_allreduce_matmul_grad_overlapping = self.run_mp(False) + + np.testing.assert_equal( + losses_with_allreduce_matmul_grad_overlapping, + losses_without_allreduce_matmul_grad_overlapping, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/reshard_api.py b/test/auto_parallel/reshard_api.py index c77cb9b773cac1..5ad046080fa8ff 100644 --- a/test/auto_parallel/reshard_api.py +++ b/test/auto_parallel/reshard_api.py @@ -18,6 +18,7 @@ import paddle import paddle.distributed as dist +from paddle import nn class TestReshardAPI: @@ -33,6 +34,8 @@ def run_test_cases(self): if self._backend == "cpu": paddle.set_device("cpu") self.test_case_p_to_r() + self.test_case_r_to_s() + self.test_case_forward_and_backward() def test_case_p_to_r(self): a = paddle.ones(self._shape) @@ -82,6 +85,68 @@ def test_case_r_to_s(self): assert np.equal(output_tensor.shape, input_tensor.shape).all() assert np.equal(output_tensor._local_shape, out_shape).all() + def test_case_forward_and_backward(self): + if self._backend == "cpu": + return + + np.random.seed(1901) + input_numpy = np.random.random(self._shape).astype("float32") + label_numpy = np.random.random(self._shape).astype('float32') + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[self._shard] = "x" + + in_dist_attr = dist.DistAttr( + mesh=dist.ProcessMesh([0, 1], dim_names=["x"]), + sharding_specs=in_shard_specs, + ) + + out_dist_attr = dist.DistAttr( + mesh=dist.ProcessMesh([0, 1], dim_names=["x"]), + sharding_specs=out_shard_specs, + ) + + local_input = paddle.to_tensor(input_numpy) + dist_input = dist.shard_tensor( + paddle.to_tensor(input_numpy), dist_attr=in_dist_attr + ) + + local_input.stop_gradient = False + dist_input.stop_gradient = False + + local_output = local_input + paddle.ones(self._shape) + dist_output = dist_input + dist.shard_tensor( + paddle.ones(self._shape), dist_attr=in_dist_attr + ) + dist_output.stop_gradient = False + + dist_output = dist.reshard(dist_output, dist_attr=out_dist_attr) + + local_label = paddle.to_tensor(label_numpy) + dist_label = dist.shard_tensor( + paddle.to_tensor(label_numpy), dist_attr=out_dist_attr + ) + + local_loss_fn = nn.MSELoss() + dist_loss_fn = nn.MSELoss() + + local_loss = local_loss_fn(local_output, local_label) + dist_loss = dist_loss_fn(dist_output, dist_label) + + np.testing.assert_allclose( + local_loss.numpy(), dist_loss.numpy(), rtol=1e-5, atol=1e-5 + ) + + local_loss.backward() + dist_loss.backward() + np.testing.assert_allclose( + local_input.grad.numpy(), + dist_input.grad.numpy(), + rtol=1e-5, + atol=1e-5, + ) + if __name__ == '__main__': TestReshardAPI().run_test_cases() diff --git a/test/auto_parallel/semi_auto_parallel_clear_gradient.py b/test/auto_parallel/semi_auto_parallel_clear_gradient.py new file mode 100644 index 00000000000000..4c755fa70a76b4 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_clear_gradient.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from semi_auto_parallel_simple_net import MPDemoNet + +import paddle +import paddle.distributed as dist +from paddle import nn + +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +def run_dynamic(layer, image, label): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(image) + image.stop_gradient = False + out = layer(image) + + label = paddle.to_tensor(label) + loss = loss_fn(out, label) + + loss.backward() + layer.w0.clear_gradient() + layer.w1.clear_gradient(False) + + +class TestSemiAutoParallelClearGradient: + def test_clear_gradient(): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype('float32') + label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32') + w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32') + w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32') + run_dynamic(layer=MPDemoNet(w0, w1, mesh), image=image, label=label) + + +if __name__ == "__main__": + TestSemiAutoParallelClearGradient.test_clear_gradient() diff --git a/test/auto_parallel/semi_auto_parallel_for_reduction.py b/test/auto_parallel/semi_auto_parallel_for_reduction.py new file mode 100644 index 00000000000000..4b2e7d4bb026b2 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_reduction.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestReductionApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body(self, x_shape, out_shape, x_specs, axis, keepdim, op_func): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x = paddle.randn(x_shape, self._dtype) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + dist_out = op_func(dist_x, axis=axis, keepdim=keepdim) + out = op_func(x, axis=axis, keepdim=keepdim) + self.check_tensor_eq(out, dist_out) + np.testing.assert_equal(dist_out.shape, out_shape, verbose=True) + + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + + def test_sum_x_shard(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4, 6], + x_specs=['x', None, None], + axis=1, + keepdim=False, + op_func=paddle.sum, + ) + + def test_sum_x_shard_on_axis(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4], + x_specs=[None, 'x', None], + axis=[1, 2], + keepdim=False, + op_func=paddle.sum, + ) + + def test_sum_x_shard_on_axis_keepdim(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4, 1, 6], + x_specs=[None, 'x', None], + axis=1, + keepdim=True, + op_func=paddle.sum, + ) + + def test_mean_x_shard(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[8, 6], + x_specs=['x', None, None], + axis=-3, + keepdim=False, + op_func=paddle.mean, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_sum_x_shard() + self.test_sum_x_shard_on_axis() + self.test_sum_x_shard_on_axis_keepdim() + self.test_mean_x_shard() + + +if __name__ == '__main__': + TestReductionApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py index 3ca9baac5b5082..1c526874093362 100644 --- a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py +++ b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py @@ -67,6 +67,109 @@ def test_unbind(self): dist_out.backward() self.check_tensor_eq(local_in.grad, dist_in.grad) + # input: paddle::optional + # output: phi::Tensor + def test_expand_as(self): + x1 = np.random.random(size=[2, 8]).astype("float32") + x2 = np.random.random(size=[2, 2, 8]).astype("float32") + local_in1, dist_in1 = self.create_local_and_dist_tensor_pair( + x1, ['x', None] + ) + local_in2, dist_in2 = self.create_local_and_dist_tensor_pair( + x2, [None, None, None] + ) + local_out = paddle.expand_as(local_in1, local_in2) + dist_out = paddle.expand_as(dist_in1, dist_in2) + self.check_tensor_eq(local_out, dist_out) + + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_in1.grad, dist_in1.grad) + + # input: phi::Tensor + # output: inplace paddle::optional + def test_adamax(self): + dtype = np.float32 + mp_dtype = np.float32 + shape = [120, 320] + + beta1 = 0.78 + beta2 = 0.899 + epsilon = 1e-5 + param = np.random.random(shape).astype(dtype) + grad = np.random.random(shape).astype(dtype) + moment = np.random.random(shape).astype(dtype) + inf_norm = np.random.random(shape).astype(dtype) + master_param = param.astype(mp_dtype) + + lr = np.array([0.002]).astype("float32") + beta1_pow = np.array([beta1**10]).astype("float32") + + local_param, dist_param = self.create_local_and_dist_tensor_pair( + param, ['x', None] + ) + local_grad, dist_grad = self.create_local_and_dist_tensor_pair( + grad, ['x', None] + ) + local_lr, dist_lr = self.create_local_and_dist_tensor_pair(lr, [None]) + ( + local_beta1_pow, + dist_beta1_pow, + ) = self.create_local_and_dist_tensor_pair(beta1_pow, [None]) + local_moment, dist_moment = self.create_local_and_dist_tensor_pair( + moment, ['x', None] + ) + local_inf_norm, dist_inf_norm = self.create_local_and_dist_tensor_pair( + inf_norm, ['x', None] + ) + ( + local_master_param, + dist_master_param, + ) = self.create_local_and_dist_tensor_pair(master_param, [None, None]) + + ( + local_param_out, + local_moment_out, + local_inf_norm_out, + local_master_param_out, + ) = paddle._C_ops.adamax_( + local_param, + local_grad, + local_lr, + local_moment, + local_inf_norm, + local_beta1_pow, + local_master_param, + beta1, + beta2, + epsilon, + True, + ) + + ( + dist_param_out, + dist_moment_out, + dist_inf_norm_out, + dist_master_param_out, + ) = paddle._C_ops.adamax_( + dist_param, + dist_grad, + dist_lr, + dist_moment, + dist_inf_norm, + dist_beta1_pow, + dist_master_param, + beta1, + beta2, + epsilon, + True, + ) + + self.check_tensor_eq(local_param_out, dist_param_out) + self.check_tensor_eq(local_moment_out, dist_moment_out) + self.check_tensor_eq(local_inf_norm_out, dist_inf_norm_out) + self.check_tensor_eq(local_master_param_out, dist_master_param_out) + # mutiple operators def test_mse_loss(self): x = np.random.random(size=[4, 4]).astype(self._dtype) @@ -78,9 +181,9 @@ def test_mse_loss(self): y, [None] ) - mes_loss = paddle.nn.loss.MSELoss() - local_out = mes_loss(local_in, local_label) - dist_out = mes_loss(dist_in, dist_label) + mse_loss = paddle.nn.loss.MSELoss() + local_out = mse_loss(local_in, local_label) + dist_out = mse_loss(dist_in, dist_label) self.check_tensor_eq(local_out, dist_out) # test backward @@ -100,8 +203,10 @@ def run_test_case(self): else: raise ValueError("Only support cpu or gpu backend.") - self.test_mse_loss() self.test_unbind() + self.test_expand_as() + self.test_adamax() + self.test_mse_loss() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_recompute.py b/test/auto_parallel/semi_auto_parallel_recompute.py new file mode 100644 index 00000000000000..7329a1f4d0bafb --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_recompute.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from semi_auto_parallel_simple_net import MPDemoNetRecompute + +import paddle +import paddle.distributed as dist +from paddle import nn + +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +def run_dynamic(layer, image, label): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(image) + image.stop_gradient = False + out = layer(image) + + label = paddle.to_tensor(label) + loss = loss_fn(out, label) + + loss.backward() + return loss, layer.w0.grad, layer.w1.grad + + +class TestSemiAutoParallelRecompute: + def test_recompute(): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype('float32') + label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32') + w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32') + w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32') + run_dynamic( + layer=MPDemoNetRecompute(w0, w1, mesh), image=image, label=label + ) + + +if __name__ == "__main__": + TestSemiAutoParallelRecompute.test_recompute() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net.py b/test/auto_parallel/semi_auto_parallel_simple_net.py index fb7d0b4406697d..ea71bf13617889 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net.py @@ -18,7 +18,9 @@ import paddle import paddle.distributed as dist +import paddle.nn.functional as F from paddle import nn +from paddle.distributed.fleet.utils import recompute BATCH_SIZE = 16 BATCH_NUM = 4 @@ -26,7 +28,6 @@ CLASS_NUM = 10 -# TODO(chenweihang): update to MLP Layer later class DemoNet(nn.Layer): def __init__(self, np_w0, np_w1, param_suffix=""): super().__init__() @@ -46,9 +47,11 @@ def __init__(self, np_w0, np_w1, param_suffix=""): ) def forward(self, x): - y = paddle.matmul(x, self.w0) - z = paddle.matmul(y, self.w1) - return z + out = F.linear(x, self.w0) + out = F.relu(out) + out = F.linear(out, self.w1) + + return out class DPDemoNet(nn.Layer): @@ -71,7 +74,7 @@ def __init__(self, np_w0, np_w1, mesh, param_suffix=""): ) def forward(self, x): - y = paddle.matmul( + out = F.linear( dist.shard_tensor( x, dist_attr=dist.DistAttr( @@ -80,8 +83,10 @@ def forward(self, x): ), self.w0, ) - z = paddle.matmul(y, self.w1) - return z + out = F.relu(out) + out = F.linear(out, self.w1) + + return out class MPDemoNet(nn.Layer): @@ -109,13 +114,49 @@ def __init__(self, np_w0, np_w1, mesh, param_suffix=""): ) def forward(self, x): + out = F.linear(x, self.w0) + out = F.relu(out) + out = F.linear(out, self.w1) + + return out + + +class MPDemoNetRecompute(nn.Layer): + def __init__(self, np_w0, np_w1, mesh, param_suffix=""): + super().__init__() + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="mp_demo_weight_1" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w0), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']), + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="mp_nemo_weight_2" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w1), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]), + ) + + def _inner_forward_fn(self, x): y = paddle.matmul(x, self.w0) z = paddle.matmul(y, self.w1) return z + def forward(self, x): + z = recompute(self._inner_forward_fn, x) + return z + class PPDemoNet(nn.Layer): - def __init__(self, np_w0, np_w1, mesh0, mesh1): + def __init__(self, np_w0, np_w1, mesh0, mesh1, param_suffix=""): super().__init__() self.replicate_dist_attr0 = dist.DistAttr( mesh=mesh0, sharding_specs=[None, None] @@ -127,7 +168,7 @@ def __init__(self, np_w0, np_w1, mesh0, mesh1): self.create_parameter( shape=[IMAGE_SIZE, IMAGE_SIZE], attr=paddle.framework.ParamAttr( - name="pp_demo_weight_0", + name="pp_demo_weight_0" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w0), ), ), @@ -137,7 +178,7 @@ def __init__(self, np_w0, np_w1, mesh0, mesh1): self.create_parameter( shape=[IMAGE_SIZE, CLASS_NUM], attr=paddle.framework.ParamAttr( - name="pp_nemo_weight_1", + name="pp_nemo_weight_1" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w1), ), ), @@ -145,10 +186,11 @@ def __init__(self, np_w0, np_w1, mesh0, mesh1): ) def forward(self, x): - y = paddle.matmul(x, self.w0) - y = dist.reshard(y, dist_attr=self.replicate_dist_attr1) - z = paddle.matmul(y, self.w1) - return z + out = F.linear(x, self.w0) + out = F.relu(out) + out = dist.reshard(out, dist_attr=self.replicate_dist_attr1) + out = F.linear(out, self.w1) + return out class TestSimpleNetForSemiAutoParallel: @@ -177,7 +219,6 @@ def init_input_data(self): self.w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32') self.w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32') - # TODO(chenweihang): optimizer cannot run auto-parallel now # TODO(GhostScreaming): support pp backward later. def run_dynamic(self, layer, is_pp=False): # create loss @@ -193,10 +234,14 @@ def run_dynamic(self, layer, is_pp=False): return loss, None, None else: loss.backward() - return loss, layer.w0.grad, layer.w1.grad + opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=[layer.w0, layer.w1] + ) + opt.step() + return loss, layer.w0, layer.w1 def init_single_card_net_result(self): - self.base_loss, self.base_w0_grad, self.base_w1_grad = self.run_dynamic( + self.base_loss, self.base_w0, self.base_w1 = self.run_dynamic( DemoNet(self.w0, self.w1) ) @@ -206,20 +251,24 @@ def check_tensor_eq(self, a, b): np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) def test_dp_demo_net(self): - self.dp_loss, self.dp_w0_grad, self.dp_w1_grad = self.run_dynamic( + self.dp_loss, self.dp_w0, self.dp_w1 = self.run_dynamic( DPDemoNet(self.w0, self.w1, self._mesh) ) self.check_tensor_eq(self.dp_loss, self.base_loss) - self.check_tensor_eq(self.dp_w0_grad, self.base_w0_grad) - self.check_tensor_eq(self.dp_w1_grad, self.base_w1_grad) + self.check_tensor_eq(self.dp_w0.grad, self.base_w0.grad) + self.check_tensor_eq(self.dp_w1.grad, self.base_w1.grad) + self.check_tensor_eq(self.dp_w0, self.base_w0) + self.check_tensor_eq(self.dp_w1, self.base_w1) def test_mp_demo_net(self): - self.mp_loss, self.mp_w0_grad, self.mp_w1_grad = self.run_dynamic( + self.mp_loss, self.mp_w0, self.mp_w1 = self.run_dynamic( MPDemoNet(self.w0, self.w1, self._mesh) ) self.check_tensor_eq(self.mp_loss, self.base_loss) - self.check_tensor_eq(self.mp_w0_grad, self.base_w0_grad) - self.check_tensor_eq(self.mp_w1_grad, self.base_w1_grad) + self.check_tensor_eq(self.mp_w0.grad, self.base_w0.grad) + self.check_tensor_eq(self.mp_w1.grad, self.base_w1.grad) + self.check_tensor_eq(self.mp_w0, self.base_w0) + self.check_tensor_eq(self.mp_w1, self.base_w1) # TODO(GhostScreaming): support pp backward later. def test_pp_demo_net(self): diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_gradient_hook.py b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_hook.py new file mode 100644 index 00000000000000..f8d3209fd3cb85 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_hook.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + MPDemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + +hook_triggered = False + + +def backward_hook(): + def trigger_hook(grad): + global hook_triggered + hook_triggered = True + + return trigger_hook + + +class TestSimpleNetWithGradientHookForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + + def run_dynamic(self, layer): + loss_fn = nn.MSELoss() + image = paddle.to_tensor(self.image) + + out = layer(image) + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + loss.backward() + + def test_register_grad_hook(self): + model = MPDemoNet( + self.w0, self.w1, self._mesh, param_suffix="register_grad_hook" + ) + model.w0._register_grad_hook(backward_hook()) + self.run_dynamic(model) + global hook_triggered + assert hook_triggered + hook_triggered = False + + def test_register_hook(self): + model = MPDemoNet( + self.w0, self.w1, self._mesh, param_suffix="register_hook" + ) + model.w0.register_hook(backward_hook()) + self.run_dynamic(model) + global hook_triggered + assert hook_triggered + hook_triggered = False + + def run_test_case(self): + self.test_register_grad_hook() + self.test_register_hook() + + +if __name__ == '__main__': + TestSimpleNetWithGradientHookForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_gradient_merge.py b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_merge.py new file mode 100644 index 00000000000000..a358ec4b485795 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_gradient_merge.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + DemoNet, + DPDemoNet, + MPDemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithGradientMergeForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + self.init_single_card_net_result() + + def run_dynamic_gradient_merge(self, layer): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(self.image) + + for i in range(2): + out = layer(image) + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + loss.backward() + + return loss, layer.w0.grad, layer.w1.grad + + def init_single_card_net_result(self): + ( + self.base_loss, + self.base_w0_grad, + self.base_w1_grad, + ) = self.run_dynamic_gradient_merge(DemoNet(self.w0, self.w1)) + + def test_dp_demo_net(self): + ( + self.dp_loss, + self.dp_w0_grad, + self.dp_w1_grad, + ) = self.run_dynamic_gradient_merge( + DPDemoNet(self.w0, self.w1, self._mesh) + ) + self.check_tensor_eq(self.dp_loss, self.base_loss) + self.check_tensor_eq(self.dp_w0_grad, self.base_w0_grad) + self.check_tensor_eq(self.dp_w1_grad, self.base_w1_grad) + + def test_mp_demo_net(self): + ( + self.mp_loss, + self.mp_w0_grad, + self.mp_w1_grad, + ) = self.run_dynamic_gradient_merge( + MPDemoNet(self.w0, self.w1, self._mesh) + ) + self.check_tensor_eq(self.mp_loss, self.base_loss) + self.check_tensor_eq(self.mp_w0_grad, self.base_w0_grad) + self.check_tensor_eq(self.mp_w1_grad, self.base_w1_grad) + + def run_test_case(self): + self.test_dp_demo_net() + self.test_mp_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithGradientMergeForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py b/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py index 90532a647812ad..031d4dfb9e326a 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py @@ -22,6 +22,7 @@ import paddle import paddle.distributed as dist +import paddle.nn.functional as F from paddle import nn @@ -51,7 +52,7 @@ def __init__(self, np_w0, np_w1, mesh): ) def forward(self, x): - y = paddle.matmul( + out = F.linear( dist.shard_tensor( x, dist_attr=dist.DistAttr( @@ -60,8 +61,10 @@ def forward(self, x): ), self.w0, ) - z = paddle.matmul(y, self.w1) - return z + out = F.relu(out) + out = F.linear(out, self.w1) + + return out class TestSimpleNetHybridStrategyForSemiAutoParallel( @@ -81,12 +84,14 @@ def __init__(self): def test_dp_mp_demo_net(self): ( self.dp_mp_loss, - self.dp_mp_w0_grad, - self.dp_mp_w1_grad, + self.dp_mp_w0, + self.dp_mp_w1, ) = self.run_dynamic(DPAndMPDemoNet(self.w0, self.w1, self._mesh)) self.check_tensor_eq(self.dp_mp_loss, self.base_loss) - self.check_tensor_eq(self.dp_mp_w0_grad, self.base_w0_grad) - self.check_tensor_eq(self.dp_mp_w1_grad, self.base_w1_grad) + self.check_tensor_eq(self.dp_mp_w0, self.base_w0) + self.check_tensor_eq(self.dp_mp_w1, self.base_w1) + self.check_tensor_eq(self.dp_mp_w0.grad, self.base_w0.grad) + self.check_tensor_eq(self.dp_mp_w1.grad, self.base_w1.grad) def run_test_case(self): self.test_dp_mp_demo_net() diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index e3f18abeca1932..20965b19263364 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE) test_default_data_parallel_rule) py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) py_test_modules(test_squeeze_rule MODULES test_squeeze_rule) + py_test_modules(test_slice_rule MODULES test_slice_rule) py_test_modules(test_flatten_rule MODULES test_flatten_rule) # End of unittests WITH single card WITHOUT timeout diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py index a370580682d8cb..8268c7e768276a 100644 --- a/test/auto_parallel/spmd_rules/test_reshape_rule.py +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -243,10 +243,58 @@ def test_reshape_infer_forward(self): infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] ) + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] + # dims_mapping: [0, 1, -1] --> [0, 1, -1], [0, 1, -1, -1] + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.attrs["shape"] = [0, 0, -1, 192] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] + # dims_mapping: [0, -1, 1] --> [0, -1, -1], [0, -1, -1, -1] + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.attrs["shape"] = [0, 0, -1, 192] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] + # dims_mapping: [1, -1, 0] --> [1, -1, 0], [1, -1, 0, -1] + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.attrs["shape"] = [0, 0, -1, 192] + self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0]) + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1] + ) + # shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1] # raise error self.attrs["shape"] = [3, 24, 6, -1, -1] - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): self.rule.infer_forward( self.x_dist_tensor_spec, self.attrs['shape'] ) @@ -454,6 +502,63 @@ def test_reshape_infer_backward(self): infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, 0] ) + # shape: [8, 1024, 3072] --> [0, 0, -1, 192] (input --> output) + # dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [8, 1024, 3072] + self.output_dist_tensor_spec.shape = [0, 0, -1, 192] + self.attrs["shape"] = [0, 0, -1, 192] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output) + # dims_mapping: [0, 1, -1, -1] --> [0, 1, -1], [0, 1, -1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.output_dist_tensor_spec.shape = [0, 0, -1, 192] + self.attrs["shape"] = [0, 0, -1, 192] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [-1, -1, 3072] --> [0, 0, -1, 192] (input --> output) + # dims_mapping: [0, -1, 1, -1] --> [0, -1, 1], [0, -1, 1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [-1, -1, 3072] + self.output_dist_tensor_spec.shape = [0, 0, -1, 192] + self.attrs["shape"] = [0, 0, -1, 192] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1]) + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['shape'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, 1, -1] + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/spmd_rules/test_slice_rule.py b/test/auto_parallel/spmd_rules/test_slice_rule.py new file mode 100644 index 00000000000000..e5bad8ff9b87e7 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_slice_rule.py @@ -0,0 +1,303 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestSliceSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = core.get_phi_spmd_rule("slice") + + x_shape = [8, 8, 16, 16] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() + self.attrs['infer_flags'] = [0] + self.attrs['decrease_axis'] = [0] + + def test_slice_infer_forward(self): + # axes: [-1] + # dims_mapping: [-1, 0, 1, -1] --> [-1, 0, 1, -1] [-1, 0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [8] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1, -1] + ) + + # axes: [-1] + # dims_mapping: [-1, -1, 1, 0] --> [-1, -1, 1, -1] [-1, -1, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + + # axes: [1, 2] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1] [0, -1, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # axes: [1, 2] + # dims_mapping: [-1, 1, 0, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 1, 0, -1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + # axes: [0, 1, 2, 3] + # dims_mapping: [0, 1, -1, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + self.attrs['axes'] = [0, 1, 2, 3] + self.attrs['starts'] = [0, 0, 4, 4] + self.attrs['ends'] = [4, 4, -1, 32] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + def test_slice_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 8, 16, 16], output_tensor_dist_attr + ) + + # axes: [-1] + # dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, -1], [-1, -1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [8] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, -1] + ) + + # axes: [-1] + # dims_mapping: [-1, 1, 0, -1] --> [-1, 1, 0, -1], [-1, 1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0, -1]) + self.attrs['axes'] = [-1] + self.attrs['starts'] = [4] + self.attrs['ends'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, 0, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0, -1] + ) + + # axes: [1, 2] + # dims_mapping: [-1, 1, 0, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0, -1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + # axes: [1, 2] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1], [0, -1, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['axes'] = [1, 2] + self.attrs['starts'] = [4, 4] + self.attrs['ends'] = [-1, 32] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # axes: [0, 1, 2, 3] + # dims_mapping: [0, 1, -1, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + self.attrs['axes'] = [0, 1, 2, 3] + self.attrs['starts'] = [0, 0, 4, 4] + self.attrs['ends'] = [4, 4, -1, 32] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axes'], + self.attrs['starts'], + self.attrs['ends'], + self.attrs['infer_flags'], + self.attrs['decrease_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_api_dist_branch.py b/test/auto_parallel/test_api_dist_branch.py index 99a5d5be878d9a..fa2adc046422e7 100644 --- a/test/auto_parallel/test_api_dist_branch.py +++ b/test/auto_parallel/test_api_dist_branch.py @@ -137,24 +137,6 @@ def test_concat_for_dist_tensor(self): # self.check_tensor_eq(local_in1.grad, dist_in1.grad) # self.check_tensor_eq(local_in2.grad, dist_in2.grad) - # input: paddle::optional - # output: phi::Tensor - def test_expand_as_for_dist_tensor(self): - x1 = np.random.random(size=[2, 8]).astype("float32") - x2 = np.random.random(size=[2, 2, 8]).astype("float32") - local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) - local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) - local_out = paddle.expand_as(local_in1, local_in2) - dist_out = paddle.expand_as(dist_in1, dist_in2) - self.check_tensor_eq(local_out, dist_out) - - # TODO(chenweihang): expand_as is a special case, the forward contains - # optional input, but backward not, open this case after dist support - # optional input - # local_out.backward() - # dist_out.backward() - # self.check_tensor_eq(local_in1.grad, dist_in1.grad) - # input: paddle::optional # output: phi::Tensor def test_bincount_api_for_dist_tensor(self): @@ -256,86 +238,6 @@ def test_check_finite_and_unscale_for_dist_tensor(self): self.check_tensor_eq(local_x, dist_x) self.check_tensor_eq(local_found_inf, dist_found_inf) - # input: phi::Tensor - # output: inplace paddle::optional - def test_adamax_for_dist_tensor(self): - dtype = np.float32 - mp_dtype = np.float32 - shape = [123, 321] - - beta1 = 0.78 - beta2 = 0.899 - epsilon = 1e-5 - param = np.random.random(shape).astype(dtype) - grad = np.random.random(shape).astype(dtype) - moment = np.random.random(shape).astype(dtype) - inf_norm = np.random.random(shape).astype(dtype) - master_param = param.astype(mp_dtype) - - lr = np.array([0.002]).astype("float32") - beta1_pow = np.array([beta1**10]).astype("float32") - - local_param, dist_param = self.create_local_and_dist_tensor_pair(param) - local_grad, dist_grad = self.create_local_and_dist_tensor_pair(grad) - local_lr, dist_lr = self.create_local_and_dist_tensor_pair(lr) - ( - local_beta1_pow, - dist_beta1_pow, - ) = self.create_local_and_dist_tensor_pair(beta1_pow) - local_moment, dist_moment = self.create_local_and_dist_tensor_pair( - moment - ) - local_inf_norm, dist_inf_norm = self.create_local_and_dist_tensor_pair( - inf_norm - ) - ( - local_master_param, - dist_master_param, - ) = self.create_local_and_dist_tensor_pair(master_param) - - ( - local_param_out, - local_moment_out, - local_inf_norm_out, - local_master_param_out, - ) = paddle._C_ops.adamax_( - local_param, - local_grad, - local_lr, - local_moment, - local_inf_norm, - local_beta1_pow, - local_master_param, - beta1, - beta2, - epsilon, - True, - ) - - ( - dist_param_out, - dist_moment_out, - dist_inf_norm_out, - dist_master_param_out, - ) = paddle._C_ops.adamax_( - dist_param, - dist_grad, - dist_lr, - dist_moment, - dist_inf_norm, - dist_beta1_pow, - dist_master_param, - beta1, - beta2, - epsilon, - True, - ) - - self.check_tensor_eq(local_param_out, dist_param_out) - self.check_tensor_eq(local_moment_out, dist_moment_out) - self.check_tensor_eq(local_inf_norm_out, dist_inf_norm_out) - self.check_tensor_eq(local_master_param_out, dist_master_param_out) - # multi kernel functions def test_adagrad_for_dist_tensor(self): dtype = np.float16 diff --git a/test/auto_parallel/test_mp_allreduce_matmul_grad_overlapping.py b/test/auto_parallel/test_mp_allreduce_matmul_grad_overlapping.py new file mode 100644 index 00000000000000..168836b263f5cf --- /dev/null +++ b/test/auto_parallel/test_mp_allreduce_matmul_grad_overlapping.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestMPAllreduceMatmulGradOverlapping(unittest.TestCase): + def test_mp_allreduce_matmul_grad_overlapping(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join( + file_dir, "mp_allreduce_matmul_grad_overlapping_unittest.py" + ) + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 8040b97d43ac94..047b769f12f75f 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -46,6 +46,16 @@ def test_elementwise_api(self): user_defined_envs=envs, ) + def test_reduction_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_reduction.py", + user_defined_envs=envs, + ) + def test_several_replicated_spmd_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 03b31f70a9e9b3..370debfbad41eb 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -50,6 +50,48 @@ def test_simple_net_single_strategy_with_amp(self): user_defined_envs=envs, ) + def test_simple_net_single_strategy_with_gradient_merge(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_gradient_merge.py", + user_defined_envs=envs, + ) + + def test_simple_net_recompute(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_recompute.py", + user_defined_envs=envs, + ) + + def test_simple_net_single_strategy_with_gradient_hook(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_gradient_hook.py", + user_defined_envs=envs, + ) + + def test_simple_net_clear_gradient(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_clear_gradient.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/fleet/CMakeLists.txt b/test/collective/fleet/CMakeLists.txt index b1b57cb6cf4f5e..5a0e2c0d859ec1 100644 --- a/test/collective/fleet/CMakeLists.txt +++ b/test/collective/fleet/CMakeLists.txt @@ -340,6 +340,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) set_tests_properties(test_dygraph_sharding_stage2_bf16 PROPERTIES TIMEOUT "200") endif() +if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) + bash_test_modules( + test_dygraph_sharding_stage1_bf16 + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "200" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=22024;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_dygraph_sharding_stage1_bf16 PROPERTIES TIMEOUT + "200") +endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) bash_test_modules( test_dygraph_sharding_stage1_fp16 diff --git a/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py b/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py new file mode 100644 index 00000000000000..9a69976b830cc6 --- /dev/null +++ b/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py @@ -0,0 +1,279 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.nn import Linear, ReLU + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(paddle.nn.Layer): + def __init__(self, linear_size=1000): + super().__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + self._relu = ReLU() + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + y = self._relu(y) + return y + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples=200, linear_size=1000): + self.num_samples = num_samples + self.linear_size = linear_size + + def __getitem__(self, idx): + img = np.random.rand(self.linear_size).astype('float32') + return img + + def __len__(self): + return self.num_samples + + +def optimizer_setting(model, use_pure_bf16, use_main_grad): + if use_main_grad: + assert use_pure_bf16 + model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.00001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + multi_precision=use_pure_bf16, + ) + if use_main_grad: + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) + + return optimizer + + +def train_mlp( + model, + sharding_stage, + use_pure_bf16=False, + accumulate_grad=False, + use_main_grad=False, + test_scaler=False, +): + # bf16 not support dynamic loss scaling + # disable dynamic_loss_scaling to coverage distributed_scaler + dynamic_loss_scaling = False + scaler = None + scale_loss = 1024 + if test_scaler: + assert sharding_stage == 1 + assert not accumulate_grad + scaler = paddle.amp.GradScaler( + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=dynamic_loss_scaling, + ) + scaler = fleet.distributed_scaler(scaler) + optimizer = optimizer_setting( + model=model, use_pure_bf16=use_pure_bf16, use_main_grad=use_main_grad + ) + + strategy = fleet.DistributedStrategy() + if use_pure_bf16: + level = 'O2' + custom_white_list = None + + amp_configs = { + "init_loss_scaling": scale_loss, + "use_pure_bf16": True, + "use_dynamic_loss_scaling": dynamic_loss_scaling, + } + strategy.amp = True + strategy.amp_configs = amp_configs + else: + level = 'O1' + custom_white_list = [ + "matmul_v2", + "elementwise_add", + "relu", + "reduce_mean", + ] + + if sharding_stage == 1: + hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 2, + } + strategy.hybrid_configs = hybrid_configs + + fleet.init(is_collective=True, strategy=strategy) + model = fleet.distributed_model(model) + + if sharding_stage == 1: + optimizer = fleet.distributed_optimizer(optimizer) + + paddle.seed(2023) + np.random.seed(2023) + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, + ) + + if sharding_stage == 1: + model.to(device="gpu") + + if not use_pure_bf16: + for param in model.parameters(): + t = paddle.cast( + paddle.cast(param, dtype='bfloat16'), dtype='float32' + ) + param.set_value(t) + + losses = [] + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + data.stop_gradient = True + + with paddle.amp.auto_cast( + True, + level=level, + dtype="bfloat16", + custom_white_list=custom_white_list, + ): + out = model(data) + loss = paddle.mean(out) + + losses.append(loss) + + if test_scaler: + assert scaler is not None + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + else: + loss.backward() + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + return losses + + +def test_stage1_bf16(): + if not paddle.amp.is_bfloat16_supported(): + return + paddle.distributed.init_parallel_env() + + mlp = MLP() + state_dict = mlp.state_dict() + + # stage1 bf16 O1 vs stage1 bf16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + o1_losses = train_mlp( + mlp1, + sharding_stage=1, + use_pure_bf16=False, + ) + o2_losses = train_mlp( + mlp2, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=True, + ) + for i in range(len(o1_losses)): + o1_32_loss = paddle.cast(o1_losses[i], dtype='float32').detach() + o2_32_loss = paddle.cast(o2_losses[i], dtype='float32').detach() + np.testing.assert_array_equal(o1_32_loss, o2_32_loss) + + # stage1 scaler test with main_grad + mlp3 = MLP() + mlp3.set_state_dict(state_dict) + train_mlp( + mlp3, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=True, + test_scaler=True, + ) + + # stage1 scaler test without main_grad + mlp4 = MLP() + mlp4.set_state_dict(state_dict) + train_mlp( + mlp4, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=False, + test_scaler=True, + ) + + # grad accumulation test + mlp5 = MLP() + mlp6 = MLP() + mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) + o1_losses_grad_acc = train_mlp( + mlp5, + sharding_stage=1, + use_pure_bf16=False, + accumulate_grad=True, + ) + o2_losses_grad_acc = train_mlp( + mlp6, + sharding_stage=1, + use_pure_bf16=True, + use_main_grad=True, + accumulate_grad=True, + ) + for i in range(len(o2_losses_grad_acc)): + o2_loss_grad_acc = paddle.cast( + o2_losses_grad_acc[i], dtype='float32' + ).detach() + o1_loss_grad_acc = paddle.cast( + o1_losses_grad_acc[i], dtype='float32' + ).detach() + np.testing.assert_array_equal(o2_loss_grad_acc, o1_loss_grad_acc) + + return + + +if __name__ == '__main__': + test_stage1_bf16() diff --git a/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py b/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py index 601659e0fb98b9..93e163b9facca6 100644 --- a/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py +++ b/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py @@ -83,9 +83,9 @@ def train_mlp( accumulate_grad=False, use_main_grad=False, test_scaler=False, - scale_loss=1024, ): scaler = None + scale_loss = 1024 if test_scaler: assert sharding_stage == 1 assert not accumulate_grad @@ -94,10 +94,15 @@ def train_mlp( optimizer = optimizer_setting( model=model, use_pure_fp16=use_pure_fp16, use_main_grad=use_main_grad ) + + strategy = fleet.DistributedStrategy() if use_pure_fp16: level = 'O2' custom_white_list = None - model = paddle.amp.decorate(models=model, dtype="float16", level=level) + + amp_configs = {"init_loss_scaling": scale_loss, "use_pure_fp16": True} + strategy.amp_configs = amp_configs + strategy.amp = True else: level = 'O1' custom_white_list = [ @@ -108,11 +113,19 @@ def train_mlp( ] if sharding_stage == 1: - optimizer = fleet.distributed_optimizer(optimizer) + hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 2, + } + strategy.hybrid_configs = hybrid_configs - model = fleet.distributed_model(model) - else: - model = paddle.DataParallel(model) + fleet.init(is_collective=True, strategy=strategy) + model = fleet.distributed_model(model) + + if sharding_stage == 1: + optimizer = fleet.distributed_optimizer(optimizer) paddle.seed(2023) np.random.seed(2023) @@ -176,19 +189,6 @@ def test_stage1_fp16(): return paddle.distributed.init_parallel_env() - strategy = fleet.DistributedStrategy() - hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": 1, - "sharding_degree": 2, - } - scale_loss = 1024 - amp_configs = {"init_loss_scaling": scale_loss, "use_pure_fp16": True} - strategy.hybrid_configs = hybrid_configs - strategy.amp_configs = amp_configs - - fleet.init(is_collective=True, strategy=strategy) mlp = MLP() state_dict = mlp.state_dict() @@ -201,14 +201,12 @@ def test_stage1_fp16(): mlp1, sharding_stage=1, use_pure_fp16=False, - scale_loss=scale_loss, ) o2_losses = train_mlp( mlp2, sharding_stage=1, use_pure_fp16=True, use_main_grad=True, - scale_loss=scale_loss, ) for i in range(len(o1_losses)): o1_32_loss = paddle.cast(o1_losses[i], dtype='float32').detach() @@ -224,7 +222,6 @@ def test_stage1_fp16(): use_pure_fp16=True, use_main_grad=True, test_scaler=True, - scale_loss=scale_loss, ) # grad accumulation test @@ -237,7 +234,6 @@ def test_stage1_fp16(): sharding_stage=1, use_pure_fp16=False, accumulate_grad=True, - scale_loss=scale_loss, ) o2_losses_grad_acc = train_mlp( mlp6, @@ -245,7 +241,6 @@ def test_stage1_fp16(): use_pure_fp16=True, use_main_grad=True, accumulate_grad=True, - scale_loss=scale_loss, ) for i in range(len(o2_losses_grad_acc)): o2_loss_grad_acc = paddle.cast( diff --git a/test/collective/fleet/test_dygraph_sharding_stage1_bf16.py b/test/collective/fleet/test_dygraph_sharding_stage1_bf16.py new file mode 100644 index 00000000000000..bd15963edd2634 --- /dev/null +++ b/test/collective/fleet/test_dygraph_sharding_stage1_bf16.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphShardingStage1(TestMultipleGpus): + # check sharding logic as well as the accuracy with single mode + def test_dygraph_sharding_stage1_bf16(self): + self.run_mnist_2gpu('dygraph_group_sharded_stage1_bf16.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/testslist.csv b/test/collective/fleet/testslist.csv index b9df9ace687cf4..8b7a3b7a2f4c4d 100644 --- a/test/collective/fleet/testslist.csv +++ b/test/collective/fleet/testslist.csv @@ -26,6 +26,7 @@ test_parallel_dygraph_no_sync,,GPU,300,DIST,../../legacy_test/dist_test.sh,2,,ht test_dygraph_dataparallel_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage2,,,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage2_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_sharding_stage1_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage1_fp16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_control_flow,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_lars_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., diff --git a/test/collective/test_collective_alltoall_api.py b/test/collective/test_collective_alltoall_api.py index 01864126a96e95..21d01075aa7299 100644 --- a/test/collective/test_collective_alltoall_api.py +++ b/test/collective/test_collective_alltoall_api.py @@ -57,7 +57,7 @@ def test_alltoall_nccl_with_new_comm(self): "alltoall", "nccl", dtype=dtype, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_alltoall_nccl_dygraph(self): diff --git a/test/collective/test_collective_barrier_api.py b/test/collective/test_collective_barrier_api.py index 74e5cebc873c15..75b0e809053654 100644 --- a/test/collective/test_collective_barrier_api.py +++ b/test/collective/test_collective_barrier_api.py @@ -33,7 +33,7 @@ def test_barrier_nccl_with_new_comm(self): "collective_barrier_api.py", "barrier", "nccl", - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_barrier_gloo(self): diff --git a/test/collective/test_collective_global_gather.py b/test/collective/test_collective_global_gather.py index c4c2e42c0b561d..c5110b65198012 100644 --- a/test/collective/test_collective_global_gather.py +++ b/test/collective/test_collective_global_gather.py @@ -44,7 +44,7 @@ def test_global_gather_nccl_new_comm(self): "collective_global_gather.py", "global_gather", "nccl", - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/collective/test_collective_global_scatter.py b/test/collective/test_collective_global_scatter.py index 7eb34abe6cf5af..26a267a98d349b 100644 --- a/test/collective/test_collective_global_scatter.py +++ b/test/collective/test_collective_global_scatter.py @@ -43,7 +43,7 @@ def test_global_scatter_nccl_new_comm(self): "collective_global_scatter.py", "global_scatter", "nccl", - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/collective/test_collective_reduce_api.py b/test/collective/test_collective_reduce_api.py index 9759b500288356..aafda45aea9762 100644 --- a/test/collective/test_collective_reduce_api.py +++ b/test/collective/test_collective_reduce_api.py @@ -78,7 +78,7 @@ def test_reduce_nccl_with_new_comm(self): "nccl", dtype=dtype, reduce_type=red_type, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_reduce_bkcl(self): diff --git a/test/collective/test_collective_reduce_scatter_api.py b/test/collective/test_collective_reduce_scatter_api.py index 4ec909e8d2b448..bd3dd14df88df0 100644 --- a/test/collective/test_collective_reduce_scatter_api.py +++ b/test/collective/test_collective_reduce_scatter_api.py @@ -59,7 +59,7 @@ def test_reduce_scatter_nccl_with_new_comm(self): "reduce_scatter", "nccl", dtype=dtype, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_reduce_scatter_nccl_dygraph(self): diff --git a/test/collective/test_collective_scatter_api.py b/test/collective/test_collective_scatter_api.py index b21e06c6c75d04..7ac51e99b55937 100644 --- a/test/collective/test_collective_scatter_api.py +++ b/test/collective/test_collective_scatter_api.py @@ -47,7 +47,7 @@ def test_scatter_nccl_with_new_comm(self): "scatter", "nccl", dtype=dtype, - need_envs={"FLAGS_dynamic_static_unified_comm": "1"}, + need_envs={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_scatter_nccl_dygraph(self): diff --git a/test/cpp/fluid/CMakeLists.txt b/test/cpp/fluid/CMakeLists.txt index ca62b7c1c7c03f..324043b0746fe4 100644 --- a/test/cpp/fluid/CMakeLists.txt +++ b/test/cpp/fluid/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(benchmark) +add_subdirectory(framework) if(WITH_CINN) add_subdirectory(cinn) endif() diff --git a/test/cpp/fluid/cinn/CMakeLists.txt b/test/cpp/fluid/cinn/CMakeLists.txt index 0feb905a83902f..96c38feb32ba7a 100644 --- a/test/cpp/fluid/cinn/CMakeLists.txt +++ b/test/cpp/fluid/cinn/CMakeLists.txt @@ -46,7 +46,13 @@ if(WITH_TESTING) elementwise_add_op paddle_flags) target_link_libraries(cinn_instruction_run_op_test ${PYTHON_LIBRARIES}) - set_tests_properties( - cinn_instruction_run_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT - "${CINN_RUN_ENVIRONMENT}") + + get_property( + env + TEST cinn_instruction_run_op_test + PROPERTY ENVIRONMENT) + set_property(TEST cinn_instruction_run_op_test + PROPERTY ENVIRONMENT "${CINN_RUN_ENVIRONMENT}" ${env}) + set_tests_properties(cinn_instruction_run_op_test PROPERTIES LABELS + "RUN_TYPE=CINN") endif() diff --git a/test/cpp/fluid/framework/CMakeLists.txt b/test/cpp/fluid/framework/CMakeLists.txt new file mode 100644 index 00000000000000..663dae547625b5 --- /dev/null +++ b/test/cpp/fluid/framework/CMakeLists.txt @@ -0,0 +1,289 @@ +# add_subdirectory(details) + +cc_test( + data_type_test + SRCS data_type_test.cc + DEPS data_type place tensor) + +cc_test( + tensor_test + SRCS tensor_test.cc + DEPS tensor isfinite_op) +if(WITH_GPU) + nv_test( + tensor_util_test + SRCS tensor_util_test.cc tensor_util_test.cu + DEPS tensor dlpack_tensor isfinite_op) +elseif(WITH_ROCM) + hip_test( + tensor_util_test + SRCS tensor_util_test.cc tensor_util_test.cu + DEPS tensor dlpack_tensor isfinite_op) +else() + cc_test( + tensor_util_test + SRCS tensor_util_test.cc + DEPS tensor dlpack_tensor isfinite_op) +endif() + +cc_test( + copy_same_tensor_test + SRCS copy_same_tensor_test.cc + DEPS tensor) + +cc_test( + eigen_test + SRCS eigen_test.cc + DEPS tensor) + +cc_test( + lod_tensor_test + SRCS lod_tensor_test.cc + DEPS phi lod_tensor memory) + +if(WITH_GPU) + nv_test( + lod_tensor_gpu_test + SRCS lod_tensor_test.cu + DEPS lod_tensor) +elseif(WITH_ROCM) + hip_test( + lod_tensor_gpu_test + SRCS lod_tensor_test.cu + DEPS lod_tensor) +endif() + +cc_test( + reader_test + SRCS reader_test.cc + DEPS reader) + +cc_test( + threadpool_test + SRCS threadpool_test.cc + DEPS phi) + +cc_test( + var_type_traits_test + SRCS var_type_traits_test.cc + DEPS var_type_traits) + +cc_test( + device_worker_test + SRCS device_worker_test.cc + DEPS device_worker) + +cc_test( + scope_test + SRCS scope_test.cc + DEPS scope) + +cc_test( + variable_test + SRCS variable_test.cc + DEPS tensor var_type_traits) + +if(WITH_GPU) + nv_test( + data_device_transform_test + SRCS data_device_transform_test.cu + DEPS operator op_registry device_context phi scope) +elseif(WITH_ROCM) + hip_test( + data_device_transform_test + SRCS data_device_transform_test.cu + DEPS operator op_registry device_context phi scope) +endif() + +if(WITH_GPU) + nv_test( + data_type_transform_test + SRCS data_type_transform_test.cc data_type_transform_test.cu + DEPS data_type_transform) +elseif(WITH_ROCM) + hip_test( + data_type_transform_test + SRCS data_type_transform_test.cc data_type_transform_test.cu + DEPS data_type_transform) +elseif(WITH_XPU) + cc_test( + data_type_transform_test + SRCS data_type_transform_test.cc + DEPS data_type_transform) +else() + cc_test( + data_type_transform_test + SRCS data_type_transform_test.cc + DEPS data_type_transform) +endif() + +cc_test( + data_layout_transform_test + SRCS data_layout_transform_test.cc + DEPS data_layout_transform) + +cc_test( + attribute_test + SRCS attribute_test.cc + DEPS attribute framework_proto proto_desc) + +cc_test( + program_desc_test + SRCS program_desc_test.cc + DEPS proto_desc device_context) + +cc_test( + op_desc_test + SRCS op_desc_test.cc + DEPS proto_desc) + +cc_test( + op_version_registry_test + SRCS op_version_registry_test.cc + DEPS op_version_registry) + +cc_test( + op_proto_maker_test + SRCS op_proto_maker_test.cc + DEPS op_proto_maker) + +cc_test( + no_need_buffer_vars_inference_test + SRCS no_need_buffer_vars_inference_test.cc + DEPS no_need_buffer_vars_inference layer) + +cc_test( + operator_test + SRCS operator_test.cc + DEPS operator op_registry device_context) +cc_test( + operator_exception_test + SRCS operator_exception_test.cc + DEPS operator op_registry device_context) + +cc_test( + version_test + SRCS version_test.cc + DEPS version) + +cc_test( + op_call_stack_test + SRCS op_call_stack_test.cc + DEPS op_call_stack) + +cc_test( + program_utils_test + SRCS program_utils_test.cc + DEPS proto_desc program_utils) + +if(WITH_GPU) + nv_test( + op_registry_test + SRCS op_registry_test.cc + DEPS op_registry) +elseif(WITH_ROCM) + hip_test( + op_registry_test + SRCS op_registry_test.cc + DEPS op_registry) +endif() + +if(WITH_PSCORE) + get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + if(WITH_HETERPS) + cc_test( + dist_multi_trainer_test + SRCS dist_multi_trainer_test.cc + DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS} + graph_gpu_wrapper) + cc_test( + heter_pipeline_trainer_test + SRCS heter_pipeline_trainer_test.cc + DEPS conditional_block_op + generated_op + heter_listen_and_serv_op + executor + heter_server + gloo_wrapper + phi + ${RPC_DEPS} + graph_gpu_wrapper) + else() + cc_test( + dist_multi_trainer_test + SRCS dist_multi_trainer_test.cc + DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS}) + cc_test( + heter_pipeline_trainer_test + SRCS heter_pipeline_trainer_test.cc + DEPS conditional_block_op + generated_op + heter_listen_and_serv_op + executor + heter_server + gloo_wrapper + phi + ${RPC_DEPS}) + endif() +else() + cc_test( + dist_multi_trainer_test + SRCS dist_multi_trainer_test.cc + DEPS conditional_block_op executor gloo_wrapper) +endif() + +cc_test( + prune_test + SRCS prune_test.cc + DEPS op_info prune recurrent_op device_context) +cc_test( + var_type_inference_test + SRCS var_type_inference_test.cc + DEPS op_registry proto_desc) + +cc_test( + selected_rows_utils_test + SRCS selected_rows_utils_test.cc + DEPS selected_rows_utils) + +cc_test( + op_kernel_type_test + SRCS op_kernel_type_test.cc + DEPS place device_context framework_proto op_kernel_type) +cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) + +cc_test(tuple_test SRCS tuple_test.cc) + +cc_test(inlined_vector_test SRCS inlined_vector_test.cc) + +cc_test( + dlpack_tensor_test + SRCS dlpack_tensor_test.cc + DEPS dlpack_tensor glog) + +cc_test_old( + op_compatible_info_test + SRCS + op_compatible_info_test.cc + DEPS + op_compatible_info + proto_desc + string_helper + glog) + +cc_test( + infershape_utils_test + SRCS infershape_utils_test.cc + DEPS infershape_utils phi) + +if(WITH_TESTING AND TEST selected_rows_utils_test) + set_tests_properties(selected_rows_utils_test PROPERTIES TIMEOUT 120) +endif() + +cc_test(scope_guard_test SRCS scope_guard_test.cc) +cc_test( + phi_utils_test + SRCS phi_utils_test.cc + DEPS phi_utils) + +cc_test(convert_utils_test SRCS convert_utils_test.cc) diff --git a/paddle/fluid/framework/attribute_test.cc b/test/cpp/fluid/framework/attribute_test.cc similarity index 100% rename from paddle/fluid/framework/attribute_test.cc rename to test/cpp/fluid/framework/attribute_test.cc diff --git a/paddle/fluid/framework/convert_utils_test.cc b/test/cpp/fluid/framework/convert_utils_test.cc similarity index 100% rename from paddle/fluid/framework/convert_utils_test.cc rename to test/cpp/fluid/framework/convert_utils_test.cc diff --git a/paddle/fluid/framework/copy_same_tensor_test.cc b/test/cpp/fluid/framework/copy_same_tensor_test.cc similarity index 100% rename from paddle/fluid/framework/copy_same_tensor_test.cc rename to test/cpp/fluid/framework/copy_same_tensor_test.cc diff --git a/paddle/fluid/framework/data_device_transform_test.cu b/test/cpp/fluid/framework/data_device_transform_test.cu similarity index 100% rename from paddle/fluid/framework/data_device_transform_test.cu rename to test/cpp/fluid/framework/data_device_transform_test.cu diff --git a/paddle/fluid/framework/data_feed_test.cc b/test/cpp/fluid/framework/data_feed_test.cc similarity index 100% rename from paddle/fluid/framework/data_feed_test.cc rename to test/cpp/fluid/framework/data_feed_test.cc diff --git a/paddle/fluid/framework/data_layout_transform_test.cc b/test/cpp/fluid/framework/data_layout_transform_test.cc similarity index 100% rename from paddle/fluid/framework/data_layout_transform_test.cc rename to test/cpp/fluid/framework/data_layout_transform_test.cc diff --git a/paddle/fluid/framework/data_type_test.cc b/test/cpp/fluid/framework/data_type_test.cc similarity index 100% rename from paddle/fluid/framework/data_type_test.cc rename to test/cpp/fluid/framework/data_type_test.cc diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/test/cpp/fluid/framework/data_type_transform_test.cc similarity index 100% rename from paddle/fluid/framework/data_type_transform_test.cc rename to test/cpp/fluid/framework/data_type_transform_test.cc diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/test/cpp/fluid/framework/data_type_transform_test.cu similarity index 100% rename from paddle/fluid/framework/data_type_transform_test.cu rename to test/cpp/fluid/framework/data_type_transform_test.cu diff --git a/paddle/fluid/framework/details/cow_ptr_test.cc b/test/cpp/fluid/framework/details/cow_ptr_test.cc similarity index 100% rename from paddle/fluid/framework/details/cow_ptr_test.cc rename to test/cpp/fluid/framework/details/cow_ptr_test.cc diff --git a/paddle/fluid/framework/device_worker_test.cc b/test/cpp/fluid/framework/device_worker_test.cc similarity index 100% rename from paddle/fluid/framework/device_worker_test.cc rename to test/cpp/fluid/framework/device_worker_test.cc diff --git a/paddle/fluid/framework/dist_multi_trainer_test.cc b/test/cpp/fluid/framework/dist_multi_trainer_test.cc similarity index 100% rename from paddle/fluid/framework/dist_multi_trainer_test.cc rename to test/cpp/fluid/framework/dist_multi_trainer_test.cc diff --git a/paddle/fluid/framework/dlpack_tensor_test.cc b/test/cpp/fluid/framework/dlpack_tensor_test.cc similarity index 100% rename from paddle/fluid/framework/dlpack_tensor_test.cc rename to test/cpp/fluid/framework/dlpack_tensor_test.cc diff --git a/paddle/fluid/framework/eigen_test.cc b/test/cpp/fluid/framework/eigen_test.cc similarity index 100% rename from paddle/fluid/framework/eigen_test.cc rename to test/cpp/fluid/framework/eigen_test.cc diff --git a/paddle/fluid/framework/heter_pipeline_trainer_test.cc b/test/cpp/fluid/framework/heter_pipeline_trainer_test.cc similarity index 100% rename from paddle/fluid/framework/heter_pipeline_trainer_test.cc rename to test/cpp/fluid/framework/heter_pipeline_trainer_test.cc diff --git a/paddle/fluid/framework/infershape_utils_test.cc b/test/cpp/fluid/framework/infershape_utils_test.cc similarity index 100% rename from paddle/fluid/framework/infershape_utils_test.cc rename to test/cpp/fluid/framework/infershape_utils_test.cc diff --git a/paddle/fluid/framework/inlined_vector_test.cc b/test/cpp/fluid/framework/inlined_vector_test.cc similarity index 100% rename from paddle/fluid/framework/inlined_vector_test.cc rename to test/cpp/fluid/framework/inlined_vector_test.cc diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/test/cpp/fluid/framework/lod_tensor_test.cc similarity index 100% rename from paddle/fluid/framework/lod_tensor_test.cc rename to test/cpp/fluid/framework/lod_tensor_test.cc diff --git a/paddle/fluid/framework/lod_tensor_test.cu b/test/cpp/fluid/framework/lod_tensor_test.cu similarity index 100% rename from paddle/fluid/framework/lod_tensor_test.cu rename to test/cpp/fluid/framework/lod_tensor_test.cu diff --git a/paddle/fluid/framework/naive_executor_test.cc b/test/cpp/fluid/framework/naive_executor_test.cc similarity index 100% rename from paddle/fluid/framework/naive_executor_test.cc rename to test/cpp/fluid/framework/naive_executor_test.cc diff --git a/paddle/fluid/framework/no_need_buffer_vars_inference_test.cc b/test/cpp/fluid/framework/no_need_buffer_vars_inference_test.cc similarity index 100% rename from paddle/fluid/framework/no_need_buffer_vars_inference_test.cc rename to test/cpp/fluid/framework/no_need_buffer_vars_inference_test.cc diff --git a/paddle/fluid/framework/op_call_stack_test.cc b/test/cpp/fluid/framework/op_call_stack_test.cc similarity index 100% rename from paddle/fluid/framework/op_call_stack_test.cc rename to test/cpp/fluid/framework/op_call_stack_test.cc diff --git a/paddle/fluid/framework/op_compatible_info_test.cc b/test/cpp/fluid/framework/op_compatible_info_test.cc similarity index 100% rename from paddle/fluid/framework/op_compatible_info_test.cc rename to test/cpp/fluid/framework/op_compatible_info_test.cc diff --git a/paddle/fluid/framework/op_desc_test.cc b/test/cpp/fluid/framework/op_desc_test.cc similarity index 100% rename from paddle/fluid/framework/op_desc_test.cc rename to test/cpp/fluid/framework/op_desc_test.cc diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/test/cpp/fluid/framework/op_kernel_type_test.cc similarity index 100% rename from paddle/fluid/framework/op_kernel_type_test.cc rename to test/cpp/fluid/framework/op_kernel_type_test.cc diff --git a/paddle/fluid/framework/op_proto_maker_test.cc b/test/cpp/fluid/framework/op_proto_maker_test.cc similarity index 100% rename from paddle/fluid/framework/op_proto_maker_test.cc rename to test/cpp/fluid/framework/op_proto_maker_test.cc diff --git a/paddle/fluid/framework/op_registry_test.cc b/test/cpp/fluid/framework/op_registry_test.cc similarity index 100% rename from paddle/fluid/framework/op_registry_test.cc rename to test/cpp/fluid/framework/op_registry_test.cc diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/test/cpp/fluid/framework/op_version_registry_test.cc similarity index 100% rename from paddle/fluid/framework/op_version_registry_test.cc rename to test/cpp/fluid/framework/op_version_registry_test.cc diff --git a/paddle/fluid/framework/operator_exception_test.cc b/test/cpp/fluid/framework/operator_exception_test.cc similarity index 100% rename from paddle/fluid/framework/operator_exception_test.cc rename to test/cpp/fluid/framework/operator_exception_test.cc diff --git a/paddle/fluid/framework/operator_test.cc b/test/cpp/fluid/framework/operator_test.cc similarity index 100% rename from paddle/fluid/framework/operator_test.cc rename to test/cpp/fluid/framework/operator_test.cc diff --git a/paddle/fluid/framework/phi_utils_test.cc b/test/cpp/fluid/framework/phi_utils_test.cc similarity index 100% rename from paddle/fluid/framework/phi_utils_test.cc rename to test/cpp/fluid/framework/phi_utils_test.cc diff --git a/paddle/fluid/framework/program_desc_test.cc b/test/cpp/fluid/framework/program_desc_test.cc similarity index 100% rename from paddle/fluid/framework/program_desc_test.cc rename to test/cpp/fluid/framework/program_desc_test.cc diff --git a/paddle/fluid/framework/program_utils_test.cc b/test/cpp/fluid/framework/program_utils_test.cc similarity index 100% rename from paddle/fluid/framework/program_utils_test.cc rename to test/cpp/fluid/framework/program_utils_test.cc diff --git a/paddle/fluid/framework/prune_test.cc b/test/cpp/fluid/framework/prune_test.cc similarity index 100% rename from paddle/fluid/framework/prune_test.cc rename to test/cpp/fluid/framework/prune_test.cc diff --git a/paddle/fluid/framework/reader_test.cc b/test/cpp/fluid/framework/reader_test.cc similarity index 100% rename from paddle/fluid/framework/reader_test.cc rename to test/cpp/fluid/framework/reader_test.cc diff --git a/paddle/fluid/framework/scope_guard_test.cc b/test/cpp/fluid/framework/scope_guard_test.cc similarity index 100% rename from paddle/fluid/framework/scope_guard_test.cc rename to test/cpp/fluid/framework/scope_guard_test.cc diff --git a/paddle/fluid/framework/scope_test.cc b/test/cpp/fluid/framework/scope_test.cc similarity index 100% rename from paddle/fluid/framework/scope_test.cc rename to test/cpp/fluid/framework/scope_test.cc diff --git a/paddle/fluid/framework/selected_rows_utils_test.cc b/test/cpp/fluid/framework/selected_rows_utils_test.cc similarity index 100% rename from paddle/fluid/framework/selected_rows_utils_test.cc rename to test/cpp/fluid/framework/selected_rows_utils_test.cc diff --git a/paddle/fluid/framework/tensor_test.cc b/test/cpp/fluid/framework/tensor_test.cc similarity index 100% rename from paddle/fluid/framework/tensor_test.cc rename to test/cpp/fluid/framework/tensor_test.cc diff --git a/paddle/fluid/framework/tensor_util_test.cc b/test/cpp/fluid/framework/tensor_util_test.cc similarity index 100% rename from paddle/fluid/framework/tensor_util_test.cc rename to test/cpp/fluid/framework/tensor_util_test.cc diff --git a/paddle/fluid/framework/tensor_util_test.cu b/test/cpp/fluid/framework/tensor_util_test.cu similarity index 100% rename from paddle/fluid/framework/tensor_util_test.cu rename to test/cpp/fluid/framework/tensor_util_test.cu diff --git a/paddle/fluid/framework/threadpool_test.cc b/test/cpp/fluid/framework/threadpool_test.cc similarity index 100% rename from paddle/fluid/framework/threadpool_test.cc rename to test/cpp/fluid/framework/threadpool_test.cc diff --git a/paddle/fluid/framework/trainer_test.cc b/test/cpp/fluid/framework/trainer_test.cc similarity index 100% rename from paddle/fluid/framework/trainer_test.cc rename to test/cpp/fluid/framework/trainer_test.cc diff --git a/paddle/fluid/framework/tuple_test.cc b/test/cpp/fluid/framework/tuple_test.cc similarity index 100% rename from paddle/fluid/framework/tuple_test.cc rename to test/cpp/fluid/framework/tuple_test.cc diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/test/cpp/fluid/framework/var_type_inference_test.cc similarity index 100% rename from paddle/fluid/framework/var_type_inference_test.cc rename to test/cpp/fluid/framework/var_type_inference_test.cc diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/test/cpp/fluid/framework/var_type_traits_test.cc similarity index 100% rename from paddle/fluid/framework/var_type_traits_test.cc rename to test/cpp/fluid/framework/var_type_traits_test.cc diff --git a/paddle/fluid/framework/variable_test.cc b/test/cpp/fluid/framework/variable_test.cc similarity index 100% rename from paddle/fluid/framework/variable_test.cc rename to test/cpp/fluid/framework/variable_test.cc diff --git a/paddle/fluid/framework/version_test.cc b/test/cpp/fluid/framework/version_test.cc similarity index 100% rename from paddle/fluid/framework/version_test.cc rename to test/cpp/fluid/framework/version_test.cc diff --git a/test/cpp/fluid/mkldnn/CMakeLists.txt b/test/cpp/fluid/mkldnn/CMakeLists.txt index 3d5883dabfbf89..f83fd91963be20 100644 --- a/test/cpp/fluid/mkldnn/CMakeLists.txt +++ b/test/cpp/fluid/mkldnn/CMakeLists.txt @@ -83,3 +83,18 @@ else() cc_test_old(test_mkldnn_op_nhwc SRCS test_mkldnn_op_nhwc.cc DEPS ${paddle_lib} python) endif() + +cc_test( + test_mkldnn_pool_adaptive_op + SRCS test_mkldnn_pool_adaptive_op.cc + DEPS fleet_executor + conditional_block_op + standalone_executor + executor + op_registry + generated_static_op + generated_op + phi + scope + device_context + enforce) diff --git a/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc b/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc index ecc5ce726b2d8f..4b6498d07289ec 100644 --- a/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc +++ b/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc @@ -19,7 +19,6 @@ #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" diff --git a/test/cpp/fluid/mkldnn/test_mkldnn_pool_adaptive_op.cc b/test/cpp/fluid/mkldnn/test_mkldnn_pool_adaptive_op.cc new file mode 100644 index 00000000000000..3e1a9230ec231c --- /dev/null +++ b/test/cpp/fluid/mkldnn/test_mkldnn_pool_adaptive_op.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include + +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +template +void AddVarToScope(const std::string var_name, + paddle::framework::Scope* scope, + const paddle::framework::DDim& dims) { + std::random_device seed; + std::default_random_engine engine(seed()); + std::uniform_real_distribution dist(0, 100); + + phi::DenseTensor tmp_tensor; + auto* tmp_data = + tmp_tensor.mutable_data(dims, paddle::platform::CPUPlace()); + auto* tensor = scope->Var(var_name)->GetMutable(); + tensor->mutable_data(dims, paddle::platform::CPUPlace()); + for (auto i = 0; i < tensor->numel(); ++i) { + tmp_data[i] = static_cast(dist(engine)); + } + paddle::framework::TensorCopySync( + tmp_tensor, paddle::platform::CPUPlace(), tensor); +} +void test_pool2d(bool adaptive, bool ceil_mode, std::string pool_type = "max") { + framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("pool2d"); + desc.SetInput("X", {"pool2d-X"}); + desc.SetOutput("Out", {"pool2d-Out"}); + AddVarToScope("pool2d-X", &scope, {1, 3, 9, 12}); + AddVarToScope("pool2d-Out", &scope, {1, 3, 2, 2}); + std::vector ksize({2, 2}); + std::vector strides({1, 1}); + std::vector paddings({0, 0}); + std::string pooling_t = pool_type; + + desc.SetAttr("pooling_type", pooling_t); + desc.SetAttr("ksize", ksize); + desc.SetAttr("strides", strides); + desc.SetAttr("paddings", paddings); + desc.SetAttr("adaptive", adaptive); + desc.SetAttr("ceil_mode", ceil_mode); + desc.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(desc); + + op->Run(scope, cpu_place); +} + +TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); } +TEST(Pool2dOpConverter, adaptive) { test_pool2d(true, false); } + +TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); } +TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(true, true, "avg"); } + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP_ITSELF(pool2d); +PD_DECLARE_KERNEL(pool2d, OneDNN, ONEDNN); +PD_DECLARE_KERNEL(pool2d, CPU, ALL_LAYOUT); diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index 28a425dbd4ebe9..29cc5373dbb509 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -22,7 +22,7 @@ #include "paddle/phi/core/kernel_registry.h" -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -65,20 +65,19 @@ TEST(StandaloneExecutor, run) { paddle::dialect::FullOp op2 = builder.Build( std::vector{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); - builder.Build(op1->result(0), op2->result(0)); + auto add_op = + builder.Build(op1->result(0), op2->result(0)); + + std::string out_name = "add_out"; + builder.Build(add_op->result(0), out_name); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_2"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); @@ -136,8 +135,10 @@ TEST(StandaloneExecutor, run_feed_tensor) { pir::Operation::Create({}, attr_map2, {dense_tensor_dtype}, feed_op_info); program.block()->push_back(feed_op2); - builder.Build(feed_op1->result(0), - feed_op2->result(0)); + auto add_op = builder.Build(feed_op1->result(0), + feed_op2->result(0)); + std::string out_name = "add_out"; + builder.Build(add_op->result(0), out_name); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -145,10 +146,6 @@ TEST(StandaloneExecutor, run_feed_tensor) { Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_2"; test_core.SetSkipGcVars({out_name}); phi::DenseTensorMeta meta( @@ -191,16 +188,15 @@ TEST(StandaloneExecutor, run_inplace_sqrt) { builder.Build(full->result(0)); + std::string out_name = "full_out"; + builder.Build(full->result(0), out_name); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_0"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); @@ -254,16 +250,16 @@ TEST(StandaloneExecutor, if_op) { std::vector{3}, true, phi::DataType::BOOL); builder.Build(std::vector{full_op_2.out()}); + std::string out_name = "if_out"; + builder.SetInsertionPointToEnd(block); + builder.Build(if_op->result(0), out_name); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_1"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); @@ -325,16 +321,15 @@ TEST(StandaloneExecutor, while_op) { builder.SetInsertionPointAfter(while_op); + std::string out_name = "while_out"; + builder.Build(while_op->result(0), out_name); + auto kernel_program = PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string out_name = os.str() + "_inner_var_3"; test_core.SetSkipGcVars({out_name}); test_core.Run({}); diff --git a/test/cpp/pir/cinn/CMakeLists.txt b/test/cpp/pir/cinn/CMakeLists.txt index 2cafafb9e41cc7..30aa78bc67fb64 100644 --- a/test/cpp/pir/cinn/CMakeLists.txt +++ b/test/cpp/pir/cinn/CMakeLists.txt @@ -1,20 +1,24 @@ if(WITH_TESTING AND WITH_CINN) + paddle_test(test_pir_compiler SRCS pir_compiler_test.cc DEPS pir_compiler + cinn_runtime_dialect) + set_tests_properties(test_pir_compiler PROPERTIES LABELS "RUN_TYPE=CINN") + + paddle_test(test_jit_instruction SRCS jit_instruction_test.cc DEPS + cinn_runtime_dialect pir_compiler) + set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN") + cc_test_old( - test_new_ir_compiler + dialect_convert_test SRCS - new_ir_compiler_test.cc + dialect_convert_test.cc DEPS - new_ir_compiler - cinn_runtime_dialect - pir - phi + drr gtest - glog) - set_tests_properties(test_new_ir_compiler PROPERTIES LABELS "RUN_TYPE=CINN") - - cc_test_old(test_jit_instruction SRCS jit_instruction_test.cc DEPS - interpreter new_ir_compiler) - set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN") + pd_to_cinn_pass + pd_op_dialect + cinn_op_dialect + pir) + set_tests_properties(dialect_convert_test PROPERTIES LABELS "RUN_TYPE=CINN") cc_test_old( ir_op_fusion_test @@ -27,8 +31,10 @@ if(WITH_TESTING AND WITH_CINN) pir gtest glog) + set_tests_properties(ir_op_fusion_test PROPERTIES LABELS "RUN_TYPE=CINN") - paddle_test(test_group_op SRCS group_op_test.cc DEPS cinn_op_dialect) + paddle_test(test_group_op SRCS group_op_test.cc DEPS op_with_group_merge_pass + cinn_op_dialect) set_tests_properties(test_group_op PROPERTIES LABELS "RUN_TYPE=CINN") paddle_test(test_pir_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS diff --git a/test/cpp/pir/cinn/dialect_convert_test.cc b/test/cpp/pir/cinn/dialect_convert_test.cc new file mode 100644 index 00000000000000..b92af287c7cc8a --- /dev/null +++ b/test/cpp/pir/cinn/dialect_convert_test.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/pir/transforms/dead_code_elimination_pass.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + auto sum_op = + builder.Build(full_input_op.result(0), + std::vector({-1}), + phi::DataType::FLOAT32, + true); + auto relu_op = builder.Build(sum_op.result(0)); + auto exp_op = builder.Build(sum_op.result(0)); +} + +void BuildProgramMax(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + auto max_op = builder.Build( + full_input_op.result(0), std::vector({-1}), true); + auto relu_op = builder.Build(max_op.result(0)); + auto exp_op = builder.Build(max_op.result(0)); +} + +TEST(DrrTest, reduce_sum) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + cinn::dialect::ir::PdOp2CinnOpConverter(&program); + + auto it = program.block()->begin(); + + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); +} + +TEST(DrrTest, reduce_max) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgramMax(builder); + + cinn::dialect::ir::PdOp2CinnOpConverter(&program); + + auto it = program.block()->begin(); + + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); + it++; + CHECK_EQ((*it)->isa(), true); +} diff --git a/test/cpp/pir/cinn/group_op_test.cc b/test/cpp/pir/cinn/group_op_test.cc index c252c06a3cccdf..049cd88ea1a5d1 100644 --- a/test/cpp/pir/cinn/group_op_test.cc +++ b/test/cpp/pir/cinn/group_op_test.cc @@ -19,14 +19,19 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h" +#include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } + std::vector<::pir::Type> CreateDenseTensorTypes(const phi::DDim& dims) { ::pir::IrContext* ctx = ::pir::IrContext::Instance(); ::pir::Type fp32_dtype = ::pir::Float32Type::get(ctx); @@ -144,3 +149,93 @@ TEST(GroupOp, TestBuildByBlock) { ++i; } } + +std::shared_ptr<::pir::Program> BuildGroupProgramForLowering() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); + + auto program = std::make_shared<::pir::Program>(ctx); + const std::vector shape = {2, 2}; + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + const float value = 0.5; + auto full_x = builder.Build( + shape, value, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto full_y = builder.Build( + shape, value, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto group_op1 = builder.Build( + CreateDenseTensorTypes(phi::make_ddim(shape))); + pir::Block* block1 = group_op1.block(); + builder.SetInsertionPointToEnd(block1); + auto sin = builder.Build(full_x->result(0)); + + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{ + sin.out(), + }); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op2 = builder.Build( + CreateDenseTensorTypes(phi::make_ddim(shape))); + pir::Block* block2 = group_op2.block(); + builder.SetInsertionPointToEnd(block2); + auto cos_op = builder.Build(full_y->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{cos_op.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op3 = builder.Build( + CreateDenseTensorTypes(phi::make_ddim(shape))); + pir::Block* block3 = group_op3.block(); + builder.SetInsertionPointToEnd(block3); + auto add = builder.Build(group_op1->result(0), + group_op2->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{add.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto exp = builder.Build(group_op3->result(0)); + + builder.Build(exp.out(), "out", 0); + return program; +} + +TEST(GroupOp, CINNLowering) { + // Step 1: Construct pir::Program + std::shared_ptr<::pir::Program> program = BuildGroupProgramForLowering(); + + auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get()); + + paddle::platform::Place place = paddle::platform::CUDAPlace(0); + + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(res.get(), place); + + paddle::framework::Scope exe_scope; + + paddle::framework::InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + std::set out_names; + out_names.insert("out@fetch"); + auto local_names = exe_scope.LocalVarNames(); + for (size_t i = 0; i < local_names.size(); ++i) { + out_names.insert(local_names[i]); + } + + executor.SetSkipGcVars(out_names); + executor.Run({}, true); + + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 3.88455); + bool res1 = simple_cmp(out_tensor.data()[1], 3.88455); + bool res2 = simple_cmp(out_tensor.data()[2], 3.88455); + bool res3 = simple_cmp(out_tensor.data()[3], 3.88455); + + EXPECT_EQ(res0, true); + EXPECT_EQ(res1, true); + EXPECT_EQ(res2, true); + EXPECT_EQ(res3, true); +} diff --git a/test/cpp/pir/cinn/ir_op_fusion_test.cc b/test/cpp/pir/cinn/ir_op_fusion_test.cc index a392373358b2af..57abf07498de5f 100644 --- a/test/cpp/pir/cinn/ir_op_fusion_test.cc +++ b/test/cpp/pir/cinn/ir_op_fusion_test.cc @@ -52,7 +52,9 @@ TEST(IROpFusionPass, demo) { auto add = builder.Build(inputs[0], inputs[1]); builder.Build(add.result(0)); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); ASSERT_EQ(res.size(), 1u); } @@ -75,10 +77,11 @@ TEST(IROpFusionPass, ElementWise_Fusion_0) { auto f = builder.Build(e, inputs[2]).result(0); builder.Build(f, inputs[2]); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(res.size(), 1u); } @@ -107,10 +110,11 @@ TEST(IROpFusionPass, Broadcast_Test_0) { builder.Build(e, axes, out_shape).result(0); builder.Build(e1, f); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); // ASSERT_EQ(res.size(), 1u); } @@ -138,10 +142,11 @@ TEST(IROpFusionPass, Broadcast_Test_1) { builder.Build(e, axes, out_shape).result(0); builder.Build(inputs[3], e1); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 2u); } @@ -170,10 +175,11 @@ TEST(IROpFusionPass, Broadcast_Test_2) { builder.Build(inputs[3], f1); builder.Build(inputs[4], f1); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 2u); } @@ -199,10 +205,11 @@ TEST(IROpFusionPass, reduce_test_0) { builder.Build(c, axes, true).result(0); builder.Build(c, axes, true).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); } @@ -228,10 +235,11 @@ TEST(IROpFusionPass, reduce_test_1) { builder.Build(c, axes, true).result(0); builder.Build(c, axes1, true).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 2u); } @@ -259,10 +267,11 @@ TEST(IROpFusionPass, reduce_test_2) { builder.Build(inputs[2], e).result(0); builder.Build(inputs[2], f).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 2u); } @@ -294,10 +303,11 @@ TEST(IROpFusionPass, reduce_test_3) { builder.Build(f, axes1, out_shape).result(0); builder.Build(inputs[2], f1).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); } @@ -332,10 +342,11 @@ TEST(IROpFusionPass, reduce_test_4) { builder.Build(f, axes1, out_shape).result(0); builder.Build(inputs[3], f2).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); } @@ -362,10 +373,11 @@ TEST(IROpFusionPass, reduce_test_5) { builder.Build(inputs[1], axes, false).result(0); builder.Build(c, axes, false).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); } @@ -435,10 +447,11 @@ TEST(IROpFusionPass, layer_norm) { auto t5 = builder.Build(t3, scale).result(0); builder.Build(t5, bias).result(0); - auto res = cinn::dialect::ir::OpFusionPassInternal(program); + auto res = + cinn::dialect::ir::OpFusionPassInternal(std::vector( + program.block()->begin(), program.block()->end())); - auto new_group = - cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + auto new_group = cinn::dialect::ir::GeneralFusionMergePassInternal(res); ASSERT_EQ(new_group.size(), 1u); } diff --git a/test/cpp/pir/cinn/jit_instruction_test.cc b/test/cpp/pir/cinn/jit_instruction_test.cc index 5e80cd8021a3fa..8fdffa86de6677 100644 --- a/test/cpp/pir/cinn/jit_instruction_test.cc +++ b/test/cpp/pir/cinn/jit_instruction_test.cc @@ -31,7 +31,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/cinn/utils/data_util.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" @@ -84,7 +84,7 @@ TEST(CinnJitInstruction, Run) { auto target = cinn::common::DefaultNVGPUTarget(); auto scope = cinn::hlir::framework::BuildScope(target, *program); - std::vector compiler_list; + std::vector compiler_list; std::set checking_cinn_ops = {"pd_op.sin", "pd_op.cos"}; @@ -101,10 +101,10 @@ TEST(CinnJitInstruction, Run) { ++it) { if (checking_cinn_ops.count((*it)->name())) { auto ir_compiler = - new cinn::hlir::framework::NewIRCompiler(*program, target, scope); + new cinn::hlir::framework::PIRCompiler(*program, target, scope); std::vector<::pir::Operation*> ops = {*it}; - auto group = std::make_shared(ops); + auto group = std::make_shared(ops); auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group}); compiler_list.push_back(ir_compiler); std::unordered_map op_attrs{ diff --git a/test/cpp/pir/cinn/new_ir_compiler_test.cc b/test/cpp/pir/cinn/pir_compiler_test.cc similarity index 64% rename from test/cpp/pir/cinn/new_ir_compiler_test.cc rename to test/cpp/pir/cinn/pir_compiler_test.cc index c75df1959ceada..8f1c883bc37341 100644 --- a/test/cpp/pir/cinn/new_ir_compiler_test.cc +++ b/test/cpp/pir/cinn/pir_compiler_test.cc @@ -22,15 +22,17 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/cinn/utils/data_util.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" -using cinn::hlir::framework::newir::Group; -using cinn::hlir::framework::newir::GroupPtr; +using cinn::hlir::framework::pir::Group; +using cinn::hlir::framework::pir::GroupPtr; using ProgramInfo = std::tuple, std::vector>; @@ -74,7 +76,73 @@ ProgramInfo BuildProgram() { return {program, groups}; } -TEST(NewIRCompier, CompilerAndRun) { +ProgramInfo BuildSoftmax() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + auto program = std::make_shared<::pir::Program>(ctx); + paddle::dialect::APIBuilder::Instance().SetProgram(program.get()); + + auto x = paddle::dialect::full(std::vector{64, 128}, + 1.0, + phi::DataType::FLOAT32, + phi::GPUPlace()); + auto max_tmp = paddle::dialect::max(x, std::vector{1}, true); + auto sub_tmp = paddle::dialect::subtract(x, max_tmp); + auto exp_tmp = paddle::dialect::exp(sub_tmp); + // sum need to be decomposed in Program pass, but not implemented currently. + auto sum_tmp = paddle::dialect::sum( + exp_tmp, std::vector{1}, phi::DataType::FLOAT32, true); + auto out = paddle::dialect::divide(exp_tmp, sum_tmp); + + std::vector groups; + groups.emplace_back(std::make_shared( + std::initializer_list<::pir::Operation*>({x.owner()}))); + groups.emplace_back( + std::make_shared(std::initializer_list<::pir::Operation*>({ + max_tmp.owner(), + sub_tmp.owner(), + exp_tmp.owner(), + sum_tmp.owner(), + out.owner(), + }))); + + return {program, groups}; +} + +TEST(PIRCompier, CompileSoftmax) { + // Step 1: Construct pir::Program + auto prog_info = BuildSoftmax(); + std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); + std::vector groups = std::get<1>(prog_info); + EXPECT_EQ(program->block()->size(), 8u); + LOG(INFO) << program->block()->size(); + + std::stringstream ss; + program->Print(ss); + LOG(INFO) << ss.str(); + + // Step 2: Compiler New pir::Program into Runtime Program + auto target = cinn::common::DefaultNVGPUTarget(); + auto scope = cinn::hlir::framework::BuildScope(target, *program); + LOG(INFO) << scope->var_names().size(); + ASSERT_EQ(scope->var_names().size(), 8); + + cinn::hlir::framework::PIRCompiler ir_compiler(*program, target, scope); + auto runtime_program = ir_compiler.Build(groups); + + // Step 3: Execute Runtime Instruction and check Scope. + ASSERT_NO_THROW(runtime_program->Execute()); + for (auto& var_name : scope->var_names()) { + std::string name = {var_name.begin(), var_name.end()}; + std::vector data = + cinn::GetTensorData(scope->GetTensor(name), target); + for (int i = 0; i < 1; ++i) { + LOG_FIRST_N(INFO, 10) << "data: " << data[i]; + } + } +} + +TEST(PIRCompier, CompilerAndRun) { // Step 1: Construct pir::Program auto prog_info = BuildProgram(); std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); @@ -90,7 +158,7 @@ TEST(NewIRCompier, CompilerAndRun) { auto scope = cinn::hlir::framework::BuildScope(target, *program); ASSERT_EQ(scope->var_names().size(), 6); - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); + cinn::hlir::framework::PIRCompiler ir_compiler(*program, target, scope); auto runtime_program = ir_compiler.Build(); // Step 3: Execute Runtime Instruction and check Scope. @@ -105,7 +173,7 @@ TEST(NewIRCompier, CompilerAndRun) { } } -TEST(NewIRCompier, CompileGroupOps) { +TEST(PIRCompier, CompileGroupOps) { // Step 1: Construct pir::Program auto prog_info = BuildProgram(); std::shared_ptr<::pir::Program> program = std::get<0>(prog_info); @@ -122,7 +190,7 @@ TEST(NewIRCompier, CompileGroupOps) { auto scope = cinn::hlir::framework::BuildScope(target, *program); ASSERT_EQ(scope->var_names().size(), 6); - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); + cinn::hlir::framework::PIRCompiler ir_compiler(*program, target, scope); auto runtime_program = ir_compiler.Build(groups); // Step 3: Execute Runtime Instruction and check Scope. @@ -148,6 +216,6 @@ TEST(RuntimeDialect, CompilerAndRun) { auto scope = cinn::hlir::framework::BuildScope(target, *program); ASSERT_EQ(scope->var_names().size(), 6u); - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); + cinn::hlir::framework::PIRCompiler ir_compiler(*program, target, scope); auto runtime_program = ir_compiler.Build(); } diff --git a/test/cpp/pir/core/ir_builder_test.cc b/test/cpp/pir/core/ir_builder_test.cc index e3705d08c7ef93..84e7d271bce47c 100644 --- a/test/cpp/pir/core/ir_builder_test.cc +++ b/test/cpp/pir/core/ir_builder_test.cc @@ -47,6 +47,7 @@ TEST(builder_test, attribute_api) { EXPECT_EQ(pir::DoubleAttribute::get(&ctx, 2.0), builder.double_attr(2.0)); EXPECT_EQ(pir::Int32Attribute::get(&ctx, 2), builder.int32_attr(2)); EXPECT_EQ(pir::Int64Attribute::get(&ctx, 2), builder.int64_attr(2)); + EXPECT_EQ(pir::IndexAttribute::get(&ctx, 2), builder.index_attr(2)); EXPECT_EQ(pir::ArrayAttribute::get(&ctx, std::vector()), builder.array_attr({})); EXPECT_EQ(pir::PointerAttribute::get(&ctx, nullptr), diff --git a/test/cpp/pir/core/ir_op_test.cc b/test/cpp/pir/core/ir_op_test.cc index 596519ba57d4cc..1631c8198d3e56 100644 --- a/test/cpp/pir/core/ir_op_test.cc +++ b/test/cpp/pir/core/ir_op_test.cc @@ -21,7 +21,6 @@ #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/enforce.h" #include "paddle/pir/core/ir_context.h" @@ -32,39 +31,7 @@ #include "test/cpp/pir/tools/test_dialect.h" #include "test/cpp/pir/tools/test_op.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} - -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes, - const pir::Type &dtype = - pir::Float32Type::get(pir::IrContext::Instance())) { - std::vector op_inputs = {}; - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - pir::DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; -} +#include "test/cpp/pir/tools/test_pir_utils.h" TEST(op_test, region_test) { // (1) Register Dialect, Operation1, Operation2 into IrContext. @@ -76,12 +43,12 @@ TEST(op_test, region_test) { pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(test::Operation1::name()); pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(test::Operation2::name()); - pir::Operation *op1 = - pir::Operation::Create({}, - CreateAttributeMap({"op1_attr1", "op1_attr2"}, - {"op1_attr1", "op1_attr2"}), - {pir::Float32Type::get(ctx)}, - op1_info); + pir::Operation *op1 = pir::Operation::Create( + {}, + test::CreateAttributeMap({"op1_attr1", "op1_attr2"}, + {"op1_attr1", "op1_attr2"}), + {pir::Float32Type::get(ctx)}, + op1_info); pir::Operation *op_2 = pir::Operation::Create({}, {}, {pir::Float32Type::get(ctx)}, op2_info); @@ -169,9 +136,9 @@ TEST(op_test, op_traits_test) { pir::DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); auto op3 = builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype); @@ -220,9 +187,9 @@ TEST(op_test, same_operands_shape_trait_test2) { pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype), @@ -255,9 +222,9 @@ TEST(op_test, same_operands_and_result_shape_trait_test2) { phi::DDim dims = {2, 2, 2}; pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0)), @@ -287,9 +254,9 @@ TEST(op_test, same_operands_and_result_shape_trait_test3) { pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype), @@ -330,9 +297,9 @@ TEST(op_test, same_operands_element_type_trait_test2) { pir::DenseTensorType::get(ctx, dtype1, dims, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0), dense_tensor_dtype), @@ -365,9 +332,9 @@ TEST(op_test, same_operands_and_result_element_type_trait_test2) { phi::DDim dims = {2, 2}; pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0)), @@ -399,9 +366,9 @@ TEST(op_test, same_operands_and_result_element_type_trait_test3) { pir::DenseTensorType::get(ctx, dtype2, dims2, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); EXPECT_THROW(builder.Build( op1->result(0), @@ -443,9 +410,9 @@ TEST(op_test, same_operands_and_result_type_trait_test2) { phi::DDim dims = {2, 2}; pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + test::CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); EXPECT_THROW(builder.Build( op1->result(0), op2->result(0)), @@ -481,9 +448,9 @@ TEST(op_test, same_operands_and_result_type_trait_test3) { pir::DenseTensorType::get(ctx, dtype1, dims2, data_layout, lod, offset); pir::Operation *op1 = - CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype2); + test::CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype2); pir::Operation *op2 = - CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype1); + test::CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype1); EXPECT_THROW(builder.Build( op1->result(0), diff --git a/test/cpp/pir/core/ir_value_test.cc b/test/cpp/pir/core/ir_value_test.cc index d4a7d14322a66c..dba46b72c08a08 100644 --- a/test/cpp/pir/core/ir_value_test.cc +++ b/test/cpp/pir/core/ir_value_test.cc @@ -16,55 +16,47 @@ #include "paddle/pir/core/attribute.h" #include "paddle/pir/core/builtin_attribute.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/operation.h" +#include "test/cpp/pir/tools/test_pir_utils.h" + // This unittest is used to test the construction interfaces of value class and // operation. The constructed test scenario is: a = OP1(); b = OP2(); c = OP3(a, // b); d, e, f, g, h, i, j = OP4(a, c); -pir::AttributeMap CreateAttributeMap(std::string attribute_name, - std::string attribute) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attribute); - pir::AttributeMap attr_map; - attr_map.insert( - std::pair(attribute_name, attr_value)); - return attr_map; -} TEST(value_test, value_test) { pir::IrContext *ctx = pir::IrContext::Instance(); // 1. Construct OP1: a = OP1() std::vector op1_inputs = {}; std::vector op1_output_types = {pir::Float32Type::get(ctx)}; - pir::Operation *op1 = - pir::Operation::Create(op1_inputs, - CreateAttributeMap("op1_name", "op1_attr"), - op1_output_types, - pir::OpInfo()); + pir::Operation *op1 = pir::Operation::Create( + op1_inputs, + test::CreateAttributeMap({"op1_name"}, {"op1_attr"}), + op1_output_types, + pir::OpInfo()); op1->Print(std::cout); pir::OpResult a = op1->result(0); EXPECT_TRUE(a.use_empty()); // 2. Construct OP2: b = OP2(); std::vector op2_inputs = {}; std::vector op2_output_types = {pir::Float32Type::get(ctx)}; - pir::Operation *op2 = - pir::Operation::Create(op2_inputs, - CreateAttributeMap("op2_name", "op2_attr"), - op2_output_types, - pir::OpInfo()); + pir::Operation *op2 = pir::Operation::Create( + op2_inputs, + test::CreateAttributeMap({"op2_name"}, {"op2_attr"}), + op2_output_types, + pir::OpInfo()); op2->Print(std::cout); pir::OpResult b = op2->result(0); EXPECT_TRUE(b.use_empty()); // 3. Construct OP3: c = OP3(a, b); std::vector op3_inputs{a, b}; std::vector op3_output_types = {pir::Float32Type::get(ctx)}; - pir::Operation *op3 = - pir::Operation::Create(op3_inputs, - CreateAttributeMap("op3_name", "op3_attr"), - op3_output_types, - pir::OpInfo()); + pir::Operation *op3 = pir::Operation::Create( + op3_inputs, + test::CreateAttributeMap({"op3_name"}, {"op3_attr"}), + op3_output_types, + pir::OpInfo()); EXPECT_TRUE(op1->result(0).HasOneUse()); EXPECT_TRUE(op2->result(0).HasOneUse()); @@ -76,11 +68,11 @@ TEST(value_test, value_test) { for (size_t i = 0; i < 7; i++) { op4_output_types.push_back(pir::Float32Type::get(ctx)); } - pir::Operation *op4 = - pir::Operation::Create(op4_inputs, - CreateAttributeMap("op4_name", "op4_attr"), - op4_output_types, - pir::OpInfo()); + pir::Operation *op4 = pir::Operation::Create( + op4_inputs, + test::CreateAttributeMap({"op4_name"}, {"op4_attr"}), + op4_output_types, + pir::OpInfo()); op4->Print(std::cout); // Test 1: diff --git a/test/cpp/pir/core/scalar_attribute_test.cc b/test/cpp/pir/core/scalar_attribute_test.cc index e15ebfad84585b..5d547c58c3a925 100644 --- a/test/cpp/pir/core/scalar_attribute_test.cc +++ b/test/cpp/pir/core/scalar_attribute_test.cc @@ -50,6 +50,9 @@ TEST(ScalarTest, test_classof) { pir::Attribute int32_scalar = pir::Int32Attribute::get(ctx, 1); EXPECT_TRUE(int32_scalar.isa()); + pir::Attribute index_scalar = pir::IndexAttribute::get(ctx, 1l); + EXPECT_TRUE(index_scalar.isa()); + pir::Attribute int64_scalar = pir::Int64Attribute::get(ctx, 1l); EXPECT_TRUE(int64_scalar.isa()); } diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index 0f3581732784fe..2ec503dd20a95b 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -95,6 +95,7 @@ TEST(type_test, built_in_type) { pir::Type index_1 = pir::IndexType::get(ctx); pir::Type index_2 = pir::IndexType::get(ctx); + EXPECT_TRUE(index_1.IsIndex()); EXPECT_EQ(index_1, index_2); EXPECT_EQ(index_1.type_id(), index_2.type_id()); EXPECT_EQ(&index_1.abstract_type(), diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 3282fe5893abba..70dfb193b3c45a 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -10,45 +10,30 @@ endif() cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ${PATTERN_REWRITE_TEST_DEPS}) -cc_test_old( +cc_test( drr_test - SRCS - drr_test.cc - DEPS - drr - gtest - pd_op_dialect - pir) -cc_test_old( + SRCS drr_test.cc + DEPS drr) + +cc_test( drr_same_type_binding_test - SRCS - drr_same_type_binding_test.cc - DEPS - drr - gtest - pd_op_dialect - pir) - -cc_test_old( + SRCS drr_same_type_binding_test.cc + DEPS drr gtest pd_op_dialect pir) + +cc_test( drr_fuse_linear_test - SRCS - drr_fuse_linear_test.cc - DEPS - fusion_passes - drr - gtest - pd_op_dialect - pir) -cc_test_old( + SRCS drr_fuse_linear_test.cc + DEPS fusion_passes drr gtest pd_op_dialect pir) + +cc_test( + drr_fuse_linear_param_grad_add_test + SRCS drr_fuse_linear_param_grad_add_test.cc + DEPS fusion_passes drr gtest pd_op_dialect pir) + +cc_test( drr_attention_fuse_test - SRCS - drr_attention_fuse_test.cc - DEPS - fusion_passes - drr - gtest - pd_op_dialect - pir) + SRCS drr_attention_fuse_test.cc + DEPS fusion_passes drr gtest pd_op_dialect pir) set_tests_properties( pattern_rewrite_test PROPERTIES ENVIRONMENT diff --git a/test/cpp/pir/pattern_rewrite/drr_fuse_linear_param_grad_add_test.cc b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_param_grad_add_test.cc new file mode 100644 index 00000000000000..a898a8ff777259 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_param_grad_add_test.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram0(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_bias_op1 = + builder.Build(std::vector{32}, 1.0); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + paddle::dialect::AddOp add_op1 = builder.Build( + matmul_op1.out(), full_bias_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::AddGradOp add_grad_op1 = + builder.Build( + matmul_op1.out(), full_bias_op1.out(), full_d_out_op1.out()); + + paddle::dialect::MatmulGradOp matmul_grad_op1 = + builder.Build( + full_input_op1.out(), full_weight_op1.out(), add_grad_op1.x_grad()); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_grad_op1.y_grad()); + + builder.Build(add_op1.out(), "out", 0); + builder.Build(add_grad_op1.y_grad(), "dbias", 1); + builder.Build(add__op1.out(), "dweight", 2); +} + +void BuildProgram1(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulGradOp matmul_grad_op1 = + builder.Build( + full_input_op1.out(), full_weight_op1.out(), full_d_out_op1.out()); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_grad_op1.y_grad()); + + builder.Build(matmul_op1.out(), "out", 0); + builder.Build(add__op1.out(), "dweight", 1); +} + +void BuildProgram2(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build( + full_input_op1.out(), full_d_out_op1.out(), true, false); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_op2.out()); + + builder.Build(matmul_op1.out(), "out", 0); + builder.Build(add__op1.out(), "dweight", 1); +} + +void BuildProgram3(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + paddle::dialect::FullOp full_bias_op1 = + builder.Build(std::vector{32}, 1.0); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(full_input_op1.out(), + full_weight_op1.out()); + paddle::dialect::AddOp add_op1 = builder.Build( + matmul_op1.out(), full_bias_op1.out()); + + paddle::dialect::FullOp full_d_weight_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::FullOp full_d_out_op1 = + builder.Build(std::vector{32, 32}, 1.5); + + paddle::dialect::AddGradOp add_grad_op1 = + builder.Build( + matmul_op1.out(), full_bias_op1.out(), full_d_out_op1.out()); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build( + add_grad_op1.x_grad(), full_weight_op1.out(), false, true); + + paddle::dialect::MatmulOp matmul_op3 = + builder.Build( + full_input_op1.out(), add_grad_op1.x_grad(), true, false); + + paddle::dialect::Add_Op add__op1 = builder.Build( + full_d_weight_op1.out(), matmul_op3.out()); + + builder.Build(add_op1.out(), "out", 0); + builder.Build(add_grad_op1.y_grad(), "dbias", 1); + builder.Build(add__op1.out(), "dweight", 2); + builder.Build(matmul_op2.out(), "dx", 3); +} + +bool verify_pass(const pir::Program &program) { + for (auto op : *(program.block())) { + if (op->name() == paddle::dialect::FusedLinearParamGradAddOp::name()) { + return true; + } + } + return false; +} + +TEST(DrrTest, FusedLinearParamGradAdd0) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram0(builder); + + EXPECT_EQ(program.block()->size(), 13u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} + +TEST(DrrTest, FusedLinearParamGradAdd1) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram1(builder); + + EXPECT_EQ(program.block()->size(), 9u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} + +TEST(DrrTest, FusedLinearParamGradAdd2) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram2(builder); + + EXPECT_EQ(program.block()->size(), 9u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} + +TEST(DrrTest, FusedLinearParamGradAdd3) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram3(builder); + + EXPECT_EQ(program.block()->size(), 15u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedLinearParamGradAddPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(verify_pass(program), true); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index f607fa5a083260..83dd94556cb60d 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -67,8 +67,15 @@ class FoldExpandToConstantPattern // Result patterns pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> phi::IntArray { + auto shape = + match_ctx.Attr>("expand_shape_value"); + + return phi::IntArray(shape); + }); const auto &full2 = res.Op("pd_op.full", - {{"shape", pat.Attr("expand_shape_value")}, + {{"shape", new_perm_attr}, {"value", pat.Attr("value_1")}, {"dtype", pat.Attr("dtype_1")}, {"place", pat.Attr("place_1")}}); diff --git a/test/cpp/pir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc index 860bf34a69ac4f..4b5e660cf6f3b1 100644 --- a/test/cpp/pir/shape_dialect/constraint_pass_test.cc +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -21,13 +21,11 @@ #include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/builtin_op.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/cast_utils.h" #include "paddle/pir/core/dialect.h" @@ -40,94 +38,72 @@ #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/dialect/shape/transforms/passes.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} +#include "test/cpp/pir/tools/test_pir_utils.h" -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes, - const pir::Type &dtype = - pir::Float32Type::get(pir::IrContext::Instance())) { - std::vector op_inputs = {}; - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - paddle::dialect::DenseTensorType::get( - ctx, dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; -} - -TEST(constraint_pass, materialize_and_build_shape) { +TEST(shape_constraint_pass, materialize_and_build_shape) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - pir::PassManager pm(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - pir::Operation *op0 = CreateDenseTensorOp( - ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op0_attr"}, {"op0_name"}); - program.block()->push_back(op0); + + pir::Operation *op0 = + test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, 2}, + {"op0_attr"}, + {"create_dense_tensor_op0"}); pir::Operation *op1 = - CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, 2, 2}, - {"op1_attr"}, - {"op1_name"}); + test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, 2, 2}, + {"op1_attr"}, + {"create_dense_tensor_op1"}); + program.block()->push_back(op0); program.block()->push_back(op1); - EXPECT_EQ(program.block()->size(), static_cast(2)); + EXPECT_EQ(program.block()->size(), 2u); + + std::stringstream ss1; + program.Print(ss1); + LOG(INFO) << " ================================================ Before Add " + "and Run Pass ================================================ "; + LOG(INFO) << ss1.str(); + + pir::PassManager pm(ctx); pm.AddPass(pir::CreateShapeOptimizationPass()); EXPECT_TRUE(pm.Run(&program)); // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 + 1 funcOp == 15 Ops. - EXPECT_EQ(program.block()->size(), static_cast(15)); - - std::stringstream ss; - program.Print(ss); + EXPECT_EQ(program.block()->size(), 15u); - LOG(INFO) << ss.str(); + std::stringstream ss2; + program.Print(ss2); + LOG(INFO) << " ================================================ After Add " + "and Run Pass ================================================ "; + LOG(INFO) << ss2.str(); } -TEST(constraint_pass, shape_computation_run) { +TEST(shape_constraint_pass, shape_computation_run) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - pir::PassManager pm(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - builder.Build(); - pir::Operation *op0 = - CreateDenseTensorOp(ctx, - {2}, - {"op0_attr"}, - {"op0_name"}, - pir::Int64Type::get(pir::IrContext::Instance())); + pir::Builder builder = ::pir::Builder(ctx, program.block()); + builder.Build(); + pir::Operation *op0 = test::CreateDenseTensorOp( + ctx, + {2}, + {"op0_attr"}, + {"op0_name"}, + pir::Int64Type::get(pir::IrContext::Instance())); program.block()->push_back(op0); - pir::Operation *op1 = CreateDenseTensorOp( + pir::Operation *op1 = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op1_attr"}, {"op1_name"}); program.block()->push_back(op1); + pir::PassManager pm(ctx); pm.AddPass(pir::CreateShapeOptimizationPass()); EXPECT_TRUE(pm.Run(&program)); @@ -135,3 +111,5 @@ TEST(constraint_pass, shape_computation_run) { EXPECT_TRUE(mgr.Load()); EXPECT_TRUE(mgr.Save()); } + +// TODO(zhangbopd): ExpandShapeOfOpPattern etc. diff --git a/test/cpp/pir/shape_dialect/shape_op_test.cc b/test/cpp/pir/shape_dialect/shape_op_test.cc index 9d71e721fe72df..89a728beed9b79 100644 --- a/test/cpp/pir/shape_dialect/shape_op_test.cc +++ b/test/cpp/pir/shape_dialect/shape_op_test.cc @@ -16,119 +16,121 @@ #include #include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/core/block.h" -#include "paddle/pir/core/builder.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/shape/ir/shape_dialect.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/dialect/shape/utils/symbol_table.h" +#include "test/cpp/pir/tools/test_pir_utils.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { +TEST(shape_op, symbolic_dim_op) { pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::shape::SymbolicDimOp sym_dim_op1 = + builder.Build( + "S0", 10, false, false, false, false); + pir::shape::SymbolicDimOp sym_dim_op2 = + builder.Build( + "S1", 10, false, false, false, false); + + EXPECT_EQ(sym_dim_op1.GetDimSize(), 10); + EXPECT_EQ(sym_dim_op1.GetSymName(), "S0"); + EXPECT_FALSE(sym_dim_op1.GetKnownNegativeOne()); + EXPECT_FALSE(sym_dim_op1.GetKnownNonSizeOne()); + EXPECT_FALSE(sym_dim_op1.GetKnownNonSizeZero()); + EXPECT_FALSE(sym_dim_op1.GetKnownNonNegative()); + + EXPECT_FALSE(sym_dim_op1.IsDynamic()); + EXPECT_TRUE(sym_dim_op1.Merge(sym_dim_op2)); -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes) { - std::vector op_inputs = {}; - pir::Type fp32_dtype = pir::Float32Type::get(ctx); - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; + sym_dim_op1.SetDimSize(20); + sym_dim_op1.SetSymName("S2"); + sym_dim_op1.UpdateKnownNegativeOne(true); + sym_dim_op1.UpdateKnownNonSizeOne(true); + sym_dim_op1.UpdateKnownNonSizeZero(true); + sym_dim_op1.UpdateKnownNonNegative(true); + + EXPECT_FALSE(sym_dim_op1.Merge(sym_dim_op2)); + + EXPECT_EQ(sym_dim_op1.GetDimSize(), 20); + EXPECT_EQ(sym_dim_op1.GetSymName(), "S2"); + EXPECT_TRUE(sym_dim_op1.GetKnownNegativeOne()); + EXPECT_TRUE(sym_dim_op1.GetKnownNonSizeOne()); + EXPECT_TRUE(sym_dim_op1.GetKnownNonSizeZero()); + EXPECT_TRUE(sym_dim_op1.GetKnownNonNegative()); } -TEST(shape_op, dim) { +TEST(shape_op, dim_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::DimOp dim_op = builder.Build("S0"); + pir::shape::DimOp dim_op = builder.Build("S0"); pir::OpResult res = dim_op.out(); - EXPECT_EQ(dim_op.getName(), "S0"); - dim_op.setName("S1"); - EXPECT_EQ(dim_op.getName(), "S1"); + EXPECT_EQ(dim_op.GetName(), "S0"); + dim_op.SetName("S1"); + EXPECT_EQ(dim_op.GetName(), "S1"); EXPECT_EQ(res.owner(), dim_op.operation()); EXPECT_EQ(res.type(), pir::IndexType::get(ctx)); } -TEST(shape_op, tie_product_equal) { +TEST(shape_op, tie_product_equal_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); pir::SymbolTable symbolt_table(program.module_op()); - pir::OpResult dim_op0 = builder.Build("S0").out(); - pir::OpResult dim_op1 = builder.Build("S1").out(); - pir::OpResult dim_op2 = builder.Build("S2").out(); - pir::OpResult dim_op3 = builder.Build("S3").out(); - pir::OpResult dim_op4 = builder.Build("S4").out(); + pir::OpResult dim_op0 = builder.Build("S0").out(); + pir::OpResult dim_op1 = builder.Build("S1").out(); + pir::OpResult dim_op2 = builder.Build("S2").out(); + pir::OpResult dim_op3 = builder.Build("S3").out(); + pir::OpResult dim_op4 = builder.Build("S4").out(); - pir::dialect::TieProductEqualOp tie_product_equal = - builder.Build( + pir::shape::TieProductEqualOp tie_product_equal_op = + builder.Build( 2, 3, std::vector{dim_op0, dim_op1, dim_op2, dim_op3, dim_op4}); - std::vector lhs = tie_product_equal.lhs(); - std::vector rhs = tie_product_equal.rhs(); + std::vector lhs = tie_product_equal_op.lhs(); + std::vector rhs = tie_product_equal_op.rhs(); std::vector lhs_ref{dim_op0, dim_op1}; std::vector rhs_ref{dim_op2, dim_op3, dim_op4}; - EXPECT_EQ(symbolt_table.insert(tie_product_equal), "tie_product_equal"); + EXPECT_EQ(symbolt_table.insert(tie_product_equal_op), "tie_product_equal"); EXPECT_EQ( - symbolt_table.Lookup("tie_product_equal") + symbolt_table.Lookup("tie_product_equal") .size(), static_cast(1)); - EXPECT_EQ(symbolt_table.Lookup( + EXPECT_EQ(symbolt_table.Lookup( "tie_product_equal")[0], - tie_product_equal); + tie_product_equal_op); EXPECT_EQ(lhs, lhs_ref); EXPECT_EQ(rhs, rhs_ref); } -TEST(shape_op, tie_shape) { +TEST(shape_op, tie_shape_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - auto op = CreateDenseTensorOp( + auto op = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::OpResult res = op->result(0); - pir::dialect::TieShapeOp tie_shape_op = - builder.Build(res); - pir::Value tie_shape_op_value = tie_shape_op.value(); + pir::shape::TieShapeOp tie_shape_op = + builder.Build(res); + pir::Value tie_shape_op_input = tie_shape_op.input(); pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); @@ -137,28 +139,28 @@ TEST(shape_op, tie_shape) { auto array_attr = pir::ArrayAttribute::get(ctx, new_attrs); tie_shape_op->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr); std::vector arr_attr_vec = tie_shape_op ->attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName()) + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName()) .AsVector(); - EXPECT_EQ(tie_shape_op_value, res); + EXPECT_EQ(tie_shape_op_input, res); EXPECT_EQ(arr_attr_vec.size(), static_cast(2)); EXPECT_EQ(arr_attr_vec[0].dyn_cast(), attr_s0); EXPECT_EQ(arr_attr_vec[1].dyn_cast(), attr_s1); EXPECT_TRUE(tie_shape_op->HasAttribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName())); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName())); } TEST(shape_op, func_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::dialect::FuncOp func_op = builder.Build(); + pir::shape::FuncOp func_op = builder.Build(); auto func_block = func_op.block(); builder.SetInsertionPointToStart(func_block); builder.Build(pir::Int32Attribute::get(ctx, 2), @@ -168,19 +170,20 @@ TEST(shape_op, func_op) { EXPECT_EQ(func_block->size(), static_cast(1)); } -TEST(shape_op, tensor_dim) { +TEST(shape_op, tensor_dim_op) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::Operation *op = CreateDenseTensorOp( + pir::Operation *op = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::OpResult res_dense_tensor_value = op->result(0); - pir::dialect::TensorDimOp tensor_dim_op0 = - builder.Build(res_dense_tensor_value, 0); + pir::shape::TensorDimOp tensor_dim_op0 = + builder.Build(res_dense_tensor_value, 0); pir::OpResult res0 = tensor_dim_op0.out(); + std::optional index0 = tensor_dim_op0.GetConstantIndex(); pir::OpResult index_value = builder @@ -188,14 +191,117 @@ TEST(shape_op, tensor_dim) { pir::Int64Attribute::get(pir::IrContext::Instance(), 1), pir::IndexType::get(pir::IrContext::Instance())) ->result(0); - pir::dialect::TensorDimOp tensor_dim_op1 = - builder.Build(res_dense_tensor_value, - index_value); + pir::shape::TensorDimOp tensor_dim_op1 = + builder.Build(res_dense_tensor_value, + index_value); pir::OpResult res1 = tensor_dim_op1.out(); EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(*index0, static_cast(0)); EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); EXPECT_EQ(tensor_dim_op0.source(), res_dense_tensor_value); EXPECT_EQ(tensor_dim_op1.source(), res_dense_tensor_value); EXPECT_EQ(tensor_dim_op1.index(), index_value); } + +TEST(shape_op, shape_of_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + auto op = test::CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult res = op->result(0); + + pir::shape::ShapeOfOp shape_of_op = builder.Build(res); + pir::Value shape_of_op_input = shape_of_op.input(); + EXPECT_EQ(shape_of_op_input, res); +} + +TEST(shape_op, from_elements_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::Int32Attribute int32_attr0 = builder.int32_attr(0); + pir::Int32Attribute int32_attr1 = builder.int32_attr(1); + pir::Int32Attribute int32_attr2 = builder.int32_attr(2); + pir::Int32Type int32_type = builder.int32_type(); + + pir::OpResult element0 = + builder.Build(int32_attr0, int32_type).out(); + pir::OpResult element1 = + builder.Build(int32_attr1, int32_type).out(); + pir::OpResult element2 = + builder.Build(int32_attr2, int32_type).out(); + + std::vector elements_in = {element0, element1, element2}; + + pir::shape::FromElementsOp from_elements_op = + builder.Build(elements_in); + + std::vector elements_out = from_elements_op.elements(); + for (size_t i = 0; i < elements_in.size(); i++) { + EXPECT_EQ(elements_in[i], elements_out[i]); + } +} + +TEST(shape_op, extract_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + auto op = test::CreateDenseTensorOp(ctx, {3, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult res = op->result(0); + + pir::Int32Attribute int32_attr = builder.int32_attr(1); + pir::Int32Type int32_type = builder.int32_type(); + pir::OpResult indice = + builder.Build(int32_attr, int32_type).out(); + std::vector indice_in = {indice, indice}; + + pir::shape::ExtractOp extract_op = + builder.Build(res, indice_in); + pir::Value input = extract_op.tensor(); + std::vector indice_out = extract_op.indices(); + + EXPECT_EQ(input, res); + for (size_t i = 0; i < indice_in.size(); i++) { + EXPECT_EQ(indice_in[i], indice_out[i]); + } +} + +TEST(shape_op, constant_index_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::shape::ConstantIndexOp constant_index_op = + builder.Build(1); + + EXPECT_EQ( + constant_index_op.value().dyn_cast().data() == 1, + true); +} + +TEST(shape_op, index_cast_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::IndexAttribute index_attr = builder.index_attr(1); + pir::IndexType index_type = builder.index_type(); + pir::OpResult in = + builder.Build(index_attr, index_type).out(); + + pir::shape::IndexCastOp index_cast_op = + builder.Build(builder.int32_type(), in); + pir::Value index_cast_op_input = index_cast_op.in(); + + EXPECT_EQ(index_cast_op_input, in); +} diff --git a/test/cpp/pir/shape_dialect/shape_struct_test.cc b/test/cpp/pir/shape_dialect/shape_struct_test.cc index 64b58a399a1508..a9020f5e31ad97 100644 --- a/test/cpp/pir/shape_dialect/shape_struct_test.cc +++ b/test/cpp/pir/shape_dialect/shape_struct_test.cc @@ -15,97 +15,24 @@ #include #include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/builder.h" -#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" -#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/dialect/shape/utils/symbol_table.h" -pir::AttributeMap CreateAttributeMap( - const std::vector &attribute_names, - const std::vector &attributes) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} - -pir::Operation *CreateDenseTensorOp( - pir::IrContext *ctx, - const phi::DDim &dims, - const std::vector &attribute_names, - const std::vector &attributes) { - std::vector op_inputs = {}; - pir::Type fp32_dtype = pir::Float32Type::get(ctx); - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{0, 1, 2}}; - size_t offset = 0; - std::vector op_output_types = { - paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset)}; - pir::Operation *op = - pir::Operation::Create(op_inputs, - CreateAttributeMap(attribute_names, attributes), - op_output_types, - pir::OpInfo()); - return op; -} - -TEST(shape_struct_test, symbolic_dim) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - pir::Builder builder = pir::Builder(ctx, program.block()); - - pir::dialect::SymbolicDim sym_dim1 = builder.Build( - "S0", 10, false, false, false, false); - pir::dialect::SymbolicDim sym_dim2 = builder.Build( - "S1", 10, false, false, false, false); - - EXPECT_EQ(sym_dim1.GetDimSize(), 10); - EXPECT_EQ(sym_dim1.GetSymName(), "S0"); - EXPECT_FALSE(sym_dim1.GetKnownNegativeOne()); - EXPECT_FALSE(sym_dim1.GetKnownNonSizeOne()); - EXPECT_FALSE(sym_dim1.GetKnownNonSizeZero()); - EXPECT_FALSE(sym_dim1.GetKnownNonNegative()); - - EXPECT_FALSE(sym_dim1.IsDynamic()); - EXPECT_TRUE(sym_dim1.Merge(sym_dim2)); - - sym_dim1.SetDimSize(20); - sym_dim1.SetSymName("S2"); - sym_dim1.UpdateKnownNegativeOne(true); - sym_dim1.UpdateKnownNonSizeOne(true); - sym_dim1.UpdateKnownNonSizeZero(true); - sym_dim1.UpdateKnownNonNegative(true); - - EXPECT_FALSE(sym_dim1.Merge(sym_dim2)); - - EXPECT_EQ(sym_dim1.GetDimSize(), 20); - EXPECT_EQ(sym_dim1.GetSymName(), "S2"); - EXPECT_TRUE(sym_dim1.GetKnownNegativeOne()); - EXPECT_TRUE(sym_dim1.GetKnownNonSizeOne()); - EXPECT_TRUE(sym_dim1.GetKnownNonSizeZero()); - EXPECT_TRUE(sym_dim1.GetKnownNonNegative()); -} +#include "test/cpp/pir/tools/test_pir_utils.h" TEST(shape_struct_test, symbolic_dim_product) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::SymbolicDim sym_dim = builder.Build( + pir::shape::SymbolicDimOp sym_dim = builder.Build( "S0", pir::ShapedTypeInterface::kDynamic, false, false, false, false); pir::SymbolicDimProduct sym_dim_product1; pir::SymbolicDimProduct sym_dim_product2; @@ -119,39 +46,39 @@ TEST(shape_struct_test, symbolic_dim_product) { TEST(shape_struct_test, symbolic_dim_table) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::dialect::SymbolicDim sym_dim = builder.Build( + pir::shape::SymbolicDimOp sym_dim = builder.Build( "S0", 10, false, false, false, false); pir::SymbolTable symbol_table(program.module_op()); EXPECT_EQ(symbol_table.insert(sym_dim), "S0"); - EXPECT_EQ(symbol_table.Lookup("S0"), sym_dim); + EXPECT_EQ(symbol_table.Lookup("S0"), sym_dim); EXPECT_EQ(symbol_table.getOp(), program.module_op()); - EXPECT_FALSE(symbol_table.Lookup("S1")); + EXPECT_FALSE(symbol_table.Lookup("S1")); } TEST(shape_struct_test, symbolic_dim_mgr_simple) { /******************************************************/ - /* Mgr simple version, only SymbolicDim related func. */ + /* Mgr simple version, only SymbolicDimOp related func. */ /******************************************************/ pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); - pir::dialect::SymbolicDim sym_dim_s0 = sym_dim_mgr.NewSymbolicDim(); - pir::dialect::SymbolicDim sym_dim_s1 = sym_dim_mgr.NewSymbolicDim(); - pir::dialect::SymbolicDim sym_dim_c10 = + pir::shape::SymbolicDimOp sym_dim_s0 = sym_dim_mgr.NewSymbolicDim(); + pir::shape::SymbolicDimOp sym_dim_s1 = sym_dim_mgr.NewSymbolicDim(); + pir::shape::SymbolicDimOp sym_dim_c10 = sym_dim_mgr.NewConstantSymbolicDim(10); sym_dim_mgr.MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); - auto op = CreateDenseTensorOp( + auto op = test::CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::Value res = op->result(0); - std::vector sym_dim_vec = + std::vector sym_dim_vec = sym_dim_mgr.CreateSymbolicDimsForRankedValue(res); EXPECT_EQ(sym_dim_s0.GetSymName(), "S0"); @@ -161,9 +88,9 @@ TEST(shape_struct_test, symbolic_dim_mgr_simple) { EXPECT_EQ(sym_dim_c10.GetDimSize(), 10); EXPECT_EQ(sym_dim_vec[0].GetSymName(), "S2"); EXPECT_EQ(sym_dim_vec[1].GetSymName(), "C2"); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("S0"), + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("S0"), sym_dim_s0); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C10"), + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C10"), sym_dim_c10); EXPECT_EQ(sym_dim_mgr.GetRootSymbolicDim(sym_dim_s1), sym_dim_s0); EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s1)); @@ -176,47 +103,47 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { /***************************************************************/ pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); auto func_op = - sym_dim_mgr.symbolTable().getOp()->dyn_cast(); + sym_dim_mgr.symbolTable().getOp()->dyn_cast(); pir::Builder builder = pir::Builder(ctx, func_op.block()); - pir::dialect::SymbolicDim sym_dim_s0 = sym_dim_mgr.NewSymbolicDim("S0"); - pir::dialect::SymbolicDim sym_dim_s1 = sym_dim_mgr.NewSymbolicDim("S1"); - pir::dialect::SymbolicDim sym_dim_s2 = sym_dim_mgr.NewSymbolicDim("S2"); - pir::dialect::SymbolicDim sym_dim_s3 = sym_dim_mgr.NewSymbolicDim("S3"); - pir::dialect::SymbolicDim sym_dim_s4 = sym_dim_mgr.NewSymbolicDim("S4"); - pir::dialect::SymbolicDim sym_dim_s5 = sym_dim_mgr.NewSymbolicDim("S5"); - pir::dialect::SymbolicDim sym_dim_s6 = sym_dim_mgr.NewSymbolicDim("S6"); - pir::dialect::SymbolicDim sym_dim_s7 = sym_dim_mgr.NewSymbolicDim("S7"); - pir::dialect::SymbolicDim sym_dim_s8 = sym_dim_mgr.NewSymbolicDim("S8"); - pir::dialect::SymbolicDim sym_dim_s9 = sym_dim_mgr.NewSymbolicDim("S9"); - pir::dialect::SymbolicDim sym_dim_s10 = sym_dim_mgr.NewSymbolicDim("S10"); - pir::dialect::SymbolicDim sym_dim_s11 = sym_dim_mgr.NewSymbolicDim("S11"); - pir::dialect::SymbolicDim sym_dim_s12 = sym_dim_mgr.NewSymbolicDim("S12"); - pir::dialect::SymbolicDim sym_dim_c10 = + pir::shape::SymbolicDimOp sym_dim_s0 = sym_dim_mgr.NewSymbolicDim("S0"); + pir::shape::SymbolicDimOp sym_dim_s1 = sym_dim_mgr.NewSymbolicDim("S1"); + pir::shape::SymbolicDimOp sym_dim_s2 = sym_dim_mgr.NewSymbolicDim("S2"); + pir::shape::SymbolicDimOp sym_dim_s3 = sym_dim_mgr.NewSymbolicDim("S3"); + pir::shape::SymbolicDimOp sym_dim_s4 = sym_dim_mgr.NewSymbolicDim("S4"); + pir::shape::SymbolicDimOp sym_dim_s5 = sym_dim_mgr.NewSymbolicDim("S5"); + pir::shape::SymbolicDimOp sym_dim_s6 = sym_dim_mgr.NewSymbolicDim("S6"); + pir::shape::SymbolicDimOp sym_dim_s7 = sym_dim_mgr.NewSymbolicDim("S7"); + pir::shape::SymbolicDimOp sym_dim_s8 = sym_dim_mgr.NewSymbolicDim("S8"); + pir::shape::SymbolicDimOp sym_dim_s9 = sym_dim_mgr.NewSymbolicDim("S9"); + pir::shape::SymbolicDimOp sym_dim_s10 = sym_dim_mgr.NewSymbolicDim("S10"); + pir::shape::SymbolicDimOp sym_dim_s11 = sym_dim_mgr.NewSymbolicDim("S11"); + pir::shape::SymbolicDimOp sym_dim_s12 = sym_dim_mgr.NewSymbolicDim("S12"); + pir::shape::SymbolicDimOp sym_dim_c10 = sym_dim_mgr.NewConstantSymbolicDim(10); - pir::dialect::SymbolicDim sym_dim_c20 = + pir::shape::SymbolicDimOp sym_dim_c20 = sym_dim_mgr.NewConstantSymbolicDim(20); - pir::OpResult dim_op_s0 = builder.Build("S0").out(); - pir::OpResult dim_op_s1 = builder.Build("S1").out(); - pir::OpResult dim_op_s2 = builder.Build("S2").out(); - pir::OpResult dim_op_s3 = builder.Build("S3").out(); - pir::OpResult dim_op_s4 = builder.Build("S4").out(); - pir::OpResult dim_op_s5 = builder.Build("S5").out(); - pir::OpResult dim_op_s6 = builder.Build("S6").out(); - pir::OpResult dim_op_s7 = builder.Build("S7").out(); - pir::OpResult dim_op_s8 = builder.Build("S8").out(); - pir::OpResult dim_op_s9 = builder.Build("S9").out(); - pir::OpResult dim_op_s10 = builder.Build("S10").out(); - pir::OpResult dim_op_s11 = builder.Build("S11").out(); - pir::OpResult dim_op_c10 = builder.Build("C10").out(); - pir::OpResult dim_op_c20 = builder.Build("C20").out(); + pir::OpResult dim_op_s0 = builder.Build("S0").out(); + pir::OpResult dim_op_s1 = builder.Build("S1").out(); + pir::OpResult dim_op_s2 = builder.Build("S2").out(); + pir::OpResult dim_op_s3 = builder.Build("S3").out(); + pir::OpResult dim_op_s4 = builder.Build("S4").out(); + pir::OpResult dim_op_s5 = builder.Build("S5").out(); + pir::OpResult dim_op_s6 = builder.Build("S6").out(); + pir::OpResult dim_op_s7 = builder.Build("S7").out(); + pir::OpResult dim_op_s8 = builder.Build("S8").out(); + pir::OpResult dim_op_s9 = builder.Build("S9").out(); + pir::OpResult dim_op_s10 = builder.Build("S10").out(); + pir::OpResult dim_op_s11 = builder.Build("S11").out(); + pir::OpResult dim_op_c10 = builder.Build("C10").out(); + pir::OpResult dim_op_c20 = builder.Build("C20").out(); pir::OpResult constant = builder .Build(pir::Int32Attribute::get(ctx, 2), @@ -224,62 +151,62 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { ->result(0); // Mark S1 == S2. - builder.Build( + builder.Build( 2, 2, std::vector{constant, dim_op_s1, dim_op_s2, constant}); // Mark S0 * S1 == S2 * S3, For check S0 == S3. - builder.Build( + builder.Build( 2, 2, std::vector{dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3}); // Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5. - builder.Build( + builder.Build( 3, 3, std::vector{ dim_op_s4, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s5}); // For check S6 == C10 * C20. - builder.Build( + builder.Build( 1, 2, std::vector{dim_op_s6, dim_op_c10, dim_op_c20}); // Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7. - builder.Build( + builder.Build( 3, 3, std::vector{ dim_op_c10, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s7}); // For unsimplify product case: S8 * S9 == S10 * S11 - builder.Build( + builder.Build( 2, 2, std::vector{dim_op_s8, dim_op_s9, dim_op_s10, dim_op_s11}); - auto op = CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic}, - {"op0_attr"}, - {"op0_name"}); - auto op_ = CreateDenseTensorOp(ctx, - {pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - pir::ShapedTypeInterface::kDynamic, - 10, - 20}, - {"op1_attr"}, - {"op1_name"}); + auto op = test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic}, + {"op0_attr"}, + {"op0_name"}); + auto op_ = test::CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + pir::ShapedTypeInterface::kDynamic, + 10, + 20}, + {"op1_attr"}, + {"op1_name"}); pir::OpResult res = op->result(0); pir::OpResult res_ = op_->result(0); builder.SetInsertionPointToEnd(program.block()); - pir::dialect::TieShapeOp tie_shape_op1 = - builder.Build(res); - pir::dialect::TieShapeOp tie_shape_op2 = - builder.Build(res_); + pir::shape::TieShapeOp tie_shape_op1 = + builder.Build(res); + pir::shape::TieShapeOp tie_shape_op2 = + builder.Build(res_); pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); @@ -314,9 +241,9 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { auto array_attr_ref = pir::ArrayAttribute::get(ctx, new_attrs_ref); tie_shape_op1->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr1); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr1); tie_shape_op2->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), array_attr2); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr2); EXPECT_TRUE(sym_dim_mgr.Load()); @@ -380,7 +307,7 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s3)); EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s4, sym_dim_s5)); EXPECT_EQ(sym_dim_s6.GetDimSize(), 200); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C20"), + EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C20"), sym_dim_c20); EXPECT_EQ(sym_dim_s7.GetDimSize(), sym_dim_c10.GetDimSize()); EXPECT_EQ(simplified_product_s7.factor, 10); @@ -402,11 +329,11 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { EXPECT_TRUE(sym_dim_mgr_new.Load()); auto attrs = tie_shape_op1.attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName()); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName()); EXPECT_FALSE( - sym_dim_mgr_new.symbolTable().Lookup("S7")); + sym_dim_mgr_new.symbolTable().Lookup("S7")); EXPECT_EQ(sym_dim_mgr_new.symbolTable() - .Lookup("tie_product_equal") + .Lookup("tie_product_equal") .size(), static_cast(1)); @@ -416,52 +343,56 @@ TEST(shape_struct_test, symbolic_dim_mgr_complex) { TEST(shape_struct_test, shape_analysis) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::dialect::FuncOp func_op = builder.Build(); + pir::shape::FuncOp func_op = builder.Build(); phi::DDim dims_D_2 = {pir::ShapedTypeInterface::kDynamic, 2}; phi::DDim dims_2_2 = {2, 2}; phi::DDim dims_D = {pir::ShapedTypeInterface::kDynamic}; // same shape with dynamic: value1 == value2 - auto op1 = CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); - auto op2 = CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); + auto op1 = + test::CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); + auto op2 = + test::CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); pir::OpResult value1 = op1->result(0); pir::OpResult value2 = op2->result(0); // same shape with static: value3 == value4 - auto op3 = CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); - auto op4 = CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); + auto op3 = + test::CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); + auto op4 = + test::CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); pir::OpResult value3 = op3->result(0); pir::OpResult value4 = op4->result(0); // one dimension with dynamic: value5 != value1 != value3 - auto op5 = CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); + auto op5 = test::CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); pir::OpResult value5 = op5->result(0); - pir::dialect::TieShapeOp tie_shape_op1 = - builder.Build(value1); - pir::dialect::TieShapeOp tie_shape_op2 = - builder.Build(value2); - pir::dialect::TieShapeOp tie_shape_op3 = - builder.Build(value3); - pir::dialect::TieShapeOp tie_shape_op4 = - builder.Build(value4); - pir::dialect::TieShapeOp tie_shape_op5 = - builder.Build(value5); + pir::shape::TieShapeOp tie_shape_op1 = + builder.Build(value1); + pir::shape::TieShapeOp tie_shape_op2 = + builder.Build(value2); + pir::shape::TieShapeOp tie_shape_op3 = + builder.Build(value3); + pir::shape::TieShapeOp tie_shape_op4 = + builder.Build(value4); + pir::shape::TieShapeOp tie_shape_op5 = + builder.Build(value5); builder.SetInsertionPointToEnd(func_op.block()); - builder.Build("C2", 2, true, false, true, true); - pir::dialect::SymbolicDim sym_dim_s0 = - builder.Build( + builder.Build("C2", 2, true, false, true, true); + pir::shape::SymbolicDimOp sym_dim_s0 = + builder.Build( "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::dialect::SymbolicDim sym_dim_s1 = - builder.Build( + pir::shape::SymbolicDimOp sym_dim_s1 = + builder.Build( "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::dialect::SymbolicDim sym_dim_s2 = - builder.Build( + pir::shape::SymbolicDimOp sym_dim_s2 = + builder.Build( "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); @@ -476,15 +407,15 @@ TEST(shape_struct_test, shape_analysis) { auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2}); tie_shape_op1->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op1); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op1); tie_shape_op2->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op2); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op2); tie_shape_op3->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op3); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op3); tie_shape_op4->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op4); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op4); tie_shape_op5->set_attribute( - pir::dialect::SymbolicDim::GetSymbolicDimAttrName(), attr_op5); + pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op5); pir::ShapeConstraintIRAnalysis shape_analysis(program.module_op()); EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4)); diff --git a/test/cpp/pir/tools/test_pir_utils.h b/test/cpp/pir/tools/test_pir_utils.h new file mode 100644 index 00000000000000..d71ddb0d2ea954 --- /dev/null +++ b/test/cpp/pir/tools/test_pir_utils.h @@ -0,0 +1,59 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" + +namespace test { + +pir::AttributeMap CreateAttributeMap( + const std::vector &attribute_names, + const std::vector &attributes) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes, + const pir::Type &dtype = + pir::Float32Type::get(pir::IrContext::Instance())) { + std::vector op_inputs = {}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, dtype, dims, data_layout, lod, offset)}; + + pir::Builder builder = pir::Builder(ctx); + pir::Operation *op = + builder.Build(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} + +} // namespace test diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index a28e412a9ebbaa..7e420635ad210d 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -14,7 +14,7 @@ #include -#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" @@ -42,6 +42,16 @@ PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); namespace paddle { namespace framework { +pir::Operation* GetOpFromProgram(const std::string& op_name, + const pir::Program& program) { + for (auto op : *(program.block())) { + if (op->name() == op_name) { + return op; + } + } + return nullptr; +} + TEST(VJP, TanhBackwardTest) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); @@ -59,38 +69,38 @@ TEST(VJP, TanhBackwardTest) { std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{{op1.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh"); auto tanh_vjp_interface_impl = op2_info.GetInterfaceImpl(); - tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + tanh_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op2->result(0), "tanh_out"); + builder->Build( + GetOpFromProgram("pd_op.tanh_grad", program)->result(0), "tanh_grad_out"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars( - {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); + test_core.SetSkipGcVars({"tanh_out", "tanh_grad_out"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_1")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_1") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("tanh_out")->Get() + : test_core.local_scope() + ->FindVar("tanh_out") + ->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() + ? scope.FindVar("tanh_grad_out")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") + ->FindVar("tanh_grad_out") ->Get(); ASSERT_NEAR(out_tensor.data()[0], 0.76159, 1e-5); @@ -114,38 +124,39 @@ TEST(VJP, Tanh_BackwardTest) { std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{{op1.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.tanh_"); auto tanh_vjp_interface_impl = op2_info.GetInterfaceImpl(); - tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + tanh_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + std::string tanh_out = "tanh_out"; + std::string tanh_grad_out = "tanh_grad_out"; + builder->Build(op2->result(0), tanh_out); + builder->Build( + GetOpFromProgram("pd_op.tanh_grad", program)->result(0), tanh_grad_out); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars( - {prefix_str + "_inner_var_0", prefix_str + "_inner_var_2"}); + test_core.SetSkipGcVars({tanh_out, tanh_grad_out}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_0")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_0") - ->Get(); + ? scope.FindVar(tanh_out)->Get() + : test_core.local_scope()->FindVar(tanh_out)->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_2")->Get() + ? scope.FindVar(tanh_grad_out)->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_2") + ->FindVar(tanh_grad_out) ->Get(); ASSERT_NEAR(out_tensor.data()[0], 0.76159, 1e-5); @@ -169,12 +180,19 @@ TEST(VJP, MeanBackwardTest) { std::vector{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{{op1.out()}}; + std::vector> outputs{{op2.out()}}; std::vector> out_grads{{op3.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.mean"); auto mean_vjp_interface_impl = op2_info.GetInterfaceImpl(); - mean_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + mean_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op2->result(0), "mean_out"); + builder->Build( + GetOpFromProgram("pd_op.mean_grad", program)->result(0), "mean_grad_out"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -183,24 +201,18 @@ TEST(VJP, MeanBackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars( - {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); + test_core.SetSkipGcVars({"mean_out", "mean_grad_out"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_1")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_1") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("mean_out")->Get() + : test_core.local_scope() + ->FindVar("mean_out") + ->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() + ? scope.FindVar("mean_grad_out")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") + ->FindVar("mean_grad_out") ->Get(); ASSERT_EQ(out_tensor.data()[0], 2.0); ASSERT_EQ(grad_out_tensor.data()[0], 0.25); @@ -227,11 +239,22 @@ TEST(VJP, ConcatBackwardTest) { paddle::dialect::FullOp op4 = builder->Build( std::vector{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false, false}}; + std::vector> inputs{{op1.out(), op1.out()}, + {op3.axis()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.concat"); auto concat_vjp_interface_impl = op2_info.GetInterfaceImpl(); - concat_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); + concat_vjp_interface_impl->vjp_( + op3.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op3->result(0), "concat_out"); + builder->Build( + GetOpFromProgram("builtin.split", program)->result(0), "split_out_0"); + builder->Build( + GetOpFromProgram("builtin.split", program)->result(1), "split_out_1"); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); @@ -239,31 +262,24 @@ TEST(VJP, ConcatBackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_3", - prefix_str + "_inner_var_7", - prefix_str + "_inner_var_8"}); + test_core.SetSkipGcVars({"concat_out", "split_out_0", "split_out_1"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("concat_out")->Get() + : test_core.local_scope() + ->FindVar("concat_out") + ->Get(); auto grad_out_tensor_0 = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_7")->Get() + ? scope.FindVar("split_out_0")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_7") + ->FindVar("split_out_0") ->Get(); auto grad_out_tensor_1 = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + ? scope.FindVar("split_out_1")->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_8") + ->FindVar("split_out_1") ->Get(); ASSERT_EQ(out_tensor.data()[0], 2.0); ASSERT_EQ(grad_out_tensor_0.data()[0], 1.0); @@ -291,12 +307,21 @@ TEST(VJP, AddBackwardTest) { std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}, {false}}; + std::vector> inputs{{op1.out()}, {op2.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add"); auto add_vjp_interface_impl = op3_info.GetInterfaceImpl(); - add_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); + add_vjp_interface_impl->vjp_( + op3.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op3->result(0), "add_out"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(0), "add_grad_out_0"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(1), "add_grad_out_1"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -305,33 +330,24 @@ TEST(VJP, AddBackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_2", - prefix_str + "_inner_var_4", - prefix_str + "_inner_var_5"}); + test_core.SetSkipGcVars({"add_out", "add_grad_out_0", "add_grad_out_1"}); test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_2")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_2") - ->Get(); - auto dx = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_4")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_4") - ->Get(); - - auto dy = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_5")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_5") - ->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("add_out")->Get() + : test_core.local_scope() + ->FindVar("add_out") + ->Get(); + auto dx = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_0")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_0") + ->Get(); + + auto dy = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_1")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_1") + ->Get(); ASSERT_EQ(out_tensor.data()[0], 4.0); ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); @@ -356,13 +372,21 @@ TEST(VJP, Add_BackwardTest) { std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}, {false}}; + std::vector> inputs{{op1.out()}, {op2.out()}}; + std::vector> outputs{{op3.out()}}; std::vector> out_grads{{op4.out()}}; pir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd_op.add_"); auto add_inplace_vjp_interface_impl = op3_info.GetInterfaceImpl(); add_inplace_vjp_interface_impl->vjp_( - op3.operation(), out_grads, stop_gradients); + op3.operation(), inputs, outputs, out_grads, stop_gradients); + + builder->Build(op1->result(0), "full_op1_out"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(0), "add_grad_out_0"); + builder->Build( + GetOpFromProgram("pd_op.add_grad", program)->result(1), "add_grad_out_1"); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); @@ -371,33 +395,25 @@ TEST(VJP, Add_BackwardTest) { ProgramDesc prog_desc; InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_0", - prefix_str + "_inner_var_3", - prefix_str + "_inner_var_4"}); - test_core.Run({}); - auto out_tensor = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_0")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_0") - ->Get(); - auto dx = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_3")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_3") - ->Get(); - auto dy = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_4")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_4") - ->Get(); + test_core.SetSkipGcVars({"full_op1_out", "add_grad_out_0", "add_grad_out_1"}); + test_core.Run({}); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("full_op1_out")->Get() + : test_core.local_scope() + ->FindVar("full_op1_out") + ->Get(); + auto dx = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_0")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_0") + ->Get(); + + auto dy = test_core.local_scope() == nullptr + ? scope.FindVar("add_grad_out_1")->Get() + : test_core.local_scope() + ->FindVar("add_grad_out_1") + ->Get(); ASSERT_EQ(out_tensor.data()[0], 4.0); ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); @@ -405,6 +421,7 @@ TEST(VJP, Add_BackwardTest) { TEST(VJP, SplitBackwardTest) { pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); pir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); @@ -422,44 +439,51 @@ TEST(VJP, SplitBackwardTest) { std::vector{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; + std::vector> inputs{ + {op2.x()}, {op2.sections()}, {op2.axis()}}; + std::vector> outputs{{op3.outputs()}}; std::vector> out_grads{{op3.result(0), op4.out()}}; pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.split"); auto concat_vjp_interface_impl = op2_info.GetInterfaceImpl(); - concat_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); + concat_vjp_interface_impl->vjp_( + op2.operation(), inputs, outputs, out_grads, stop_gradients); + + std::string split_out1 = "split_out1"; + std::string split_out2 = "split_out2"; + std::string concat_out = "concat_out"; + + builder->Build(op3->result(0), split_out1); + builder->Build(op3->result(1), split_out2); + builder->Build( + GetOpFromProgram("pd_op.concat", program)->result(0), concat_out); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; - ProgramDesc prog_desc; + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); - std::stringstream os; - os << reinterpret_cast( - const_cast(test_core.Impl())); - std::string prefix_str = os.str(); - test_core.SetSkipGcVars({prefix_str + "_inner_var_4", - prefix_str + "_inner_var_5", - prefix_str + "_inner_var_8"}); + + test_core.SetSkipGcVars({split_out1, split_out2, concat_out}); test_core.Run({}); - auto out_tensor_0 = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_4")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_4") - ->Get(); - auto out_tensor_1 = - test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_5")->Get() - : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_5") - ->Get(); + auto out_tensor_0 = test_core.local_scope() == nullptr + ? scope.FindVar(split_out1)->Get() + : test_core.local_scope() + ->FindVar(split_out1) + ->Get(); + auto out_tensor_1 = test_core.local_scope() == nullptr + ? scope.FindVar(split_out2)->Get() + : test_core.local_scope() + ->FindVar(split_out2) + ->Get(); auto grad_out_tensor_0 = test_core.local_scope() == nullptr - ? scope.FindVar(prefix_str + "_inner_var_8")->Get() + ? scope.FindVar(concat_out)->Get() : test_core.local_scope() - ->FindVar(prefix_str + "_inner_var_8") + ->FindVar(concat_out) ->Get(); ASSERT_EQ(out_tensor_0.data()[0], 2.0); ASSERT_EQ(out_tensor_0.data()[1], 2.0); diff --git a/test/dygraph_to_static/dygraph_to_static_utils_new.py b/test/dygraph_to_static/dygraph_to_static_utils_new.py index de74552e3248d1..80b0d233763692 100644 --- a/test/dygraph_to_static/dygraph_to_static_utils_new.py +++ b/test/dygraph_to_static/dygraph_to_static_utils_new.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import inspect import logging import os @@ -24,6 +23,7 @@ from paddle import set_flags, static from paddle.base import core +from paddle.jit.api import sot_mode_guard """ # Usage: @@ -69,21 +69,6 @@ def lower_case_name(self): DEFAULT_IR_MODE = IrMode.LEGACY_PROGRAM -def in_sot_mode(): - return os.getenv("ENABLE_FALL_BACK", "False") == "True" - - -@contextlib.contextmanager -def enable_fallback_guard(enable): - flag = os.environ.get("ENABLE_FALL_BACK", None) - os.environ["ENABLE_FALL_BACK"] = enable - yield - if flag is not None: - os.environ["ENABLE_FALL_BACK"] = flag - else: - del os.environ["ENABLE_FALL_BACK"] - - def to_legacy_ast_test(fn): """ convert run fall_back to ast @@ -92,7 +77,7 @@ def to_legacy_ast_test(fn): @wraps(fn) def impl(*args, **kwargs): logger.info("[AST] running AST") - with enable_fallback_guard("False"): + with sot_mode_guard(False): fn(*args, **kwargs) return impl @@ -106,7 +91,7 @@ def to_sot_test(fn): @wraps(fn) def impl(*args, **kwargs): logger.info("[SOT] running SOT") - with enable_fallback_guard("True"): + with sot_mode_guard(True): fn(*args, **kwargs) return impl @@ -263,22 +248,27 @@ def decorator(fn): # Suger decorators # These decorators can be simply composed by base decorators -def ast_only_test(fn): +def test_ast_only(fn): fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn) return fn -def sot_only_test(fn): +def test_sot_only(fn): fn = set_to_static_mode(ToStaticMode.SOT)(fn) return fn -def test_with_new_ir(fn): +def test_pir_only(fn): fn = set_ir_mode(IrMode.PIR)(fn) return fn -def _test_and_compare_with_new_ir(fn): +def test_legacy_and_pir(fn): + fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn) + return fn + + +def compare_legacy_with_pir(fn): @wraps(fn) def impl(*args, **kwargs): outs = fn(*args, **kwargs) @@ -297,17 +287,6 @@ def impl(*args, **kwargs): return impl -def test_and_compare_with_new_ir(need_check_output: bool = True): - def decorator(fn): - fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn) - if need_check_output: - logger.info(f"[need_check_output] {fn.__name__}") - fn = _test_and_compare_with_new_ir(fn) - return fn - - return decorator - - # For debug def show_all_test_cases(test_class): logger.info(f"[showing {test_class.__name__}]") diff --git a/test/dygraph_to_static/test_assert.py b/test/dygraph_to_static/test_assert.py index 210e904454fd93..2e5066b801e523 100644 --- a/test/dygraph_to_static/test_assert.py +++ b/test/dygraph_to_static/test_assert.py @@ -17,8 +17,8 @@ import numpy from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -37,12 +37,11 @@ def dyfunc_assert_non_variable(x=True): assert x -# @dy2static_unittest class TestAssertVariable(Dy2StTestBase): def _run(self, func, x, with_exception, to_static): paddle.jit.enable_to_static(to_static) if with_exception: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 with base.dygraph.guard(): func(x) else: @@ -53,8 +52,8 @@ def _run_dy_static(self, func, x, with_exception): self._run(func, x, with_exception, True) self._run(func, x, with_exception, False) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_non_variable(self): self._run_dy_static( dyfunc_assert_non_variable, x=False, with_exception=True @@ -63,8 +62,8 @@ def test_non_variable(self): dyfunc_assert_non_variable, x=True, with_exception=False ) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_bool_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([False]), with_exception=True @@ -73,8 +72,8 @@ def test_bool_variable(self): dyfunc_assert_variable, x=numpy.array([True]), with_exception=False ) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_int_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([0]), with_exception=True diff --git a/test/dygraph_to_static/test_ast_util.py b/test/dygraph_to_static/test_ast_util.py index c2468765e34387..a6421e4cc60ba5 100644 --- a/test/dygraph_to_static/test_ast_util.py +++ b/test/dygraph_to_static/test_ast_util.py @@ -19,8 +19,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from ifelse_simple_func import ( dyfunc_with_if_else, @@ -35,7 +35,6 @@ from paddle.utils import gast -# @dy2static_unittest class TestAST2Func(Dy2StTestBase): """ TestCase for the transformation from ast.AST into python callable function. @@ -48,7 +47,7 @@ def _ast2func(self, func): transformed_func, _ = ast_to_func(ast_root, func) return transformed_func - @ast_only_test + @test_ast_only def test_ast2func(self): def func(x, y): return x + y @@ -56,7 +55,7 @@ def func(x, y): x, y = 10, 20 self.assertEqual(func(x, y), self._ast2func(func)(x, y)) - @ast_only_test + @test_ast_only def test_ast2func_dygraph(self): paddle.disable_static() funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else] @@ -68,8 +67,8 @@ def test_ast2func_dygraph(self): test_ret = self._ast2func(func)(x_v).numpy() self.assertTrue((true_ret == test_ret).all()) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_ast2func_static(self): paddle.enable_static() @@ -88,7 +87,7 @@ def func(x): ret = exe.run(main_program, fetch_list=[true_ret, test_ret]) self.assertTrue((ret[0] == ret[1]).all()) - @ast_only_test + @test_ast_only def test_ast2func_error(self): with self.assertRaises(Exception) as e: self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo')) diff --git a/test/dygraph_to_static/test_backward_without_params.py b/test/dygraph_to_static/test_backward_without_params.py index 336d96f2399b53..e11ee387ec69cf 100644 --- a/test/dygraph_to_static/test_backward_without_params.py +++ b/test/dygraph_to_static/test_backward_without_params.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -32,9 +29,8 @@ def forward(self, x): return out -# @dy2static_unittest class TestBackwardWithoutParams(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_run(self): net = paddle.jit.to_static(Net()) @@ -57,9 +53,8 @@ def forward(self, x): return y, out -# @dy2static_unittest class TestZeroSizeNet(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_run(self): net = paddle.jit.to_static(ZeroSizeNet()) x = paddle.ones([2, 2]) diff --git a/test/dygraph_to_static/test_basic_api_transformation.py b/test/dygraph_to_static/test_basic_api_transformation.py index e0998b8fe1e67f..51ddbe6e11a1cb 100644 --- a/test/dygraph_to_static/test_basic_api_transformation.py +++ b/test/dygraph_to_static/test_basic_api_transformation.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base, to_tensor @@ -72,8 +69,7 @@ def dyfunc_bool_to_tensor(x): return paddle.to_tensor(True) -@dy2static_unittest -class TestDygraphBasicApi_ToVariable(unittest.TestCase): +class TestDygraphBasicApi_ToVariable(Dy2StTestBase): def setUp(self): self.input = np.ones(5).astype("int32") self.test_funcs = [ @@ -96,7 +92,7 @@ def get_dygraph_output(self): res = self.dygraph_func(self.input).numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): main_program = base.Program() main_program.random_seed = SEED @@ -234,8 +230,7 @@ def dyfunc_Prelu(input): return res -@dy2static_unittest -class TestDygraphBasicApi(unittest.TestCase): +class TestDygraphBasicApi(Dy2StTestBase): # Compare results of dynamic graph and transformed static graph function which only # includes basic Api. @@ -252,7 +247,7 @@ def get_dygraph_output(self): return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -286,7 +281,7 @@ def get_dygraph_output(self): res = self.dygraph_func(self.input1, self.input2).numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -401,8 +396,7 @@ def dyfunc_PolynomialDecay(): return paddle.to_tensor(lr) -@dy2static_unittest -class TestDygraphBasicApi_CosineDecay(unittest.TestCase): +class TestDygraphBasicApi_CosineDecay(Dy2StTestBase): def setUp(self): self.dygraph_func = dyfunc_CosineDecay @@ -413,7 +407,7 @@ def get_dygraph_output(self): res = self.dygraph_func().numpy() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -444,7 +438,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -471,7 +465,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -498,7 +492,7 @@ def get_dygraph_output(self): res = self.dygraph_func() return res - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): startup_program = base.Program() startup_program.random_seed = SEED @@ -545,8 +539,7 @@ def _dygraph_fn(): np.random.random(1) -@dy2static_unittest -class TestDygraphApiRecognition(unittest.TestCase): +class TestDygraphApiRecognition(Dy2StTestBase): def setUp(self): self.src = inspect.getsource(_dygraph_fn) self.root = gast.parse(self.src) diff --git a/test/dygraph_to_static/test_bert.py b/test/dygraph_to_static/test_bert.py index ba8e2350794aad..7c6a2c1b4d42a4 100644 --- a/test/dygraph_to_static/test_bert.py +++ b/test/dygraph_to_static/test_bert.py @@ -20,10 +20,10 @@ import numpy as np from bert_dygraph_model import PretrainModelLayer from bert_utils import get_bert_config, get_feed_data_reader -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_pir_only, ) from predictor_utils import PredictorTools @@ -78,8 +78,7 @@ def __len__(self): return len(self.src_ids) -@dy2static_unittest -class TestBert(unittest.TestCase): +class TestBert(Dy2StTestBase): def setUp(self): self.bert_config = get_bert_config() self.data_reader = get_feed_data_reader(self.bert_config) @@ -266,7 +265,7 @@ def predict_analysis_inference(self, data): out = output() return out - @test_with_new_ir + @test_pir_only def test_train_new_ir(self): static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader @@ -277,7 +276,7 @@ def test_train_new_ir(self): np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05) np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05) - @ast_only_test + @test_ast_only def test_train(self): static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader diff --git a/test/dygraph_to_static/test_bmn.py b/test/dygraph_to_static/test_bmn.py index f5f8d357598695..11afe6100d79f2 100644 --- a/test/dygraph_to_static/test_bmn.py +++ b/test/dygraph_to_static/test_bmn.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from dygraph_to_static_util import dy2static_unittest, test_with_new_ir +from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only from predictor_utils import PredictorTools import paddle @@ -637,8 +637,7 @@ def val_bmn(model, args): return loss_data -@dy2static_unittest -class TestTrain(unittest.TestCase): +class TestTrain(Dy2StTestBase): def setUp(self): self.args = Args() self.place = ( @@ -751,7 +750,7 @@ def train_bmn(self, args, place, to_static): break return np.array(loss_data) - @test_with_new_ir + @test_pir_only def test_train_new_ir(self): static_res = self.train_bmn(self.args, self.place, to_static=True) dygraph_res = self.train_bmn(self.args, self.place, to_static=False) diff --git a/test/dygraph_to_static/test_break_continue.py b/test/dygraph_to_static/test_break_continue.py index a803c1d4bf49ed..e1df868435e8fa 100644 --- a/test/dygraph_to_static/test_break_continue.py +++ b/test/dygraph_to_static/test_break_continue.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import base @@ -26,14 +26,13 @@ np.random.seed(SEED) -@dy2static_unittest -class TestDy2staticException(unittest.TestCase): +class TestDy2staticException(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = None self.error = "Your if/else have different number of return value." - @ast_only_test + @test_ast_only def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -205,8 +204,7 @@ def test_optim_break_in_while(x): return x -@dy2static_unittest -class TestContinueInFor(unittest.TestCase): +class TestContinueInFor(Dy2StTestBase): def setUp(self): self.input = np.zeros(1).astype('int64') self.place = ( diff --git a/test/dygraph_to_static/test_build_strategy.py b/test/dygraph_to_static/test_build_strategy.py index 85e934afb020bb..ee19dad5842f9c 100644 --- a/test/dygraph_to_static/test_build_strategy.py +++ b/test/dygraph_to_static/test_build_strategy.py @@ -15,14 +15,13 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only from test_resnet import ResNetHelper import paddle -@dy2static_unittest -class TestResnetWithPass(unittest.TestCase): +class TestResnetWithPass(Dy2StTestBase): def setUp(self): self.build_strategy = paddle.static.BuildStrategy() self.build_strategy.fuse_elewise_add_act_ops = True @@ -62,7 +61,7 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @ast_only_test + @test_ast_only def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -74,7 +73,7 @@ def test_resnet(self): ) self.verify_predict() - @ast_only_test + @test_ast_only def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': True}) try: @@ -84,8 +83,7 @@ def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': False}) -@dy2static_unittest -class TestError(unittest.TestCase): +class TestError(Dy2StTestBase): def test_type_error(self): def foo(x): out = x + 1 diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index 199c3e980e20c9..9683afb05bdda0 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -16,7 +16,7 @@ from collections import Counter import numpy as np -from dygraph_to_static_util import dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase from test_fetch_feed import Linear, Pool2D import paddle @@ -25,8 +25,7 @@ from paddle.jit.dy2static import convert_to_static -@dy2static_unittest -class TestCacheProgram(unittest.TestCase): +class TestCacheProgram(Dy2StTestBase): def setUp(self): self.batch_num = 5 self.dygraph_class = Pool2D @@ -76,8 +75,7 @@ def setUp(self): self.data = np.random.random((4, 10)).astype('float32') -@dy2static_unittest -class TestCacheProgramWithOptimizer(unittest.TestCase): +class TestCacheProgramWithOptimizer(Dy2StTestBase): def setUp(self): self.dygraph_class = Linear self.data = np.random.random((4, 10)).astype('float32') @@ -126,8 +124,7 @@ def simple_func(x): return mean -@dy2static_unittest -class TestConvertWithCache(unittest.TestCase): +class TestConvertWithCache(Dy2StTestBase): def test_cache(self): static_func = convert_to_static(simple_func) # Get transformed function from cache. @@ -157,8 +154,7 @@ def sum_under_while(limit): return ret_sum -@dy2static_unittest -class TestToOutputWithCache(unittest.TestCase): +class TestToOutputWithCache(Dy2StTestBase): def test_output(self): with base.dygraph.guard(): ret = sum_even_until_limit(80, 10) diff --git a/test/dygraph_to_static/test_cast.py b/test/dygraph_to_static/test_cast.py index 8c0a4bf0a1318a..48564e2776395e 100644 --- a/test/dygraph_to_static/test_cast.py +++ b/test/dygraph_to_static/test_cast.py @@ -17,8 +17,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from paddle import base @@ -60,7 +60,6 @@ def test_mix_cast(x): return x -# @dy2static_unittest class TestCastBase(Dy2StTestBase): def setUp(self): self.place = ( @@ -89,9 +88,8 @@ def do_test(self): res = self.func(self.input) return res - @ast_only_test # TODO: add new symbolic only test. - @test_and_compare_with_new_ir(False) - # @set_to_static_mode(ToStaticMode.LEGACY_AST) + @test_ast_only # TODO: add new sot only test. + @test_legacy_and_pir def test_cast_result(self): res = self.do_test().numpy() self.assertTrue( @@ -156,8 +154,8 @@ def prepare(self): def set_func(self): self.func = to_static(full_graph=True)(test_mix_cast) - @ast_only_test # TODO: add new symbolic only test. - @test_and_compare_with_new_ir(False) + @test_ast_only # TODO: add new symbolic only test. + @test_legacy_and_pir def test_cast_result(self): res = self.do_test().numpy() self.assertTrue( @@ -188,8 +186,8 @@ def prepare(self): def set_func(self): self.func = to_static(full_graph=True)(test_not_var_cast) - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_cast_result(self): # breakpoint() # print("run once!!!") diff --git a/test/dygraph_to_static/test_cinn.py b/test/dygraph_to_static/test_cinn.py index 84e619149c8009..0f8f5c962934cb 100644 --- a/test/dygraph_to_static/test_cinn.py +++ b/test/dygraph_to_static/test_cinn.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -45,8 +42,7 @@ def apply_to_static(net, use_cinn): return paddle.jit.to_static(net, build_strategy=build_strategy) -@dy2static_unittest -class TestCINN(unittest.TestCase): +class TestCINN(Dy2StTestBase): def setUp(self): self.x = paddle.randn([2, 4]) self.x.stop_gradient = False @@ -83,7 +79,7 @@ def train(self, use_cinn): return res - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_cinn(self): dy_res = self.train(use_cinn=False) cinn_res = self.train(use_cinn=True) diff --git a/test/dygraph_to_static/test_cinn_prim.py b/test/dygraph_to_static/test_cinn_prim.py index 2ed5326f7b9d00..95df5d498c6fb9 100644 --- a/test/dygraph_to_static/test_cinn_prim.py +++ b/test/dygraph_to_static/test_cinn_prim.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -43,8 +43,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -94,7 +93,7 @@ def check_prim(self, net, use_prim): # Ensure that softmax is splitted into small ops self.assertTrue('softmax' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) @@ -105,8 +104,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -161,7 +159,7 @@ def check_prim(self, net, use_prim): if op != "matmul_v2_grad": self.assertTrue("_grad" not in op) - @ast_only_test + @test_ast_only def test_cinn_prim(self): dy_res = self.train(use_prim=False) cinn_res = self.train(use_prim=True) @@ -172,9 +170,8 @@ def test_cinn_prim(self): ) -@dy2static_unittest -class TestBackend(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestBackend(Dy2StTestBase): + @test_legacy_and_pir def test_backend(self): x = paddle.randn([2, 4]) out1 = self.forward(x, 'CINN') diff --git a/test/dygraph_to_static/test_cinn_prim_gelu.py b/test/dygraph_to_static/test_cinn_prim_gelu.py index be2e8f67c1e988..ab9b3697eba620 100644 --- a/test/dygraph_to_static/test_cinn_prim_gelu.py +++ b/test/dygraph_to_static/test_cinn_prim_gelu.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.nn.functional as F @@ -53,8 +53,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -106,7 +105,7 @@ def check_prim(self, net, use_prim): # Ensure that gelu is splitted into small ops self.assertTrue('gelu' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_cinn_prim_layer_norm.py b/test/dygraph_to_static/test_cinn_prim_layer_norm.py index 42bf36d731eca6..94186bb1bff39b 100644 --- a/test/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/test/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.nn.functional as F @@ -52,8 +52,7 @@ def forward(self, x, w, b): return out[0] -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -103,7 +102,7 @@ def check_prim(self, net, use_prim): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): for dtype in self.dtypes: if paddle.device.get_device() == "cpu": @@ -125,8 +124,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -172,7 +170,7 @@ def check_prim(self, net, use_prim): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for dtype in self.dtypes: if paddle.device.get_device() == "cpu": diff --git a/test/dygraph_to_static/test_cinn_prim_mean.py b/test/dygraph_to_static/test_cinn_prim_mean.py index cb32f5b466035e..fe82e9cfe0a5b3 100644 --- a/test/dygraph_to_static/test_cinn_prim_mean.py +++ b/test/dygraph_to_static/test_cinn_prim_mean.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle from paddle import tensor @@ -55,8 +55,7 @@ def forward(self, x): return out -@dy2static_unittest -class TestPrimForward(unittest.TestCase): +class TestPrimForward(Dy2StTestBase): """ This case only tests prim_forward + to_static + cinn. Thus we need to set this flag as False to avoid prim_backward. @@ -112,7 +111,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim_forward(self): for shape in self.shapes: for dtype in self.dtypes: @@ -134,8 +133,7 @@ def test_cinn_prim_forward(self): ) -@dy2static_unittest -class TestPrimForwardAndBackward(unittest.TestCase): +class TestPrimForwardAndBackward(Dy2StTestBase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph """ @@ -187,7 +185,7 @@ def check_prim(self, net, use_prim): # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) - @ast_only_test + @test_ast_only def test_cinn_prim(self): for shape in self.shapes: for dtype in self.dtypes: diff --git a/test/dygraph_to_static/test_closure_analysis.py b/test/dygraph_to_static/test_closure_analysis.py index de1d1e12d6502a..fe390108ed7d5a 100644 --- a/test/dygraph_to_static/test_closure_analysis.py +++ b/test/dygraph_to_static/test_closure_analysis.py @@ -15,10 +15,7 @@ import inspect import unittest -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from numpy import append import paddle @@ -263,7 +260,7 @@ def init_dygraph_func(self): class TestPushPopTrans(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test(self): def vlist_of_dict(x): ma = {'a': []} @@ -274,7 +271,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test2(self): import numpy as np @@ -287,7 +284,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test3(self): import numpy as np @@ -300,7 +297,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test4(self): import numpy as np @@ -313,7 +310,7 @@ def vlist_of_dict(x): x = paddle.to_tensor([3]) print(paddle.jit.to_static(vlist_of_dict)(x)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test5(self): import numpy as np diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 723d3f910debdd..bd21698579d93b 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle import paddle.jit.dy2static as _jst @@ -77,8 +77,7 @@ def dyfunc_with_staticmethod(x_v): return a.add(x_v, x_v) -@dy2static_unittest -class TestRecursiveCall1(unittest.TestCase): +class TestRecursiveCall1(Dy2StTestBase): def setUp(self): self.input = np.random.random([10, 16]).astype('float32') self.place = ( @@ -169,8 +168,7 @@ def forward(self, inputs): return self.act(out) -@dy2static_unittest -class TestRecursiveCall2(unittest.TestCase): +class TestRecursiveCall2(Dy2StTestBase): def setUp(self): self.input = np.random.random((1, 3, 3, 5)).astype('float32') self.place = ( @@ -253,7 +251,6 @@ def test_code(self): ) -@dy2static_unittest class TestNotToConvert2(TestRecursiveCall2): def set_func(self): self.net = NotToStaticHelper() @@ -266,7 +263,7 @@ def test_conversion_options(self): self.assertIsNotNone(options) self.assertTrue(options.not_convert) - @ast_only_test + @test_ast_only def test_code(self): self.dygraph_func = paddle.jit.to_static(self.net.sum) # check 'if statement' is not converted @@ -281,23 +278,22 @@ def forward(self, x): return x -@dy2static_unittest -class TestConvertPaddleAPI(unittest.TestCase): - @ast_only_test +class TestConvertPaddleAPI(Dy2StTestBase): + @test_ast_only def test_functional_api(self): func = paddle.nn.functional.relu func = paddle.jit.to_static(func) self.assertNotIn("_jst.IfElse", func.code) self.assertIn("if in_dynamic_or_pir_mode()", func.code) - @ast_only_test + @test_ast_only def test_class_api(self): bn = paddle.nn.SyncBatchNorm(2) paddle.jit.to_static(bn) self.assertNotIn("_jst.IfElse", bn.forward.code) self.assertIn("if in_dynamic_mode()", bn.forward.code) - @ast_only_test + @test_ast_only def test_class_patch_api(self): paddle.nn.SyncBatchNorm.forward = forward bn = paddle.nn.SyncBatchNorm(2) diff --git a/test/dygraph_to_static/test_convert_call_generator.py b/test/dygraph_to_static/test_convert_call_generator.py index dd9d93c907c552..b3793fa22d289c 100644 --- a/test/dygraph_to_static/test_convert_call_generator.py +++ b/test/dygraph_to_static/test_convert_call_generator.py @@ -14,10 +14,10 @@ import unittest -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -36,11 +36,10 @@ def main_func(): print(i) -@dy2static_unittest -class TestConvertGenerator(unittest.TestCase): +class TestConvertGenerator(Dy2StTestBase): # fallback will ok. - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_raise_error(self): translator_logger.verbosity_level = 1 with self.assertLogs( diff --git a/test/dygraph_to_static/test_convert_operators.py b/test/dygraph_to_static/test_convert_operators.py index 02d0c09a70857c..05a6d4de9c7d9f 100644 --- a/test/dygraph_to_static/test_convert_operators.py +++ b/test/dygraph_to_static/test_convert_operators.py @@ -15,10 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -44,10 +44,9 @@ def forward(self): net.forward = "A string so that convert forward will fail" -@dy2static_unittest -class TestConvertCall(unittest.TestCase): +class TestConvertCall(Dy2StTestBase): # fallback mode will raise a InnerError, it's ok. - @ast_only_test + @test_ast_only def test_class_exception(self): @paddle.jit.to_static def call_not_exist(): @@ -73,8 +72,7 @@ def callable_list(x, y): self.assertEqual(callable_list(1, 2), 3) -@dy2static_unittest -class TestConvertShapeCompare(unittest.TestCase): +class TestConvertShapeCompare(Dy2StTestBase): def test_non_variable(self): self.assertEqual( paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True @@ -136,7 +134,7 @@ def error_func(): False, ) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_variable(self): paddle.enable_static() with paddle.static.program_guard( @@ -210,9 +208,8 @@ def forward(self, x): return out -@dy2static_unittest -class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestChooseShapeAttrOrApiWithLayer(Dy2StTestBase): + @test_legacy_and_pir def test_tensor_shape(self): x = paddle.zeros(shape=[4, 1], dtype='float32') net = ShapeLayer() @@ -221,9 +218,8 @@ def test_tensor_shape(self): np.testing.assert_array_equal(out.numpy(), x.numpy()) -@dy2static_unittest -class TestIfElseNoValue(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestIfElseNoValue(Dy2StTestBase): + @test_legacy_and_pir def test_else_ret_none(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) @@ -253,7 +249,7 @@ def without_common_value(x, use_cache=False): out = without_common_value(input_x, False) self.assertIsNone(out) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_else_ret_c(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) @@ -286,7 +282,7 @@ def without_common_value(x, use_cache=False): self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1)) self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2)) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_else_ret_cz(self): input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) diff --git a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py index b6e55b8900c1e8..1d199dc8138df1 100644 --- a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py +++ b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py @@ -15,18 +15,16 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - ast_only_test, - dy2static_unittest, - sot_only_test, - test_and_compare_with_new_ir, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pir, ) import paddle -@dy2static_unittest -class TestCpuCuda(unittest.TestCase): +class TestCpuCuda(Dy2StTestBase): def test_cpu_cuda(self): def func(x): x = paddle.to_tensor([1, 2, 3, 4]) @@ -39,9 +37,8 @@ def func(x): # print(paddle.jit.to_static(func)(x)) -@dy2static_unittest -class TestToTensor(unittest.TestCase): - @test_and_compare_with_new_ir(False) +class TestToTensor(Dy2StTestBase): + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): ones = paddle.to_tensor(1) @@ -58,10 +55,9 @@ def func(x): ) -@dy2static_unittest -class TestToTensor1(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestToTensor1(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): ones = paddle.to_tensor([1]) @@ -79,8 +75,8 @@ def func(x): rtol=1e-05, ) - @sot_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list_sot(self): def func(x): ones = paddle.to_tensor([1]) @@ -99,10 +95,9 @@ def func(x): ) -@dy2static_unittest -class TestToTensor2(unittest.TestCase): - @ast_only_test - @test_and_compare_with_new_ir(False) +class TestToTensor2(Dy2StTestBase): + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list(self): def func(x): x = paddle.to_tensor([[1], [2], [3], [4]]) @@ -115,8 +110,8 @@ def func(x): rtol=1e-05, ) - @sot_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_to_tensor_with_variable_list_sot(self): def func(x): x = paddle.to_tensor([[1], [2], [3], [4]]) diff --git a/test/dygraph_to_static/test_cycle_gan.py b/test/dygraph_to_static/test_cycle_gan.py index fb06a52407ec61..d069a630b73fe1 100644 --- a/test/dygraph_to_static/test_cycle_gan.py +++ b/test/dygraph_to_static/test_cycle_gan.py @@ -26,10 +26,7 @@ # Use GPU:0 to elimate the influence of other tasks. os.environ["CUDA_VISIBLE_DEVICES"] = "1" -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle from paddle.base.dygraph import to_variable @@ -679,8 +676,7 @@ def train(args, to_static): return np.array(loss_data) -@dy2static_unittest -class TestCycleGANModel(unittest.TestCase): +class TestCycleGANModel(Dy2StTestBase): def setUp(self): self.args = Args() @@ -688,7 +684,7 @@ def train(self, to_static): out = train(self.args, to_static) return out - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train(self): st_out = self.train(to_static=True) dy_out = self.train(to_static=False) diff --git a/test/dygraph_to_static/test_declarative.py b/test/dygraph_to_static/test_declarative.py index f1599a8b907c30..7c6eac567641fa 100644 --- a/test/dygraph_to_static/test_declarative.py +++ b/test/dygraph_to_static/test_declarative.py @@ -19,8 +19,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from test_basic_api_transformation import dyfunc_to_variable @@ -124,8 +124,8 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_with_input_spec(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -186,7 +186,7 @@ def test_with_error(self): ) net.add_func(x, y) - @ast_only_test + @test_ast_only def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -226,8 +226,8 @@ class TestDifferentInputSpecCacheProgram(Dy2StTestBase): def setUp(self): paddle.jit.enable_to_static(True) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_with_different_input(self): with base.dygraph.guard(base.CPUPlace()): x_data = np.ones([16, 10]).astype('float32') @@ -273,7 +273,7 @@ def test_with_different_input(self): recent_program = foo.program_cache.last() self.assertTrue(first_program == recent_program) - @ast_only_test + @test_ast_only def test_get_concrete_program(self): foo = to_static(foo_func) @@ -314,8 +314,8 @@ def test_get_concrete_program(self): InputSpec([10]), InputSpec([10]), e=4 ) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): # usage 1 @@ -364,7 +364,7 @@ def test_nest_input(self): class TestDeclarativeAPI(Dy2StTestBase): - @ast_only_test + @test_ast_only def test_error(self): func = to_static(dyfunc_to_variable) @@ -388,15 +388,15 @@ def setUp(self): paddle.jit.enable_to_static(True) self.x = to_variable(np.ones([4, 10]).astype('float32')) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_fake_input(self): net = SimpleNet() net = to_static(net) y = net(self.x) self.assertTrue(len(net.forward.program_cache) == 1) - @ast_only_test + @test_ast_only def test_input_spec(self): net = SimpleNet() net = to_static(net, input_spec=[InputSpec([None, 8, 10])]) @@ -454,7 +454,7 @@ def func(self): class TestCallNonForwardFunc(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_call_non_forward(self): paddle.disable_static() net = CallNonForwardFuncNet() @@ -494,7 +494,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_set_buffers1(self): paddle.disable_static() net = SetBuffersNet1() @@ -503,7 +503,7 @@ def test_set_buffers1(self): paddle.jit.save(net, self.model_path) paddle.enable_static() - @ast_only_test + @test_ast_only def test_set_buffers2(self): paddle.disable_static() net = SetBuffersNet2() diff --git a/test/dygraph_to_static/test_decorator_transform.py b/test/dygraph_to_static/test_decorator_transform.py index 4f4096d607dc8a..4ab416cceaa105 100644 --- a/test/dygraph_to_static/test_decorator_transform.py +++ b/test/dygraph_to_static/test_decorator_transform.py @@ -21,8 +21,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -186,7 +186,7 @@ def deco_with_paddle_api(): class TestDecoratorTransform(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_deco_transform(self): outs = paddle.jit.to_static(forward)() np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05) @@ -198,7 +198,7 @@ def test_deco_transform(self): np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05) np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05) - @ast_only_test + @test_ast_only def test_contextmanager_warning(self): paddle.disable_static() with warnings.catch_warnings(record=True) as w: @@ -215,7 +215,7 @@ def test_contextmanager_warning(self): break self.assertTrue(flag) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_deco_with_paddle_api(self): self.assertTrue(deco_with_paddle_api()) diff --git a/test/dygraph_to_static/test_deepcopy.py b/test/dygraph_to_static/test_deepcopy.py index 82ffeaf9f2290c..5d281ba8ea213a 100644 --- a/test/dygraph_to_static/test_deepcopy.py +++ b/test/dygraph_to_static/test_deepcopy.py @@ -16,19 +16,15 @@ from copy import deepcopy import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir from test_rollback import Net, foo import paddle from paddle.jit.dy2static.program_translator import StaticFunction -# @dy2static_unittest class TestDeepCopy(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_net(self): net = Net() net = paddle.jit.to_static(net) @@ -44,7 +40,7 @@ def test_net(self): self.assertTrue(id(copy_net), id(copy_net.forward.__self__)) np.testing.assert_array_equal(src_out.numpy(), copy_out.numpy()) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_func(self): st_foo = paddle.jit.to_static(foo) x = paddle.randn([3, 4]) diff --git a/test/dygraph_to_static/test_dict.py b/test/dygraph_to_static/test_dict.py index 99364c1343a7d6..c88496fd86b3e1 100644 --- a/test/dygraph_to_static/test_dict.py +++ b/test/dygraph_to_static/test_dict.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import base @@ -119,8 +116,7 @@ def update_cache(cache): return cache -@dy2static_unittest -class TestNetWithDict(unittest.TestCase): +class TestNetWithDict(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -130,7 +126,7 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.batch_size = self.x.shape[0] - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run_static(self): return self.train(to_static=True) @@ -173,8 +169,7 @@ def test_dic_pop_2(x): return out -@dy2static_unittest -class TestDictPop(unittest.TestCase): +class TestDictPop(Dy2StTestBase): def setUp(self): self.input = np.random.random(3).astype('int32') self.place = ( @@ -187,7 +182,7 @@ def setUp(self): def _set_test_func(self): self.dygraph_func = test_dic_pop - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run_static(self): return self._run(to_static=True) @@ -254,8 +249,7 @@ def test_ast_to_func(self): ) -@dy2static_unittest -class TestDictCmpInFor(unittest.TestCase): +class TestDictCmpInFor(Dy2StTestBase): def test_with_for(self): def func(): pos = [1, 3] diff --git a/test/dygraph_to_static/test_drop_path.py b/test/dygraph_to_static/test_drop_path.py index aad752007ceb0c..d559ce7f55ac29 100644 --- a/test/dygraph_to_static/test_drop_path.py +++ b/test/dygraph_to_static/test_drop_path.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -39,15 +36,14 @@ def forward(self, x): return drop_path(x, self.training) -@dy2static_unittest -class TestTrainEval(unittest.TestCase): +class TestTrainEval(Dy2StTestBase): def setUp(self): self.model = DropPath() def tearDown(self): pass - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_train_and_eval(self): x = paddle.to_tensor([1, 2, 3]).astype("int64") eval_out = x.numpy() diff --git a/test/dygraph_to_static/test_duplicate_output.py b/test/dygraph_to_static/test_duplicate_output.py index c7f1e21b3552ab..70637729671f0b 100644 --- a/test/dygraph_to_static/test_duplicate_output.py +++ b/test/dygraph_to_static/test_duplicate_output.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ( - dy2static_unittest, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -41,8 +38,7 @@ def forward(self, x): return x, x -@dy2static_unittest -class TestDuplicateOutput(unittest.TestCase): +class TestDuplicateOutput(Dy2StTestBase): """ TestCase for the transformation from control flow `if/else` dependent on tensor in Dygraph into Static `base.layers.cond`. @@ -52,7 +48,7 @@ def setUp(self): self.net = paddle.jit.to_static(SimpleNet()) self.x = paddle.to_tensor([1.0]) - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def _run_static(self): param = self.net.parameters() param[0].clear_grad() diff --git a/test/dygraph_to_static/test_grid_generator.py b/test/dygraph_to_static/test_grid_generator.py index 7c1a9189366e0e..586302f385574e 100644 --- a/test/dygraph_to_static/test_grid_generator.py +++ b/test/dygraph_to_static/test_grid_generator.py @@ -15,10 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir import paddle from paddle import ParamAttr, nn @@ -133,7 +130,7 @@ class TestGridGenerator(Dy2StTestBase): def setUp(self): self.x = paddle.uniform(shape=[1, 20, 2], dtype='float32') - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def _run(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 12db665b8c822a..e1840560b13440 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -147,7 +147,7 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() - # TODO(zhangbo): open pir test (sub block cannot find var in parent block) + @test_and_compare_with_new_ir() def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -274,7 +274,7 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() - # TODO(zhangbo): open pir test (abnormal insertion of fill constant op after conditional block op) + @test_and_compare_with_new_ir() def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -305,7 +305,7 @@ def _run(self, to_static=False): ret = net(x_v) return ret.numpy() - # TODO(zhangbo): open pir test (sub block cannot find var in parent block) + @test_and_compare_with_new_ir() def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -503,8 +503,8 @@ def setUp(self): self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3) self.out = self.get_dy2stat_out() - # TODO(zhangbo): open pir test (abnormal insertion of fill constant op after conditional block op) @ast_only_test + @test_and_compare_with_new_ir() def test_ast_to_func(self): self.setUp() self.assertIsInstance(self.out, (paddle.Tensor, core.eager.Tensor)) diff --git a/test/dygraph_to_static/test_load_transformer.py b/test/dygraph_to_static/test_load_transformer.py index 81a45fb91cc4ef..65f16a8bdcb2d4 100644 --- a/test/dygraph_to_static/test_load_transformer.py +++ b/test/dygraph_to_static/test_load_transformer.py @@ -16,10 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_utils_new import ( - Dy2StTestBase, - test_and_compare_with_new_ir, -) +from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir import paddle @@ -48,7 +45,7 @@ class TestFallback(Dy2StTestBase): def setUp(self): self.x = paddle.to_tensor(1.0).astype('int') - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_name_load(self): net_dy = Net() net_st = Net() @@ -58,7 +55,7 @@ def test_name_load(self): class TestLoad2(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_name_load_nograd(self): @paddle.no_grad() def func(x): diff --git a/test/dygraph_to_static/test_partial_program.py b/test/dygraph_to_static/test_partial_program.py index a521b113373454..cc3c5678c48431 100644 --- a/test/dygraph_to_static/test_partial_program.py +++ b/test/dygraph_to_static/test_partial_program.py @@ -17,8 +17,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) from test_fetch_feed import Linear @@ -89,7 +89,7 @@ def _run(self, to_static): return out.numpy() - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_nest(self): dygraph_res = self._run(to_static=False) static_res = self._run(to_static=True) @@ -116,7 +116,7 @@ def _run(self, to_static): return out - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_nest(self): dygraph_res = self._run(to_static=False) dygraph_res = paddle.utils.flatten(dygraph_res) @@ -136,8 +136,8 @@ def test_nest(self): class TestWithTrainAndEval(Dy2StTestBase): - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_switch_eval_and_train(self): with base.dygraph.guard(): linear_net = Linear() @@ -169,8 +169,8 @@ def test_switch_eval_and_train(self): class TestWithNoGrad(Dy2StTestBase): - @ast_only_test - @test_and_compare_with_new_ir(False) + @test_ast_only + @test_legacy_and_pir def test_with_no_grad(self): with base.dygraph.guard(): linear_net = Linear() @@ -205,7 +205,7 @@ def forward(self, x): class TestPruneUnusedParamInProgram(Dy2StTestBase): - @test_and_compare_with_new_ir(False) + @test_legacy_and_pir def test_prune(self): input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32") diff --git a/test/dygraph_to_static/test_partial_program_hook.py b/test/dygraph_to_static/test_partial_program_hook.py index c10194f6187adf..950fb570e635a2 100644 --- a/test/dygraph_to_static/test_partial_program_hook.py +++ b/test/dygraph_to_static/test_partial_program_hook.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest from dygraph_to_static_util import dy2static_unittest import paddle from paddle.base import core +from paddle.jit.api import ENV_ENABLE_SOT from paddle.jit.dy2static import partial_program, program_translator @dy2static_unittest class TestPartiaProgramLayerHook(unittest.TestCase): def setUp(self): - os.environ["ENABLE_FALL_BACK"] = "False" + ENV_ENABLE_SOT.set(False) self._hook = partial_program.PartialProgramLayerHook() def test_before_append_backward(self): @@ -41,7 +41,7 @@ def test_after_infer(self): @dy2static_unittest class TestPrimHook(unittest.TestCase): def setUp(self): - os.environ["ENABLE_FALL_BACK"] = "False" + ENV_ENABLE_SOT.set(False) core._set_prim_all_enabled(False) def f(): diff --git a/test/dygraph_to_static/test_set_dynamic_shape.py b/test/dygraph_to_static/test_set_dynamic_shape.py index 0f6859f49e92e0..3a3843846a9a4a 100644 --- a/test/dygraph_to_static/test_set_dynamic_shape.py +++ b/test/dygraph_to_static/test_set_dynamic_shape.py @@ -14,13 +14,13 @@ import unittest -from dygraph_to_static_utils_new import Dy2StTestBase, ast_only_test +from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only import paddle class TestSetDynamicShape(Dy2StTestBase): - @ast_only_test + @test_ast_only def test_start(self): def dygraph_func(loop_number): mask = paddle.randn([2, 2]) diff --git a/test/dygraph_to_static/test_spec_names.py b/test/dygraph_to_static/test_spec_names.py index 72ffdc845134a8..7f2f9683e0951b 100644 --- a/test/dygraph_to_static/test_spec_names.py +++ b/test/dygraph_to_static/test_spec_names.py @@ -16,8 +16,8 @@ from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + test_ast_only, + test_legacy_and_pir, ) import paddle @@ -48,8 +48,8 @@ def read_from_dataset(self): self.m = paddle.randn([4, 2, 8]) self.n = paddle.randn([4, 2, 8]) - @test_and_compare_with_new_ir(False) - @ast_only_test + @test_legacy_and_pir + @test_ast_only def test_spec_name_hash(self): net = Net() net = paddle.jit.to_static(net) diff --git a/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py b/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py index a8e955be9e8634..09f90837c6086b 100644 --- a/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py +++ b/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import test_and_compare_with_new_ir import paddle @@ -47,6 +48,7 @@ def _run(self, to_static): x2 = tensor_copy_to_cpu(x1) return x1.place, x2.place, x2.numpy() + @test_and_compare_with_new_ir(False) def test_tensor_cpu_on_default_cpu(self): paddle.base.framework._set_expected_place(paddle.CPUPlace()) dygraph_x1_place, dygraph_place, dygraph_res = self._run( @@ -67,6 +69,7 @@ def _run(self, to_static): x2 = tensor_copy_to_cuda(x1) return x1.place, x2.place, x2.numpy() + @test_and_compare_with_new_ir(False) def test_tensor_cuda_on_default_cpu(self): if not paddle.base.is_compiled_with_cuda(): return diff --git a/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py b/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py index 30e8e556119596..230844e6573823 100644 --- a/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py +++ b/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import test_and_compare_with_new_ir import paddle @@ -48,6 +49,7 @@ def _run(self, to_static): x2 = tensor_copy_to_cpu(x1) return x1.place, x2.place, x2.numpy() + @test_and_compare_with_new_ir(False) def test_tensor_cpu_on_default_gpu(self): if paddle.base.is_compiled_with_cuda(): place = paddle.CUDAPlace( @@ -74,6 +76,7 @@ def _run(self, to_static): x2 = tensor_copy_to_cuda(x1) return x1.place, x2.place, x2.numpy() + @test_and_compare_with_new_ir(False) def test_tensor_cuda_on_default_gpu(self): if paddle.base.is_compiled_with_cuda(): place = paddle.CUDAPlace( diff --git a/test/dygraph_to_static/test_tensor_shape.py b/test/dygraph_to_static/test_tensor_shape.py index d8c13cff351931..23dccb0f610938 100644 --- a/test/dygraph_to_static/test_tensor_shape.py +++ b/test/dygraph_to_static/test_tensor_shape.py @@ -17,8 +17,8 @@ import numpy as np from dygraph_to_static_utils_new import ( Dy2StTestBase, - ast_only_test, - test_and_compare_with_new_ir, + compare_legacy_with_pir, + test_ast_only, ) import paddle @@ -266,7 +266,7 @@ def _run(self, to_static): def get_dygraph_output(self): return self._run(to_static=False) - @test_and_compare_with_new_ir(True) + @compare_legacy_with_pir def get_static_output(self): return self._run(to_static=True) @@ -293,7 +293,7 @@ def _compute_op_num(self, program): [op for op in block.ops if op.type == "slice"] ) - @ast_only_test + @test_ast_only def test_op_num(self): static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) program = static_layer.main_program @@ -526,7 +526,7 @@ def _compute_op_num(self, program): [op for op in block.ops if op.type == "slice"] ) - @ast_only_test + @test_ast_only def test_op_num(self): static_layer = paddle.jit.to_static(self.dygraph_func, self.input_spec) program = static_layer.main_program @@ -617,7 +617,7 @@ def dyfunc_with_static_convert_var_shape(x): class TestFindStatiConvertVarShapeSuffixVar(Dy2StTestBase): - @ast_only_test + @test_ast_only def test(self): x_spec = paddle.static.InputSpec(shape=[None, 10]) func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec]) diff --git a/test/dygraph_to_static/test_typehint.py b/test/dygraph_to_static/test_typehint.py index 563db1d7a1df04..d126b8e9316e9e 100644 --- a/test/dygraph_to_static/test_typehint.py +++ b/test/dygraph_to_static/test_typehint.py @@ -37,7 +37,7 @@ def function(x: A) -> A: @dy2static_unittest -class TestTransformWhileLoop(unittest.TestCase): +class TestTypeHint(unittest.TestCase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -77,10 +77,5 @@ def test_ast_to_func(self): np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05) -class TestTypeHint(TestTransformWhileLoop): - def _init_dyfunc(self): - self.dyfunc = function - - if __name__ == '__main__': unittest.main() diff --git a/test/ir/new_ir/CMakeLists.txt b/test/ir/new_ir/CMakeLists.txt index cad2633fb1aa49..e93b2ced83922e 100644 --- a/test/ir/new_ir/CMakeLists.txt +++ b/test/ir/new_ir/CMakeLists.txt @@ -19,3 +19,5 @@ foreach(target ${TEST_IR_SYSTEM_CASES}) endforeach() set_tests_properties(test_pd_inplace_pass PROPERTIES TIMEOUT 60) + +add_subdirectory(fused_pass) diff --git a/test/ir/new_ir/fused_pass/CMakeLists.txt b/test/ir/new_ir/fused_pass/CMakeLists.txt new file mode 100644 index 00000000000000..8876db2d4b7942 --- /dev/null +++ b/test/ir/new_ir/fused_pass/CMakeLists.txt @@ -0,0 +1,9 @@ +file( + GLOB TEST_INTERP_CASES + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") + +foreach(target ${TEST_INTERP_CASES}) + py_test_modules(${target} MODULES ${target}) +endforeach() diff --git a/test/ir/new_ir/fused_pass/test_fused_dropout_add_pass.py b/test/ir/new_ir/fused_pass/test_fused_dropout_add_pass.py new file mode 100644 index 00000000000000..af413d0e2096e7 --- /dev/null +++ b/test/ir/new_ir/fused_pass/test_fused_dropout_add_pass.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.autograd.ir_backward import grad +from paddle.base import core + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestFusedDropoutAdd(unittest.TestCase): + def _test_fused_dropout_add(self): + with paddle.pir_utils.IrGuard(): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data(name="x", shape=[3, 2], dtype="float32") + y = paddle.static.data(name="y", shape=[3, 2], dtype="float32") + res1 = paddle.nn.functional.dropout(x=x, p=0.5, training=True) + res2 = paddle.add(res1, y) + res3 = paddle.sum(res2) + + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue('pd_op.dropout' in op_names) + self.assertTrue('pd_op.add' in op_names) + pm = paddle.pir.PassManager() + pm.add_pass( + 'fused_dropout_add_pass' + ) # apply pass to elimitate dead code + pm.run(main_program) + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue('pd_op.fused_dropout_add' in op_names) + self.assertTrue('pd_op.dropout' not in op_names) + + x_np = np.ones([3, 2]).astype("float32") + y_np = x_np + + exe = paddle.base.Executor(paddle.base.CUDAPlace(0)) + fetches = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[res3], + ) + + def test_fused_dropout_add_grad(self): + with paddle.pir_utils.IrGuard(): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data(name="x", shape=[3, 2], dtype="float32") + x.stop_gradient = False + y = paddle.static.data(name="y", shape=[3, 2], dtype="float32") + y.stop_gradient = False + dout = paddle.static.data( + name="dout", shape=[3, 2], dtype="float32" + ) + res0 = paddle.assign(x) + res1 = paddle.nn.functional.dropout( + x=res0, p=0.5, training=True + ) + res2 = paddle.add(res1, y) + res3 = paddle.sum(res2) + + # res4 = paddle.incubate.nn.functional.fused_dropout_add( x, y, p=0.5, training=True) + # res5 = paddle.sum(res4) + dx = grad(res3, x) + + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue( + 'pd_op.dropout' in op_names and 'pd_op.add' in op_names + ) + self.assertTrue( + 'pd_op.add_grad' in op_names + and 'pd_op.dropout_grad' in op_names + ) + pm = paddle.pir.PassManager() + pm.add_pass( + 'fused_dropout_add_pass' + ) # apply pass to elimitate dead code + pm.run(main_program) + op_names = [op.name() for op in main_program.global_block().ops] + self.assertTrue( + 'pd_op.fused_dropout_add' in op_names + and 'pd_op.fused_dropout_add_grad' in op_names + ) + self.assertTrue( + 'pd_op.dropout' not in op_names + and 'pd_op.dropout_grad' not in op_names + ) + + x_np = np.ones([3, 2]).astype("float32") + y_np = x_np + + exe = paddle.base.Executor(paddle.base.CUDAPlace(0)) + fetches = exe.run( + main_program, + feed={"x": x_np, "y": y_np, "dout": y_np}, + fetch_list=[dx], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/new_ir/test_ir_vjp.py b/test/ir/new_ir/test_ir_vjp.py index d0e630fccff723..01f63a272d8f3f 100644 --- a/test/ir/new_ir/test_ir_vjp.py +++ b/test/ir/new_ir/test_ir_vjp.py @@ -43,7 +43,13 @@ def test_tanh_vjp1(self): out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[False]] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) + grad_outs = call_vjp( + tanh_op, + [[tanh_op.operand_source(0)]], + [[tanh_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.tanh_grad" ) @@ -74,7 +80,13 @@ def test_tanh_vjp2(self): out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[True]] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) + grad_outs = call_vjp( + tanh_op, + [[tanh_op.operand_source(0)]], + [[tanh_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual(grad_outs[0][0], None) @@ -95,7 +107,13 @@ def test_mean_vjp1(self): out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[False]] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(mean_op, out_grads, stop_gradients) + grad_outs = call_vjp( + mean_op, + [[mean_op.operand_source(0)]], + [[mean_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.mean_grad" ) @@ -135,7 +153,13 @@ def test_mean_vjp2(self): out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[True]] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(mean_op, out_grads, stop_gradients) + grad_outs = call_vjp( + mean_op, + [[mean_op.operand_source(0)]], + [[mean_op.result(0)]], + out_grads, + stop_gradients, + ) self.assertEqual(grad_outs[0][0], None) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 96a15b04ab8a2e..8ae246e042af22 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1392,6 +1392,17 @@ foreach(IR_OP_TEST ${NEW_IR_OP_TESTS}) endif() endforeach() +file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/new_ir_op_test_no_check_list" + NEW_IR_OP_NO_CHECK_TESTS) +foreach(IR_OP_TEST ${NEW_IR_OP_NO_CHECK_TESTS}) + if(TEST ${IR_OP_TEST}) + set_tests_properties(${IR_OP_TEST} PROPERTIES ENVIRONMENT + "FLAGS_NEW_IR_NO_CHECK=True") + else() + message(STATUS "NewIR OpTest: not found ${IR_OP_TEST} in legacy_test") + endif() +endforeach() + file(STRINGS "${CMAKE_SOURCE_DIR}/test/white_list/new_ir_op_test_precision_white_list" NEW_IR_OP_RELAXED_TESTS) diff --git a/test/legacy_test/distributed_fused_lamb_test_base.py b/test/legacy_test/distributed_fused_lamb_test_base.py index ea011becc9090c..348191e66d7d55 100644 --- a/test/legacy_test/distributed_fused_lamb_test_base.py +++ b/test/legacy_test/distributed_fused_lamb_test_base.py @@ -270,7 +270,10 @@ def setUpClass(cls): paddle.enable_static() paddle.set_flags({'FLAGS_cudnn_deterministic': True}) _clip_by_global_norm_using_mp_type(True) - if os.environ.get("FLAGS_dynamic_static_unified_comm") == "1": + if ( + os.environ.get("FLAGS_dynamic_static_unified_comm", "false").lower() + == "true" + ): paddle.distributed.collective._init_parallel_env("nccl") else: fleet.init(role_maker=get_role_maker()) diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 77ca5512d2b4fa..351af794e1e015 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -1435,10 +1435,12 @@ def _check_ir_output(self, place, program, feed_map, fetch_list, outs): ), "Fetch result should have same length when executed in pir" check_method = np.testing.assert_array_equal - if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None): + if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None) == "True": check_method = lambda x, y, z: np.testing.assert_allclose( x, y, err_msg=z, atol=1e-6, rtol=1e-6 ) + if os.getenv("FLAGS_NEW_IR_NO_CHECK", None) == "True": + check_method = lambda x, y, err_msg: None for i in range(len(outs)): check_method( @@ -3368,11 +3370,14 @@ def _check_ir_grad_output( ) check_method = np.testing.assert_array_equal - if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None): + if os.getenv("FLAGS_NEW_IR_OPTEST_RELAX_CHECK", None) == "True": check_method = lambda x, y, z: np.testing.assert_allclose( x, y, err_msg=z, atol=1e-6, rtol=1e-6 ) + if os.getenv("FLAGS_NEW_IR_NO_CHECK", None) == "True": + check_method = lambda x, y, err_msg: None + for i in range(len(new_gradients)): check_method( gradients[i], diff --git a/test/legacy_test/test_accuracy_op.py b/test/legacy_test/test_accuracy_op.py index 2acb9aa121e18f..44c4cfa7c49ac3 100755 --- a/test/legacy_test/test_accuracy_op.py +++ b/test/legacy_test/test_accuracy_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api def accuracy_wrapper(infer, indices, label): @@ -53,7 +54,7 @@ def init_dtype(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAccuracyOpFp16(TestAccuracyOp): @@ -61,7 +62,7 @@ def init_dtype(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(atol=1e-3) + self.check_output(atol=1e-3, check_pir=True) @unittest.skipIf( @@ -103,7 +104,7 @@ def init_dtype(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-2) + self.check_output_with_place(place, atol=1e-2, check_pir=True) class TestAccuracyOpError(unittest.TestCase): @@ -142,35 +143,38 @@ def test_value_errors(self): class TestAccuracyAPI1(unittest.TestCase): - def setUp(self): + def run_api(self, accuracy_api): with paddle_static_guard(): - self.predictions = paddle.static.data( - shape=[2, 5], name="predictions", dtype="float32" - ) - self.label = paddle.static.data( - shape=[2, 1], name="labels", dtype="int64" - ) - self.result = paddle.static.accuracy( - input=self.predictions, label=self.label, k=1 - ) - self.input_predictions = np.array( - [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], - dtype="float32", - ) - self.input_labels = np.array([[2], [0]], dtype="int64") - self.expect_value = np.array([0.5], dtype='float32') + with paddle.static.program_guard(paddle.static.Program()): + self.predictions = paddle.static.data( + shape=[2, 5], name="predictions", dtype="float32" + ) + self.label = paddle.static.data( + shape=[2, 1], name="labels", dtype="int64" + ) + self.result = accuracy_api( + input=self.predictions, label=self.label, k=1 + ) + self.input_predictions = np.array( + [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], + dtype="float32", + ) + self.input_labels = np.array([[2], [0]], dtype="int64") + self.expect_value = np.array([0.5], dtype='float32') + exe = paddle.static.Executor() + (result,) = exe.run( + feed={ + "predictions": self.input_predictions, + 'labels': self.input_labels, + }, + fetch_list=[self.result], + ) + self.assertEqual((result == self.expect_value).all(), True) + @test_with_pir_api def test_api(self): - with paddle_static_guard(): - exe = paddle.static.Executor() - (result,) = exe.run( - feed={ - "predictions": self.input_predictions, - 'labels': self.input_labels, - }, - fetch_list=[self.result.name], - ) - self.assertEqual((result == self.expect_value).all(), True) + self.run_api(accuracy_api=paddle.static.accuracy) + self.run_api(accuracy_api=paddle.metric.accuracy) class TestAccuracyAPI2(unittest.TestCase): diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index d1da7d941a679e..5cfdcca9983d02 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -205,9 +205,10 @@ def init_dtype(self): class Test_Exp_Op_Fp16(unittest.TestCase): + @test_with_pir_api def test_api_fp16(self): with static_guard(): - with static.program_guard( + with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): np_x = np.array([[2, 3, 4], [7, 8, 9]]) @@ -559,6 +560,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -781,6 +783,7 @@ def setUp(self): def executed_api(self): self.tanh = F.tanh + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2466,12 +2469,12 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_pir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestLeakyReluAlpha1(TestLeakyRelu): @@ -2508,6 +2511,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -3490,10 +3494,13 @@ def setUp(self): self.outputs = {'Out': out} self.convert_input_output() + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestLog10_ZeroDim(TestLog10): @@ -3512,21 +3519,23 @@ def test_api_int(self): np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) paddle.enable_static() + @test_with_pir_api def test_api_bf16(self): - with static_guard(): - with static.program_guard( - paddle.static.Program(), paddle.static.Program() - ): - x = [[2, 3, 4], [7, 8, 9]] - x = paddle.to_tensor(x, dtype='bfloat16') - out = paddle.log10(x) - if core.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - (res,) = exe.run(fetch_list=[out]) + paddle.enable_static() + with static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = [[2, 3, 4], [7, 8, 9]] + x = paddle.to_tensor(x, dtype='bfloat16') + out = paddle.log10(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + (res,) = exe.run(fetch_list=[out]) class TestLog10API(unittest.TestCase): + @test_with_pir_api def test_api(self): with static_guard(): with paddle.static.program_guard( @@ -4737,7 +4746,7 @@ def test_check_grad(self): create_test_act_fp16_class(TestLog2) else: create_test_act_fp16_class(TestLog2) -create_test_act_fp16_class(TestLog10) +create_test_act_fp16_class(TestLog10, check_pir=True) create_test_act_fp16_class(TestLog1p) create_test_act_fp16_class(TestSquare, check_pir=True) create_test_act_fp16_class(TestPow, check_prim=True, check_prim_pir=True) @@ -4750,7 +4759,9 @@ def test_check_grad(self): create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestHardSwish, check_prim=True) create_test_act_fp16_class(TestMish) -create_test_act_fp16_class(TestLeakyRelu, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True +) create_test_act_fp16_class( TestLeakyReluAlpha1, check_prim=True, enable_cinn=True ) @@ -4885,7 +4896,7 @@ def test_check_grad(self): create_test_act_bf16_class(TestLog2) else: create_test_act_bf16_class(TestLog2) -create_test_act_bf16_class(TestLog10) +create_test_act_bf16_class(TestLog10, check_pir=True) create_test_act_bf16_class(TestLog1p) create_test_act_bf16_class(TestSquare, check_pir=True) create_test_act_bf16_class(TestPow, check_prim=True) @@ -4898,7 +4909,7 @@ def test_check_grad(self): create_test_act_bf16_class(TestSwish) create_test_act_bf16_class(TestHardSwish, check_prim=True) create_test_act_bf16_class(TestMish) -create_test_act_bf16_class(TestLeakyRelu, check_prim=True) +create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True) create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True) create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True) create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True) diff --git a/test/legacy_test/test_batch_norm_op.py b/test/legacy_test/test_batch_norm_op.py index 284826d7b4e530..06b0682a2db009 100644 --- a/test/legacy_test/test_batch_norm_op.py +++ b/test/legacy_test/test_batch_norm_op.py @@ -375,6 +375,77 @@ def test_check_output(self): def init_kernel_type(self): pass + def check_without_scale_and_bias(self, place, data_layout, dtype, shape): + epsilon = 0.00001 + if len(shape) == 2: + x_shape = shape + c = x_shape[1] + else: + n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if data_layout == "NHWC": + x_shape = [n, h, w, c] + elif data_layout == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data layout.") + scale_shape = [c] + + if dtype == np.uint16: + x_val = np.random.random_sample(x_shape).astype(np.float32) + else: + x_val = np.random.random_sample(x_shape).astype(dtype) + # generate some negative values to test case with relu fused + x_val = x_val - 0.5 + scale_val = np.ones(scale_shape).astype(np.float32) + bias_val = np.zeros(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + + if dtype == np.uint16: + y_out = _reference_testing( + x_val, scale_val, bias_val, mean, variance, epsilon, data_layout + ).astype(np.float32) + y_out = convert_float_to_uint16(y_out) + else: + y_out = _reference_testing( + x_val, scale_val, bias_val, mean, variance, epsilon, data_layout + ).astype(dtype) + if self.fuse_with_relu: + y_out = np.maximum(y_out, 0) + + if dtype == np.uint16: + x_val = convert_float_to_uint16(x_val) + + y_tensor, _, _, _, _, _ = paddle.nn.functional.batch_norm( + paddle.to_tensor(x_val), + paddle.to_tensor(mean), + paddle.to_tensor(variance), + None, + None, + False, + ) + + # check inference result + atol = 1e-3 + if dtype == np.uint16: + y_tensor = convert_uint16_to_float(y_tensor) + y_out = convert_uint16_to_float(y_out) + atol = 1e-2 + self.__assert_close( + y_tensor, + y_out, + "inference output are different at " + + str(place) + + ", " + + data_layout + + ", " + + str(np.dtype(dtype)) + + str(np.array(y_tensor)) + + str(y_out), + atol=atol, + ) + class TestFP16BatchNormOpInference(TestBatchNormOpInference): def setUp(self): diff --git a/test/legacy_test/test_batch_norm_op_v2.py b/test/legacy_test/test_batch_norm_op_v2.py index 639011460b102c..4ae4c609ea1de7 100644 --- a/test/legacy_test/test_batch_norm_op_v2.py +++ b/test/legacy_test/test_batch_norm_op_v2.py @@ -192,24 +192,91 @@ def compute_v3(x, is_test, trainable_statistics): ), trainable_statistics=trainable_statistics, ) - y = bn(paddle.to_tensor(x)) - return y.numpy() + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v3_1(x, is_test, trainable_statistics): + with base.dygraph.guard(p): + bn = paddle.nn.BatchNorm( + shape[1], + is_test=is_test, + param_attr=False, + bias_attr=False, + trainable_statistics=trainable_statistics, + ) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v3_2(x, is_test, trainable_statistics): + with base.dygraph.guard(p): + bn = paddle.nn.BatchNorm( + shape[1], + is_test=is_test, + param_attr=False, + bias_attr=base.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0), + trainable=False, + ), + trainable_statistics=trainable_statistics, + ) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v3_3(x, is_test, trainable_statistics): + with base.dygraph.guard(p): + bn = paddle.nn.BatchNorm( + shape[1], + is_test=is_test, + param_attr=base.ParamAttr( + initializer=paddle.nn.initializer.Constant(1.0), + trainable=False, + ), + bias_attr=False, + trainable_statistics=trainable_statistics, + ) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() def compute_v4(x): with base.dygraph.guard(p): bn = paddle.nn.BatchNorm2D( shape[1], weight_attr=False, bias_attr=False ) - y = bn(paddle.to_tensor(x)) - return y.numpy() + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() x = np.random.randn(*shape).astype("float32") y1 = compute_v1(x, False, False) y2 = compute_v2(x) - y3 = compute_v3(x, False, False) - y4 = compute_v4(x) + y3, g3 = compute_v3(x, False, False) + y3_1, g3_1 = compute_v3_1(x, False, False) + y3_2, g3_2 = compute_v3_2(x, False, False) + y3_3, g3_3 = compute_v3_3(x, False, False) + y4, g4 = compute_v4(x) np.testing.assert_allclose(y1, y2, rtol=1e-05) np.testing.assert_allclose(y3, y4, rtol=1e-05) + np.testing.assert_allclose(y3_1, y4, rtol=1e-05) + np.testing.assert_allclose(y3_2, y4, rtol=1e-05) + np.testing.assert_allclose(y3_3, y4, rtol=1e-05) + np.testing.assert_allclose(g3, g4, rtol=1e-05) + np.testing.assert_allclose(g3_1, g4, rtol=1e-05) + np.testing.assert_allclose(g3_2, g4, rtol=1e-05) + np.testing.assert_allclose(g3_3, g4, rtol=1e-05) @test_with_pir_api def test_static(self): diff --git a/test/legacy_test/test_bitwise_op.py b/test/legacy_test/test_bitwise_op.py index 21a7abe812ad7a..eb3ec980f05fdb 100644 --- a/test/legacy_test/test_bitwise_op.py +++ b/test/legacy_test/test_bitwise_op.py @@ -150,7 +150,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): pass @@ -258,7 +258,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): pass @@ -363,7 +363,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): pass diff --git a/test/legacy_test/test_cholesky_op.py b/test/legacy_test/test_cholesky_op.py index 034cbb87366fa4..832941395213a6 100644 --- a/test/legacy_test/test_cholesky_op.py +++ b/test/legacy_test/test_cholesky_op.py @@ -58,7 +58,7 @@ def setUp(self): self.outputs = {"Out": output_data} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): places = [base.CPUPlace()] diff --git a/test/legacy_test/test_cholesky_solve_op.py b/test/legacy_test/test_cholesky_solve_op.py index c1c9e4e7400bce..76f3e2e2a64ebe 100644 --- a/test/legacy_test/test_cholesky_solve_op.py +++ b/test/legacy_test/test_cholesky_solve_op.py @@ -139,7 +139,7 @@ def setUp(self): # check Op forward result def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) # check Op grad def test_check_grad_normal(self): diff --git a/test/legacy_test/test_collective_api_base.py b/test/legacy_test/test_collective_api_base.py index a431d77cdfe713..8f6a382297a1f1 100644 --- a/test/legacy_test/test_collective_api_base.py +++ b/test/legacy_test/test_collective_api_base.py @@ -189,7 +189,8 @@ def runtime_main(test_class, col_type): args["reduce_type"] = os.getenv("REDUCE_TYPE") args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0"))) args["dynamic_static_unified_comm"] = bool( - int(os.getenv("FLAGS_dynamic_static_unified_comm", "0")) + os.getenv("FLAGS_dynamic_static_unified_comm", "false").lower() + == "true" ) model.run_trainer(args) diff --git a/test/legacy_test/test_compare_reduce_op.py b/test/legacy_test/test_compare_reduce_op.py index e281407c242b01..fdd08b2990cfe7 100644 --- a/test/legacy_test/test_compare_reduce_op.py +++ b/test/legacy_test/test_compare_reduce_op.py @@ -32,7 +32,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'not_equal_all') Cls.__name__ = cls_name @@ -51,7 +51,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'not_shape_equal_all') Cls.__name__ = cls_name @@ -69,7 +69,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'equal_all') Cls.__name__ = cls_name @@ -89,7 +89,7 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output() + self.check_output(check_pir=True) cls_name = "{}_{}_{}".format(op_type, typename, 'equal_all') Cls.__name__ = cls_name diff --git a/test/legacy_test/test_complex_view_op.py b/test/legacy_test/test_complex_view_op.py index b747804ca65c5c..c529e3950a9fbd 100644 --- a/test/legacy_test/test_complex_view_op.py +++ b/test/legacy_test/test_complex_view_op.py @@ -20,6 +20,7 @@ import paddle from paddle import static from paddle.base import dygraph +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -43,7 +44,7 @@ def setUp(self): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -64,7 +65,7 @@ def setUp(self): self.python_api = paddle.as_real def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -84,6 +85,7 @@ def test_dygraph(self): out_np = paddle.as_complex(x).numpy() np.testing.assert_allclose(self.out, out_np, rtol=1e-05) + @test_with_pir_api def test_static(self): mp, sp = static.Program(), static.Program() with static.program_guard(mp, sp): @@ -107,6 +109,7 @@ def test_dygraph(self): out_np = paddle.as_real(x).numpy() np.testing.assert_allclose(self.out, out_np, rtol=1e-05) + @test_with_pir_api def test_static(self): mp, sp = static.Program(), static.Program() with static.program_guard(mp, sp): diff --git a/test/legacy_test/test_conv1d_layer.py b/test/legacy_test/test_conv1d_layer.py index e284c25568abff..8c2264b1604b17 100644 --- a/test/legacy_test/test_conv1d_layer.py +++ b/test/legacy_test/test_conv1d_layer.py @@ -20,6 +20,7 @@ import paddle.base.dygraph as dg import paddle.nn.functional as F from paddle import base, nn +from paddle.pir_utils import test_with_pir_api class Conv1DTestCase(unittest.TestCase): @@ -99,13 +100,16 @@ def functional(self, place): w_var = paddle.static.data( "weight", self.weight_shape, dtype=self.dtype ) - b_var = paddle.static.data( - "bias", (self.num_filters,), dtype=self.dtype - ) + if not self.no_bias: + b_var = paddle.static.data( + "bias", (self.num_filters,), dtype=self.dtype + ) + else: + b_var = None y_var = F.conv1d( x_var, w_var, - b_var if not self.no_bias else None, + b_var, padding=self.padding, stride=self.stride, dilation=self.dilation, @@ -117,6 +121,7 @@ def functional(self, place): feed_dict["bias"] = self.bias exe = base.Executor(place) exe.run(start) + # breakpoint() (y_np,) = exe.run(main, feed=feed_dict, fetch_list=[y_var]) return y_np @@ -140,6 +145,7 @@ def paddle_nn_layer(self): y_np = y_var.numpy() return y_np + @test_with_pir_api def _test_equivalence(self, place): result1 = self.functional(place) with dg.guard(place): diff --git a/test/legacy_test/test_cuda_max_memory_allocated.py b/test/legacy_test/test_cuda_max_memory_allocated.py index 90e016921f8a21..969489fa8f925e 100644 --- a/test/legacy_test/test_cuda_max_memory_allocated.py +++ b/test/legacy_test/test_cuda_max_memory_allocated.py @@ -61,10 +61,10 @@ def test_max_memory_allocated_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 max_memory_allocated(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): max_memory_allocated() diff --git a/test/legacy_test/test_cuda_max_memory_reserved.py b/test/legacy_test/test_cuda_max_memory_reserved.py index ac3b2b712e2ff7..7f0a3f4da388fc 100644 --- a/test/legacy_test/test_cuda_max_memory_reserved.py +++ b/test/legacy_test/test_cuda_max_memory_reserved.py @@ -61,10 +61,10 @@ def test_max_memory_reserved_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 max_memory_reserved(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): max_memory_reserved() diff --git a/test/legacy_test/test_cuda_memory_allocated.py b/test/legacy_test/test_cuda_memory_allocated.py index 3e4c2589406590..192126c092a4bb 100644 --- a/test/legacy_test/test_cuda_memory_allocated.py +++ b/test/legacy_test/test_cuda_memory_allocated.py @@ -46,10 +46,10 @@ def test_memory_allocated_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 memory_allocated(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): memory_allocated() diff --git a/test/legacy_test/test_cuda_memory_reserved.py b/test/legacy_test/test_cuda_memory_reserved.py index d639eab054ff52..8a02834f8fd3a3 100644 --- a/test/legacy_test/test_cuda_memory_reserved.py +++ b/test/legacy_test/test_cuda_memory_reserved.py @@ -46,10 +46,10 @@ def test_memory_reserved_exception(self): "gpu1", ] for device in wrong_device: - with self.assertRaises(BaseException): + with self.assertRaises(BaseException): # noqa: B017 memory_reserved(device) else: - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): memory_reserved() diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index be733d989a93fa..ee853bd553eb0d 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -23,6 +23,7 @@ import paddle.inference as paddle_infer from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestCumsumOp(unittest.TestCase): @@ -53,7 +54,7 @@ def run_cases(self): np.testing.assert_array_equal(z, y.numpy()) def run_static(self, use_gpu=False): - with base.program_guard(base.Program()): + with paddle.static.program_guard(paddle.static.Program()): data_np = np.random.random((100, 100)).astype(np.float32) x = paddle.static.data('X', [100, 100]) y = paddle.cumsum(x) @@ -65,16 +66,16 @@ def run_static(self, use_gpu=False): place = base.CUDAPlace(0) if use_gpu else base.CPUPlace() exe = base.Executor(place) - exe.run(base.default_startup_program()) + exe.run(paddle.static.default_startup_program()) out = exe.run( feed={'X': data_np}, fetch_list=[ - y.name, - y2.name, - y3.name, - y4.name, - y5.name, - y6.name, + y, + y2, + y3, + y4, + y5, + y6, ], ) @@ -89,20 +90,26 @@ def run_static(self, use_gpu=False): z = np.cumsum(data_np, axis=-2) np.testing.assert_allclose(z, out[5], rtol=1e-05) - def test_cpu(self): + def test_cpu_dygraph(self): paddle.disable_static(paddle.base.CPUPlace()) self.run_cases() paddle.enable_static() + @test_with_pir_api + def test_cpu_static(self): self.run_static() - def test_gpu(self): + def test_gpu_dygraph(self): if not base.core.is_compiled_with_cuda(): return paddle.disable_static(paddle.base.CUDAPlace(0)) self.run_cases() paddle.enable_static() + @test_with_pir_api + def test_gpu_static(self): + if not base.core.is_compiled_with_cuda(): + return self.run_static(use_gpu=True) def test_name(self): @@ -133,10 +140,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = self.dtype_ = np.float64 @@ -242,10 +249,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = self.dtype_ = np.float64 @@ -341,10 +348,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = np.float16 @@ -380,10 +387,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def init_dtype(self): self.dtype = self.dtype_ = np.float64 @@ -401,14 +408,10 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad( - ['X'], - 'Out', - check_prim=True, - ) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) cls_name = "{}_{}".format(parent.__name__, "Fp16") TestCumsumFP16Op.__name__ = cls_name @@ -445,12 +448,17 @@ def if_enable_cinn(self): def test_check_output(self): place = paddle.CUDAPlace(0) - self.check_output_with_place(place, check_prim=True) + self.check_output_with_place(place, check_prim=True, check_pir=True) def test_check_grad(self): place = paddle.CUDAPlace(0) self.check_grad_with_place( - place, ["X"], "Out", check_prim=True, numeric_grad_delta=0.05 + place, + ["X"], + "Out", + check_prim=True, + numeric_grad_delta=0.05, + check_pir=True, ) cls_name = "{}_{}".format(parent.__name__, "BF16") @@ -552,6 +560,7 @@ def test_static_and_infer(self): class TestCumSumOpFp16(unittest.TestCase): + @test_with_pir_api def test_fp16(self): paddle.enable_static() x_np = np.random.random((100, 100)).astype('float16') diff --git a/test/legacy_test/test_diag_embed.py b/test/legacy_test/test_diag_embed.py index 2f3869713f0e39..ab2955f9d44056 100644 --- a/test/legacy_test/test_diag_embed.py +++ b/test/legacy_test/test_diag_embed.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest import numpy as np from op_test import OpTest, paddle_static_guard import paddle -import paddle.nn.functional as F from paddle import base from paddle.base import core @@ -26,7 +26,7 @@ class TestDiagEmbedOp(OpTest): def setUp(self): self.op_type = "diag_embed" - self.python_api = F.diag_embed + self.python_api = paddle.diag_embed self.init_config() self.outputs = {'Out': self.target} @@ -57,8 +57,8 @@ def test_case1(self): data1 = paddle.static.data( name='data1', shape=[2, 3, 4], dtype='float32' ) - out1 = F.diag_embed(data1) - out2 = F.diag_embed(data1, offset=1, dim1=-2, dim2=3) + out1 = paddle.diag_embed(data1) + out2 = paddle.diag_embed(data1, offset=1, dim1=-2, dim2=3) place = core.CPUPlace() exe = base.Executor(place) @@ -77,6 +77,11 @@ def test_case1(self): np.testing.assert_allclose(results[0], target1, rtol=1e-05) np.testing.assert_allclose(results[1], target2, rtol=1e-05) + def test_tensor_method(self): + paddle.disable_static() + x = paddle.arange(15).reshape((3, 5)).astype('float64') + self.assertTrue(inspect.ismethod(x.diag_embed)) + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py index 62a94832d1ae9e..9133577bddb2e2 100644 --- a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py +++ b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py @@ -96,14 +96,14 @@ def test_1_new_comm(self): run_test( clip_after_allreduce=True, max_global_norm=0.01, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_2_new_comm(self): run_test( clip_after_allreduce=False, max_global_norm=0.01, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py b/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py index f236be3a8d1507..279c2dd1016317 100644 --- a/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py +++ b/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py @@ -38,7 +38,7 @@ def test_gm_new_comm(self): clip_after_allreduce=True, max_global_norm=-1.0, gradient_merge_steps=2, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) def test_gm_with_fp16_acc_grad_new_comm(self): @@ -47,7 +47,7 @@ def test_gm_with_fp16_acc_grad_new_comm(self): max_global_norm=-1.0, gradient_merge_steps=2, use_master_acc_grad=False, - need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + need_env={"FLAGS_dynamic_static_unified_comm": "true"}, ) diff --git a/test/legacy_test/test_eigvals_op.py b/test/legacy_test/test_eigvals_op.py index 379603234d5afe..c54a4070be3a44 100644 --- a/test/legacy_test/test_eigvals_op.py +++ b/test/legacy_test/test_eigvals_op.py @@ -327,13 +327,13 @@ def test_cases(self): def test_error(self): paddle.disable_static() x = paddle.to_tensor([1]) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.linalg.eigvals(x) self.input_dims = [1, 2, 3, 4] self.set_input_data() x = paddle.to_tensor(self.input_data) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.linalg.eigvals(x) diff --git a/test/legacy_test/test_expand_as_v2_op.py b/test/legacy_test/test_expand_as_v2_op.py index 13aa6863b9bd6b..6b11c2f8dee99e 100755 --- a/test/legacy_test/test_expand_as_v2_op.py +++ b/test/legacy_test/test_expand_as_v2_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestExpandAsBasic(OpTest): @@ -48,10 +49,10 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestExpandAs_ZeroDim1(TestExpandAsBasic): @@ -104,11 +105,11 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output_with_place(place=paddle.CUDAPlace(0)) + self.check_output_with_place(place=paddle.CUDAPlace(0), check_pir=True) def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True + paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True, check_pir=True ) @@ -242,7 +243,7 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(output)} def test_check_output(self): - self.check_output_with_place(place=paddle.CUDAPlace(0)) + self.check_output_with_place(place=paddle.CUDAPlace(0), check_pir=True) def test_check_grad(self): pass @@ -261,26 +262,28 @@ def test_errors(self): # Test python API class TestExpandAsV2API(unittest.TestCase): + @test_with_pir_api def test_api(self): - input1 = np.random.random([12, 14]).astype("float32") - input2 = np.random.random([2, 12, 14]).astype("float32") - x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") - - y = paddle.static.data( - name='target_tensor', - shape=[2, 12, 14], - dtype="float32", - ) - - out_1 = paddle.expand_as(x, y=y) - - exe = base.Executor(place=base.CPUPlace()) - res_1 = exe.run( - base.default_main_program(), - feed={"x": input1, "target_tensor": input2}, - fetch_list=[out_1], - ) - np.testing.assert_array_equal(res_1[0], np.tile(input1, (2, 1, 1))) + with paddle.static.program_guard(paddle.static.Program()): + input1 = np.random.random([12, 14]).astype("float32") + input2 = np.random.random([2, 12, 14]).astype("float32") + x = paddle.static.data(name='x', shape=[12, 14], dtype="float32") + + y = paddle.static.data( + name='target_tensor', + shape=[2, 12, 14], + dtype="float32", + ) + + out_1 = paddle.expand_as(x, y=y) + + exe = base.Executor(place=base.CPUPlace()) + res_1 = exe.run( + paddle.static.default_main_program(), + feed={"x": input1, "target_tensor": input2}, + fetch_list=[out_1], + ) + np.testing.assert_array_equal(res_1[0], np.tile(input1, (2, 1, 1))) if __name__ == "__main__": diff --git a/test/legacy_test/test_flip.py b/test/legacy_test/test_flip.py index 4e5cc58ad33121..e4f729ded8234f 100644 --- a/test/legacy_test/test_flip.py +++ b/test/legacy_test/test_flip.py @@ -100,10 +100,10 @@ def init_attrs(self): self.attrs = {"axis": self.axis} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_cinn=True) + self.check_grad(["X"], "Out", check_cinn=True, check_pir=True) def init_test_case(self): self.in_shape = (6, 4, 2, 3) @@ -167,12 +167,16 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, check_cinn=True) + self.check_output_with_place( + place, check_cinn=True, check_pir=True + ) def test_check_grad(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_grad_with_place(place, ["X"], "Out", check_cinn=True) + self.check_grad_with_place( + place, ["X"], "Out", check_cinn=True, check_pir=True + ) cls_name = "{}_{}".format(parent.__name__, "FP16OP") TestFlipFP16.__name__ = cls_name @@ -202,12 +206,12 @@ def init_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) if core.is_bfloat16_supported(place): - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) if core.is_bfloat16_supported(place): - self.check_grad_with_place(place, ["X"], "Out") + self.check_grad_with_place(place, ["X"], "Out", check_pir=True) cls_name = "{}_{}".format(parent.__name__, "BF16OP") TestFlipBF16.__name__ = cls_name diff --git a/test/legacy_test/test_full_like_op.py b/test/legacy_test/test_full_like_op.py index 5cbcc3f5c78aa1..2126f532d9149f 100644 --- a/test/legacy_test/test_full_like_op.py +++ b/test/legacy_test/test_full_like_op.py @@ -22,6 +22,7 @@ from paddle.base import core from paddle.base.framework import convert_np_dtype_to_dtype_ from paddle.framework import in_pir_mode +from paddle.pir_utils import test_with_pir_api from paddle.static import Program, program_guard @@ -45,11 +46,12 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None): class TestFullOp(unittest.TestCase): """Test fill_any_like op(whose API is full_like) for attr out.""" + @test_with_pir_api def test_attr_tensor_API(self): paddle.enable_static() - startup_program = Program() - train_program = Program() - with program_guard(train_program, startup_program): + startup_program = paddle.static.Program() + train_program = paddle.static.Program() + with paddle.static.program_guard(train_program, startup_program): fill_value = 2.0 input = paddle.static.data( name='input', dtype='float32', shape=[2, 3] diff --git a/test/legacy_test/test_full_op.py b/test/legacy_test/test_full_op.py index 74e928e58a52a9..0281d41252a274 100644 --- a/test/legacy_test/test_full_op.py +++ b/test/legacy_test/test_full_op.py @@ -19,60 +19,63 @@ import paddle from paddle import base from paddle.base import Program, program_guard +from paddle.pir_utils import test_with_pir_api # Test python API class TestFullAPI(unittest.TestCase): + @test_with_pir_api def test_api(self): - positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2) + with paddle.static.program_guard(paddle.static.Program()): + positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2) - positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 2) - shape_tensor_int32 = paddle.static.data( - name="shape_tensor_int32", shape=[2], dtype="int32" - ) + positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 2) + shape_tensor_int32 = paddle.static.data( + name="shape_tensor_int32", shape=[2], dtype="int32" + ) - shape_tensor_int64 = paddle.static.data( - name="shape_tensor_int64", shape=[2], dtype="int64" - ) + shape_tensor_int64 = paddle.static.data( + name="shape_tensor_int64", shape=[2], dtype="int64" + ) - out_1 = paddle.full(shape=[1, 2], dtype="float32", fill_value=1.1) + out_1 = paddle.full(shape=[1, 2], dtype="float32", fill_value=1.1) - out_2 = paddle.full( - shape=[1, positive_2_int32], dtype="float32", fill_value=1.1 - ) + out_2 = paddle.full( + shape=[1, positive_2_int32], dtype="float32", fill_value=1.1 + ) - out_3 = paddle.full( - shape=[1, positive_2_int64], dtype="float32", fill_value=1.1 - ) + out_3 = paddle.full( + shape=[1, positive_2_int64], dtype="float32", fill_value=1.1 + ) - out_4 = paddle.full( - shape=shape_tensor_int32, dtype="float32", fill_value=1.2 - ) + out_4 = paddle.full( + shape=shape_tensor_int32, dtype="float32", fill_value=1.2 + ) - out_5 = paddle.full( - shape=shape_tensor_int64, dtype="float32", fill_value=1.1 - ) + out_5 = paddle.full( + shape=shape_tensor_int64, dtype="float32", fill_value=1.1 + ) - out_6 = paddle.full( - shape=shape_tensor_int64, dtype=np.float32, fill_value=1.1 - ) + out_6 = paddle.full( + shape=shape_tensor_int64, dtype=np.float32, fill_value=1.1 + ) - val = paddle.tensor.fill_constant( - shape=[1], dtype=np.float32, value=1.1 - ) - out_7 = paddle.full( - shape=shape_tensor_int64, dtype=np.float32, fill_value=val - ) + val = paddle.tensor.fill_constant( + shape=[1], dtype=np.float32, value=1.1 + ) + out_7 = paddle.full( + shape=shape_tensor_int64, dtype=np.float32, fill_value=val + ) - exe = base.Executor(place=base.CPUPlace()) - res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( - base.default_main_program(), - feed={ - "shape_tensor_int32": np.array([1, 2]).astype("int32"), - "shape_tensor_int64": np.array([1, 2]).astype("int64"), - }, - fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], - ) + exe = base.Executor(place=base.CPUPlace()) + res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( + paddle.static.default_main_program(), + feed={ + "shape_tensor_int32": np.array([1, 2]).astype("int32"), + "shape_tensor_int64": np.array([1, 2]).astype("int64"), + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], + ) np.testing.assert_array_equal( res_1, np.full([1, 2], 1.1, dtype="float32") diff --git a/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py b/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py index 7cfc0da1ebe47e..b4b2471d95da9b 100644 --- a/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py +++ b/test/legacy_test/test_fusion_seqconv_eltadd_relu_op.py @@ -56,7 +56,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_dygraph=False) class TestSeqConvEltAddReluBS1(TestSeqConvEltAddRelu): diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index 3a27faf99cb6b8..7d1dea17e20ebc 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestGatherNdOpWithEmptyIndex(OpTest): @@ -561,6 +562,7 @@ def test_check_grad(self): # Test Python API class TestGatherNdOpAPI(unittest.TestCase): + @test_with_pir_api def test_case1(self): x1 = paddle.static.data( name='x1', shape=[-1, 30, 40, 50, 60], dtype='float32' @@ -570,6 +572,7 @@ def test_case1(self): ) output1 = paddle.gather_nd(x1, index1) + @test_with_pir_api def test_case2(self): x2 = paddle.static.data( name='x2', shape=[-1, 30, 40, 50], dtype='float32' @@ -579,6 +582,7 @@ def test_case2(self): ) output2 = paddle.gather_nd(x2, index2) + @test_with_pir_api def test_case3(self): x3 = paddle.static.data(name='x3', shape=[-1, 3, 4, 5], dtype='float32') index3 = paddle.static.data( @@ -589,6 +593,7 @@ def test_case3(self): # Test Raise Index Error class TestGatherNdOpRaise(unittest.TestCase): + @test_with_pir_api def test_check_raise(self): def check_raise_is_test(): try: @@ -638,16 +643,15 @@ def test_index_dtype(): class TestGatherNdAPI2(unittest.TestCase): + @test_with_pir_api def test_static(self): with base.program_guard(base.Program(), base.Program()): data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64') - data1.desc.set_need_check_feed(False) index = paddle.static.data('index', shape=[-1, 1], dtype='int32') - index.desc.set_need_check_feed(False) out = paddle.gather_nd(data1, index) place = base.CPUPlace() exe = base.Executor(place) - input = np.array([[1, 2], [3, 4], [5, 6]]) + input = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64') index_1 = np.array([[1]]).astype('int32') (result,) = exe.run( feed={"data1": input, "index": index_1}, fetch_list=[out] @@ -655,6 +659,7 @@ def test_static(self): expected_output = np.array([[3, 4]]) np.testing.assert_allclose(result, expected_output, rtol=1e-05) + @test_with_pir_api def test_static_fp16_with_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -671,11 +676,9 @@ def test_static_fp16_with_gpu(self): x = paddle.static.data( name="x", shape=[2, 3, 2], dtype="float16" ) - x.desc.set_need_check_feed(False) idx = paddle.static.data( name="index", shape=[1, 2], dtype="int32" ) - idx.desc.set_need_check_feed(False) y = paddle.gather_nd(x, idx) diff --git a/test/legacy_test/test_gather_op.py b/test/legacy_test/test_gather_op.py index e845875394be68..7a317e724a5c9a 100644 --- a/test/legacy_test/test_gather_op.py +++ b/test/legacy_test/test_gather_op.py @@ -21,6 +21,7 @@ from paddle import base from paddle.base.dygraph.base import switch_to_static_graph from paddle.framework import core +from paddle.pir_utils import test_with_pir_api def gather_numpy(x, index, axis): @@ -418,23 +419,23 @@ def config_dtype(self): class API_TestGather(unittest.TestCase): + @test_with_pir_api def test_out1(self): with base.program_guard(base.Program(), base.Program()): data1 = paddle.static.data('data1', shape=[-1, 2], dtype='float64') - data1.desc.set_need_check_feed(False) - index = paddle.static.data('index', shape=[-1, 1], dtype='int32') - index.desc.set_need_check_feed(False) + index = paddle.static.data('index', shape=[-1, 1], dtype='int64') out = paddle.gather(data1, index) place = base.CPUPlace() exe = base.Executor(place) - input = np.array([[1, 2], [3, 4], [5, 6]]) - index_1 = np.array([1, 2]) + input = np.array([[1, 2], [3, 4], [5, 6]]).astype('float64') + index_1 = np.array([1, 2]).astype('int64') (result,) = exe.run( feed={"data1": input, "index": index_1}, fetch_list=[out] ) expected_output = np.array([[3, 4], [5, 6]]) np.testing.assert_allclose(result, expected_output, rtol=1e-05) + @test_with_pir_api def test_out2(self): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() @@ -608,6 +609,13 @@ def test_out_type(self): out = paddle.gather(data, index) self.assertTrue(out.dtype == core.VarDesc.VarType.INT64) + def test_pir_out_type(self): + with paddle.pir_utils.IrGuard(): + data = paddle.static.data(shape=[16, 10], dtype='int64', name='x') + index = paddle.static.data(shape=[4], dtype='int64', name='index') + out = paddle.gather(data, index) + self.assertTrue(out.dtype == core.DataType.INT64) + if __name__ == "__main__": paddle.enable_static() diff --git a/test/legacy_test/test_increment.py b/test/legacy_test/test_increment.py index 4887564e9b9bb2..3055ffe1bdcf3a 100755 --- a/test/legacy_test/test_increment.py +++ b/test/legacy_test/test_increment.py @@ -18,9 +18,11 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api class TestIncrement(unittest.TestCase): + @test_with_pir_api def test_api(self): with base.program_guard(base.Program(), base.Program()): input = paddle.tensor.fill_constant( @@ -41,6 +43,7 @@ def test_api(self): class TestInplaceApiWithDataTransform(unittest.TestCase): + @test_with_pir_api def test_increment(self): if base.core.is_compiled_with_cuda(): paddle.enable_static() diff --git a/test/legacy_test/test_logical_op.py b/test/legacy_test/test_logical_op.py index 98e15878cdfb68..81dec36e2f698e 100755 --- a/test/legacy_test/test_logical_op.py +++ b/test/legacy_test/test_logical_op.py @@ -67,6 +67,7 @@ } +# @test_with_pir_api def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True): paddle.enable_static() startup_program = Program() diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index b2f2e21af25eec..641fc68e1832dc 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -58,10 +58,12 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ) class TestLogitOpFp32(TestLogitOp): @@ -71,10 +73,12 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ) class TestLogitOpFp16(TestLogitOp): @@ -84,10 +88,12 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ) @unittest.skipIf( @@ -115,7 +121,7 @@ def set_attrs(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): if core.is_compiled_with_cuda(): @@ -125,6 +131,7 @@ def test_check_grad(self): ['X'], ['Out'], user_defined_grads=[self.x_grad], + check_pir=True, ) diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 1a0254b66df52b..9a7ab29ea4451e 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -16,12 +16,153 @@ import unittest import warnings +import numpy as np + import paddle +from paddle import base paddle.enable_static() +paddle.device.set_device("cpu") + + +def new_program(): + # TODO(gouzil): Optimize program code + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + place = base.CPUPlace() + exe = base.Executor(place) + return ( + main_program, + exe, + paddle.static.program_guard( + main_program=main_program, startup_program=startup_program + ), + ) class TestMathOpPatchesPir(unittest.TestCase): + def test_pow(self): + # Calculate results in dynamic graphs + paddle.disable_static() + x_np = np.random.random([10, 1024]).astype('float32') + y_np = np.random.random([10, 1024]).astype('float32') + res_np_b = x_np**y_np + res_np_c = paddle.pow(paddle.to_tensor(x_np), 2) + # TODO(gouzil): solve paddle.fill_constant problem + # res_np_d = x_np.__pow__(2) + # res_np_e = x_np.__rpow__(2) + paddle.enable_static() + # Calculate results under pir + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='float32' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='float32' + ) + b = x**y + c = x.pow(2) + # d = x.__pow__(2) + # e = x.__rpow__(2) + # TODO(gouzil): Why not use `paddle.static.default_main_program()`? + # Because different case do not isolate parameters (This is a known problem) + (b_np, c_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c], + ) + np.testing.assert_allclose(res_np_b, b_np, rtol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, rtol=1e-05) + # np.testing.assert_allclose(res_np_d, d_np, rtol=1e-05) + # np.testing.assert_allclose(res_np_e, e_np, rtol=1e-05) + + def test_mod(self): + paddle.disable_static() + x_np = np.random.randint(1, 100, size=[10, 1024], dtype=np.int64) + y_np = np.random.randint(1, 100, size=[10, 1024], dtype=np.int64) + res_np_b = x_np % y_np + res_np_c = paddle.mod(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + res_np_d = x_np.__mod__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='int64' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='int64' + ) + b = x % y + c = x.mod(y) + d = x.__mod__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + + def test_matmul(self): + paddle.disable_static() + x_np = np.random.uniform(-1, 1, [2, 3]).astype('float32') + y_np = np.random.uniform(-1, 1, [3, 5]).astype('float32') + res_np_b = x_np @ y_np # __matmul__ + res_np_c = paddle.matmul(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + res_np_d = x_np.__matmul__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data(name='x', shape=[2, 3], dtype='float32') + y = paddle.static.data(name='y', shape=[3, 5], dtype='float32') + b = x @ y + c = x.matmul(y) + d = x.__matmul__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + + def test_floordiv(self): + paddle.disable_static() + x_np = np.full([10, 1024], 10, np.int64) + y_np = np.full([10, 1024], 2, np.int64) + res_np_b = x_np // y_np + res_np_c = paddle.floor_divide( + paddle.to_tensor(x_np), paddle.to_tensor(y_np) + ) + res_np_d = x_np.__floordiv__(y_np) + paddle.enable_static() + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.static.data( + name='x', shape=[10, 1024], dtype='int64' + ) + y = paddle.static.data( + name='y', shape=[10, 1024], dtype='int64' + ) + b = x // y + c = x.floor_divide(y) + d = x.__floordiv__(y) + (b_np, c_np, d_np) = exe.run( + main_program, + feed={"x": x_np, "y": y_np}, + fetch_list=[b, c, d], + ) + np.testing.assert_allclose(res_np_b, b_np, atol=1e-05) + np.testing.assert_allclose(res_np_c, c_np, atol=1e-05) + np.testing.assert_allclose(res_np_d, d_np, atol=1e-05) + def test_item(self): with paddle.pir_utils.IrGuard(): x = paddle.static.data(name='x', shape=[3, 2, 1]) diff --git a/test/legacy_test/test_mse_loss.py b/test/legacy_test/test_mse_loss.py index 688895240a374c..ab2e9deaef488c 100644 --- a/test/legacy_test/test_mse_loss.py +++ b/test/legacy_test/test_mse_loss.py @@ -20,9 +20,11 @@ from paddle import base from paddle.base import core from paddle.base.executor import Executor +from paddle.pir_utils import test_with_pir_api class TestMseLoss(unittest.TestCase): + @test_with_pir_api def test_mse_loss(self): input_val = np.random.uniform(0.1, 0.5, (2, 3)).astype("float32") label_val = np.random.uniform(0.1, 0.5, (2, 3)).astype("float32") @@ -30,29 +32,35 @@ def test_mse_loss(self): sub = input_val - label_val np_result = np.mean(sub * sub) - input_var = paddle.static.data( - name="input", shape=[-1, 3], dtype="float32" - ) - label_var = paddle.static.data( - name="label", shape=[-1, 3], dtype="float32" - ) - - output = paddle.nn.functional.mse_loss(input=input_var, label=label_var) - for use_cuda in ( - [False, True] if core.is_compiled_with_cuda() else [False] - ): - place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() - exe = Executor(place) - (result,) = exe.run( - base.default_main_program(), - feed={"input": input_val, "label": label_val}, - fetch_list=[output], + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + input_var = paddle.static.data( + name="input", shape=[-1, 3], dtype="float32" ) + label_var = paddle.static.data( + name="label", shape=[-1, 3], dtype="float32" + ) + + output = paddle.nn.functional.mse_loss( + input=input_var, label=label_var + ) + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() + exe = Executor(place) + (result,) = exe.run( + main, + feed={"input": input_val, "label": label_val}, + fetch_list=[output], + ) - np.testing.assert_allclose(np_result, result, rtol=1e-05) + np.testing.assert_allclose(np_result, result, rtol=1e-05) class TestMseInvalidInput(unittest.TestCase): + @test_with_pir_api def test_error(self): def test_invalid_input(): input = [256, 3] @@ -74,6 +82,7 @@ def test_invalid_label(): class TestNNMseLoss(unittest.TestCase): + @test_with_pir_api def test_NNMseLoss_mean(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -88,13 +97,11 @@ def test_NNMseLoss_mean(self): ) with base.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=[-1] + dim, dtype='float32' + name='input', shape=dim, dtype='float32' ) - input.desc.set_need_check_feed(False) label = paddle.static.data( - name='label', shape=[-1] + dim, dtype='float32' + name='label', shape=dim, dtype='float32' ) - label.desc.set_need_check_feed(False) mse_loss = paddle.nn.loss.MSELoss() ret = mse_loss(input, label) @@ -120,6 +127,7 @@ def test_NNMseLoss_mean(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNMseLoss_sum(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -134,13 +142,11 @@ def test_NNMseLoss_sum(self): ) with base.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=[-1] + dim, dtype='float32' + name='input', shape=dim, dtype='float32' ) - input.desc.set_need_check_feed(False) label = paddle.static.data( - name='label', shape=[-1] + dim, dtype='float32' + name='label', shape=dim, dtype='float32' ) - label.desc.set_need_check_feed(False) mse_loss = paddle.nn.loss.MSELoss(reduction='sum') ret = mse_loss(input, label) @@ -166,6 +172,7 @@ def test_NNMseLoss_sum(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNMseLoss_none(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -180,13 +187,11 @@ def test_NNMseLoss_none(self): ) with base.program_guard(prog, startup_prog): input = paddle.static.data( - name='input', shape=[-1] + dim, dtype='float32' + name='input', shape=dim, dtype='float32' ) - input.desc.set_need_check_feed(False) label = paddle.static.data( - name='label', shape=[-1] + dim, dtype='float32' + name='label', shape=dim, dtype='float32' ) - label.desc.set_need_check_feed(False) mse_loss = paddle.nn.loss.MSELoss(reduction='none') ret = mse_loss(input, label) @@ -214,6 +219,7 @@ def test_NNMseLoss_none(self): class TestNNFunctionalMseLoss(unittest.TestCase): + @test_with_pir_api def test_NNFunctionalMseLoss_mean(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -256,6 +262,7 @@ def test_NNFunctionalMseLoss_mean(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNFunctionalMseLoss_sum(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") @@ -298,6 +305,7 @@ def test_NNFunctionalMseLoss_sum(self): np.testing.assert_allclose(dy_result, expected, rtol=1e-05) self.assertEqual(dy_result.shape, ()) + @test_with_pir_api def test_NNFunctionalMseLoss_none(self): for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: input_np = np.random.uniform(0.1, 0.5, dim).astype("float32") diff --git a/test/legacy_test/test_nonzero_api.py b/test/legacy_test/test_nonzero_api.py index a57e1d9803c224..a14c72a22a149e 100644 --- a/test/legacy_test/test_nonzero_api.py +++ b/test/legacy_test/test_nonzero_api.py @@ -29,6 +29,7 @@ def call_nonzero(x): class TestNonZeroAPI(unittest.TestCase): def test_nonzero_api_as_tuple(self): + paddle.enable_static() data = np.array([[True, False], [False, True]]) with program_guard(Program(), Program()): x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32') @@ -61,6 +62,7 @@ def test_nonzero_api_as_tuple(self): np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) def test_nonzero_api(self): + paddle.enable_static() data = np.array([[True, False], [False, True]]) with program_guard(Program(), Program()): x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32') @@ -108,7 +110,7 @@ def setUp(self): self.outputs = self.return_outputs() def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def init_shape(self): self.shape = [8, 8] @@ -156,7 +158,7 @@ def setUp(self): self.outputs = self.return_outputs() def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def init_shape(self): self.shape = [12, 9] diff --git a/test/legacy_test/test_norm_all.py b/test/legacy_test/test_norm_all.py index 58be6779757422..86eea3a4c8eb02 100644 --- a/test/legacy_test/test_norm_all.py +++ b/test/legacy_test/test_norm_all.py @@ -102,10 +102,10 @@ def setUp(self): self.outputs = {'Out': norm} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def init_test_case(self): self.shape = [2, 3, 4, 5] @@ -126,7 +126,7 @@ def init_dtype(self): self.dtype = "float32" def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestPnormOp(OpTest): diff --git a/test/legacy_test/test_pca_lowrank.py b/test/legacy_test/test_pca_lowrank.py index 107c76b442af9b..68f0005b368230 100644 --- a/test/legacy_test/test_pca_lowrank.py +++ b/test/legacy_test/test_pca_lowrank.py @@ -62,9 +62,7 @@ def run_subtest( self.assertEqual(v.shape[-1], guess_rank) self.assertEqual(v.shape[-2], columns) - A1 = u.matmul(paddle.nn.functional.diag_embed(s)).matmul( - self.transpose(v) - ) + A1 = u.matmul(paddle.diag_embed(s)).matmul(self.transpose(v)) ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype) c = a.sum(axis=-2) / rows c = c.reshape(batches + (1, columns)) diff --git a/test/legacy_test/test_prod_op.py b/test/legacy_test/test_prod_op.py index 2a0b06d76f849d..7a69a840c393dd 100644 --- a/test/legacy_test/test_prod_op.py +++ b/test/legacy_test/test_prod_op.py @@ -18,6 +18,7 @@ from test_sum_op import TestReduceOPTensorAxisBase import paddle +from paddle.pir_utils import test_with_pir_api class TestProdOp(unittest.TestCase): @@ -70,33 +71,35 @@ def run_imperative(self): dy_result.numpy(), expected_result, rtol=1e-05 ) + @test_with_pir_api def run_static(self, use_gpu=False): - input = paddle.static.data( - name='input', shape=[10, 10, 5], dtype='float32' - ) - result0 = paddle.prod(input) - result1 = paddle.prod(input, axis=1) - result2 = paddle.prod(input, axis=-1) - result3 = paddle.prod(input, axis=[0, 1]) - result4 = paddle.prod(input, axis=1, keepdim=True) - result5 = paddle.prod(input, axis=1, dtype='int64') - result6 = paddle.prod(input, axis=1, keepdim=True, dtype='int64') - - place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) - static_result = exe.run( - feed={"input": self.input}, - fetch_list=[ - result0, - result1, - result2, - result3, - result4, - result5, - result6, - ], - ) + with paddle.static.program_guard(paddle.static.Program()): + input = paddle.static.data( + name='input', shape=[10, 10, 5], dtype='float32' + ) + result0 = paddle.prod(input) + result1 = paddle.prod(input, axis=1) + result2 = paddle.prod(input, axis=-1) + result3 = paddle.prod(input, axis=[0, 1]) + result4 = paddle.prod(input, axis=1, keepdim=True) + result5 = paddle.prod(input, axis=1, dtype='int64') + result6 = paddle.prod(input, axis=1, keepdim=True, dtype='int64') + + place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + static_result = exe.run( + feed={"input": self.input}, + fetch_list=[ + result0, + result1, + result2, + result3, + result4, + result5, + result6, + ], + ) expected_result = np.prod(self.input) np.testing.assert_allclose( @@ -134,8 +137,7 @@ def test_cpu(self): self.run_imperative() paddle.enable_static() - with paddle.static.program_guard(paddle.static.Program()): - self.run_static() + self.run_static() def test_gpu(self): if not paddle.base.core.is_compiled_with_cuda(): @@ -145,8 +147,7 @@ def test_gpu(self): self.run_imperative() paddle.enable_static() - with paddle.static.program_guard(paddle.static.Program()): - self.run_static(use_gpu=True) + self.run_static(use_gpu=True) class TestProdOpError(unittest.TestCase): diff --git a/test/legacy_test/test_real_imag_op.py b/test/legacy_test/test_real_imag_op.py index f714cef69e6d4d..cfc9ea2112c65a 100644 --- a/test/legacy_test/test_real_imag_op.py +++ b/test/legacy_test/test_real_imag_op.py @@ -19,6 +19,7 @@ import paddle from paddle import base, static +from paddle.pir_utils import test_with_pir_api numpy_apis = { "real": np.real, @@ -57,7 +58,7 @@ def init_grad_input_output(self): ) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -99,6 +100,7 @@ def setUp(self): self.places.append(paddle.CUDAPlace(0)) self._shape = [2, 20, 2, 3] + @test_with_pir_api def test_in_static_mode(self): def init_input_output(dtype): input = np.random.random(self._shape).astype( @@ -114,7 +116,7 @@ def init_input_output(dtype): out = paddle_apis[self.api](x) exe = static.Executor(place) - out_value = exe.run(feed=input_dict, fetch_list=[out.name]) + out_value = exe.run(feed=input_dict, fetch_list=[out]) np.testing.assert_array_equal(np_res, out_value[0]) def test_in_dynamic_mode(self): diff --git a/test/legacy_test/test_repeat_interleave_op.py b/test/legacy_test/test_repeat_interleave_op.py index ec6649039dd45f..764b2a84d7b634 100644 --- a/test/legacy_test/test_repeat_interleave_op.py +++ b/test/legacy_test/test_repeat_interleave_op.py @@ -112,6 +112,7 @@ def input_data(self): [9.0, 10.0, 11.0, 12.0], ] ) + self.data_zero_dim_index = np.array(2) self.data_index = np.array([0, 1, 2, 1]).astype('int32') def test_repeat_interleave_api(self): @@ -267,6 +268,17 @@ def test_dygraph_api(self): expect_out = np.repeat(self.data_zero_dim_x, index, axis=None) np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + # case 4 zero_dim_index + with base.dygraph.guard(): + x = base.dygraph.to_variable(self.data_zero_dim_x) + index = base.dygraph.to_variable(self.data_zero_dim_index) + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat( + self.data_zero_dim_x, self.data_zero_dim_index, axis=None + ) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_rms_norm_op.py b/test/legacy_test/test_rms_norm_op.py index 79e20e906d92ce..dc9061ad95924e 100644 --- a/test/legacy_test/test_rms_norm_op.py +++ b/test/legacy_test/test_rms_norm_op.py @@ -18,6 +18,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def quant_helper( @@ -342,49 +343,6 @@ def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): ) return out_s[0], paddle_naive_rmsnorm_out - def test_rmsnorm_pir(self): - paddle.disable_static() - x = paddle.to_tensor(self.x_np.astype("float32")) - gamma = paddle.to_tensor(self.norm_weight_np.astype("float32")) - beta = paddle.to_tensor(self.norm_bias_np.astype("float32")) - - paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) - paddle.enable_static() - - with paddle.pir_utils.IrGuard(): - x_static = paddle.static.data( - name="x_static", shape=[self.batch, self.cols], dtype="float32" - ) - gamma_static = paddle.static.data( - name="gamma_static", shape=[self.cols], dtype="float32" - ) - beta_static = paddle.static.data( - name="beta_static", shape=[self.cols], dtype="float32" - ) - out, _ = paddle.incubate.nn.functional.fused_rms_norm( - x_static, - gamma_static, - beta_static, - self.epsilon, - begin_norm_axis=1, - ) - exe = base.Executor(self.place) - out_s = exe.run( - feed={ - "x_static": self.x_np.astype("float32"), - "gamma_static": self.norm_weight_np.astype("float32"), - "beta_static": self.norm_bias_np.astype("float32"), - }, - fetch_list=[out], - ) - - np.testing.assert_allclose( - out_s[0], - paddle_naive_rmsnorm_out.numpy(), - rtol=1e-3, - atol=1e-3, - ) - def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype): paddle.disable_static() x = paddle.to_tensor(x_np.astype(dtype)) @@ -491,6 +449,7 @@ def check_residual_bias_rmsnorm( ) return out_s[0], paddle_naive_rmsnorm_out + @test_with_pir_api def test_rmsnorm_fp16(self): if not paddle.is_compiled_with_cuda(): return @@ -505,6 +464,7 @@ def test_rmsnorm_fp16(self): atol=1e-3, ) + @test_with_pir_api def test_residual_bias_add_rmsnorm_fp16(self): if not paddle.is_compiled_with_cuda(): return @@ -524,6 +484,7 @@ def test_residual_bias_add_rmsnorm_fp16(self): atol=1e-3, ) + @test_with_pir_api def test_rmsnorm_int8(self): if not paddle.is_compiled_with_cuda(): return diff --git a/test/legacy_test/test_softmax_op.py b/test/legacy_test/test_softmax_op.py index ae98b434766192..74b685333d9252 100644 --- a/test/legacy_test/test_softmax_op.py +++ b/test/legacy_test/test_softmax_op.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.random.seed(10) @@ -512,6 +513,7 @@ def setUp(self): def executed_api(self): self.softmax = F.softmax + @test_with_pir_api def test_static_check(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -590,6 +592,7 @@ def test_dygraph(self): paddle.enable_static() + @test_with_pir_api def test_static(self): with static_guard(): main_prog = base.Program() @@ -597,18 +600,17 @@ def test_static(self): x = paddle.rand([]) x.stop_gradient = False out = paddle.nn.functional.softmax(x) - base.backward.append_backward(out) # Test compile shape - self.assertEqual(x.shape, ()) - self.assertEqual(out.shape, ()) + self.assertEqual(tuple(x.shape), ()) + self.assertEqual(tuple(out.shape), ()) exe = base.Executor() result = exe.run(main_prog, fetch_list=[x, out]) # Test runtime shape - self.assertEqual(result[0].shape, ()) - self.assertEqual(result[1].shape, ()) + self.assertEqual(tuple(result[0].shape), ()) + self.assertEqual(tuple(result[1].shape), ()) class TestSoftmaxInplaceAPI(TestSoftmaxAPI): diff --git a/test/legacy_test/test_tensor_fill_diagonal_tensor.py b/test/legacy_test/test_tensor_fill_diagonal_tensor.py index 7409cdae1f0072..cf3493d6039764 100644 --- a/test/legacy_test/test_tensor_fill_diagonal_tensor.py +++ b/test/legacy_test/test_tensor_fill_diagonal_tensor.py @@ -17,7 +17,6 @@ import numpy as np import paddle -import paddle.nn.functional as F from paddle import base @@ -202,9 +201,9 @@ def test_largedim(self): loss.backward() expected_pred = v - 2 - expected_pred = F.diag_embed(expected_pred) + 2 + expected_pred = paddle.diag_embed(expected_pred) + 2 expected_grad = paddle.ones(v.shape, dtype=dtype) - 2 - expected_grad = F.diag_embed(expected_grad) + 1 + expected_grad = paddle.diag_embed(expected_grad) + 1 self.assertEqual((ny == expected_pred).all(), True) self.assertEqual((y.grad == expected_grad).all(), True) diff --git a/test/legacy_test/test_tensor_fill_diagonal_tensor_.py b/test/legacy_test/test_tensor_fill_diagonal_tensor_.py index 482f3e542f6fc3..7966470e4e8fbf 100644 --- a/test/legacy_test/test_tensor_fill_diagonal_tensor_.py +++ b/test/legacy_test/test_tensor_fill_diagonal_tensor_.py @@ -17,7 +17,6 @@ import numpy as np import paddle -import paddle.nn.functional as F from paddle import base @@ -203,9 +202,9 @@ def test_largedim(self): loss.backward() expected_pred = v - 2 - expected_pred = F.diag_embed(expected_pred) + 2 + expected_pred = paddle.diag_embed(expected_pred) + 2 expected_grad = paddle.ones(v.shape, dtype=dtype) - 2 - expected_grad = F.diag_embed(expected_grad) + 1 + expected_grad = paddle.diag_embed(expected_grad) + 1 self.assertEqual((y == expected_pred).all(), True) self.assertEqual((y.grad == expected_grad).all(), True) diff --git a/test/legacy_test/test_tensordot.py b/test/legacy_test/test_tensordot.py index 16d2015573d10f..0e41772abd6cb5 100644 --- a/test/legacy_test/test_tensordot.py +++ b/test/legacy_test/test_tensordot.py @@ -342,9 +342,21 @@ def test_error(self): paddle.disable_static() x = paddle.to_tensor(self.x) y = paddle.to_tensor(self.y) - for axes in self.all_axes: - with self.assertRaises(BaseException): - paddle.tensordot(x, y, axes) + + with self.assertRaises(TypeError): + paddle.tensordot(x, y, axes=self.all_axes[0]) + with self.assertRaises(TypeError): + paddle.tensordot(x, y, axes=self.all_axes[1]) + with self.assertRaises(AssertionError): + paddle.tensordot(x, y, axes=self.all_axes[2]) + with self.assertRaises(IndexError): + paddle.tensordot(x, y, axes=self.all_axes[3]) + with self.assertRaises(ValueError): + paddle.tensordot(x, y, axes=self.all_axes[4]) + with self.assertRaises(AssertionError): + paddle.tensordot(x, y, axes=self.all_axes[5]) + with self.assertRaises(AssertionError): + paddle.tensordot(x, y, axes=self.all_axes[6]) class TestTensordotAPIAxesTypeFloat64(TestTensordotAPIAxesType): diff --git a/test/legacy_test/test_top_k_v2_op.py b/test/legacy_test/test_top_k_v2_op.py index 9ff5d03473afcd..41d021c9085ad7 100644 --- a/test/legacy_test/test_top_k_v2_op.py +++ b/test/legacy_test/test_top_k_v2_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api def numpy_topk(x, k=1, axis=-1, largest=True): @@ -63,10 +64,10 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestTopkOp_ZeroDim(TestTopkOp): @@ -270,11 +271,13 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, check_pir=True + ) class TestTopKAPI(unittest.TestCase): @@ -377,8 +380,8 @@ def run_static(self, place): result1 = paddle.topk(input_tensor, k=2) result2 = paddle.topk(input_tensor, k=2, axis=-1) result3 = paddle.topk(input_tensor, k=k_tensor, axis=1) - self.assertEqual(result3[0].shape, (6, -1, 8)) - self.assertEqual(result3[1].shape, (6, -1, 8)) + self.assertEqual(tuple(result3[0].shape), (6, -1, 8)) + self.assertEqual(tuple(result3[1].shape), (6, -1, 8)) result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False) result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False) result6 = paddle.topk(large_input_tensor, k=1, axis=-1) @@ -461,21 +464,28 @@ def run_static(self, place): sort_paddle[0], numpy_result[0], rtol=1e-05 ) - def test_cases(self): + def test_dygraph_cases(self): places = [core.CPUPlace()] if core.is_compiled_with_cuda(): places.append(core.CUDAPlace(0)) for place in places: self.run_dygraph(place) + + @test_with_pir_api + def test_static_cases(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: self.run_static(place) def test_errors(self): with paddle.base.dygraph.guard(): x = paddle.to_tensor([1, 2, 3]) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.topk(x, k=-1) - with self.assertRaises(BaseException): + with self.assertRaises(ValueError): paddle.topk(x, k=0) diff --git a/test/legacy_test/test_trace_op.py b/test/legacy_test/test_trace_op.py index 1d53c1180b8367..a62c9e7f9aa8a5 100644 --- a/test/legacy_test/test_trace_op.py +++ b/test/legacy_test/test_trace_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base, tensor from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestTraceOp(OpTest): @@ -30,10 +31,10 @@ def setUp(self): self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['Input'], 'Out') + self.check_grad(['Input'], 'Out', check_pir=True) def init_config(self): self.case = np.random.randn(20, 6).astype('float64') @@ -108,11 +109,15 @@ def setUp(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) def test_check_grad(self): self.check_grad_with_place( - self.place, ['Input'], 'Out', numeric_grad_delta=0.02 + self.place, + ['Input'], + 'Out', + numeric_grad_delta=0.02, + check_pir=True, ) def init_config(self): @@ -145,22 +150,24 @@ def init_config(self): class TestTraceAPICase(unittest.TestCase): + @test_with_pir_api def test_case1(self): - case = np.random.randn(2, 20, 2, 3).astype('float32') - data1 = paddle.static.data( - name='data1', shape=[2, 20, 2, 3], dtype='float32' - ) - out1 = tensor.trace(data1) - out2 = tensor.trace(data1, offset=-5, axis1=1, axis2=-1) - - place = core.CPUPlace() - exe = base.Executor(place) - results = exe.run( - base.default_main_program(), - feed={"data1": case}, - fetch_list=[out1, out2], - return_numpy=True, - ) + with paddle.static.program_guard(paddle.static.Program()): + case = np.random.randn(2, 20, 2, 3).astype('float32') + data1 = paddle.static.data( + name='data1', shape=[2, 20, 2, 3], dtype='float32' + ) + out1 = tensor.trace(data1) + out2 = tensor.trace(data1, offset=-5, axis1=1, axis2=-1) + + place = core.CPUPlace() + exe = base.Executor(place) + results = exe.run( + paddle.static.default_main_program(), + feed={"data1": case}, + fetch_list=[out1, out2], + return_numpy=True, + ) target1 = np.trace(case) target2 = np.trace(case, offset=-5, axis1=1, axis2=-1) np.testing.assert_allclose(results[0], target1, rtol=1e-05) diff --git a/test/legacy_test/test_transpose_op.py b/test/legacy_test/test_transpose_op.py index 32f071eafb472b..98774942ce65d6 100644 --- a/test/legacy_test/test_transpose_op.py +++ b/test/legacy_test/test_transpose_op.py @@ -22,6 +22,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -541,6 +542,7 @@ def test_each_elem_value_check(): class TestTransposeApi(unittest.TestCase): + @test_with_pir_api def test_static_out(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -578,7 +580,8 @@ def test_dygraph_out(self): class TestTAPI(unittest.TestCase): - def test_out(self): + @test_with_pir_api + def test_static_out(self): with base.program_guard(base.Program()): data = paddle.static.data(shape=[10], dtype="float64", name="data") data_t = paddle.t(data) @@ -613,6 +616,7 @@ def test_out(self): expected_result = np.transpose(data_np) self.assertEqual((result == expected_result).all(), True) + def test_dygraph_out(self): with base.dygraph.guard(): np_x = np.random.random([10]).astype("float64") data = base.dygraph.to_variable(np_x) @@ -637,6 +641,7 @@ def test_out(self): z_expected = np.array(np.transpose(np_x)) self.assertEqual((np_z == z_expected).all(), True) + @test_with_pir_api def test_errors(self): with base.program_guard(base.Program()): x = paddle.static.data(name='x', shape=[10, 5, 3], dtype='float64') @@ -648,7 +653,8 @@ def test_x_dimension_check(): class TestMoveAxis(unittest.TestCase): - def test_moveaxis1(self): + @test_with_pir_api + def test_static_moveaxis1(self): x_np = np.random.randn(2, 3, 4, 5, 7) expected = np.moveaxis(x_np, [0, 4, 3, 2], [1, 3, 2, 0]) paddle.enable_static() @@ -661,6 +667,9 @@ def test_moveaxis1(self): np.testing.assert_array_equal(out_np, expected) + def test_dygraph_moveaxis1(self): + x_np = np.random.randn(2, 3, 4, 5, 7) + expected = np.moveaxis(x_np, [0, 4, 3, 2], [1, 3, 2, 0]) paddle.disable_static() x = paddle.to_tensor(x_np) out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0]) @@ -668,7 +677,8 @@ def test_moveaxis1(self): np.testing.assert_array_equal(out.numpy(), expected) paddle.enable_static() - def test_moveaxis2(self): + @test_with_pir_api + def test_static_moveaxis2(self): x_np = np.random.randn(2, 3, 5) expected = np.moveaxis(x_np, -2, -1) paddle.enable_static() @@ -681,6 +691,9 @@ def test_moveaxis2(self): np.testing.assert_array_equal(out_np, expected) + def test_dygraph_moveaxis2(self): + x_np = np.random.randn(2, 3, 5) + expected = np.moveaxis(x_np, -2, -1) paddle.disable_static() x = paddle.to_tensor(x_np) out = x.moveaxis(-2, -1) diff --git a/test/legacy_test/test_var_base.py b/test/legacy_test/test_var_base.py index 748ac4ca608ab8..6b388e2e7e4b1e 100644 --- a/test/legacy_test/test_var_base.py +++ b/test/legacy_test/test_var_base.py @@ -87,6 +87,10 @@ def check_with_place(place): self.assertEqual(y.place.__repr__(), "Place(gpu:0)") y = x.cuda(blocking=True) self.assertEqual(y.place.__repr__(), "Place(gpu:0)") + y = x.cuda(device_id=0, blocking=True) + self.assertEqual(y.place.__repr__(), "Place(gpu:0)") + y = x.cuda(device_id=0, blocking=False) + self.assertEqual(y.place.__repr__(), "Place(gpu:0)") with self.assertRaises(ValueError): y = x.cuda("test") diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index 3f12fa397a3a8c..bf4924815a4743 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import numpy @@ -23,6 +24,9 @@ from paddle.base.executor import Executor from paddle.incubate.layers.nn import shuffle_batch +sys.path.append("../dygraph_to_static") +from dygraph_to_static_util import test_and_compare_with_new_ir + paddle.enable_static() @@ -63,7 +67,6 @@ def simple_net(self): i = paddle.increment(x=i) paddle.tensor.array_write(result, i=i, array=mem_array) - paddle.assign(paddle.less_than(x=i, y=array_len), cond) with while_op2.block(): d2 = paddle.tensor.array_read(array=data_array, i=j) @@ -73,10 +76,13 @@ def simple_net(self): j = paddle.increment(x=j) paddle.tensor.array_write(result2, i=j, array=mem_array) paddle.assign(paddle.less_than(x=j, y=array_len2), cond2) + + paddle.assign(paddle.less_than(x=i, y=array_len), cond) sum_result = paddle.tensor.array_read(array=mem_array, i=j) loss = paddle.mean(sum_result) return loss, sum_result + # TODO(zhangbo): Support pir test(support write_to_array and read_from_array, support while_grad). def test_simple_net(self): main_program = base.Program() startup_program = base.Program() @@ -98,13 +104,13 @@ def test_simple_net(self): ) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) + # TODO(zhangbo): Support pir test(support write_to_array and read_from_array) def test_simple_net_forward(self): main_program = base.Program() startup_program = base.Program() with base.program_guard(main_program, startup_program): self.simple_net() binary = base.compiler.CompiledProgram(main_program) - cpu = core.CPUPlace() exe = Executor(cpu) d = [] @@ -115,6 +121,7 @@ def test_simple_net_forward(self): for _ in range(2): exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) + @test_and_compare_with_new_ir() def test_exceptions(self): i = paddle.zeros(shape=[2], dtype='int64') array_len = paddle.tensor.fill_constant( @@ -129,6 +136,7 @@ def test_exceptions(self): class BadInputTest(unittest.TestCase): + @test_and_compare_with_new_ir() def test_error(self): with base.program_guard(base.Program()): @@ -184,6 +192,7 @@ def body_func(i, ten, batch_info, origin_seq): class TestOutputsMustExistsInputs(unittest.TestCase): + @test_and_compare_with_new_ir() def test_outputs_exists_inputs(self): """ We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx. diff --git a/test/prim/pir_prim/test_vjp_prim.py b/test/prim/pir_prim/test_vjp_prim.py index 2755f2854487f9..86fbbfcd508ac2 100644 --- a/test/prim/pir_prim/test_vjp_prim.py +++ b/test/prim/pir_prim/test_vjp_prim.py @@ -71,7 +71,13 @@ def test_divide_grad_prim_case1(self): stop_gradients = [[False], [False]] divide_op = newir_program.global_block().ops[-1] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(divide_op, out_grads, stop_gradients) + grad_outs = call_vjp( + divide_op, + [[value] for value in divide_op.operands_source()], + [[value] for value in divide_op.results()], + out_grads, + stop_gradients, + ) reshape_op2 = newir_program.global_block().ops[-1] reshape_op1 = newir_program.global_block().ops[-8] self.assertEqual(len(grad_outs), 2) @@ -113,7 +119,13 @@ def test_divide_grad_no_prim(self): stop_gradients = [[False], [False]] divide_op = newir_program.global_block().ops[-1] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(divide_op, out_grads, stop_gradients) + grad_outs = call_vjp( + divide_op, + [[value] for value in divide_op.operands_source()], + [[value] for value in divide_op.results()], + out_grads, + stop_gradients, + ) self.assertEqual(len(grad_outs), 2) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.divide_grad" @@ -132,7 +144,13 @@ def test_sum_grad_prim(self): stop_gradients = [[False]] sum_op = newir_program.global_block().ops[-1] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(sum_op, out_grads, stop_gradients) + grad_outs = call_vjp( + sum_op, + [[value] for value in sum_op.operands_source()], + [[value] for value in sum_op.results()], + out_grads, + stop_gradients, + ) expand_op = newir_program.global_block().ops[-1] self.assertEqual(len(grad_outs), 1) self.assertEqual(len(newir_program.global_block().ops), 8) @@ -159,7 +177,13 @@ def test_sum_grad_no_prim(self): stop_gradients = [[False]] sum_op = newir_program.global_block().ops[-1] with paddle.pir.core.program_guard(newir_program): - grad_outs = call_vjp(sum_op, out_grads, stop_gradients) + grad_outs = call_vjp( + sum_op, + [[value] for value in sum_op.operands_source()], + [[value] for value in sum_op.results()], + out_grads, + stop_gradients, + ) self.assertEqual(len(grad_outs), 1) self.assertEqual( grad_outs[0][0].get_defining_op().name(), "pd_op.sum_grad" diff --git a/test/sot/test_01_basic.py b/test/sot/test_01_basic.py index 8a03ea9fd3ae5a..4a76cc2a2bdb53 100644 --- a/test/sot/test_01_basic.py +++ b/test/sot/test_01_basic.py @@ -14,9 +14,10 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def foo(x: int, y: paddle.Tensor): @@ -34,7 +35,7 @@ def numpy_add(x, y): class TestNumpyAdd(TestCaseBase): - @strict_mode_guard(0) + @strict_mode_guard(False) def test_numpy_add(self): x = paddle.to_tensor([2]) y = paddle.to_tensor([3]) diff --git a/test/sot/test_12_for_loop.py b/test/sot/test_12_for_loop.py index 63e3fedace4bfd..015ba340a1b357 100644 --- a/test/sot/test_12_for_loop.py +++ b/test/sot/test_12_for_loop.py @@ -19,7 +19,7 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle from paddle.jit import sot @@ -27,6 +27,7 @@ from paddle.jit.sot.opcode_translator.executor.executor_cache import ( OpcodeExecutorCache, ) +from paddle.jit.sot.utils import strict_mode_guard def gener(): @@ -294,5 +295,5 @@ def test_undefined_var_case_1(self): if __name__ == "__main__": - with strict_mode_guard(0): + with strict_mode_guard(False): unittest.main() diff --git a/test/sot/test_19_closure.py b/test/sot/test_19_closure.py index 6191141e07f390..ddfd36e2a60962 100644 --- a/test/sot/test_19_closure.py +++ b/test/sot/test_19_closure.py @@ -15,9 +15,10 @@ import inspect import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def foo(x: int, y: paddle.Tensor): @@ -180,7 +181,7 @@ def test_closure(self): self.assert_results(foo5, paddle.to_tensor(2)) self.assert_results(foo6, paddle.to_tensor(2)) self.assert_results(numpy_sum, paddle.to_tensor(1)) - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results( lambda_closure, paddle.to_tensor(2), paddle.to_tensor(1) ) diff --git a/test/sot/test_case_base.py b/test/sot/test_case_base.py index 03ce3c98227e8a..f5a57f66c186bc 100644 --- a/test/sot/test_case_base.py +++ b/test/sot/test_case_base.py @@ -136,23 +136,3 @@ def copy_fn(fn): sym_copied_fn.__globals__[key], paddle_fn.__globals__[key] ) self.assert_nest_match(sym_output, paddle_output) - - -@contextlib.contextmanager -def strict_mode_guard(value): - if "STRICT_MODE" not in os.environ: - os.environ["STRICT_MODE"] = "0" - old_value = os.environ["STRICT_MODE"] - os.environ["STRICT_MODE"] = str(value) - yield - os.environ["STRICT_MODE"] = old_value - - -@contextlib.contextmanager -def cost_model_guard(value): - if "COST_MODEL" not in os.environ: - os.environ["COST_MODEL"] = "True" - old_value = os.environ["COST_MODEL"] - os.environ["COST_MODEL"] = str(value) - yield - os.environ["COST_MODEL"] = old_value diff --git a/test/sot/test_code_status.py b/test/sot/test_code_status.py index 9fec5712c2293a..a873c919af3536 100644 --- a/test/sot/test_code_status.py +++ b/test/sot/test_code_status.py @@ -14,11 +14,12 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle from paddle.jit import sot from paddle.jit.sot.opcode_translator.skip_files import skip_function +from paddle.jit.sot.utils import strict_mode_guard from paddle.jit.sot.utils.code_status import CodeState, CodeStatus @@ -85,7 +86,7 @@ def test_case_1(self): ) def test_case_2(self): - with strict_mode_guard(0): + with strict_mode_guard(False): CodeStatus().clear() net = SimpleNet2() inp = paddle.rand((10, 10)) diff --git a/test/sot/test_cost_model.py b/test/sot/test_cost_model.py index 07899a03efbfd6..a3acec5942005e 100644 --- a/test/sot/test_cost_model.py +++ b/test/sot/test_cost_model.py @@ -15,11 +15,11 @@ import time import unittest -from test_case_base import TestCaseBase, cost_model_guard +from test_case_base import TestCaseBase import paddle from paddle.jit.sot import psdb, symbolic_translate -from paddle.jit.sot.utils import StepInfoManager, StepState +from paddle.jit.sot.utils import StepInfoManager, StepState, cost_model_guard def dyn_fast(x, net, iter_): @@ -58,7 +58,7 @@ def forward(self, x): class TestCostModel(TestCaseBase): - @cost_model_guard("True") + @cost_model_guard(True) def test_dyn_fast(self): x = paddle.rand([10]) net = paddle.nn.Linear(10, 10) @@ -69,7 +69,7 @@ def test_dyn_fast(self): state = StepInfoManager().step_record[dyn_fast.__code__].state assert state == StepState.RUN_DYN - @cost_model_guard("True") + @cost_model_guard(True) def test_sot_fast_with_multi_graph(self): x = paddle.rand([10]) net = paddle.nn.Linear(10, 10) @@ -84,7 +84,7 @@ def test_sot_fast_with_multi_graph(self): ) assert state == StepState.RUN_SOT - @cost_model_guard("True") + @cost_model_guard(True) def test_sot_fast_with_single_graph(self): x = paddle.rand([10]) net = paddle.nn.Linear(10, 10) @@ -98,7 +98,7 @@ def test_sot_fast_with_single_graph(self): ) assert state == StepState.RUN_SOT - @cost_model_guard("True") + @cost_model_guard(True) def test_net(self): x = paddle.rand([10]) net = Net() diff --git a/test/sot/test_enumerate.py b/test/sot/test_enumerate.py index f81a451da55c99..236eece7560d20 100644 --- a/test/sot/test_enumerate.py +++ b/test/sot/test_enumerate.py @@ -14,9 +14,10 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def test_enumerate_1(x: int, y: int): @@ -100,13 +101,13 @@ def test_cases(self): self.assert_results(test_enumerate_4, ty) # TODO(zmh): support range for tensor - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results(test_enumerate_5, paddle.to_tensor([1, 2, 3])) self.assert_results(test_enumerate_6, paddle.to_tensor([1, 2, 3])) self.assert_results(test_enumerate_7, ty) # TODO(zmh): support -1 - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results(test_enumerate_8, ty) self.assert_results(test_enumerate_10, layer_list, paddle.randn((10,))) diff --git a/test/sot/test_error_handling.py b/test/sot/test_error_handling.py index c74436f0d44f4f..4e5000cd0c50db 100644 --- a/test/sot/test_error_handling.py +++ b/test/sot/test_error_handling.py @@ -14,9 +14,10 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase from paddle.jit import sot +from paddle.jit.sot.utils import strict_mode_guard def fn_with_try_except(): @@ -30,7 +31,7 @@ def fn_with_try_except(): class TestErrorHandling(TestCaseBase): - @strict_mode_guard(0) + @strict_mode_guard(False) def test_fn_with_try_except(self): self.assert_results(fn_with_try_except) diff --git a/test/sot/test_map.py b/test/sot/test_map.py index 812ab36673be42..f005ec10cdbe4b 100644 --- a/test/sot/test_map.py +++ b/test/sot/test_map.py @@ -17,10 +17,11 @@ import unittest from typing import Iterable -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase from paddle.jit import sot from paddle.jit.sot.psdb import check_no_breakgraph +from paddle.jit.sot.utils import strict_mode_guard def double_num(num: float | int): @@ -110,7 +111,7 @@ def test_map_comprehension(self): ) def test_map_with_breakgraph(self): - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4]) def test_map_unpack(self): diff --git a/test/sot/test_numpy.py b/test/sot/test_numpy.py index 3600d4df7cc455..eb47e86b03b20d 100644 --- a/test/sot/test_numpy.py +++ b/test/sot/test_numpy.py @@ -15,9 +15,10 @@ import unittest import numpy as np -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle +from paddle.jit.sot.utils import strict_mode_guard def foo(x, y): @@ -32,7 +33,7 @@ def test_tensor_add_numpy_number(self): self.assert_results(foo, x, y) self.assert_results(foo, y, x) - @strict_mode_guard(0) + @strict_mode_guard(False) def test_tensor_add_numpy_array(self): x = paddle.to_tensor([1.0]) y = np.array(2.0) diff --git a/test/sot/test_numpy_var_if.py b/test/sot/test_numpy_var_if.py index 9d7c4a7048e251..6e098df70d3be0 100644 --- a/test/sot/test_numpy_var_if.py +++ b/test/sot/test_numpy_var_if.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy as np @@ -20,8 +19,9 @@ import paddle from paddle.jit.sot.psdb import check_no_breakgraph, check_no_fallback +from paddle.jit.sot.utils import ENV_MIN_GRAPH_SIZE -os.environ['MIN_GRAPH_SIZE'] = '-1' +ENV_MIN_GRAPH_SIZE.set(-1) @check_no_breakgraph diff --git a/test/sot/test_side_effects.py b/test/sot/test_side_effects.py index 46bed6e8d3c4e3..96ec9a7c5f6a78 100644 --- a/test/sot/test_side_effects.py +++ b/test/sot/test_side_effects.py @@ -16,12 +16,12 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle from paddle.jit import sot from paddle.jit.sot import symbolic_translate -from paddle.jit.sot.utils import InnerError +from paddle.jit.sot.utils import InnerError, strict_mode_guard def dict_setitem(x): @@ -275,7 +275,7 @@ def test_list_reverse(self): def test_slice_in_for_loop(self): x = 2 - with strict_mode_guard(0): + with strict_mode_guard(False): self.assert_results_with_side_effects(slice_in_for_loop, x) def test_list_nested(self): diff --git a/test/white_list/new_ir_op_test_no_check_list b/test/white_list/new_ir_op_test_no_check_list new file mode 100644 index 00000000000000..31c6df29fe12cb --- /dev/null +++ b/test/white_list/new_ir_op_test_no_check_list @@ -0,0 +1 @@ +test_exponential_op diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/new_ir_op_test_white_list index dea0398f9d5fac..626d6d80272d5b 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/new_ir_op_test_white_list @@ -79,6 +79,7 @@ test_elementwise_mul_op test_elementwise_pow_op test_erfinv_op test_expand_v2_op +test_exponential_op test_eye_op test_fill_any_op test_fill_constant_batch_size_like @@ -136,6 +137,7 @@ test_matrix_nms_op test_matrix_power_op test_maxout_op test_mean_op +test_memcpy_op test_mode_op test_multi_dot_op test_multiplex_op diff --git a/tools/codestyle/copyright.hook b/tools/codestyle/copyright.hook index 8985e3882cdd67..e007af33ce3cb9 100644 --- a/tools/codestyle/copyright.hook +++ b/tools/codestyle/copyright.hook @@ -36,7 +36,7 @@ def _generate_copyright(comment_mark): copyright=COPYRIGHT.split(os.linesep) header = copyright[0].rstrip() - p = re.search('(\d{4})', header).group(0) + p = re.search(r'(\d{4})', header).group(0) now = datetime.datetime.now() header = header.replace(p,str(now.year)) diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index d1cb054771535b..822f0a11fec21f 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -31,6 +31,7 @@ serial_list="^test_conv2d_op$|\ ^test_conv2d_transpose_op$|\ ^test_dygraph_dataparallel_bf16$|\ ^test_dygraph_sharding_stage1_fp16$|\ +^test_dygraph_sharding_stage1_bf16$|\ ^test_dygraph_sharding_stage2_bf16$|\ ^test_dygraph_sharding_stage3_bf16$|\ ^test_conv3d_op$"