Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#69 from mthreads/static_graph
Browse files Browse the repository at this point in the history
Static graph
  • Loading branch information
yaowang-mt authored and mt-robot committed Sep 6, 2023
2 parents 70dd030 + 5e4d91e commit 447ea98
Show file tree
Hide file tree
Showing 16 changed files with 266 additions and 38 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils static_prim_api get_expected_kernel_func)

if (WITH_MUSA)
register_operators(EXCLUDES stft_op cross_entropy_op gru_op batch_norm_op inplace_abn_op gaussian_random_batch_size_like_op top_k_op py_func_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} processgroup_comm_utils)
register_operators(EXCLUDES stft_op cross_entropy_op gru_op inplace_abn_op gaussian_random_batch_size_like_op top_k_op py_func_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} processgroup_comm_utils)
else()
register_operators(EXCLUDES py_func_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} processgroup_comm_utils)
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/platform/dynload/musartc.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ limitations under the License. */

#pragma once

#include <mtrtc.h>

#include <mutex> // NOLINT

#include "paddle/phi/backends/dynload/musartc.h"
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,14 @@ bool IsCompiledWithROCM() {
#endif
}

bool IsCompiledWithMUSA() {
#ifndef PADDLE_WITH_MUSA
return false;
#else
return true;
#endif
}

bool IsCompiledWithXPU() {
#ifndef PADDLE_WITH_XPU
return false;
Expand Down Expand Up @@ -1935,6 +1943,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_avx", IsCompiledWithAVX);
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_rocm", IsCompiledWithROCM);
m.def("is_compiled_with_musa", IsCompiledWithMUSA);
m.def("is_compiled_with_custom_device", IsCompiledWithCustomDevice);
m.def("is_compiled_with_ipu", IsCompiledWithIPU);
m.def("is_compiled_with_xpu", IsCompiledWithXPU);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@
outputs :
out : Output
extra :
attrs : [bool is_test = false, bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
attrs : [bool is_test = false, bool use_cudnn = true, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
Expand Down
90 changes: 87 additions & 3 deletions paddle/phi/backends/dynload/musartc.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,102 @@ limitations under the License. */

#pragma once

#include <mtrtc.h>
// #include <mtrtc.h>

#include <mutex> // NOLINT

#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"

// TODO(MTAI): The following musa runtime compiling functions are not supported
// now. Here empty implementations are given temporarily. When compiler MCC
// supports these functions, we will replace them.
typedef struct _mtrtcProgram *mtrtcProgram;

typedef enum {
MTRTC_SUCCESS = 0,
MTRTC_ERROR_OUT_OF_MEMORY = 1,
MTRTC_ERROR_PROGRAM_CREATION_FAILURE = 2,
MTRTC_ERROR_INVALID_INPUT = 3,
MTRTC_ERROR_INVALID_PROGRAM = 4,
MTRTC_ERROR_INVALID_OPTION = 5,
MTRTC_ERROR_COMPILATION = 6,
MTRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7,
MTRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8,
MTRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9,
MTRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10,
MTRTC_ERROR_INTERNAL_ERROR = 11
} mtrtcResult;

inline mtrtcResult mtrtcVersion(int *major, int *minor) {
PADDLE_THROW(
phi::errors::Unimplemented("mtrtcVersion is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

inline const char *mtrtcGetErrorString(mtrtcResult result) {
PADDLE_THROW(phi::errors::Unimplemented(
"mtrtcGetErrorString is not supported on MUSA now!"));
return "mtrtcGetErrorString is not supported on MUSA now!";
}

inline mtrtcResult mtrtcCompileProgram(mtrtcProgram prog,
int numOptions,
const char *const *options) {
PADDLE_THROW(phi::errors::Unimplemented(
"mtrtcCompileProgram is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

inline mtrtcResult mtrtcCreateProgram(mtrtcProgram *prog,
const char *src,
const char *name,
int numHeaders,
const char *const *headers,
const char *const *includeNames) {
PADDLE_THROW(phi::errors::Unimplemented(
"mtrtcCreateProgram is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

inline mtrtcResult mtrtcDestroyProgram(mtrtcProgram *prog) {
PADDLE_THROW(phi::errors::Unimplemented(
"mtrtcDestroyProgram is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

inline mtrtcResult mtrtcGetMUSA(mtrtcProgram prog, char *musa) {
PADDLE_THROW(
phi::errors::Unimplemented("mtrtcGetMUSA is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

inline mtrtcResult mtrtcGetMUSASize(mtrtcProgram prog, size_t *musaSizeRet) {
PADDLE_THROW(phi::errors::Unimplemented(
"mtrtcGetMUSASize is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

inline mtrtcResult mtrtcGetProgramLog(mtrtcProgram prog, char *log) {
PADDLE_THROW(phi::errors::Unimplemented(
"mtrtcGetProgramLog is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

inline mtrtcResult mtrtcGetProgramLogSize(mtrtcProgram prog,
size_t *logSizeRet) {
PADDLE_THROW(phi::errors::Unimplemented(
"mtrtcGetProgramLogSize is not supported on MUSA now!"));
return mtrtcResult::MTRTC_ERROR_INTERNAL_ERROR;
}

namespace phi {
namespace dynload {

extern std::once_flag musartc_dso_flag;
extern void* musartc_dso_handle;
extern void *musartc_dso_handle;
extern bool HasNVRTC();

#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \
Expand All @@ -36,7 +120,7 @@ extern bool HasNVRTC();
std::call_once(musartc_dso_flag, []() { \
musartc_dso_handle = phi::dynload::GetNVRTCDsoHandle(); \
}); \
static void* p_##__name = dlsym(musartc_dso_handle, #__name); \
static void *p_##__name = dlsym(musartc_dso_handle, #__name); \
return reinterpret_cast<musartc_func>(p_##__name)(args...); \
} \
}; \
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ if(WITH_MUSA)
"gpu/depthwise_conv_grad_kernel.cu"
"gpu/depthwise_conv_kernel.cu"
"gpu/dist_kernel.cu"
"gpu/elementwise_divide_grad_kernel.cu"
"gpu/elementwise_grad_kernel.cu"
"gpu/erfinv_kernel.cu"
"gpu/exponential_kernel.cu"
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/kernels/batch_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ PD_REGISTER_KERNEL(batch_norm_infer,
float,
phi::dtype::float16) {}
#endif
#ifdef PADDLE_WITH_MUSA
PD_REGISTER_KERNEL(batch_norm_infer,
GPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
phi::dtype::float16) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(batch_norm_infer,
XPU,
Expand Down
Loading

0 comments on commit 447ea98

Please sign in to comment.