Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit a4a56bd

Browse files
authored
Merge pull request #464 from JuliaGPU/tb/shfl
Improvements to shfl
2 parents d9d9da9 + c95ad5e commit a4a56bd

File tree

2 files changed

+109
-67
lines changed

2 files changed

+109
-67
lines changed

src/device/cuda/warp_shuffle.jl

Lines changed: 60 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
# TODO: does not work on sub-word (ie. Int16) or non-word divisible sized types
44

5-
# TODO: should shfl_idx conform to 1-based indexing?
6-
75
# TODO: these functions should dispatch based on the actual warp size
86
const ws = Int32(32)
97

@@ -14,52 +12,48 @@ const ws = Int32(32)
1412

1513
# "two packed values specifying a mask for logically splitting warps into sub-segments
1614
# and an upper bound for clamping the source lane index"
17-
@inline pack(width::UInt32, mask::UInt32)::UInt32 = (convert(UInt32, ws - width) << 8) | mask
15+
@inline pack(width, mask) = (convert(UInt32, ws - width) << 8) | convert(UInt32, mask)
1816

1917
# NOTE: CUDA C disagrees with PTX on how shuffles are called
20-
for (name, mode, mask) in (("_up", :up, UInt32(0x00)),
21-
("_down", :down, UInt32(0x1f)),
22-
("_xor", :bfly, UInt32(0x1f)),
23-
("", :idx, UInt32(0x1f)))
18+
for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
19+
("_down", :down, UInt32(0x1f), src->src),
20+
("_xor", :bfly, UInt32(0x1f), src->src),
21+
("", :idx, UInt32(0x1f), src->:($src-1)))
2422
fname = Symbol("shfl$name")
23+
@eval export $fname
2524

2625
if cuda_driver_version >= v"9.0" && v"6.0" in ptx_support
27-
instruction = Symbol("shfl.sync.$mode.b32")
28-
fname_sync = Symbol("$(fname)_sync")
29-
30-
# TODO: implement using LLVM intrinsics when we have D38090
26+
# newer hardware/CUDA versions use synchronizing intrinsics, which take an extra
27+
# mask argument indicating which threads in the lane should be synchronized
28+
intrinsic = "llvm.nvvm.shfl.sync.$mode.i32"
3129

30+
fname_sync = Symbol("$(fname)_sync")
31+
__fname_sync = Symbol("__$(fname)_sync")
3232
@eval begin
33-
export $fname_sync, $fname
34-
35-
@inline $fname_sync(val::UInt32, src::UInt32, width::UInt32=$ws,
36-
threadmask::UInt32=0xffffffff) =
37-
@asmcall($"$instruction \$0, \$1, \$2, \$3, \$4;", "=r,r,r,r,r", true,
38-
UInt32, NTuple{4,UInt32},
39-
val, src, pack(width, $mask), threadmask)
40-
41-
# FIXME: replace this with a checked conversion once we have exceptions
42-
@inline $fname_sync(val::UInt32, src::Integer, width::Integer=$ws,
43-
threadmask::UInt32=0xffffffff) =
44-
$fname_sync(val, unsafe_trunc(UInt32, src), unsafe_trunc(UInt32, width),
45-
threadmask)
46-
47-
@inline $fname(val::UInt32, src::Integer, width::Integer=$ws) =
48-
$fname_sync(val, src, width)
33+
export $fname_sync
34+
35+
# HACK: recurse_value_invocation and friends split the first argument of a call,
36+
# so swap mask and val for these tools to works.
37+
@inline $fname_sync(mask, val, src, width=$ws) =
38+
$__fname_sync(val, mask, src, width)
39+
@inline $__fname_sync(val::UInt32, mask, src, width) =
40+
ccall($intrinsic, llvmcall, UInt32,
41+
(UInt32, UInt32, UInt32, UInt32),
42+
mask, val, $(offset(:src)), pack(width, $mask))
43+
44+
# for backwards compatibility, have the non-synchronizing intrinsic dispatch
45+
# to the synchronizing one (with a full-lane default value for the mask)
46+
@inline $fname(val::UInt32, src, width=$ws, mask::UInt32=0xffffffff) =
47+
$fname_sync(mask, val, src, width)
4948
end
5049
else
51-
intrinsic = Symbol("llvm.nvvm.shfl.$mode.i32")
50+
intrinsic = "llvm.nvvm.shfl.$mode.i32"
5251

5352
@eval begin
54-
export $fname
55-
@inline $fname(val::UInt32, src::UInt32, width::UInt32=$ws) =
56-
ccall($"$intrinsic", llvmcall, UInt32,
53+
@inline $fname(val::UInt32, src, width=$ws) =
54+
ccall($intrinsic, llvmcall, UInt32,
5755
(UInt32, UInt32, UInt32),
58-
val, src, pack(width, $mask))
59-
60-
# FIXME: replace this with a checked conversion once we have exceptions
61-
@inline $fname(val::UInt32, src::Integer, width::Integer=$ws) =
62-
$fname(val, unsafe_trunc(UInt32, src), unsafe_trunc(UInt32, width))
56+
val, $(offset(:src)), pack(width, $mask))
6357
end
6458
end
6559
end
@@ -71,62 +65,70 @@ for name in ["_up", "_down", "_xor", ""]
7165
fname = Symbol("shfl$name")
7266
@eval @inline $fname(src, args...) = recurse_value_invocation($fname, src, args...)
7367

74-
fname_sync = Symbol("$(fname)_sync")
75-
@eval @inline $fname_sync(src, args...) = recurse_value_invocation($fname, src, args...)
68+
fname_sync = Symbol("__$(fname)_sync")
69+
@eval @inline $fname_sync(src, args...) = recurse_value_invocation($fname_sync, src, args...)
7670
end
7771

7872

7973
# documentation
8074

8175
@doc """
82-
shfl(val, lane::Integer, width::Integer=32)
76+
shfl(val, lane::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
8377
84-
Shuffle a value from a directly indexed lane `lane`.
78+
Shuffle a value from a directly indexed lane `lane`. The argument `threadmask` for selecting
79+
which threads to synchronize is only available on recent hardware, and defaults to all
80+
threads in the warp.
8581
""" shfl
8682

8783
@doc """
88-
shfl_up(val, delta::Integer, width::Integer=32)
84+
shfl_up(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
8985
90-
Shuffle a value from a lane with lower ID relative to caller.
86+
Shuffle a value from a lane with lower ID relative to caller. The argument `threadmask` for
87+
selecting which threads to synchronize is only available on recent hardware, and defaults to
88+
all threads in the warp.
9189
""" shfl_up
9290

9391
@doc """
94-
shfl_down(val, delta::Integer, width::Integer=32)
92+
shfl_down(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
9593
96-
Shuffle a value from a lane with higher ID relative to caller.
94+
Shuffle a value from a lane with higher ID relative to caller. The argument `threadmask` for
95+
selecting which threads to synchronize is only available on recent hardware, and defaults to
96+
all threads in the warp.
9797
""" shfl_down
9898

9999
@doc """
100-
shfl_xor(val, mask::Integer, width::Integer=32)
100+
shfl_xor(val, lanemask::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
101101
102-
Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`.
102+
Shuffle a value from a lane based on bitwise XOR of own lane ID with `lanemask`. The
103+
argument `threadmask` for selecting which threads to synchronize is only available on recent
104+
hardware, and defaults to all threads in the warp.
103105
""" shfl_xor
104106

105107

106108
@doc """
107-
shfl_sync(val, lane::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
109+
shfl_sync(threadmask::UInt32, val, lane::Integer, width::Integer=32)
108110
109-
Shuffle a value from a directly indexed lane `lane`. The default value for `threadmask`
110-
performs the shuffle on all threads in the warp.
111+
Shuffle a value from a directly indexed lane `lane`, and synchronize threads according to
112+
`threadmask`.
111113
""" shfl_sync
112114

113115
@doc """
114-
shfl_up_sync(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
116+
shfl_up_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)
115117
116-
Shuffle a value from a lane with lower ID relative to caller. The default value for
117-
`threadmask` performs the shuffle on all threads in the warp.
118+
Shuffle a value from a lane with lower ID relative to caller, and synchronize threads
119+
according to `threadmask`.
118120
""" shfl_up_sync
119121

120122
@doc """
121-
shfl_down_sync(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
123+
shfl_down_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)
122124
123-
Shuffle a value from a lane with higher ID relative to caller. The default value for
124-
`threadmask` performs the shuffle on all threads in the warp.
125+
Shuffle a value from a lane with higher ID relative to caller, and synchronize threads
126+
according to `threadmask`.
125127
""" shfl_down_sync
126128

127129
@doc """
128-
shfl_xor_sync(val, mask::Integer, width::Integer=32, threadmask::UInt32=0xffffffff)
130+
shfl_xor_sync(threadmask::UInt32, val, mask::Integer, width::Integer=32)
129131
130-
Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`. The default
131-
value for `threadmask` performs the shuffle on all threads in the warp.
132+
Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`, and synchronize
133+
threads according to `threadmask`.
132134
""" shfl_xor_sync

test/device/cuda.jl

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,25 @@ end
528528
@testset "data movement and conversion" begin
529529

530530
if capability(dev) >= v"3.0"
531-
@testset "shuffle down" begin
532531

532+
@testset "shuffle idx" begin
533+
function kernel(d)
534+
i = threadIdx().x
535+
j = 32 - i + 1
536+
537+
d[i] = shfl(d[i], j)
538+
539+
return
540+
end
541+
542+
warpsize = CUDAdrv.warpsize(device())
543+
544+
a = CuTestArray([i for i in 1:warpsize])
545+
@cuda threads=warpsize kernel(a)
546+
@test Array(a) == [i for i in warpsize:-1:1]
547+
end
548+
549+
@testset "shuffle down" begin
533550
@eval struct AddableTuple
534551
x::Int32
535552
y::Int64
@@ -539,15 +556,38 @@ if capability(dev) >= v"3.0"
539556

540557
n = 14
541558

542-
@testset for T in [Int32, Int64, Float32, Float64, AddableTuple]
543-
function kernel(d::CuDeviceArray{T}, n) where {T}
544-
t = threadIdx().x
545-
if t <= n
546-
d[t] += shfl_down(d[t], n÷2)
547-
end
548-
return
559+
function kernel1(d::CuDeviceArray{T}, n) where {T}
560+
t = threadIdx().x
561+
if t <= n
562+
d[t] += shfl_down(d[t], n÷2)
563+
end
564+
return
565+
end
566+
567+
function kernel2(d::CuDeviceArray{T}, n) where {T}
568+
t = threadIdx().x
569+
if t <= n
570+
d[t] += shfl_down(d[t], n÷2, 32, 0xffffffff)
549571
end
572+
return
573+
end
550574

575+
function kernel3(d::CuDeviceArray{T}, n) where {T}
576+
t = threadIdx().x
577+
if t <= n
578+
d[t] += shfl_down_sync(0xffffffff, d[t], n÷2, 32)
579+
end
580+
return
581+
end
582+
583+
kernels = try
584+
getfield(CUDAnative, :shfl_sync)
585+
(kernel1, kernel2, kernel3)
586+
catch
587+
(kernel1,)
588+
end
589+
590+
@testset for T in [Int32, Int64, Float32, Float64, AddableTuple], kernel in kernels
551591
a = T[T(i) for i in 1:n]
552592
d_a = CuArray(a)
553593

@@ -557,8 +597,8 @@ if capability(dev) >= v"3.0"
557597
a[1:n÷2] += a[n÷2+1:end]
558598
@test a == Array(d_a)
559599
end
560-
561600
end
601+
562602
end
563603

564604
end

0 commit comments

Comments
 (0)