Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_weight_decay and test_graph_reindex #62707

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmake/external/cccl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ set(CCCL_INCLUDE_DIR ${CCCL_SOURCE_DIR})
message("CCCL_INCLUDE_DIR is ${CCCL_INCLUDE_DIR}")
include_directories(${CCCL_INCLUDE_DIR})

file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/cccl/util_device.cuh.patch
native_src)
set(CCCL_PATCH_COMMAND git checkout -- . && git checkout ${CCCL_TAG} && patch
-p1 -Nd ${CCCL_SOURCE_DIR} < ${native_src})

ExternalProject_Add(
extern_cccl
${EXTERNAL_PROJECT_LOG_ARGS}
SOURCE_DIR ${CCCL_SOURCE_DIR}
PREFIX ${CCCL_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND ${CCCL_PATCH_COMMAND}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
Expand Down
59 changes: 20 additions & 39 deletions paddle/phi/kernels/gpu/graph_reindex_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,53 +67,34 @@ std::shared_ptr<phi::Allocation> FillHashTable(const Context& dev_ctx,
input, num_input, len_hashtable, keys, key_index);

// Get item index count.
auto item_count =
phi::memory_utils::Alloc(place, (num_input + 1) * sizeof(int));
int* item_count_ptr = reinterpret_cast<int*>(item_count->ptr());
#ifdef PADDLE_WITH_HIP
hipMemset(item_count_ptr, 0, sizeof(int) * (num_input + 1));
#else
cudaMemset(item_count_ptr, 0, sizeof(int) * (num_input + 1));
#endif
thrust::device_vector<int> item_count(num_input + 1, 0);
GetItemIndexCount<T><<<grid, block, 0, dev_ctx.stream()>>>(
input, item_count_ptr, num_input, len_hashtable, keys, key_index);

size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(
NULL, temp_storage_bytes, item_count_ptr, item_count_ptr, num_input + 1);
auto d_temp_storage = phi::memory_utils::Alloc(place, temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(),
temp_storage_bytes,
item_count_ptr,
item_count_ptr,
num_input + 1);
int total_unique_items = 0;
#ifdef PADDLE_WITH_HIP
hipMemcpy(&total_unique_items,
item_count_ptr + num_input,
sizeof(int),
hipMemcpyDeviceToHost);
#else
cudaMemcpy(&total_unique_items,
item_count_ptr + num_input,
sizeof(int),
cudaMemcpyDeviceToHost);
#endif
input,
thrust::raw_pointer_cast(item_count.data()),
num_input,
len_hashtable,
keys,
key_index);

thrust::exclusive_scan(
item_count.begin(), item_count.end(), item_count.begin());

int total_unique_items = item_count[num_input];
auto unique_items =
phi::memory_utils::AllocShared(place, total_unique_items * sizeof(T));
T* unique_items_data = reinterpret_cast<T*>(unique_items->ptr());
*final_nodes_len = total_unique_items;

// Get unique items
FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(input,
num_input,
len_hashtable,
unique_items_data,
item_count_ptr,
keys,
values,
key_index);
FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(
input,
num_input,
len_hashtable,
unique_items_data,
thrust::raw_pointer_cast(item_count.data()),
keys,
values,
key_index);
return unique_items;
}

Expand Down
31 changes: 31 additions & 0 deletions patches/cccl/util_device.cuh.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh
index c7e15cafe..756336914 100644
--- a/cub/cub/util_device.cuh
+++ b/cub/cub/util_device.cuh
@@ -278,7 +278,7 @@ public:
/**
* \brief Retrieves the PTX version that will be used on the current device (major * 100 + minor * 10).
*/
-CUB_RUNTIME_FUNCTION inline cudaError_t PtxVersionUncached(int& ptx_version)
+CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersionUncached(int& ptx_version)
{
// Instantiate `EmptyKernel<void>` in both host and device code to ensure
// it can be called.
@@ -375,7 +375,7 @@ __host__ inline cudaError_t PtxVersion(int& ptx_version, int device)
*
* \note This function is thread safe.
*/
-CUB_RUNTIME_FUNCTION inline cudaError_t PtxVersion(int &ptx_version)
+CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersion(int &ptx_version)
{
cudaError_t result = cudaErrorUnknown;
NV_IF_TARGET(
@@ -593,7 +593,7 @@ CUB_RUNTIME_FUNCTION inline cudaError_t HasUVA(bool& has_uva)
*
*/
template <typename KernelPtr>
-CUB_RUNTIME_FUNCTION inline
+CUB_RUNTIME_FUNCTION __forceinline__
cudaError_t MaxSmOccupancy(
int& max_sm_occupancy, ///< [out] maximum number of thread blocks that can reside on a single SM
KernelPtr kernel_ptr, ///< [in] Kernel pointer for which to compute SM occupancy