From 4aa431c44764b2e703169ef548ac6c7e7c2ce1d3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 27 May 2021 20:26:28 +0200 Subject: [PATCH 01/10] Reorganize `sinkhorn_unbalanced` and improve convergence checks --- src/OptimalTransport.jl | 250 ++++++++++++++++++++++++++++++---------- 1 file changed, 190 insertions(+), 60 deletions(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index b44cf6cd..6e4a2609 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -239,97 +239,227 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...) end """ - sinkhorn_unbalanced(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing) + sinkhorn_unbalanced( + μ, ν, C, λ1::Real, λ2::Real, ε; + atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, + ) + +Compute the optimal transport plan for the unbalanced entropic regularization optimal +transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation +terms `λ1` and `λ2`. -Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, -using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` -for the marginals `mu`, `nu` respectively. +The optimal transport plan `γ` is of the same size as `C` and solves +```math +\\inf_{\\gamma} \\langle \\gamma, C \\rangle ++ \\varepsilon \\Omega(\\gamma) ++ \\lambda_1 \\operatorname{KL}(\\gamma 1, \\mu) ++ \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1, \\nu), +``` +where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic +regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence. -For full generality, the user can specify the soft marginal constraints ``(F_1(\\cdot | \\mu), F_2(\\cdot | \\nu))`` to the problem +Every `check_convergence` steps a convergence check of the error of the scaling factors +with absolute tolerance `atol` and relative tolerance `rtol` is performed. The default +`rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the +computation is stopped. +""" +function sinkhorn_unbalanced( + μ, ν, C, λ1::Real, λ2::Real, ε; proxdiv_F1=nothing, proxdiv_F2=nothing, kwargs... +) + if proxdiv_F1 !== nothing && proxdiv_F2 !== nothing + Base.depwarn( + "keyword arguments `proxdiv_F1` and `proxdiv_F2` are deprecated", + :sinkhorn_unbalanced, + ) + + # have to wrap the "proxdiv" functions since the signature changed + # ε was fixed in the function, so we ignore it + proxdiv_F1_wrapper(s, p, _) = copyto!(s, proxdiv_F1(s, p)) + proxdiv_F2_wrapper(s, p, _) = copyto!(s, proxdiv_F2(s, p)) + + return sinkhorn_unbalanced( + μ, ν, C, proxdiv_F1_wrapper, proxdiv_F2_wrapper, ε; kwargs... + ) + end + # define "proxdiv" functions for the unbalanced OT problem + proxdivF!(s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ))) + proxdivF1!(s, p, ε) = proxdivF!(s, p, ε, λ1) + proxdivF2!(s, p, ε) = proxdivF!(s, p, ε, λ2) + + return sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs...) +end + +""" + sinkhorn_unbalanced( + μ, ν, C, proxdivF1!, proxdivF2!, ε; + atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, + ) + +Compute the optimal transport plan for the unbalanced entropic regularization optimal +transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal +constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`. + +The optimal transport plan `γ` is of the same size as `C` and solves ```math -\\min_\\gamma \\epsilon \\mathrm{KL}(\\gamma | \\exp(-C/\\epsilon)) + F_1(\\gamma_1 | \\mu) + F_2(\\gamma_2 | \\nu) +\\inf_{\\gamma} \\langle \\gamma, C \\rangle ++ \\varepsilon \\Omega(\\gamma) ++ F_1(\\gamma 1, \\mu) ++ F_2(\\gamma^{\\mathsf{T}} 1, \\nu), ``` +where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic +regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)` are soft marginal +constraints for the source and target marginals. -via `math\\mathrm{proxdiv}_{F_1}(s, p)` and `math\\mathrm{proxdiv}_{F_2}(s, p)` (see Chizat et al., 2016 for details on this). If specified, the algorithm will use the user-specified F1, F2 rather than the default (a KL-divergence). +The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv" +functions of ``F_1(\\cdot, p)`` and ``F_2(\\cdot, p)`` at ``s`` for the entropic +regularization parameter ``\\varepsilon``. They have to be mutating and overwrite the first +argument `s` with the result of their computations. + +Mathematically, the "proxdiv" functions are defined as +```math +\\operatorname{proxdiv}_{F_i}(s, p, \\varepsilon) += \\operatorname{prox}^{\\operatorname{KL}}_{F_i(\\cdot, p)/\\varepsilon}(s) \\oslash s +``` +where ``\\oslash`` denotes element-wise division and +``\\operatorname{prox}_{F_i(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}`` is the proximal +operator of ``F_i(\\cdot, p)/\\varepsilon`` for the Kullback-Leibler +(``\\operatorname{KL}``) divergence. It is defined as +```math +\\operatorname{prox}_{F}^{\\operatorname{KL}}(x) += \\operatorname{argmin}_{y} F(y) + \\operatorname{KL}(y|x) +``` +and can be computed in closed-form for specific choices of ``F``. For instance, if +``F(\\cdot, p) = \\lambda \\operatorname{KL}(\\cdot | p)`` (``\\lambda > 0``), then +```math +\\operatorname{prox}_{F(\\cdot | p)/\\varepsilon}^{\\operatorname{KL}}(x) += x^{\\frac{\\varepsilon}{\\varepsilon + \\lambda}} p^{\\frac{\\lambda}{\\varepsilon + \\lambda}}, +``` +where all operators are acting pointwise.[^CPSV18] + +Every `check_convergence` steps a convergence check of the error of the scaling factors +with absolute tolerance `atol` and relative tolerance `rtol` is performed. The default +`rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the +computation is stopped. + +[^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609. """ function sinkhorn_unbalanced( - mu, - nu, + μ, + ν, C, - lambda1, - lambda2, - eps; - tol=1e-9, - max_iter=1000, - verbose=false, - proxdiv_F1=nothing, - proxdiv_F2=nothing, + proxdivF1!, + proxdivF2!, + ε; + tol=nothing, + atol=tol, + rtol=nothing, + max_iter=nothing, + maxiter=max_iter, + check_convergence::Int=10, ) - function proxdiv_KL(s, eps, lambda, p) - return @. (s^(eps / (eps + lambda)) * p^(lambda / (eps + lambda))) / s + # deprecations + if tol !== nothing + Base.depwarn( + "keyword argument `tol` is deprecated, please use `atol` and `rtol`", + :sinkhorn_unbalanced, + ) + end + if max_iter !== nothing + Base.depwarn( + "keyword argument `max_iter` is deprecated, please use `maxiter`", + :sinkhorn_unbalanced, + ) end - a = ones(size(mu, 1)) - b = ones(size(nu, 1)) - a_old = a - b_old = b - tmp_a = zeros(size(nu, 1)) - tmp_b = zeros(size(mu, 1)) + # compute Gibbs kernel + K = @. exp(-C / ε) - K = @. exp(-C / eps) + # set default values of squared tolerances + T = float(Base.promote_eltype(μ, ν, K)) + sqatol = atol === nothing ? 0 : atol^2 + sqrtol = rtol === nothing ? (sqatol > zero(sqatol) ? zero(T) : eps(T)) : rtol^2 - iter = 1 + # initialize iterates + a = similar(μ, T) + sum!(a, K) + proxdivF1!(a, μ, ε) + b = similar(ν, T) + mul!(b, K', a) + proxdivF2!(b, ν, ε) - while true - a_old = a - b_old = b - tmp_b = K * b - if proxdiv_F1 == nothing - a = proxdiv_KL(tmp_b, eps, lambda1, mu) - else - a = proxdiv_F1(tmp_b, mu) - end - tmp_a = K' * a - if proxdiv_F2 == nothing - b = proxdiv_KL(tmp_a, eps, lambda2, nu) - else - b = proxdiv_F2(tmp_a, nu) + # caches for convergence checks + a_old = similar(a) + b_old = similar(b) + + isconverged = false + _maxiter = maxiter === nothing ? 1_000 : maxiter + for iter in 1:_maxiter + # update cache if necessary + ischeck = iter % check_convergence == 0 + if ischeck + copyto!(a_old, a) + copyto!(b_old, b) end - iter += 1 - if iter % 10 == 0 - err_a = - maximum(abs.(a - a_old)) / max(maximum(abs.(a)), maximum(abs.(a_old)), 1) - err_b = - maximum(abs.(b - b_old)) / max(maximum(abs.(b)), maximum(abs.(b_old)), 1) - if verbose - println("Iteration $iter, err = ", 0.5 * (err_a + err_b)) - end - if (0.5 * (err_a + err_b) < tol) || iter > max_iter + + # compute next iterates + mul!(a, K, b) + proxdivF1!(a, μ, ε) + mul!(b, K', a) + proxdivF2!(b, ν, ε) + + # check convergence of the scaling factors + if ischeck + # compute norm of current and previous scaling factors and their difference + sqnorm_a_b = sum(abs2, a) + sum(abs2, b) + sqnorm_a_b_old = sum(abs2, a_old) + sum(abs2, b_old) + a_old .-= a + b_old .-= b + sqeuclidean_a_b = sum(abs2, a_old) + sum(abs2, b_old) + @debug "Sinkhorn algorithm (" * + string(iter) * + "/" * + string(_maxiter) * + ": squared Euclidean distance of iterates = " * + string(sqeuclidean_a_b) + + # check convergence of `a` + if sqeuclidean_a_b < max(sqatol, sqrtol * max(sqnorm_a_b, sqnorm_a_b_old)) + @debug "Sinkhorn algorithm ($iter/$_maxiter): converged" + isconverged = true break end end end - if iter > max_iter && verbose - println("Warning: exited before convergence") + + if !isconverged + @warn "Sinkhorn algorithm ($_maxiter/$_maxiter): not converged" end - return Diagonal(a) * K * Diagonal(b) + + return K .* a .* b' end """ - sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...) + sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) + sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...) -Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, -using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` -for the marginals mu, nu respectively. +Compute the optimal transport plan for the unbalanced entropic regularization optimal +transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation +terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and +`proxdivF2!`. A pre-computed optimal transport `plan` may be provided. See also: [`sinkhorn_unbalanced`](@ref) """ -function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) +function sinkhorn_unbalanced2( + μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan=nothing, kwargs... +) γ = if plan === nothing - sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; kwargs...) + sinkhorn_unbalanced(μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; kwargs...) else # check dimensions size(C) == (length(μ), length(ν)) || From 55a6edd38da2f9b16dbe27014e4efed35b0a7fd6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 27 May 2021 20:27:10 +0200 Subject: [PATCH 02/10] Extend tests --- test/runtests.jl | 100 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 94 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 920a240a..7bbe653b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -136,31 +136,119 @@ end @testset "unbalanced transport" begin M = 250 N = 200 + @testset "example" begin μ = fill(1 / N, M) ν = fill(1 / N, N) C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + # compute optimal transport plan eps = 0.01 lambda = 1 γ = sinkhorn_unbalanced(μ, ν, C, lambda, lambda, eps) + + # compare with POT γ_pot = POT.sinkhorn_unbalanced(μ, ν, C, eps, lambda; stopThr=1e-9) + @test γ_pot ≈ γ - # compute optimal transport map - @test norm(γ - γ_pot, Inf) < 1e-9 + # compute optimal transport cost + c = sinkhorn_unbalanced2(μ, ν, C, lambda, lambda, eps; maxiter=5_000) - c = sinkhorn_unbalanced2(μ, ν, C, lambda, lambda, eps; max_iter=5_000) + # compare with POT c_pot = POT.sinkhorn_unbalanced2( μ, ν, C, eps, lambda; numItermax=5_000, stopThr=1e-9 )[1] + @test c_pot ≈ c - @test c ≈ c_pot atol = 1e-9 - - # ensure that provided map is used + # ensure that provided plan is used c2 = sinkhorn_unbalanced2(similar(μ), similar(ν), C, rand(), rand(), rand(); plan=γ) @test c2 ≈ c end + + @testset "proxdiv operators" begin + μ = fill(1 / N, M) + ν = fill(1 / N, N) + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # compute optimal transport plan and cost with real-valued parameters + eps = 0.01 + lambda1 = 0.4 + lambda2 = 0.5 + γ = sinkhorn_unbalanced(μ, ν, C, lambda1, lambda2, eps) + c = sinkhorn_unbalanced2(μ, ν, C, lambda1, lambda2, eps) + @test c ≈ dot(γ, C) + + # compute optimal transport plan and cost with manual "proxdiv" functions + proxdivF1!(s, p, ε) = (s .= s .^ (ε / (ε + 0.4)) .* p .^ (0.4 / (ε + 0.4)) ./ s) + proxdivF2!(s, p, ε) = (s .= s .^ (ε / (ε + 0.5)) .* p .^ (0.5 / (ε + 0.5)) ./ s) + γ_proxdiv = sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, eps) + c_proxdiv = sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, eps) + @test γ_proxdiv ≈ γ + @test c_proxdiv ≈ c + end + + @testset "consistency with balanced OT" begin + μ = fill(1 / M, M) + ν = fill(1 / N, N) + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # compute optimal transport plan and cost with manual "proxdiv" functions for + # balanced OT + ε = 0.01 + proxdivF!(s, p, ε) = (s .= p ./ s) + γ = sinkhorn_unbalanced(μ, ν, C, proxdivF!, proxdivF!, ε) + c = sinkhorn_unbalanced2(μ, ν, C, proxdivF!, proxdivF!, ε) + @test c ≈ dot(γ, C) + + # compute optimal transport plan and cost for balanced OT + γ_balanced = sinkhorn(μ, ν, C, ε) + c_balanced = sinkhorn2(μ, ν, C, ε) + @test γ_balanced ≈ γ rtol=1e-4 + @test c_balanced ≈ c rtol=1e-4 + end + + @testset "deprecations" begin + μ = fill(1 / N, M) + ν = fill(1 / N, N) + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # compute optimal transport plan and cost with real-valued parameters + ε = 0.01 + λ1 = 0.4 + λ2 = 0.5 + γ = sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε) + c = sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε) + + # compute optimal transport plan and cost with manual "proxdiv" functions + # as keyword arguments + proxdivF1(s, p) = s .^ (0.01 / (0.01 + 0.4)) .* p .^ (0.4 / (0.01 + 0.4)) ./ s + proxdivF2(s, p) = s .^ (0.01 / (0.01 + 0.5)) .* p .^ (0.5 / (0.01 + 0.5)) ./ s + γ_proxdiv = @test_deprecated sinkhorn_unbalanced( + μ, ν, C, rand(), rand(), ε; proxdiv_F1=proxdivF1, proxdiv_F2=proxdivF2 + ) + c_proxdiv = @test_deprecated sinkhorn_unbalanced2( + μ, ν, C, rand(), rand(), ε; proxdiv_F1=proxdivF1, proxdiv_F2=proxdivF2 + ) + @test γ_proxdiv ≈ γ + @test c_proxdiv ≈ c + + # deprecated `tol` keyword argument + γ = sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; atol=1e-7) + c = sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; atol=1e-7) + γ_tol = @test_deprecated sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; tol=1e-7) + c_tol = @test_deprecated sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; tol=1e-7) + @test γ_tol == γ + @test c_tol == c + + # deprecated `max_iter` keyword argument + γ = sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; maxiter=50) + c = sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; maxiter=50) + γ_max_iter = @test_deprecated sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; max_iter=50) + c_max_iter = @test_deprecated sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; max_iter=50) + @test γ_max_iter == γ + @test c_max_iter == c + end end @testset "stabilized sinkhorn" begin From 7b5e2cf6d71f1ba8f05e6ab19186ad8883a42663 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 27 May 2021 20:27:28 +0200 Subject: [PATCH 03/10] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0bd68983..de483865 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.2" +version = "0.3.3" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" From b50289c950191c0eb42b3f7e2014dae5aa1bb58f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 27 May 2021 21:00:44 +0200 Subject: [PATCH 04/10] Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7bbe653b..f0ef2ac1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -204,8 +204,8 @@ end # compute optimal transport plan and cost for balanced OT γ_balanced = sinkhorn(μ, ν, C, ε) c_balanced = sinkhorn2(μ, ν, C, ε) - @test γ_balanced ≈ γ rtol=1e-4 - @test c_balanced ≈ c rtol=1e-4 + @test γ_balanced ≈ γ rtol = 1e-4 + @test c_balanced ≈ c rtol = 1e-4 end @testset "deprecations" begin From cb14d9ee3eacdae010233a558e84687ad68802ce Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 28 May 2021 00:33:57 +0200 Subject: [PATCH 05/10] Remove redundant information --- src/OptimalTransport.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 6e4a2609..9c24bc44 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -239,10 +239,7 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...) end """ - sinkhorn_unbalanced( - μ, ν, C, λ1::Real, λ2::Real, ε; - atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, - ) + sinkhorn_unbalanced(μ, ν, C, λ1::Real, λ2::Real, ε; kwargs...) Compute the optimal transport plan for the unbalanced entropic regularization optimal transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size @@ -259,10 +256,8 @@ The optimal transport plan `γ` is of the same size as `C` and solves where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence. -Every `check_convergence` steps a convergence check of the error of the scaling factors -with absolute tolerance `atol` and relative tolerance `rtol` is performed. The default -`rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the -computation is stopped. +The keyword arguments supported here are the same as those in the `sinkhorn_unbalanced` +for unbalanced optimal transport problems with general soft marginal constraints. """ function sinkhorn_unbalanced( μ, ν, C, λ1::Real, λ2::Real, ε; proxdiv_F1=nothing, proxdiv_F2=nothing, kwargs... @@ -451,7 +446,9 @@ transport problem with source and target marginals `μ` and `ν`, cost matrix `C terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and `proxdivF2!`. -A pre-computed optimal transport `plan` may be provided. +A pre-computed optimal transport `plan` may be provided. The other keyword arguments +supported here are the same as those in the [`sinkhorn_unbalanced`](@ref) for unbalanced +optimal transport problems with general soft marginal constraints. See also: [`sinkhorn_unbalanced`](@ref) """ From ebee41d20ebeeb49fd43e0f7fc03a500d6d8c007 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 28 May 2021 00:34:21 +0200 Subject: [PATCH 06/10] Explain `rtol` and `atol` more clearly --- src/OptimalTransport.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 9c24bc44..5900dd62 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -334,12 +334,15 @@ and can be computed in closed-form for specific choices of ``F``. For instance, ``` where all operators are acting pointwise.[^CPSV18] -Every `check_convergence` steps a convergence check of the error of the scaling factors -with absolute tolerance `atol` and relative tolerance `rtol` is performed. The default -`rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the -computation is stopped. +Every `check_convergence` steps it is assessed if the algorithm is converged by checking if +the iterates of the scaling factor in the current and previous iteration satisfy +`isapprox(vcat(a, b), vcat(aprev, bprev); atol=atol, rtol=rtol)` where `a` and `b` are the +current iterates and `aprev` and `bprev` the previous ones. The default `rtol` depends on +the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the computation is stopped. [^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609. + +See also: [`sinkhorn_unbalanced2`](@ref) """ function sinkhorn_unbalanced( μ, From eabbaad2545cad1d5680bc3e7ce0be774fddf477 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 28 May 2021 00:42:45 +0200 Subject: [PATCH 07/10] Use term `entropically regularized OT` --- src/OptimalTransport.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 5900dd62..3d8bcc3c 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -241,7 +241,7 @@ end """ sinkhorn_unbalanced(μ, ν, C, λ1::Real, λ2::Real, ε; kwargs...) -Compute the optimal transport plan for the unbalanced entropic regularization optimal +Compute the optimal transport plan for the unbalanced entropically regularized optimal transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation terms `λ1` and `λ2`. @@ -292,7 +292,7 @@ end atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, ) -Compute the optimal transport plan for the unbalanced entropic regularization optimal +Compute the optimal transport plan for the unbalanced entropically regularized optimal transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`. @@ -443,7 +443,7 @@ end sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...) -Compute the optimal transport plan for the unbalanced entropic regularization optimal +Compute the optimal transport plan for the unbalanced entropically regularized optimal transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and From 980e9ab2628ad3b7316709ec17bfe4ed12a3743c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 28 May 2021 02:02:05 +0200 Subject: [PATCH 08/10] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d1fa13f1..6effd863 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.3" +version = "0.3.4" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" From 9585ccd6478e20a17eb2a0fec1813af2359b1ab2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 28 May 2021 02:28:28 +0200 Subject: [PATCH 09/10] Fix docstring --- src/OptimalTransport.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 7f1bed21..40b509e8 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -315,8 +315,8 @@ The optimal transport plan `γ` is of the same size as `C` and solves ```math \\inf_{\\gamma} \\langle \\gamma, C \\rangle + \\varepsilon \\Omega(\\gamma) -+ \\lambda_1 \\operatorname{KL}(\\gamma 1, \\mu) -+ \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1, \\nu), ++ \\lambda_1 \\operatorname{KL}(\\gamma 1 | \\mu) ++ \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1 | \\nu), ``` where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence. From 31dc3f62b658526ec6bf087eb302edad4c4be27e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 28 May 2021 02:43:03 +0200 Subject: [PATCH 10/10] More docstring fixes --- src/OptimalTransport.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 40b509e8..7879b70c 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -370,7 +370,7 @@ The optimal transport plan `γ` is of the same size as `C` and solves + F_2(\\gamma^{\\mathsf{T}} 1, \\nu), ``` where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic -regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)` are soft marginal +regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)`` are soft marginal constraints for the source and target marginals. The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv" @@ -394,7 +394,7 @@ operator of ``F_i(\\cdot, p)/\\varepsilon`` for the Kullback-Leibler and can be computed in closed-form for specific choices of ``F``. For instance, if ``F(\\cdot, p) = \\lambda \\operatorname{KL}(\\cdot | p)`` (``\\lambda > 0``), then ```math -\\operatorname{prox}_{F(\\cdot | p)/\\varepsilon}^{\\operatorname{KL}}(x) +\\operatorname{prox}_{F(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}(x) = x^{\\frac{\\varepsilon}{\\varepsilon + \\lambda}} p^{\\frac{\\lambda}{\\varepsilon + \\lambda}}, ``` where all operators are acting pointwise.[^CPSV18]