Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#96 from YeelandX/FIXCONTINUOUS3
Browse files Browse the repository at this point in the history
fix: rollback && check_continuous_memory for share_embedding series
  • Loading branch information
YeelandX authored Jul 9, 2024
2 parents 21c3691 + 37fa6de commit 8ae906a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/framework/fleet/box_wrapper_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -879,11 +879,13 @@ void BoxWrapper::PullSparseCaseXPU(const paddle::platform::Place& place,
}

if (is_xpu_continuous_memory_pull_ == -1) {
int check_expand_dim = expand_only ? expand_embed_dim : expand_embed_dim + hidden_size;
if (pull_info_.expand_size < 0) check_expand_dim = -1;
is_xpu_continuous_memory_pull_ = check_continuous_memory_pull(device_id,
values,
slot_lengths,
hidden_size,
expand_only ? pull_info_.expand_size : expand_embed_dim + hidden_size,
check_expand_dim,
total_length);
}
box_wrapper_kernel_->CopyForPull(place, xpu_keys, (float**)values.data(), total_values_xpu,
Expand Down Expand Up @@ -1400,11 +1402,13 @@ void BoxWrapper::PushSparseGradCaseXPU(const paddle::platform::Place& place,
total_length * sizeof(int));

if (is_xpu_continuous_memory_push_ == -1) {
int check_expand_dim = expand_only ? expand_embed_dim : expand_embed_dim + hidden_size;
if (pull_info_.expand_size < 0) check_expand_dim = -1;
is_xpu_continuous_memory_push_ = check_continuous_memory_push(device_id,
grad_values,
slot_lengths,
hidden_size,
expand_only ? pull_info_.expand_size : expand_embed_dim + hidden_size);
check_expand_dim);
}

box_wrapper_kernel_->CopyForPush(place, xpu_values, total_grad_values_xpu,
Expand Down

0 comments on commit 8ae906a

Please sign in to comment.