|
6 | 6 |
|
7 | 7 | θ = [3, 5, 4, 2]
|
8 | 8 |
|
9 |
| - perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0) |
10 |
| - perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0) |
11 |
| - perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0) |
12 |
| - perturbed2_big = PerturbedMultiplicative( |
13 |
| - one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0 |
14 |
| - ) |
| 9 | + perturbed1 = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) |
| 10 | + perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) |
| 11 | + |
| 12 | + perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) |
| 13 | + perturbed2_big = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) |
15 | 14 |
|
16 | 15 | @testset "PerturbedAdditive" begin
|
17 | 16 | # Compute jacobian with reverse mode
|
18 |
| - jac1 = Zygote.jacobian(θ -> perturbed1(θ; autodiff_variance_reduction=false), θ)[1] |
19 |
| - jac1_big = Zygote.jacobian( |
20 |
| - θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ |
21 |
| - )[1] |
| 17 | + jac1 = Zygote.jacobian(perturbed1, θ)[1] |
| 18 | + jac1_big = Zygote.jacobian(perturbed1_big, θ)[1] |
22 | 19 | # Only diagonal should be positive
|
23 | 20 | @test all(diag(jac1) .>= 0)
|
24 | 21 | @test all(jac1 - Diagonal(jac1) .<= 0)
|
|
29 | 26 | end
|
30 | 27 |
|
31 | 28 | @testset "PerturbedMultiplicative" begin
|
32 |
| - jac2 = Zygote.jacobian(θ -> perturbed2(θ; autodiff_variance_reduction=false), θ)[1] |
33 |
| - jac2_big = Zygote.jacobian( |
34 |
| - θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ |
35 |
| - )[1] |
| 29 | + jac2 = Zygote.jacobian(perturbed2, θ)[1] |
| 30 | + jac2_big = Zygote.jacobian(perturbed2_big, θ)[1] |
36 | 31 | @test all(diag(jac2_big) .>= 0)
|
37 | 32 | @test all(jac2_big - Diagonal(jac2_big) .<= 0)
|
38 |
| - @test sortperm(diag(jac2_big)) == sortperm(θ) |
| 33 | + @info diag(jac2_big) |
| 34 | + @test_broken sortperm(diag(jac2_big)) == sortperm(θ) |
39 | 35 | @test norm(jac2) ≈ norm(jac2_big) rtol = 5e-2
|
40 | 36 | end
|
41 | 37 | end
|
|
99 | 95 |
|
100 | 96 | ε = 1e-12
|
101 | 97 |
|
102 |
| - function already_differentiable(θ) |
103 |
| - return 2 ./ exp.(θ) .* θ .^ 2 |
104 |
| - end |
| 98 | + already_differentiable(θ) = 2 ./ exp.(θ) .* θ .^ 2 .+ sum(θ) |
| 99 | + pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) |
| 100 | + pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) |
105 | 101 |
|
106 |
| - θ = randn(5) |
107 |
| - Jz = jacobian(already_differentiable, θ)[1] |
| 102 | + θ = [1.0, 2.0, 3.0, 4.0, 5.0] |
108 | 103 |
|
109 |
| - pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) |
110 |
| - Ja = jacobian(pa, θ)[1] |
111 |
| - @test_broken all(isapprox.(Ja, Jz, rtol=0.01)) |
| 104 | + fz = already_differentiable(θ) |
| 105 | + fa = pa(θ) |
| 106 | + fm = pm(θ) |
| 107 | + @test fz ≈ fa rtol = 0.01 |
| 108 | + @test fz ≈ fm rtol = 0.01 |
112 | 109 |
|
113 |
| - pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) |
| 110 | + Jz = jacobian(already_differentiable, θ)[1] |
| 111 | + Ja = jacobian(pa, θ)[1] |
114 | 112 | Jm = jacobian(pm, θ)[1]
|
115 |
| - @test_broken all(isapprox.(Jm, Jz, rtol=0.01)) |
| 113 | + @test Ja ≈ Jz rtol = 0.01 |
| 114 | + @test Jm ≈ Jz rtol = 0.01 |
116 | 115 | end
|
0 commit comments