2
2
3
3
# TODO : does not work on sub-word (ie. Int16) or non-word divisible sized types
4
4
5
- # TODO : should shfl_idx conform to 1-based indexing?
6
-
7
5
# TODO : these functions should dispatch based on the actual warp size
8
6
const ws = Int32 (32 )
9
7
@@ -14,52 +12,48 @@ const ws = Int32(32)
14
12
15
13
# "two packed values specifying a mask for logically splitting warps into sub-segments
16
14
# and an upper bound for clamping the source lane index"
17
- @inline pack (width:: UInt32 , mask:: UInt32 ) :: UInt32 = (convert (UInt32, ws - width) << 8 ) | mask
15
+ @inline pack (width, mask) = (convert (UInt32, ws - width) << 8 ) | convert (UInt32, mask)
18
16
19
17
# NOTE: CUDA C disagrees with PTX on how shuffles are called
20
- for (name, mode, mask) in ((" _up" , :up , UInt32 (0x00 )),
21
- (" _down" , :down , UInt32 (0x1f )),
22
- (" _xor" , :bfly , UInt32 (0x1f )),
23
- (" " , :idx , UInt32 (0x1f )))
18
+ for (name, mode, mask, offset ) in ((" _up" , :up , UInt32 (0x00 ), src -> src ),
19
+ (" _down" , :down , UInt32 (0x1f ), src -> src ),
20
+ (" _xor" , :bfly , UInt32 (0x1f ), src -> src ),
21
+ (" " , :idx , UInt32 (0x1f ), src -> :( $ src - 1 )))
24
22
fname = Symbol (" shfl$name " )
23
+ @eval export $ fname
25
24
26
25
if cuda_driver_version >= v " 9.0" && v " 6.0" in ptx_support
27
- instruction = Symbol (" shfl.sync.$mode .b32" )
28
- fname_sync = Symbol (" $(fname) _sync" )
29
-
30
- # TODO : implement using LLVM intrinsics when we have D38090
26
+ # newer hardware/CUDA versions use synchronizing intrinsics, which take an extra
27
+ # mask argument indicating which threads in the lane should be synchronized
28
+ intrinsic = " llvm.nvvm.shfl.sync.$mode .i32"
31
29
30
+ fname_sync = Symbol (" $(fname) _sync" )
31
+ __fname_sync = Symbol (" __$(fname) _sync" )
32
32
@eval begin
33
- export $ fname_sync, $ fname
34
-
35
- @inline $ fname_sync (val:: UInt32 , src:: UInt32 , width:: UInt32 = $ ws,
36
- threadmask:: UInt32 = 0xffffffff ) =
37
- @asmcall ($ " $instruction \$ 0, \$ 1, \$ 2, \$ 3, \$ 4;" , " =r,r,r,r,r" , true ,
38
- UInt32, NTuple{4 ,UInt32},
39
- val, src, pack (width, $ mask), threadmask)
40
-
41
- # FIXME : replace this with a checked conversion once we have exceptions
42
- @inline $ fname_sync (val:: UInt32 , src:: Integer , width:: Integer = $ ws,
43
- threadmask:: UInt32 = 0xffffffff ) =
44
- $ fname_sync (val, unsafe_trunc (UInt32, src), unsafe_trunc (UInt32, width),
45
- threadmask)
46
-
47
- @inline $ fname (val:: UInt32 , src:: Integer , width:: Integer = $ ws) =
48
- $ fname_sync (val, src, width)
33
+ export $ fname_sync
34
+
35
+ # HACK: recurse_value_invocation and friends split the first argument of a call,
36
+ # so swap mask and val for these tools to works.
37
+ @inline $ fname_sync (mask, val, src, width= $ ws) =
38
+ $ __fname_sync (val, mask, src, width)
39
+ @inline $ __fname_sync (val:: UInt32 , mask, src, width) =
40
+ ccall ($ intrinsic, llvmcall, UInt32,
41
+ (UInt32, UInt32, UInt32, UInt32),
42
+ mask, val, $ (offset (:src )), pack (width, $ mask))
43
+
44
+ # for backwards compatibility, have the non-synchronizing intrinsic dispatch
45
+ # to the synchronizing one (with a full-lane default value for the mask)
46
+ @inline $ fname (val:: UInt32 , src, width= $ ws, mask:: UInt32 = 0xffffffff ) =
47
+ $ fname_sync (mask, val, src, width)
49
48
end
50
49
else
51
- intrinsic = Symbol ( " llvm.nvvm.shfl.$mode .i32" )
50
+ intrinsic = " llvm.nvvm.shfl.$mode .i32"
52
51
53
52
@eval begin
54
- export $ fname
55
- @inline $ fname (val:: UInt32 , src:: UInt32 , width:: UInt32 = $ ws) =
56
- ccall ($ " $intrinsic " , llvmcall, UInt32,
53
+ @inline $ fname (val:: UInt32 , src, width= $ ws) =
54
+ ccall ($ intrinsic, llvmcall, UInt32,
57
55
(UInt32, UInt32, UInt32),
58
- val, src, pack (width, $ mask))
59
-
60
- # FIXME : replace this with a checked conversion once we have exceptions
61
- @inline $ fname (val:: UInt32 , src:: Integer , width:: Integer = $ ws) =
62
- $ fname (val, unsafe_trunc (UInt32, src), unsafe_trunc (UInt32, width))
56
+ val, $ (offset (:src )), pack (width, $ mask))
63
57
end
64
58
end
65
59
end
@@ -71,62 +65,70 @@ for name in ["_up", "_down", "_xor", ""]
71
65
fname = Symbol (" shfl$name " )
72
66
@eval @inline $ fname (src, args... ) = recurse_value_invocation ($ fname, src, args... )
73
67
74
- fname_sync = Symbol (" $(fname) _sync" )
75
- @eval @inline $ fname_sync (src, args... ) = recurse_value_invocation ($ fname , src, args... )
68
+ fname_sync = Symbol (" __ $(fname) _sync" )
69
+ @eval @inline $ fname_sync (src, args... ) = recurse_value_invocation ($ fname_sync , src, args... )
76
70
end
77
71
78
72
79
73
# documentation
80
74
81
75
@doc """
82
- shfl(val, lane::Integer, width::Integer=32)
76
+ shfl(val, lane::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
83
77
84
- Shuffle a value from a directly indexed lane `lane`.
78
+ Shuffle a value from a directly indexed lane `lane`. The argument `threadmask` for selecting
79
+ which threads to synchronize is only available on recent hardware, and defaults to all
80
+ threads in the warp.
85
81
""" shfl
86
82
87
83
@doc """
88
- shfl_up(val, delta::Integer, width::Integer=32)
84
+ shfl_up(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
89
85
90
- Shuffle a value from a lane with lower ID relative to caller.
86
+ Shuffle a value from a lane with lower ID relative to caller. The argument `threadmask` for
87
+ selecting which threads to synchronize is only available on recent hardware, and defaults to
88
+ all threads in the warp.
91
89
""" shfl_up
92
90
93
91
@doc """
94
- shfl_down(val, delta::Integer, width::Integer=32)
92
+ shfl_down(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
95
93
96
- Shuffle a value from a lane with higher ID relative to caller.
94
+ Shuffle a value from a lane with higher ID relative to caller. The argument `threadmask` for
95
+ selecting which threads to synchronize is only available on recent hardware, and defaults to
96
+ all threads in the warp.
97
97
""" shfl_down
98
98
99
99
@doc """
100
- shfl_xor(val, mask ::Integer, width::Integer=32)
100
+ shfl_xor(val, lanemask ::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
101
101
102
- Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`.
102
+ Shuffle a value from a lane based on bitwise XOR of own lane ID with `lanemask`. The
103
+ argument `threadmask` for selecting which threads to synchronize is only available on recent
104
+ hardware, and defaults to all threads in the warp.
103
105
""" shfl_xor
104
106
105
107
106
108
@doc """
107
- shfl_sync(val, lane::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
109
+ shfl_sync(threadmask::UInt32, val, lane::Integer, width::Integer=32)
108
110
109
- Shuffle a value from a directly indexed lane `lane`. The default value for `threadmask`
110
- performs the shuffle on all threads in the warp .
111
+ Shuffle a value from a directly indexed lane `lane`, and synchronize threads according to
112
+ `threadmask` .
111
113
""" shfl_sync
112
114
113
115
@doc """
114
- shfl_up_sync(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
116
+ shfl_up_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)
115
117
116
- Shuffle a value from a lane with lower ID relative to caller. The default value for
117
- `threadmask` performs the shuffle on all threads in the warp .
118
+ Shuffle a value from a lane with lower ID relative to caller, and synchronize threads
119
+ according to `threadmask`.
118
120
""" shfl_up_sync
119
121
120
122
@doc """
121
- shfl_down_sync(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
123
+ shfl_down_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)
122
124
123
- Shuffle a value from a lane with higher ID relative to caller. The default value for
124
- `threadmask` performs the shuffle on all threads in the warp .
125
+ Shuffle a value from a lane with higher ID relative to caller, and synchronize threads
126
+ according to `threadmask`.
125
127
""" shfl_down_sync
126
128
127
129
@doc """
128
- shfl_xor_sync(val, mask::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
130
+ shfl_xor_sync(threadmask::UInt32, val, mask::Integer, width::Integer=32)
129
131
130
- Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`. The default
131
- value for `threadmask` performs the shuffle on all threads in the warp .
132
+ Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`, and synchronize
133
+ threads according to `threadmask`.
132
134
""" shfl_xor_sync
0 commit comments