Skip to content

Commit

Permalink
【Error Message BUAA No. 32】 channel.h,archive.h,fleet/heter_ps/heter_…
Browse files Browse the repository at this point in the history
…comm_inl.h (#66719)

* change CHECK to PADDLE_ENFORCE_xx in archive.h+channel.h+heter_comm_inl.h

* add #include 'paddle/fluid/platform/enforce.h' in channel.h

* change 'should not be' to 'should be'

* 析构函数里不能抛出异常,不用修改CHECK
  • Loading branch information
tlxd authored Aug 1, 2024
1 parent fc53458 commit 5ea2b18
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 24 deletions.
85 changes: 74 additions & 11 deletions paddle/fluid/framework/archive.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,14 @@ class ArchiveBase {
size_t length,
size_t capacity,
std::function<void(char*)>&& deleter) {
CHECK(length <= capacity);
PADDLE_ENFORCE_LE(
length,
capacity,
phi::errors::InvalidArgument(
"Param length should be less than or equal to param capacity, but "
"the length got %d, the capacity got %d.",
length,
capacity));
FreeBuffer();
buffer_ = buffer;
cursor_ = buffer_;
Expand All @@ -119,24 +126,54 @@ class ArchiveBase {
char* Cursor() { return cursor_; }

void SetCursor(char* cursor) {
CHECK(cursor >= buffer_ && cursor <= finish_);
PADDLE_ENFORCE_EQ(
cursor >= buffer_ && cursor <= finish_,
true,
phi::errors::InvalidArgument(
"Param cursor should be greater than or equal to buffer, and "
"should be less than or equal to finish, but the cursor got %d, "
"the buffer got %d, the finish got %d.",
cursor,
buffer_,
finish_));
cursor_ = cursor;
}

void AdvanceCursor(size_t offset) {
CHECK(offset <= size_t(finish_ - cursor_));
PADDLE_ENFORCE_LE(
offset,
size_t(finish_ - cursor_),
phi::errors::InvalidArgument(
"Param offset should be less than or equal to %d, but got %d.",
size_t(finish_ - cursor_),
offset));
cursor_ += offset;
}

char* Finish() { return finish_; }

void SetFinish(char* finish) {
CHECK(finish >= cursor_ && finish <= limit_);
PADDLE_ENFORCE_EQ(
finish >= cursor_ && finish <= limit_,
true,
phi::errors::InvalidArgument(
"Param finish should be greater than or equal to cursor, and "
"should be less than or equal to limit, but the finish got %d, "
"the cursor got %d, the limit got %d.",
finish,
cursor_,
limit_));
finish_ = finish;
}

void AdvanceFinish(size_t offset) {
CHECK(offset <= size_t(limit_ - finish_));
PADDLE_ENFORCE_LE(
offset,
size_t(limit_ - finish_),
phi::errors::InvalidArgument(
"Param offset should be less than or equal to %d, but got %d.",
size_t(limit_ - finish_),
offset));
finish_ += offset;
}

Expand Down Expand Up @@ -188,7 +225,10 @@ class ArchiveBase {
if (newcap > Capacity()) {
char* newbuf = NULL;
newbuf = new char[newcap];
CHECK(newbuf != nullptr) << "Reserve failed, out of memory";
PADDLE_ENFORCE_NE(
newbuf,
nullptr,
phi::errors::InvalidArgument("Reserve failed, out of memory."));
if (Length() > 0) {
memcpy(newbuf, buffer_, Length());
}
Expand All @@ -207,7 +247,13 @@ class ArchiveBase {
#else
if (!(size <= size_t(finish_ - cursor_))) {
#endif
CHECK(size <= size_t(finish_ - cursor_));
PADDLE_ENFORCE_LE(
size,
size_t(finish_ - cursor_),
phi::errors::InvalidArgument(
"Param size should be less than or equal to %d, but got %d.",
size_t(finish_ - cursor_),
size));
}
}

Expand All @@ -231,7 +277,13 @@ class ArchiveBase {

void ReadBack(void* data, size_t size) {
if (size > 0) {
CHECK(size <= size_t(finish_ - cursor_));
PADDLE_ENFORCE_LE(
size,
size_t(finish_ - cursor_),
phi::errors::InvalidArgument(
"Param size should be less than or equal to %d, but got %d.",
size_t(finish_ - cursor_),
size));
memcpy(data, finish_ - size, size);
finish_ -= size;
}
Expand Down Expand Up @@ -326,11 +378,22 @@ class Archive<BinaryArchiveType> : public ArchiveBase {
void Printf(const char* fmt, ARGS&&... args) {
size_t temp = Limit() - Finish();
int len = snprintf(Finish(), temp, fmt, args...);
CHECK(len >= 0); // NOLINT
PADDLE_ENFORCE_GE(
len,
0,
phi::errors::InvalidArgument(
"Param len should be greater than or equal to 0, but got %d.",
len)); // NOLINT
if (static_cast<size_t>(len) >= temp) {
PrepareWrite(len + 1);
CHECK(snprintf(Finish(), static_cast<size_t>(len) + 1, fmt, args...) ==
len);
PADDLE_ENFORCE_EQ(
snprintf(Finish(), static_cast<size_t>(len) + 1, fmt, args...),
len,
phi::errors::InvalidArgument(
"The snprintf(Finish(), static_cast<size_t>(len) + 1, fmt, "
"args...) should be equal to %d, but got %d.",
len,
snprintf(Finish(), static_cast<size_t>(len) + 1, fmt, args...)));
}
AdvanceFinish(len);
}
Expand Down
37 changes: 30 additions & 7 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <utility>
#include <vector>

#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/expect.h"

namespace paddle {
Expand Down Expand Up @@ -67,7 +68,12 @@ class ChannelObject {
}

void SetBlockSize(size_t x) {
CHECK(x >= 1) << "block size must be >= 1";
PADDLE_ENFORCE_GE(
x,
1,
phi::errors::InvalidArgument(
"The block size must be greater than or equal to 1, but got %d.",
x));
std::lock_guard<std::mutex> lock(mutex_);
block_size_ = x;
}
Expand Down Expand Up @@ -260,7 +266,13 @@ class ChannelObject {
std::unique_lock<std::mutex>& lock, // NOLINT
bool once = false) { // NOLINT
size_t finished = 0;
CHECK(n <= MaxCapacity() - reading_count_);
PADDLE_ENFORCE_LE(
n,
MaxCapacity() - reading_count_,
phi::errors::InvalidArgument(
"Param n should be less than or equal to %d, but got %d.",
MaxCapacity() - reading_count_,
n));
reading_count_ += n;
while (finished < n && WaitForRead(lock)) {
size_t m = (std::min)(n - finished, data_.size());
Expand Down Expand Up @@ -316,7 +328,10 @@ Channel<T> MakeChannel(size_t capacity = (std::numeric_limits<size_t>::max)()) {

template <class T, class U>
Channel<T> MakeChannel(const Channel<U>& other) {
CHECK(other != nullptr) << "channel can not be NULL";
PADDLE_ENFORCE_NE(
other,
nullptr,
phi::errors::InvalidArgument("The channel can not be NULL!"));
Channel<T> chan = std::make_shared<ChannelObject<T>>();
chan->InheritFrom(other);
return chan;
Expand All @@ -338,7 +353,10 @@ class ChannelReader {
ChannelObject<T>* channel() { return channel_; }

void Reset(ChannelObject<T>* channel) {
CHECK(channel != nullptr) << "Channel can not be nullptr";
PADDLE_ENFORCE_NE(
channel,
nullptr,
phi::errors::InvalidArgument("Channel can not be nullptr"));
channel_ = channel;
cursor_ = 0;
failed_ = !channel;
Expand Down Expand Up @@ -390,8 +408,10 @@ class ChannelWriter {
ChannelObject<T>* channel() { return channel_; }

void Reset(ChannelObject<T>* channel) {
CHECK(buffer_.empty()) << "Forgot to flush";
// CHECK(channel != nullptr) << "Channel can not be nullptr";
PADDLE_ENFORCE_EQ(buffer_.empty(),
true,
phi::errors::InvalidArgument(
"The buffer should be empty! Forgot to flush."));
channel_ = channel;
buffer_.clear();
failed_ = !channel;
Expand Down Expand Up @@ -446,7 +466,10 @@ struct ChannelIterator {
T data_;

void operator++() {
CHECK(reader_ != nullptr) << "reader can not be NULL";
PADDLE_ENFORCE_NE(
reader_,
nullptr,
phi::errors::InvalidArgument("The reader can not be NULL."));
if (!(*reader_ >> data_)) {
reader_ = nullptr;
}
Expand Down
47 changes: 41 additions & 6 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2254,7 +2254,12 @@ int HeterComm<KeyType, ValType, GradType, GPUAccessor>::dedup_keys_and_fillidx(
stream = resource_->local_stream(gpu_id, 0);
}

CHECK_GT(total_fea_num, 0);
PADDLE_ENFORCE_GT(
total_fea_num,
0,
phi::errors::InvalidArgument(
"Param total feature num should be greater than 0, but got %d.",
total_fea_num));
size_t merged_size = 0;
size_t byte_size = sizeof(uint32_t) * (total_fea_num + 1);

Expand Down Expand Up @@ -2688,7 +2693,13 @@ HeterComm<KeyType, ValType, GradType, GPUAccessor>::gather_inner_keys_by_copy(
max_part_size = res.h_part_sizes[i];
}
}
CHECK_EQ(shard_send_offset, static_cast<size_t>(fea_size));
PADDLE_ENFORCE_EQ(
shard_send_offset,
static_cast<size_t>(fea_size),
phi::errors::InvalidArgument(
"Param shard_send_offset should be equal to %d, but got %d.",
static_cast<size_t>(fea_size),
shard_send_offset));

size_t trans_need_size =
std::max(shard_recv_offset, static_cast<size_t>(fea_size));
Expand Down Expand Up @@ -2868,7 +2879,12 @@ size_t HeterComm<KeyType, ValType, GradType, GPUAccessor>::send_data_by_all2all(
cudaMemcpyDeviceToDevice,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
CHECK_EQ(send_size, h_recv_part_sizes[nccl_rank_id]);
PADDLE_ENFORCE_EQ(send_size,
h_recv_part_sizes[nccl_rank_id],
phi::errors::InvalidArgument(
"Param send_size should be equal to %d, but got %d.",
h_recv_part_sizes[nccl_rank_id],
send_size));

auto &loc = storage_[gpu_id];
auto nccl_stream = resource_->comm_stream(gpu_id, 0);
Expand Down Expand Up @@ -2950,7 +2966,12 @@ size_t HeterComm<KeyType, ValType, GradType, GPUAccessor>::
cache.remote_keys_ += h_local_part_sizes[i];
}
}
CHECK_EQ(fea_size, h_local_part_offsets[node_size_]);
PADDLE_ENFORCE_EQ(fea_size,
h_local_part_offsets[node_size_],
phi::errors::InvalidArgument(
"Param fea_size should be equal to %d, but got %d.",
h_local_part_offsets[node_size_],
fea_size));

PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&res.d_node_size_ptr[rank_offset],
&h_push_fea_sizes[rank_offset],
Expand Down Expand Up @@ -3892,7 +3913,14 @@ HeterComm<KeyType, ValType, GradType, GPUAccessor>::send_keys_by_all2all_trans(
const size_t &recv_size =
my_cache.shard_res.h_remote_part_offsets[nccl_node_size];
size_t need_len = std::max(fea_size, recv_size);
CHECK(trans.trans_keys_buff->size() >= need_len * sizeof(KeyType) * 2);
PADDLE_ENFORCE_EQ(
trans.trans_keys_buff->size() >= need_len * sizeof(KeyType) * 2,
true,
phi::errors::InvalidArgument(
"The size of trnas keys buffer should be greater than or equal to "
"%d, but got %d.",
need_len * sizeof(KeyType) * 2,
trans.trans_keys_buff->size()));

// p2p copy
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyPeerAsync(trans.d_merged_trans_keys,
Expand Down Expand Up @@ -4080,7 +4108,14 @@ size_t HeterComm<KeyType, ValType, GradType, GPUAccessor>::
const size_t &recv_total_size =
my_cache.shard_res.h_remote_part_offsets[nccl_node_size];
size_t need_len = std::max(fea_size, recv_total_size);
CHECK(trans.trans_keys_buff->size() >= need_len * sizeof(KeyType) * 2);
PADDLE_ENFORCE_EQ(
trans.trans_keys_buff->size() >= need_len * sizeof(KeyType) * 2,
true,
phi::errors::InvalidArgument(
"The size of trans keys buffer should be greater than or equal to "
"%d, but got %d.",
need_len * sizeof(KeyType) * 2,
trans.trans_keys_buff->size()));

// p2p copy
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyPeerAsync(trans.d_merged_trans_keys,
Expand Down

0 comments on commit 5ea2b18

Please sign in to comment.