Skip to content

Commit

Permalink
Merge pull request #59 from MaJun-cn/async_norm
Browse files Browse the repository at this point in the history
add norm async update in BoxPSAsynDenseTable
  • Loading branch information
qingshui authored Oct 27, 2022
2 parents c26eeb8 + fe6f5d8 commit f0e0944
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 41 deletions.
11 changes: 10 additions & 1 deletion paddle/fluid/framework/boxps_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ void BoxPSTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) {
PADDLE_ENFORCE(root_scope_, "Null root_scope pointer");
for (auto& var : main_program.Block(0).AllVars()) {
if (async_mode_) {
std::string cur_var_name = var->Name();
size_t tag_pos = cur_var_name.find("@GRAD");
if (tag_pos != std::string::npos && tag_pos == cur_var_name.size() - 5) {
VLOG(3) << "BoxPSTrainer async_grad_name_ insert : " << cur_var_name;
async_grad_name_.insert(cur_var_name);
}
}
if (var->Persistable()) {
persistable_vars_.push_back(var->Name());
}
Expand All @@ -176,7 +184,8 @@ void BoxPSTrainer::InitTrainerEnv(const ProgramDesc& main_program,
std::set<std::string> async_param_name;
if (async_mode_) {
async_param_name = dense_table_->Init(*root_scope_, *param_need_sync_.get(),
persistable_vars_);
persistable_vars_,
async_grad_name_);
}
for (int i = 0; i < thread_num_; ++i) {
auto this_worker =
Expand Down
106 changes: 83 additions & 23 deletions paddle/fluid/framework/boxps_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,35 +54,59 @@ BoxPSAsynDenseTable::~BoxPSAsynDenseTable() {}

std::set<std::string> BoxPSAsynDenseTable::Init(
const Scope& root_scope, const std::vector<std::string>& param_need_sync,
const std::vector<std::string>& persistable_vars) {
const std::vector<std::string>& persistable_vars,
const std::set<std::string>& async_grad_name) {
std::set<std::string> async_param_name;
root_scope_ = const_cast<paddle::framework::Scope*>(&root_scope);
VLOG(1) << "Begin Init For Aysnc Optimize";
for (const auto& e : param_need_sync) {
std::string grad_name = e + "@GRAD";
if (async_grad_name.find(grad_name) == async_grad_name.end()) continue;
if (e.find("param") != std::string::npos &&
e.find("pow_acc") == std::string::npos) {
VLOG(3) << "async mode choose " << e << " to update";
VLOG(3) << "async mode choose adam param " << e << " to update";
async_param_list_.push_back(e);
async_param_list_.push_back(e + "_moment1_0");
async_param_list_.push_back(e + "_moment2_0");
async_param_name.insert(e);
async_param_name.insert(e + "@GRAD");
}
if (e.find("summary") != std::string::npos &&
e.find("batch_s") != std::string::npos) {
VLOG(3) << "async mode choose norm param " << e << " to update";
async_norm_param_list_.push_back(e);
async_param_name.insert(e);
async_param_name.insert(e + "@GRAD");
}
}
original_ps_.resize(async_param_list_.size());
//adam param
const size_t adam_param_list_size = async_param_list_.size();
std::sort(
async_param_list_.begin(),
async_param_list_
.end()); // xx_param.b_0, xx_param_moment1_0, xx_param_moment2_0
for (size_t i = 0; i < async_param_list_.size(); i += 3) {
const LoDTensor& root_tensor =
root_scope.FindVar(async_param_list_[i])->Get<LoDTensor>();
adam_param_len_ += root_tensor.numel();
}
//norm param
std::sort(
async_norm_param_list_.begin(),
async_norm_param_list_
.end()); // xxsummary.batch_size, xxsummary.batch_square_sum, xxsummary.batch_sum
for (size_t i = 0; i < async_norm_param_list_.size(); i += 1) {
const LoDTensor& root_tensor =
root_scope.FindVar(async_norm_param_list_[i])->Get<LoDTensor>();
total_param_len_ += root_tensor.numel();
async_param_list_.push_back(async_norm_param_list_[i]);
}
total_param_len_ += adam_param_len_;
original_ps_.resize(async_param_list_.size());

ps_.mutable_data<float>({total_param_len_, 1}, platform::CPUPlace());
mom1_.mutable_data<float>({total_param_len_, 1}, platform::CPUPlace());
mom2_.mutable_data<float>({total_param_len_, 1}, platform::CPUPlace());
mom1_.mutable_data<float>({adam_param_len_, 1}, platform::CPUPlace());
mom2_.mutable_data<float>({adam_param_len_, 1}, platform::CPUPlace());
for (size_t i = 0; i < device_grads_.size(); ++i) {
device_grads_[i].mutable_data<float>(
{static_cast<int64_t>(total_param_len_), 1}, platform::CPUPlace());
Expand All @@ -95,23 +119,32 @@ std::set<std::string> BoxPSAsynDenseTable::Init(
root_scope.FindVar(async_param_list_[i])->Get<LoDTensor>();
auto dim = root_tensor.dims();
size_t len = root_tensor.numel();
if (i % 3 == 0) {
original_ps_[i]
.ShareDataWith(ps_.Slice(offset, offset + len))
.Resize(dim);
} else if (i % 3 == 1) {
original_ps_[i]
.ShareDataWith(mom1_.Slice(offset, offset + len))
.Resize(dim);
if (i < adam_param_list_size) {
if (i % 3 == 0) {
original_ps_[i]
.ShareDataWith(ps_.Slice(offset, offset + len))
.Resize(dim);
} else if (i % 3 == 1) {
original_ps_[i]
.ShareDataWith(mom1_.Slice(offset, offset + len))
.Resize(dim);
} else {
original_ps_[i]
.ShareDataWith(mom2_.Slice(offset, offset + len))
.Resize(dim);
offset += len;
}
} else {
VLOG(3) << "original_ps_ ShareDataWith norml name:" << async_param_list_[i] << " , i:" << i << " offset:" << offset;
original_ps_[i]
.ShareDataWith(mom2_.Slice(offset, offset + len))
.ShareDataWith(ps_.Slice(offset, offset + len))
.Resize(dim);
offset += len;
}
TensorCopy(*static_cast<const Tensor*>(&root_tensor), platform::CPUPlace(),
static_cast<Tensor*>(&(original_ps_[i])));
}
VLOG(3) << "after original_ps_ ShareDataWith offset:" << offset;

// Copy global lr for async mode
for (const auto& e : persistable_vars) {
Expand All @@ -132,14 +165,17 @@ std::set<std::string> BoxPSAsynDenseTable::Init(
}
}
VLOG(0) << "Aysnc alloc dense table param size: " << async_param_list_.size()
<< ", total length:" << total_param_len_ << ", base_lr=" << base_lr_;
<< ", adam param size: " << adam_param_list_size
<< ", total length:" << total_param_len_
<< ", adam length: " << adam_param_len_
<< ", base_lr=" << base_lr_;

ps_buffer_.reset(new PSBufferQueue(device_num_ * 3)); // magic number
all_lr_.resize(total_param_len_);
all_lr_.resize(adam_param_len_);
auto box_ptr = BoxWrapper::GetInstance();
std::map<std::string, float> lr_map = box_ptr->GetLRMap();
int lr_index = 0;
for (size_t i = 0; i < async_param_list_.size() / 3; ++i) {
for (size_t i = 0; i < adam_param_list_size / 3; ++i) {
float learning_rate = base_lr_;
if (lr_map.find(async_param_list_[i * 3]) != lr_map.end()) {
learning_rate = lr_map[async_param_list_[i * 3]];
Expand Down Expand Up @@ -219,12 +255,16 @@ void BoxPSAsynDenseTable::ThreadUpdate(int thread_id,
4;
}
}

VLOG(3) << "ThreadUpdate[" << thread_id << "] start: " << start << ", end: " << end << ", adam_param_len_: " << (size_t)adam_param_len_;
for (size_t j = start; j < end; ++j) {
mom1_data[j] =
0.99 * mom1_data[j] + 0.01 * grad_data[j]; // magic beta and episilon
mom2_data[j] = 0.9999 * mom2_data[j] + 0.0001 * grad_data[j] * grad_data[j];
param_data[j] -= all_lr_[j] * (mom1_data[j] / (sqrt(mom2_data[j]) + 1e-8));
if (j < (size_t)adam_param_len_) {//adam
mom1_data[j] =
0.99 * mom1_data[j] + 0.01 * grad_data[j]; // magic beta and episilon
mom2_data[j] = 0.9999 * mom2_data[j] + 0.0001 * grad_data[j] * grad_data[j];
param_data[j] -= all_lr_[j] * (mom1_data[j] / (sqrt(mom2_data[j]) + 1e-8));
} else { //norm
param_data[j] = param_data[j] * 0.9999999 + grad_data[j];
}
}
return;
}
Expand Down Expand Up @@ -443,15 +483,28 @@ void BoxPSWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
std::vector<VarDesc*> sorted_var = block.AllVars();
std::sort(sorted_var.begin(), sorted_var.end(),
[](const VarDesc* var1, const VarDesc* var2) {
return var1->Name() < var2->Name();
std::string var1_name = var1->Name();
std::string var2_name = var2->Name();
if (var1_name.find("param") != std::string::npos &&
var2_name.find("param") == std::string::npos) {
return true;
} else if (var1_name.find("param") == std::string::npos &&
var2_name.find("param") != std::string::npos) {
return false;
} else {
return var1->Name() < var2->Name();
}
});
// init var and copy persistable
int grad_var_num = 0;
int var_num = 0;
for (auto& var : sorted_var) {
std::string name = var->Name();
if (!var->Persistable()) {
if (dense_table_ &&
async_param_name_.find(name) != async_param_name_.end()) {
// parm@GRAD can not find in root_scope_ use parm length replace
VLOG(3) << "device[" << device_id_ << "] grad var name " << name;
const LoDTensor& root_tensor =
root_scope_->FindVar(name.substr(0, name.length() - 5))
->Get<LoDTensor>();
Expand All @@ -463,6 +516,7 @@ void BoxPSWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
->ShareDataWith(grad_async_.Slice(grad_offset, grad_offset + len))
.Resize(dim);
grad_offset += len;
grad_var_num += 1;
} else {
auto* ptr = thread_scope_->Var(name);
InitializeVariable(ptr, var->GetType());
Expand All @@ -481,11 +535,13 @@ void BoxPSWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
}
} else if (dense_table_) {
if (async_param_name_.find(name) != async_param_name_.end()) {
VLOG(3) << "device[" << device_id_ << "] Persistable var name " << name;
auto dim = root_tensor.dims();
size_t len = root_tensor.numel();
gpu_tensor->ShareDataWith(param_async_.Slice(offset, offset + len))
.Resize(dim);
offset += len;
var_num += 1;
}
}
TensorCopy(*static_cast<const Tensor*>(&root_tensor), place_,
Expand All @@ -495,6 +551,10 @@ void BoxPSWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
if (sync_mode_ > 0) {
CHECK(offset <= (param_sync_.numel() - pad_len));
} else if (dense_table_) {
VLOG(3) << "device[" << device_id_ << "]CreateDeviceResource param_async_ offset:" << offset
<< " grad_offset: " << grad_offset
<< " var_num: " << var_num
<< " grad_var_num: " << grad_var_num;
CHECK(offset <= param_async_.numel());
CHECK(grad_offset <= grad_async_.numel());
}
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,8 @@ class BoxPSAsynDenseTable {

std::set<std::string> Init(const Scope& root_scope,
const std::vector<std::string>& param_need_sync,
const std::vector<std::string>& persistable_vars);
const std::vector<std::string>& persistable_vars,
const std::set<std::string>& async_grad_name);
void Finalize(void);
void PullDense(const platform::Place& place, Tensor* tensor);
void PushDense(const platform::Place& place, Tensor* tensor);
Expand All @@ -826,6 +827,7 @@ class BoxPSAsynDenseTable {
int device_num_ = 0;
std::vector<LoDTensor> device_grads_;
std::vector<std::string> async_param_list_;
std::vector<std::string> async_norm_param_list_;
std::vector<LoDTensor> original_ps_;
LoDTensor ps_;
LoDTensor mom1_;
Expand All @@ -836,6 +838,7 @@ class BoxPSAsynDenseTable {
std::shared_ptr<PSBufferQueue> ps_buffer_ = nullptr;
Scope* root_scope_ = nullptr;
int64_t total_param_len_ = 0;
int64_t adam_param_len_ = 0;
std::vector<size_t> thread_start_index_;
std::vector<size_t> thread_end_index_;
std::shared_ptr<paddle::framework::ThreadPool> thread_pool = nullptr;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ class BoxPSTrainer : public TrainerBase {

std::shared_ptr<std::vector<std::string>> param_need_sync_;
std::vector<std::string> persistable_vars_;
std::set<std::string> async_grad_name_;

bool async_mode_ = false;
std::shared_ptr<BoxPSAsynDenseTable> dense_table_ = nullptr;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/data_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddAttr<bool>("sync_stats", "(bool, default false) only used in multi-GPU")
.SetDefault(false);
AddAttr<bool>("update_norm", "(bool, default true) used in update_norm")
.SetDefault(true);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
Expand Down
34 changes: 18 additions & 16 deletions paddle/fluid/operators/data_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class DataNormGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon");
const float dr = ctx.Attr<float>("summary_decay_rate");
const bool need_sync_stats = ctx.Attr<bool>("sync_stats");
const bool update_norm = ctx.Attr<bool>("update_norm");

const auto &x_dims = x->dims();
// Align with CPU version, but should we add this restriction?
Expand Down Expand Up @@ -233,22 +234,23 @@ class DataNormGradKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
"supported on windows now."));
#endif
}

T *batch_size_data =
ctx.Output<Tensor>("BatchSize")->mutable_data<T>(ctx.GetPlace());
T *batch_sum_data =
ctx.Output<Tensor>("BatchSum")->mutable_data<T>(ctx.GetPlace());
T *batch_square_sum_data =
ctx.Output<Tensor>("BatchSquareSum")->mutable_data<T>(ctx.GetPlace());
KernelUpdateParam<<<GET_BLOCKS(C), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
C,
d_batch_size,
d_batch_sum,
d_batch_square_sum,
batch_size_data,
batch_sum_data,
batch_square_sum_data,
dr);
if (update_norm) {
T *batch_size_data =
ctx.Output<Tensor>("BatchSize")->mutable_data<T>(ctx.GetPlace());
T *batch_sum_data =
ctx.Output<Tensor>("BatchSum")->mutable_data<T>(ctx.GetPlace());
T *batch_square_sum_data =
ctx.Output<Tensor>("BatchSquareSum")->mutable_data<T>(ctx.GetPlace());
KernelUpdateParam<<<GET_BLOCKS(C), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
C,
d_batch_size,
d_batch_sum,
d_batch_square_sum,
batch_size_data,
batch_sum_data,
batch_square_sum_data,
dr);
} //if !update_norm, will update norm param use BoxPSAsynDenseTable
}
};
} // namespace operators
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3497,6 +3497,7 @@ def data_norm(input,
do_model_average_for_mean_and_var=True,
slot_dim=-1,
sync_stats=False,
update_norm=True,
summary_decay_rate=0.9999999,
enable_scale_and_shift=False):
r"""
Expand Down Expand Up @@ -3645,6 +3646,7 @@ def data_norm(input,
"epsilon": epsilon,
"data_layout": data_layout,
"sync_stats": sync_stats,
"update_norm": update_norm,
"summary_decay_rate": summary_decay_rate,
}
if slot_dim > 0:
Expand Down

0 comments on commit f0e0944

Please sign in to comment.