@@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0];
312
312
313
313
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
314
314
315
- julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
315
+ julia> round.( sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4 )
316
316
3×2 Matrix{Float64}:
317
- 0.0 0.499964
318
- 0.0 0.200188
319
- 0.0 0.29983
317
+ 0.0 0.5
318
+ 0.0 0.2002
319
+ 0.0 0.2998
320
320
```
321
321
322
322
It is possible to provide multiple target marginals as columns of a matrix. In this case the
@@ -325,10 +325,10 @@ optimal transport costs are returned:
325
325
```jldoctest sinkhorn_unbalanced
326
326
julia> ν = [0.0 0.5; 1.0 0.5];
327
327
328
- julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6 )
328
+ julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4 )
329
329
2-element Vector{Float64}:
330
- 0.949709
331
- 0.449411
330
+ 0.9497
331
+ 0.4494
332
332
```
333
333
334
334
See also: [`sinkhorn_unbalanced2`](@ref)
@@ -371,20 +371,19 @@ julia> ν = [0.0, 1.0];
371
371
372
372
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
373
373
374
- julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
375
- 1-element Vector{Float64}:
376
- 0.949709
374
+ julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
375
+ 0.9497
377
376
```
378
377
379
378
It is possible to provide multiple target marginals as columns of a matrix:
380
379
381
380
```jldoctest sinkhorn_unbalanced2
382
381
julia> ν = [0.0 0.5; 1.0 0.5];
383
382
384
- julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6 )
383
+ julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4 )
385
384
2-element Vector{Float64}:
386
- 0.949709
387
- 0.449411
385
+ 0.9497
386
+ 0.4494
388
387
```
389
388
390
389
See also: [`sinkhorn_unbalanced`](@ref)
@@ -516,3 +515,35 @@ Python function.
516
515
function entropic_gromov_wasserstein (μ, ν, Cμ, Cν, ε, loss= " square_loss" ; kwargs... )
517
516
return pot. gromov. entropic_gromov_wasserstein (Cμ, Cν, μ, ν, loss, ε; kwargs... )
518
517
end
518
+
519
+ """
520
+ mm_unbalanced(a, b, M, reg_m; kwargs...)
521
+
522
+ Solve the unbalanced optimal transport problem and return the OT plan.
523
+ The function solves the following optimization problem:
524
+
525
+ ```math
526
+ W = \\ min_\\ gamma \\ quad \\ langle \\ gamma, \\ mathbf{M} \\ rangle_F +
527
+ \\ mathrm{reg_{m1}} \\ cdot \\ mathrm{div}(\\ gamma \\ mathbf{1}, \\ mathbf{a}) +
528
+ \\ mathrm{reg_{m2}} \\ cdot \\ mathrm{div}(\\ gamma^T \\ mathbf{1}, \\ mathbf{b}) +
529
+ \\ mathrm{reg} \\ cdot \\ mathrm{div}(\\ gamma, \\ mathbf{c}) \\\\
530
+
531
+ s.t.
532
+ \\ gamma \\ geq 0
533
+ ```
534
+
535
+ where:
536
+
537
+ - ``\\ mathbf{M}`` is the (``dim_a``, ``dim_b``) metric cost matrix.
538
+ - ``\\ mathbf{a}`` and ``\\ mathbf{b}`` are source and target unbalanced distributions.
539
+ - ``\\ mathbf{c}`` is a reference distribution for the regularization.
540
+ - ``\\ mathrm{reg_m}`` is the marginal relaxation term
541
+
542
+ This function is a wrapper of the function
543
+ [`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the
544
+ Python Optimal Transport package. Keyword arguments are listed in the documentation of the
545
+ Python function.
546
+ """
547
+ function mm_unbalanced (a, b, M, reg_m; kwargs... )
548
+ return pot. unbalanced. mm_unbalanced (a, b, M, reg_m; kwargs... )
549
+ end
0 commit comments