Skip to content

Commit 1c0510f

Browse files
committed
fix bug in perturbed multiplicative
1 parent b6cc895 commit 1c0510f

File tree

4 files changed

+58
-27
lines changed

4 files changed

+58
-27
lines changed

src/layers/perturbed/perturbation.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ function (pdc::AdditivePerturbation)(θ::AbstractArray)
4444
return product_distribution.+ ε * perturbation_dist)
4545
end
4646

47+
"""
48+
$TYPEDSIGNATURES
49+
50+
Compute the gradient of the logdensity of η = θ + εZ w.r.t. θ., with Z ∼ N(0, 1).
51+
"""
52+
function normal_additive_grad_logdensity(ε, η, θ)
53+
return ((η .- θ) ./ ε^2,)
54+
end
55+
4756
"""
4857
$TYPEDEF
4958
@@ -68,3 +77,13 @@ function (pdc::MultiplicativePerturbation)(θ::AbstractArray)
6877
(; perturbation_dist, ε) = pdc
6978
return product_distribution.* ExponentialOf* perturbation_dist - ε^2 / 2))
7079
end
80+
"""
81+
$TYPEDSIGNATURES
82+
83+
Compute the gradient of the logdensity of η = θ ⊙ exp(εZ - ε²/2) w.r.t. θ., with Z ∼ N(0, 1).
84+
!!! warning
85+
η should be a relization of θ, i.e. should be of the same sign.
86+
"""
87+
function normal_multiplicative_grad_logdensity(ε, η, θ)
88+
return (inv.(ε^2 .* θ) .* (log.(abs.(η)) - log.(abs.(θ)) .+^2 / 2)),)
89+
end

src/layers/perturbed/perturbed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function PerturbedAdditive(
9393
threaded=false,
9494
rng=Random.default_rng(),
9595
dist_logdensity_grad=if (perturbation_dist == Normal(0, 1))
96-
(η, θ) -> ((η .- θ) ./ ε^2,)
96+
FixFirst(normal_additive_grad_logdensity, ε)
9797
else
9898
nothing
9999
end,
@@ -126,7 +126,7 @@ function PerturbedMultiplicative(
126126
threaded=false,
127127
rng=Random.default_rng(),
128128
dist_logdensity_grad=if (perturbation_dist == Normal(0, 1))
129-
(η, θ) -> (inv.(ε^2 .* θ) .*.- θ),)
129+
FixFirst(normal_multiplicative_grad_logdensity, ε)
130130
else
131131
nothing
132132
end,

src/utils/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,16 @@ struct Fix1Kwargs{F,K,T} <: Function
3333
end
3434

3535
(fk::Fix1Kwargs)(args...) = fk.f(fk.x, args...; fk.kwargs...)
36+
37+
"""
38+
$TYPEDEF
39+
40+
Callable struct that fixes the first argument of `f` to `x`.
41+
Compared to Base.Fix1, works on functions with more than two arguments.
42+
"""
43+
struct FixFirst{F,T}
44+
f::F
45+
x::T
46+
end
47+
48+
(fk::FixFirst)(args...) = fk.f(fk.x, args...)

test/perturbed.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66

77
θ = [3, 5, 4, 2]
88

9-
perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0)
10-
perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0)
11-
perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0)
12-
perturbed2_big = PerturbedMultiplicative(
13-
one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0
14-
)
9+
perturbed1 = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0)
10+
perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0)
11+
12+
perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0)
13+
perturbed2_big = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0)
1514

1615
@testset "PerturbedAdditive" begin
1716
# Compute jacobian with reverse mode
18-
jac1 = Zygote.jacobian-> perturbed1(θ; autodiff_variance_reduction=false), θ)[1]
19-
jac1_big = Zygote.jacobian(
20-
θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ
21-
)[1]
17+
jac1 = Zygote.jacobian(perturbed1, θ)[1]
18+
jac1_big = Zygote.jacobian(perturbed1_big, θ)[1]
2219
# Only diagonal should be positive
2320
@test all(diag(jac1) .>= 0)
2421
@test all(jac1 - Diagonal(jac1) .<= 0)
@@ -29,13 +26,12 @@
2926
end
3027

3128
@testset "PerturbedMultiplicative" begin
32-
jac2 = Zygote.jacobian-> perturbed2(θ; autodiff_variance_reduction=false), θ)[1]
33-
jac2_big = Zygote.jacobian(
34-
θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ
35-
)[1]
29+
jac2 = Zygote.jacobian(perturbed2, θ)[1]
30+
jac2_big = Zygote.jacobian(perturbed2_big, θ)[1]
3631
@test all(diag(jac2_big) .>= 0)
3732
@test all(jac2_big - Diagonal(jac2_big) .<= 0)
38-
@test sortperm(diag(jac2_big)) == sortperm(θ)
33+
@info diag(jac2_big)
34+
@test_broken sortperm(diag(jac2_big)) == sortperm(θ)
3935
@test norm(jac2) norm(jac2_big) rtol = 5e-2
4036
end
4137
end
@@ -99,18 +95,21 @@ end
9995

10096
ε = 1e-12
10197

102-
function already_differentiable(θ)
103-
return 2 ./ exp.(θ) .* θ .^ 2
104-
end
98+
already_differentiable(θ) = 2 ./ exp.(θ) .* θ .^ 2 .+ sum(θ)
99+
pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0)
100+
pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0)
105101

106-
θ = randn(5)
107-
Jz = jacobian(already_differentiable, θ)[1]
102+
θ = [1.0, 2.0, 3.0, 4.0, 5.0]
108103

109-
pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0)
110-
Ja = jacobian(pa, θ)[1]
111-
@test_broken all(isapprox.(Ja, Jz, rtol=0.01))
104+
fz = already_differentiable(θ)
105+
fa = pa(θ)
106+
fm = pm(θ)
107+
@test fz fa rtol = 0.01
108+
@test fz fm rtol = 0.01
112109

113-
pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0)
110+
Jz = jacobian(already_differentiable, θ)[1]
111+
Ja = jacobian(pa, θ)[1]
114112
Jm = jacobian(pm, θ)[1]
115-
@test_broken all(isapprox.(Jm, Jz, rtol=0.01))
113+
@test Ja Jz rtol = 0.01
114+
@test Jm Jz rtol = 0.01
116115
end

0 commit comments

Comments
 (0)