@@ -104,7 +104,7 @@ function _similarstructure_from_indices(T, poA::IndexTuple, poB::IndexTuple,
104
104
return sz
105
105
end
106
106
107
- scalar (C:: AbstractArray ) = ndims (C)== 0 ? C[1 ] : throw (DimensionMismatch ())
107
+ scalar (C:: AbstractArray ) = ndims (C)== 0 ? C[] : throw (DimensionMismatch ())
108
108
109
109
function add! (α, A:: AbstractArray{<:Any, N} , CA:: Symbol ,
110
110
β, C:: AbstractArray{<:Any, N} , indCinA) where {N}
@@ -254,21 +254,18 @@ function contract!(α, A::AbstractArray, CA::Symbol, B::AbstractArray, CB::Symbo
254
254
(ndims (C) == length (indCinoAB) && isperm (indCinoAB)) ||
255
255
throw (IndexError (" invalid permutation of length $(ndims (C)) : $indCinoAB " ))
256
256
257
- sizeA = i -> size (A, i )
258
- sizeB = i -> size (B, i )
259
- sizeC = i -> size (C, i )
257
+ sizeA = size (A)
258
+ sizeB = size (B)
259
+ sizeC = size (C)
260
260
261
- csizeA = sizeA .( cindA)
262
- csizeB = sizeB .( cindB)
263
- osizeA = sizeA .( oindA)
264
- osizeB = sizeB .( oindB)
261
+ csizeA = TupleTools . getindices (sizeA, cindA)
262
+ csizeB = TupleTools . getindices (sizeB, cindB)
263
+ osizeA = TupleTools . getindices (sizeA, oindA)
264
+ osizeB = TupleTools . getindices (sizeB, oindB)
265
265
266
266
csizeA == csizeB ||
267
267
throw (DimensionMismatch (" non-matching sizes in contracted dimensions" ))
268
- sizeAB = let osize = (osizeA... , osizeB... )
269
- i-> osize[i]
270
- end
271
- sizeAB .(indCinoAB) == size (C) ||
268
+ TupleTools. getindices ((osizeA... , osizeB... ), indCinoAB) == size (C) ||
272
269
throw (DimensionMismatch (" non-matching sizes in uncontracted dimensions" ))
273
270
274
271
if use_blas () && TC <: BlasFloat
@@ -306,7 +303,8 @@ function contract!(α, A::AbstractArray, CA::Symbol, B::AbstractArray, CB::Symbo
306
303
if isblascontractable (C, oindAinC, oindBinC, :D )
307
304
C2 = C
308
305
_blas_contract! (α, A2, CA2, B2, CB2, β, C2,
309
- oindA, cindA, oindB, cindB, oindAinC, oindBinC)
306
+ oindA, cindA, oindB, cindB, oindAinC, oindBinC,
307
+ osizeA, csizeA, osizeB, csizeB)
310
308
else
311
309
if syms === nothing
312
310
C2 = similar_from_indices (TC, oindAinC, oindBinC, C, :N )
@@ -315,31 +313,38 @@ function contract!(α, A::AbstractArray, CA::Symbol, B::AbstractArray, CB::Symbo
315
313
end
316
314
_blas_contract! (1 , A2, CA2, B2, CB2, 0 , C2,
317
315
oindA, cindA, oindB, cindB,
318
- _trivtuple (oindA), length (oindA) .+ _trivtuple (oindB))
316
+ _trivtuple (oindA), length (oindA) .+ _trivtuple (oindB),
317
+ osizeA, csizeA, osizeB, csizeB)
318
+
319
319
add! (α, C2, :N , β, C, indCinoAB, ())
320
320
end
321
321
else
322
- _native_contract! (α, A, CA, B, CB, β, C, oindA, cindA, oindB, cindB, indCinoAB)
322
+ _native_contract! (α, A, CA, B, CB, β, C, oindA, cindA, oindB, cindB, indCinoAB,
323
+ osizeA, csizeA, osizeB, csizeB)
323
324
end
324
325
return C
325
326
end
326
327
327
- function isblascontractable (A:: AbstractArray{T,N} , p1:: IndexTuple , p2:: IndexTuple ,
328
- C:: Symbol ) where {T,N}
328
+ function isblascontractable (A:: AbstractArray , p1:: IndexTuple , p2:: IndexTuple ,
329
+ C:: Symbol )
329
330
330
- T <: LinearAlgebra.BlasFloat || return false
331
+ eltype (A) <: LinearAlgebra.BlasFloat || return false
331
332
@unsafe_strided A isblascontractable (A, p1, p2, C)
332
333
end
333
334
334
- function isblascontractable (A:: AbstractStridedView{T,N} , p1:: IndexTuple , p2:: IndexTuple ,
335
- C:: Symbol ) where {T,N}
335
+ function isblascontractable (A:: AbstractStridedView , p1:: IndexTuple , p2:: IndexTuple ,
336
+ C:: Symbol )
336
337
337
- T <: LinearAlgebra.BlasFloat || return false
338
- strideA = i-> stride (A, i)
339
- sizeA = i-> size (A,i)
338
+ eltype (A) <: LinearAlgebra.BlasFloat || return false
339
+ sizeA = size (A)
340
+ stridesA = strides (A)
341
+ sizeA1 = TupleTools. getindices (sizeA, p1)
342
+ sizeA2 = TupleTools. getindices (sizeA, p2)
343
+ stridesA1 = TupleTools. getindices (stridesA, p1)
344
+ stridesA2 = TupleTools. getindices (stridesA, p2)
340
345
341
- canfuse1, d1, s1 = _canfuse (sizeA .(p1), strideA .(p1) )
342
- canfuse2, d2, s2 = _canfuse (sizeA .(p2), strideA .(p2) )
346
+ canfuse1, d1, s1 = _canfuse (sizeA1, stridesA1 )
347
+ canfuse2, d2, s2 = _canfuse (sizeA2, stridesA2 )
343
348
344
349
if C == :D # destination
345
350
return A. op == identity && canfuse1 && canfuse2 && s1 == 1
@@ -369,18 +374,9 @@ function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N}
369
374
end
370
375
_trivtuple (t:: NTuple{N} ) where {N} = ntuple (identity, Val (N))
371
376
372
- function _blas_contract! (α, A:: AbstractArray{T} , CA, B:: AbstractArray{T} , CB,
373
- β, C:: AbstractArray{T} , oindA, cindA, oindB, cindB, oindAinC, oindBinC) where
374
- {T<: LinearAlgebra.BlasFloat }
375
-
376
- sizeA = i-> size (A, i)
377
- sizeB = i-> size (B, i)
378
- sizeC = i-> size (C, i)
379
-
380
- csizeA = sizeA .(cindA)
381
- csizeB = sizeB .(cindB)
382
- osizeA = sizeA .(oindA)
383
- osizeB = sizeB .(oindB)
377
+ function _blas_contract! (α, A:: AbstractArray , CA, B:: AbstractArray , CB,
378
+ β, C:: AbstractArray , oindA, cindA, oindB, cindB, oindAinC, oindBinC,
379
+ osizeA, csizeA, osizeB, csizeB)
384
380
385
381
@unsafe_strided A B C begin
386
382
A2 = sreshape (permutedims (A, (oindA... , cindA... )), (prod (osizeA), prod (csizeA)))
@@ -403,16 +399,8 @@ function _blas_contract!(α, A::AbstractArray{T}, CA, B::AbstractArray{T}, CB,
403
399
end
404
400
405
401
function _native_contract! (α, A:: AbstractArray , CA:: Symbol , B:: AbstractArray , CB:: Symbol ,
406
- β, C:: AbstractArray , oindA, cindA, oindB, cindB, indCinoAB)
407
-
408
- sizeA = i-> size (A, i)
409
- sizeB = i-> size (B, i)
410
- sizeC = i-> size (C, i)
411
-
412
- csizeA = sizeA .(cindA)
413
- csizeB = sizeB .(cindB)
414
- osizeA = sizeA .(oindA)
415
- osizeB = sizeB .(oindB)
402
+ β, C:: AbstractArray , oindA, cindA, oindB, cindB, indCinoAB,
403
+ osizeA, csizeA, osizeB, csizeB)
416
404
417
405
ipC = TupleTools. invperm (indCinoAB)
418
406
if CA == :N
0 commit comments