@@ -59,17 +59,15 @@ Keyword arguments are passed to the underlying linear maximizer.
59
59
function InferOpt. compute_probability_distribution (
60
60
regularized:: RegularizedFrankWolfe , θ:: AbstractArray ; kwargs...
61
61
)
62
- shape = size (θ)
63
62
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs) = regularized
64
63
f (y, θ) = Ω (y) - dot (θ, y)
65
64
f_grad1 (y, θ) = Ω_grad (y) - θ
66
- maximizer (θ; shape, kwargs... ) = vec ( linear_maximizer (reshape (θ, shape) ; kwargs... ) )
67
- lmo = LinearMaximizationOracleWithKwargs (maximizer, (; shape, kwargs... ) )
65
+ maximizer (θ; kwargs... ) = linear_maximizer (θ ; kwargs... )
66
+ lmo = LinearMaximizationOracleWithKwargs (maximizer, kwargs)
68
67
dfw = DiffFW (f, f_grad1, lmo; implicit_kwargs)
69
- weights, atoms = dfw. implicit (vec (θ); frank_wolfe_kwargs= frank_wolfe_kwargs)
70
- probadist = FixedAtomsProbabilityDistribution (
71
- map (atom -> reshape (atom, shape), atoms), weights
72
- )
68
+ weights, stats = dfw. implicit (θ, frank_wolfe_kwargs)
69
+ atoms = stats. active_set. atoms # TODO : make it public in DiffFW
70
+ probadist = FixedAtomsProbabilityDistribution (atoms, weights)
73
71
return probadist
74
72
end
75
73
0 commit comments