Skip to content

Move to MarkovKernels.jl for Gaussians #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MarkovKernels = "202a2b00-fae3-41a1-a054-d2db40c1e3ea"
MatrixEquations = "99c1a7ee-ab34-5fd5-8076-27c950a045f4"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand Down
2 changes: 1 addition & 1 deletion src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ import ..ProbNumDiffEq: dalton_data_loglik, filtering_data_loglik, fenrir_data_l
export dalton_data_loglik, filtering_data_loglik, fenrir_data_loglik
end

include("precompile.jl")
# include("precompile.jl")

end
2 changes: 1 addition & 1 deletion src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ function OrdinaryDiffEq.alg_cache(
H = factorized_similar(FAC, d, D)
v = similar(Array{uElType}, d)
S = factorized_zeros(FAC, d, d)
measurement = Gaussian(v, S)
measurement = Gaussian{uElType}(v, S)

# Caches
du = is_secondorder_ode ? similar(u.x[2]) : similar(u)
Expand Down
23 changes: 16 additions & 7 deletions src/filtering/markov_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@ copy!(dst::AffineNormalKernel, src::AffineNormalKernel) = begin
return nothing
end

RecursiveArrayTools.recursivecopy(K::AffineNormalKernel) = copy(K)
RecursiveArrayTools.recursivecopy!(
dst::AffineNormalKernel, src::AffineNormalKernel) = copy!(dst, src)
RecursiveArrayTools.recursivecopy(K::AffineNormalKernel) =
AffineNormalKernel(
RecursiveArrayTools.recursivecopy(K.A),
ismissing(K.b) ? missing : RecursiveArrayTools.recursivecopy(K.b),
RecursiveArrayTools.recursivecopy(K.C),
)
RecursiveArrayTools.recursivecopy!(dst::AffineNormalKernel, src::AffineNormalKernel) = begin
RecursiveArrayTools.recursivecopy!(dst.A, src.A)
RecursiveArrayTools.recursivecopy!(dst.b, src.b)
RecursiveArrayTools.recursivecopy!(dst.C, src.C)
return dst
end

isapprox(K1::AffineNormalKernel, K2::AffineNormalKernel; kwargs...) =
isapprox(K1.A, K2.A; kwargs...) &&
Expand Down Expand Up @@ -257,8 +266,8 @@ function compute_backward_kernel!(
Q = D ÷ d # n_derivatives_dim
_Kout =
AffineNormalKernel(Kout.A.B, reshape_no_alloc(Kout.b, d, Q)', PSDMatrix(Kout.C.R.B))
_x_pred = Gaussian(reshape_no_alloc(xpred.μ, d, Q)', PSDMatrix(xpred.Σ.R.B))
_x = Gaussian(reshape_no_alloc(x.μ, d, Q)', PSDMatrix(x.Σ.R.B))
_x_pred = Gaussian{T}(reshape_no_alloc(xpred.μ, d, Q)', PSDMatrix(xpred.Σ.R.B))
_x = Gaussian{T}(reshape_no_alloc(x.μ, d, Q)', PSDMatrix(x.Σ.R.B))
_K = AffineNormalKernel(K.A.B, reshape_no_alloc(K.b, d, Q)', PSDMatrix(K.C.R.B))
_C_DxD = C_DxD.B
_diffusion =
Expand Down Expand Up @@ -298,11 +307,11 @@ function compute_backward_kernel!(
Kout.b[i:d:end],
PSDMatrix(Kout.C.R.blocks[i]),
)
_xpred = Gaussian(
_xpred = Gaussian{T}(
xpred.μ[i:d:end],
PSDMatrix(xpred.Σ.R.blocks[i]),
)
_x = Gaussian(
_x = Gaussian{T}(
x.μ[i:d:end],
PSDMatrix(x.Σ.R.blocks[i]),
)
Expand Down
14 changes: 8 additions & 6 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ function update!(
x_out::SRGaussian{T,<:IsometricKroneckerProduct},
x_pred::SRGaussian{T,<:IsometricKroneckerProduct},
measurement::Gaussian{
T,
<:AbstractVector,
<:Union{<:PSDMatrix{T,<:IsometricKroneckerProduct},<:IsometricKroneckerProduct},
},
Expand All @@ -165,9 +166,9 @@ function update!(
D = length(x_out.μ) # full_state_dim
d = H.rdim # ode_dimension_dim
Q = D ÷ d # n_derivatives_dim
_x_out = Gaussian(reshape_no_alloc(x_out.μ, d, Q)', PSDMatrix(x_out.Σ.R.B))
_x_pred = Gaussian(reshape_no_alloc(x_pred.μ, d, Q)', PSDMatrix(x_pred.Σ.R.B))
_measurement = Gaussian(
_x_out = Gaussian{T}(reshape_no_alloc(x_out.μ, d, Q)', PSDMatrix(x_out.Σ.R.B))
_x_pred = Gaussian{T}(reshape_no_alloc(x_pred.μ, d, Q)', PSDMatrix(x_pred.Σ.R.B))
_measurement = Gaussian{T}(
reshape_no_alloc(measurement.μ, d, 1)',
measurement.Σ isa PSDMatrix ? PSDMatrix(measurement.Σ.R.B) : measurement.Σ.B,
)
Expand Down Expand Up @@ -197,6 +198,7 @@ function update!(
x_out::SRGaussian{T,<:BlocksOfDiagonals},
x_pred::SRGaussian{T,<:BlocksOfDiagonals},
measurement::Gaussian{
T,
<:AbstractVector,
<:Union{<:PSDMatrix{T,<:BlocksOfDiagonals},<:BlocksOfDiagonals},
},
Expand All @@ -214,11 +216,11 @@ function update!(
ll = zero(eltype(x_out.μ))
@views for i in eachindex(blocks(x_out.Σ.R))
_, _ll = update!(
Gaussian(x_out.μ[i:d:end],
Gaussian{T}(x_out.μ[i:d:end],
PSDMatrix(x_out.Σ.R.blocks[i])),
Gaussian(x_pred.μ[i:d:end],
Gaussian{T}(x_pred.μ[i:d:end],
PSDMatrix(x_pred.Σ.R.blocks[i])),
Gaussian(measurement.μ[i:d:end],
Gaussian{T}(measurement.μ[i:d:end],
if measurement.Σ isa PSDMatrix
PSDMatrix(measurement.Σ.R.blocks[i])
else
Expand Down
125 changes: 6 additions & 119 deletions src/gaussians.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,131 +2,18 @@
# `Gaussian` distributions
# Based on @mschauer's GaussianDistributions.jl
############################################################################################
"""
Gaussian(μ, Σ) -> P
import MarkovKernels: Normal
const Gaussian = Normal

Gaussian distribution with mean `μ` and covariance `Σ`. Defines `rand(P)` and `(log-)pdf(P, x)`.
Designed to work with `Number`s, `UniformScaling`s, `StaticArrays` and `PSD`-matrices.

Implementation details: On `Σ` the functions `logdet`, `whiten` and `unwhiten`
(or `chol` as fallback for the latter two) are called.
"""
struct Gaussian{T,S}
μ::T
Σ::S
Gaussian(μ::T, Σ::S) where {T,S} = new{T,S}(μ, Σ)
end
Base.convert(::Type{Gaussian{T,S}}, g::Gaussian) where {T,S} =
Gaussian(convert(T, g.μ), convert(S, g.Σ))

# Base
Base.:(==)(g1::Gaussian, g2::Gaussian) = g1.μ == g2.μ && g1.Σ == g2.Σ
Base.isapprox(g1::Gaussian, g2::Gaussian; kwargs...) =
isapprox(g1.μ, g2.μ; kwargs...) && isapprox(g1.Σ, g2.Σ; kwargs...)
copy(P::Gaussian) = Gaussian(copy(P.μ), copy(P.Σ))
similar(P::Gaussian) = Gaussian(similar(P.μ), similar(P.Σ))
Base.copyto!(P::AbstractArray{<:Gaussian}, idx::Integer, el::Gaussian) =
(P[idx] = copy(el); P)
function Base.copy!(dst::Gaussian, src::Gaussian)
copy!(dst.μ, src.μ)
copy!(dst.Σ, src.Σ)
return dst
end
length(P::Gaussian) = length(mean(P))
size(g::Gaussian) = size(g.μ)
eltype(::Type{G}) where {G<:Gaussian} = G
Base.@propagate_inbounds Base.getindex(P::Gaussian, i::Integer) =
Gaussian(P.μ[i], diag(P.Σ)[i])

# Statistics
mean(P::Gaussian) = P.μ
cov(P::Gaussian) = P.Σ
var(P::Gaussian{<:Number}) = P.Σ
std(P::Gaussian{<:Number}) = sqrt(var(P))
var(g::Gaussian) = diag(g.Σ)
std(g::Gaussian) = sqrt.(diag(g.Σ))

dim(P::Gaussian) = length(P.μ)

# whiten(Σ::PSD, z) = Σ.σ\z
whiten(Σ, z) = cholesky(Σ).U' \ z
whiten(Σ::Number, z) = sqrt(Σ) \ z
whiten(Σ::UniformScaling, z) = sqrt(Σ.λ) \ z

# unwhiten(Σ::PSD, z) = Σ.σ*z
unwhiten(Σ, z) = cholesky(Σ).U' * z
unwhiten(Σ::Number, z) = sqrt(Σ) * z
unwhiten(Σ::UniformScaling, z) = sqrt(Σ.λ) * z

sqmahal(P::Gaussian, x) = norm_sqr(whiten(P.Σ, x - P.μ))

rand(P::Gaussian) = rand(GLOBAL_RNG, P)
rand(RNG::AbstractRNG, P::Gaussian) = P.μ + unwhiten(P.Σ, randn(RNG, typeof(P.μ)))
rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}) where {T} =
P.μ + unwhiten(P.Σ, randn(RNG, T, length(P.μ)))
rand(RNG::AbstractRNG, P::Gaussian{<:Number}) =
P.μ + sqrt(P.Σ) * randn(RNG, typeof(one(P.μ)))

_logdet(Σ, d) = logdet(Σ)
_logdet(J::UniformScaling, d) = log(J.λ) * d
logpdf(P::Gaussian, x) = -(sqmahal(P, x) + _logdet(P.Σ, dim(P)) + dim(P) * log(2pi)) / 2
pdf(P::Gaussian, x) = exp(logpdf(P::Gaussian, x))

Base.:+(g::Gaussian, vec) = Gaussian(g.μ + vec, g.Σ)
Base.:+(vec, g::Gaussian) = g + vec
Base.:-(g::Gaussian, vec) = g + (-vec)
Base.:*(M, g::Gaussian) = Gaussian(M * g.μ, X_A_Xt(g.Σ, M))

function rand_scalar(RNG::AbstractRNG, P::Gaussian{T}, dims) where {T}
X = zeros(T, dims)
for i in 1:length(X)
X[i] = rand(RNG, P)
end
X
end

function rand_vector(
RNG::AbstractRNG,
P::Gaussian{Vector{T}},
dims::Union{Integer,NTuple},
) where {T}
X = zeros(T, dim(P), dims...)
for i in 1:prod(dims)
X[:, i] = rand(RNG, P)
end
X
end
rand(RNG::AbstractRNG, P::Gaussian, dim::Integer) = rand_scalar(RNG, P, dim)
rand(RNG::AbstractRNG, P::Gaussian, dims::Tuple{Vararg{Int64,N}} where {N}) =
rand_scalar(RNG, P, dims)

rand(RNG::AbstractRNG, P::Gaussian{Vector{T}}, dim::Integer) where {T} =
rand_vector(RNG, P, dim)
rand(
RNG::AbstractRNG,
P::Gaussian{Vector{T}},
dims::Tuple{Vararg{Int64,N}} where {N},
) where {T} = rand_vector(RNG, P, dims)
rand(P::Gaussian, dims::Tuple{Vararg{Int64,N}} where {N}) = rand(GLOBAL_RNG, P, dims)
rand(P::Gaussian, dim::Integer) = rand(GLOBAL_RNG, P, dim)

# RecursiveArrayTools
RecursiveArrayTools.recursivecopy(P::Gaussian) = copy(P)
RecursiveArrayTools.recursivecopy!(dst::Gaussian, src::Gaussian) = copy!(dst, src)

# Print
show(io::IO, g::Gaussian) = print(io, "Gaussian(μ=$(g.μ), Σ=$(g.Σ))")
show(io::IO, ::MIME"text/plain", g::Gaussian{T,S}) where {T,S} = begin
println(io, "Gaussian{$T,$S}(")
println(io, " μ=$(g.μ),")
println(io, " Σ=$(g.Σ)")
print(io, ")")
function Normal(μ::AbstractVector, Σ::PSDMatrix)
T = promote_type(eltype(μ), eltype(Σ))
return Normal{T}(convert(AbstractVector{T}, μ), convert(PSDMatrix{T}, Σ))
end

############################################################################################
# `SRGaussian`: Gaussians with PDFMatrix covariances
############################################################################################
const SRGaussian{T,S} = Gaussian{VM,PSDMatrix{T,S}} where {VM<:AbstractVecOrMat{T}}
const SRGaussian{T,S} = Gaussian{A,VM,PSDMatrix{T,S}} where {A,VM<:AbstractVecOrMat{T}}
function _gaussian_mul!(g_out::SRGaussian, M::AbstractMatrix, g_in::SRGaussian)
_matmul!(g_out.μ, M, g_in.μ)
fast_X_A_Xt!(g_out.Σ, g_in.Σ, M)
Expand Down
4 changes: 2 additions & 2 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ function (interp::ODEFilterPosterior)(
t, interp.ts, interp.x_filt, interp.x_smooth, interp.diffusions, interp.cache;
smoothed=interp.smooth)
u = proj * x
P = zeros(Bool, length(idxs), length(u))
P = zeros(Bool, length(idxs), length(u))
for (i, idx) in enumerate(idxs)
P[i, idx] = 1
end
Expand Down Expand Up @@ -362,7 +362,7 @@ function interpolate(
return goal_pred
end

@assert length(x_filt) == length(x_smooth)
@assert length(x_filt) == length(x_smooth)
next_t = t[idx+1]
next_smoothed = x_smooth[idx+1]

Expand Down
Loading