diff --git a/Project.toml b/Project.toml index d1fa13f1..6effd863 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.3" +version = "0.3.4" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 29c62063..7879b70c 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -304,97 +304,227 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...) end """ - sinkhorn_unbalanced(mu, nu, C, lambda1, lambda2, eps; tol = 1e-9, max_iter = 1000, verbose = false, proxdiv_F1 = nothing, proxdiv_F2 = nothing) + sinkhorn_unbalanced(μ, ν, C, λ1::Real, λ2::Real, ε; kwargs...) -Computes the optimal transport plan of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, -using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` -for the marginals `mu`, `nu` respectively. +Compute the optimal transport plan for the unbalanced entropically regularized optimal +transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation +terms `λ1` and `λ2`. -For full generality, the user can specify the soft marginal constraints ``(F_1(\\cdot | \\mu), F_2(\\cdot | \\nu))`` to the problem +The optimal transport plan `γ` is of the same size as `C` and solves +```math +\\inf_{\\gamma} \\langle \\gamma, C \\rangle ++ \\varepsilon \\Omega(\\gamma) ++ \\lambda_1 \\operatorname{KL}(\\gamma 1 | \\mu) ++ \\lambda_2 \\operatorname{KL}(\\gamma^{\\mathsf{T}} 1 | \\nu), +``` +where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic +regularization term and ``\\operatorname{KL}`` is the Kullback-Leibler divergence. + +The keyword arguments supported here are the same as those in the `sinkhorn_unbalanced` +for unbalanced optimal transport problems with general soft marginal constraints. +""" +function sinkhorn_unbalanced( + μ, ν, C, λ1::Real, λ2::Real, ε; proxdiv_F1=nothing, proxdiv_F2=nothing, kwargs... +) + if proxdiv_F1 !== nothing && proxdiv_F2 !== nothing + Base.depwarn( + "keyword arguments `proxdiv_F1` and `proxdiv_F2` are deprecated", + :sinkhorn_unbalanced, + ) + # have to wrap the "proxdiv" functions since the signature changed + # ε was fixed in the function, so we ignore it + proxdiv_F1_wrapper(s, p, _) = copyto!(s, proxdiv_F1(s, p)) + proxdiv_F2_wrapper(s, p, _) = copyto!(s, proxdiv_F2(s, p)) + + return sinkhorn_unbalanced( + μ, ν, C, proxdiv_F1_wrapper, proxdiv_F2_wrapper, ε; kwargs... + ) + end + + # define "proxdiv" functions for the unbalanced OT problem + proxdivF!(s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ))) + proxdivF1!(s, p, ε) = proxdivF!(s, p, ε, λ1) + proxdivF2!(s, p, ε) = proxdivF!(s, p, ε, λ2) + + return sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs...) +end + +""" + sinkhorn_unbalanced( + μ, ν, C, proxdivF1!, proxdivF2!, ε; + atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, + ) + +Compute the optimal transport plan for the unbalanced entropically regularized optimal +transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, entropic regularization parameter `ε`, and soft marginal +constraints ``F_1`` and ``F_2`` with "proxdiv" functions `proxdivF!` and `proxdivG!`. + +The optimal transport plan `γ` is of the same size as `C` and solves +```math +\\inf_{\\gamma} \\langle \\gamma, C \\rangle ++ \\varepsilon \\Omega(\\gamma) ++ F_1(\\gamma 1, \\mu) ++ F_2(\\gamma^{\\mathsf{T}} 1, \\nu), +``` +where ``\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j} \\log \\gamma_{i,j}`` is the entropic +regularization term and ``F_1(\\cdot, \\mu)`` and ``F_2(\\cdot, \\nu)`` are soft marginal +constraints for the source and target marginals. + +The functions `proxdivF1!(s, p, ε)` and `proxdivF2!(s, p, ε)` evaluate the "proxdiv" +functions of ``F_1(\\cdot, p)`` and ``F_2(\\cdot, p)`` at ``s`` for the entropic +regularization parameter ``\\varepsilon``. They have to be mutating and overwrite the first +argument `s` with the result of their computations. + +Mathematically, the "proxdiv" functions are defined as +```math +\\operatorname{proxdiv}_{F_i}(s, p, \\varepsilon) += \\operatorname{prox}^{\\operatorname{KL}}_{F_i(\\cdot, p)/\\varepsilon}(s) \\oslash s +``` +where ``\\oslash`` denotes element-wise division and +``\\operatorname{prox}_{F_i(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}`` is the proximal +operator of ``F_i(\\cdot, p)/\\varepsilon`` for the Kullback-Leibler +(``\\operatorname{KL}``) divergence. It is defined as +```math +\\operatorname{prox}_{F}^{\\operatorname{KL}}(x) += \\operatorname{argmin}_{y} F(y) + \\operatorname{KL}(y|x) +``` +and can be computed in closed-form for specific choices of ``F``. For instance, if +``F(\\cdot, p) = \\lambda \\operatorname{KL}(\\cdot | p)`` (``\\lambda > 0``), then ```math -\\min_\\gamma \\epsilon \\mathrm{KL}(\\gamma | \\exp(-C/\\epsilon)) + F_1(\\gamma_1 | \\mu) + F_2(\\gamma_2 | \\nu) +\\operatorname{prox}_{F(\\cdot, p)/\\varepsilon}^{\\operatorname{KL}}(x) += x^{\\frac{\\varepsilon}{\\varepsilon + \\lambda}} p^{\\frac{\\lambda}{\\varepsilon + \\lambda}}, ``` +where all operators are acting pointwise.[^CPSV18] + +Every `check_convergence` steps it is assessed if the algorithm is converged by checking if +the iterates of the scaling factor in the current and previous iteration satisfy +`isapprox(vcat(a, b), vcat(aprev, bprev); atol=atol, rtol=rtol)` where `a` and `b` are the +current iterates and `aprev` and `bprev` the previous ones. The default `rtol` depends on +the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the computation is stopped. -via `math\\mathrm{proxdiv}_{F_1}(s, p)` and `math\\mathrm{proxdiv}_{F_2}(s, p)` (see Chizat et al., 2016 for details on this). If specified, the algorithm will use the user-specified F1, F2 rather than the default (a KL-divergence). +[^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609. + +See also: [`sinkhorn_unbalanced2`](@ref) """ function sinkhorn_unbalanced( - mu, - nu, + μ, + ν, C, - lambda1, - lambda2, - eps; - tol=1e-9, - max_iter=1000, - verbose=false, - proxdiv_F1=nothing, - proxdiv_F2=nothing, + proxdivF1!, + proxdivF2!, + ε; + tol=nothing, + atol=tol, + rtol=nothing, + max_iter=nothing, + maxiter=max_iter, + check_convergence::Int=10, ) - function proxdiv_KL(s, eps, lambda, p) - return @. (s^(eps / (eps + lambda)) * p^(lambda / (eps + lambda))) / s + # deprecations + if tol !== nothing + Base.depwarn( + "keyword argument `tol` is deprecated, please use `atol` and `rtol`", + :sinkhorn_unbalanced, + ) end + if max_iter !== nothing + Base.depwarn( + "keyword argument `max_iter` is deprecated, please use `maxiter`", + :sinkhorn_unbalanced, + ) + end + + # compute Gibbs kernel + K = @. exp(-C / ε) - a = ones(size(mu, 1)) - b = ones(size(nu, 1)) - a_old = a - b_old = b - tmp_a = zeros(size(nu, 1)) - tmp_b = zeros(size(mu, 1)) + # set default values of squared tolerances + T = float(Base.promote_eltype(μ, ν, K)) + sqatol = atol === nothing ? 0 : atol^2 + sqrtol = rtol === nothing ? (sqatol > zero(sqatol) ? zero(T) : eps(T)) : rtol^2 - K = @. exp(-C / eps) + # initialize iterates + a = similar(μ, T) + sum!(a, K) + proxdivF1!(a, μ, ε) + b = similar(ν, T) + mul!(b, K', a) + proxdivF2!(b, ν, ε) - iter = 1 + # caches for convergence checks + a_old = similar(a) + b_old = similar(b) - while true - a_old = a - b_old = b - tmp_b = K * b - if proxdiv_F1 == nothing - a = proxdiv_KL(tmp_b, eps, lambda1, mu) - else - a = proxdiv_F1(tmp_b, mu) - end - tmp_a = K' * a - if proxdiv_F2 == nothing - b = proxdiv_KL(tmp_a, eps, lambda2, nu) - else - b = proxdiv_F2(tmp_a, nu) + isconverged = false + _maxiter = maxiter === nothing ? 1_000 : maxiter + for iter in 1:_maxiter + # update cache if necessary + ischeck = iter % check_convergence == 0 + if ischeck + copyto!(a_old, a) + copyto!(b_old, b) end - iter += 1 - if iter % 10 == 0 - err_a = - maximum(abs.(a - a_old)) / max(maximum(abs.(a)), maximum(abs.(a_old)), 1) - err_b = - maximum(abs.(b - b_old)) / max(maximum(abs.(b)), maximum(abs.(b_old)), 1) - if verbose - println("Iteration $iter, err = ", 0.5 * (err_a + err_b)) - end - if (0.5 * (err_a + err_b) < tol) || iter > max_iter + + # compute next iterates + mul!(a, K, b) + proxdivF1!(a, μ, ε) + mul!(b, K', a) + proxdivF2!(b, ν, ε) + + # check convergence of the scaling factors + if ischeck + # compute norm of current and previous scaling factors and their difference + sqnorm_a_b = sum(abs2, a) + sum(abs2, b) + sqnorm_a_b_old = sum(abs2, a_old) + sum(abs2, b_old) + a_old .-= a + b_old .-= b + sqeuclidean_a_b = sum(abs2, a_old) + sum(abs2, b_old) + @debug "Sinkhorn algorithm (" * + string(iter) * + "/" * + string(_maxiter) * + ": squared Euclidean distance of iterates = " * + string(sqeuclidean_a_b) + + # check convergence of `a` + if sqeuclidean_a_b < max(sqatol, sqrtol * max(sqnorm_a_b, sqnorm_a_b_old)) + @debug "Sinkhorn algorithm ($iter/$_maxiter): converged" + isconverged = true break end end end - if iter > max_iter && verbose - println("Warning: exited before convergence") + + if !isconverged + @warn "Sinkhorn algorithm ($_maxiter/$_maxiter): not converged" end - return Diagonal(a) * K * Diagonal(b) + + return K .* a .* b' end """ - sinkhorn_unbalanced2(mu, nu, C, lambda1, lambda2, eps; plan=nothing, kwargs...) + sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) + sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...) -Computes the optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and entropic regularization parameter `eps`, -using the unbalanced Sinkhorn algorithm [Chizat 2016] with KL-divergence terms for soft marginal constraints, with weights `(lambda1, lambda2)` -for the marginals mu, nu respectively. +Compute the optimal transport plan for the unbalanced entropically regularized optimal +transport problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, entropic regularization parameter `ε`, and marginal relaxation +terms `λ1` and `λ2` or soft marginal constraints with "proxdiv" functions `proxdivF1!` and +`proxdivF2!`. -A pre-computed optimal transport `plan` may be provided. +A pre-computed optimal transport `plan` may be provided. The other keyword arguments +supported here are the same as those in the [`sinkhorn_unbalanced`](@ref) for unbalanced +optimal transport problems with general soft marginal constraints. See also: [`sinkhorn_unbalanced`](@ref) """ -function sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) +function sinkhorn_unbalanced2( + μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan=nothing, kwargs... +) γ = if plan === nothing - sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; kwargs...) + sinkhorn_unbalanced(μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; kwargs...) else # check dimensions size(C) == (length(μ), length(ν)) || diff --git a/test/runtests.jl b/test/runtests.jl index 1e0bd259..c795a416 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -209,31 +209,119 @@ end @testset "unbalanced transport" begin M = 250 N = 200 + @testset "example" begin μ = fill(1 / N, M) ν = fill(1 / N, N) C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + # compute optimal transport plan eps = 0.01 lambda = 1 γ = sinkhorn_unbalanced(μ, ν, C, lambda, lambda, eps) + + # compare with POT γ_pot = POT.sinkhorn_unbalanced(μ, ν, C, eps, lambda; stopThr=1e-9) + @test γ_pot ≈ γ - # compute optimal transport map - @test norm(γ - γ_pot, Inf) < 1e-9 + # compute optimal transport cost + c = sinkhorn_unbalanced2(μ, ν, C, lambda, lambda, eps; maxiter=5_000) - c = sinkhorn_unbalanced2(μ, ν, C, lambda, lambda, eps; max_iter=5_000) + # compare with POT c_pot = POT.sinkhorn_unbalanced2( μ, ν, C, eps, lambda; numItermax=5_000, stopThr=1e-9 )[1] + @test c_pot ≈ c - @test c ≈ c_pot atol = 1e-9 - - # ensure that provided map is used + # ensure that provided plan is used c2 = sinkhorn_unbalanced2(similar(μ), similar(ν), C, rand(), rand(), rand(); plan=γ) @test c2 ≈ c end + + @testset "proxdiv operators" begin + μ = fill(1 / N, M) + ν = fill(1 / N, N) + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # compute optimal transport plan and cost with real-valued parameters + eps = 0.01 + lambda1 = 0.4 + lambda2 = 0.5 + γ = sinkhorn_unbalanced(μ, ν, C, lambda1, lambda2, eps) + c = sinkhorn_unbalanced2(μ, ν, C, lambda1, lambda2, eps) + @test c ≈ dot(γ, C) + + # compute optimal transport plan and cost with manual "proxdiv" functions + proxdivF1!(s, p, ε) = (s .= s .^ (ε / (ε + 0.4)) .* p .^ (0.4 / (ε + 0.4)) ./ s) + proxdivF2!(s, p, ε) = (s .= s .^ (ε / (ε + 0.5)) .* p .^ (0.5 / (ε + 0.5)) ./ s) + γ_proxdiv = sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, eps) + c_proxdiv = sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, eps) + @test γ_proxdiv ≈ γ + @test c_proxdiv ≈ c + end + + @testset "consistency with balanced OT" begin + μ = fill(1 / M, M) + ν = fill(1 / N, N) + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # compute optimal transport plan and cost with manual "proxdiv" functions for + # balanced OT + ε = 0.01 + proxdivF!(s, p, ε) = (s .= p ./ s) + γ = sinkhorn_unbalanced(μ, ν, C, proxdivF!, proxdivF!, ε) + c = sinkhorn_unbalanced2(μ, ν, C, proxdivF!, proxdivF!, ε) + @test c ≈ dot(γ, C) + + # compute optimal transport plan and cost for balanced OT + γ_balanced = sinkhorn(μ, ν, C, ε) + c_balanced = sinkhorn2(μ, ν, C, ε) + @test γ_balanced ≈ γ rtol = 1e-4 + @test c_balanced ≈ c rtol = 1e-4 + end + + @testset "deprecations" begin + μ = fill(1 / N, M) + ν = fill(1 / N, N) + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # compute optimal transport plan and cost with real-valued parameters + ε = 0.01 + λ1 = 0.4 + λ2 = 0.5 + γ = sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε) + c = sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε) + + # compute optimal transport plan and cost with manual "proxdiv" functions + # as keyword arguments + proxdivF1(s, p) = s .^ (0.01 / (0.01 + 0.4)) .* p .^ (0.4 / (0.01 + 0.4)) ./ s + proxdivF2(s, p) = s .^ (0.01 / (0.01 + 0.5)) .* p .^ (0.5 / (0.01 + 0.5)) ./ s + γ_proxdiv = @test_deprecated sinkhorn_unbalanced( + μ, ν, C, rand(), rand(), ε; proxdiv_F1=proxdivF1, proxdiv_F2=proxdivF2 + ) + c_proxdiv = @test_deprecated sinkhorn_unbalanced2( + μ, ν, C, rand(), rand(), ε; proxdiv_F1=proxdivF1, proxdiv_F2=proxdivF2 + ) + @test γ_proxdiv ≈ γ + @test c_proxdiv ≈ c + + # deprecated `tol` keyword argument + γ = sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; atol=1e-7) + c = sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; atol=1e-7) + γ_tol = @test_deprecated sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; tol=1e-7) + c_tol = @test_deprecated sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; tol=1e-7) + @test γ_tol == γ + @test c_tol == c + + # deprecated `max_iter` keyword argument + γ = sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; maxiter=50) + c = sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; maxiter=50) + γ_max_iter = @test_deprecated sinkhorn_unbalanced(μ, ν, C, λ1, λ2, ε; max_iter=50) + c_max_iter = @test_deprecated sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; max_iter=50) + @test γ_max_iter == γ + @test c_max_iter == c + end end @testset "stabilized sinkhorn" begin