Skip to content

Commit

Permalink
Remove ActorMsg::user_data_ (#9762)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujuncheng authored Jan 28, 2023
1 parent 9ad538c commit 2050000
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 47 deletions.
28 changes: 12 additions & 16 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,26 @@ IBVerbsCommNet::~IBVerbsCommNet() {
}

void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
ActorMsg new_msg = msg;
IBVerbsActorMsgWrapper msg_wrapper;
msg_wrapper.msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
CHECK_EQ(msg.user_data_size(), 0);
auto* mem_desc = reinterpret_cast<IBVerbsMemDesc*>(msg.regst()->comm_net_token());
CHECK(mem_desc != nullptr);
IBVerbsCommNetRMADesc rma_desc{};
rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());
rma_desc.mem_size = mem_desc->mem_size();
rma_desc.mr_rkey = mem_desc->mr()->rkey;
static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
msg_wrapper.rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());
msg_wrapper.rma_desc.mem_size = mem_desc->mem_size();
msg_wrapper.rma_desc.mr_rkey = mem_desc->mr()->rkey;
}
qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
qp_vec_.at(dst_machine_id)->PostSendRequest(msg_wrapper);
}

void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
void IBVerbsCommNet::RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper) {
ActorMsg new_msg = msg_wrapper.msg;
if (msg_wrapper.msg.IsDataRegstMsgToConsumer()) {
std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
reinterpret_cast<uint64_t>(msg.regst()))];
auto& desc = remote_regst2rma_desc_[std::make_pair(
msg_wrapper.msg.src_actor_id(), reinterpret_cast<uint64_t>(msg_wrapper.msg.regst()))];
if (!desc) { desc.reset(new IBVerbsCommNetRMADesc); }
CHECK_EQ(msg.user_data_size(), sizeof(IBVerbsCommNetRMADesc));
std::memcpy(desc.get(), msg.user_data(), sizeof(IBVerbsCommNetRMADesc));
*desc = msg_wrapper.rma_desc;
new_msg.set_comm_net_token(desc.get());
}
Singleton<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
Expand Down
8 changes: 1 addition & 7 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,13 @@ limitations under the License.

namespace oneflow {

struct IBVerbsCommNetRMADesc {
uint64_t mem_ptr;
uint64_t mem_size;
uint32_t mr_rkey;
};

class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet);
~IBVerbsCommNet();

void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void RecvActorMsg(const ActorMsg& msg);
void RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper);

private:
friend class Singleton<IBVerbsCommNet>;
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,
}
}

void IBVerbsQP::PostSendRequest(const ActorMsg& msg) {
void IBVerbsQP::PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper) {
ActorMsgMR* msg_mr = GetOneSendMsgMRFromBuf();
msg_mr->set_msg(msg);
msg_mr->set_msg(msg_wrapper);
WorkRequestId* wr_id = NewWorkRequestId();
wr_id->msg_mr = msg_mr;
ibv_send_wr wr{};
Expand Down
19 changes: 15 additions & 4 deletions oneflow/core/comm_network/ibverbs/ibverbs_qp.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,30 @@ limitations under the License.

namespace oneflow {

struct IBVerbsCommNetRMADesc {
uint64_t mem_ptr;
uint64_t mem_size;
uint32_t mr_rkey;
};

struct IBVerbsActorMsgWrapper final {
ActorMsg msg;
IBVerbsCommNetRMADesc rma_desc;
};

class ActorMsgMR final {
public:
OF_DISALLOW_COPY_AND_MOVE(ActorMsgMR);
ActorMsgMR() = delete;
ActorMsgMR(ibv_pd* pd) { mem_desc_.reset(new IBVerbsMemDesc(pd, &msg_, sizeof(msg_))); }
~ActorMsgMR() { mem_desc_.reset(); }

const ActorMsg& msg() const { return msg_; }
void set_msg(const ActorMsg& val) { msg_ = val; }
const IBVerbsActorMsgWrapper& msg() const { return msg_; }
void set_msg(const IBVerbsActorMsgWrapper& val) { msg_ = val; }
const IBVerbsMemDesc& mem_desc() const { return *mem_desc_; }

private:
ActorMsg msg_;
IBVerbsActorMsgWrapper msg_;
std::unique_ptr<IBVerbsMemDesc> mem_desc_;
};

Expand Down Expand Up @@ -64,7 +75,7 @@ class IBVerbsQP final {

void PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem, const IBVerbsMemDesc& local_mem,
void* read_id);
void PostSendRequest(const ActorMsg& msg);
void PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper);

void ReadDone(WorkRequestId*);
void SendDone(WorkRequestId*);
Expand Down
11 changes: 0 additions & 11 deletions oneflow/core/lazy/actor/actor_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,6 @@ int64_t ActorMsg::eord_regst_desc_id() const {
return eord_regst_desc_id_;
}

void ActorMsg::AddUserData(uint8_t size, const void* data) {
CHECK_EQ(user_data_size_, 0);
CHECK_LE(size, kActorMsgUserDataMaxSize);
user_data_size_ = size;
std::memcpy(user_data_, data, size);
}

uint8_t ActorMsg::user_data_size() const { return user_data_size_; }

const void* ActorMsg::user_data() const { return user_data_; }

bool ActorMsg::IsDataRegstMsgToConsumer() const {
return msg_type_ == ActorMsgType::kRegstMsg && regst_wrapper_.is_data_regst_to_consumer;
}
Expand Down
7 changes: 0 additions & 7 deletions oneflow/core/lazy/actor/actor_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ enum class ActorCmd {

enum class ActorMsgType : int8_t { kRegstMsg = 0, kEordMsg, kCmdMsg };

constexpr uint8_t kActorMsgUserDataMaxSize = 32;

class ActorMsg final {
public:
ActorMsg() = default;
Expand All @@ -54,9 +52,6 @@ class ActorMsg final {
void set_comm_net_token(void* token);
bool has_sole_empty_blob() const;
int64_t eord_regst_desc_id() const;
void AddUserData(uint8_t size, const void* data);
uint8_t user_data_size() const;
const void* user_data() const;
bool IsDataRegstMsgToConsumer() const;
int64_t comm_net_sequence_number() const;
void set_comm_net_sequence_number(int64_t sequence_number);
Expand Down Expand Up @@ -91,8 +86,6 @@ class ActorMsg final {
int64_t eord_regst_desc_id_;
};
ActorMsgType msg_type_;
uint8_t user_data_size_;
unsigned char user_data_[kActorMsgUserDataMaxSize];
};

template<typename StreamT>
Expand Down

0 comments on commit 2050000

Please sign in to comment.