Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Sep 18, 2024
1 parent ddd4f78 commit 24d0092
Show file tree
Hide file tree
Showing 18 changed files with 202 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ void PirDependencyBuilder::AddDependencyForCommunicationOp() {
// c_allreduce_sum(b)
// c_allreduce_sum(c)
// c_sync_comm_stream(a)
const std::string kSyncComm = dialect::CSyncCommStreamOp::name();
const std::string kSyncComm = dialect::SyncCommStreamOp::name();
dependence_op_idx = ULLONG_MAX;
for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) {
if (instructions_.at(op_idx)->Name() == kSyncComm) {
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3902,6 +3902,20 @@ struct WithXShapeAndAxisGradOpTranscriber : public OpTranscriber {
}
};

struct SyncCommStreamOpTranscriber : public OpTranscriber {
pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name = "pd_op.sync_comm_stream_";
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
PADDLE_THROW(common::errors::InvalidArgument(
"Op c_sync_comm_stream should have corresponding "
"OpInfo pd_op.sync_comm_stream_."));
}
return op_info;
}
};

OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down Expand Up @@ -4012,6 +4026,8 @@ OpTranslator::OpTranslator() {
WithXShapeAndAxisGradOpTranscriber<dialect::SqueezeGradOp>();
special_handlers["unsqueeze2_grad"] =
WithXShapeAndAxisGradOpTranscriber<dialect::UnsqueezeGradOp>();

special_handlers["c_sync_comm_stream"] = SyncCommStreamOpTranscriber();
}
} // namespace translator
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_sync_comm_stream_op.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {
Expand Down
19 changes: 0 additions & 19 deletions paddle/fluid/operators/collective/c_sync_comm_stream_op.cu.cc

This file was deleted.

113 changes: 0 additions & 113 deletions paddle/fluid/operators/collective/c_sync_comm_stream_op.h

This file was deleted.

43 changes: 0 additions & 43 deletions paddle/fluid/operators/collective/c_sync_comm_stream_op.kps

This file was deleted.

28 changes: 0 additions & 28 deletions paddle/fluid/operators/collective/c_sync_comm_stream_op_xpu.cc

This file was deleted.

27 changes: 27 additions & 0 deletions paddle/fluid/operators/ops_signature/c_sync_comm_stream_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/compat/op_utils.h"

namespace phi {

KernelSignature CSyncCommStreamOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("sync_comm_stream", {"X"}, {"ring_id"}, {"Out"});
}

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(c_sync_comm_stream,
phi::CSyncCommStreamOpArgumentMapping);
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@
'shadow_feed',
'shadow_feed_tensors',
'sparse_momentum',
'sync_comm_stream',
'sync_comm_stream_',
'soft_relu',
'match_matrix_tensor',
'c_reduce_max',
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ const std::unordered_set<std::string> LegacyOpList = {
CConcatOp::name(),
CBroadcast_Op::name(),
CBroadcastOp::name(),
CSyncCommStream_Op::name(),
DistributedPushSparseOp::name(),
SendV2Op::name(),
RecvV2Op::name(),
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1351,9 +1351,9 @@ phi::KernelKey GetKernelKey(
VLOG(8) << "LoadCombineOp's kernel data type must be FLOAT32";
}

if (op->isa<CSyncCommStream_Op>() || op->isa<CSyncCommStreamOp>()) {
if (op->isa<SyncCommStream_Op>() || op->isa<SyncCommStreamOp>()) {
res.set_dtype(phi::DataType::FLOAT32);
VLOG(8) << "CSyncCommStream_Op/CSyncCommStreamOp's kernel data type must "
VLOG(8) << "SyncCommStream_Op/SyncCommStreamOp's kernel data type must "
"be FLOAT32";
}

Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/kernels/gpu/sync_comm_stream_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/impl/sync_comm_stream_kernel.h"

PD_REGISTER_KERNEL(
sync_comm_stream, GPU, ALL_LAYOUT, phi::SyncCommStreamKernel, float) {}
53 changes: 53 additions & 0 deletions paddle/phi/kernels/impl/sync_comm_stream_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/phi/backends/xpu/xpu_info.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/kernel_registry.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#elif defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
#endif

namespace phi {

#if defined(PADDLE_WITH_XPU_BKCL)
static void XPUStreamSync(XPUStream stream) {
PADDLE_ENFORCE_XDNN_SUCCESS(xpu_wait(stream), "xpu_wait");
}
#endif

template <typename T, typename Context>
void SyncCommStreamKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &x UNUSED,
int ring_id UNUSED,
std::vector<DenseTensor *> out UNUSED) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
phi::backends::gpu::GpuStreamSync(dev_ctx.stream());
#elif defined(PADDLE_WITH_XPU_BKCL)
XPUStreamSync(dev_ctx.stream());
#else
PADDLE_THROW(common::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU or XPU."));
#endif
}

} // namespace phi
Loading

0 comments on commit 24d0092

Please sign in to comment.