Skip to content

Commit ba7217c

Browse files
committed
Add failing test cases for flatness check
1 parent 3259cd2 commit ba7217c

File tree

1 file changed

+43
-11
lines changed

1 file changed

+43
-11
lines changed

test/captured.jl

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,47 @@ end
2121
end
2222
end
2323

24-
# From PR#174
25-
@testset "PR#174" begin
26-
tc = LineSearchTestCase(
27-
[0.0, 1.0, 5.0, 3.541670844449739],
28-
[3003.592409634743, 2962.0378569864743, 2891.4462095232184, 3000.9760725116876],
29-
[-22332.321416890798, -20423.214551925797, 11718.185026267562, -22286.821227217057]
30-
)
31-
fdf = OnceDifferentiable(tc)
32-
hz = HagerZhang()
33-
α, val = hz(fdf.f, fdf.fdf, 1.0, fdf.fdf(0.0)...)
34-
@test_broken val <= minimum(tc)
24+
wolfec1(ls::BackTracking) = ls.c_1
25+
wolfec1(ls::StrongWolfe) = ls.c_1
26+
wolfec1(ls::MoreThuente) = ls.f_tol
27+
wolfec1(ls::HagerZhang{T}) where T = zero(T) # HZ uses Wolfe condition, but not necessarily based at zero
28+
29+
function amongtypes(x, types)
30+
t = typeof(x)
31+
return any(t <: ty for ty in types)
32+
end
33+
34+
@testset "Captured cases" begin
35+
lsalgs = (HagerZhang(), StrongWolfe(), MoreThuente(),
36+
BackTracking(), BackTracking(; order=2) )
37+
# From PR#174
38+
@testset "PR#174" begin
39+
tc = LineSearchTestCase(
40+
[0.0, 1.0, 5.0, 3.541670844449739],
41+
[3003.592409634743, 2962.0378569864743, 2891.4462095232184, 3000.9760725116876],
42+
[-22332.321416890798, -20423.214551925797, 11718.185026267562, -22286.821227217057]
43+
)
44+
fdf = OnceDifferentiable(tc)
45+
for ls in lsalgs
46+
α, val = ls(fdf.f, fdf.df, fdf.fdf, 1.0, fdf.fdf(0.0)...)
47+
@test 0 < α <= 5
48+
@test val < fdf.f(0) + wolfec1(ls) * α * fdf.df(0) # Armijo (first Wolfe) condition
49+
if amongtypes(ls, (StrongWolfe, MoreThuente, HagerZhang)) # these types do not just backtrack
50+
@test val <= tc.values[2]
51+
end
52+
end
53+
end
54+
@testset "Issue#175" begin
55+
tc = LineSearchTestCase(
56+
[0.0, 0.2, 0.1, 0.055223623837026156],
57+
[3.042968312396456, 3.1174112871667603, -3.5035848233450224, 0.5244246783151265],
58+
[-832.4270136930788, -505.3362249257043, 674.9478303586366, 738.3388472427769]
59+
)
60+
fdf = OnceDifferentiable(tc)
61+
for ls in lsalgs
62+
α, val = ls(fdf.f, fdf.df, fdf.fdf, 0.2, fdf.fdf(0.0)...)
63+
@test 0 < α < 0.2
64+
@test val < fdf.f(0) + wolfec1(ls) * α * fdf.df(0)
65+
end
66+
end
3567
end

0 commit comments

Comments
 (0)