Skip to content

Commit dce4994

Browse files
committed
wip parametric tree solve
1 parent b4a6b50 commit dce4994

File tree

7 files changed

+336
-18
lines changed

7 files changed

+336
-18
lines changed

src/CliqueStateMachine/services/CliqueStateMachine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ function updateFromSubgraph_StateMachine(csmc::CliqStateMachineContainer)
960960
logCSM(
961961
csmc,
962962
"CSM-5 Clique $(csmc.cliq.id) finished, solveKey=$(csmc.solveKey)";
963-
loglevel = Logging.Info,
963+
loglevel = Logging.Debug,
964964
)
965965
return IncrementalInference.exitStateMachine
966966
end

src/Factors/GenericFunctions.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,82 @@ function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
9999
return distanceTangent2Point(cf.factor.M, X, p, q)
100100
end
101101

102+
103+
## ======================================================================================
104+
## adjoint factor - adjoint action applied to the measurement
105+
## ======================================================================================
106+
function Ad(::Union{typeof(SpecialEuclidean(2)), typeof(SpecialEuclidean(3))}, p, X)
107+
t = p.x[1]
108+
R = p.x[2]
109+
v = X.x[1]
110+
Ω = X.x[2]
111+
ArrayPartition(-R*Ω*R'*t + R*v, R*Ω*R')
112+
end
113+
114+
function Ad(::typeof(SpecialEuclidean(3)), p)
115+
t = p.x[1]
116+
R = p.x[2]
117+
vcat(
118+
hcat(R, skew(t)*R),
119+
hcat(zero(SMatrix{3,3,Float64}), R)
120+
)
121+
end
122+
123+
function Ad(::typeof(SpecialEuclidean(2)), p)
124+
t = p.x[1]
125+
R = p.x[2]
126+
vcat(
127+
hcat(R, -SA[0 -1; 1 0]*t),
128+
SA[0 0 1]
129+
)
130+
end
131+
132+
struct AdFactor{F <: AbstractManifoldMinimize} <: AbstractManifoldMinimize
133+
factor::F
134+
end
135+
136+
function (cf::CalcFactor{<:AdFactor})(Xϵ, p, q)
137+
# M = getManifold(cf.factor)
138+
# p,q ∈ M
139+
# Xϵ ∈ TϵM
140+
# ϵ = identity_element(M)
141+
# transform measurement from TϵM to TpM (global to local coordinates)
142+
# Adₚ⁻¹ = AdjointMatrix(M, p)⁻¹ = AdjointMatrix(M, p⁻¹)
143+
# Xp = Adₚ⁻¹ * Xϵᵛ
144+
# ad = Ad(M, inv(M, p))
145+
# Xp = Ad(M, inv(M, p), Xϵ)
146+
# Xp = adjoint_action(M, inv(M, p), Xϵ)
147+
#TODO is vector transport supposed to be the same?
148+
# Xp = vector_transport_to(M, ϵ, Xϵ, p)
149+
150+
# Transform measurement covariance
151+
# ᵉΣₚ = Adₚ ᵖΣₚ Adₚᵀ
152+
#TODO test if transforming sqrt_iΣ is the same as Σ
153+
# Σ = ad * inv(cf.sqrt_iΣ^2) * ad'
154+
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), sqrt(inv(Σ)))
155+
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), ad * cf.sqrt_iΣ * ad')
156+
Xp =
157+
158+
child_cf = CalcFactorResidual(
159+
cf.faclbl,
160+
cf.factor.factor,
161+
cf.varOrder,
162+
cf.varOrderIdxs,
163+
cf.meas,
164+
cf.sqrt_iΣ,
165+
cf.cache,
166+
)
167+
return child_cf(Xp, p, q)
168+
end
169+
170+
getMeasurementParametric(f::AdFactor) = getMeasurementParametric(f.factor)
171+
172+
getManifold(f::AdFactor) = getManifold(f.factor)
173+
function getSample(cf::CalcFactor{<:AdFactor})
174+
M = getManifold(cf)
175+
return sampleTangent(M, cf.factor.factor.Z)
176+
end
177+
102178
## ======================================================================================
103179
## ManifoldPrior
104180
## ======================================================================================

src/manifolds/services/ManifoldSampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ function getSample(cf::CalcFactor{<:AbstractPrior})
132132
end
133133

134134
function getSample(cf::CalcFactor{<:AbstractRelative})
135-
M =getManifold(cf)
135+
M = getManifold(cf)
136136
if hasfield(typeof(cf.factor), :Z)
137137
X = sampleTangent(M, cf.factor.Z)
138138
else

src/parametric/services/ParametricCSMFunctions.jl

Lines changed: 184 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Notes
66
- Parametric state machine function nr. 3
77
"""
8-
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
8+
function solveUp_ParametricStateMachine_Old(csmc::CliqStateMachineContainer)
99
infocsm(csmc, "Par-3, Solving Up")
1010

1111
setCliqueDrawColor!(csmc.cliq, "red")
@@ -96,6 +96,145 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
9696
return waitForDown_StateMachine
9797
end
9898

99+
# solve relatives ignoring any priors keeping `from` at ϵ
100+
# if clique has priors : solve to get a prior on `from`
101+
# send messages as factors or just the beliefs? for now factors
102+
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
103+
infocsm(csmc, "Par-3, Solving Up")
104+
105+
setCliqueDrawColor!(csmc.cliq, "red")
106+
# csmc.drawtree ? drawTree(csmc.tree, show=false, filepath=joinpath(getSolverParams(csmc.dfg).logpath,"bt.pdf")) : nothing
107+
108+
msgfcts = Symbol[]
109+
110+
for (idx, upmsg) in getMessageBuffer(csmc.cliq).upRx #get cached messages taken from children saved in this clique
111+
child_factors = addMsgFactors_Parametric!(csmc.cliqSubFg, upmsg, UpwardPass)
112+
append!(msgfcts, getLabel.(child_factors)) # addMsgFactors_Parametric!
113+
end
114+
logCSM(csmc, "length mgsfcts=$(length(msgfcts))")
115+
infocsm(csmc, "length mgsfcts=$(length(msgfcts))")
116+
117+
# store the cliqSubFg for later debugging
118+
_dbgCSMSaveSubFG(csmc, "fg_beforeupsolve")
119+
120+
subfg = csmc.cliqSubFg
121+
122+
frontals = getCliqFrontalVarIds(csmc.cliq)
123+
separators = getCliqSeparatorVarIds(csmc.cliq)
124+
125+
# if its a root do full solve
126+
if length(getParent(csmc.tree, csmc.cliq)) == 0
127+
# M, vartypeslist, lm_r, Σ = solve_RLM(subfg; is_sparse=false, finiteDiffCovariance=true)
128+
autoinitParametric!(subfg)
129+
M, vartypeslist, lm_r, Σ = solveGraphParametric!(subfg; is_sparse=false, finiteDiffCovariance=true, damping_term_min=1e-18)
130+
131+
else
132+
133+
# select first seperator as constant reference at the identity element
134+
isempty(separators) && @warn "empty separators solving cliq $(csmc.cliq.id.value)" ls(subfg) lsf(subfg)
135+
from = first(separators)
136+
from_v = getVariable(subfg, from)
137+
getSolverData(from_v, :parametric).val[1] = getPointIdentity(getVariableType(from_v))
138+
139+
#TODO handle priors
140+
# Variables that are free to move
141+
free_vars = [frontals; separators[2:end]]
142+
# Solve for the free variables
143+
144+
@assert !isempty(lsf(subfg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(subfg)) lsf=$(lsf(subfg))"
145+
146+
# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from];)
147+
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from]; finiteDiffCovariance=false, damping_term_min=1e-18)
148+
149+
end
150+
151+
# FIXME check solve convergence
152+
if !true
153+
@error "Par-3, clique $(csmc.cliq.id) failed to converge in upsolve" result
154+
# propagate error to cleanly exit all cliques
155+
putErrorUp(csmc)
156+
if length(getParent(csmc.tree, csmc.cliq)) == 0
157+
putErrorDown(csmc)
158+
return IncrementalInference.exitStateMachine
159+
end
160+
161+
return waitForDown_StateMachine
162+
end
163+
164+
logCSM(csmc, "$(csmc.cliq.id): subfg solve converged sending messages")
165+
166+
# Pack results in massage factors
167+
168+
sigmas = extractMarginalsAP(M, vartypeslist, Σ)
169+
170+
# FIXME fix MsgRelativeType
171+
relative_message_factors = MsgRelativeType();
172+
for (i, to) in enumerate(vartypeslist)
173+
if to in separators
174+
#assume full dim factor
175+
factype = selectFactorType(subfg, from, to)
176+
# make S symetrical
177+
# S = sigmas[i] # FIXME for some reason SMatrix is not invertable even though it is!!!!!!!!
178+
S = Matrix(sigmas[i])# FIXME
179+
S = (S + S') / 2
180+
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
181+
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
182+
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
183+
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")
184+
185+
186+
M_to = getManifold(getVariableType(subfg, to))
187+
ϵ = getPointIdentity(M_to)
188+
μ = vee(M_to, ϵ, log(M_to, ϵ, lm_r[i]))
189+
190+
message_factor = AdFactor(factype(MvNormal(μ, S)))
191+
192+
193+
# logCSM(csmc, "$(csmc.cliq.id): Z=$(getMeasurementParametric(message_factor))"; loglevel = Logging.Warn)
194+
195+
push!(relative_message_factors, (variables=[from, to], likelihood=message_factor))
196+
end
197+
end
198+
199+
# Done with solve delete factors
200+
#TODO confirm, maybe don't delete mesage factors on subgraph, maybe delete if its priors, but not conditionals
201+
# deleteMsgFactors!(csmc.cliqSubFg)
202+
203+
# store the cliqSubFg for later debugging
204+
_dbgCSMSaveSubFG(csmc, "fg_afterupsolve")
205+
206+
# cliqueLikelihood = calculateMarginalCliqueLikelihood(vardict, Σ, varIds, cliqSeparatorVarIds)
207+
208+
#Fill in CliqueLikelihood
209+
beliefMsg = LikelihoodMessage(;
210+
sender = (; id = csmc.cliq.id.value, step = csmc._csm_iter),
211+
status = UPSOLVED,
212+
variableOrder = separators,
213+
# cliqueLikelihood,
214+
jointmsg = _MsgJointLikelihood(;relatives=relative_message_factors),
215+
msgType = ParametricMessage(),
216+
)
217+
218+
# @assert length(separators) <= 2 "TODO length(separators) = $(length(separators)) > 2 in clique $(csmc.cliq.id.value)"
219+
@assert isempty(lsfPriors(csmc.cliqSubFg)) || csmc.cliq.id.value == 1 "TODO priors in clique $(csmc.cliq.id.value)"
220+
# if length(lsfPriors(csmc.cliqSubFg)) > 0 || length(separators) > 2
221+
# for si in cliqSeparatorVarIds
222+
# vnd = getSolverData(getVariable(csmc.cliqSubFg, si), :parametric)
223+
# beliefMsg.belief[si] = TreeBelief(deepcopy(vnd))
224+
# end
225+
# end
226+
227+
for e in getEdgesParent(csmc.tree, csmc.cliq)
228+
logCSM(csmc, "$(csmc.cliq.id): put! on edge $(e)")
229+
getMessageBuffer(csmc.cliq).upTx = deepcopy(beliefMsg)
230+
putBeliefMessageUp!(csmc.tree, e, beliefMsg)
231+
end
232+
233+
return waitForDown_StateMachine
234+
end
235+
236+
global g_n = nothing
237+
99238
"""
100239
$SIGNATURES
101240
@@ -120,6 +259,15 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
120259
logCSM(csmc, "$(csmc.cliq.id): Updating separator $msym from message $(belief.val)")
121260
vnd.val .= belief.val
122261
vnd.bw .= belief.bw
262+
263+
p = belief.val[1]
264+
265+
S = belief.bw
266+
S = (S + S') / 2
267+
vnd.bw .= S
268+
269+
nd = MvNormal(getCoordinates(Main.Pose2, p), S)
270+
addFactor!(csmc.cliqSubFg, [msym], Main.PriorPose2(nd))
123271
end
124272
end
125273
end
@@ -132,23 +280,48 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
132280
#only down solve if its not a root
133281
if length(getParent(csmc.tree, csmc.cliq)) != 0
134282
frontals = getCliqFrontalVarIds(csmc.cliq)
135-
vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
283+
# vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
136284
#TEMP testing difference
137285
# vardict, result = solveGraphParametric(csmc.cliqSubFg)
138286
# Pack all results in variables
139-
if result.g_converged || result.f_converged
287+
@assert !isempty(lsf(csmc.cliqSubFg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(csmc.cliqSubFg)) lsf=$(lsf(csmc.cliqSubFg))"
288+
289+
# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(csmc.cliqSubFg, frontals; finiteDiffCovariance=false, damping_term_min=1e-18)
290+
M, vartypeslist, lm_r, Σ = solve_RLM(csmc.cliqSubFg; finiteDiffCovariance=false, damping_term_min=1e-18)
291+
sigmas = extractMarginalsAP(M, vartypeslist, Σ)
292+
293+
if true # TODO check for convergence result.g_converged || result.f_converged
140294
logCSM(
141295
csmc,
142296
"$(csmc.cliq.id): subfg optim converged updating variables";
143-
loglevel = Logging.Info,
297+
loglevel = Logging.Debug,
144298
)
145-
for (v, val) in vardict
146-
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info)
147-
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
148-
#Update subfg variables
149-
vnd.val[1] = val.val
150-
vnd.bw .= val.cov
299+
for (i, v) in enumerate(vartypeslist)
300+
if v in frontals
301+
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
302+
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
303+
304+
S = Matrix(sigmas[i])# FIXME
305+
S = (S + S') / 2
306+
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
307+
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
308+
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
309+
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")
310+
311+
312+
#Update subfg variables
313+
vnd.val[1] = lm_r[i]
314+
vnd.bw .= S
315+
end
151316
end
317+
# for (v, val) in vardict
318+
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
319+
# vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
320+
321+
# #Update subfg variables
322+
# vnd.val[1] = val.val
323+
# vnd.bw .= val.cov
324+
# end
152325
else
153326
@error "Par-5, clique $(csmc.cliq.id) failed to converge in down solve" result
154327
#propagate error to cleanly exit all cliques
@@ -169,7 +342,7 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
169342
for fi in cliqFrontalVarIds
170343
vnd = getSolverData(getVariable(csmc.cliqSubFg, fi), :parametric)
171344
beliefMsg.belief[fi] = TreeBelief(vnd)
172-
logCSM(csmc, "$(csmc.cliq.id): down message $fi : $beliefMsg"; loglevel = Logging.Info)
345+
logCSM(csmc, "$(csmc.cliq.id): down message $fi"; beliefMsg=beliefMsg.belief[fi], loglevel = Logging.Debug)
173346
end
174347

175348
# pass through the frontal variables that were sent from above

0 commit comments

Comments
 (0)