-
Notifications
You must be signed in to change notification settings - Fork 11
Reorganize sinkhorn_unbalanced
, improve convergence checks, and fix GPU issues
#80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4aa431c
55a6edd
7b5e2cf
b50289c
cb14d9e
ebee41d
6acb85e
eabbaad
85f1c30
c587490
980e9ab
9585ccd
31dc3f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "OptimalTransport" | ||
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" | ||
authors = ["zsteve <[email protected]>"] | ||
version = "0.3.3" | ||
version = "0.3.4" | ||
|
||
[deps] | ||
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,97 +304,227 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...) | |
end | ||
|
||
""" | ||
sinkhorn_unbalanced(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing) | ||
sinkhorn_unbalanced(μ, ν, C, λ1::Real, λ2::Real, ε; kwargs...) | ||
|
||
Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, | ||
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` | ||
for the marginals `mu`, `nu` respectively. | ||
Compute the optimal transport plan for the unbalanced entropically regularized optimal | ||
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size | ||
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation | ||
terms `λ1` and `λ2`. | ||
|
||
For full generality, the user can specify the soft marginal constraints ``(F_1(\\cdot | \\mu), F_2(\\cdot | \\nu))`` to the problem | ||
The optimal transport plan `γ` is of the same size as `C` and solves | ||
```math | ||
\\inf_{\\gamma} \\langle \\gamma, C \\rangle | ||
+ \\varepsilon \\Omega(\\gamma) | ||
+ \\lambda_1 \\operatorname{KL}(\\gamma 1 | \\mu) | ||
+ \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1 | \\nu), | ||
``` | ||
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic | ||
regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence. | ||
|
||
The keyword arguments supported here are the same as those in the `sinkhorn_unbalanced` | ||
for unbalanced optimal transport problems with general soft marginal constraints. | ||
""" | ||
function sinkhorn_unbalanced( | ||
μ, ν, C, λ1::Real, λ2::Real, ε; proxdiv_F1=nothing, proxdiv_F2=nothing, kwargs... | ||
) | ||
if proxdiv_F1 !== nothing && proxdiv_F2 !== nothing | ||
Base.depwarn( | ||
"keyword arguments `proxdiv_F1` and `proxdiv_F2` are deprecated", | ||
:sinkhorn_unbalanced, | ||
) | ||
|
||
# have to wrap the "proxdiv" functions since the signature changed | ||
# ε was fixed in the function, so we ignore it | ||
proxdiv_F1_wrapper(s, p, _) = copyto!(s, proxdiv_F1(s, p)) | ||
proxdiv_F2_wrapper(s, p, _) = copyto!(s, proxdiv_F2(s, p)) | ||
|
||
return sinkhorn_unbalanced( | ||
μ, ν, C, proxdiv_F1_wrapper, proxdiv_F2_wrapper, ε; kwargs... | ||
) | ||
end | ||
|
||
# define "proxdiv" functions for the unbalanced OT problem | ||
proxdivF!(s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ))) | ||
proxdivF1!(s, p, ε) = proxdivF!(s, p, ε, λ1) | ||
proxdivF2!(s, p, ε) = proxdivF!(s, p, ε, λ2) | ||
|
||
return sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs...) | ||
end | ||
|
||
""" | ||
sinkhorn_unbalanced( | ||
μ, ν, C, proxdivF1!, proxdivF2!, ε; | ||
atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, | ||
) | ||
|
||
Compute the optimal transport plan for the unbalanced entropically regularized optimal | ||
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size | ||
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal | ||
constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`. | ||
|
||
The optimal transport plan `γ` is of the same size as `C` and solves | ||
```math | ||
\\inf_{\\gamma} \\langle \\gamma, C \\rangle | ||
+ \\varepsilon \\Omega(\\gamma) | ||
+ F_1(\\gamma 1, \\mu) | ||
+ F_2(\\gamma^{\\mathsf{T}} 1, \\nu), | ||
``` | ||
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic | ||
regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)`` are soft marginal | ||
constraints for the source and target marginals. | ||
|
||
The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv" | ||
functions of ``F_1(\\cdot, p)`` and ``F_2(\\cdot, p)`` at ``s`` for the entropic | ||
regularization parameter ``\\varepsilon``. They have to be mutating and overwrite the first | ||
argument `s` with the result of their computations. | ||
|
||
Mathematically, the "proxdiv" functions are defined as | ||
```math | ||
\\operatorname{proxdiv}_{F_i}(s, p, \\varepsilon) | ||
= \\operatorname{prox}^{\\operatorname{KL}}_{F_i(\\cdot, p)/\\varepsilon}(s) \\oslash s | ||
``` | ||
where ``\\oslash`` denotes element-wise division and | ||
``\\operatorname{prox}_{F_i(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}`` is the proximal | ||
operator of ``F_i(\\cdot, p)/\\varepsilon`` for the Kullback-Leibler | ||
(``\\operatorname{KL}``) divergence. It is defined as | ||
```math | ||
\\operatorname{prox}_{F}^{\\operatorname{KL}}(x) | ||
= \\operatorname{argmin}_{y} F(y) + \\operatorname{KL}(y|x) | ||
``` | ||
and can be computed in closed-form for specific choices of ``F``. For instance, if | ||
``F(\\cdot, p) = \\lambda \\operatorname{KL}(\\cdot | p)`` (``\\lambda > 0``), then | ||
```math | ||
\\min_\\gamma \\epsilon \\mathrm{KL}(\\gamma | \\exp(-C/\\epsilon)) + F_1(\\gamma_1 | \\mu) + F_2(\\gamma_2 | \\nu) | ||
\\operatorname{prox}_{F(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}(x) | ||
= x^{\\frac{\\varepsilon}{\\varepsilon + \\lambda}} p^{\\frac{\\lambda}{\\varepsilon + \\lambda}}, | ||
``` | ||
where all operators are acting pointwise.[^CPSV18] | ||
|
||
Every `check_convergence` steps it is assessed if the algorithm is converged by checking if | ||
the iterates of the scaling factor in the current and previous iteration satisfy | ||
`isapprox(vcat(a, b), vcat(aprev, bprev); atol=atol, rtol=rtol)` where `a` and `b` are the | ||
current iterates and `aprev` and `bprev` the previous ones. The default `rtol` depends on | ||
the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the computation is stopped. | ||
|
||
via `math\\mathrm{proxdiv}_{F_1}(s, p)` and `math\\mathrm{proxdiv}_{F_2}(s, p)` (see Chizat et al., 2016 for details on this). If specified, the algorithm will use the user-specified F1, F2 rather than the default (a KL-divergence). | ||
[^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609. | ||
|
||
See also: [`sinkhorn_unbalanced2`](@ref) | ||
""" | ||
function sinkhorn_unbalanced( | ||
mu, | ||
nu, | ||
μ, | ||
ν, | ||
C, | ||
lambda1, | ||
lambda2, | ||
eps; | ||
tol=1e-9, | ||
max_iter=1000, | ||
verbose=false, | ||
proxdiv_F1=nothing, | ||
proxdiv_F2=nothing, | ||
proxdivF1!, | ||
proxdivF2!, | ||
ε; | ||
tol=nothing, | ||
atol=tol, | ||
rtol=nothing, | ||
max_iter=nothing, | ||
maxiter=max_iter, | ||
check_convergence::Int=10, | ||
) | ||
function proxdiv_KL(s, eps, lambda, p) | ||
return @. (s^(eps / (eps + lambda)) * p^(lambda / (eps + lambda))) / s | ||
# deprecations | ||
if tol !== nothing | ||
Base.depwarn( | ||
"keyword argument `tol` is deprecated, please use `atol` and `rtol`", | ||
:sinkhorn_unbalanced, | ||
) | ||
end | ||
if max_iter !== nothing | ||
Base.depwarn( | ||
"keyword argument `max_iter` is deprecated, please use `maxiter`", | ||
:sinkhorn_unbalanced, | ||
) | ||
end | ||
|
||
# compute Gibbs kernel | ||
K = @. exp(-C / ε) | ||
|
||
a = ones(size(mu, 1)) | ||
b = ones(size(nu, 1)) | ||
a_old = a | ||
b_old = b | ||
tmp_a = zeros(size(nu, 1)) | ||
tmp_b = zeros(size(mu, 1)) | ||
# set default values of squared tolerances | ||
T = float(Base.promote_eltype(μ, ν, K)) | ||
sqatol = atol === nothing ? 0 : atol^2 | ||
sqrtol = rtol === nothing ? (sqatol > zero(sqatol) ? zero(T) : eps(T)) : rtol^2 | ||
|
||
K = @. exp(-C / eps) | ||
# initialize iterates | ||
a = similar(μ, T) | ||
sum!(a, K) | ||
proxdivF1!(a, μ, ε) | ||
b = similar(ν, T) | ||
mul!(b, K', a) | ||
proxdivF2!(b, ν, ε) | ||
|
||
iter = 1 | ||
# caches for convergence checks | ||
a_old = similar(a) | ||
b_old = similar(b) | ||
|
||
while true | ||
a_old = a | ||
b_old = b | ||
tmp_b = K * b | ||
if proxdiv_F1 == nothing | ||
a = proxdiv_KL(tmp_b, eps, lambda1, mu) | ||
else | ||
a = proxdiv_F1(tmp_b, mu) | ||
end | ||
tmp_a = K' * a | ||
if proxdiv_F2 == nothing | ||
b = proxdiv_KL(tmp_a, eps, lambda2, nu) | ||
else | ||
b = proxdiv_F2(tmp_a, nu) | ||
isconverged = false | ||
_maxiter = maxiter === nothing ? 1_000 : maxiter | ||
for iter in 1:_maxiter | ||
# update cache if necessary | ||
ischeck = iter % check_convergence == 0 | ||
if ischeck | ||
copyto!(a_old, a) | ||
copyto!(b_old, b) | ||
end | ||
iter += 1 | ||
if iter % 10 == 0 | ||
err_a = | ||
maximum(abs.(a - a_old)) / max(maximum(abs.(a)), maximum(abs.(a_old)), 1) | ||
err_b = | ||
maximum(abs.(b - b_old)) / max(maximum(abs.(b)), maximum(abs.(b_old)), 1) | ||
if verbose | ||
println("Iteration $iter, err = ", 0.5 * (err_a + err_b)) | ||
end | ||
if (0.5 * (err_a + err_b) < tol) || iter > max_iter | ||
|
||
# compute next iterates | ||
mul!(a, K, b) | ||
proxdivF1!(a, μ, ε) | ||
mul!(b, K', a) | ||
proxdivF2!(b, ν, ε) | ||
|
||
# check convergence of the scaling factors | ||
if ischeck | ||
# compute norm of current and previous scaling factors and their difference | ||
sqnorm_a_b = sum(abs2, a) + sum(abs2, b) | ||
sqnorm_a_b_old = sum(abs2, a_old) + sum(abs2, b_old) | ||
a_old .-= a | ||
b_old .-= b | ||
sqeuclidean_a_b = sum(abs2, a_old) + sum(abs2, b_old) | ||
@debug "Sinkhorn algorithm (" * | ||
string(iter) * | ||
"/" * | ||
string(_maxiter) * | ||
": squared Euclidean distance of iterates = " * | ||
string(sqeuclidean_a_b) | ||
|
||
# check convergence of `a` | ||
if sqeuclidean_a_b < max(sqatol, sqrtol * max(sqnorm_a_b, sqnorm_a_b_old)) | ||
@debug "Sinkhorn algorithm ($iter/$_maxiter): converged" | ||
isconverged = true | ||
break | ||
end | ||
end | ||
end | ||
if iter > max_iter && verbose | ||
println("Warning: exited before convergence") | ||
|
||
if !isconverged | ||
@warn "Sinkhorn algorithm ($_maxiter/$_maxiter): not converged" | ||
end | ||
return Diagonal(a) * K * Diagonal(b) | ||
|
||
return K .* a .* b' | ||
end | ||
|
||
""" | ||
sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...) | ||
sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) | ||
sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...) | ||
|
||
Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, | ||
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` | ||
for the marginals mu, nu respectively. | ||
Compute the optimal transport plan for the unbalanced entropically regularized optimal | ||
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size | ||
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation | ||
terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and | ||
`proxdivF2!`. | ||
|
||
A pre-computed optimal transport `plan` may be provided. | ||
A pre-computed optimal transport `plan` may be provided. The other keyword arguments | ||
supported here are the same as those in the [`sinkhorn_unbalanced`](@ref) for unbalanced | ||
optimal transport problems with general soft marginal constraints. | ||
|
||
See also: [`sinkhorn_unbalanced`](@ref) | ||
""" | ||
function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) | ||
function sinkhorn_unbalanced2( | ||
μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan=nothing, kwargs... | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we rename to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe. I still think the main problem is not the name of the functions but the amount of functions - they are all doing the same thing but for slightly different problems (many even for the same) and different algorithms. So the natural approach would be to be able to dispatch both on the problem and the algorithm, which would also solve the problem that #66 tries to address but doesn't fix in a general and extendable way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In any case, I would suggest that both renaming and reorganization of functions should be done in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, you are right. It's better that we decide on the naming, and then submit in another PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a problem since if we have different functions for every combination of problem and algorithm (such as e.g.
|
||
γ = if plan === nothing | ||
sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; kwargs...) | ||
sinkhorn_unbalanced(μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; kwargs...) | ||
else | ||
# check dimensions | ||
size(C) == (length(μ), length(ν)) || | ||
|
Uh oh!
There was an error while loading. Please reload this page.