Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… type_hint_ci
  • Loading branch information
megemini committed May 20, 2024
2 parents e7be07d + e474970 commit 618c3b9
Show file tree
Hide file tree
Showing 1,456 changed files with 52,606 additions and 28,038 deletions.
5 changes: 1 addition & 4 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@ insert_final_newline = true
[*.{c,cc,cxx,cpp,cu,cuh,h,hpp,hxx,kps}]
indent_size = 2

[*.{py,java,r}]
[*.{py,pyi,java,r,toml}]
indent_size = 4

[Dockerfile.*]
indent_size = 4

[.flake8]
indent_size = 4

[*.go]
indent_style = tab
indent_size = 4
75 changes: 75 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1 +1,76 @@
# This file is migrated from CI script, it's an effort of modenizing our dev infra.
# Code owners are expected to take responsibility for review patches to respective file.

/CMakeLists.txt @wanghuancoder @Aurelius84 @XiaoguangHu01 @qili93
paddle/fluid/eager/autograd_meta.cc @JiabinYang @phlrain
paddle/fluid/eager/autograd_meta.h @JiabinYang @phlrain
paddle/fluid/eager/backward.cc @JiabinYang @phlrain
paddle/fluid/eager/backward.h @JiabinYang @phlrain
paddle/fluid/eager/grad_node_info.cc @JiabinYang @phlrain
paddle/fluid/eager/grad_node_info.h @JiabinYang @phlrain
paddle/fluid/eager/grad_tensor_holder.cc @JiabinYang @phlrain
paddle/fluid/eager/grad_tensor_holder.h @JiabinYang @phlrain
paddle/fluid/eager/tensor_wrapper.h @JiabinYang @phlrain
paddle/fluid/framework/block_desc.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/details/op_registry.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/framework.proto @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/grad_op_desc_maker.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/ir/graph.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/ir/node.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/lod_tensor.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/op_desc.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/operator.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/scope.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/selected_rows.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/tensor.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/framework/unused_var_check.cc @zhiqiu @phlrain
paddle/fluid/framework/var_desc.h @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
paddle/fluid/operators/distributed/send_recv.proto.in @gongweibao @seiriosPlus
paddle/fluid/prim/api/api.yaml @cxxly @xiaoguoguo626807 @Charles-hit @cyber-pioneer @JiabinYang
paddle/fluid/prim/api/composite_backward/composite_backward_api.h @cxxly @xiaoguoguo626807 @Charles-hit @cyber-pioneer @JiabinYang
paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @cxxly @xiaoguoguo626807 @Charles-hit @cyber-pioneer @JiabinYang
paddle/fluid/prim/api/manual_prim/prim_manual_api.h @cxxly @xiaoguoguo626807 @Charles-hit @cyber-pioneer @JiabinYang
paddle/phi/api/include/tensor.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/attribute.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/dense_tensor.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/device_context.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/infermeta_utils.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/kernel_context.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/kernel_factory.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/kernel_registry.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/kernel_utils.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/meta_tensor.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/tensor_base.h @phlrain @zyfncg @YuanRisheng
paddle/phi/core/tensor_meta.h @phlrain @zyfncg @YuanRisheng
paddle/phi/infermeta/spmd_rules @LiYuRio @ForFishes @zhiqiu
paddle/scripts/paddle_build.bat @zhwesky2010 @wanghuancoder @Aurelius84
paddle/scripts/paddle_build.sh @risemeup1 @zhangbo9674 @XieYunshen
pyproject.toml @SigureMo @gouzil
python/paddle/autograd/backward_utils.py @Aurelius84 @cxxly @xiaoguoguo626807 @changeyoung98
python/paddle/autograd/ir_backward.py @Aurelius84 @cxxly @xiaoguoguo626807 @changeyoung98
python/paddle/base/backward.py @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
python/paddle/base/compiler.py @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
python/paddle/base/dygraph/layers.py @JiabinYang @phlrain
python/paddle/base/framework.py @XiaoguangHu01 @zhiqiu @Xreki @qili93 @Aurelius84
python/paddle/base/__init__.py @phlrain @Aurelius84 @qili93
python/paddle/base/parallel_executor.py @Xreki @zhhsplendid @Aurelius84
python/paddle/base/tests/unittests/white_list/check_op_sequence_batch_1_input_white_list.py @Aurelius84 @phlrain
python/paddle/base/tests/unittests/white_list/check_op_sequence_instance_0_input_white_list.py @Aurelius84 @phlrain
python/paddle/base/tests/unittests/white_list/check_shape_white_list.py @hong19860320 @Aurelius84 @phlrain
python/paddle/base/tests/unittests/white_list/compile_vs_runtime_white_list.py @Aurelius84 @phlrain
python/paddle/base/tests/unittests/white_list/no_check_set_white_list.py @Aurelius84 @phlrain
python/paddle/base/tests/unittests/white_list/no_grad_set_white_list.py @Aurelius84 @phlrain
python/paddle/base/tests/unittests/white_list/op_accuracy_white_list.py @juncaipeng @zhangting2020 @Aurelius84
python/paddle/base/tests/unittests/white_list/op_threshold_white_list.py @juncaipeng @zhangting2020 @Aurelius84
python/paddle/distributed/fleet/__init__.py @sneaxiy @raindrops2sea
python/paddle/distributed/fleet/launch.py @sneaxiy @raindrops2sea
python/paddle/distributed/__init__.py @sneaxiy @raindrops2sea
python/paddle/incubate/autograd/composite_rules.py @cyber-pioneer @xiaoguoguo626807 @Charles-hit @JiabinYang
python/paddle/incubate/autograd/primitives.py @cyber-pioneer @xiaoguoguo626807 @Charles-hit @JiabinYang
python/paddle/_typing @SigureMo @zrr1999 @gouzil
python/requirements.txt @phlrain @jzhang533 @kolinwei
test/dygraph_to_static @SigureMo @Aurelius84 @gouzil
test/sot @SigureMo @Aurelius84 @gouzil
tools/parallel_UT_rule.py @zhwesky2010 @wanghuancoder @Aurelius84
tools/windows/run_unittests.sh @zhwesky2010 @wanghuancoder @Aurelius84
.pre-commit-config.yaml @SigureMo @gouzil
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ repos:
name: copyright_checker
entry: python ./tools/codestyle/copyright.py
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|pyi|sh)$
exclude: |
(?x)^(
paddle/utils/.*|
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ option(WITH_PIP_CUDA_LIBRARIES
"Paddle uses the CUDA library provided by NVIDIA" OFF)
option(WITH_NIGHTLY_BUILD
"Compile nightly paddle whl package of the develop branch" OFF)
option(WITH_CPP_TEST "Compile PaddlePaddle skip cpp test" ON)
find_package(Git REQUIRED)

# config GIT_URL with github mirrors to speed up dependent repos clone
Expand Down
13 changes: 13 additions & 0 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ function(select_nvcc_arch_flags out_variable out_arch_bin)
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
if(WITH_NV_JETSON)
set(cuda_arch_bin "87")
else()
Expand All @@ -183,6 +185,8 @@ function(select_nvcc_arch_flags out_variable out_arch_bin)
endif()
endif()
elseif(${CUDA_ARCH_NAME} STREQUAL "Hopper")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
set(cuda_arch_bin "90")
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
Expand All @@ -196,8 +200,17 @@ function(select_nvcc_arch_flags out_variable out_arch_bin)
to get a full wheel package to resolve this warning.
While, this version will still work on local GPU architecture.")
detect_installed_gpus(cuda_arch_bin)
if(${cuda_arch_bin} MATCHES "[ ]*(8\.0|8\.6|8\.9|9\.0)[ ]*")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
endif()
else() # (${CUDA_ARCH_NAME} STREQUAL "Manual")
set(cuda_arch_bin ${CUDA_ARCH_BIN})

if(${CUDA_ARCH_BIN} MATCHES "[ ]*(80|86|89|90)[ ]*")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
endif()
endif()

if(NEW_RELEASE_JIT)
Expand Down
9 changes: 7 additions & 2 deletions cmake/external/eigen.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ endif()
set(EIGEN_INCLUDE_DIR ${SOURCE_DIR})
include_directories(${EIGEN_INCLUDE_DIR})
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=maybe-uninitialized")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-error=maybe-uninitialized")
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=maybe-uninitialized")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-error=maybe-uninitialized")
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=uninitialized")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-error=uninitialized")
endif()
endif()
ExternalProject_Add(
extern_eigen3
Expand Down
18 changes: 18 additions & 0 deletions cmake/external/pybind11.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ set(SOURCE_INCLUDE_DIR ${SOURCE_DIR}/include)

include_directories(${PYBIND_INCLUDE_DIR})

# It can be safely removed in gcc9.1+
set(PYBIND_PATCH_COMMAND "")
if(LINUX
AND (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9)
set(PYBIND_TAG v2.12.0)
file(TO_NATIVE_PATH
${PADDLE_SOURCE_DIR}/patches/pybind/detail/internals.h.patch native_dst)
# Note: [Why calling some `git` commands before `patch`?]
# Paddle's CI uses cache to accelerate the make process. However, error might raise when patch codes in two scenarios:
# 1. Patch to the wrong version: the tag version of CI's cache falls behind PYBIND_TAG, use `git checkout ${PYBIND_TAG}` to solve this.
# 2. Patch twice: the tag version of cache == PYBIND_TAG, but patch has already applied to cache.
set(PYBIND_PATCH_COMMAND
git checkout -- . && git checkout ${PYBIND_TAG} && patch -Nd
${SOURCE_INCLUDE_DIR}/pybind11/detail < ${native_dst})
endif()

ExternalProject_Add(
extern_pybind
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
Expand All @@ -33,6 +50,7 @@ ExternalProject_Add(
# third-party library version changes cannot be incorporated.
# reference: https://cmake.org/cmake/help/latest/module/ExternalProject.html
UPDATE_COMMAND ""
PATCH_COMMAND ${PYBIND_PATCH_COMMAND}
CONFIGURE_COMMAND ""
# I intentionally preserved an extern_pybind/include/pybind11 directory
# to site-packages, so that you could discern that you intended to
Expand Down
12 changes: 11 additions & 1 deletion cmake/external/warpctc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ if(NOT WIN32 AND WITH_GPU)
endif()
endif()

if(WITH_ROCM)
set(WARPCTC_PATHCH_ROCM_COMMAND
patch -p1 <
${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.rocm.patch && patch
-p1 < ${PADDLE_SOURCE_DIR}/patches/warpctc/devicetypes.cuh.patch)
endif()

set(WARPCTC_INCLUDE_DIR
"${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE)
Expand Down Expand Up @@ -100,7 +107,10 @@ ExternalProject_Add(
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${WARPCTC_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND ${WARPCTC_PATCH_COMMAND} ${WARPCTC_PATCH_CUDA_COMMAND}
PATCH_COMMAND
COMMAND ${WARPCTC_PATCH_COMMAND}
COMMAND ${WARPCTC_PATCH_CUDA_COMMAND}
COMMAND ${WARPCTC_PATHCH_ROCM_COMMAND}
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
Expand Down
10 changes: 8 additions & 2 deletions cmake/external/warprnnt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ else()
${SOURCE_DIR} <
${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.cuda.patch)
endif()

if(WITH_ROCM)
set(WARPRNNT_PATCH_ROCM_COMMAND
patch -p1 <
${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.rocm.patch)
endif()
if(NOT WIN32 AND WITH_GPU)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0 AND ${CMAKE_CXX_COMPILER_VERSION}
VERSION_GREATER 12.0)
Expand Down Expand Up @@ -99,7 +103,9 @@ ExternalProject_Add(
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${WARPRNNT_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND ${WARPCTC_PATCH_CUDA_COMMAND}
PATCH_COMMAND
COMMAND ${WARPCTC_PATCH_CUDA_COMMAND}
COMMAND ${WARPRNNT_PATCH_ROCM_COMMAND}
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ if(NOT DEFINED XPU_XDNN_BASE_DATE)
set(XPU_XDNN_BASE_DATE "20240327")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20240422")
set(XPU_XHPC_BASE_DATE "20240515")
endif()
set(XPU_XCCL_BASE_VERSION "1.2.0.5")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
10 changes: 8 additions & 2 deletions cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ macro(safe_set_nvflag flag_name)
check_c_compiler_flag(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name})
set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name})
if(${safe_name})
set(SAFE_GPU_COMMON_FLAGS
"${SAFE_GPU_COMMON_FLAGS} -Xcompiler=\"${flag_name}\"")
if(WITH_ROCM)
set(SAFE_GPU_COMMON_FLAGS
"${SAFE_GPU_COMMON_FLAGS} -Xcompiler \"${flag_name}\"")
else()
set(SAFE_GPU_COMMON_FLAGS
"${SAFE_GPU_COMMON_FLAGS} -Xcompiler=\"${flag_name}\"")
endif()
endif()
endmacro()

Expand Down Expand Up @@ -279,6 +284,7 @@ endif()

# Disable -Werror, otherwise the compile will fail for rocblas_gemm_ex
if(WITH_ROCM)
string(REPLACE "-Werror" "-Wno-error" HIP_HIPCC_FLAGS ${HIP_HIPCC_FLAGS})
string(REPLACE "-Werror" "-Wno-error" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE "-Werror" "-Wno-error" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
endif()
Expand Down
12 changes: 12 additions & 0 deletions cmake/hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ else()
CACHE PATH "Path to which clang has been installed")
endif()
set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH})
set(CMAKE_PREFIX_PATH "${ROCM_PATH}" ${CMAKE_PREFIX_PATH})

find_package(HIP REQUIRED)
include_directories(${ROCM_PATH}/include)
Expand Down Expand Up @@ -123,6 +124,17 @@ list(APPEND HIP_CXX_FLAGS -Wno-switch)
list(APPEND HIP_CXX_FLAGS -Wno-literal-conversion)
list(APPEND HIP_CXX_FLAGS -Wno-constant-conversion)
list(APPEND HIP_CXX_FLAGS -Wno-defaulted-function-deleted)
list(APPEND HIP_CXX_FLAGS -Wno-sign-compare)
list(APPEND HIP_CXX_FLAGS -Wno-bitwise-instead-of-logical)
list(APPEND HIP_CXX_FLAGS -Wno-unknown-warning-option)
list(APPEND HIP_CXX_FLAGS -Wno-unused-lambda-capture)
list(APPEND HIP_CXX_FLAGS -Wno-unused-variable)
list(APPEND HIP_CXX_FLAGS -Wno-unused-but-set-variable)
list(APPEND HIP_CXX_FLAGS -Wno-reorder-ctor)
list(APPEND HIP_CXX_FLAGS -Wno-deprecated-copy-with-user-provided-copy)
list(APPEND HIP_CXX_FLAGS -Wno-unused-local-typedef)
list(APPEND HIP_CXX_FLAGS -Wno-missing-braces)
list(APPEND HIP_CXX_FLAGS -Wno-sometimes-uninitialized)

if(WITH_CINN)
list(APPEND HIP_CXX_FLAGS -std=c++14)
Expand Down
8 changes: 8 additions & 0 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ function(copy_part_of_third_party TARGET DST)
SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib)

if(WITH_FLASHATTN)
set(dst_dir "${DST}/third_party/install/flashattn")
copy(
${TARGET}
SRCS ${FLASHATTN_INCLUDE_DIR} ${FLASHATTN_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib)
endif()

if(NOT PROTOBUF_FOUND OR WIN32)
set(dst_dir "${DST}/third_party/install/protobuf")
copy(
Expand Down
37 changes: 33 additions & 4 deletions paddle/cinn/adt/anchor_sd_equation_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/cinn/adt/anchor_sd_equation_context.h"
#include "paddle/common/enforce.h"

namespace cinn::adt::config {

Expand All @@ -27,7 +28,14 @@ void GenerateScheduleMeshEquationsImpl(const List<ScheduleDim>& sched_dims,
const List<Iterator>& input_iterators,
const List<Iterator>& output_iterators,
Equations* equations) {
CHECK_EQ(input_iterators->size(), output_iterators->size());
PADDLE_ENFORCE_EQ(
input_iterators->size() == output_iterators->size(),
true,
phi::errors::InvalidArgument(
"The size of input iterators and output iterators should be equal, "
"but got input iterators size = %d, output iterators size = %d.",
input_iterators->size(),
output_iterators->size()));
for (std::size_t i = 0; i < output_iterators->size(); ++i) {
Equal(input_iterators->at(i), output_iterators->at(i), equations);
}
Expand All @@ -42,7 +50,14 @@ void GenerateScheduleMeshEquationsImpl(
List<Iterator> middle_iterators =
MakeIterators(GetOutputRank(middle_sched_mesh));
List<DimExpr> middle_dims = GetOutputDimValues(middle_sched_mesh);
CHECK_EQ(shape.value()->size(), output_iterators->size());
PADDLE_ENFORCE_EQ(
shape.value()->size() == output_iterators->size(),
true,
phi::errors::InvalidArgument(
"The size of shape and output iterators should be equal, but got "
"shape size = %d, output iterators size = %d.",
shape.value()->size(),
output_iterators->size()));
List<DimExpr> output_dims = GetOutputDimValues(ScheduleMesh{sched_reshape});
const auto& middle_index = MakeDot(middle_iterators, middle_dims, equations);
const auto& output_index = MakeDot(output_iterators, output_dims, equations);
Expand All @@ -58,7 +73,14 @@ void GenerateScheduleMeshEquationsImpl(
const List<Iterator>& output_iterators,
Equations* equations) {
const auto& [sched_mesh, perm] = sched_transpose.tuple();
CHECK_EQ(GetOutputRank(sched_mesh), output_iterators->size());
PADDLE_ENFORCE_EQ(GetOutputRank(sched_mesh) == output_iterators->size(),
true,
phi::errors::InvalidArgument(
"The size of output iterators should be equal to the "
"rank of the schedule mesh, but got output iterators "
"size = %d, rank of the schedule mesh = %d.",
output_iterators->size(),
GetOutputRank(sched_mesh)));
List<Iterator> middle_iterators = MakeIterators(output_iterators->size());
for (std::size_t i = 0; i < perm.value()->size(); ++i) {
Equal(middle_iterators->at(perm.value()->at(i)),
Expand All @@ -75,7 +97,14 @@ void GenerateScheduleMeshEquationsImpl(
const List<Iterator>& output_iterators,
Equations* equations) {
const auto& [sched_mesh, _] = sched_padding.tuple();
CHECK_EQ(GetOutputRank(sched_mesh), output_iterators->size());
PADDLE_ENFORCE_EQ(GetOutputRank(sched_mesh) == output_iterators->size(),
true,
phi::errors::InvalidArgument(
"The size of output iterators should be equal to the "
"rank of the schedule mesh, but got output iterators "
"size = %d, rank of the schedule mesh = %d.",
output_iterators->size(),
GetOutputRank(sched_mesh)));
List<Iterator> middle_iterators = MakeIterators(output_iterators->size());
for (std::size_t i = 0; i < output_iterators->size(); ++i) {
Equal(middle_iterators->at(i), output_iterators->at(i), equations);
Expand Down
Loading

0 comments on commit 618c3b9

Please sign in to comment.