@@ -14,6 +14,7 @@ struct SparseArray{T,N} <: AbstractArray{T,N}
14
14
end
15
15
end
16
16
17
+ memsize (A:: SparseArray ) = memsize (A. data)
17
18
@inline function Base. getindex (A:: SparseArray{T,N} , I:: Vararg{Int,N} ) where {T,N}
18
19
@boundscheck checkbounds (A, I... )
19
20
return get (A. data, I, zero (T))
28
29
return v
29
30
end
30
31
32
+ Array (a:: SparseArray{T,N} ) where {T,N} = Array {T,N} (a)
33
+ function Array {T,N} (a:: SparseArray ) where {T,N}
34
+ d = fill (zero (T), size (a))
35
+ for (k,v) in a. data
36
+ d[k... ] = v
37
+ end
38
+ d
39
+ end
40
+
41
+ SparseArray (a:: AbstractArray{T,N} ) where {T,N} = SparseArray {T,N} (a)
42
+ function SparseArray {T,N} (a:: AbstractArray ) where {T,N}
43
+ d = SparseArray {T} (undef, size (a))
44
+ for I in CartesianIndices (a)
45
+ a[I] == zero (T) && continue
46
+ d[I] = a[I]
47
+ end
48
+ return d
49
+ end
50
+
31
51
Base. copy (A:: SparseArray ) = SparseArray (A)
32
52
33
53
Base. size (A:: SparseArray ) = A. dims
34
54
35
55
Base. similar (A:: SparseArray , :: Type{S} , dims:: Dims{N} ) where {S,N} =
36
- SparseArray {N,S} ( Dict {NTuple{N,Int64},S} () , dims)
56
+ SparseArray {S,N} (undef , dims)
37
57
38
58
# TODO : Basic arithmitic
39
59
40
60
# Vector space functions
41
61
# ------------------------
42
62
function LinearAlgebra. lmul! (a:: Number , d:: SparseArray )
43
- lmul! (a, d. vals)
63
+ lmul! (a, d. data . vals)
44
64
# typical occupation in a dict is about 30% from experimental testing
45
65
# the benefits of scaling all values (e.g. SIMD) largely outweight the extra work
46
66
return d
47
67
end
48
68
function LinearAlgebra. rmul! (d:: SparseArray , a:: Number )
49
- rmul! (d. vals, a)
69
+ rmul! (d. data . vals, a)
50
70
return d
51
71
end
52
72
function LinearAlgebra. axpby! (α, x:: SparseArray , β, y:: SparseArray )
@@ -172,11 +192,11 @@ function contract!(α, A::SparseArray, CA::Symbol, B::SparseArray, CB::Symbol,
172
192
173
193
kBo = TupleTools. getindices (kB, oindB)
174
194
175
- kABo = (kAo... , kB ... )
195
+ kABo = (kAo... , kBo ... )
176
196
177
197
kC = TupleTools. getindices (kABo, indCinoAB)
178
198
179
- C[kC... ] += α * (conjA == :C ? conj (vA) : vA) * (conjB == :C ? conj (vB) : vB)
199
+ C[kC... ] += α * (CA == :C ? conj (vA) : vA) * (CB == :C ? conj (vB) : vB)
180
200
end
181
201
end
182
202
C
0 commit comments