Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

common: delay ps initialization for better RDMA compatibility #91

Merged
merged 5 commits into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/ps-lite
Submodule ps-lite updated 3 files
+4 −4 Makefile
+6 −1 src/customer.cc
+53 −96 src/rdma_van.h
29 changes: 17 additions & 12 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,6 @@ void BytePSGlobal::Init() {

_shm_obj = std::make_shared<BytePSSharedMemory>(); // share memory obj

if (IsDistributed() &&
_my_role ==
BytePSRole::LOCAL_ROOT) { // only the root need to do networking
// init low-level ps implementation
_ps = new ps::KVWorker<char>(0, 0);
ps::StartAsync(0, "byteps\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}

// Set to associated GPU
CUDA_CALL(cudaSetDevice(_local_rank));

Expand Down Expand Up @@ -200,6 +188,23 @@ void BytePSGlobal::Init() {
return;
}

ps::KVWorker<char>* BytePSGlobal::GetOrInitPS() {
// we reuse _init_mutex, because BytePS should have been inited
std::lock_guard<std::mutex> lock(_init_mutex);
if (!_ps && IsDistributed() &&
_my_role ==
BytePSRole::LOCAL_ROOT) { // only the root needs networking
// init low-level ps implementation
_ps = new ps::KVWorker<char>(0, 0);
ps::StartAsync(0, "byteps\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}
return _ps;
}

void BytePSGlobal::Start(const std::vector<LoopFunction>& func) {
// Start background threads
for (size_t i = 0; i < func.size(); i++) {
Expand Down
1 change: 1 addition & 0 deletions byteps/common/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class BytePSGlobal {
static BytePSScheduledQueue* GetScheduledQueue(QueueType queueType);
static void CreateScheduledQueue(QueueType queueType);
static ps::KVWorker<char>* GetPS() { return _ps; }
static ps::KVWorker<char>* GetOrInitPS();

static bool IsTensorDeclared(const std::string& name);
static ps::Key GetKeyFromName(const std::string& name);
Expand Down
6 changes: 3 additions & 3 deletions byteps/common/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,15 @@ void InitTensor(BPSContext &context, size_t size, int dtype, void *cpubuff) {
int len = ((size - accumulated) > bound) ? bound : (size - accumulated);

if (BytePSGlobal::IsDistributed() && BytePSGlobal::IsRootDevice()) {
auto ps = BytePSGlobal::GetOrInitPS();
// encode the key for pskv scattering
auto &pskv = BytePSGlobal::EncodeDefaultKey(key, len);
// false means not to delete data when SArray is deleted
ps::SArray<char> vals(data + accumulated, len, false);
// cmd type
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
// blocking push, also as a global barrier
BytePSGlobal::GetPS()->Wait(
BytePSGlobal::GetPS()->ZPush(pskv.keys, vals, pskv.lens, cmd));
// blocking push, also as a global barrirer
ps->Wait(ps->ZPush(pskv.keys, vals, pskv.lens, cmd));
}

accumulated += len;
Expand Down