Skip to content

Commit 427403e

Browse files
authored
resolve conjugation flags in strided implementations (#190)
* resolve conjugation flags in strided implementations * more type stability and formatting * add comment and bump version [skip ci]
1 parent c1e37ec commit 427403e

File tree

6 files changed

+174
-139
lines changed

6 files changed

+174
-139
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "5.0.1"
4+
version = "5.0.2"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/TensorOperations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ include("indexnotation/tensormacros.jl")
5858
include("implementation/functions.jl")
5959
include("implementation/ncon.jl")
6060
include("implementation/abstractarray.jl")
61-
include("implementation/diagonal.jl")
6261
include("implementation/strided.jl")
62+
include("implementation/diagonal.jl")
6363
include("implementation/base.jl")
6464
include("implementation/indices.jl")
6565
include("implementation/allocator.jl")

src/implementation/abstractarray.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,3 @@ function dimcheck_tensorcontract(C::AbstractArray,
160160
throw(DimensionMismatch("non-matching sizes in uncontracted dimensions"))
161161
return nothing
162162
end
163-
164-
#-------------------------------------------------------------------------------------------
165-
# Utility functions
166-
#-------------------------------------------------------------------------------------------
167-
flag2op(flag::Bool) = flag ? conj : identity

src/implementation/diagonal.jl

Lines changed: 86 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@ function tensorcontract!(C::AbstractArray,
1010
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
1111
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
1212

13-
_diagtensorcontract!(StridedView(C),
14-
StridedView(A), pA, conjA,
15-
StridedView(B.diag), pB, conjB,
16-
pAB, α, β)
13+
if conjA && conjB
14+
_diagtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B.diag)), pB, pAB, α, β)
15+
elseif conjA
16+
_diagtensorcontract!(SV(C), conj(SV(A)), pA, SV(B.diag), pB, pAB, α, β)
17+
elseif conjB
18+
_diagtensorcontract!(SV(C), SV(A), pA, conj(SV(B.diag)), pB, pAB, α, β)
19+
else
20+
_diagtensorcontract!(SV(C), SV(A), pA, SV(B.diag), pB, pAB, α, β)
21+
end
1722
return C
1823
end
1924

@@ -35,10 +40,15 @@ function tensorcontract!(C::AbstractArray,
3540
rpAB = (TupleTools.getindices(indCinoBA, tpAB[1]),
3641
TupleTools.getindices(indCinoBA, tpAB[2]))
3742

38-
_diagtensorcontract!(StridedView(C),
39-
StridedView(B), rpB, conjB,
40-
StridedView(A.diag), rpA, conjA,
41-
rpAB, α, β)
43+
if conjA && conjB
44+
_diagtensorcontract!(SV(C), conj(SV(B)), rpB, conj(SV(A.diag)), rpA, rpAB, α, β)
45+
elseif conjA
46+
_diagtensorcontract!(SV(C), SV(B), rpB, conj(SV(A.diag)), rpA, rpAB, α, β)
47+
elseif conjB
48+
_diagtensorcontract!(SV(C), conj(SV(B)), rpB, SV(A.diag), rpA, rpAB, α, β)
49+
else
50+
_diagtensorcontract!(SV(C), SV(B), rpB, SV(A.diag), rpA, rpAB, α, β)
51+
end
4252
return C
4353
end
4454

@@ -50,40 +60,16 @@ function tensorcontract!(C::AbstractArray,
5060
::StridedNative, allocator=DefaultAllocator())
5161
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
5262
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
53-
if numin(pA) == 1 # matrix multiplication
54-
scale!(C, β)
55-
β = one(β)
56-
57-
A2 = sreshape(flag2op(conjA)(StridedView(A.diag)), (length(A.diag), 1))
58-
B2 = sreshape(flag2op(conjB)(StridedView(B.diag)), (length(B.diag), 1))
59-
# take a view of the diagonal elements of C, having strides 1 + length(diag)
60-
totsize = (length(A.diag),)
61-
C2 = StridedView(C, totsize, (sum(strides(C)),))
62-
63-
elseif numin(pA) == 2 # trace
64-
A2 = flag2op(conjA)(StridedView(A.diag, (length(A.diag),)))
65-
B2 = flag2op(conjB)(StridedView(B.diag, (length(B.diag),)))
66-
totsize = (length(A.diag),)
67-
C2 = sreshape(StridedView(C), (1,))
68-
69-
else # outer product
70-
scale!(C, β)
71-
β = one(β)
7263

73-
A2 = sreshape(StridedView(A.diag), (length(A.diag), 1))
74-
B2 = sreshape(StridedView(B.diag), (1, length(A.diag)))
75-
76-
C3 = permutedims(StridedView(C), invperm(linearize(pAB)))
77-
strC = strides(C3)
78-
newstrides = (strC[1] + strC[2], strC[3] + strC[4])
79-
totsize = (length(A2), length(B2))
80-
C2 = StridedView(C3.parent, totsize, newstrides, C3.offset, C3.op)
64+
if conjA && conjB
65+
_diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, conj(SV(B.diag)), pB, pAB, α, β)
66+
elseif conjA
67+
_diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, SV(B.diag), pB, pAB, α, β)
68+
elseif conjB
69+
_diagdiagcontract!(SV(C), SV(A.diag), pA, conj(SV(B.diag)), pB, pAB, α, β)
70+
else
71+
_diagdiagcontract!(SV(C), SV(A.diag), pA, SV(B.diag), pB, pAB, α, β)
8172
end
82-
83-
op1 = Base.Fix2(scale, α) *
84-
op2 = Base.Fix2(scale, β)
85-
Strided._mapreducedim!(op1, +, op2, totsize, (C2, A2, B2))
86-
8773
return C
8874
end
8975

@@ -96,41 +82,49 @@ function tensorcontract!(C::Diagonal,
9682
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
9783
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
9884

99-
A2 = flag2op(conjA)(StridedView(A.diag))
100-
B2 = flag2op(conjB)(StridedView(B.diag))
85+
A2 = StridedView(A.diag)
86+
B2 = StridedView(B.diag)
10187
C2 = StridedView(C.diag)
10288

103-
C2 .= C2 .* β .+ A2 .* B2 .* α
89+
if conjA && conjB
90+
C2 .= C2 .* β .+ conj.(A2 .* B2) .* α
91+
elseif conjA
92+
C2 .= C2 .* β .+ conj.(A2) .* B2 .* α
93+
elseif conjB
94+
C2 .= C2 .* β .+ A2 .* conj.(B2) .* α
95+
else
96+
C2 .= C2 .* β .+ A2 .* B2 .* α
97+
end
10498
return C
10599
end
106100

107101
function _diagtensorcontract!(C::StridedView,
108-
A::StridedView, pA::Index2Tuple, conjA::Bool,
109-
Bdiag::StridedView, pB::Index2Tuple, conjB::Bool,
102+
A::StridedView, pA::Index2Tuple,
103+
Bdiag::StridedView, pB::Index2Tuple,
110104
pAB::Index2Tuple, α::Number, β::Number)
111105
sizeA = i -> size(A, i)
112106
csizeA = sizeA.(pA[2])
113107
osizeA = sizeA.(pA[1])
114108

115109
if numin(pB) == 1 # => numin(A) == numout(B) == 1
116110
totsize = (osizeA..., csizeA...)
117-
A2 = flag2op(conjA)(permutedims(A, linearize(pA)))
118-
B2 = flag2op(conjB)(sreshape(Bdiag, ((one.(osizeA))..., csizeA...)))
111+
A2 = permutedims(A, linearize(pA))
112+
B2 = sreshape(Bdiag, ((one.(osizeA))..., csizeA...))
119113
C2 = permutedims(C, invperm(linearize(pAB)))
120114

121115
elseif numin(pB) == 0
122116
strideA = i -> stride(A, i)
123117
newstrides = (strideA.(pA[1])..., strideA(pA[2][1]) + strideA(pA[2][2]))
124118
totsize = (osizeA..., csizeA[1])
125-
A2 = flag2op(conjA)(StridedView(A.parent, totsize, newstrides, A.offset, A.op))
126-
B2 = flag2op(conjB)(sreshape(Bdiag, ((one.(osizeA))..., csizeA[1])))
119+
A2 = StridedView(A.parent, totsize, newstrides, A.offset, A.op)
120+
B2 = sreshape(Bdiag, ((one.(osizeA))..., csizeA[1]))
127121
C2 = permutedims(C, invperm(linearize(pAB)))
128122

129123
else # numout(pB) == 2 # direct product
130124
scale!(C, β)
131125
β = one(β)
132-
A2 = flag2op(conjA)(sreshape(permutedims(A, linearize(pA)), (osizeA..., 1)))
133-
B2 = flag2op(conjB)(sreshape(Bdiag, ((one.(osizeA))..., length(Bdiag))))
126+
A2 = sreshape(permutedims(A, linearize(pA)), (osizeA..., 1))
127+
B2 = sreshape(Bdiag, ((one.(osizeA))..., length(Bdiag)))
134128

135129
C3 = permutedims(C, invperm(linearize(pAB)))
136130
sC = strides(C3)
@@ -145,3 +139,44 @@ function _diagtensorcontract!(C::StridedView,
145139

146140
return C
147141
end
142+
143+
function _diagdiagcontract!(C::StridedView,
144+
Adiag::StridedView, pA::Index2Tuple,
145+
Bdiag::StridedView, pB::Index2Tuple,
146+
pAB::Index2Tuple, α::Number, β::Number)
147+
if numin(pA) == 1 # matrix multiplication
148+
scale!(C, β)
149+
β = one(β)
150+
151+
A2 = sreshape(Adiag, (length(Adiag), 1))
152+
B2 = sreshape(Bdiag, (length(Bdiag), 1))
153+
# take a view of the diagonal elements of C, having strides 1 + length(diag)
154+
totsize = (length(Adiag),)
155+
C2 = StridedView(C.parent, totsize, (sum(strides(C)),))
156+
157+
elseif numin(pA) == 2 # trace
158+
A2 = Adiag
159+
B2 = Bdiag
160+
totsize = (length(Adiag),)
161+
C2 = sreshape(C, (1,))
162+
163+
else # outer product
164+
scale!(C, β)
165+
β = one(β)
166+
167+
A2 = sreshape(Adiag, (length(Adiag), 1))
168+
B2 = sreshape(Bdiag, (1, length(Adiag)))
169+
170+
C3 = permutedims(C, invperm(linearize(pAB)))
171+
strC = strides(C3)
172+
newstrides = (strC[1] + strC[2], strC[3] + strC[4])
173+
totsize = (length(A2), length(B2))
174+
C2 = StridedView(C3.parent, totsize, newstrides, C3.offset, C3.op)
175+
end
176+
177+
op1 = Base.Fix2(scale, α) *
178+
op2 = Base.Fix2(scale, β)
179+
Strided._mapreducedim!(op1, +, op2, totsize, (C2, A2, B2))
180+
181+
return C
182+
end

0 commit comments

Comments
 (0)