Skip to content

Commit 6ec21d3

Browse files
authored
Merge pull request #746 from JuliaGPU/tb/pool_per_device
Use a memory pool per device.
2 parents 9fe3015 + afc145a commit 6ec21d3

File tree

16 files changed

+294
-326
lines changed

16 files changed

+294
-326
lines changed

lib/cudadrv/memory.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,12 @@ for access on the CPU.
6464
"""
6565
function alloc(::Type{DeviceBuffer}, bytesize::Integer;
6666
async::Bool=false, stream::CuStream=stream(),
67-
pool::Union{Nothing,CuMemoryPool}=nothing)
67+
pool::Union{Nothing,CuMemoryPool}=nothing,
68+
stream_ordered::Bool=CUDA.version() >= v"11.2")
6869
bytesize == 0 && return DeviceBuffer(CU_NULL, 0)
6970

7071
ptr_ref = Ref{CUDA.CUdeviceptr}()
71-
if CUDA.async_alloc[]
72+
if stream_ordered
7273
if pool !== nothing
7374
CUDA.cuMemAllocFromPoolAsync(ptr_ref, bytesize, pool, stream)
7475
else
@@ -83,10 +84,11 @@ function alloc(::Type{DeviceBuffer}, bytesize::Integer;
8384
end
8485

8586

86-
function free(buf::DeviceBuffer; async::Bool=false, stream::CuStream=stream())
87+
function free(buf::DeviceBuffer; async::Bool=false, stream::CuStream=stream(),
88+
stream_ordered::Bool=CUDA.version() >= v"11.2")
8789
pointer(buf) == CU_NULL && return
8890

89-
if CUDA.async_alloc[]
91+
if stream_ordered
9092
CUDA.cuMemFreeAsync(buf, stream)
9193
async || synchronize(stream)
9294
else

lib/cudadrv/pool.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function memory_pool(dev::CuDevice)
5050
handle_ref = Ref{CUmemoryPool}()
5151
cuDeviceGetMemPool(handle_ref, dev)
5252

53-
ctx = CuCurrentContext()
53+
ctx = CuCurrentContext()::CuContext
5454
CuMemoryPool(handle_ref[], ctx)
5555
end
5656

lib/cudadrv/stream.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function CuStream(; flags::CUstream_flags=STREAM_DEFAULT,
3232
cuStreamCreateWithPriority(handle_ref, flags, priority)
3333
end
3434

35-
ctx = CuCurrentContext()
35+
ctx = CuCurrentContext()::CuContext
3636
obj = CuStream(handle_ref[], ctx)
3737
finalizer(unsafe_destroy!, obj)
3838
return obj

lib/cudnn/CUDNN.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,17 @@ function __init__()
118118
end
119119
end
120120

121-
function log_message(sev, udata, dbg_ptr, cstr)
121+
function log_message(sev, udata, dbg_ptr, ptr)
122122
# "Each line of this message is terminated by \0, and the end of the message is
123123
# terminated by \0\0"
124124
len = 0
125125
while true
126-
if unsafe_load(cstr, len+1) == '\0' && unsafe_load(cstr, len+2)
126+
if unsafe_load(ptr, len+1) == '\0' && unsafe_load(ptr, len+2) == '\0'
127127
break
128128
end
129129
len += 1
130130
end
131-
str = unsafe_string(cstr, len)
131+
str = unsafe_string(ptr, len)
132132
lines = split(str, '\0')
133133
msg = join(str, '\n')
134134

@@ -143,7 +143,7 @@ function __runtime_init__()
143143
# FIXME: this doesn't work, and the mask remains 0 (as observed with cudnnGetCallback)
144144
if isdebug(:init, CUDNN)
145145
callback = @cfunction(log_message, Nothing,
146-
(cudnnSeverity_t, Ptr{Cvoid}, Ptr{cudnnDebug_t}, Cstring))
146+
(cudnnSeverity_t, Ptr{Cvoid}, Ptr{cudnnDebug_t}, Ptr{UInt8}))
147147
cudnnSetCallback(typemax(UInt32), C_NULL, callback)
148148
end
149149
end

src/pool.jl

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Printf
44
using Logging
55
using TimerOutputs
6+
using DataStructures
67

78
include("pool/utils.jl")
89
using .PoolUtils
@@ -64,9 +65,6 @@ end
6465
const usage_limit = PerDevice{Int}() do dev
6566
if haskey(ENV, "JULIA_CUDA_MEMORY_LIMIT")
6667
parse(Int, ENV["JULIA_CUDA_MEMORY_LIMIT"])
67-
elseif haskey(ENV, "CUARRAYS_MEMORY_LIMIT")
68-
Base.depwarn("The CUARRAYS_MEMORY_LIMIT environment flag is deprecated, please use JULIA_CUDA_MEMORY_LIMIT instead.", :__init_pool__)
69-
parse(Int, ENV["CUARRAYS_MEMORY_LIMIT"])
7068
else
7169
typemax(Int)
7270
end
@@ -116,7 +114,8 @@ function hard_limit(dev::CuDevice)
116114
usage_limit[dev]
117115
end
118116

119-
function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
117+
function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false;
118+
stream_ordered::Bool=false)
120119
buf = @device! dev begin
121120
# check the memory allocation limit
122121
if usage[dev][] + bytes > (last_resort ? hard_limit(dev) : soft_limit(dev))
@@ -127,7 +126,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
127126
try
128127
time = Base.@elapsed begin
129128
@timeit_debug alloc_to "alloc" begin
130-
buf = Mem.alloc(Mem.Device, bytes; async=true)
129+
buf = Mem.alloc(Mem.Device, bytes; async=true, stream_ordered)
131130
end
132131
end
133132

@@ -146,7 +145,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
146145
return Block(buf, bytes; state=AVAILABLE)
147146
end
148147

149-
function actual_free(dev::CuDevice, block::Block)
148+
function actual_free(dev::CuDevice, block::Block; stream_ordered::Bool=false)
150149
@assert iswhole(block) "Cannot free $block: block is not whole"
151150
@assert block.off == 0
152151
@assert block.state == AVAILABLE "Cannot free $block: block is not available"
@@ -155,7 +154,7 @@ function actual_free(dev::CuDevice, block::Block)
155154
# free the memory
156155
@timeit_debug alloc_to "free" begin
157156
time = Base.@elapsed begin
158-
Mem.free(block.buf; async=true)
157+
Mem.free(block.buf; async=true, stream_ordered)
159158
end
160159
block.state = INVALID
161160

@@ -181,41 +180,49 @@ Show the timings of the currently active memory pool. Assumes
181180
pool_timings() = (show(PoolUtils.to; allocations=false, sortby=:name); println())
182181

183182
# pool API:
184-
# - init()
185-
# - alloc(::CuDevice, sz)::Block
186-
# - free(::CuDevice, ::Block)
187-
# - reclaim(::CuDevice, nb::Int=typemax(Int))::Int
188-
# - cached_memory()
183+
# - constructor taking a CuDevice
184+
# - alloc(::AbstractPool, sz)::Block
185+
# - free(::AbstractPool, ::Block)
186+
# - reclaim(::AbstractPool, nb::Int=typemax(Int))::Int
187+
# - cached_memory(::AbstractPool)
189188

190189
module Pool
191190
@enum MemoryPool None Simple Binned Split
192191
end
193-
const active_pool = Ref{Pool.MemoryPool}()
194-
const async_alloc = Ref{Bool}()
195-
196-
macro pooled(ex)
197-
@assert Meta.isexpr(ex, :call)
198-
f, args... = ex.args
199-
quote
200-
if active_pool[] == Pool.None
201-
NoPool.$(f)($(map(esc, args)...))
202-
elseif active_pool[] == Pool.Simple
203-
SimplePool.$(f)($(map(esc, args)...))
204-
elseif active_pool[] == Pool.Binned
205-
BinnedPool.$(f)($(map(esc, args)...))
206-
elseif active_pool[] == Pool.Split
207-
SplitPool.$(f)($(map(esc, args)...))
208-
else
209-
error("unreachable")
210-
end
211-
end
212-
end
213192

193+
abstract type AbstractPool end
214194
include("pool/none.jl")
215195
include("pool/simple.jl")
216196
include("pool/binned.jl")
217197
include("pool/split.jl")
218198

199+
const pools = PerDevice{AbstractPool}(dev->begin
200+
default_pool = if version() >= v"11.2" &&
201+
attribute(dev, CUDA.DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED) == 1
202+
"cuda"
203+
else
204+
"binned"
205+
end
206+
pool_name = get(ENV, "JULIA_CUDA_MEMORY_POOL", default_pool)
207+
pool = if pool_name == "none"
208+
NoPool(; dev, stream_ordered=false)
209+
elseif pool_name == "simple"
210+
SimplePool(; dev, stream_ordered=false)
211+
elseif pool_name == "binned"
212+
BinnedPool(; dev, stream_ordered=false)
213+
elseif pool_name == "split"
214+
SplitPool(; dev, stream_ordered=false)
215+
elseif pool_name == "cuda"
216+
@assert version() >= v"11.2" "The CUDA memory pool is only supported on CUDA 11.2+"
217+
@assert(attribute(dev, CUDA.DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED) == 1,
218+
"Your device $(name(dev)) does not support the CUDA memory pool")
219+
NoPool(; dev, stream_ordered=true)
220+
else
221+
error("Invalid memory pool '$pool_name'")
222+
end
223+
pool
224+
end)
225+
219226

220227
## interface
221228

@@ -263,11 +270,11 @@ a [`OutOfGPUMemoryError`](@ref) if the allocation request cannot be satisfied.
263270
sz == 0 && return CU_NULL
264271

265272
dev = device()
273+
pool = pools[dev]
266274

267275
time = Base.@elapsed begin
268-
@pool_timeit "pooled alloc" block = @pooled alloc(dev, sz)
276+
@pool_timeit "pooled alloc" block = alloc(pool, sz)::Union{Nothing,Block}
269277
end
270-
block::Union{Nothing,Block}
271278
block === nothing && throw(OutOfGPUMemoryError(sz))
272279

273280
# record the memory block
@@ -328,6 +335,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
328335
ptr == CU_NULL && return
329336

330337
dev = device()
338+
pool = pools[dev]
331339
last_use[dev] = time()
332340

333341
if MEMDEBUG && ptr == CuPtr{Cvoid}(0xbbbbbbbbbbbbbbbb)
@@ -359,7 +367,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
359367
end
360368

361369
time = Base.@elapsed begin
362-
@pool_timeit "pooled free" @pooled free(dev, block)
370+
@pool_timeit "pooled free" free(pool, block)
363371
end
364372

365373
alloc_stats.pool_time += time
@@ -382,7 +390,8 @@ actually reclaimed.
382390
"""
383391
function reclaim(sz::Int=typemax(Int))
384392
dev = device()
385-
@pooled reclaim(dev, sz)
393+
pool = pools[dev]
394+
reclaim(pool, sz)
386395
end
387396

388397
"""
@@ -403,6 +412,9 @@ macro retry_reclaim(isfailed, ex)
403412
ret = $(esc(ex))
404413
$(esc(isfailed))(ret) || break
405414

415+
dev = device()
416+
pool = pools[dev]
417+
406418
# incrementally more costly reclaim of cached memory
407419
if phase == 1
408420
reclaim()
@@ -412,11 +424,10 @@ macro retry_reclaim(isfailed, ex)
412424
elseif phase == 3
413425
GC.gc(true)
414426
reclaim()
415-
elseif phase == 4 && async_alloc[]
427+
elseif phase == 4 && pool.stream_ordered
416428
# this phase is unique to retry_reclaim, as regular allocations come from the pool
417429
# so are assumed to never need to trim its contents.
418-
pool = memory_pool(device())
419-
trim(pool)
430+
trim(memory_pool(device()))
420431
end
421432
end
422433
ret
@@ -445,7 +456,8 @@ function pool_cleanup()
445456

446457
if t1-t0 > 300
447458
# the pool hasn't been used for a while, so reclaim unused buffers
448-
@pooled reclaim(dev)
459+
pool = pools[dev]
460+
reclaim(pool)
449461
end
450462
end
451463

@@ -561,7 +573,10 @@ macro timed(ex)
561573
end
562574
end
563575

564-
cached_memory() = @pooled cached_memory()
576+
function cached_memory(dev::CuDevice=device())
577+
pool = pools[dev]
578+
cached_memory(pool)
579+
end
565580

566581
"""
567582
memory_status([io=stdout])
@@ -584,10 +599,11 @@ function memory_status(io::IO=stdout)
584599
end
585600
println(io)
586601

587-
alloc_used_bytes = used_memory()
588-
alloc_cached_bytes = cached_memory()
602+
pool = pools[dev]
603+
alloc_used_bytes = used_memory(dev)
604+
alloc_cached_bytes = cached_memory(pool)
589605
alloc_total_bytes = alloc_used_bytes + alloc_cached_bytes
590-
@printf(io, "Memory pool '%s' usage: %s (%s allocated, %s cached)\n", string(active_pool[]),
606+
@printf(io, "Memory pool '%s' usage: %s (%s allocated, %s cached)\n", string(pool),
591607
Base.format_bytes(alloc_total_bytes), Base.format_bytes(alloc_used_bytes),
592608
Base.format_bytes(alloc_cached_bytes))
593609

@@ -627,24 +643,8 @@ function __init_pool__()
627643
initialize!(allocated, ndevices())
628644
initialize!(requested, ndevices())
629645

630-
# memory pool configuration
631-
default_pool = version() >= v"11.2" ? "cuda" : "binned"
632-
pool_name = get(ENV, "JULIA_CUDA_MEMORY_POOL", default_pool)
633-
active_pool[], async_alloc[] = if pool_name == "none"
634-
Pool.None, false
635-
elseif pool_name == "simple"
636-
Pool.Simple, false
637-
elseif pool_name == "binned"
638-
Pool.Binned, false
639-
elseif pool_name == "split"
640-
Pool.Split, false
641-
elseif pool_name == "cuda"
642-
@assert version() >= v"11.2" "The CUDA memory pool is only supported on CUDA 11.2+"
643-
Pool.None, true
644-
else
645-
error("Invalid memory pool '$pool_name'")
646-
end
647-
@pooled init()
646+
# memory pools
647+
initialize!(pools, ndevices())
648648

649649
TimerOutputs.reset_timer!(alloc_to)
650650
TimerOutputs.reset_timer!(PoolUtils.to)
@@ -660,6 +660,6 @@ function __init_pool__()
660660
end
661661

662662
if isinteractive()
663-
@async @pooled pool_cleanup()
663+
@async pool_cleanup()
664664
end
665665
end

0 commit comments

Comments
 (0)