Skip to content

Commit 6f0bd7d

Browse files
committed
update CUDA support to latest version
1 parent da466f7 commit 6f0bd7d

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/TensorOperations.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ function __init__()
148148
const CuArray = CUDA.CuArray
149149
const CublasFloat = CUDA.CUBLAS.CublasFloat
150150
const CublasReal = CUDA.CUBLAS.CublasReal
151-
for s in (:handle, :CuDefaultStream, :CuTensorDescriptor, :cudaDataType,
151+
for s in (:handle, :CuTensorDescriptor, :cudaDataType,
152152
:cutensorContractionDescriptor_t, :cutensorContractionFind_t,
153153
:cutensorContractionPlan_t,
154154
:CUTENSOR_OP_IDENTITY, :CUTENSOR_OP_CONJ, :CUTENSOR_OP_ADD,
@@ -160,6 +160,11 @@ function __init__()
160160
:cutensorInitContractionPlan, :cutensorContraction)
161161
eval(:(const $s = CUDA.CUTENSOR.$s))
162162
end
163+
if isdefined(CUDA, :default_stream)
164+
const default_stream = CUDA.default_stream
165+
else
166+
const default_stream = CUDA.CuDefaultStream
167+
end
163168
include("implementation/cuarray.jl")
164169
@nospecialize
165170
include("indexnotation/cutensormacros.jl")

src/implementation/cuarray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function add!(α, A::CuArray{<:Any, N}, CA::Symbol,
3333
typeCompute = convert(cudaDataType, T)
3434
modeA = collect(Cint, 1:N)
3535
modeC = collect(Cint, indCinA)
36-
stream = CuDefaultStream()
36+
stream = default_stream()
3737
if β == zero(β)
3838
cutensorPermutation(handle(), T[α], A, descA, modeA, C, descC, modeC,
3939
typeCompute, stream)
@@ -71,7 +71,7 @@ end
7171
# typeCompute = cudaDataType(T)
7272
# modeA = collect(Cint, 1:N)
7373
# modeC = collect(Cint, indCinA)
74-
# stream = CuDefaultStream()
74+
# stream = default_stream()
7575
# cutensorElementwiseBinary(handle(), T[real(α)], A, descA, modeA, T[1], Cr, descCr,
7676
# modeC, Cr, descCr, modeC, opAC, typeCompute, stream)
7777
# if imag(α) != 0
@@ -110,7 +110,7 @@ function trace!(α, A::CuArray, CA::Symbol, β, C::CuArray,
110110
typeCompute = cutensorComputeType(T)
111111
modeA = collect(Cint, 1:NA)
112112
modeC = collect(Cint, 1:NC)
113-
stream = CuDefaultStream()
113+
stream = default_stream()
114114
function workspacesize()
115115
out = Ref{UInt64}(C_NULL)
116116
cutensorReductionGetWorkspace(handle(),
@@ -207,7 +207,7 @@ function contract!(α, A::CuArray, CA::Symbol,
207207
modeC = collect(Cint, indCinoAB)
208208

209209
algo = CUTENSOR_ALGO_DEFAULT
210-
stream = CuDefaultStream()
210+
stream = default_stream()
211211
pref = CUTENSOR_WORKSPACE_RECOMMENDED
212212

213213
alignmentRequirementA = Ref{UInt32}(C_NULL)

0 commit comments

Comments
 (0)