Skip to content

Commit 41ce067

Browse files
committed
some further sparsearray improvements (thanks @maartenvd)
1 parent b6f645f commit 41ce067

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

src/implementation/sparsearray.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct SparseArray{T,N} <: AbstractArray{T,N}
1414
end
1515
end
1616

17+
memsize(A::SparseArray) = memsize(A.data)
1718
@inline function Base.getindex(A::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
1819
@boundscheck checkbounds(A, I...)
1920
return get(A.data, I, zero(T))
@@ -28,25 +29,44 @@ end
2829
return v
2930
end
3031

32+
Array(a::SparseArray{T,N}) where {T,N} = Array{T,N}(a)
33+
function Array{T,N}(a::SparseArray) where {T,N}
34+
d = fill(zero(T), size(a))
35+
for (k,v) in a.data
36+
d[k...] = v
37+
end
38+
d
39+
end
40+
41+
SparseArray(a::AbstractArray{T,N}) where {T,N} = SparseArray{T,N}(a)
42+
function SparseArray{T,N}(a::AbstractArray) where {T,N}
43+
d = SparseArray{T}(undef, size(a))
44+
for I in CartesianIndices(a)
45+
a[I] == zero(T) && continue
46+
d[I] = a[I]
47+
end
48+
return d
49+
end
50+
3151
Base.copy(A::SparseArray) = SparseArray(A)
3252

3353
Base.size(A::SparseArray) = A.dims
3454

3555
Base.similar(A::SparseArray, ::Type{S}, dims::Dims{N}) where {S,N} =
36-
SparseArray{N,S}(Dict{NTuple{N,Int64},S}(), dims)
56+
SparseArray{S,N}(undef, dims)
3757

3858
# TODO: Basic arithmitic
3959

4060
# Vector space functions
4161
#------------------------
4262
function LinearAlgebra.lmul!(a::Number, d::SparseArray)
43-
lmul!(a, d.vals)
63+
lmul!(a, d.data.vals)
4464
# typical occupation in a dict is about 30% from experimental testing
4565
# the benefits of scaling all values (e.g. SIMD) largely outweight the extra work
4666
return d
4767
end
4868
function LinearAlgebra.rmul!(d::SparseArray, a::Number)
49-
rmul!(d.vals, a)
69+
rmul!(d.data.vals, a)
5070
return d
5171
end
5272
function LinearAlgebra.axpby!(α, x::SparseArray, β, y::SparseArray)
@@ -172,11 +192,11 @@ function contract!(α, A::SparseArray, CA::Symbol, B::SparseArray, CB::Symbol,
172192

173193
kBo = TupleTools.getindices(kB, oindB)
174194

175-
kABo = (kAo..., kB...)
195+
kABo = (kAo..., kBo...)
176196

177197
kC = TupleTools.getindices(kABo, indCinoAB)
178198

179-
C[kC...] += α * (conjA == :C ? conj(vA) : vA) * (conjB == :C ? conj(vB) : vB)
199+
C[kC...] += α * (CA == :C ? conj(vA) : vA) * (CB == :C ? conj(vB) : vB)
180200
end
181201
end
182202
C

0 commit comments

Comments
 (0)