Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update fluid device_context #44418

Merged
merged 1 commit into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,8 @@ void AnalysisPredictor::InitPlace() {
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (config_.thread_local_stream_enabled()) {
auto *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
VLOG(3) << "The prediction process will be completed using a separate "
"normal-priority stream on each thread.";
ctx->ResetThreadContext(platform::stream::Priority::kNormal);
LOG_FIRST_N(WARNING, 1) << "We will remove this interface in the future. "
"Please use config.SetExecStream instead.";
}
#endif
} else if (config_.use_xpu()) {
Expand Down Expand Up @@ -1621,14 +1618,8 @@ bool AnalysisPredictor::ZeroCopyRun() {

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
if (stream != nullptr) {
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto gpu_place = place_;
auto *dev_ctx = reinterpret_cast<paddle::platform::CUDADeviceContext *>(
pool.Get(gpu_place));
dev_ctx->SetThreadLocalStream(stream);
}
LOG_FIRST_N(WARNING, 1) << "We will remove this interface in the future. "
"Please use config.SetExecStream instead.";
return ZeroCopyRun();
}
#endif
Expand Down
185 changes: 0 additions & 185 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,198 +534,13 @@ void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

thread_local std::unordered_map<const CUDADeviceContext*,
std::shared_ptr<CUDAContext>>
CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&RawStream(), place_);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
const stream::Priority& priority,
const stream::StreamFlag& flag) {
place_ = place;
CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority, flag));
InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
InitCuSparseContext();
InitCuSolverContext();
#endif
}

void CUDAContext::SetStream(gpuStream_t stream) {
if (stream_->raw_stream() != stream) {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
DestoryCuBlasLtContext();
#endif
DestoryCuSolverContext();
#endif

stream_->SetStream(stream);

InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
InitCuSolverContext();
#endif
}
}

CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
InitCuBlasLtContext();
#endif
DestoryCuSparseContext();
DestoryCuSolverContext();
#endif
}

CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
phi::GPUContext::PartialInitWithoutAllocator();
cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
}

CUDADeviceContext::~CUDADeviceContext() = default;

Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
if (thread_ctx_.count(this)) {
return context()->EigenDevice().get();
}
return phi::GPUContext::eigen_device();
}

void CUDADeviceContext::Wait() const {
VLOG(4) << "CUDA context(" << this << ") Wait";
if (thread_ctx_.count(this)) {
context()->Stream()->Wait();
return;
}
phi::GPUContext::Wait();
}

#ifdef PADDLE_WITH_HIP
miopenHandle_t CUDADeviceContext::cudnn_handle() const {
#else
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
#endif
if (thread_ctx_.count(this)) {
return context()->CudnnHandle();
}
return phi::GPUContext::cudnn_handle();
}

#ifdef PADDLE_WITH_HIP
rocblas_handle CUDADeviceContext::cublas_handle() const {
if (thread_ctx_.count(this)) {
return context()->CublasHandle()->GetCublasHandle();
}
return phi::GPUContext::cublas_handle();
}
#else
cublasHandle_t CUDADeviceContext::cublas_handle() const {
if (thread_ctx_.count(this)) {
return context()->CublasHandle()->GetCublasHandle();
}
return phi::GPUContext::cublas_handle();
}
#if CUDA_VERSION >= 11060
cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const {
if (thread_ctx_.count(this)) {
return context()->CublasLtHandle()->GetCublasLtHandle();
}
return phi::GPUContext::cublaslt_handle();
}
#endif
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
if (thread_ctx_.count(this)) {
return context()->CusparseHandle()->GetCusparseHandle();
}
return phi::GPUContext::cusparse_handle();
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
if (thread_ctx_.count(this)) {
return context()->CusolverDnHandle();
}
return phi::GPUContext::cusolver_dn_handle();
}
#endif

void CUDADeviceContext::RecordEvent(
gpuEvent_t ev, const std::function<void()>& callback) const {
if (thread_ctx_.count(this)) {
context()->Stream()->RecordEvent(ev, callback);
return;
}
phi::GPUContext::RecordEvent(ev, callback);
}

void CUDADeviceContext::AddStreamCallback(
const std::function<void()>& callback) const {
if (thread_ctx_.count(this)) {
context()->Stream()->AddCallback(callback);
return;
}
phi::GPUContext::AddStreamCallback(callback);
}

void CUDADeviceContext::WaitStreamCallback() const {
if (thread_ctx_.count(this)) {
context()->Stream()->WaitCallback();
return;
}
phi::GPUContext::WaitStreamCallback();
}

phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
if (thread_ctx_.count(this)) {
// return workspace_.get();
return phi::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(GetPlace())
.get(),
stream());
}
return phi::GPUContext::cudnn_workspace_handle();
}

gpuStream_t CUDADeviceContext::stream() const {
if (thread_ctx_.count(this)) {
return context()->RawStream();
}
return phi::GPUContext::stream();
}

std::shared_ptr<CUDAContext> CUDADeviceContext::context() const {
if (!thread_ctx_.count(this)) {
PADDLE_THROW(platform::errors::PermissionDenied(
"CUDADeviceContext call context() failed, make sure in the "
"thread_local semantic."));
}
return thread_ctx_.at(this);
}

stream::CUDAStream* CUDADeviceContext::GetCudaStream() const {
return cuda_stream_.get();
}
Expand Down
Loading