@@ -10,10 +10,15 @@ function tensorcontract!(C::AbstractArray,
10
10
argcheck_tensorcontract (C, A, pA, B, pB, pAB)
11
11
dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
12
12
13
- _diagtensorcontract! (StridedView (C),
14
- StridedView (A), pA, conjA,
15
- StridedView (B. diag), pB, conjB,
16
- pAB, α, β)
13
+ if conjA && conjB
14
+ _diagtensorcontract! (SV (C), conj (SV (A)), pA, conj (SV (B. diag)), pB, pAB, α, β)
15
+ elseif conjA
16
+ _diagtensorcontract! (SV (C), conj (SV (A)), pA, SV (B. diag), pB, pAB, α, β)
17
+ elseif conjB
18
+ _diagtensorcontract! (SV (C), SV (A), pA, conj (SV (B. diag)), pB, pAB, α, β)
19
+ else
20
+ _diagtensorcontract! (SV (C), SV (A), pA, SV (B. diag), pB, pAB, α, β)
21
+ end
17
22
return C
18
23
end
19
24
@@ -35,10 +40,15 @@ function tensorcontract!(C::AbstractArray,
35
40
rpAB = (TupleTools. getindices (indCinoBA, tpAB[1 ]),
36
41
TupleTools. getindices (indCinoBA, tpAB[2 ]))
37
42
38
- _diagtensorcontract! (StridedView (C),
39
- StridedView (B), rpB, conjB,
40
- StridedView (A. diag), rpA, conjA,
41
- rpAB, α, β)
43
+ if conjA && conjB
44
+ _diagtensorcontract! (SV (C), conj (SV (B)), rpB, conj (SV (A. diag)), rpA, rpAB, α, β)
45
+ elseif conjA
46
+ _diagtensorcontract! (SV (C), SV (B), rpB, conj (SV (A. diag)), rpA, rpAB, α, β)
47
+ elseif conjB
48
+ _diagtensorcontract! (SV (C), conj (SV (B)), rpB, SV (A. diag), rpA, rpAB, α, β)
49
+ else
50
+ _diagtensorcontract! (SV (C), SV (B), rpB, SV (A. diag), rpA, rpAB, α, β)
51
+ end
42
52
return C
43
53
end
44
54
@@ -50,40 +60,16 @@ function tensorcontract!(C::AbstractArray,
50
60
:: StridedNative , allocator= DefaultAllocator ())
51
61
argcheck_tensorcontract (C, A, pA, B, pB, pAB)
52
62
dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
53
- if numin (pA) == 1 # matrix multiplication
54
- scale! (C, β)
55
- β = one (β)
56
-
57
- A2 = sreshape (flag2op (conjA)(StridedView (A. diag)), (length (A. diag), 1 ))
58
- B2 = sreshape (flag2op (conjB)(StridedView (B. diag)), (length (B. diag), 1 ))
59
- # take a view of the diagonal elements of C, having strides 1 + length(diag)
60
- totsize = (length (A. diag),)
61
- C2 = StridedView (C, totsize, (sum (strides (C)),))
62
-
63
- elseif numin (pA) == 2 # trace
64
- A2 = flag2op (conjA)(StridedView (A. diag, (length (A. diag),)))
65
- B2 = flag2op (conjB)(StridedView (B. diag, (length (B. diag),)))
66
- totsize = (length (A. diag),)
67
- C2 = sreshape (StridedView (C), (1 ,))
68
-
69
- else # outer product
70
- scale! (C, β)
71
- β = one (β)
72
63
73
- A2 = sreshape ( StridedView (A . diag), ( length (A . diag), 1 ))
74
- B2 = sreshape ( StridedView (B . diag), ( 1 , length (A . diag)))
75
-
76
- C3 = permutedims ( StridedView (C), invperm ( linearize (pAB)) )
77
- strC = strides (C3)
78
- newstrides = (strC[ 1 ] + strC[ 2 ], strC[ 3 ] + strC[ 4 ] )
79
- totsize = ( length (A2), length (B2))
80
- C2 = StridedView (C3 . parent, totsize, newstrides, C3 . offset, C3 . op )
64
+ if conjA && conjB
65
+ _diagdiagcontract! ( SV (C), conj ( SV (A . diag)), pA, conj ( SV (B . diag)), pB, pAB, α, β )
66
+ elseif conjA
67
+ _diagdiagcontract! ( SV (C), conj ( SV (A . diag)), pA, SV (B . diag), pB, pAB, α, β )
68
+ elseif conjB
69
+ _diagdiagcontract! ( SV (C), SV (A . diag), pA, conj ( SV (B . diag)), pB, pAB, α, β )
70
+ else
71
+ _diagdiagcontract! ( SV (C), SV (A . diag), pA, SV (B . diag), pB, pAB, α, β )
81
72
end
82
-
83
- op1 = Base. Fix2 (scale, α) ∘ *
84
- op2 = Base. Fix2 (scale, β)
85
- Strided. _mapreducedim! (op1, + , op2, totsize, (C2, A2, B2))
86
-
87
73
return C
88
74
end
89
75
@@ -96,41 +82,49 @@ function tensorcontract!(C::Diagonal,
96
82
argcheck_tensorcontract (C, A, pA, B, pB, pAB)
97
83
dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
98
84
99
- A2 = flag2op (conjA)( StridedView (A. diag) )
100
- B2 = flag2op (conjB)( StridedView (B. diag) )
85
+ A2 = StridedView (A. diag)
86
+ B2 = StridedView (B. diag)
101
87
C2 = StridedView (C. diag)
102
88
103
- C2 .= C2 .* β .+ A2 .* B2 .* α
89
+ if conjA && conjB
90
+ C2 .= C2 .* β .+ conj .(A2 .* B2) .* α
91
+ elseif conjA
92
+ C2 .= C2 .* β .+ conj .(A2) .* B2 .* α
93
+ elseif conjB
94
+ C2 .= C2 .* β .+ A2 .* conj .(B2) .* α
95
+ else
96
+ C2 .= C2 .* β .+ A2 .* B2 .* α
97
+ end
104
98
return C
105
99
end
106
100
107
101
function _diagtensorcontract! (C:: StridedView ,
108
- A:: StridedView , pA:: Index2Tuple , conjA :: Bool ,
109
- Bdiag:: StridedView , pB:: Index2Tuple , conjB :: Bool ,
102
+ A:: StridedView , pA:: Index2Tuple ,
103
+ Bdiag:: StridedView , pB:: Index2Tuple ,
110
104
pAB:: Index2Tuple , α:: Number , β:: Number )
111
105
sizeA = i -> size (A, i)
112
106
csizeA = sizeA .(pA[2 ])
113
107
osizeA = sizeA .(pA[1 ])
114
108
115
109
if numin (pB) == 1 # => numin(A) == numout(B) == 1
116
110
totsize = (osizeA... , csizeA... )
117
- A2 = flag2op (conjA)( permutedims (A, linearize (pA) ))
118
- B2 = flag2op (conjB)( sreshape (Bdiag, ((one .(osizeA)). .. , csizeA... ) ))
111
+ A2 = permutedims (A, linearize (pA))
112
+ B2 = sreshape (Bdiag, ((one .(osizeA)). .. , csizeA... ))
119
113
C2 = permutedims (C, invperm (linearize (pAB)))
120
114
121
115
elseif numin (pB) == 0
122
116
strideA = i -> stride (A, i)
123
117
newstrides = (strideA .(pA[1 ])... , strideA (pA[2 ][1 ]) + strideA (pA[2 ][2 ]))
124
118
totsize = (osizeA... , csizeA[1 ])
125
- A2 = flag2op (conjA)( StridedView (A. parent, totsize, newstrides, A. offset, A. op) )
126
- B2 = flag2op (conjB)( sreshape (Bdiag, ((one .(osizeA)). .. , csizeA[1 ]) ))
119
+ A2 = StridedView (A. parent, totsize, newstrides, A. offset, A. op)
120
+ B2 = sreshape (Bdiag, ((one .(osizeA)). .. , csizeA[1 ]))
127
121
C2 = permutedims (C, invperm (linearize (pAB)))
128
122
129
123
else # numout(pB) == 2 # direct product
130
124
scale! (C, β)
131
125
β = one (β)
132
- A2 = flag2op (conjA)( sreshape (permutedims (A, linearize (pA)), (osizeA... , 1 ) ))
133
- B2 = flag2op (conjB)( sreshape (Bdiag, ((one .(osizeA)). .. , length (Bdiag) )))
126
+ A2 = sreshape (permutedims (A, linearize (pA)), (osizeA... , 1 ))
127
+ B2 = sreshape (Bdiag, ((one .(osizeA)). .. , length (Bdiag)))
134
128
135
129
C3 = permutedims (C, invperm (linearize (pAB)))
136
130
sC = strides (C3)
@@ -145,3 +139,44 @@ function _diagtensorcontract!(C::StridedView,
145
139
146
140
return C
147
141
end
142
+
143
+ function _diagdiagcontract! (C:: StridedView ,
144
+ Adiag:: StridedView , pA:: Index2Tuple ,
145
+ Bdiag:: StridedView , pB:: Index2Tuple ,
146
+ pAB:: Index2Tuple , α:: Number , β:: Number )
147
+ if numin (pA) == 1 # matrix multiplication
148
+ scale! (C, β)
149
+ β = one (β)
150
+
151
+ A2 = sreshape (Adiag, (length (Adiag), 1 ))
152
+ B2 = sreshape (Bdiag, (length (Bdiag), 1 ))
153
+ # take a view of the diagonal elements of C, having strides 1 + length(diag)
154
+ totsize = (length (Adiag),)
155
+ C2 = StridedView (C. parent, totsize, (sum (strides (C)),))
156
+
157
+ elseif numin (pA) == 2 # trace
158
+ A2 = Adiag
159
+ B2 = Bdiag
160
+ totsize = (length (Adiag),)
161
+ C2 = sreshape (C, (1 ,))
162
+
163
+ else # outer product
164
+ scale! (C, β)
165
+ β = one (β)
166
+
167
+ A2 = sreshape (Adiag, (length (Adiag), 1 ))
168
+ B2 = sreshape (Bdiag, (1 , length (Adiag)))
169
+
170
+ C3 = permutedims (C, invperm (linearize (pAB)))
171
+ strC = strides (C3)
172
+ newstrides = (strC[1 ] + strC[2 ], strC[3 ] + strC[4 ])
173
+ totsize = (length (A2), length (B2))
174
+ C2 = StridedView (C3. parent, totsize, newstrides, C3. offset, C3. op)
175
+ end
176
+
177
+ op1 = Base. Fix2 (scale, α) ∘ *
178
+ op2 = Base. Fix2 (scale, β)
179
+ Strided. _mapreducedim! (op1, + , op2, totsize, (C2, A2, B2))
180
+
181
+ return C
182
+ end
0 commit comments