Skip to content

Commit b6f645f

Browse files
committed
some changes, add sparsearray type
1 parent d60138e commit b6f645f

File tree

4 files changed

+227
-48
lines changed

4 files changed

+227
-48
lines changed

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33

44
[compat]
5-
Documenter = "0.24"
5+
Documenter = "0.25"

src/TensorOperations.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ include("implementation/indices.jl")
5252
include("implementation/tensorcache.jl")
5353
include("implementation/stridedarray.jl")
5454
include("implementation/diagonal.jl")
55+
include("implementation/sparsearray.jl")
5556

5657
# Functions
5758
#-----------
@@ -172,6 +173,9 @@ end
172173
#----------------------------
173174
function _precompile_()
174175
AVector = Vector{Any}
176+
for N = 1:8
177+
@assert precompile(Tuple{typeof(isperm), NTuple{N,Int}})
178+
end
175179
@assert precompile(Tuple{typeof(_intersect), Base.BitArray{1}, Base.BitArray{1}})
176180
@assert precompile(Tuple{typeof(_intersect), Base.BitSet, Base.BitSet})
177181
@assert precompile(Tuple{typeof(_intersect), UInt128, UInt128})
@@ -265,7 +269,11 @@ function _precompile_()
265269
@assert precompile(Tuple{typeof(storeset), Type{UInt64}, Array{Int64, 1}, Int64})
266270
@assert precompile(Tuple{typeof(storeset), Type{UInt64}, Base.Set{Int64}, Int64})
267271
@assert precompile(Tuple{typeof(tensorify), Expr})
272+
@assert precompile(Tuple{typeof(extracttensorobjects), Any})
273+
@assert precompile(Tuple{typeof(_flatten), Expr})
274+
# @assert precompile(Tuple{typeof(processcontractions), Any, Any, Any})
268275
@assert precompile(Tuple{typeof(defaultparser), Expr})
276+
@assert precompile(Tuple{typeof(defaultparser), Any})
269277
@assert precompile(Tuple{typeof(unique2), AVector})
270278
@assert precompile(Tuple{typeof(unique2), Array{Int64, 1}})
271279
@assert precompile(Tuple{typeof(use_blas)})

src/implementation/sparsearray.jl

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
struct SparseArray{T,N} <: AbstractArray{T,N}
2+
data::Dict{NTuple{N,Int64}, T}
3+
dims::NTuple{N,Int64}
4+
function SparseArray{T,N}(::UndefInitializer, dims::NTuple{N,Int}) where {T,N}
5+
data = Dict{NTuple{N,Int64}, T}()
6+
return new{T,N}(data, dims)
7+
end
8+
function SparseArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T,N}
9+
data = Dict{NTuple{N,Int64}, T}()
10+
return new{T,N}(data, dims)
11+
end
12+
function SparseArray(A::SparseArray{T,N}) where {T,N}
13+
new{T,N}(copy(A.data), A.dims)
14+
end
15+
end
16+
17+
@inline function Base.getindex(A::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
18+
@boundscheck checkbounds(A, I...)
19+
return get(A.data, I, zero(T))
20+
end
21+
@inline function Base.setindex!(A::SparseArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
22+
@boundscheck checkbounds(A, I...)
23+
if v != zero(v)
24+
A.data[I] = v
25+
else
26+
delete!(A.data, I) # does not do anything if there was no key I
27+
end
28+
return v
29+
end
30+
31+
Base.copy(A::SparseArray) = SparseArray(A)
32+
33+
Base.size(A::SparseArray) = A.dims
34+
35+
Base.similar(A::SparseArray, ::Type{S}, dims::Dims{N}) where {S,N} =
36+
SparseArray{N,S}(Dict{NTuple{N,Int64},S}(), dims)
37+
38+
# TODO: Basic arithmitic
39+
40+
# Vector space functions
41+
#------------------------
42+
function LinearAlgebra.lmul!(a::Number, d::SparseArray)
43+
lmul!(a, d.vals)
44+
# typical occupation in a dict is about 30% from experimental testing
45+
# the benefits of scaling all values (e.g. SIMD) largely outweight the extra work
46+
return d
47+
end
48+
function LinearAlgebra.rmul!(d::SparseArray, a::Number)
49+
rmul!(d.vals, a)
50+
return d
51+
end
52+
function LinearAlgebra.axpby!(α, x::SparseArray, β, y::SparseArray)
53+
lmul!(y, β)
54+
for (k, v) in x
55+
y[k] += α*v
56+
end
57+
return y
58+
end
59+
function LinearAlgebra.axpy!(α, x::SparseArray, y::SparseArray)
60+
for (k, v) in x
61+
y[k] += α*v
62+
end
63+
return y
64+
end
65+
66+
function LinearAlgebra.norm(x::SparseArray, p::Real = 2)
67+
norm(Base.Generator(last, A.data), p)
68+
end
69+
70+
function LinearAlgebra.dot(x::SparseArray, y::SparseArray)
71+
size(x) == size(y) || throw(DimensionMismatch("dot arguments have different size"))
72+
s = dot(zero(eltype(x)), zero(eltype(y)))
73+
if length(x.data) >= length(y.data)
74+
iter = keys(x.data)
75+
else
76+
iter = keys(y.data)
77+
end
78+
@inbounds for I in iter
79+
s += dot(x[I...], y[I...])
80+
end
81+
return s
82+
end
83+
84+
# TensorOperations compatiblity
85+
#-------------------------------
86+
function add!(α, A::SparseArray{<:Any, N}, CA::Symbol,
87+
β, C::SparseArray{<:Any, N}, indCinA) where {N}
88+
89+
(N == length(indCinA) && TupleTools.isperm(indCinA)) ||
90+
throw(IndexError("Invalid permutation of length $N: $indCinA"))
91+
size(C) == TupleTools.getindices(size(A), indCinA) ||
92+
throw(DimensionMismatch("non-matching sizes while adding arrays"))
93+
94+
β == one(β) || LinearAlgebra.lmul!(β, C);
95+
for (kA, vA) in A.data
96+
kC = TupleTools.getindices(kA, indCinA)
97+
C[kC...] += α* (conjA == :C ? conj(vA) : vA)
98+
end
99+
C
100+
end
101+
102+
function trace!(α, A::SparseArray{<:Any, NA}, CA::Symbol, β, C::SparseArray{<:Any, NC},
103+
indCinA, cindA1, cindA2) where {NA,NC}
104+
105+
NC == length(indCinA) ||
106+
throw(IndexError("Invalid selection of $NC out of $NA: $indCinA"))
107+
NA-NC == 2*length(cindA1) == 2*length(cindA2) ||
108+
throw(IndexError("invalid number of trace dimension"))
109+
pA = (indCinA..., cindA1..., cindA2...)
110+
TupleTools.isperm(pA) ||
111+
throw(IndexError("invalid permutation of length $(ndims(A)): $pA"))
112+
113+
sizeA = size(A)
114+
sizeC = size(C)
115+
116+
TupleTools.getindices(sizeA, cindA1) == TupleTools.getindices(sizeA, cindA2) ||
117+
throw(DimensionMismatch("non-matching trace sizes"))
118+
sizeC == TupleTools.getindices(sizeA, indCinA) ||
119+
throw(DimensionMismatch("non-matching sizes"))
120+
121+
β == one(β) || LinearAlgebra.lmul!(β, C);
122+
123+
for (kA, v) in A.data
124+
kAc1 = TupleTools.getindices(kA, cindA1)
125+
kAc2 = TupleTools.getindices(kA, cindA2)
126+
kAc1 == kAc2 || continue
127+
128+
kC = TupleTools.getindices(kC, indCinA)
129+
C[kC...] += α * (conjA == :C ? conj(v) : v)
130+
end
131+
return C
132+
end
133+
134+
function contract!(α, A::SparseArray, CA::Symbol, B::SparseArray, CB::Symbol,
135+
β, C::SparseArray,
136+
oindA::IndexTuple, cindA::IndexTuple, oindB::IndexTuple, cindB::IndexTuple,
137+
indCinoAB::IndexTuple, syms::Union{Nothing, NTuple{3,Symbol}} = nothing)
138+
139+
pA = (oindA...,cindA...)
140+
(length(pA) == ndims(A) && TupleTools.isperm(pA)) ||
141+
throw(IndexError("invalid permutation of length $(ndims(A)): $pA"))
142+
pB = (oindB...,cindB...)
143+
(length(pB) == ndims(B) && TupleTools.isperm(pB)) ||
144+
throw(IndexError("invalid permutation of length $(ndims(B)): $pB"))
145+
(length(oindA) + length(oindB) == ndims(C)) ||
146+
throw(IndexError("non-matching output indices in contraction"))
147+
(ndims(C) == length(indCinoAB) && isperm(indCinoAB)) ||
148+
throw(IndexError("invalid permutation of length $(ndims(C)): $indCinoAB"))
149+
150+
sizeA = size(A)
151+
sizeB = size(B)
152+
sizeC = size(C)
153+
154+
csizeA = TupleTools.getindices(sizeA, cindA)
155+
csizeB = TupleTools.getindices(sizeB, cindB)
156+
osizeA = TupleTools.getindices(sizeA, oindA)
157+
osizeB = TupleTools.getindices(sizeB, oindB)
158+
159+
csizeA == csizeB ||
160+
throw(DimensionMismatch("non-matching sizes in contracted dimensions"))
161+
TupleTools.getindices((osizeA..., osizeB...), indCinoAB) == size(C) ||
162+
throw(DimensionMismatch("non-matching sizes in uncontracted dimensions"))
163+
164+
β == one(β) || LinearAlgebra.lmul!(β, C);
165+
166+
for (kA, vA) in A.data
167+
kAc = TupleTools.getindices(kA, cindA)
168+
kAo = TupleTools.getindices(kA, oindA)
169+
for (kB, vB) in B.data
170+
kBc = TupleTools.getindices(kB, cindB)
171+
kAc == kBc || continue
172+
173+
kBo = TupleTools.getindices(kB, oindB)
174+
175+
kABo = (kAo..., kB...)
176+
177+
kC = TupleTools.getindices(kABo, indCinoAB)
178+
179+
C[kC...] += α * (conjA == :C ? conj(vA) : vA) * (conjB == :C ? conj(vB) : vB)
180+
end
181+
end
182+
C
183+
end

src/implementation/stridedarray.jl

Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function _similarstructure_from_indices(T, poA::IndexTuple, poB::IndexTuple,
104104
return sz
105105
end
106106

107-
scalar(C::AbstractArray) = ndims(C)==0 ? C[1] : throw(DimensionMismatch())
107+
scalar(C::AbstractArray) = ndims(C)==0 ? C[] : throw(DimensionMismatch())
108108

109109
function add!(α, A::AbstractArray{<:Any, N}, CA::Symbol,
110110
β, C::AbstractArray{<:Any, N}, indCinA) where {N}
@@ -254,21 +254,18 @@ function contract!(α, A::AbstractArray, CA::Symbol, B::AbstractArray, CB::Symbo
254254
(ndims(C) == length(indCinoAB) && isperm(indCinoAB)) ||
255255
throw(IndexError("invalid permutation of length $(ndims(C)): $indCinoAB"))
256256

257-
sizeA = i->size(A, i)
258-
sizeB = i->size(B, i)
259-
sizeC = i->size(C, i)
257+
sizeA = size(A)
258+
sizeB = size(B)
259+
sizeC = size(C)
260260

261-
csizeA = sizeA.(cindA)
262-
csizeB = sizeB.(cindB)
263-
osizeA = sizeA.(oindA)
264-
osizeB = sizeB.(oindB)
261+
csizeA = TupleTools.getindices(sizeA, cindA)
262+
csizeB = TupleTools.getindices(sizeB, cindB)
263+
osizeA = TupleTools.getindices(sizeA, oindA)
264+
osizeB = TupleTools.getindices(sizeB, oindB)
265265

266266
csizeA == csizeB ||
267267
throw(DimensionMismatch("non-matching sizes in contracted dimensions"))
268-
sizeAB = let osize = (osizeA..., osizeB...)
269-
i->osize[i]
270-
end
271-
sizeAB.(indCinoAB) == size(C) ||
268+
TupleTools.getindices((osizeA..., osizeB...), indCinoAB) == size(C) ||
272269
throw(DimensionMismatch("non-matching sizes in uncontracted dimensions"))
273270

274271
if use_blas() && TC <: BlasFloat
@@ -306,7 +303,8 @@ function contract!(α, A::AbstractArray, CA::Symbol, B::AbstractArray, CB::Symbo
306303
if isblascontractable(C, oindAinC, oindBinC, :D)
307304
C2 = C
308305
_blas_contract!(α, A2, CA2, B2, CB2, β, C2,
309-
oindA, cindA, oindB, cindB, oindAinC, oindBinC)
306+
oindA, cindA, oindB, cindB, oindAinC, oindBinC,
307+
osizeA, csizeA, osizeB, csizeB)
310308
else
311309
if syms === nothing
312310
C2 = similar_from_indices(TC, oindAinC, oindBinC, C, :N)
@@ -315,31 +313,38 @@ function contract!(α, A::AbstractArray, CA::Symbol, B::AbstractArray, CB::Symbo
315313
end
316314
_blas_contract!(1, A2, CA2, B2, CB2, 0, C2,
317315
oindA, cindA, oindB, cindB,
318-
_trivtuple(oindA), length(oindA) .+ _trivtuple(oindB))
316+
_trivtuple(oindA), length(oindA) .+ _trivtuple(oindB),
317+
osizeA, csizeA, osizeB, csizeB)
318+
319319
add!(α, C2, :N, β, C, indCinoAB, ())
320320
end
321321
else
322-
_native_contract!(α, A, CA, B, CB, β, C, oindA, cindA, oindB, cindB, indCinoAB)
322+
_native_contract!(α, A, CA, B, CB, β, C, oindA, cindA, oindB, cindB, indCinoAB,
323+
osizeA, csizeA, osizeB, csizeB)
323324
end
324325
return C
325326
end
326327

327-
function isblascontractable(A::AbstractArray{T,N}, p1::IndexTuple, p2::IndexTuple,
328-
C::Symbol) where {T,N}
328+
function isblascontractable(A::AbstractArray, p1::IndexTuple, p2::IndexTuple,
329+
C::Symbol)
329330

330-
T <: LinearAlgebra.BlasFloat || return false
331+
eltype(A) <: LinearAlgebra.BlasFloat || return false
331332
@unsafe_strided A isblascontractable(A, p1, p2, C)
332333
end
333334

334-
function isblascontractable(A::AbstractStridedView{T,N}, p1::IndexTuple, p2::IndexTuple,
335-
C::Symbol) where {T,N}
335+
function isblascontractable(A::AbstractStridedView, p1::IndexTuple, p2::IndexTuple,
336+
C::Symbol)
336337

337-
T <: LinearAlgebra.BlasFloat || return false
338-
strideA = i->stride(A, i)
339-
sizeA = i->size(A,i)
338+
eltype(A) <: LinearAlgebra.BlasFloat || return false
339+
sizeA = size(A)
340+
stridesA = strides(A)
341+
sizeA1 = TupleTools.getindices(sizeA, p1)
342+
sizeA2 = TupleTools.getindices(sizeA, p2)
343+
stridesA1 = TupleTools.getindices(stridesA, p1)
344+
stridesA2 = TupleTools.getindices(stridesA, p2)
340345

341-
canfuse1, d1, s1 = _canfuse(sizeA.(p1), strideA.(p1))
342-
canfuse2, d2, s2 = _canfuse(sizeA.(p2), strideA.(p2))
346+
canfuse1, d1, s1 = _canfuse(sizeA1, stridesA1)
347+
canfuse2, d2, s2 = _canfuse(sizeA2, stridesA2)
343348

344349
if C == :D # destination
345350
return A.op == identity && canfuse1 && canfuse2 && s1 == 1
@@ -369,18 +374,9 @@ function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N}
369374
end
370375
_trivtuple(t::NTuple{N}) where {N} = ntuple(identity, Val(N))
371376

372-
function _blas_contract!(α, A::AbstractArray{T}, CA, B::AbstractArray{T}, CB,
373-
β, C::AbstractArray{T}, oindA, cindA, oindB, cindB, oindAinC, oindBinC) where
374-
{T<:LinearAlgebra.BlasFloat}
375-
376-
sizeA = i->size(A, i)
377-
sizeB = i->size(B, i)
378-
sizeC = i->size(C, i)
379-
380-
csizeA = sizeA.(cindA)
381-
csizeB = sizeB.(cindB)
382-
osizeA = sizeA.(oindA)
383-
osizeB = sizeB.(oindB)
377+
function _blas_contract!(α, A::AbstractArray, CA, B::AbstractArray, CB,
378+
β, C::AbstractArray, oindA, cindA, oindB, cindB, oindAinC, oindBinC,
379+
osizeA, csizeA, osizeB, csizeB)
384380

385381
@unsafe_strided A B C begin
386382
A2 = sreshape(permutedims(A, (oindA..., cindA...)), (prod(osizeA), prod(csizeA)))
@@ -403,16 +399,8 @@ function _blas_contract!(α, A::AbstractArray{T}, CA, B::AbstractArray{T}, CB,
403399
end
404400

405401
function _native_contract!(α, A::AbstractArray, CA::Symbol, B::AbstractArray, CB::Symbol,
406-
β, C::AbstractArray, oindA, cindA, oindB, cindB, indCinoAB)
407-
408-
sizeA = i->size(A, i)
409-
sizeB = i->size(B, i)
410-
sizeC = i->size(C, i)
411-
412-
csizeA = sizeA.(cindA)
413-
csizeB = sizeB.(cindB)
414-
osizeA = sizeA.(oindA)
415-
osizeB = sizeB.(oindB)
402+
β, C::AbstractArray, oindA, cindA, oindB, cindB, indCinoAB,
403+
osizeA, csizeA, osizeB, csizeB)
416404

417405
ipC = TupleTools.invperm(indCinoAB)
418406
if CA == :N

0 commit comments

Comments
 (0)