Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#77 from tiancaitzp/paddlebox
Browse files Browse the repository at this point in the history
remove redundant device context create in bkcl init flow.
  • Loading branch information
tiancaitzp authored Jun 5, 2024
2 parents 83e62b7 + 462ca00 commit 4acc69b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
32 changes: 16 additions & 16 deletions paddle/fluid/platform/collective_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,20 +285,25 @@ class BKCLCommImpl : public BKCLComm {
BKCLContext_t comm() const override { return comm_; }

XPUStream stream() const override {
return dev_ctx_->x_context()->xpu_stream;
return stream_;
}

void set_dev_ctx(std::unique_ptr<XPUDeviceContext>&& dev_ctx) {
dev_ctx_ = std::move(dev_ctx);
void set_dev_ctx(XPUDeviceContext* dev_ctx) {
dev_ctx_ = dev_ctx;
}

void set_stream(XPUStream stream) {
stream_ = stream;
}
XPUDeviceContext* dev_context() const override { return dev_ctx_.get(); }
XPUDeviceContext* dev_context() const override { return dev_ctx_; }

private:
int ring_id_;
int nranks_;
int rank_;
BKCLContext_t comm_;
std::unique_ptr<XPUDeviceContext> dev_ctx_;
XPUDeviceContext* dev_ctx_;
XPUStream stream_;
};

BKCLComm* BKCLCommContext::CreateComm(
Expand Down Expand Up @@ -408,21 +413,23 @@ void BKCLCommContext::CreateBKCLCommMultiTrainer(

BKCLComm* BKCLCommContext::AssignBKCLComm(
BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) {
std::unique_ptr<XPUDeviceContext> dev_ctx(
new XPUDeviceContext(XPUPlace(dev_id)));

auto dev_ctx =
static_cast<platform::XPUDeviceContext*>(platform::DeviceContextPool::Instance().Get(platform::XPUPlace(dev_id)));
dev_ctx->SetBkclContext(comm);
// used in BKCL as comm_stream, for every dev_id there is
// a comm_stream at each ring. this stream is passed as input var
// when calling collective comm commands like bkcl_all_reduce
XPUStream comm_stream;
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_stream));
dev_ctx->SetXPUStream(comm_stream);

BKCLCommImpl* c = new BKCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx));
c->set_dev_ctx(dev_ctx);
c->set_stream(comm_stream);

comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
Expand All @@ -433,13 +440,6 @@ BKCLComm* BKCLCommContext::AssignBKCLComm(
dev2comm.emplace(dev_id, std::unique_ptr<BKCLComm>(c));
comm_map_mutex_.unlock();

if (ring_id == 0) {
auto* dev_ctx = static_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::XPUPlace(dev_id)));
dev_ctx->SetBkclContext(comm);
}

return comm_map_[ring_id][dev_id].get();
}

Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/platform/device/xpu/bkcl_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ struct BKCLContext {
: ctx_(new platform::XPUDeviceContext(XPUPlace(dev_id))),
comm_{nullptr} {}

explicit BKCLContext(platform::Place place)
: ctx_(static_cast<platform::XPUDeviceContext*>(platform::DeviceContextPool::Instance().Get(place))),
comm_{nullptr} {}

BKCLContext_t comm() const { return comm_; }

int device_id() const { return ctx_->GetPlace().device; }
Expand Down Expand Up @@ -107,7 +111,7 @@ struct BKCLContextMap {
for (auto &p : places_) {
int dev_id = p.device;
order_.emplace_back(dev_id);
contexts_.emplace(dev_id, BKCLContext(dev_id));
contexts_.emplace(dev_id, BKCLContext(p));
}
PADDLE_ENFORCE_EQ(
order_.size(),
Expand Down

0 comments on commit 4acc69b

Please sign in to comment.