STORMM Source Documentation
Loading...
Searching...
No Matches
ptx_macros.h
1// -*-c++-*-
2#ifndef STORMM_PTX_MACROS_H
3#define STORMM_PTX_MACROS_H
4
5#include "copyright.h"
6
7#ifdef STORMM_USE_CUDA
8# define SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff, a, b, 32)
9# define SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b, 32)
10# define SHFL_UP(a, b) __shfl_up_sync(0xffffffff, a, b, 32)
11# define SHFL(a, b) __shfl_sync(0xffffffff, a, b, 32)
12# define SYNCWARP syncWarp()
13# define BALLOT(predicate) __ballot_sync(0xffffffff, predicate)
14
15// Transfer memory between two arrays without caching in either the read or write operations:
16# define NON_CACHING_XFER(dest, orig, dest_idx, orig_idx) { \
17 __stwt(&dest[dest_idx], __ldcv(&orig[orig_idx])); \
18}
19
20// Reduce all elements of the warp such that the sum appears in lane 0
21# define WARP_REDUCE_DOWN(var) { \
22 var += __shfl_down_sync(0xffffffff, var, 16, 32); \
23 var += __shfl_down_sync(0xffffffff, var, 8, 32); \
24 var += __shfl_down_sync(0xffffffff, var, 4, 32); \
25 var += __shfl_down_sync(0xffffffff, var, 2, 32); \
26 var += __shfl_down_sync(0xffffffff, var, 1, 32); \
27}
28
29// Compute an inclusive prefix sum over all elements of the warp. An example of this operation:
30// [ 5 3 9 1 0 5 ] -> [ 5 8 17 18 18 23 ]
31# define INCLUSIVE_WARP_PREFIXSUM(var, tgx) { \
32 var += ((tgx & 1) == 1) * __shfl_up_sync(0xffffffff, var, 1, 32); \
33 var += ((tgx & 3) == 3) * __shfl_up_sync(0xffffffff, var, 2, 32); \
34 var += ((tgx & 7) == 7) * __shfl_up_sync(0xffffffff, var, 4, 32); \
35 var += ((tgx & 15) == 15) * __shfl_up_sync(0xffffffff, var, 8, 32); \
36 var += (tgx == 31) * __shfl_up_sync(0xffffffff, var, 16, 32); \
37 var += ((tgx & 15) == 7 && tgx > 16) * __shfl_up_sync(0xffffffff, var, 8, 32); \
38 var += ((tgx & 7) == 3 && tgx > 8) * __shfl_up_sync(0xffffffff, var, 4, 32); \
39 var += ((tgx & 3) == 1 && tgx > 4) * __shfl_up_sync(0xffffffff, var, 2, 32); \
40 var += ((tgx & 1) == 0 && tgx >= 2) * __shfl_up_sync(0xffffffff, var, 1, 32); \
41}
42
43// Compute an exclusive prefix sum over all elements of the warp. If the total of the prefix sum
44// is required, i.e. the upper limit of the final bin, use EXLCUSIVE_WARP_PREFIXSUM_SAVETOTAL. An
45// example of this operation: [ 5 3 9 1 0 5 ] -> [ 0 5 8 17 18 18 23 ]
46# define EXCLUSIVE_WARP_PREFIXSUM(var, tgx) { \
47 var += ((tgx & 1) == 1) * __shfl_up_sync(0xffffffff, var, 1, 32); \
48 var += ((tgx & 3) == 3) * __shfl_up_sync(0xffffffff, var, 2, 32); \
49 var += ((tgx & 7) == 7) * __shfl_up_sync(0xffffffff, var, 4, 32); \
50 var += ((tgx & 15) == 15) * __shfl_up_sync(0xffffffff, var, 8, 32); \
51 var += (tgx == 31) * __shfl_up_sync(0xffffffff, var, 16, 32); \
52 var += ((tgx & 15) == 7 && tgx > 16) * __shfl_up_sync(0xffffffff, var, 8, 32); \
53 var += ((tgx & 7) == 3 && tgx > 8) * __shfl_up_sync(0xffffffff, var, 4, 32); \
54 var += ((tgx & 3) == 1 && tgx > 4) * __shfl_up_sync(0xffffffff, var, 2, 32); \
55 var += ((tgx & 1) == 0 && tgx >= 2) * __shfl_up_sync(0xffffffff, var, 1, 32); \
56 var = __shfl_up_sync(0xffffffff, var, 1, 32); \
57 if (tgx == 0) { \
58 var = 0; \
59 } \
60}
61
62// Compute an exclusive prefix sum over all elements of the warp and retain the total in an
63// auxiliary variable which is then broadcast to all threads.
64# define EXCLUSIVE_WARP_PREFIXSUM_SAVETOTAL(var, tgx, result) { \
65 var += ((tgx & 1) == 1) * __shfl_up_sync(0xffffffff, var, 1, 32); \
66 var += ((tgx & 3) == 3) * __shfl_up_sync(0xffffffff, var, 2, 32); \
67 var += ((tgx & 7) == 7) * __shfl_up_sync(0xffffffff, var, 4, 32); \
68 var += ((tgx & 15) == 15) * __shfl_up_sync(0xffffffff, var, 8, 32); \
69 var += (tgx == 31) * __shfl_up_sync(0xffffffff, var, 16, 32); \
70 var += ((tgx & 15) == 7 && tgx > 16) * __shfl_up_sync(0xffffffff, var, 8, 32); \
71 var += ((tgx & 7) == 3 && tgx > 8) * __shfl_up_sync(0xffffffff, var, 4, 32); \
72 var += ((tgx & 3) == 1 && tgx > 4) * __shfl_up_sync(0xffffffff, var, 2, 32); \
73 var += ((tgx & 1) == 0 && tgx >= 2) * __shfl_up_sync(0xffffffff, var, 1, 32); \
74 result = __shfl_sync(0xffffffff, var, 31, 32); \
75 var = __shfl_up_sync(0xffffffff, var, 1, 32); \
76 if (tgx == 0) { \
77 var = 0; \
78 } \
79}
80#endif // STORMM_USE_CUDA
81
82#ifdef STORMM_USE_HIP
83# define SHFL_DOWN(a, b) __shfl_down(0xffffffffffffffff, a, b, 64)
84# define SHFL_XOR(a, b) __shfl_xor(0xffffffffffffffff, a, b, 64)
85# define SHFL_UP(a, b) __shfl_up(0xffffffffffffffff, a, b, 64)
86# define SHFL(a, b) __shfl(0xffffffffffffffff, a, b, 64)
87# define SYNCWARP
88# define BALLOT(predicate) __ballot(0xffffffffffffffff, predicate)
89
90// Reduce all elements of the warp such that the sum appears in lane 0
91# define WARP_REDUCE_DOWN(var) { \
92 var += __shfl_down(var, 32); \
93 var += __shfl_down(var, 16); \
94 var += __shfl_down(var, 8); \
95 var += __shfl_down(var, 4); \
96 var += __shfl_down(var, 2); \
97 var += __shfl_down(var, 1); \
98}
99
100// Compute an inclusive prefix sum over all elements of the warp. An example of this operation:
101// [ 5 3 9 1 0 5 ] -> [ 5 8 17 18 18 23 ]
102# define INCLUSIVE_WARP_PREFIXSUM(var, tgx) { \
103 var += ((tgx & 1) == 1) * __shfl_up(var, 1); \
104 var += ((tgx & 3) == 3) * __shfl_up(var, 2); \
105 var += ((tgx & 7) == 7) * __shfl_up(var, 4); \
106 var += ((tgx & 15) == 15) * __shfl_up(var, 8); \
107 var += ((tgx & 31) == 31) * __shfl_up(var, 16); \
108 var += (tgx == 63) * __shfl_up(var, 32); \
109 var += ((tgx & 31) == 15 && tgx > 32) * __shfl_up(var, 16); \
110 var += ((tgx & 15) == 7 && tgx > 16) * __shfl_up(var, 8); \
111 var += ((tgx & 7) == 3 && tgx > 8) * __shfl_up(var, 4); \
112 var += ((tgx & 3) == 1 && tgx > 4) * __shfl_up(var, 2); \
113 var += ((tgx & 1) == 0 && tgx >= 2) * __shfl_up(var, 1); \
114}
115
116// Compute an exclusive prefix sum over all elements of the warp. If the total of the prefix sum
117// is required, i.e. the upper limit of the final bin, use EXLCUSIVE_WARP_PREFIXSUM_SAVETOTAL. An
118// example of this operation: [ 5 3 9 1 0 5 ] -> [ 0 5 8 17 18 18 23 ]
119# define EXCLUSIVE_WARP_PREFIXSUM(var, tgx) { \
120 var += ((tgx & 1) == 1) * __shfl_up(var, 1); \
121 var += ((tgx & 3) == 3) * __shfl_up(var, 2); \
122 var += ((tgx & 7) == 7) * __shfl_up(var, 4); \
123 var += ((tgx & 15) == 15) * __shfl_up(var, 8); \
124 var += ((tgx & 31) == 31) * __shfl_up(var, 16); \
125 var += (tgx == 63) * __shfl_up(var, 32); \
126 var += ((tgx & 31) == 15 && tgx > 32) * __shfl_up(var, 16); \
127 var += ((tgx & 15) == 7 && tgx > 16) * __shfl_up(var, 8); \
128 var += ((tgx & 7) == 3 && tgx > 8) * __shfl_up(var, 4); \
129 var += ((tgx & 3) == 1 && tgx > 4) * __shfl_up(var, 2); \
130 var += ((tgx & 1) == 0 && tgx >= 2) * __shfl_up(var, 1); \
131 var = __shfl_up(var, 1); \
132 if (tgx == 0) { \
133 var = 0; \
134 } \
135}
136
137// Compute an exclusive prefix sum over all elements of the warp and retain the total in an
138// auxiliary variable which is then broadcast to all threads.
139# define EXCLUSIVE_WARP_PREFIXSUM_SAVETOTAL(var, tgx, result) { \
140 var += ((tgx & 1) == 1) * __shfl_up(var, 1); \
141 var += ((tgx & 3) == 3) * __shfl_up(var, 2); \
142 var += ((tgx & 7) == 7) * __shfl_up(var, 4); \
143 var += ((tgx & 15) == 15) * __shfl_up(var, 8); \
144 var += ((tgx & 31) == 31) * __shfl_up(var, 16); \
145 var += (tgx == 63) * __shfl_up(var, 32); \
146 var += ((tgx & 31) == 15 && tgx > 32) * __shfl_up(var, 16); \
147 var += ((tgx & 15) == 7 && tgx > 16) * __shfl_up(var, 8); \
148 var += ((tgx & 7) == 3 && tgx > 8) * __shfl_up(var, 4); \
149 var += ((tgx & 3) == 1 && tgx > 4) * __shfl_up(var, 2); \
150 var += ((tgx & 1) == 0 && tgx >= 2) * __shfl_up(var, 1); \
151 result = __shfl(var, 63); \
152 var = __shfl_up(var, 1); \
153 if (tgx == 0) { \
154 var = 0; \
155 } \
156}
157#endif // STORMM_USE_HIP
158
159#endif