|
23 | 23 | #include <raft/util/cuda_utils.cuh>
|
24 | 24 | #include <raft/util/cudart_utils.hpp>
|
25 | 25 |
|
26 |
| -#include <thrust/functional.h> |
| 26 | +#include <cuda/std/functional> |
27 | 27 |
|
28 | 28 | #include <algorithm>
|
29 | 29 | #include <cmath>
|
@@ -81,13 +81,13 @@ struct vec {
|
81 | 81 | __host__ __device__ T operator[](int i) const { return data[i]; }
|
82 | 82 | friend __host__ __device__ vec<N, T> operator+(const vec<N, T>& a, const vec<N, T>& b)
|
83 | 83 | {
|
84 |
| - return vectorized(cub::Sum())(a, b); |
| 84 | + return vectorized(cuda::std::plus<T>{})(a, b); |
85 | 85 | }
|
86 | 86 | friend __host__ __device__ void operator+=(vec<N, T>& a, const vec<N, T>& b) { a = a + b; }
|
87 | 87 | template <typename Vec>
|
88 | 88 | friend __host__ __device__ vec<N, T> operator/(vec<N, T>& a, const Vec& b)
|
89 | 89 | {
|
90 |
| - return vectorized(thrust::divides<T>())(a, vec<N, T>(b)); |
| 90 | + return vectorized(cuda::std::divides<T>())(a, vec<N, T>(b)); |
91 | 91 | }
|
92 | 92 | template <typename Vec>
|
93 | 93 | friend __host__ __device__ void operator/=(vec<N, T>& a, const Vec& b)
|
@@ -295,7 +295,7 @@ struct tree_aggregator_t {
|
295 | 295 | // ensure input columns can be overwritten (no threads traversing trees)
|
296 | 296 | __syncthreads();
|
297 | 297 | if (log2_threads_per_tree == 0) {
|
298 |
| - acc = block_reduce(acc, vectorized(cub::Sum()), tmp_storage); |
| 298 | + acc = block_reduce(acc, vectorized(cuda::std::plus{}), tmp_storage); |
299 | 299 | } else {
|
300 | 300 | auto per_thread = (vec<NITEMS, real_t>*)tmp_storage;
|
301 | 301 | per_thread[threadIdx.x] = acc;
|
@@ -383,7 +383,7 @@ __device__ __forceinline__ void block_softmax(Iterator begin, Iterator end, void
|
383 | 383 | for (Iterator it = begin + threadIdx.x; it < end; it += blockDim.x)
|
384 | 384 | *it = vectorized(shifted_exp())(*it, max);
|
385 | 385 | // sum of exponents
|
386 |
| - value_type soe = allreduce_shmem(begin, end, vectorized(cub::Sum()), tmp_storage); |
| 386 | + value_type soe = allreduce_shmem(begin, end, vectorized(cuda::std::plus{}), tmp_storage); |
387 | 387 | // softmax phase 2: normalization
|
388 | 388 | for (Iterator it = begin + threadIdx.x; it < end; it += blockDim.x)
|
389 | 389 | *it /= soe;
|
|
0 commit comments