Skip to content

Commit 01bbd5a

Browse files
authored
Merge pull request #122 from JuliaDecisionFocusedLearning/gd/fw
Adapt to latest DifferentiableFrankWolfe
2 parents 4d8c3ff + 7a5fbd4 commit 01bbd5a

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ InferOptFrankWolfeExt = ["DifferentiableFrankWolfe", "FrankWolfe", "ImplicitDiff
3131
ChainRulesCore = "1"
3232
DensityInterface = "0.4.0"
3333
DifferentiableExpectations = "0.2"
34-
DifferentiableFrankWolfe = "0.3"
34+
DifferentiableFrankWolfe = "0.4.1"
3535
Distributions = "0.25"
3636
DocStringExtensions = "0.9"
37-
FrankWolfe = "0.3"
38-
ImplicitDifferentiation = "0.6"
37+
FrankWolfe = "0.3,0.4"
38+
ImplicitDifferentiation = "0.7.2"
3939
LinearAlgebra = "1"
4040
Random = "1"
4141
RequiredInterfaces = "0.1.3"

ext/InferOptFrankWolfeExt.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,15 @@ Keyword arguments are passed to the underlying linear maximizer.
5959
function InferOpt.compute_probability_distribution(
6060
regularized::RegularizedFrankWolfe, θ::AbstractArray; kwargs...
6161
)
62-
shape = size(θ)
6362
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs) = regularized
6463
f(y, θ) = Ω(y) - dot(θ, y)
6564
f_grad1(y, θ) = Ω_grad(y) - θ
66-
maximizer(θ; shape, kwargs...) = vec(linear_maximizer(reshape(θ, shape); kwargs...))
67-
lmo = LinearMaximizationOracleWithKwargs(maximizer, (; shape, kwargs...))
65+
maximizer(θ; kwargs...) = linear_maximizer(θ; kwargs...)
66+
lmo = LinearMaximizationOracleWithKwargs(maximizer, kwargs)
6867
dfw = DiffFW(f, f_grad1, lmo; implicit_kwargs)
69-
weights, atoms = dfw.implicit(vec(θ); frank_wolfe_kwargs=frank_wolfe_kwargs)
70-
probadist = FixedAtomsProbabilityDistribution(
71-
map(atom -> reshape(atom, shape), atoms), weights
72-
)
68+
weights, stats = dfw.implicit(θ, frank_wolfe_kwargs)
69+
atoms = stats.active_set.atoms # TODO: make it public in DiffFW
70+
probadist = FixedAtomsProbabilityDistribution(atoms, weights)
7371
return probadist
7472
end
7573

0 commit comments

Comments
 (0)