-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform.hxx
40 lines (30 loc) · 1.08 KB
/
transform.hxx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#pragma once
#include "launch.hxx"
BEGIN_MGPU_NAMESPACE
namespace vk {
// Launch a grid and pass (tid, cta)
template<int nt, int subgroup_size = -1, typename func_t>
[[using spirv: comp, local_size(nt), subgroup_size(subgroup_size), push]]
void launch_cs(func_t func) {
func(threadIdx.x, blockIdx.x);
}
template<int nt, int subgroup_size = -1, typename func_t>
static void launch(int num_blocks, cmd_buffer_t& cmd_buffer, func_t func) {
launch_cs<nt, subgroup_size><<<num_blocks, cmd_buffer>>>(func);
}
// Launch a grid and pass gid.
template<int nt = 256, int subgroup_size = -1, typename func_t>
[[using spirv: comp, local_size(nt), subgroup_size(subgroup_size), push]]
void transform_cs(int count, func_t func) {
int gid = glcomp_GlobalInvocationID.x;
if(gid >= count)
return;
func(gid);
}
template<int nt = 256, int subgroup_size = -1, typename func_t>
static void transform(int count, cmd_buffer_t& cmd_buffer, func_t func) {
int num_blocks = div_up(count, nt);
transform_cs<nt, subgroup_size><<<num_blocks, cmd_buffer>>>(count, func);
}
} // namespace vk
END_MGPU_NAMESPACE