Skip to content

Commit ceb747d

Browse files
committed
Add mm_unbalanced function
1 parent a3c1d24 commit ceb747d

File tree

3 files changed

+47
-15
lines changed

3 files changed

+47
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PythonOT"
22
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

src/PythonOT.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ export emd,
1212
barycenter_unbalanced,
1313
sinkhorn_unbalanced,
1414
sinkhorn_unbalanced2,
15-
empirical_sinkhorn_divergence
15+
empirical_sinkhorn_divergence,
16+
mm_unbalanced
1617

1718
const pot = PyCall.PyNULL()
1819

src/lib.jl

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0];
312312
313313
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
314314
315-
julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
315+
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
316316
3×2 Matrix{Float64}:
317-
0.0 0.499964
318-
0.0 0.200188
319-
0.0 0.29983
317+
0.0 0.5
318+
0.0 0.2002
319+
0.0 0.2998
320320
```
321321
322322
It is possible to provide multiple target marginals as columns of a matrix. In this case the
@@ -325,10 +325,10 @@ optimal transport costs are returned:
325325
```jldoctest sinkhorn_unbalanced
326326
julia> ν = [0.0 0.5; 1.0 0.5];
327327
328-
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6)
328+
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
329329
2-element Vector{Float64}:
330-
0.949709
331-
0.449411
330+
0.9497
331+
0.4494
332332
```
333333
334334
See also: [`sinkhorn_unbalanced2`](@ref)
@@ -371,20 +371,19 @@ julia> ν = [0.0, 1.0];
371371
372372
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
373373
374-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
375-
1-element Vector{Float64}:
376-
0.949709
374+
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
375+
0.9497
377376
```
378377
379378
It is possible to provide multiple target marginals as columns of a matrix:
380379
381380
```jldoctest sinkhorn_unbalanced2
382381
julia> ν = [0.0 0.5; 1.0 0.5];
383382
384-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
383+
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
385384
2-element Vector{Float64}:
386-
0.949709
387-
0.449411
385+
0.9497
386+
0.4494
388387
```
389388
390389
See also: [`sinkhorn_unbalanced`](@ref)
@@ -516,3 +515,35 @@ Python function.
516515
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
517516
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
518517
end
518+
519+
"""
520+
mm_unbalanced(a, b, M, reg_m; kwargs...)
521+
522+
Solve the unbalanced optimal transport problem and return the OT plan.
523+
The function solves the following optimization problem:
524+
525+
```math
526+
W = \\min_\\gamma \\quad \\langle \\gamma, \\mathbf{M} \\rangle_F +
527+
\\mathrm{reg_{m1}} \\cdot \\mathrm{div}(\\gamma \\mathbf{1}, \\mathbf{a}) +
528+
\\mathrm{reg_{m2}} \\cdot \\mathrm{div}(\\gamma^T \\mathbf{1}, \\mathbf{b}) +
529+
\\mathrm{reg} \\cdot \\mathrm{div}(\\gamma, \\mathbf{c}) \\\\
530+
531+
s.t.
532+
\\gamma \\geq 0
533+
```
534+
535+
where:
536+
537+
- ``\\mathbf{M}`` is the (``dim_a``, ``dim_b``) metric cost matrix.
538+
- ``\\mathbf{a}`` and ``\\mathbf{b}`` are source and target unbalanced distributions.
539+
- ``\\mathbf{c}`` is a reference distribution for the regularization.
540+
- ``\\mathrm{reg_m}`` is the marginal relaxation term
541+
542+
This function is a wrapper of the function
543+
[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the
544+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
545+
Python function.
546+
"""
547+
function mm_unbalanced(a, b, M, reg_m; kwargs...)
548+
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)
549+
end

0 commit comments

Comments
 (0)