Skip to content

Commit 3b6f51d

Browse files
authored
Return Samples type instead of a matrix (#60)
* sample type change * update_evidence! and update_temperature
1 parent 77d41d7 commit 3b6f51d

8 files changed

+69
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorInference"
22
uuid = "c2297e78-99bd-40ad-871d-f50e56b81012"
33
authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
4-
version = "0.2.1"
4+
version = "0.3.0"
55

66
[deps]
77
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

src/Core.jl

+17
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ struct TensorNetworkModel{LT, ET, MT <: AbstractArray}
5959
mars::Vector{Vector{LT}}
6060
end
6161

62+
"""
63+
$TYPEDSIGNATURES
64+
65+
Update the evidence of a tensor network model, without changing the set of observed variables!
66+
67+
### Arguments
68+
- `tnet` is the [`TensorNetworkModel`](@ref) instance.
69+
- `evidence` is the new evidence, the keys must be a subset of existing evidence.
70+
"""
71+
function update_evidence!(tnet::TensorNetworkModel, evidence::Dict)
72+
for (k, v) in evidence
73+
haskey(tnet.evidence, k) || error("`update_evidence!` can only update observed variables!")
74+
tnet.evidence[k] = v
75+
end
76+
return tnet
77+
end
78+
6279
function Base.show(io::IO, tn::TensorNetworkModel)
6380
open = getiyv(tn.code)
6481
variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ")

src/TensorInference.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ export problem_from_artifact, ArtifactProblemSpec
2323
export read_model, UAIModel, read_evidence, read_solution, read_queryvars, dataset_from_artifact
2424

2525
# marginals
26-
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals
26+
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals, update_evidence!
2727

2828
# MAP
2929
export most_probable_config, maximum_logp

src/generictensornetworks.jl

+25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using .GenericTensorNetworks: generate_tensors, GraphProblem, flavors, labels
22

3+
# update models
4+
export update_temperature
5+
36
"""
47
$TYPEDSIGNATURES
58
@@ -20,6 +23,24 @@ function TensorInference.TensorNetworkModel(problem::GraphProblem, β::Real; evi
2023
factors = [Factor((ix...,), t) for (ix, t) in zip(ixs, tensors)]
2124
return TensorNetworkModel(lbs, fill(nflavors, length(lbs)), factors; openvars=iy, evidence, optimizer, simplifier, mars)
2225
end
26+
27+
"""
28+
$TYPEDSIGNATURES
29+
30+
Update the temperature of a tensor network model.
31+
The program will regenerate tensors from the problem, without repeated optimizing the contraction order.
32+
33+
### Arguments
34+
- `tnet` is the [`TensorNetworkModel`](@ref) instance.
35+
- `problem` is the target constraint satisfiability problem.
36+
- `β` is the inverse temperature.
37+
"""
38+
function update_temperature(tnet::TensorNetworkModel, problem::GraphProblem, β::Real)
39+
tensors = generate_tensors(exp(β), problem)
40+
alltensors = [tnet.tensors[1:end-length(tensors)]..., tensors...]
41+
return TensorNetworkModel(tnet.vars, tnet.code, alltensors, tnet.evidence, tnet.mars)
42+
end
43+
2344
function TensorInference.MMAPModel(problem::GraphProblem, β::Real;
2445
queryvars,
2546
evidence = Dict{labeltype(problem.code), Int}(),
@@ -37,6 +58,10 @@ function TensorInference.MMAPModel(problem::GraphProblem, β::Real;
3758
optimizer, simplifier,
3859
marginalize_optimizer, marginalize_simplifier)
3960
end
61+
function update_temperature(tnet::MMAPModel, problem::GraphProblem, β::Real)
62+
error("We haven't got time to implement setting temperatures for `MMAPModel`.
63+
It is about one or two hours of works. If you need it, please file an issue to let us know: https://github.com/TensorBFS/TensorInference.jl/issues")
64+
end
4065

4166
@info "`TensorInference` loaded `GenericTensorNetworks` extension successfully,
4267
`TensorNetworkModel` and `MMAPModel` can be used for converting a `GraphProblem` to a probabilistic model now."

src/sampling.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The sampled configurations are stored in `samples`, which is a vector of vector.
99
`labels` is a vector of variable names for labeling configurations.
1010
The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
1111
"""
12-
struct Samples{L}
12+
struct Samples{L} <: AbstractVector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
1313
samples::Matrix{Int} # size is nvars × nsample
1414
labels::Vector{L}
1515
setmask::BitVector
@@ -22,7 +22,9 @@ function setmask!(samples::Samples, eliminated_variables)
2222
end
2323
return samples
2424
end
25-
25+
Base.getindex(s::Samples, i::Int) = view(s.samples, :, i)
26+
Base.length(s::Samples) = size(s.samples, 2)
27+
Base.size(s::Samples) = (size(s.samples, 2),)
2628
idx4labels(totalset, labels)::Vector{Int} = map(v->findfirst(==(v), totalset), labels)
2729

2830
"""
@@ -99,7 +101,7 @@ Returns a vector of vector, each element being a configurations defined on `get_
99101
* `tn` is the tensor network model.
100102
* `n` is the number of samples to be returned.
101103
"""
102-
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::AbstractMatrix{Int}
104+
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Samples
103105
# generate tropical tensors with its elements being log(p).
104106
xs = adapt_tensors(tn; usecuda, rescale = false)
105107
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
@@ -125,7 +127,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::AbstractMatrix
125127
idx = findfirst(==(k), labels)
126128
samples.samples[idx, :] .= v
127129
end
128-
return samples.samples
130+
return samples
129131
end
130132

131133
function generate_samples(se::SlicedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}

test/generictensornetworks.jl

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ using GenericTensorNetworks, TensorInference
1212
mars2 = TensorInference.normalize!(GenericTensorNetworks.solve(problem2, PartitionFunction(β)), 1)
1313
@test mars mars2
1414

15+
# update temperature
16+
β2 = 3.0
17+
model = update_temperature(model, problem, β2)
18+
pa = probability(model)[]
19+
model2 = TensorNetworkModel(problem, β2)
20+
pb = probability(model2)[]
21+
@test pa pb
22+
1523
# mmap
1624
model = MMAPModel(problem, β; queryvars=[1,4])
1725
logp, config = most_probable_config(model)

test/mar.jl

+9
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,13 @@ end
122122
tnet34 = TensorNetworkModel(model; openvars=[3,4])
123123
@test mars[1] probability(tnet23)
124124
@test mars[2] probability(tnet34)
125+
126+
tnet1 = TensorNetworkModel(model; mars=[[2, 3], [3, 4]], evidence=Dict(3=>1))
127+
tnet2 = TensorNetworkModel(model; mars=[[2, 3], [3, 4]], evidence=Dict(3=>0))
128+
mars1 = marginals(tnet1)
129+
mars2 = marginals(tnet2)
130+
update_evidence!(tnet1, Dict(3=>0))
131+
mars1b = marginals(tnet1)
132+
@test !(mars1 mars2)
133+
@test mars1b mars2
125134
end

test/sampling.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ using TensorInference, Test
5050
tnet = TensorNetworkModel(model)
5151
samples = sample(tnet, n)
5252
mars = getindex.(marginals(tnet), 2)
53-
mars_sample = [count(i->samples[k, i]==(1), axes(samples, 2)) for k=1:8] ./ n
53+
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
5454
@test isapprox(mars, mars_sample, atol=0.05)
5555

5656
# fix the evidence
5757
tnet = TensorNetworkModel(model, optimizer=TreeSA(), evidence=Dict(7=>1))
5858
samples = sample(tnet, n)
5959
mars = getindex.(marginals(tnet), 1)
60-
mars_sample = [count(i->samples[k, i]==(0), axes(samples, 2)) for k=1:8] ./ n
60+
mars_sample = [count(s->s[k]==(0), samples) for k=1:8] ./ n
6161
@test isapprox([mars[1:6]..., mars[8]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
6262
end

0 commit comments

Comments
 (0)