3
3
using Printf
4
4
using Logging
5
5
using TimerOutputs
6
+ using DataStructures
6
7
7
8
include (" pool/utils.jl" )
8
9
using . PoolUtils
64
65
const usage_limit = PerDevice {Int} () do dev
65
66
if haskey (ENV , " JULIA_CUDA_MEMORY_LIMIT" )
66
67
parse (Int, ENV [" JULIA_CUDA_MEMORY_LIMIT" ])
67
- elseif haskey (ENV , " CUARRAYS_MEMORY_LIMIT" )
68
- Base. depwarn (" The CUARRAYS_MEMORY_LIMIT environment flag is deprecated, please use JULIA_CUDA_MEMORY_LIMIT instead." , :__init_pool__ )
69
- parse (Int, ENV [" CUARRAYS_MEMORY_LIMIT" ])
70
68
else
71
69
typemax (Int)
72
70
end
@@ -116,7 +114,8 @@ function hard_limit(dev::CuDevice)
116
114
usage_limit[dev]
117
115
end
118
116
119
- function actual_alloc (dev:: CuDevice , bytes:: Integer , last_resort:: Bool = false )
117
+ function actual_alloc (dev:: CuDevice , bytes:: Integer , last_resort:: Bool = false ;
118
+ stream_ordered:: Bool = false )
120
119
buf = @device! dev begin
121
120
# check the memory allocation limit
122
121
if usage[dev][] + bytes > (last_resort ? hard_limit (dev) : soft_limit (dev))
@@ -127,7 +126,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
127
126
try
128
127
time = Base. @elapsed begin
129
128
@timeit_debug alloc_to " alloc" begin
130
- buf = Mem. alloc (Mem. Device, bytes; async= true )
129
+ buf = Mem. alloc (Mem. Device, bytes; async= true , stream_ordered )
131
130
end
132
131
end
133
132
@@ -146,7 +145,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
146
145
return Block (buf, bytes; state= AVAILABLE)
147
146
end
148
147
149
- function actual_free (dev:: CuDevice , block:: Block )
148
+ function actual_free (dev:: CuDevice , block:: Block ; stream_ordered :: Bool = false )
150
149
@assert iswhole (block) " Cannot free $block : block is not whole"
151
150
@assert block. off == 0
152
151
@assert block. state == AVAILABLE " Cannot free $block : block is not available"
@@ -155,7 +154,7 @@ function actual_free(dev::CuDevice, block::Block)
155
154
# free the memory
156
155
@timeit_debug alloc_to " free" begin
157
156
time = Base. @elapsed begin
158
- Mem. free (block. buf; async= true )
157
+ Mem. free (block. buf; async= true , stream_ordered )
159
158
end
160
159
block. state = INVALID
161
160
@@ -181,41 +180,49 @@ Show the timings of the currently active memory pool. Assumes
181
180
pool_timings () = (show (PoolUtils. to; allocations= false , sortby= :name ); println ())
182
181
183
182
# pool API:
184
- # - init()
185
- # - alloc(::CuDevice , sz)::Block
186
- # - free(::CuDevice , ::Block)
187
- # - reclaim(::CuDevice , nb::Int=typemax(Int))::Int
188
- # - cached_memory()
183
+ # - constructor taking a CuDevice
184
+ # - alloc(::AbstractPool , sz)::Block
185
+ # - free(::AbstractPool , ::Block)
186
+ # - reclaim(::AbstractPool , nb::Int=typemax(Int))::Int
187
+ # - cached_memory(::AbstractPool )
189
188
190
189
module Pool
191
190
@enum MemoryPool None Simple Binned Split
192
191
end
193
- const active_pool = Ref {Pool.MemoryPool} ()
194
- const async_alloc = Ref {Bool} ()
195
-
196
- macro pooled (ex)
197
- @assert Meta. isexpr (ex, :call )
198
- f, args... = ex. args
199
- quote
200
- if active_pool[] == Pool. None
201
- NoPool.$ (f)($ (map (esc, args)... ))
202
- elseif active_pool[] == Pool. Simple
203
- SimplePool.$ (f)($ (map (esc, args)... ))
204
- elseif active_pool[] == Pool. Binned
205
- BinnedPool.$ (f)($ (map (esc, args)... ))
206
- elseif active_pool[] == Pool. Split
207
- SplitPool.$ (f)($ (map (esc, args)... ))
208
- else
209
- error (" unreachable" )
210
- end
211
- end
212
- end
213
192
193
+ abstract type AbstractPool end
214
194
include (" pool/none.jl" )
215
195
include (" pool/simple.jl" )
216
196
include (" pool/binned.jl" )
217
197
include (" pool/split.jl" )
218
198
199
+ const pools = PerDevice {AbstractPool} (dev-> begin
200
+ default_pool = if version () >= v " 11.2" &&
201
+ attribute (dev, CUDA. DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED) == 1
202
+ " cuda"
203
+ else
204
+ " binned"
205
+ end
206
+ pool_name = get (ENV , " JULIA_CUDA_MEMORY_POOL" , default_pool)
207
+ pool = if pool_name == " none"
208
+ NoPool (; dev, stream_ordered= false )
209
+ elseif pool_name == " simple"
210
+ SimplePool (; dev, stream_ordered= false )
211
+ elseif pool_name == " binned"
212
+ BinnedPool (; dev, stream_ordered= false )
213
+ elseif pool_name == " split"
214
+ SplitPool (; dev, stream_ordered= false )
215
+ elseif pool_name == " cuda"
216
+ @assert version () >= v " 11.2" " The CUDA memory pool is only supported on CUDA 11.2+"
217
+ @assert (attribute (dev, CUDA. DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED) == 1 ,
218
+ " Your device $(name (dev)) does not support the CUDA memory pool" )
219
+ NoPool (; dev, stream_ordered= true )
220
+ else
221
+ error (" Invalid memory pool '$pool_name '" )
222
+ end
223
+ pool
224
+ end )
225
+
219
226
220
227
# # interface
221
228
@@ -263,11 +270,11 @@ a [`OutOfGPUMemoryError`](@ref) if the allocation request cannot be satisfied.
263
270
sz == 0 && return CU_NULL
264
271
265
272
dev = device ()
273
+ pool = pools[dev]
266
274
267
275
time = Base. @elapsed begin
268
- @pool_timeit " pooled alloc" block = @pooled alloc (dev , sz)
276
+ @pool_timeit " pooled alloc" block = alloc (pool , sz):: Union{Nothing,Block}
269
277
end
270
- block:: Union{Nothing,Block}
271
278
block === nothing && throw (OutOfGPUMemoryError (sz))
272
279
273
280
# record the memory block
@@ -328,6 +335,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
328
335
ptr == CU_NULL && return
329
336
330
337
dev = device ()
338
+ pool = pools[dev]
331
339
last_use[dev] = time ()
332
340
333
341
if MEMDEBUG && ptr == CuPtr {Cvoid} (0xbbbbbbbbbbbbbbbb )
@@ -359,7 +367,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
359
367
end
360
368
361
369
time = Base. @elapsed begin
362
- @pool_timeit " pooled free" @pooled free (dev , block)
370
+ @pool_timeit " pooled free" free (pool , block)
363
371
end
364
372
365
373
alloc_stats. pool_time += time
@@ -382,7 +390,8 @@ actually reclaimed.
382
390
"""
383
391
function reclaim (sz:: Int = typemax (Int))
384
392
dev = device ()
385
- @pooled reclaim (dev, sz)
393
+ pool = pools[dev]
394
+ reclaim (pool, sz)
386
395
end
387
396
388
397
"""
@@ -403,6 +412,9 @@ macro retry_reclaim(isfailed, ex)
403
412
ret = $ (esc (ex))
404
413
$ (esc (isfailed))(ret) || break
405
414
415
+ dev = device ()
416
+ pool = pools[dev]
417
+
406
418
# incrementally more costly reclaim of cached memory
407
419
if phase == 1
408
420
reclaim ()
@@ -412,11 +424,10 @@ macro retry_reclaim(isfailed, ex)
412
424
elseif phase == 3
413
425
GC. gc (true )
414
426
reclaim ()
415
- elseif phase == 4 && async_alloc[]
427
+ elseif phase == 4 && pool . stream_ordered
416
428
# this phase is unique to retry_reclaim, as regular allocations come from the pool
417
429
# so are assumed to never need to trim its contents.
418
- pool = memory_pool (device ())
419
- trim (pool)
430
+ trim (memory_pool (device ()))
420
431
end
421
432
end
422
433
ret
@@ -445,7 +456,8 @@ function pool_cleanup()
445
456
446
457
if t1- t0 > 300
447
458
# the pool hasn't been used for a while, so reclaim unused buffers
448
- @pooled reclaim (dev)
459
+ pool = pools[dev]
460
+ reclaim (pool)
449
461
end
450
462
end
451
463
@@ -561,7 +573,10 @@ macro timed(ex)
561
573
end
562
574
end
563
575
564
- cached_memory () = @pooled cached_memory ()
576
+ function cached_memory (dev:: CuDevice = device ())
577
+ pool = pools[dev]
578
+ cached_memory (pool)
579
+ end
565
580
566
581
"""
567
582
memory_status([io=stdout])
@@ -584,10 +599,11 @@ function memory_status(io::IO=stdout)
584
599
end
585
600
println (io)
586
601
587
- alloc_used_bytes = used_memory ()
588
- alloc_cached_bytes = cached_memory ()
602
+ pool = pools[dev]
603
+ alloc_used_bytes = used_memory (dev)
604
+ alloc_cached_bytes = cached_memory (pool)
589
605
alloc_total_bytes = alloc_used_bytes + alloc_cached_bytes
590
- @printf (io, " Memory pool '%s' usage: %s (%s allocated, %s cached)\n " , string (active_pool[] ),
606
+ @printf (io, " Memory pool '%s' usage: %s (%s allocated, %s cached)\n " , string (pool ),
591
607
Base. format_bytes (alloc_total_bytes), Base. format_bytes (alloc_used_bytes),
592
608
Base. format_bytes (alloc_cached_bytes))
593
609
@@ -627,24 +643,8 @@ function __init_pool__()
627
643
initialize! (allocated, ndevices ())
628
644
initialize! (requested, ndevices ())
629
645
630
- # memory pool configuration
631
- default_pool = version () >= v " 11.2" ? " cuda" : " binned"
632
- pool_name = get (ENV , " JULIA_CUDA_MEMORY_POOL" , default_pool)
633
- active_pool[], async_alloc[] = if pool_name == " none"
634
- Pool. None, false
635
- elseif pool_name == " simple"
636
- Pool. Simple, false
637
- elseif pool_name == " binned"
638
- Pool. Binned, false
639
- elseif pool_name == " split"
640
- Pool. Split, false
641
- elseif pool_name == " cuda"
642
- @assert version () >= v " 11.2" " The CUDA memory pool is only supported on CUDA 11.2+"
643
- Pool. None, true
644
- else
645
- error (" Invalid memory pool '$pool_name '" )
646
- end
647
- @pooled init ()
646
+ # memory pools
647
+ initialize! (pools, ndevices ())
648
648
649
649
TimerOutputs. reset_timer! (alloc_to)
650
650
TimerOutputs. reset_timer! (PoolUtils. to)
@@ -660,6 +660,6 @@ function __init_pool__()
660
660
end
661
661
662
662
if isinteractive ()
663
- @async @pooled pool_cleanup ()
663
+ @async pool_cleanup ()
664
664
end
665
665
end
0 commit comments