Skip to content

Added OTFlow Layers #963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@

include("ffjord.jl")
include("neural_de.jl")

include("otflow.jl")
include("collocation.jl")
include("multiple_shooting.jl")

export NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, AugmentedNDELayer,
NeuralODEMM
export FFJORD, FFJORDDistribution
export OTFlow, OTFlowDistribution

Check warning on line 42 in src/DiffEqFlux.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".

Check warning on line 42 in src/DiffEqFlux.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".
export DimMover

export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel, TriweightKernel,
Expand Down
96 changes: 96 additions & 0 deletions src/otflow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
struct OTFlow <: AbstractLuxLayer

Check warning on line 1 in src/otflow.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".
d::Int # Input dimension
m::Int # Hidden dimension
r::Int # Rank for low-rank approximation
end

# Constructor with default rank
OTFlow(d::Int, m::Int; r::Int=min(10,d)) = OTFlow(d, m, r)

Check warning on line 8 in src/otflow.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".

Check warning on line 8 in src/otflow.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".

# Initialize parameters and states
function Lux.initialparameters(rng::AbstractRNG, l::OTFlow)

Check warning on line 11 in src/otflow.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".
w = randn(rng, Float32, l.m) .* 0.01f0
A = randn(rng, Float32, l.r, l.d + 1) .* 0.01f0
b = zeros(Float32, l.d + 1)
c = zero(Float32)
K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01f0
K1 = randn(rng, Float32, l.m, l.m) .* 0.01f0
b0 = zeros(Float32, l.m)
b1 = zeros(Float32, l.m)

return (; w, A, b, c, K0, K1, b0, b1)
end

sigma(x) = log(exp(x) + exp(-x))
sigma_prime(x) = tanh(x)
sigma_double_prime(x) = 1 - tanh(x)^2

function resnet_forward(x::AbstractVector, t::Real, ps)
s = vcat(x, t)
u0 = sigma.(ps.K0 * s .+ ps.b0)
u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) # h=1 as in paper
return u1
end

function potential(x::AbstractVector, t::Real, ps)
s = vcat(x, t)
N = resnet_forward(x, t, ps)
quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s
linear_term = ps.b' * s
return ps.w' * N + quadratic_term + linear_term + ps.c
end

function gradient(x::AbstractVector, t::Real, ps, d::Int)
s = vcat(x, t)
u0 = sigma.(ps.K0 * s .+ ps.b0)
z1 = ps.w .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w)
z0 = ps.K0' * (sigma_prime.(ps.K0 * s .+ ps.b0) .* z1)

grad = z0 + (ps.A' * ps.A) * s + ps.b
return grad[1:d]
end

function trace(x::AbstractVector, t::Real, ps, d::Int)
s = vcat(x, t)
u0 = sigma.(ps.K0 * s .+ ps.b0)
z1 = ps.w .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w)

K0_E = ps.K0[:, 1:d]
A_E = ps.A[:, 1:d]

t0 = sum(sigma_double_prime.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2))
J = Diagonal(sigma_prime.(ps.K0 * s .+ ps.b0)) * K0_E
t1 = sum(sigma_double_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w .* (ps.K1 * J) .^ 2)
trace_A = tr(A_E' * A_E)

return t0 + t1 + trace_A
end

function (l::OTFlow)(xt::Tuple{AbstractVector, Real}, ps, st)

Check warning on line 69 in src/otflow.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".
x, t = xt
v = -gradient(x, t, ps, l.d) # v = -∇Φ
tr = -trace(x, t, ps, l.d) # tr(∇v) = -tr(∇²Φ)
return (v, tr), st
end

function simple_loss(x::AbstractVector, t::Real, l::OTFlow, ps)

Check warning on line 76 in src/otflow.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"OT" should be "TO" or "OF" or "OR" or "NOT".
(v, tr), _ = l((x, t), ps, NamedTuple())
return sum(v .^ 2) / 2 - tr
end

function manual_gradient(x::AbstractVector, t::Real, l::OTFlow, ps)
s = vcat(x, t)
u0 = sigma.(ps.K0 * s .+ ps.b0)
u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1)

v = -gradient(x, t, ps, l.d)
tr = -trace(x, t, ps, l.d)

# Simplified gradients (not full implementation)
grad_w = u1
grad_A = (ps.A * s) * s'

return (w = grad_w, A = grad_A, b = similar(ps.b), c = 0.0,
K0 = zeros(l.m, l.d + 1), K1 = zeros(l.m, l.m),
b0 = zeros(l.m), b1 = zeros(l.m))
end
102 changes: 51 additions & 51 deletions test/collocation_tests.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,56 @@
@testitem "Collocation" tags=[:layers] begin
using OrdinaryDiffEq
@testitem "Collocation" tags = [:layers] begin
using OrdinaryDiffEq

bounded_support_kernels = [EpanechnikovKernel(), UniformKernel(), TriangularKernel(),
QuarticKernel(), TriweightKernel(), TricubeKernel(), CosineKernel()]
bounded_support_kernels = [EpanechnikovKernel(), UniformKernel(), TriangularKernel(),
QuarticKernel(), TriweightKernel(), TricubeKernel(), CosineKernel()]

unbounded_support_kernels = [
GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()]
unbounded_support_kernels = [
GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()]

@testset "Kernel Functions" begin
ts = collect(-5.0:0.1:5.0)
@testset "Kernels with support from -1 to 1" begin
minus_one_index = findfirst(x -> ==(x, -1.0), ts)
plus_one_index = findfirst(x -> ==(x, 1.0), ts)
@testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels,
[0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0])
ws = DiffEqFlux.calckernel.((kernel,), ts)
# t < -1
@test all(ws[1:(minus_one_index - 1)] .== 0.0)
# t > 1
@test all(ws[(plus_one_index + 1):end] .== 0.0)
# -1 < t <1
@test all(ws[(minus_one_index + 1):(plus_one_index - 1)] .> 0.0)
# t = 0
@test DiffEqFlux.calckernel(kernel, 0.0) == x0
end
end
@testset "Kernels with unbounded support" begin
@testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels,
[1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))])
# t = 0
@test DiffEqFlux.calckernel(kernel, 0.0) == x0
end
end
end
@testset "Kernel Functions" begin
ts = collect(-5.0:0.1:5.0)
@testset "Kernels with support from -1 to 1" begin
minus_one_index = findfirst(x -> ==(x, -1.0), ts)
plus_one_index = findfirst(x -> ==(x, 1.0), ts)
@testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels,
[0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0])
ws = DiffEqFlux.calckernel.((kernel,), ts)
# t < -1
@test all(ws[1:(minus_one_index-1)] .== 0.0)
# t > 1
@test all(ws[(plus_one_index+1):end] .== 0.0)
# -1 < t <1
@test all(ws[(minus_one_index+1):(plus_one_index-1)] .> 0.0)
# t = 0
@test DiffEqFlux.calckernel(kernel, 0.0) == x0
end
end
@testset "Kernels with unbounded support" begin
@testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels,
[1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))])
# t = 0
@test DiffEqFlux.calckernel(kernel, 0.0) == x0
end
end
end

@testset "Collocation of data" begin
f(u, p, t) = p .* u
rc = 2
ps = repeat([-0.001], rc)
tspan = (0.0, 50.0)
u0 = 3.4 .+ ones(rc)
t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000))
prob = ODEProblem(f, u0, tspan, ps)
data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12))
@testset "$kernel" for kernel in [
bounded_support_kernels..., unbounded_support_kernels...]
u′, u = collocate_data(data, t, kernel, 0.003)
@test sum(abs2, u - data) < 1e-8
end
@testset "$kernel" for kernel in [bounded_support_kernels...]
# Errors out as the bandwidth is too low
@test_throws ErrorException collocate_data(data, t, kernel, 0.001)
end
end
@testset "Collocation of data" begin
f(u, p, t) = p .* u
rc = 2
ps = repeat([-0.001], rc)
tspan = (0.0, 50.0)
u0 = 3.4 .+ ones(rc)
t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000))
prob = ODEProblem(f, u0, tspan, ps)
data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12))
@testset "$kernel" for kernel in [
bounded_support_kernels..., unbounded_support_kernels...]
u′, u = collocate_data(data, t, kernel, 0.003)
@test sum(abs2, u - data) < 1e-8
end
@testset "$kernel" for kernel in [bounded_support_kernels...]
# Errors out as the bandwidth is too low
@test_throws ErrorException collocate_data(data, t, kernel, 0.001)
end
end
end
2 changes: 1 addition & 1 deletion test/neural_dae_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@
optprob = Optimization.OptimizationProblem(optfunc, p)
res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001))
end
end
end
Loading
Loading