Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: use omp to speed up memcpy #152

Merged
merged 1 commit into from
Nov 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions byteps/common/cpu_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,19 @@ int CpuReducer::_sum_float16(void* dst, void* src1, void* src2, size_t len) {
return 0;
}

int CpuReducer::copy(void* dst, void* src, size_t len) {
auto in = (float*)src;
auto out = (float*)dst;
#pragma omp parallel for simd num_threads(_num_threads)
for (size_t i = 0; i < len / 4; ++i) {
out[i] = in[i];
}
if (len % 4) {
memcpy(out + len / 4, in + len / 4, len % 4);
}
return 0;
}


} // namespace common
} // namespace byteps
1 change: 1 addition & 0 deletions byteps/common/cpu_reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CpuReducer {

int sum(void* dst, void* src, size_t len, DataType dtype);
int sum(void* dst, void* src1, void* src2, size_t len, DataType dtype);
int copy(void* dst, void* src, size_t len);

#ifndef BYTEPS_BUILDING_SERVER
bool isRoot();
Expand Down
10 changes: 5 additions & 5 deletions byteps/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void BytePSServerEngineThread(int i) {
<< "dst_addr: " << DEBUG_PRINT_TENSOR_ADDRESS(msg.dst) << "\t"
<< "src_addr: " << DEBUG_PRINT_TENSOR_ADDRESS(msg.src) << "\t";
}
memcpy(msg.dst, msg.src, msg.len);
bps_reducer_->copy(msg.dst, msg.src, msg.len);
if (is_debug) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: ENGINE_COPY_RECV_AFTER \t"
Expand All @@ -105,7 +105,7 @@ void BytePSServerEngineThread(int i) {
<< "dst_addr: " << DEBUG_PRINT_TENSOR_ADDRESS(msg.dst) << "\t"
<< "src_addr: " << DEBUG_PRINT_TENSOR_ADDRESS(msg.src) << "\t";
}
memcpy(msg.dst, msg.src, msg.len);
bps_reducer_->copy(msg.dst, msg.src, msg.len);
if (is_debug) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: ENGINE_COPY_MERGED_TO_STORE_AFTER \t"
Expand Down Expand Up @@ -207,7 +207,7 @@ void BytePSHandler(const ps::KVMeta& req_meta,
stored.tensor = (char*) malloc(len);
stored.len = len;
stored.dtype = type.dtype;
memcpy(stored.tensor, recved, len); // we may not need this copy
bps_reducer_->copy(stored.tensor, recved, len); // we may not need this copy
for (const auto& req : updates.request) {
SendPushResponse(key, req, server);
}
Expand All @@ -223,7 +223,7 @@ void BytePSHandler(const ps::KVMeta& req_meta,
if (updates.request.empty()) { // from the first incoming worker
if (sync_mode_) {
if (is_engine_blocking_) {
memcpy(updates.merged.tensor, recved, len);
bps_reducer_->copy(updates.merged.tensor, recved, len);
} else { // non-blocking
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
Expand Down Expand Up @@ -291,7 +291,7 @@ void BytePSHandler(const ps::KVMeta& req_meta,
auto& stored = store_[key];
auto& update = updates.merged;
if (is_engine_blocking_) {
memcpy(stored.tensor, updates.merged.tensor, len);
bps_reducer_->copy(stored.tensor, updates.merged.tensor, len);
} else {
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
Expand Down