From 3b8de3c71ad730b1d4a0ff6e828669b7cc2e8777 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sun, 26 Jan 2025 08:36:11 +0530 Subject: [PATCH 01/12] Added OTFlow Layers --- src/DiffEqFlux.jl | 3 +- src/otflow.jl | 152 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 src/otflow.jl diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index dc1b9d2a2..5056fd519 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -32,13 +32,14 @@ fixed_state_type(::Layers.HamiltonianNN{False}) = false include("ffjord.jl") include("neural_de.jl") - +include("otflow.jl") include("collocation.jl") include("multiple_shooting.jl") export NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, AugmentedNDELayer, NeuralODEMM export FFJORD, FFJORDDistribution +export OTFlow, OTFlowDistribution export DimMover export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel, TriweightKernel, diff --git a/src/otflow.jl b/src/otflow.jl new file mode 100644 index 000000000..addbf5ba7 --- /dev/null +++ b/src/otflow.jl @@ -0,0 +1,152 @@ +# Abstract type for CNF layers +abstract type CNFLayer <: AbstractLuxWrapperLayer{:model} end + +""" + OTFlow(model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...) + +Constructs a continuous-time neural network based on optimal transport (OT) theory, using +a potential function to define the dynamics and exact trace computation for the Jacobian. +This is a continuous normalizing flow (CNF) model specialized for density estimation. + +Arguments: + - `model`: A `Lux.AbstractLuxLayer` neural network that defines the potential function Φ. + - `basedist`: Distribution of the base variable. Set to the unit normal by default. + - `input_dims`: Input dimensions of the model. + - `tspan`: The timespan to be solved on. + - `args`: Additional arguments splatted to the ODE solver. + - `ad`: The automatic differentiation method to use for the internal Jacobian trace. + - `kwargs`: Additional arguments splatted to the ODE solver. +""" +@concrete struct OTFlow <: CNFLayer + model <: AbstractLuxLayer + basedist <: Union{Nothing, Distribution} + ad + input_dims + tspan + args + kwargs +end + +function LuxCore.initialstates(rng::AbstractRNG, n::OTFlow) + # Initialize the model's state and other parameters + model_st = LuxCore.initialstates(rng, n.model) + return (; model = model_st, regularize = false) +end + +function OTFlow( + model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs... +) + !(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model)) + return OTFlow(model, basedist, ad, input_dims, tspan, args, kwargs) +end + +# Dynamics function for OTFlow +function __otflow_dynamics(model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothing) where {T, N} + L = size(u, N - 1) + z = selectdim(u, N - 1, 1:(L - 1)) # Extract the state variables + @set! model.ps = p + + # Compute the potential function Φ(z) + Φ = model(z, p) + + # Compute the gradient of Φ(z) to get the dynamics v(z) = -∇Φ(z) + ∇Φ = gradient(z -> sum(model(z, p)), z)[1] + v = -∇Φ + + # Compute the trace of the Jacobian of the dynamics (∇v) + H = Zygote.hessian(z -> sum(model(z, p)), z) + trace_jac = tr(H) + + # Return the dynamics and the trace term + return cat(v, -reshape(trace_jac, ntuple(i -> 1, N - 1)..., :); dims = Val(N - 1)) +end + +# Forward pass for OTFlow +function (n::OTFlow)(x, ps, st) + return __forward_otflow(n, x, ps, st) +end + +function __forward_otflow(n::OTFlow, x::AbstractArray{T, N}, ps, st) where {T, N} + S = size(x) + (; regularize) = st + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) + + model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st.model) + + otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) + + _z = ChainRulesCore.@ignore_derivatives fill!( + similar(x, S[1:(N - 2)]..., 1, S[N]), zero(T)) + + prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), n.tspan, ps) + sol = solve(prob, n.args...; sensealg, n.kwargs..., + save_everystep = false, save_start = false, save_end = true) + pred = __get_pred(sol) + L = size(pred, N - 1) + + z = selectdim(pred, N - 1, 1:(L - 1)) + delta_logp = selectdim(pred, N - 1, L:L) + + if n.basedist === nothing + logpz = -sum(abs2, z; dims = 1:(N - 1)) / T(2) .- + T(prod(S[1:(N - 1)]) / 2 * log(2π)) + else + logpz = logpdf(n.basedist, z) + end + logpx = reshape(logpz, 1, S[N]) .- delta_logp + return (logpx,), (; model = model.st, regularize) +end + +# Backward pass for OTFlow +function __backward_otflow(::Type{T1}, n::OTFlow, n_samples::Int, ps, st, rng) where {T1} + px = n.basedist + + if px === nothing + x = rng === nothing ? randn(T1, (n.input_dims..., n_samples)) : + randn(rng, T1, (n.input_dims..., n_samples)) + else + x = rng === nothing ? rand(px, n_samples) : rand(rng, px, n_samples) + end + + N, S, T = ndims(x), size(x), eltype(x) + (; regularize) = st + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) + + model = StatefulLuxLayer{true}(n.model, nothing, st.model) + + otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) + + _z = ChainRulesCore.@ignore_derivatives fill!( + similar(x, S[1:(N - 2)]..., 1, S[N]), zero(T)) + + prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), reverse(n.tspan), ps) + sol = solve(prob, n.args...; sensealg, n.kwargs..., + save_everystep = false, save_start = false, save_end = true) + pred = __get_pred(sol) + L = size(pred, N - 1) + + return selectdim(pred, N - 1, 1:(L - 1)) +end + +# OTFlow can be used as a distribution +@concrete struct OTFlowDistribution <: ContinuousMultivariateDistribution + model <: OTFlow + ps + st +end + +Base.length(d::OTFlowDistribution) = prod(d.model.input_dims) +Base.eltype(d::OTFlowDistribution) = Lux.recursive_eltype(d.ps) + +function Distributions._logpdf(d::OTFlowDistribution, x::AbstractVector) + return first(first(__forward_otflow(d.model, reshape(x, :, 1), d.ps, d.st))) +end +function Distributions._logpdf(d::OTFlowDistribution, x::AbstractArray) + return first(first(__forward_otflow(d.model, x, d.ps, d.st))) +end +function Distributions._rand!( + rng::AbstractRNG, d::OTFlowDistribution, x::AbstractArray{<:Real} +) + copyto!(x, __backward_otflow(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng)) + return x +end \ No newline at end of file From 0db3b5b292259b6d00f09517f195705150ec7888 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sun, 26 Jan 2025 08:46:17 +0530 Subject: [PATCH 02/12] Formatted document --- src/otflow.jl | 174 +++++++++++++++++++++++++------------------------- 1 file changed, 87 insertions(+), 87 deletions(-) diff --git a/src/otflow.jl b/src/otflow.jl index addbf5ba7..68d571d6e 100644 --- a/src/otflow.jl +++ b/src/otflow.jl @@ -1,8 +1,8 @@ # Abstract type for CNF layers -abstract type CNFLayer <: AbstractLuxWrapperLayer{:model} end +abstract type OTFlowLayer <: AbstractLuxWrapperLayer{:model} end """ - OTFlow(model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...) + OTFlow(model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...) Constructs a continuous-time neural network based on optimal transport (OT) theory, using a potential function to define the dynamics and exact trace computation for the Jacobian. @@ -17,136 +17,136 @@ Arguments: - `ad`: The automatic differentiation method to use for the internal Jacobian trace. - `kwargs`: Additional arguments splatted to the ODE solver. """ -@concrete struct OTFlow <: CNFLayer - model <: AbstractLuxLayer - basedist <: Union{Nothing, Distribution} - ad - input_dims - tspan - args - kwargs +@concrete struct OTFlow <: OTFlowLayer + model <: AbstractLuxLayer + basedist <: Union{Nothing, Distribution} + ad::Any + input_dims::Any + tspan::Any + args::Any + kwargs::Any end function LuxCore.initialstates(rng::AbstractRNG, n::OTFlow) - # Initialize the model's state and other parameters - model_st = LuxCore.initialstates(rng, n.model) - return (; model = model_st, regularize = false) + # Initialize the model's state and other parameters + model_st = LuxCore.initialstates(rng, n.model) + return (; model = model_st, regularize = false) end function OTFlow( - model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs... + model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs..., ) - !(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model)) - return OTFlow(model, basedist, ad, input_dims, tspan, args, kwargs) + !(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model)) + return OTFlow(model, basedist, ad, input_dims, tspan, args, kwargs) end # Dynamics function for OTFlow function __otflow_dynamics(model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothing) where {T, N} - L = size(u, N - 1) - z = selectdim(u, N - 1, 1:(L - 1)) # Extract the state variables - @set! model.ps = p + L = size(u, N - 1) + z = selectdim(u, N - 1, 1:(L-1)) # Extract the state variables + @set! model.ps = p - # Compute the potential function Φ(z) - Φ = model(z, p) + # Compute the potential function Φ(z) + Φ = model(z, p) - # Compute the gradient of Φ(z) to get the dynamics v(z) = -∇Φ(z) - ∇Φ = gradient(z -> sum(model(z, p)), z)[1] - v = -∇Φ + # Compute the gradient of Φ(z) to get the dynamics v(z) = -∇Φ(z) + ∇Φ = gradient(z -> sum(model(z, p)), z)[1] + v = -∇Φ - # Compute the trace of the Jacobian of the dynamics (∇v) - H = Zygote.hessian(z -> sum(model(z, p)), z) - trace_jac = tr(H) + # Compute the trace of the Jacobian of the dynamics (∇v) + H = Zygote.hessian(z -> sum(model(z, p)), z) + trace_jac = tr(H) - # Return the dynamics and the trace term - return cat(v, -reshape(trace_jac, ntuple(i -> 1, N - 1)..., :); dims = Val(N - 1)) + # Return the dynamics and the trace term + return cat(v, -reshape(trace_jac, ntuple(i -> 1, N - 1)..., :); dims = Val(N - 1)) end # Forward pass for OTFlow function (n::OTFlow)(x, ps, st) - return __forward_otflow(n, x, ps, st) + return __forward_otflow(n, x, ps, st) end function __forward_otflow(n::OTFlow, x::AbstractArray{T, N}, ps, st) where {T, N} - S = size(x) - (; regularize) = st - sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) - - model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st.model) - - otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) - - _z = ChainRulesCore.@ignore_derivatives fill!( - similar(x, S[1:(N - 2)]..., 1, S[N]), zero(T)) - - prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), n.tspan, ps) - sol = solve(prob, n.args...; sensealg, n.kwargs..., - save_everystep = false, save_start = false, save_end = true) - pred = __get_pred(sol) - L = size(pred, N - 1) - - z = selectdim(pred, N - 1, 1:(L - 1)) - delta_logp = selectdim(pred, N - 1, L:L) - - if n.basedist === nothing - logpz = -sum(abs2, z; dims = 1:(N - 1)) / T(2) .- - T(prod(S[1:(N - 1)]) / 2 * log(2π)) - else - logpz = logpdf(n.basedist, z) - end - logpx = reshape(logpz, 1, S[N]) .- delta_logp - return (logpx,), (; model = model.st, regularize) + S = size(x) + (; regularize) = st + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) + + model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st.model) + + otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) + + _z = ChainRulesCore.@ignore_derivatives fill!( + similar(x, S[1:(N-2)]..., 1, S[N]), zero(T)) + + prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), n.tspan, ps) + sol = solve(prob, n.args...; sensealg, n.kwargs..., + save_everystep = false, save_start = false, save_end = true) + pred = __get_pred(sol) + L = size(pred, N - 1) + + z = selectdim(pred, N - 1, 1:(L-1)) + delta_logp = selectdim(pred, N - 1, L:L) + + if n.basedist === nothing + logpz = -sum(abs2, z; dims = 1:(N-1)) / T(2) .- + T(prod(S[1:(N-1)]) / 2 * log(2π)) + else + logpz = logpdf(n.basedist, z) + end + logpx = reshape(logpz, 1, S[N]) .- delta_logp + return (logpx,), (; model = model.st, regularize) end # Backward pass for OTFlow function __backward_otflow(::Type{T1}, n::OTFlow, n_samples::Int, ps, st, rng) where {T1} - px = n.basedist + px = n.basedist - if px === nothing - x = rng === nothing ? randn(T1, (n.input_dims..., n_samples)) : - randn(rng, T1, (n.input_dims..., n_samples)) - else - x = rng === nothing ? rand(px, n_samples) : rand(rng, px, n_samples) - end + if px === nothing + x = rng === nothing ? randn(T1, (n.input_dims..., n_samples)) : + randn(rng, T1, (n.input_dims..., n_samples)) + else + x = rng === nothing ? rand(px, n_samples) : rand(rng, px, n_samples) + end - N, S, T = ndims(x), size(x), eltype(x) - (; regularize) = st - sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) + N, S, T = ndims(x), size(x), eltype(x) + (; regularize) = st + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) - model = StatefulLuxLayer{true}(n.model, nothing, st.model) + model = StatefulLuxLayer{true}(n.model, nothing, st.model) - otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) + otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) - _z = ChainRulesCore.@ignore_derivatives fill!( - similar(x, S[1:(N - 2)]..., 1, S[N]), zero(T)) + _z = ChainRulesCore.@ignore_derivatives fill!( + similar(x, S[1:(N-2)]..., 1, S[N]), zero(T)) - prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), reverse(n.tspan), ps) - sol = solve(prob, n.args...; sensealg, n.kwargs..., - save_everystep = false, save_start = false, save_end = true) - pred = __get_pred(sol) - L = size(pred, N - 1) + prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), reverse(n.tspan), ps) + sol = solve(prob, n.args...; sensealg, n.kwargs..., + save_everystep = false, save_start = false, save_end = true) + pred = __get_pred(sol) + L = size(pred, N - 1) - return selectdim(pred, N - 1, 1:(L - 1)) + return selectdim(pred, N - 1, 1:(L-1)) end # OTFlow can be used as a distribution @concrete struct OTFlowDistribution <: ContinuousMultivariateDistribution - model <: OTFlow - ps - st + model <: OTFlow + ps::Any + st::Any end Base.length(d::OTFlowDistribution) = prod(d.model.input_dims) Base.eltype(d::OTFlowDistribution) = Lux.recursive_eltype(d.ps) function Distributions._logpdf(d::OTFlowDistribution, x::AbstractVector) - return first(first(__forward_otflow(d.model, reshape(x, :, 1), d.ps, d.st))) + return first(first(__forward_otflow(d.model, reshape(x, :, 1), d.ps, d.st))) end function Distributions._logpdf(d::OTFlowDistribution, x::AbstractArray) - return first(first(__forward_otflow(d.model, x, d.ps, d.st))) + return first(first(__forward_otflow(d.model, x, d.ps, d.st))) end function Distributions._rand!( - rng::AbstractRNG, d::OTFlowDistribution, x::AbstractArray{<:Real} + rng::AbstractRNG, d::OTFlowDistribution, x::AbstractArray{<:Real}, ) - copyto!(x, __backward_otflow(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng)) - return x -end \ No newline at end of file + copyto!(x, __backward_otflow(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng)) + return x +end From e496df6f8b0906134b58c229edc5f6f8021bcb56 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 13 Feb 2025 07:37:30 +0530 Subject: [PATCH 03/12] Added OTFlow tests --- test/neural_de_tests.jl | 2 +- test/otflow_tests.jl | 73 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 test/otflow_tests.jl diff --git a/test/neural_de_tests.jl b/test/neural_de_tests.jl index 8bbd35a23..80871eb6e 100644 --- a/test/neural_de_tests.jl +++ b/test/neural_de_tests.jl @@ -240,7 +240,7 @@ end @test first(layer(r, ps, st))[:, :, :, 1] == r[:, :, 1, :] end -@testitem "Neural DE CUDA" tags=[:cuda] skip=:(using LuxCUDA; !LuxCUDA.functional()) begin +@testset "Neural DE CUDA" tags=[:cuda] skip=:(using LuxCUDA; !LuxCUDA.functional()) begin using LuxCUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, ComponentArrays import Flux diff --git a/test/otflow_tests.jl b/test/otflow_tests.jl new file mode 100644 index 000000000..e2f1fb8da --- /dev/null +++ b/test/otflow_tests.jl @@ -0,0 +1,73 @@ +@testitem "OTFlow Tests" begin + using DiffEqFlux + using Lux + using Random + using Distributions + using Test + + rng = Random.default_rng() + + # Test basic constructor and initialization + @testset "Constructor" begin + model = Dense(2 => 1) + tspan = (0.0, 1.0) + input_dims = (2,) + + flow = OTFlow(model, tspan, input_dims) + @test flow isa OTFlow + @test flow.tspan == tspan + @test flow.input_dims == input_dims + + # Test with base distribution + base_dist = MvNormal(2, 1.0) + flow_with_dist = OTFlow(model, tspan, input_dims; basedist=base_dist) + @test flow_with_dist.basedist == base_dist + end + + @testset "Forward Pass" begin + model = Dense(2 => 1) + tspan = (0.0, 1.0) + input_dims = (2,) + flow = OTFlow(model, tspan, input_dims) + + ps, st = Lux.setup(rng, flow) + x = randn(rng, 2, 10) # 10 samples of 2D data + + # Test forward pass + output, new_st = flow(x, ps, st) + @test size(output) == (1, 10) # log probabilities + @test new_st isa NamedTuple + end + + @testset "Distribution Interface" begin + model = Dense(2 => 1) + tspan = (0.0, 1.0) + input_dims = (2,) + flow = OTFlow(model, tspan, input_dims) + + ps, st = Lux.setup(rng, flow) + dist = OTFlowDistribution(flow, ps, st) + + # Test sampling + x = rand(dist, 5) + @test size(x) == (2, 5) + + # Test log pdf + logp = logpdf(dist, x[:, 1]) + @test logp isa Real + end + + @testset "Base Distribution" begin + model = Dense(2 => 1) + tspan = (0.0, 1.0) + input_dims = (2,) + base_dist = MvNormal(zeros(2), I) + + flow = OTFlow(model, tspan, input_dims; basedist=base_dist) + ps, st = Lux.setup(rng, flow) + + x = randn(rng, 2, 5) + output, new_st = flow(x, ps, st) + @test size(output) == (1, 5) + end +end \ No newline at end of file From 56ac16be51c15244fce57650726b2396da5eb416 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sat, 22 Feb 2025 22:03:21 +0530 Subject: [PATCH 04/12] Added OTflow.jl with tests --- src/otflow.jl | 216 ++++++++++++++++--------------------------- test/otflow_tests.jl | 120 +++++++++++------------- 2 files changed, 137 insertions(+), 199 deletions(-) diff --git a/src/otflow.jl b/src/otflow.jl index 68d571d6e..83b5ad865 100644 --- a/src/otflow.jl +++ b/src/otflow.jl @@ -1,152 +1,98 @@ -# Abstract type for CNF layers -abstract type OTFlowLayer <: AbstractLuxWrapperLayer{:model} end - -""" - OTFlow(model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...) - -Constructs a continuous-time neural network based on optimal transport (OT) theory, using -a potential function to define the dynamics and exact trace computation for the Jacobian. -This is a continuous normalizing flow (CNF) model specialized for density estimation. - -Arguments: - - `model`: A `Lux.AbstractLuxLayer` neural network that defines the potential function Φ. - - `basedist`: Distribution of the base variable. Set to the unit normal by default. - - `input_dims`: Input dimensions of the model. - - `tspan`: The timespan to be solved on. - - `args`: Additional arguments splatted to the ODE solver. - - `ad`: The automatic differentiation method to use for the internal Jacobian trace. - - `kwargs`: Additional arguments splatted to the ODE solver. -""" -@concrete struct OTFlow <: OTFlowLayer - model <: AbstractLuxLayer - basedist <: Union{Nothing, Distribution} - ad::Any - input_dims::Any - tspan::Any - args::Any - kwargs::Any +struct OTFlow <: AbstractLuxLayer + d::Int # Input dimension + m::Int # Hidden dimension + r::Int # Rank for low-rank approximation end -function LuxCore.initialstates(rng::AbstractRNG, n::OTFlow) - # Initialize the model's state and other parameters - model_st = LuxCore.initialstates(rng, n.model) - return (; model = model_st, regularize = false) +# Constructor with default rank +OTFlow(d::Int, m::Int; r::Int=min(10,d)) = OTFlow(d, m, r) + +# Initialize parameters and states +function Lux.initialparameters(rng::AbstractRNG, l::OTFlow) + w = randn(rng, Float64, l.m) .* 0.01 + A = randn(rng, Float64, l.r, l.d + 1) .* 0.01 + b = zeros(Float64, l.d + 1) + c = zero(Float64) + K0 = randn(rng, Float64, l.m, l.d + 1) .* 0.01 + K1 = randn(rng, Float64, l.m, l.m) .* 0.01 + b0 = zeros(Float64, l.m) + b1 = zeros(Float64, l.m) + + return (w=w, A=A, b=b, c=c, K0=K0, K1=K1, b0=b0, b1=b1) end -function OTFlow( - model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs..., -) - !(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model)) - return OTFlow(model, basedist, ad, input_dims, tspan, args, kwargs) -end - -# Dynamics function for OTFlow -function __otflow_dynamics(model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothing) where {T, N} - L = size(u, N - 1) - z = selectdim(u, N - 1, 1:(L-1)) # Extract the state variables - @set! model.ps = p - - # Compute the potential function Φ(z) - Φ = model(z, p) +Lux.initialstates(::AbstractRNG, ::OTFlow) = NamedTuple() - # Compute the gradient of Φ(z) to get the dynamics v(z) = -∇Φ(z) - ∇Φ = gradient(z -> sum(model(z, p)), z)[1] - v = -∇Φ +σ(x) = log(exp(x) + exp(-x)) +σ′(x) = tanh(x) +σ′′(x) = 1 - tanh(x)^2 - # Compute the trace of the Jacobian of the dynamics (∇v) - H = Zygote.hessian(z -> sum(model(z, p)), z) - trace_jac = tr(H) - - # Return the dynamics and the trace term - return cat(v, -reshape(trace_jac, ntuple(i -> 1, N - 1)..., :); dims = Val(N - 1)) +function resnet_forward(x::AbstractVector, t::Real, ps) + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) # h=1 as in paper + return u1 end -# Forward pass for OTFlow -function (n::OTFlow)(x, ps, st) - return __forward_otflow(n, x, ps, st) +function potential(x::AbstractVector, t::Real, ps) + s = vcat(x, t) + N = resnet_forward(x, t, ps) + quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s + linear_term = ps.b' * s + return ps.w' * N + quadratic_term + linear_term + ps.c end -function __forward_otflow(n::OTFlow, x::AbstractArray{T, N}, ps, st) where {T, N} - S = size(x) - (; regularize) = st - sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) - - model = StatefulLuxLayer{fixed_state_type(n.model)}(n.model, nothing, st.model) - - otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) - - _z = ChainRulesCore.@ignore_derivatives fill!( - similar(x, S[1:(N-2)]..., 1, S[N]), zero(T)) - - prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), n.tspan, ps) - sol = solve(prob, n.args...; sensealg, n.kwargs..., - save_everystep = false, save_start = false, save_end = true) - pred = __get_pred(sol) - L = size(pred, N - 1) - - z = selectdim(pred, N - 1, 1:(L-1)) - delta_logp = selectdim(pred, N - 1, L:L) - - if n.basedist === nothing - logpz = -sum(abs2, z; dims = 1:(N-1)) / T(2) .- - T(prod(S[1:(N-1)]) / 2 * log(2π)) - else - logpz = logpdf(n.basedist, z) - end - logpx = reshape(logpz, 1, S[N]) .- delta_logp - return (logpx,), (; model = model.st, regularize) +function gradient(x::AbstractVector, t::Real, ps, d::Int) + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) + z0 = ps.K0' * (σ′.(ps.K0 * s .+ ps.b0) .* z1) + + grad = z0 + (ps.A' * ps.A) * s + ps.b + return grad[1:d] end -# Backward pass for OTFlow -function __backward_otflow(::Type{T1}, n::OTFlow, n_samples::Int, ps, st, rng) where {T1} - px = n.basedist - - if px === nothing - x = rng === nothing ? randn(T1, (n.input_dims..., n_samples)) : - randn(rng, T1, (n.input_dims..., n_samples)) - else - x = rng === nothing ? rand(px, n_samples) : rand(rng, px, n_samples) - end - - N, S, T = ndims(x), size(x), eltype(x) - (; regularize) = st - sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()) - - model = StatefulLuxLayer{true}(n.model, nothing, st.model) - - otflow_dynamics(u, p, t) = __otflow_dynamics(model, u, p, n.ad) - - _z = ChainRulesCore.@ignore_derivatives fill!( - similar(x, S[1:(N-2)]..., 1, S[N]), zero(T)) - - prob = ODEProblem{false}(otflow_dynamics, cat(x, _z; dims = Val(N - 1)), reverse(n.tspan), ps) - sol = solve(prob, n.args...; sensealg, n.kwargs..., - save_everystep = false, save_start = false, save_end = true) - pred = __get_pred(sol) - L = size(pred, N - 1) - - return selectdim(pred, N - 1, 1:(L-1)) +function trace(x::AbstractVector, t::Real, ps, d::Int) + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) + + K0_E = ps.K0[:, 1:d] + A_E = ps.A[:, 1:d] + + t0 = sum(σ′′.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) + J = Diagonal(σ′.(ps.K0 * s .+ ps.b0)) * K0_E + t1 = sum(σ′′.(ps.K1 * u0 .+ ps.b1) .* ps.w .* (ps.K1 * J) .^ 2) + trace_A = tr(A_E' * A_E) + + return t0 + t1 + trace_A end -# OTFlow can be used as a distribution -@concrete struct OTFlowDistribution <: ContinuousMultivariateDistribution - model <: OTFlow - ps::Any - st::Any +function (l::OTFlow)(xt::Tuple{AbstractVector, Real}, ps, st) + x, t = xt + v = -gradient(x, t, ps, l.d) # v = -∇Φ + tr = -trace(x, t, ps, l.d) # tr(∇v) = -tr(∇²Φ) + return (v, tr), st end -Base.length(d::OTFlowDistribution) = prod(d.model.input_dims) -Base.eltype(d::OTFlowDistribution) = Lux.recursive_eltype(d.ps) - -function Distributions._logpdf(d::OTFlowDistribution, x::AbstractVector) - return first(first(__forward_otflow(d.model, reshape(x, :, 1), d.ps, d.st))) -end -function Distributions._logpdf(d::OTFlowDistribution, x::AbstractArray) - return first(first(__forward_otflow(d.model, x, d.ps, d.st))) -end -function Distributions._rand!( - rng::AbstractRNG, d::OTFlowDistribution, x::AbstractArray{<:Real}, -) - copyto!(x, __backward_otflow(eltype(d), d.model, size(x, ndims(x)), d.ps, d.st, rng)) - return x +function simple_loss(x::AbstractVector, t::Real, l::OTFlow, ps) + (v, tr), _ = l((x, t), ps, nothing) + return sum(v.^2) / 2 - tr end + +function manual_gradient(x::AbstractVector, t::Real, l::OTFlow, ps) + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) + + v = -gradient(x, t, ps, l.d) + tr = -trace(x, t, ps, l.d) + + # Simplified gradients (not full implementation) + grad_w = u1 + grad_A = (ps.A * s) * s' + + return (w=grad_w, A=grad_A, b=similar(ps.b), c=0.0, + K0=zeros(l.m, l.d+1), K1=zeros(l.m, l.m), + b0=zeros(l.m), b1=zeros(l.m)) +end \ No newline at end of file diff --git a/test/otflow_tests.jl b/test/otflow_tests.jl index e2f1fb8da..cc834c696 100644 --- a/test/otflow_tests.jl +++ b/test/otflow_tests.jl @@ -1,73 +1,65 @@ -@testitem "OTFlow Tests" begin - using DiffEqFlux - using Lux - using Random - using Distributions - using Test +@testset "Tests for OTFlow Layer Functionality" begin + using Lux, LuxCore, Random, LinearAlgebra, Test, ComponentArrays, Flux, DiffEqFlux + rng = Xoshiro(0) + d = 2 + m = 4 + r = 2 + otflow = OTFlow(d, m; r=r) + ps, st = Lux.setup(rng, otflow) + ps = ComponentArray(ps) - rng = Random.default_rng() - - # Test basic constructor and initialization - @testset "Constructor" begin - model = Dense(2 => 1) - tspan = (0.0, 1.0) - input_dims = (2,) - - flow = OTFlow(model, tspan, input_dims) - @test flow isa OTFlow - @test flow.tspan == tspan - @test flow.input_dims == input_dims - - # Test with base distribution - base_dist = MvNormal(2, 1.0) - flow_with_dist = OTFlow(model, tspan, input_dims; basedist=base_dist) - @test flow_with_dist.basedist == base_dist - end + x = Float32[1.0, 2.0] + t = 0.5f0 @testset "Forward Pass" begin - model = Dense(2 => 1) - tspan = (0.0, 1.0) - input_dims = (2,) - flow = OTFlow(model, tspan, input_dims) - - ps, st = Lux.setup(rng, flow) - x = randn(rng, 2, 10) # 10 samples of 2D data - - # Test forward pass - output, new_st = flow(x, ps, st) - @test size(output) == (1, 10) # log probabilities - @test new_st isa NamedTuple + (v, tr), st_new = otflow((x, t), ps, st) + @test length(v) == d + @test isa(tr, Float32) + @test st_new == st + end + + @testset "Potential Function" begin + phi = potential(x, t, ps) + @test isa(phi, Float32) + end + + @testset "Gradient Consistency" begin + grad = gradient(x, t, ps, d) + (v, _), _ = otflow((x, t), ps, st) + @test length(grad) == d + @test grad ≈ -v atol=1e-5 # v = -∇Φ + end + + @testset "Trace Consistency" begin + tr_manual = trace(x, t, ps, d) + (_, tr_forward), _ = otflow((x, t), ps, st) + @test tr_manual ≈ -tr_forward atol=1e-5 + end + + @testset "ODE Integration" begin + x0 = Float32[1.0, 1.0] + tspan = (0.0f0, 1.0f0) + x_traj, t_vec = simple_ode_solve(otflow, x0, tspan, ps, st; dt=0.01f0) + @test size(x_traj) == (d, length(t_vec)) + @test all(isfinite, x_traj) + @test x_traj[:, end] != x0 end - @testset "Distribution Interface" begin - model = Dense(2 => 1) - tspan = (0.0, 1.0) - input_dims = (2,) - flow = OTFlow(model, tspan, input_dims) - - ps, st = Lux.setup(rng, flow) - dist = OTFlowDistribution(flow, ps, st) - - # Test sampling - x = rand(dist, 5) - @test size(x) == (2, 5) - - # Test log pdf - logp = logpdf(dist, x[:, 1]) - @test logp isa Real + @testset "Loss Function" begin + loss_val = simple_loss(x, t, otflow, ps) + @test isa(loss_val, Float32) + @test isfinite(loss_val) end - @testset "Base Distribution" begin - model = Dense(2 => 1) - tspan = (0.0, 1.0) - input_dims = (2,) - base_dist = MvNormal(zeros(2), I) - - flow = OTFlow(model, tspan, input_dims; basedist=base_dist) - ps, st = Lux.setup(rng, flow) - - x = randn(rng, 2, 5) - output, new_st = flow(x, ps, st) - @test size(output) == (1, 5) + @testset "Manual Gradient" begin + grads = manual_gradient(x, t, otflow, ps) + @test haskey(grads, :w) && length(grads.w) == m + @test haskey(grads, :A) && size(grads.A) == (r, d+1) + @test haskey(grads, :b) && length(grads.b) == d+1 + @test haskey(grads, :c) && isa(grads.c, Float32) + @test haskey(grads, :K0) && size(grads.K0) == (m, d+1) + @test haskey(grads, :K1) && size(grads.K1) == (m, m) + @test haskey(grads, :b0) && length(grads.b0) == m + @test haskey(grads, :b1) && length(grads.b1) == m end end \ No newline at end of file From de3a09d09030841d044d06f347c82941266dfbfc Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sat, 22 Feb 2025 23:36:39 +0530 Subject: [PATCH 05/12] Added @testitem blocks --- test/otflow_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/otflow_tests.jl b/test/otflow_tests.jl index cc834c696..513b0f644 100644 --- a/test/otflow_tests.jl +++ b/test/otflow_tests.jl @@ -1,4 +1,4 @@ -@testset "Tests for OTFlow Layer Functionality" begin +@testitem "Tests for OTFlow Layer Functionality" begin using Lux, LuxCore, Random, LinearAlgebra, Test, ComponentArrays, Flux, DiffEqFlux rng = Xoshiro(0) d = 2 From 3265ae9eb3bd23311a404464a6082ea86120d10c Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sun, 23 Feb 2025 12:10:54 +0530 Subject: [PATCH 06/12] Added testitem blocks --- test/cnf_tests.jl | 2 +- test/collocation_tests.jl | 14 +++++++------- test/neural_de_tests.jl | 28 ++++++++++++++-------------- test/otflow_tests.jl | 14 +++++++------- test/stiff_nested_ad_tests.jl | 2 +- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/cnf_tests.jl b/test/cnf_tests.jl index 322b1647d..cfd57a5bb 100644 --- a/test/cnf_tests.jl +++ b/test/cnf_tests.jl @@ -34,7 +34,7 @@ end return -mean(logpx) end - @testset "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(), + @testitem "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(), Optimization.AutoReverseDiff(), Optimization.AutoTracker(), Optimization.AutoZygote(), Optimization.AutoFiniteDiff()) @testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in ( diff --git a/test/collocation_tests.jl b/test/collocation_tests.jl index 756e77724..c709a3497 100644 --- a/test/collocation_tests.jl +++ b/test/collocation_tests.jl @@ -7,9 +7,9 @@ unbounded_support_kernels = [ GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()] - @testset "Kernel Functions" begin + @testitem "Kernel Functions" begin ts = collect(-5.0:0.1:5.0) - @testset "Kernels with support from -1 to 1" begin + @testitem "Kernels with support from -1 to 1" begin minus_one_index = findfirst(x -> ==(x, -1.0), ts) plus_one_index = findfirst(x -> ==(x, 1.0), ts) @testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels, @@ -25,8 +25,8 @@ @test DiffEqFlux.calckernel(kernel, 0.0) == x0 end end - @testset "Kernels with unbounded support" begin - @testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, + @testitem "Kernels with unbounded support" begin + @testitem "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))]) # t = 0 @test DiffEqFlux.calckernel(kernel, 0.0) == x0 @@ -34,7 +34,7 @@ end end - @testset "Collocation of data" begin + @testitem "Collocation of data" begin f(u, p, t) = p .* u rc = 2 ps = repeat([-0.001], rc) @@ -43,12 +43,12 @@ t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000)) prob = ODEProblem(f, u0, tspan, ps) data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12)) - @testset "$kernel" for kernel in [ + @testitem "$kernel" for kernel in [ bounded_support_kernels..., unbounded_support_kernels...] u′, u = collocate_data(data, t, kernel, 0.003) @test sum(abs2, u - data) < 1e-8 end - @testset "$kernel" for kernel in [bounded_support_kernels...] + @testitem "$kernel" for kernel in [bounded_support_kernels...] # Errors out as the bandwidth is too low @test_throws ErrorException collocate_data(data, t, kernel, 0.001) end diff --git a/test/neural_de_tests.jl b/test/neural_de_tests.jl index 80871eb6e..6fd05dc68 100644 --- a/test/neural_de_tests.jl +++ b/test/neural_de_tests.jl @@ -4,7 +4,7 @@ rng = Xoshiro(0) - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -22,8 +22,8 @@ Chain(Dense(4 => 50, tanh), Dense(50 => 4)) end - @testset "u0: $(typeof(u0))" for u0 in (x, xs) - @testset "kwargs: $(kwargs))" for kwargs in ( + @testitem "u0: $(typeof(u0))" for u0 in (x, xs) + @testitem "kwargs: $(kwargs))" for kwargs in ( (; save_everystep = false, save_start = false), (; abstol = 1e-12, reltol = 1e-12, save_everystep = false, save_start = false), @@ -58,7 +58,7 @@ end rng = Xoshiro(0) - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -122,7 +122,7 @@ end rng = Xoshiro(0) - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -154,7 +154,7 @@ end Chain(Dense(4 => 50, tanh), Dense(50 => 16), x -> reshape(x, 4, 4)) end - @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), + @testitem "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), solver in (EulerHeun(), LambaEM()) sode = NeuralSDE(dudt, diffusion_sde, tspan, 2, solver; @@ -187,7 +187,7 @@ end rng = Xoshiro(0) - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -240,7 +240,7 @@ end @test first(layer(r, ps, st))[:, :, :, 1] == r[:, :, 1, :] end -@testset "Neural DE CUDA" tags=[:cuda] skip=:(using LuxCUDA; !LuxCUDA.functional()) begin +@testitem "Neural DE CUDA" tags=[:cuda] begin using LuxCUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, ComponentArrays import Flux @@ -251,7 +251,7 @@ end const gdev = gpu_device() const cdev = cpu_device() - @testset "Neural DE" begin + @testitem "Neural DE" begin mp = Float32[0.1, 0.1] |> gdev x = Float32[2.0; 0.0] |> gdev xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) |> gdev @@ -260,9 +260,9 @@ end dudt = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) aug_dudt = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - @testset "Neural ODE" begin - @testset "u0: $(typeof(u0))" for u0 in (x, xs) - @testset "kwargs: $(kwargs))" for kwargs in ( + @testitem "Neural ODE" begin + @testitem "u0: $(typeof(u0))" for u0 in (x, xs) + @testitem "kwargs: $(kwargs))" for kwargs in ( (; save_everystep = false, save_start = false), (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), @@ -278,7 +278,7 @@ end st = st |> gdev broken = hasfield(typeof(kwargs), :sensealg) && ndims(u0) == 2 && - kwargs.sensealg isa TrackerAdjoint + kwargs[:sensealg] isa TrackerAdjoint @test begin grads = Zygote.gradient(sum ∘ last ∘ first ∘ node, u0, pd, st) CUDA.@allowscalar begin @@ -305,7 +305,7 @@ end aug_diffusion = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) tspan = (0.0f0, 0.1f0) - @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), + @testitem "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), solver in (SOSRI(),) # CuVector seems broken on CI but I can't reproduce the failure locally diff --git a/test/otflow_tests.jl b/test/otflow_tests.jl index 513b0f644..92ddab3f8 100644 --- a/test/otflow_tests.jl +++ b/test/otflow_tests.jl @@ -11,32 +11,32 @@ x = Float32[1.0, 2.0] t = 0.5f0 - @testset "Forward Pass" begin + @testitem "Forward Pass" begin (v, tr), st_new = otflow((x, t), ps, st) @test length(v) == d @test isa(tr, Float32) @test st_new == st end - @testset "Potential Function" begin + @testitem "Potential Function" begin phi = potential(x, t, ps) @test isa(phi, Float32) end - @testset "Gradient Consistency" begin + @testitem "Gradient Consistency" begin grad = gradient(x, t, ps, d) (v, _), _ = otflow((x, t), ps, st) @test length(grad) == d @test grad ≈ -v atol=1e-5 # v = -∇Φ end - @testset "Trace Consistency" begin + @testitem "Trace Consistency" begin tr_manual = trace(x, t, ps, d) (_, tr_forward), _ = otflow((x, t), ps, st) @test tr_manual ≈ -tr_forward atol=1e-5 end - @testset "ODE Integration" begin + @testitem "ODE Integration" begin x0 = Float32[1.0, 1.0] tspan = (0.0f0, 1.0f0) x_traj, t_vec = simple_ode_solve(otflow, x0, tspan, ps, st; dt=0.01f0) @@ -45,13 +45,13 @@ @test x_traj[:, end] != x0 end - @testset "Loss Function" begin + @testitem "Loss Function" begin loss_val = simple_loss(x, t, otflow, ps) @test isa(loss_val, Float32) @test isfinite(loss_val) end - @testset "Manual Gradient" begin + @testitem "Manual Gradient" begin grads = manual_gradient(x, t, otflow, ps) @test haskey(grads, :w) && length(grads.w) == m @test haskey(grads, :A) && size(grads.A) == (r, d+1) diff --git a/test/stiff_nested_ad_tests.jl b/test/stiff_nested_ad_tests.jl index 4742936f4..35de20692 100644 --- a/test/stiff_nested_ad_tests.jl +++ b/test/stiff_nested_ad_tests.jl @@ -27,7 +27,7 @@ end end - @testset "Solver: $(nameof(typeof(solver)))" for solver in ( + @testitem "Solver: $(nameof(typeof(solver)))" for solver in ( KenCarp4(), Rodas5(), RadauIIA5()) neuralde = NeuralODE(model, tspan, solver; saveat = t, reltol = 1e-7, abstol = 1e-9) ps, st = Lux.setup(Xoshiro(0), neuralde) From 0aa228ea2b53044fe9ba054944d226f20ae63999 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Mon, 24 Feb 2025 00:16:30 +0530 Subject: [PATCH 07/12] Changes Made --- src/otflow.jl | 24 ++++++++++++------------ test/collocation_tests.jl | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/otflow.jl b/src/otflow.jl index 83b5ad865..38cdd7578 100644 --- a/src/otflow.jl +++ b/src/otflow.jl @@ -9,20 +9,20 @@ OTFlow(d::Int, m::Int; r::Int=min(10,d)) = OTFlow(d, m, r) # Initialize parameters and states function Lux.initialparameters(rng::AbstractRNG, l::OTFlow) - w = randn(rng, Float64, l.m) .* 0.01 - A = randn(rng, Float64, l.r, l.d + 1) .* 0.01 - b = zeros(Float64, l.d + 1) - c = zero(Float64) - K0 = randn(rng, Float64, l.m, l.d + 1) .* 0.01 - K1 = randn(rng, Float64, l.m, l.m) .* 0.01 - b0 = zeros(Float64, l.m) - b1 = zeros(Float64, l.m) + w = randn(rng, Float32, l.m) .* 0.01 + A = randn(rng, Float32, l.r, l.d + 1) .* 0.01 + b = zeros(Float32, l.d + 1) + c = zero(Float32) + K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01 + K1 = randn(rng, Float32, l.m, l.m) .* 0.01 + b0 = zeros(Float32, l.m) + b1 = zeros(Float32, l.m) - return (w=w, A=A, b=b, c=c, K0=K0, K1=K1, b0=b0, b1=b1) + ps = (w=w, A=A, b=b, c=c, K0=K0, K1=K1, b0=b0, b1=b1) + st = NamedTuple() + return ps, st end -Lux.initialstates(::AbstractRNG, ::OTFlow) = NamedTuple() - σ(x) = log(exp(x) + exp(-x)) σ′(x) = tanh(x) σ′′(x) = 1 - tanh(x)^2 @@ -76,7 +76,7 @@ function (l::OTFlow)(xt::Tuple{AbstractVector, Real}, ps, st) end function simple_loss(x::AbstractVector, t::Real, l::OTFlow, ps) - (v, tr), _ = l((x, t), ps, nothing) + (v, tr), _ = l((x, t), ps, NamedTuple()) return sum(v.^2) / 2 - tr end diff --git a/test/collocation_tests.jl b/test/collocation_tests.jl index c709a3497..0fc971512 100644 --- a/test/collocation_tests.jl +++ b/test/collocation_tests.jl @@ -7,7 +7,7 @@ unbounded_support_kernels = [ GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()] - @testitem "Kernel Functions" begin + @testset "Kernel Functions" begin ts = collect(-5.0:0.1:5.0) @testitem "Kernels with support from -1 to 1" begin minus_one_index = findfirst(x -> ==(x, -1.0), ts) From b31270c469d6c623fa34eef627e81df562270b05 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Mon, 24 Feb 2025 00:26:47 +0530 Subject: [PATCH 08/12] Changes --- test/collocation_tests.jl | 14 +++++++------- test/neural_dae_tests.jl | 2 +- test/neural_de_tests.jl | 28 ++++++++++++++-------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/test/collocation_tests.jl b/test/collocation_tests.jl index 0fc971512..1c728a2a6 100644 --- a/test/collocation_tests.jl +++ b/test/collocation_tests.jl @@ -9,7 +9,7 @@ @testset "Kernel Functions" begin ts = collect(-5.0:0.1:5.0) - @testitem "Kernels with support from -1 to 1" begin + @testset "Kernels with support from -1 to 1" begin minus_one_index = findfirst(x -> ==(x, -1.0), ts) plus_one_index = findfirst(x -> ==(x, 1.0), ts) @testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels, @@ -25,8 +25,8 @@ @test DiffEqFlux.calckernel(kernel, 0.0) == x0 end end - @testitem "Kernels with unbounded support" begin - @testitem "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, + @testset "Kernels with unbounded support" begin + @testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))]) # t = 0 @test DiffEqFlux.calckernel(kernel, 0.0) == x0 @@ -34,7 +34,7 @@ end end - @testitem "Collocation of data" begin + @testset "Collocation of data" begin f(u, p, t) = p .* u rc = 2 ps = repeat([-0.001], rc) @@ -43,14 +43,14 @@ t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000)) prob = ODEProblem(f, u0, tspan, ps) data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12)) - @testitem "$kernel" for kernel in [ + @testset "$kernel" for kernel in [ bounded_support_kernels..., unbounded_support_kernels...] u′, u = collocate_data(data, t, kernel, 0.003) @test sum(abs2, u - data) < 1e-8 end - @testitem "$kernel" for kernel in [bounded_support_kernels...] + @testset "$kernel" for kernel in [bounded_support_kernels...] # Errors out as the bandwidth is too low @test_throws ErrorException collocate_data(data, t, kernel, 0.001) end end -end +end \ No newline at end of file diff --git a/test/neural_dae_tests.jl b/test/neural_dae_tests.jl index ffc812a5a..e33645b4f 100644 --- a/test/neural_dae_tests.jl +++ b/test/neural_dae_tests.jl @@ -69,4 +69,4 @@ optprob = Optimization.OptimizationProblem(optfunc, p) res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001)) end -end +end \ No newline at end of file diff --git a/test/neural_de_tests.jl b/test/neural_de_tests.jl index 6fd05dc68..ab59dd382 100644 --- a/test/neural_de_tests.jl +++ b/test/neural_de_tests.jl @@ -4,7 +4,7 @@ rng = Xoshiro(0) - @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -22,8 +22,8 @@ Chain(Dense(4 => 50, tanh), Dense(50 => 4)) end - @testitem "u0: $(typeof(u0))" for u0 in (x, xs) - @testitem "kwargs: $(kwargs))" for kwargs in ( + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ( (; save_everystep = false, save_start = false), (; abstol = 1e-12, reltol = 1e-12, save_everystep = false, save_start = false), @@ -58,7 +58,7 @@ end rng = Xoshiro(0) - @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -122,7 +122,7 @@ end rng = Xoshiro(0) - @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -154,7 +154,7 @@ end Chain(Dense(4 => 50, tanh), Dense(50 => 16), x -> reshape(x, 4, 4)) end - @testitem "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), + @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), solver in (EulerHeun(), LambaEM()) sode = NeuralSDE(dudt, diffusion_sde, tspan, 2, solver; @@ -187,7 +187,7 @@ end rng = Xoshiro(0) - @testitem "$(nnlib)" for nnlib in ("Flux", "Lux") + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") mp = Float32[0.1, 0.1] x = Float32[2.0; 0.0] xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) @@ -251,7 +251,7 @@ end const gdev = gpu_device() const cdev = cpu_device() - @testitem "Neural DE" begin + @testset "Neural DE" begin mp = Float32[0.1, 0.1] |> gdev x = Float32[2.0; 0.0] |> gdev xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) |> gdev @@ -260,9 +260,9 @@ end dudt = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) aug_dudt = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - @testitem "Neural ODE" begin - @testitem "u0: $(typeof(u0))" for u0 in (x, xs) - @testitem "kwargs: $(kwargs))" for kwargs in ( + @testset "Neural ODE" begin + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ( (; save_everystep = false, save_start = false), (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), @@ -278,7 +278,7 @@ end st = st |> gdev broken = hasfield(typeof(kwargs), :sensealg) && ndims(u0) == 2 && - kwargs[:sensealg] isa TrackerAdjoint + kwargs.sensealg isa TrackerAdjoint @test begin grads = Zygote.gradient(sum ∘ last ∘ first ∘ node, u0, pd, st) CUDA.@allowscalar begin @@ -305,7 +305,7 @@ end aug_diffusion = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) tspan = (0.0f0, 0.1f0) - @testitem "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), + @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), solver in (SOSRI(),) # CuVector seems broken on CI but I can't reproduce the failure locally @@ -323,4 +323,4 @@ end end end end -end +end \ No newline at end of file From 619a567ddfc3590b02296a622cbd884771cf953e Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Tue, 25 Feb 2025 09:30:23 +0530 Subject: [PATCH 09/12] Made required changes --- src/otflow.jl | 132 +++---- test/cnf_tests.jl | 2 +- test/collocation_tests.jl | 104 +++--- test/neural_de_tests.jl | 628 +++++++++++++++++----------------- test/otflow_tests.jl | 112 +++--- test/stiff_nested_ad_tests.jl | 2 +- 6 files changed, 490 insertions(+), 490 deletions(-) diff --git a/src/otflow.jl b/src/otflow.jl index 38cdd7578..2c7639a1c 100644 --- a/src/otflow.jl +++ b/src/otflow.jl @@ -1,26 +1,26 @@ struct OTFlow <: AbstractLuxLayer - d::Int # Input dimension - m::Int # Hidden dimension - r::Int # Rank for low-rank approximation + d::Int # Input dimension + m::Int # Hidden dimension + r::Int # Rank for low-rank approximation end # Constructor with default rank -OTFlow(d::Int, m::Int; r::Int=min(10,d)) = OTFlow(d, m, r) +OTFlow(d::Int, m::Int; r::Int = min(10, d)) = OTFlow(d, m, r) # Initialize parameters and states function Lux.initialparameters(rng::AbstractRNG, l::OTFlow) - w = randn(rng, Float32, l.m) .* 0.01 - A = randn(rng, Float32, l.r, l.d + 1) .* 0.01 - b = zeros(Float32, l.d + 1) - c = zero(Float32) - K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01 - K1 = randn(rng, Float32, l.m, l.m) .* 0.01 - b0 = zeros(Float32, l.m) - b1 = zeros(Float32, l.m) - - ps = (w=w, A=A, b=b, c=c, K0=K0, K1=K1, b0=b0, b1=b1) - st = NamedTuple() - return ps, st + w = randn(rng, Float32, l.m) .* 0.01 + A = randn(rng, Float32, l.r, l.d + 1) .* 0.01 + b = zeros(Float32, l.d + 1) + c = zero(Float32) + K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01 + K1 = randn(rng, Float32, l.m, l.m) .* 0.01 + b0 = zeros(Float32, l.m) + b1 = zeros(Float32, l.m) + + ps = (w = w, A = A, b = b, c = c, K0 = K0, K1 = K1, b0 = b0, b1 = b1) + st = NamedTuple() + return ps, st end σ(x) = log(exp(x) + exp(-x)) @@ -28,71 +28,71 @@ end σ′′(x) = 1 - tanh(x)^2 function resnet_forward(x::AbstractVector, t::Real, ps) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) # h=1 as in paper - return u1 + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) # h=1 as in paper + return u1 end function potential(x::AbstractVector, t::Real, ps) - s = vcat(x, t) - N = resnet_forward(x, t, ps) - quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s - linear_term = ps.b' * s - return ps.w' * N + quadratic_term + linear_term + ps.c + s = vcat(x, t) + N = resnet_forward(x, t, ps) + quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s + linear_term = ps.b' * s + return ps.w' * N + quadratic_term + linear_term + ps.c end function gradient(x::AbstractVector, t::Real, ps, d::Int) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) - z0 = ps.K0' * (σ′.(ps.K0 * s .+ ps.b0) .* z1) - - grad = z0 + (ps.A' * ps.A) * s + ps.b - return grad[1:d] + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) + z0 = ps.K0' * (σ′.(ps.K0 * s .+ ps.b0) .* z1) + + grad = z0 + (ps.A' * ps.A) * s + ps.b + return grad[1:d] end function trace(x::AbstractVector, t::Real, ps, d::Int) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) - - K0_E = ps.K0[:, 1:d] - A_E = ps.A[:, 1:d] - - t0 = sum(σ′′.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) - J = Diagonal(σ′.(ps.K0 * s .+ ps.b0)) * K0_E - t1 = sum(σ′′.(ps.K1 * u0 .+ ps.b1) .* ps.w .* (ps.K1 * J) .^ 2) - trace_A = tr(A_E' * A_E) - - return t0 + t1 + trace_A + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) + + K0_E = ps.K0[:, 1:d] + A_E = ps.A[:, 1:d] + + t0 = sum(σ′′.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) + J = Diagonal(σ′.(ps.K0 * s .+ ps.b0)) * K0_E + t1 = sum(σ′′.(ps.K1 * u0 .+ ps.b1) .* ps.w .* (ps.K1 * J) .^ 2) + trace_A = tr(A_E' * A_E) + + return t0 + t1 + trace_A end function (l::OTFlow)(xt::Tuple{AbstractVector, Real}, ps, st) - x, t = xt - v = -gradient(x, t, ps, l.d) # v = -∇Φ - tr = -trace(x, t, ps, l.d) # tr(∇v) = -tr(∇²Φ) - return (v, tr), st + x, t = xt + v = -gradient(x, t, ps, l.d) # v = -∇Φ + tr = -trace(x, t, ps, l.d) # tr(∇v) = -tr(∇²Φ) + return (v, tr), st end function simple_loss(x::AbstractVector, t::Real, l::OTFlow, ps) - (v, tr), _ = l((x, t), ps, NamedTuple()) - return sum(v.^2) / 2 - tr + (v, tr), _ = l((x, t), ps, NamedTuple()) + return sum(v .^ 2) / 2 - tr end function manual_gradient(x::AbstractVector, t::Real, l::OTFlow, ps) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) - - v = -gradient(x, t, ps, l.d) - tr = -trace(x, t, ps, l.d) - - # Simplified gradients (not full implementation) - grad_w = u1 - grad_A = (ps.A * s) * s' - - return (w=grad_w, A=grad_A, b=similar(ps.b), c=0.0, - K0=zeros(l.m, l.d+1), K1=zeros(l.m, l.m), - b0=zeros(l.m), b1=zeros(l.m)) -end \ No newline at end of file + s = vcat(x, t) + u0 = σ.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) + + v = -gradient(x, t, ps, l.d) + tr = -trace(x, t, ps, l.d) + + # Simplified gradients (not full implementation) + grad_w = u1 + grad_A = (ps.A * s) * s' + + return (w = grad_w, A = grad_A, b = similar(ps.b), c = 0.0, + K0 = zeros(l.m, l.d + 1), K1 = zeros(l.m, l.m), + b0 = zeros(l.m), b1 = zeros(l.m)) +end diff --git a/test/cnf_tests.jl b/test/cnf_tests.jl index cfd57a5bb..322b1647d 100644 --- a/test/cnf_tests.jl +++ b/test/cnf_tests.jl @@ -34,7 +34,7 @@ end return -mean(logpx) end - @testitem "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(), + @testset "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(), Optimization.AutoReverseDiff(), Optimization.AutoTracker(), Optimization.AutoZygote(), Optimization.AutoFiniteDiff()) @testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in ( diff --git a/test/collocation_tests.jl b/test/collocation_tests.jl index 1c728a2a6..d25e800ba 100644 --- a/test/collocation_tests.jl +++ b/test/collocation_tests.jl @@ -1,56 +1,56 @@ -@testitem "Collocation" tags=[:layers] begin - using OrdinaryDiffEq +@testitem "Collocation" tags = [:layers] begin + using OrdinaryDiffEq - bounded_support_kernels = [EpanechnikovKernel(), UniformKernel(), TriangularKernel(), - QuarticKernel(), TriweightKernel(), TricubeKernel(), CosineKernel()] + bounded_support_kernels = [EpanechnikovKernel(), UniformKernel(), TriangularKernel(), + QuarticKernel(), TriweightKernel(), TricubeKernel(), CosineKernel()] - unbounded_support_kernels = [ - GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()] + unbounded_support_kernels = [ + GaussianKernel(), LogisticKernel(), SigmoidKernel(), SilvermanKernel()] - @testset "Kernel Functions" begin - ts = collect(-5.0:0.1:5.0) - @testset "Kernels with support from -1 to 1" begin - minus_one_index = findfirst(x -> ==(x, -1.0), ts) - plus_one_index = findfirst(x -> ==(x, 1.0), ts) - @testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels, - [0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0]) - ws = DiffEqFlux.calckernel.((kernel,), ts) - # t < -1 - @test all(ws[1:(minus_one_index - 1)] .== 0.0) - # t > 1 - @test all(ws[(plus_one_index + 1):end] .== 0.0) - # -1 < t <1 - @test all(ws[(minus_one_index + 1):(plus_one_index - 1)] .> 0.0) - # t = 0 - @test DiffEqFlux.calckernel(kernel, 0.0) == x0 - end - end - @testset "Kernels with unbounded support" begin - @testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, - [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))]) - # t = 0 - @test DiffEqFlux.calckernel(kernel, 0.0) == x0 - end - end - end + @testset "Kernel Functions" begin + ts = collect(-5.0:0.1:5.0) + @testset "Kernels with support from -1 to 1" begin + minus_one_index = findfirst(x -> ==(x, -1.0), ts) + plus_one_index = findfirst(x -> ==(x, 1.0), ts) + @testset "$kernel" for (kernel, x0) in zip(bounded_support_kernels, + [0.75, 0.50, 1.0, 15.0 / 16.0, 35.0 / 32.0, 70.0 / 81.0, pi / 4.0]) + ws = DiffEqFlux.calckernel.((kernel,), ts) + # t < -1 + @test all(ws[1:(minus_one_index-1)] .== 0.0) + # t > 1 + @test all(ws[(plus_one_index+1):end] .== 0.0) + # -1 < t <1 + @test all(ws[(minus_one_index+1):(plus_one_index-1)] .> 0.0) + # t = 0 + @test DiffEqFlux.calckernel(kernel, 0.0) == x0 + end + end + @testset "Kernels with unbounded support" begin + @testset "$kernel" for (kernel, x0) in zip(unbounded_support_kernels, + [1 / (sqrt(2 * pi)), 0.25, 1 / pi, 1 / (2 * sqrt(2))]) + # t = 0 + @test DiffEqFlux.calckernel(kernel, 0.0) == x0 + end + end + end - @testset "Collocation of data" begin - f(u, p, t) = p .* u - rc = 2 - ps = repeat([-0.001], rc) - tspan = (0.0, 50.0) - u0 = 3.4 .+ ones(rc) - t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000)) - prob = ODEProblem(f, u0, tspan, ps) - data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12)) - @testset "$kernel" for kernel in [ - bounded_support_kernels..., unbounded_support_kernels...] - u′, u = collocate_data(data, t, kernel, 0.003) - @test sum(abs2, u - data) < 1e-8 - end - @testset "$kernel" for kernel in [bounded_support_kernels...] - # Errors out as the bandwidth is too low - @test_throws ErrorException collocate_data(data, t, kernel, 0.001) - end - end -end \ No newline at end of file + @testset "Collocation of data" begin + f(u, p, t) = p .* u + rc = 2 + ps = repeat([-0.001], rc) + tspan = (0.0, 50.0) + u0 = 3.4 .+ ones(rc) + t = collect(range(minimum(tspan); stop = maximum(tspan), length = 1000)) + prob = ODEProblem(f, u0, tspan, ps) + data = Array(solve(prob, Tsit5(); saveat = t, abstol = 1e-12, reltol = 1e-12)) + @testset "$kernel" for kernel in [ + bounded_support_kernels..., unbounded_support_kernels...] + u′, u = collocate_data(data, t, kernel, 0.003) + @test sum(abs2, u - data) < 1e-8 + end + @testset "$kernel" for kernel in [bounded_support_kernels...] + # Errors out as the bandwidth is too low + @test_throws ErrorException collocate_data(data, t, kernel, 0.001) + end + end +end diff --git a/test/neural_de_tests.jl b/test/neural_de_tests.jl index ab59dd382..25ea98c86 100644 --- a/test/neural_de_tests.jl +++ b/test/neural_de_tests.jl @@ -1,326 +1,326 @@ -@testitem "NeuralODE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - @testset "u0: $(typeof(u0))" for u0 in (x, xs) - @testset "kwargs: $(kwargs))" for kwargs in ( - (; save_everystep = false, save_start = false), - (; abstol = 1e-12, reltol = 1e-12, - save_everystep = false, save_start = false), - (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), - (; save_everystep = false, save_start = false, - sensealg = BacksolveAdjoint()), - (; saveat = 0.0f0:0.1f0:1.0f0), - (; saveat = 0.1f0), - (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), - (; saveat = 0.1f0, sensealg = TrackerAdjoint())) - node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) - pd, st = Lux.setup(rng, node) - pd = ComponentArray(pd) - grads = Zygote.gradient(sum ∘ first ∘ node, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - anode = AugmentedNDELayer(NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - end - end - end +@testitem "NeuralODE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ( + (; save_everystep = false, save_start = false), + (; abstol = 1e-12, reltol = 1e-12, + save_everystep = false, save_start = false), + (; save_everystep = false, save_start = false, sensealg = TrackerAdjoint()), + (; save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()), + (; saveat = 0.0f0:0.1f0:1.0f0), + (; saveat = 0.1f0), + (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), + (; saveat = 0.1f0, sensealg = TrackerAdjoint())) + node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) + pd, st = Lux.setup(rng, node) + pd = ComponentArray(pd) + grads = Zygote.gradient(sum ∘ first ∘ node, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + anode = AugmentedNDELayer(NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end + end + end end -@testitem "NeuralDSDE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - diffusion = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_diffusion = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - tspan = (0.0f0, 0.1f0) - @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x, xs), - solver in (EulerHeun(), LambaEM(), SOSRI()) - - sode = NeuralDSDE( - dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - pd, st = Lux.setup(rng, sode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - - sode = NeuralDSDE(aug_dudt, aug_diffusion, tspan, solver; - saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - anode = AugmentedNDELayer(sode, 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - end - end +@testitem "NeuralDSDE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_diffusion = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + tspan = (0.0f0, 0.1f0) + @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x, xs), + solver in (EulerHeun(), LambaEM(), SOSRI()) + + sode = NeuralDSDE( + dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + + sode = NeuralDSDE(aug_dudt, aug_diffusion, tspan, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + anode = AugmentedNDELayer(sode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + end + end end -@testitem "NeuralSDE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - end - - diffusion_sde = if nnlib == "Flux" - Flux.Chain( - Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 4), x -> reshape(x, 2, 2)) - else - Chain(Dense(2 => 50, tanh), Dense(50 => 4), x -> reshape(x, 2, 2)) - end - - aug_diffusion_sde = if nnlib == "Flux" - Flux.Chain( - Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 16), x -> reshape(x, 4, 4)) - else - Chain(Dense(4 => 50, tanh), Dense(50 => 16), x -> reshape(x, 4, 4)) - end - - @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), - solver in (EulerHeun(), LambaEM()) - - sode = NeuralSDE(dudt, diffusion_sde, tspan, 2, solver; - saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - pd, st = Lux.setup(rng, sode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - - sode = NeuralSDE(aug_dudt, aug_diffusion_sde, tspan, 4, solver; - saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - anode = AugmentedNDELayer(sode, 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - @test !iszero(grads[2][end]) - end - end +@testitem "NeuralSDE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + end + + diffusion_sde = if nnlib == "Flux" + Flux.Chain( + Flux.Dense(2 => 50, tanh), Flux.Dense(50 => 4), x -> reshape(x, 2, 2)) + else + Chain(Dense(2 => 50, tanh), Dense(50 => 4), x -> reshape(x, 2, 2)) + end + + aug_diffusion_sde = if nnlib == "Flux" + Flux.Chain( + Flux.Dense(4 => 50, tanh), Flux.Dense(50 => 16), x -> reshape(x, 4, 4)) + else + Chain(Dense(4 => 50, tanh), Dense(50 => 16), x -> reshape(x, 4, 4)) + end + + @testset "u0: $(typeof(u0)), solver: $(solver)" for u0 in (x,), + solver in (EulerHeun(), LambaEM()) + + sode = NeuralSDE(dudt, diffusion_sde, tspan, 2, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ sode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + + sode = NeuralSDE(aug_dudt, aug_diffusion_sde, tspan, 4, solver; + saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + anode = AugmentedNDELayer(sode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + @test !iszero(grads[2][end]) + end + end end -@testitem "NeuralCDDE" tags=[:basicneuralde] begin - using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random - import Flux - - rng = Xoshiro(0) - - @testset "$(nnlib)" for nnlib in ("Flux", "Lux") - mp = Float32[0.1, 0.1] - x = Float32[2.0; 0.0] - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) - tspan = (0.0f0, 1.0f0) - - dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(6 => 50, tanh), Flux.Dense(50 => 2)) - else - Chain(Dense(6 => 50, tanh), Dense(50 => 2)) - end - - aug_dudt = if nnlib == "Flux" - Flux.Chain(Flux.Dense(12 => 50, tanh), Flux.Dense(50 => 4)) - else - Chain(Dense(12 => 50, tanh), Dense(50 => 4)) - end - - @testset "NeuralCDDE u0: $(typeof(u0))" for u0 in (x, xs) - dode = NeuralCDDE(dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), - MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) - pd, st = Lux.setup(rng, dode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ dode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - - dode = NeuralCDDE( - aug_dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), - MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) - anode = AugmentedNDELayer(dode, 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) - - grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) - @test !iszero(grads[1]) - @test !iszero(grads[2]) - end - end +@testitem "NeuralCDDE" tags = [:basicneuralde] begin + using ComponentArrays, Zygote, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Random + using Flux: Flux + + rng = Xoshiro(0) + + @testset "$(nnlib)" for nnlib in ("Flux", "Lux") + mp = Float32[0.1, 0.1] + x = Float32[2.0; 0.0] + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) + tspan = (0.0f0, 1.0f0) + + dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(6 => 50, tanh), Flux.Dense(50 => 2)) + else + Chain(Dense(6 => 50, tanh), Dense(50 => 2)) + end + + aug_dudt = if nnlib == "Flux" + Flux.Chain(Flux.Dense(12 => 50, tanh), Flux.Dense(50 => 4)) + else + Chain(Dense(12 => 50, tanh), Dense(50 => 4)) + end + + @testset "NeuralCDDE u0: $(typeof(u0))" for u0 in (x, xs) + dode = NeuralCDDE(dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), + MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) + pd, st = Lux.setup(rng, dode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ dode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + dode = NeuralCDDE( + aug_dudt, (0.0f0, 2.0f0), (u, p, t) -> zero(u), (0.1f0, 0.2f0), + MethodOfSteps(Tsit5()); saveat = 0.0f0:0.1f0:2.0f0) + anode = AugmentedNDELayer(dode, 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) + + grads = Zygote.gradient(sum ∘ first ∘ anode, u0, pd, st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + end + end end -@testitem "DimMover" tags=[:basicneuralde] begin - using Random +@testitem "DimMover" tags = [:basicneuralde] begin + using Random - rng = Xoshiro(0) - r = rand(2, 3, 4, 5) - layer = DimMover() - ps, st = Lux.setup(rng, layer) + rng = Xoshiro(0) + r = rand(2, 3, 4, 5) + layer = DimMover() + ps, st = Lux.setup(rng, layer) - @test first(layer(r, ps, st))[:, :, :, 1] == r[:, :, 1, :] + @test first(layer(r, ps, st))[:, :, :, 1] == r[:, :, 1, :] end -@testitem "Neural DE CUDA" tags=[:cuda] begin - using LuxCUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, ComponentArrays - import Flux - - CUDA.allowscalar(false) - - rng = Xoshiro(0) - - const gdev = gpu_device() - const cdev = cpu_device() - - @testset "Neural DE" begin - mp = Float32[0.1, 0.1] |> gdev - x = Float32[2.0; 0.0] |> gdev - xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) |> gdev - tspan = (0.0f0, 1.0f0) - - dudt = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - aug_dudt = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - - @testset "Neural ODE" begin - @testset "u0: $(typeof(u0))" for u0 in (x, xs) - @testset "kwargs: $(kwargs))" for kwargs in ( - (; save_everystep = false, save_start = false), - (; save_everystep = false, save_start = false, - sensealg = TrackerAdjoint()), - (; save_everystep = false, save_start = false, - sensealg = BacksolveAdjoint()), - (; saveat = 0.0f0:0.1f0:1.0f0), - (; saveat = 0.1f0), - (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), - (; saveat = 0.1f0, sensealg = TrackerAdjoint())) - node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) - pd, st = Lux.setup(rng, node) - pd = ComponentArray(pd) |> gdev - st = st |> gdev - broken = hasfield(typeof(kwargs), :sensealg) && - ndims(u0) == 2 && - kwargs.sensealg isa TrackerAdjoint - @test begin - grads = Zygote.gradient(sum ∘ last ∘ first ∘ node, u0, pd, st) - CUDA.@allowscalar begin - !iszero(grads[1]) && !iszero(grads[2]) - end - end broken=broken - - anode = AugmentedNDELayer( - NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) - pd, st = Lux.setup(rng, anode) - pd = ComponentArray(pd) |> gdev - st = st |> gdev - @test begin - grads = Zygote.gradient(sum ∘ last ∘ first ∘ anode, u0, pd, st) - CUDA.@allowscalar begin - !iszero(grads[1]) && !iszero(grads[2]) - end - end broken=broken - end - end - end - - diffusion = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) - aug_diffusion = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) - - tspan = (0.0f0, 0.1f0) - @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), - solver in (SOSRI(),) - # CuVector seems broken on CI but I can't reproduce the failure locally - - sode = NeuralDSDE( - dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) - pd, st = Lux.setup(rng, sode) - pd = ComponentArray(pd) |> gdev - st = st |> gdev - - @test_broken begin - grads = Zygote.gradient(sum ∘ last ∘ first ∘ sode, u0, pd, st) - CUDA.@allowscalar begin - !iszero(grads[1]) && !iszero(grads[2]) && !iszero(grads[2][end]) - end - end - end - end -end \ No newline at end of file +@testitem "Neural DE CUDA" tags = [:cuda] skip = :(using LuxCUDA; !LuxCUDA.functional()) begin + using LuxCUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, ComponentArrays + using Flux: Flux + + CUDA.allowscalar(false) + + rng = Xoshiro(0) + + const gdev = gpu_device() + const cdev = cpu_device() + + @testset "Neural DE" begin + mp = Float32[0.1, 0.1] |> gdev + x = Float32[2.0; 0.0] |> gdev + xs = Float32.(hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0])) |> gdev + tspan = (0.0f0, 1.0f0) + + dudt = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + aug_dudt = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + + @testset "Neural ODE" begin + @testset "u0: $(typeof(u0))" for u0 in (x, xs) + @testset "kwargs: $(kwargs))" for kwargs in ( + (; save_everystep = false, save_start = false), + (; save_everystep = false, save_start = false, + sensealg = TrackerAdjoint()), + (; save_everystep = false, save_start = false, + sensealg = BacksolveAdjoint()), + (; saveat = 0.0f0:0.1f0:1.0f0), + (; saveat = 0.1f0), + (; saveat = 0.0f0:0.1f0:1.0f0, sensealg = TrackerAdjoint()), + (; saveat = 0.1f0, sensealg = TrackerAdjoint())) + node = NeuralODE(dudt, tspan, Tsit5(); kwargs...) + pd, st = Lux.setup(rng, node) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + broken = hasfield(typeof(kwargs), :sensealg) && + ndims(u0) == 2 && + kwargs.sensealg isa TrackerAdjoint + @test begin + grads = Zygote.gradient(sum ∘ last ∘ first ∘ node, u0, pd, st) + CUDA.@allowscalar begin + !iszero(grads[1]) && !iszero(grads[2]) + end + end broken = broken + + anode = AugmentedNDELayer( + NeuralODE(aug_dudt, tspan, Tsit5(); kwargs...), 2) + pd, st = Lux.setup(rng, anode) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + @test begin + grads = Zygote.gradient(sum ∘ last ∘ first ∘ anode, u0, pd, st) + CUDA.@allowscalar begin + !iszero(grads[1]) && !iszero(grads[2]) + end + end broken = broken + end + end + end + + diffusion = Chain(Dense(2 => 50, tanh), Dense(50 => 2)) + aug_diffusion = Chain(Dense(4 => 50, tanh), Dense(50 => 4)) + + tspan = (0.0f0, 0.1f0) + @testset "NeuralDSDE u0: $(typeof(u0)), solver: $(solver)" for u0 in (xs,), + solver in (SOSRI(),) + # CuVector seems broken on CI but I can't reproduce the failure locally + + sode = NeuralDSDE( + dudt, diffusion, tspan, solver; saveat = 0.0f0:0.01f0:0.1f0, dt = 0.01f0) + pd, st = Lux.setup(rng, sode) + pd = ComponentArray(pd) |> gdev + st = st |> gdev + + @test_broken begin + grads = Zygote.gradient(sum ∘ last ∘ first ∘ sode, u0, pd, st) + CUDA.@allowscalar begin + !iszero(grads[1]) && !iszero(grads[2]) && !iszero(grads[2][end]) + end + end + end + end +end diff --git a/test/otflow_tests.jl b/test/otflow_tests.jl index 92ddab3f8..d131941bc 100644 --- a/test/otflow_tests.jl +++ b/test/otflow_tests.jl @@ -1,65 +1,65 @@ @testitem "Tests for OTFlow Layer Functionality" begin - using Lux, LuxCore, Random, LinearAlgebra, Test, ComponentArrays, Flux, DiffEqFlux - rng = Xoshiro(0) - d = 2 - m = 4 - r = 2 - otflow = OTFlow(d, m; r=r) - ps, st = Lux.setup(rng, otflow) - ps = ComponentArray(ps) + using Lux, LuxCore, Random, LinearAlgebra, Test, ComponentArrays, Flux, DiffEqFlux + rng = Xoshiro(0) + d = 2 + m = 4 + r = 2 + otflow = OTFlow(d, m; r = r) + ps, st = Lux.setup(rng, otflow) + ps = ComponentArray(ps) - x = Float32[1.0, 2.0] - t = 0.5f0 + x = Float32[1.0, 2.0] + t = 0.5f0 - @testitem "Forward Pass" begin - (v, tr), st_new = otflow((x, t), ps, st) - @test length(v) == d - @test isa(tr, Float32) - @test st_new == st - end + @testset "Forward Pass" begin + (v, tr), st_new = otflow((x, t), ps, st) + @test length(v) == d + @test isa(tr, Float32) + @test st_new == st + end - @testitem "Potential Function" begin - phi = potential(x, t, ps) - @test isa(phi, Float32) - end + @testset "Potential Function" begin + phi = potential(x, t, ps) + @test isa(phi, Float32) + end - @testitem "Gradient Consistency" begin - grad = gradient(x, t, ps, d) - (v, _), _ = otflow((x, t), ps, st) - @test length(grad) == d - @test grad ≈ -v atol=1e-5 # v = -∇Φ - end + @testset "Gradient Consistency" begin + grad = gradient(x, t, ps, d) + (v, _), _ = otflow((x, t), ps, st) + @test length(grad) == d + @test grad ≈ -v atol = 1e-5 # v = -∇Φ + end - @testitem "Trace Consistency" begin - tr_manual = trace(x, t, ps, d) - (_, tr_forward), _ = otflow((x, t), ps, st) - @test tr_manual ≈ -tr_forward atol=1e-5 - end + @testset "Trace Consistency" begin + tr_manual = trace(x, t, ps, d) + (_, tr_forward), _ = otflow((x, t), ps, st) + @test tr_manual ≈ -tr_forward atol = 1e-5 + end - @testitem "ODE Integration" begin - x0 = Float32[1.0, 1.0] - tspan = (0.0f0, 1.0f0) - x_traj, t_vec = simple_ode_solve(otflow, x0, tspan, ps, st; dt=0.01f0) - @test size(x_traj) == (d, length(t_vec)) - @test all(isfinite, x_traj) - @test x_traj[:, end] != x0 - end + @testset "ODE Integration" begin + x0 = Float32[1.0, 1.0] + tspan = (0.0f0, 1.0f0) + x_traj, t_vec = simple_ode_solve(otflow, x0, tspan, ps, st; dt = 0.01f0) + @test size(x_traj) == (d, length(t_vec)) + @test all(isfinite, x_traj) + @test x_traj[:, end] != x0 + end - @testitem "Loss Function" begin - loss_val = simple_loss(x, t, otflow, ps) - @test isa(loss_val, Float32) - @test isfinite(loss_val) - end + @testset "Loss Function" begin + loss_val = simple_loss(x, t, otflow, ps) + @test isa(loss_val, Float32) + @test isfinite(loss_val) + end - @testitem "Manual Gradient" begin - grads = manual_gradient(x, t, otflow, ps) - @test haskey(grads, :w) && length(grads.w) == m - @test haskey(grads, :A) && size(grads.A) == (r, d+1) - @test haskey(grads, :b) && length(grads.b) == d+1 - @test haskey(grads, :c) && isa(grads.c, Float32) - @test haskey(grads, :K0) && size(grads.K0) == (m, d+1) - @test haskey(grads, :K1) && size(grads.K1) == (m, m) - @test haskey(grads, :b0) && length(grads.b0) == m - @test haskey(grads, :b1) && length(grads.b1) == m - end -end \ No newline at end of file + @testset "Manual Gradient" begin + grads = manual_gradient(x, t, otflow, ps) + @test haskey(grads, :w) && length(grads.w) == m + @test haskey(grads, :A) && size(grads.A) == (r, d + 1) + @test haskey(grads, :b) && length(grads.b) == d + 1 + @test haskey(grads, :c) && isa(grads.c, Float32) + @test haskey(grads, :K0) && size(grads.K0) == (m, d + 1) + @test haskey(grads, :K1) && size(grads.K1) == (m, m) + @test haskey(grads, :b0) && length(grads.b0) == m + @test haskey(grads, :b1) && length(grads.b1) == m + end +end diff --git a/test/stiff_nested_ad_tests.jl b/test/stiff_nested_ad_tests.jl index 35de20692..4742936f4 100644 --- a/test/stiff_nested_ad_tests.jl +++ b/test/stiff_nested_ad_tests.jl @@ -27,7 +27,7 @@ end end - @testitem "Solver: $(nameof(typeof(solver)))" for solver in ( + @testset "Solver: $(nameof(typeof(solver)))" for solver in ( KenCarp4(), Rodas5(), RadauIIA5()) neuralde = NeuralODE(model, tspan, solver; saveat = t, reltol = 1e-7, abstol = 1e-9) ps, st = Lux.setup(Xoshiro(0), neuralde) From 9e97d171b35f55e7f1633793a4db2c21f23160d4 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Tue, 25 Feb 2025 09:36:01 +0530 Subject: [PATCH 10/12] Implemented Requested changes --- src/otflow.jl | 120 +++++++++++++++++++++++++------------------------- 1 file changed, 59 insertions(+), 61 deletions(-) diff --git a/src/otflow.jl b/src/otflow.jl index 2c7639a1c..9ec45f987 100644 --- a/src/otflow.jl +++ b/src/otflow.jl @@ -1,98 +1,96 @@ struct OTFlow <: AbstractLuxLayer - d::Int # Input dimension - m::Int # Hidden dimension - r::Int # Rank for low-rank approximation + d::Int # Input dimension + m::Int # Hidden dimension + r::Int # Rank for low-rank approximation end # Constructor with default rank -OTFlow(d::Int, m::Int; r::Int = min(10, d)) = OTFlow(d, m, r) +OTFlow(d::Int, m::Int; r::Int=min(10,d)) = OTFlow(d, m, r) # Initialize parameters and states function Lux.initialparameters(rng::AbstractRNG, l::OTFlow) - w = randn(rng, Float32, l.m) .* 0.01 - A = randn(rng, Float32, l.r, l.d + 1) .* 0.01 - b = zeros(Float32, l.d + 1) - c = zero(Float32) - K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01 - K1 = randn(rng, Float32, l.m, l.m) .* 0.01 - b0 = zeros(Float32, l.m) - b1 = zeros(Float32, l.m) - - ps = (w = w, A = A, b = b, c = c, K0 = K0, K1 = K1, b0 = b0, b1 = b1) - st = NamedTuple() - return ps, st + w = randn(rng, Float32, l.m) .* 0.01f0 + A = randn(rng, Float32, l.r, l.d + 1) .* 0.01f0 + b = zeros(Float32, l.d + 1) + c = zero(Float32) + K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01f0 + K1 = randn(rng, Float32, l.m, l.m) .* 0.01f0 + b0 = zeros(Float32, l.m) + b1 = zeros(Float32, l.m) + + return (; w, A, b, c, K0, K1, b0, b1) end -σ(x) = log(exp(x) + exp(-x)) -σ′(x) = tanh(x) -σ′′(x) = 1 - tanh(x)^2 +sigma(x) = log(exp(x) + exp(-x)) +sigma_prime(x) = tanh(x) +sigma_double_prime(x) = 1 - tanh(x)^2 function resnet_forward(x::AbstractVector, t::Real, ps) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) # h=1 as in paper - return u1 + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) # h=1 as in paper + return u1 end function potential(x::AbstractVector, t::Real, ps) - s = vcat(x, t) - N = resnet_forward(x, t, ps) - quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s - linear_term = ps.b' * s - return ps.w' * N + quadratic_term + linear_term + ps.c + s = vcat(x, t) + N = resnet_forward(x, t, ps) + quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s + linear_term = ps.b' * s + return ps.w' * N + quadratic_term + linear_term + ps.c end function gradient(x::AbstractVector, t::Real, ps, d::Int) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) - z0 = ps.K0' * (σ′.(ps.K0 * s .+ ps.b0) .* z1) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + z1 = ps.w .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w) + z0 = ps.K0' * (sigma_prime.(ps.K0 * s .+ ps.b0) .* z1) - grad = z0 + (ps.A' * ps.A) * s + ps.b - return grad[1:d] + grad = z0 + (ps.A' * ps.A) * s + ps.b + return grad[1:d] end function trace(x::AbstractVector, t::Real, ps, d::Int) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - z1 = ps.w .+ ps.K1' * (σ′.(ps.K1 * u0 .+ ps.b1) .* ps.w) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + z1 = ps.w .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w) - K0_E = ps.K0[:, 1:d] - A_E = ps.A[:, 1:d] + K0_E = ps.K0[:, 1:d] + A_E = ps.A[:, 1:d] - t0 = sum(σ′′.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) - J = Diagonal(σ′.(ps.K0 * s .+ ps.b0)) * K0_E - t1 = sum(σ′′.(ps.K1 * u0 .+ ps.b1) .* ps.w .* (ps.K1 * J) .^ 2) - trace_A = tr(A_E' * A_E) + t0 = sum(sigma_double_prime.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) + J = Diagonal(sigma_prime.(ps.K0 * s .+ ps.b0)) * K0_E + t1 = sum(sigma_double_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w .* (ps.K1 * J) .^ 2) + trace_A = tr(A_E' * A_E) - return t0 + t1 + trace_A + return t0 + t1 + trace_A end function (l::OTFlow)(xt::Tuple{AbstractVector, Real}, ps, st) - x, t = xt - v = -gradient(x, t, ps, l.d) # v = -∇Φ - tr = -trace(x, t, ps, l.d) # tr(∇v) = -tr(∇²Φ) - return (v, tr), st + x, t = xt + v = -gradient(x, t, ps, l.d) # v = -∇Φ + tr = -trace(x, t, ps, l.d) # tr(∇v) = -tr(∇²Φ) + return (v, tr), st end function simple_loss(x::AbstractVector, t::Real, l::OTFlow, ps) - (v, tr), _ = l((x, t), ps, NamedTuple()) - return sum(v .^ 2) / 2 - tr + (v, tr), _ = l((x, t), ps, NamedTuple()) + return sum(v .^ 2) / 2 - tr end function manual_gradient(x::AbstractVector, t::Real, l::OTFlow, ps) - s = vcat(x, t) - u0 = σ.(ps.K0 * s .+ ps.b0) - u1 = u0 .+ σ.(ps.K1 * u0 .+ ps.b1) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) - v = -gradient(x, t, ps, l.d) - tr = -trace(x, t, ps, l.d) + v = -gradient(x, t, ps, l.d) + tr = -trace(x, t, ps, l.d) - # Simplified gradients (not full implementation) - grad_w = u1 - grad_A = (ps.A * s) * s' + # Simplified gradients (not full implementation) + grad_w = u1 + grad_A = (ps.A * s) * s' - return (w = grad_w, A = grad_A, b = similar(ps.b), c = 0.0, - K0 = zeros(l.m, l.d + 1), K1 = zeros(l.m, l.m), - b0 = zeros(l.m), b1 = zeros(l.m)) + return (w = grad_w, A = grad_A, b = similar(ps.b), c = 0.0, + K0 = zeros(l.m, l.d + 1), K1 = zeros(l.m, l.m), + b0 = zeros(l.m), b1 = zeros(l.m)) end From 9a0c2cabe8cedc6675db87146a15561d384183ba Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 26 Feb 2025 09:38:53 +0530 Subject: [PATCH 11/12] Made modifications in trainable parameters --- src/otflow.jl | 124 ++++++++++++++++++++++++-------------------------- 1 file changed, 59 insertions(+), 65 deletions(-) diff --git a/src/otflow.jl b/src/otflow.jl index 9ec45f987..f2545a474 100644 --- a/src/otflow.jl +++ b/src/otflow.jl @@ -1,24 +1,23 @@ +using LinearAlgebra + struct OTFlow <: AbstractLuxLayer - d::Int # Input dimension - m::Int # Hidden dimension - r::Int # Rank for low-rank approximation + d::Int + m::Int + r::Int end -# Constructor with default rank -OTFlow(d::Int, m::Int; r::Int=min(10,d)) = OTFlow(d, m, r) +OTFlow(d::Int, m::Int; r::Int = min(10, d)) = OTFlow(d, m, r) -# Initialize parameters and states function Lux.initialparameters(rng::AbstractRNG, l::OTFlow) - w = randn(rng, Float32, l.m) .* 0.01f0 - A = randn(rng, Float32, l.r, l.d + 1) .* 0.01f0 - b = zeros(Float32, l.d + 1) - c = zero(Float32) - K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01f0 - K1 = randn(rng, Float32, l.m, l.m) .* 0.01f0 - b0 = zeros(Float32, l.m) - b1 = zeros(Float32, l.m) - - return (; w, A, b, c, K0, K1, b0, b1) + w = randn(rng, Float32, l.m) .* 0.01f0 + A = randn(rng, Float32, l.r, l.d + 1) .* 0.01f0 + b = zeros(Float32, l.d + 1) + c = randn(rng, Float32, l.m) .* 0.01f0 + K0 = randn(rng, Float32, l.m, l.d + 1) .* 0.01f0 + K1 = randn(rng, Float32, l.m, l.m) .* 0.01f0 + b0 = zeros(Float32, l.m) + b1 = zeros(Float32, l.m) + return (; w, A, b, c, K0, K1, b0, b1) end sigma(x) = log(exp(x) + exp(-x)) @@ -26,71 +25,66 @@ sigma_prime(x) = tanh(x) sigma_double_prime(x) = 1 - tanh(x)^2 function resnet_forward(x::AbstractVector, t::Real, ps) - s = vcat(x, t) - u0 = sigma.(ps.K0 * s .+ ps.b0) - u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) # h=1 as in paper - return u1 + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) + return u1 end function potential(x::AbstractVector, t::Real, ps) - s = vcat(x, t) - N = resnet_forward(x, t, ps) - quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s - linear_term = ps.b' * s - return ps.w' * N + quadratic_term + linear_term + ps.c + s = vcat(x, t) + N = resnet_forward(x, t, ps) + quadratic_term = 0.5 * s' * (ps.A' * ps.A) * s + linear_term = ps.b' * s + neural_term = sum((ps.w .+ ps.c) .* N) + return neural_term + quadratic_term + linear_term end function gradient(x::AbstractVector, t::Real, ps, d::Int) - s = vcat(x, t) - u0 = sigma.(ps.K0 * s .+ ps.b0) - z1 = ps.w .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w) - z0 = ps.K0' * (sigma_prime.(ps.K0 * s .+ ps.b0) .* z1) - - grad = z0 + (ps.A' * ps.A) * s + ps.b - return grad[1:d] + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + z1 = (ps.w .+ ps.c) .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* (ps.w .+ ps.c)) + z0 = ps.K0' * (sigma_prime.(ps.K0 * s .+ ps.b0) .* z1) + grad = z0 + (ps.A' * ps.A) * s + ps.b + return grad[1:d] end function trace(x::AbstractVector, t::Real, ps, d::Int) - s = vcat(x, t) - u0 = sigma.(ps.K0 * s .+ ps.b0) - z1 = ps.w .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w) - - K0_E = ps.K0[:, 1:d] - A_E = ps.A[:, 1:d] - - t0 = sum(sigma_double_prime.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) - J = Diagonal(sigma_prime.(ps.K0 * s .+ ps.b0)) * K0_E - t1 = sum(sigma_double_prime.(ps.K1 * u0 .+ ps.b1) .* ps.w .* (ps.K1 * J) .^ 2) - trace_A = tr(A_E' * A_E) - - return t0 + t1 + trace_A + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + z1 = (ps.w .+ ps.c) .+ ps.K1' * (sigma_prime.(ps.K1 * u0 .+ ps.b1) .* (ps.w .+ ps.c)) + K0_E = ps.K0[:, 1:d] + A_E = ps.A[:, 1:d] + t0 = sum(sigma_double_prime.(ps.K0 * s .+ ps.b0) .* z1 .* (K0_E .^ 2)) + J = Diagonal(sigma_prime.(ps.K0 * s .+ ps.b0)) * K0_E + t1 = sum(sigma_double_prime.(ps.K1 * u0 .+ ps.b1) .* (ps.w .+ ps.c) .* (ps.K1 * J) .^ 2) + trace_A = tr(A_E' * A_E) + return t0 + t1 + trace_A end function (l::OTFlow)(xt::Tuple{AbstractVector, Real}, ps, st) - x, t = xt - v = -gradient(x, t, ps, l.d) # v = -∇Φ - tr = -trace(x, t, ps, l.d) # tr(∇v) = -tr(∇²Φ) - return (v, tr), st + x, t = xt + v = -gradient(x, t, ps, l.d) + tr = -trace(x, t, ps, l.d) + return (v, tr), st end function simple_loss(x::AbstractVector, t::Real, l::OTFlow, ps) - (v, tr), _ = l((x, t), ps, NamedTuple()) - return sum(v .^ 2) / 2 - tr + (v, tr), _ = l((x, t), ps, NamedTuple()) + return sum(v .^ 2) / 2 - tr end function manual_gradient(x::AbstractVector, t::Real, l::OTFlow, ps) - s = vcat(x, t) - u0 = sigma.(ps.K0 * s .+ ps.b0) - u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) - - v = -gradient(x, t, ps, l.d) - tr = -trace(x, t, ps, l.d) - - # Simplified gradients (not full implementation) - grad_w = u1 - grad_A = (ps.A * s) * s' - - return (w = grad_w, A = grad_A, b = similar(ps.b), c = 0.0, - K0 = zeros(l.m, l.d + 1), K1 = zeros(l.m, l.m), - b0 = zeros(l.m), b1 = zeros(l.m)) + s = vcat(x, t) + u0 = sigma.(ps.K0 * s .+ ps.b0) + u1 = u0 .+ sigma.(ps.K1 * u0 .+ ps.b1) + v = -gradient(x, t, ps, l.d) + tr = -trace(x, t, ps, l.d) + grad_w = u1 + grad_c = u1 + grad_A = (ps.A * s) * s' + grad_b = similar(ps.b) + return (w = grad_w, A = grad_A, b = grad_b, c = grad_c, + K0 = zeros(l.m, l.d + 1), K1 = zeros(l.m, l.m), + b0 = zeros(l.m), b1 = zeros(l.m)) end From 3073d3ff527a06e6d7851eb7368901285daf18bd Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 26 Feb 2025 09:42:17 +0530 Subject: [PATCH 12/12] Made modifications in trainable parameters --- src/otflow.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/otflow.jl b/src/otflow.jl index f2545a474..9b9de8d56 100644 --- a/src/otflow.jl +++ b/src/otflow.jl @@ -1,5 +1,3 @@ -using LinearAlgebra - struct OTFlow <: AbstractLuxLayer d::Int m::Int