Skip to content

Commit

Permalink
Optimize graphsage data process (PaddlePaddle#96)
Browse files Browse the repository at this point in the history
* add sample update, update reindex

* temp commit

* fix offset calc

* optimize reindex

* delete unnecessary code

* add for loop speed up

* add speed optimize

* change common_graph VLOG type

* add kernel3 update for review

* temp commit

* delete unused code
  • Loading branch information
DesmonDay authored Sep 5, 2022
1 parent d56e8b7 commit fe1653c
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 111 deletions.
146 changes: 72 additions & 74 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,20 @@ __global__ void FillSlotValueOffsetKernel(const int ins_num,

__global__ void fill_actual_neighbors(int64_t* vals,
int64_t* actual_vals,
int64_t* actual_vals_dst,
int* actual_sample_size,
int* cumsum_actual_sample_size,
int sample_size,
int len) {
int len,
int mod) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
int offset1 = cumsum_actual_sample_size[i];
int offset2 = sample_size * i;
int dst_id = i % mod;
for (int j = 0; j < actual_sample_size[i]; j++) {
actual_vals[offset1 + j] = vals[offset2 + j];
actual_vals_dst[offset1 + j] = dst_id;
}
}
}
Expand Down Expand Up @@ -226,13 +230,6 @@ __global__ void CopyDuplicateKeys(int64_t *dist_tensor,
}
}

template <typename T>
__global__ void ResetReindexTable(T* tensor, int64_t len) {
CUDA_KERNEL_LOOP(idx, len) {
tensor[idx] = -1;
}
}

int GraphDataGenerator::AcquireInstance(BufState *state) {
//
if (state->GetNextStep()) {
Expand Down Expand Up @@ -490,15 +487,15 @@ int GraphDataGenerator::FillInsBuf() {

std::vector<std::shared_ptr<phi::Allocation>> GraphDataGenerator::SampleNeighbors(
int64_t* uniq_nodes, int len, int sample_size,
std::vector<int64_t>& edges_split_num, int64_t* neighbor_len) {
std::vector<int>& edges_split_num, int64_t* neighbor_len) {

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
auto edge_to_id = gpu_graph_ptr->edge_to_id;
int64_t all_sample_size = 0;

std::vector<std::shared_ptr<phi::Allocation>> concat_sample_val;
std::vector<std::shared_ptr<phi::Allocation>> concat_sample_count;

NeighborSampleQuery q;

for (auto& iter : edge_to_id) {
int edge_idx = iter.second;
q.initialize(gpuid_, edge_idx, (uint64_t)(uniq_nodes), sample_size, len);
Expand All @@ -507,20 +504,14 @@ std::vector<std::shared_ptr<phi::Allocation>> GraphDataGenerator::SampleNeighbor
concat_sample_val.emplace_back(sample_val_mem);
auto sample_count_mem = sample_res.actual_sample_size_mem;
concat_sample_count.emplace_back(sample_count_mem);
edges_split_num.emplace_back(sample_res.total_sample_size);
all_sample_size += sample_res.total_sample_size;
}

auto all_sample_val =
memory::AllocShared(place_, len * sample_size * edge_to_id_len_ * sizeof(int64_t));
auto final_sample_val =
memory::AllocShared(place_, all_sample_size * sizeof(int64_t));
auto all_sample_count =
memory::AllocShared(place_, edge_to_id_len_ * len * sizeof(int));
int64_t* all_sample_val_ptr =
reinterpret_cast<int64_t* >(all_sample_val->ptr());
int64_t* final_sample_val_ptr =
reinterpret_cast<int64_t* >(final_sample_val->ptr());
int* all_sample_count_ptr =
reinterpret_cast<int* >(all_sample_count->ptr());

Expand All @@ -536,38 +527,56 @@ std::vector<std::shared_ptr<phi::Allocation>> GraphDataGenerator::SampleNeighbor
cudaMemcpyAsync(all_sample_count_ptr + i * len, tmp_sample_count,
sizeof(int) * len, cudaMemcpyDeviceToDevice, stream_);
}

cudaStreamSynchronize(stream_);

thrust::device_vector<int> cumsum_actual_sample_size(len * edge_to_id_len_);
thrust::exclusive_scan(thrust::device_pointer_cast(all_sample_count_ptr),
thrust::device_vector<int> cumsum_actual_sample_size(len * edge_to_id_len_ + 1, 0);
thrust::inclusive_scan(thrust::device_pointer_cast(all_sample_count_ptr),
thrust::device_pointer_cast(all_sample_count_ptr) + len * edge_to_id_len_,
cumsum_actual_sample_size.begin(),
0);
cumsum_actual_sample_size.begin() + 1);
edges_split_num.resize(edge_to_id_len_);
for (int i = 0; i < edge_to_id_len_; i++) {
cudaMemcpyAsync(
edges_split_num.data() + i,
thrust::raw_pointer_cast(cumsum_actual_sample_size.data()) + (i + 1) * len,
sizeof(int),
cudaMemcpyDeviceToHost,
stream_);
}
cudaStreamSynchronize(stream_);

int all_sample_size = edges_split_num[edge_to_id_len_ - 1];
auto final_sample_val =
memory::AllocShared(place_, all_sample_size * sizeof(int64_t));
auto final_sample_val_dst =
memory::AllocShared(place_, all_sample_size * sizeof(int64_t));
int64_t* final_sample_val_ptr =
reinterpret_cast<int64_t* >(final_sample_val->ptr());
int64_t* final_sample_val_dst_ptr =
reinterpret_cast<int64_t* >(final_sample_val_dst->ptr());
fill_actual_neighbors<<<GET_BLOCKS(len * edge_to_id_len_),
CUDA_NUM_THREADS,
0,
stream_>>>(all_sample_val_ptr,
final_sample_val_ptr,
final_sample_val_dst_ptr,
all_sample_count_ptr,
thrust::raw_pointer_cast(cumsum_actual_sample_size.data()),
sample_size,
len * edge_to_id_len_);
len * edge_to_id_len_,
len);

*neighbor_len = all_sample_size;
cudaStreamSynchronize(stream_);

std::vector<std::shared_ptr<phi::Allocation>> sample_and_count;
sample_and_count.emplace_back(final_sample_val);
sample_and_count.emplace_back(all_sample_count);
return sample_and_count;
std::vector<std::shared_ptr<phi::Allocation>> sample_results;
sample_results.emplace_back(final_sample_val);
sample_results.emplace_back(final_sample_val_dst);
return sample_results;
}

std::shared_ptr<phi::Allocation> GraphDataGenerator::GetReindexResult(
int64_t* reindex_src_data, int64_t* reindex_dst_data,
const int* count_data, const int64_t* center_nodes,
int* final_nodes_len, int node_len, int64_t neighbor_len) {
int64_t* reindex_src_data, const int64_t* center_nodes, int* final_nodes_len,
int node_len, int64_t neighbor_len) {

VLOG(2) << gpuid_ << ": Enter GetReindexResult Function";
const phi::GPUContext& dev_ctx_ =
Expand All @@ -584,15 +593,12 @@ std::shared_ptr<phi::Allocation> GraphDataGenerator::GetReindexResult(

VLOG(2) << gpuid_ << ": ResetReindexTable With -1";
// Fill table with -1.
ResetReindexTable<int64_t><<<
GET_BLOCKS(reindex_table_size_), CUDA_NUM_THREADS, 0, stream_>>>(
d_reindex_table_key_ptr, reindex_table_size_);
ResetReindexTable<int><<<
GET_BLOCKS(reindex_table_size_), CUDA_NUM_THREADS, 0, stream_>>>(
d_reindex_table_value_ptr, reindex_table_size_);
ResetReindexTable<int><<<
GET_BLOCKS(reindex_table_size_), CUDA_NUM_THREADS, 0, stream_>>>(
d_reindex_table_index_ptr, reindex_table_size_);
cudaMemsetAsync(d_reindex_table_key_ptr, -1,
reindex_table_size_ * sizeof(int64_t), stream_);
cudaMemsetAsync(d_reindex_table_value_ptr, -1,
reindex_table_size_ * sizeof(int), stream_);
cudaMemsetAsync(d_reindex_table_index_ptr, -1,
reindex_table_size_ * sizeof(int), stream_);

VLOG(2) << gpuid_ << ": Alloc all_nodes";
auto all_nodes =
Expand All @@ -605,6 +611,7 @@ std::shared_ptr<phi::Allocation> GraphDataGenerator::GetReindexResult(
cudaMemcpy(all_nodes_data + node_len, reindex_src_data, sizeof(int64_t) * neighbor_len,
cudaMemcpyDeviceToDevice);

cudaStreamSynchronize(stream_);
VLOG(2) << gpuid_ << ": Run phi::FillHashTable";
auto final_nodes =
phi::FillHashTable<int64_t, phi::GPUContext>(dev_ctx_, all_nodes_data,
Expand All @@ -621,13 +628,6 @@ std::shared_ptr<phi::Allocation> GraphDataGenerator::GetReindexResult(
reindex_table_size_,
d_reindex_table_key_ptr,
d_reindex_table_value_ptr);
VLOG(2) << gpuid_ << ": Run phi::ReindexDst";
thrust::device_vector<int> scan_dst(node_len);
thrust::sequence(scan_dst.begin(), scan_dst.end());
phi::ReindexDst<int64_t, phi::GPUContext>(dev_ctx_, reindex_dst_data,
thrust::raw_pointer_cast(scan_dst.data()),
count_data, edge_to_id_len_, node_len);
return final_nodes;
}

Expand All @@ -652,81 +652,77 @@ std::shared_ptr<phi::Allocation> GraphDataGenerator::GenerateSampleGraph(
int uniq_len = uniq_nodes.numel();
int len_samples = samples_.size();

int64_t *num_nodes_tensor_ptr_[len_samples];
int64_t *next_num_nodes_tensor_ptr_[len_samples];
int *num_nodes_tensor_ptr_[len_samples];
int *next_num_nodes_tensor_ptr_[len_samples];
int64_t *edges_src_tensor_ptr_[len_samples];
int64_t *edges_dst_tensor_ptr_[len_samples];
int64_t *edges_split_tensor_ptr_[len_samples];
int *edges_split_tensor_ptr_[len_samples];

VLOG(2) << "Sample Neighbors and Reindex";
std::vector<int64_t> edges_split_num;
std::vector<int> edges_split_num;
std::vector<std::shared_ptr<phi::Allocation>> final_nodes_vec;
std::vector<int64_t> final_nodes_len_vec;
std::vector<int> final_nodes_len_vec;

for (int i = 0; i < len_samples; i++) {

edges_split_num.clear();
std::shared_ptr<phi::Allocation> neighbors, count;
std::shared_ptr<phi::Allocation> neighbors, reindex_dst;
int64_t neighbors_len = 0;
if (i == 0) {
auto sample_and_count =
auto sample_results =
SampleNeighbors(uniq_nodes_data, uniq_len, samples_[i], edges_split_num,
&neighbors_len);
neighbors = sample_and_count[0];
count = sample_and_count[1];
neighbors = sample_results[0];
reindex_dst = sample_results[1];
edges_split_num.push_back(uniq_len);
} else {
int64_t* final_nodes_data =
reinterpret_cast<int64_t* >(final_nodes_vec[i - 1]->ptr());
auto sample_and_count =
auto sample_results =
SampleNeighbors(final_nodes_data, final_nodes_len_vec[i - 1],
samples_[i], edges_split_num, &neighbors_len);
neighbors = sample_and_count[0];
count = sample_and_count[1];
neighbors = sample_results[0];
reindex_dst = sample_results[1];
edges_split_num.push_back(final_nodes_len_vec[i - 1]);
}
auto reindex_dst =
memory::AllocShared(place_, sizeof(int64_t) * neighbors_len);

int64_t* reindex_src_data = reinterpret_cast<int64_t* >(neighbors->ptr());
int64_t* reindex_dst_data = reinterpret_cast<int64_t* >(reindex_dst->ptr());
int* count_data = reinterpret_cast<int* >(count->ptr());
int final_nodes_len = 0;
if (i == 0) {
auto tmp_final_nodes =
GetReindexResult(reindex_src_data, reindex_dst_data, count_data,
uniq_nodes_data, &final_nodes_len, uniq_len,
neighbors_len);
GetReindexResult(reindex_src_data, uniq_nodes_data, &final_nodes_len,
uniq_len, neighbors_len);
final_nodes_vec.emplace_back(tmp_final_nodes);
final_nodes_len_vec.emplace_back(final_nodes_len);
} else {
int64_t* final_nodes_data =
reinterpret_cast<int64_t* >(final_nodes_vec[i - 1]->ptr());
auto tmp_final_nodes =
GetReindexResult(reindex_src_data, reindex_dst_data, count_data,
final_nodes_data, &final_nodes_len,
GetReindexResult(reindex_src_data, final_nodes_data, &final_nodes_len,
final_nodes_len_vec[i - 1], neighbors_len);
final_nodes_vec.emplace_back(tmp_final_nodes);
final_nodes_len_vec.emplace_back(final_nodes_len);
}
int offset = 3 + 2 * slot_num_ + 5 * i;
num_nodes_tensor_ptr_[i] =
feed_vec_[offset]->mutable_data<int64_t>({1}, this->place_);
feed_vec_[offset]->mutable_data<int>({1}, this->place_);
next_num_nodes_tensor_ptr_[i] =
feed_vec_[offset + 1]->mutable_data<int64_t>({1}, this->place_);
feed_vec_[offset + 1]->mutable_data<int>({1}, this->place_);
edges_src_tensor_ptr_[i] =
feed_vec_[offset + 2]->mutable_data<int64_t>({neighbors_len, 1}, this->place_);
edges_dst_tensor_ptr_[i] =
feed_vec_[offset + 3]->mutable_data<int64_t>({neighbors_len, 1}, this->place_);
edges_split_tensor_ptr_[i] =
feed_vec_[offset + 4]->mutable_data<int64_t>({edge_to_id_len_}, this->place_);
feed_vec_[offset + 4]->mutable_data<int>({edge_to_id_len_}, this->place_);

cudaMemcpyAsync(num_nodes_tensor_ptr_[i], final_nodes_len_vec.data() + i,
sizeof(int64_t), cudaMemcpyHostToDevice, stream_);
sizeof(int), cudaMemcpyHostToDevice, stream_);
cudaMemcpyAsync(next_num_nodes_tensor_ptr_[i], edges_split_num.data() + edge_to_id_len_,
sizeof(int64_t), cudaMemcpyHostToDevice, stream_);
sizeof(int), cudaMemcpyHostToDevice, stream_);
cudaMemcpyAsync(edges_split_tensor_ptr_[i], edges_split_num.data(),
sizeof(int64_t) * edge_to_id_len_, cudaMemcpyHostToDevice, stream_);
sizeof(int) * edge_to_id_len_, cudaMemcpyHostToDevice, stream_);
cudaMemcpyAsync(edges_src_tensor_ptr_[i], reindex_src_data,
sizeof(int64_t) * neighbors_len, cudaMemcpyDeviceToDevice, stream_);
cudaMemcpyAsync(edges_dst_tensor_ptr_[i], reindex_dst_data,
Expand Down Expand Up @@ -782,6 +778,7 @@ int GraphDataGenerator::GenerateBatch() {
0,
stream_>>>(clk_tensor_ptr_, total_instance);
} else {

auto node_buf = memory::AllocShared(
place_, total_instance * sizeof(uint64_t));
int64_t* node_buf_ptr = reinterpret_cast<int64_t* >(node_buf->ptr());
Expand Down Expand Up @@ -929,6 +926,7 @@ int GraphDataGenerator::GenerateBatch() {
CUDA_NUM_THREADS,
0,
stream_>>>(clk_tensor_ptr_, uniq_instance_);

}
} else {
ins_cursor = (uint64_t *)id_tensor_ptr_;
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -917,11 +917,10 @@ class GraphDataGenerator {
}
std::vector<std::shared_ptr<phi::Allocation>> SampleNeighbors(
int64_t* uniq_nodes, int len, int sample_size,
std::vector<int64_t>& edges_split_num, int64_t* neighbor_len);
std::vector<int>& edges_split_num, int64_t* neighbor_len);

std::shared_ptr<phi::Allocation> GetReindexResult(
int64_t* reindex_src_data, int64_t* reindex_dst_data,
const int* count, const int64_t* center_nodes,
int64_t* reindex_src_data, const int64_t* center_nodes,
int* final_nodes_len, int node_len, int64_t neighbor_len);

std::shared_ptr<phi::Allocation> GenerateSampleGraph(
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ class GpuPsGraphTable
int len,
bool cpu_query_switch,
bool compress);

int get_feature_of_nodes(
int gpu_id, uint64_t *d_walk, uint64_t *d_offset, int size, int slot_num);

Expand Down
Loading

0 comments on commit fe1653c

Please sign in to comment.