Skip to content

Reorganize sinkhorn_unbalanced, improve convergence checks, and fix GPU issues #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 28, 2021
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <[email protected]>"]
version = "0.3.3"
version = "0.3.4"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand Down
252 changes: 191 additions & 61 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,97 +304,227 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
end

"""
sinkhorn_unbalanced(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing)
sinkhorn_unbalanced(μ, ν, C, λ1::Real, λ2::Real, ε; kwargs...)

Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
for the marginals `mu`, `nu` respectively.
Compute the optimal transport plan for the unbalanced entropically regularized optimal
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation
terms `λ1` and `λ2`.

For full generality, the user can specify the soft marginal constraints ``(F_1(\\cdot | \\mu), F_2(\\cdot | \\nu))`` to the problem
The optimal transport plan `γ` is of the same size as `C` and solves
```math
\\inf_{\\gamma} \\langle \\gamma, C \\rangle
+ \\varepsilon \\Omega(\\gamma)
+ \\lambda_1 \\operatorname{KL}(\\gamma 1 | \\mu)
+ \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1 | \\nu),
```
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic
regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence.

The keyword arguments supported here are the same as those in the `sinkhorn_unbalanced`
for unbalanced optimal transport problems with general soft marginal constraints.
"""
function sinkhorn_unbalanced(
μ, ν, C, λ1::Real, λ2::Real, ε; proxdiv_F1=nothing, proxdiv_F2=nothing, kwargs...
)
if proxdiv_F1 !== nothing && proxdiv_F2 !== nothing
Base.depwarn(
"keyword arguments `proxdiv_F1` and `proxdiv_F2` are deprecated",
:sinkhorn_unbalanced,
)

# have to wrap the "proxdiv" functions since the signature changed
# ε was fixed in the function, so we ignore it
proxdiv_F1_wrapper(s, p, _) = copyto!(s, proxdiv_F1(s, p))
proxdiv_F2_wrapper(s, p, _) = copyto!(s, proxdiv_F2(s, p))

return sinkhorn_unbalanced(
μ, ν, C, proxdiv_F1_wrapper, proxdiv_F2_wrapper, ε; kwargs...
)
end

# define "proxdiv" functions for the unbalanced OT problem
proxdivF!(s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ)))
proxdivF1!(s, p, ε) = proxdivF!(s, p, ε, λ1)
proxdivF2!(s, p, ε) = proxdivF!(s, p, ε, λ2)

return sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs...)
end

"""
sinkhorn_unbalanced(
μ, ν, C, proxdivF1!, proxdivF2!, ε;
atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000,
)

Compute the optimal transport plan for the unbalanced entropically regularized optimal
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal
constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`.

The optimal transport plan `γ` is of the same size as `C` and solves
```math
\\inf_{\\gamma} \\langle \\gamma, C \\rangle
+ \\varepsilon \\Omega(\\gamma)
+ F_1(\\gamma 1, \\mu)
+ F_2(\\gamma^{\\mathsf{T}} 1, \\nu),
```
where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic
regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)`` are soft marginal
constraints for the source and target marginals.

The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv"
functions of ``F_1(\\cdot, p)`` and ``F_2(\\cdot, p)`` at ``s`` for the entropic
regularization parameter ``\\varepsilon``. They have to be mutating and overwrite the first
argument `s` with the result of their computations.

Mathematically, the "proxdiv" functions are defined as
```math
\\operatorname{proxdiv}_{F_i}(s, p, \\varepsilon)
= \\operatorname{prox}^{\\operatorname{KL}}_{F_i(\\cdot, p)/\\varepsilon}(s) \\oslash s
```
where ``\\oslash`` denotes element-wise division and
``\\operatorname{prox}_{F_i(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}`` is the proximal
operator of ``F_i(\\cdot, p)/\\varepsilon`` for the Kullback-Leibler
(``\\operatorname{KL}``) divergence. It is defined as
```math
\\operatorname{prox}_{F}^{\\operatorname{KL}}(x)
= \\operatorname{argmin}_{y} F(y) + \\operatorname{KL}(y|x)
```
and can be computed in closed-form for specific choices of ``F``. For instance, if
``F(\\cdot, p) = \\lambda \\operatorname{KL}(\\cdot | p)`` (``\\lambda > 0``), then
```math
\\min_\\gamma \\epsilon \\mathrm{KL}(\\gamma | \\exp(-C/\\epsilon)) + F_1(\\gamma_1 | \\mu) + F_2(\\gamma_2 | \\nu)
\\operatorname{prox}_{F(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}(x)
= x^{\\frac{\\varepsilon}{\\varepsilon + \\lambda}} p^{\\frac{\\lambda}{\\varepsilon + \\lambda}},
```
where all operators are acting pointwise.[^CPSV18]

Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
the iterates of the scaling factor in the current and previous iteration satisfy
`isapprox(vcat(a, b), vcat(aprev, bprev); atol=atol, rtol=rtol)` where `a` and `b` are the
current iterates and `aprev` and `bprev` the previous ones. The default `rtol` depends on
the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the computation is stopped.

via `math\\mathrm{proxdiv}_{F_1}(s, p)` and `math\\mathrm{proxdiv}_{F_2}(s, p)` (see Chizat et al., 2016 for details on this). If specified, the algorithm will use the user-specified F1, F2 rather than the default (a KL-divergence).
[^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609.

See also: [`sinkhorn_unbalanced2`](@ref)
"""
function sinkhorn_unbalanced(
mu,
nu,
μ,
ν,
C,
lambda1,
lambda2,
eps;
tol=1e-9,
max_iter=1000,
verbose=false,
proxdiv_F1=nothing,
proxdiv_F2=nothing,
proxdivF1!,
proxdivF2!,
ε;
tol=nothing,
atol=tol,
rtol=nothing,
max_iter=nothing,
maxiter=max_iter,
check_convergence::Int=10,
)
function proxdiv_KL(s, eps, lambda, p)
return @. (s^(eps / (eps + lambda)) * p^(lambda / (eps + lambda))) / s
# deprecations
if tol !== nothing
Base.depwarn(
"keyword argument `tol` is deprecated, please use `atol` and `rtol`",
:sinkhorn_unbalanced,
)
end
if max_iter !== nothing
Base.depwarn(
"keyword argument `max_iter` is deprecated, please use `maxiter`",
:sinkhorn_unbalanced,
)
end

# compute Gibbs kernel
K = @. exp(-C / ε)

a = ones(size(mu, 1))
b = ones(size(nu, 1))
a_old = a
b_old = b
tmp_a = zeros(size(nu, 1))
tmp_b = zeros(size(mu, 1))
# set default values of squared tolerances
T = float(Base.promote_eltype(μ, ν, K))
sqatol = atol === nothing ? 0 : atol^2
sqrtol = rtol === nothing ? (sqatol > zero(sqatol) ? zero(T) : eps(T)) : rtol^2

K = @. exp(-C / eps)
# initialize iterates
a = similar(μ, T)
sum!(a, K)
proxdivF1!(a, μ, ε)
b = similar(ν, T)
mul!(b, K', a)
proxdivF2!(b, ν, ε)

iter = 1
# caches for convergence checks
a_old = similar(a)
b_old = similar(b)

while true
a_old = a
b_old = b
tmp_b = K * b
if proxdiv_F1 == nothing
a = proxdiv_KL(tmp_b, eps, lambda1, mu)
else
a = proxdiv_F1(tmp_b, mu)
end
tmp_a = K' * a
if proxdiv_F2 == nothing
b = proxdiv_KL(tmp_a, eps, lambda2, nu)
else
b = proxdiv_F2(tmp_a, nu)
isconverged = false
_maxiter = maxiter === nothing ? 1_000 : maxiter
for iter in 1:_maxiter
# update cache if necessary
ischeck = iter % check_convergence == 0
if ischeck
copyto!(a_old, a)
copyto!(b_old, b)
end
iter += 1
if iter % 10 == 0
err_a =
maximum(abs.(a - a_old)) / max(maximum(abs.(a)), maximum(abs.(a_old)), 1)
err_b =
maximum(abs.(b - b_old)) / max(maximum(abs.(b)), maximum(abs.(b_old)), 1)
if verbose
println("Iteration $iter, err = ", 0.5 * (err_a + err_b))
end
if (0.5 * (err_a + err_b) < tol) || iter > max_iter

# compute next iterates
mul!(a, K, b)
proxdivF1!(a, μ, ε)
mul!(b, K', a)
proxdivF2!(b, ν, ε)

# check convergence of the scaling factors
if ischeck
# compute norm of current and previous scaling factors and their difference
sqnorm_a_b = sum(abs2, a) + sum(abs2, b)
sqnorm_a_b_old = sum(abs2, a_old) + sum(abs2, b_old)
a_old .-= a
b_old .-= b
sqeuclidean_a_b = sum(abs2, a_old) + sum(abs2, b_old)
@debug "Sinkhorn algorithm (" *
string(iter) *
"/" *
string(_maxiter) *
": squared Euclidean distance of iterates = " *
string(sqeuclidean_a_b)

# check convergence of `a`
if sqeuclidean_a_b < max(sqatol, sqrtol * max(sqnorm_a_b, sqnorm_a_b_old))
@debug "Sinkhorn algorithm ($iter/$_maxiter): converged"
isconverged = true
break
end
end
end
if iter > max_iter && verbose
println("Warning: exited before convergence")

if !isconverged
@warn "Sinkhorn algorithm ($_maxiter/$_maxiter): not converged"
end
return Diagonal(a) * K * Diagonal(b)

return K .* a .* b'
end

"""
sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...)
sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...)
sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...)

Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`,
using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)`
for the marginals mu, nu respectively.
Compute the optimal transport plan for the unbalanced entropically regularized optimal
transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation
terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and
`proxdivF2!`.

A pre-computed optimal transport `plan` may be provided.
A pre-computed optimal transport `plan` may be provided. The other keyword arguments
supported here are the same as those in the [`sinkhorn_unbalanced`](@ref) for unbalanced
optimal transport problems with general soft marginal constraints.

See also: [`sinkhorn_unbalanced`](@ref)
"""
function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...)
function sinkhorn_unbalanced2(
μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan=nothing, kwargs...
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we rename to sinkhorn_unbalanced_cost?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. I still think the main problem is not the name of the functions but the amount of functions - they are all doing the same thing but for slightly different problems (many even for the same) and different algorithms. So the natural approach would be to be able to dispatch both on the problem and the algorithm, which would also solve the problem that #66 tries to address but doesn't fix in a general and extendable way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I would suggest that both renaming and reorganization of functions should be done in a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you are right. It's better that we decide on the naming, and then submit in another PR.
" I still think the main problem is not the name of the functions but the amount of functions - they are all doing the same thing but for slightly different problems (many even for the same) and different algorithms." Why is this a problem?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a problem since if we have different functions for every combination of problem and algorithm (such as e.g. sinkhorn, sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_unbalanced etc.)

  • the API is unstructured and difficult to navigate for users,
  • it is very difficult to compose functionality (e.g. if I would like to use epsilon scaling with the unbalanced Sinkhorn algorithm I have to write a new function instead of just composing epsilon scaling with the unbalanced algorithm),
  • we completely neglect multiple dispatch which arguably is the biggest feature of Julia.

γ = if plan === nothing
sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; kwargs...)
sinkhorn_unbalanced(μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; kwargs...)
else
# check dimensions
size(C) == (length(μ), length(ν)) ||
Expand Down
Loading