Skip to content

Commit f54637c

Browse files
Jutho HaegemanJutho Haegeman
authored andcommitted
updates to sparsearray
1 parent 41ce067 commit f54637c

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

src/implementation/sparsearray.jl

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,54 @@ struct SparseArray{T,N} <: AbstractArray{T,N}
1313
new{T,N}(copy(A.data), A.dims)
1414
end
1515
end
16+
SparseArray{T}(::UndefInitializer, dims...) where {T} = SparseArray{T}(undef, dims)
17+
18+
nonzeros(A::SparseArray{<:Any,1}) = (first(k)=>v for (k,v) in A.data)
19+
nonzeros(A::SparseArray) = (CartesianIndex(k)=>v for (k,v) in A.data)
20+
21+
function SparseArray(A::Adjoint{T,<:SparseArray{T,2}}) where T
22+
B = SparseArray{T}(undef, size(A))
23+
for (I, v) in nonzeros(parent(A))
24+
B[I[2], I[1]] = v
25+
end
26+
return B
27+
end
1628

1729
memsize(A::SparseArray) = memsize(A.data)
1830
@inline function Base.getindex(A::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
1931
@boundscheck checkbounds(A, I...)
2032
return get(A.data, I, zero(T))
2133
end
34+
35+
_newindex(i::Int, range::Int) = i == range ? () : nothing
36+
function _newindex(i::Int, range::AbstractVector{Int})
37+
k = findfirst(==(i), range)
38+
k === nothing ? nothing : (k,)
39+
end
40+
_newindices(I::Tuple{}, indices::Tuple{}) = ()
41+
function _newindices(I::Tuple, indices::Tuple)
42+
i = _newindex(I[1], indices[1])
43+
Itail = _newindices(Base.tail(I), Base.tail(indices))
44+
(i === nothing || Itail === nothing) && return nothing
45+
return (i..., Itail...)
46+
end
47+
48+
_findfirstvalue(v, r) = findfirst(==(v), r)
49+
# slicing should produce SparseArray
50+
function Base._unsafe_getindex(::IndexCartesian, A::SparseArray{T,N},
51+
I::Vararg{<:Union{Int,AbstractVector{Int}},N}) where {T,N}
52+
@boundscheck checkbounds(A, I...)
53+
indices = Base.to_indices(A, I)
54+
B = SparseArray{T}(undef, length.(Base.index_shape(indices...)))
55+
for (k, v) in A.data
56+
newI = _newindices(k, indices)
57+
if newI !== nothing
58+
B[newI...] = v
59+
end
60+
end
61+
return B
62+
end
63+
2264
@inline function Base.setindex!(A::SparseArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
2365
@boundscheck checkbounds(A, I...)
2466
if v != zero(v)
@@ -59,32 +101,32 @@ Base.similar(A::SparseArray, ::Type{S}, dims::Dims{N}) where {S,N} =
59101

60102
# Vector space functions
61103
#------------------------
62-
function LinearAlgebra.lmul!(a::Number, d::SparseArray)
63-
lmul!(a, d.data.vals)
104+
function LinearAlgebra.lmul!(a::Number, x::SparseArray)
105+
lmul!(a, x.data.vals)
64106
# typical occupation in a dict is about 30% from experimental testing
65107
# the benefits of scaling all values (e.g. SIMD) largely outweight the extra work
66-
return d
108+
return x
67109
end
68-
function LinearAlgebra.rmul!(d::SparseArray, a::Number)
69-
rmul!(d.data.vals, a)
70-
return d
110+
function LinearAlgebra.rmul!(x::SparseArray, a::Number)
111+
rmul!(x.data.vals, a)
112+
return x
71113
end
72114
function LinearAlgebra.axpby!(α, x::SparseArray, β, y::SparseArray)
73115
lmul!(y, β)
74-
for (k, v) in x
116+
for (k, v) in nonzeros(x)
75117
y[k] += α*v
76118
end
77119
return y
78120
end
79121
function LinearAlgebra.axpy!(α, x::SparseArray, y::SparseArray)
80-
for (k, v) in x
122+
for (k, v) in nonzeros(x)
81123
y[k] += α*v
82124
end
83125
return y
84126
end
85127

86128
function LinearAlgebra.norm(x::SparseArray, p::Real = 2)
87-
norm(Base.Generator(last, A.data), p)
129+
norm(Base.Generator(last, x.data), p)
88130
end
89131

90132
function LinearAlgebra.dot(x::SparseArray, y::SparseArray)

0 commit comments

Comments
 (0)