aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
ShaferShenoyMNInference_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) et Christophe GONZALES(@AMU)
4
* (@AMU) info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of Shafer-Shenoy's propagation for inference in
25
* Markov Networks.
26
*/
27
28
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
29
#
include
<
agrum
/
MN
/
inference
/
ShaferShenoyMNInference
.
h
>
30
31
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
32
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
33
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
34
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimProjection
.
h
>
35
36
37
namespace
gum
{
38
// default constructor
39
template
<
typename
GUM_SCALAR >
40
INLINE ShaferShenoyMNInference<
GUM_SCALAR
>::
ShaferShenoyMNInference
(
41
const
IMarkovNet
<
GUM_SCALAR
>*
MN
,
42
bool
use_binary_join_tree
) :
43
JointTargetedMNInference
<
GUM_SCALAR
>(
MN
),
44
EvidenceMNInference
<
GUM_SCALAR
>(
MN
),
45
use_binary_join_tree__
(
use_binary_join_tree
) {
46
// create a default triangulation (the user can change it afterwards)
47
triangulation__
=
new
DefaultTriangulation
;
48
49
// for debugging purposes
50
GUM_CONSTRUCTOR
(
ShaferShenoyMNInference
);
51
}
52
53
54
// destructor
55
template
<
typename
GUM_SCALAR
>
56
INLINE
ShaferShenoyMNInference
<
GUM_SCALAR
>::~
ShaferShenoyMNInference
() {
57
// remove all the potentials created during the last message passing
58
for
(
const
auto
&
pots
:
created_messages__
) {
59
for
(
const
auto
pot
:
pots
.
second
)
60
delete
pot
;
61
}
62
63
// remove the potentials created after removing the nodes that received
64
// hard evidence
65
for
(
const
auto
&
pot
:
hard_ev_projected_factors__
) {
66
delete
pot
.
second
;
67
}
68
69
// remove all the potentials in clique_potentials__ that do not belong
70
// to clique_potentials_list__ : in this case, those potentials have been
71
// created by combination of the corresponding list of potentials in
72
// clique_potentials_list__. In other words, the size of this list is strictly
73
// greater than 1.
74
for
(
auto
pot
:
clique_potentials__
) {
75
if
(
clique_potentials_list__
[
pot
.
first
].
size
() > 1) {
delete
pot
.
second
; }
76
}
77
78
// remove all the posteriors computed
79
for
(
const
auto
&
pot
:
target_posteriors__
)
80
delete
pot
.
second
;
81
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
82
delete
pot
.
second
;
83
84
// remove the junction tree and the triangulation algorithm
85
if
(
propagator__
!=
nullptr
)
delete
propagator__
;
86
if
(
junctionTree__
!=
nullptr
)
delete
junctionTree__
;
87
delete
triangulation__
;
88
89
// for debugging purposes
90
GUM_DESTRUCTOR
(
ShaferShenoyMNInference
);
91
}
92
93
94
/// set a new triangulation algorithm
95
template
<
typename
GUM_SCALAR
>
96
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
setTriangulation
(
97
const
Triangulation
&
new_triangulation
) {
98
delete
triangulation__
;
99
triangulation__
=
new_triangulation
.
newFactory
();
100
is_new_jt_needed__
=
true
;
101
this
->
setOutdatedStructureState_
();
102
}
103
104
105
/// returns the current join tree used
106
template
<
typename
GUM_SCALAR
>
107
INLINE
const
JoinTree
*
ShaferShenoyMNInference
<
GUM_SCALAR
>::
joinTree
() {
108
createNewJT__
();
109
110
return
propagator__
;
111
}
112
113
/// returns the current junction tree
114
template
<
typename
GUM_SCALAR
>
115
INLINE
const
JunctionTree
*
116
ShaferShenoyMNInference
<
GUM_SCALAR
>::
junctionTree
() {
117
createNewJT__
();
118
119
return
junctionTree__
;
120
}
121
122
123
/// sets the operator for performing the projections
124
template
<
typename
GUM_SCALAR
>
125
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
setProjectionFunction__
(
126
Potential
<
GUM_SCALAR
>* (*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
127
const
Set
<
const
DiscreteVariable
* >&)) {
128
projection_op__
=
proj
;
129
}
130
131
132
/// sets the operator for performing the combinations
133
template
<
typename
GUM_SCALAR
>
134
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
setCombinationFunction__
(
135
Potential
<
GUM_SCALAR
>* (*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
136
const
Potential
<
GUM_SCALAR
>&)) {
137
combination_op__
=
comb
;
138
}
139
140
141
/// invalidate all messages, posteriors and created potentials
142
template
<
typename
GUM_SCALAR
>
143
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
invalidateAllMessages__
() {
144
// remove all the messages computed
145
for
(
auto
&
potset
:
separator_potentials__
)
146
potset
.
second
.
clear
();
147
for
(
auto
&
mess_computed
:
messages_computed__
)
148
mess_computed
.
second
=
false
;
149
150
// remove all the created potentials
151
for
(
const
auto
&
potset
:
created_messages__
)
152
for
(
const
auto
pot
:
potset
.
second
)
153
delete
pot
;
154
155
// remove all the posteriors
156
for
(
const
auto
&
pot
:
target_posteriors__
)
157
delete
pot
.
second
;
158
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
159
delete
pot
.
second
;
160
161
// indicate that new messages need be computed
162
if
(
this
->
isInferenceReady
() ||
this
->
isInferenceDone
())
163
this
->
setOutdatedPotentialsState_
();
164
}
165
166
/// fired when a new evidence is inserted
167
template
<
typename
GUM_SCALAR
>
168
INLINE
void
169
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
id
,
170
bool
isHardEvidence
) {
171
// if we have a new hard evidence, this modifies the undigraph over which
172
// the join tree is created. This is also the case if id is not a node of
173
// of the undigraph
174
if
(
isHardEvidence
|| !
reduced_graph__
.
exists
(
id
))
175
is_new_jt_needed__
=
true
;
176
else
{
177
try
{
178
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ADDED
);
179
}
catch
(
DuplicateElement
&) {
180
// here, the evidence change already existed. This necessarily means
181
// that the current saved change is an EVIDENCE_ERASED. So if we
182
// erased the evidence and added some again, this corresponds to an
183
// EVIDENCE_MODIFIED
184
evidence_changes__
[
id
] =
EvidenceChangeType
::
EVIDENCE_MODIFIED
;
185
}
186
}
187
}
188
189
190
/// fired when an evidence is removed
191
template
<
typename
GUM_SCALAR
>
192
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onEvidenceErased_
(
193
const
NodeId
id
,
194
bool
isHardEvidence
) {
195
// if we delete a hard evidence, this modifies the undigraph over which
196
// the join tree is created.
197
if
(
isHardEvidence
)
198
is_new_jt_needed__
=
true
;
199
else
{
200
try
{
201
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
202
}
catch
(
DuplicateElement
&) {
203
// here, the evidence change already existed and it is necessarily an
204
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
205
// been added and is now erased, this is similar to not having created
206
// it. If the evidence was only modified, it already existed in the
207
// last inference and we should now indicate that it has been removed.
208
if
(
evidence_changes__
[
id
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
209
evidence_changes__
.
erase
(
id
);
210
else
211
evidence_changes__
[
id
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
212
}
213
}
214
}
215
216
217
/// fired when all the evidence are erased
218
template
<
typename
GUM_SCALAR
>
219
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
220
bool
has_hard_evidence
) {
221
if
(
has_hard_evidence
|| !
this
->
hardEvidenceNodes
().
empty
())
222
is_new_jt_needed__
=
true
;
223
else
{
224
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
225
try
{
226
evidence_changes__
.
insert
(
node
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
227
}
catch
(
DuplicateElement
&) {
228
// here, the evidence change already existed and it is necessarily an
229
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
230
// been added and is now erased, this is similar to not having created
231
// it. If the evidence was only modified, it already existed in the
232
// last inference and we should now indicate that it has been removed.
233
if
(
evidence_changes__
[
node
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
234
evidence_changes__
.
erase
(
node
);
235
else
236
evidence_changes__
[
node
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
237
}
238
}
239
}
240
}
241
242
243
/// fired when an evidence is changed
244
template
<
typename
GUM_SCALAR
>
245
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onEvidenceChanged_
(
246
const
NodeId
id
,
247
bool
hasChangedSoftHard
) {
248
if
(
hasChangedSoftHard
)
249
is_new_jt_needed__
=
true
;
250
else
{
251
try
{
252
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_MODIFIED
);
253
}
catch
(
DuplicateElement
&) {
254
// here, the evidence change already existed and it is necessarily an
255
// EVIDENCE_ADDED. So we should keep this state to indicate that this
256
// evidence is new w.r.t. the last inference
257
}
258
}
259
}
260
261
262
/// fired after a new target is inserted
263
template
<
typename
GUM_SCALAR
>
264
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
265
const
NodeId
id
) {}
266
267
268
/// fired before a target is removed
269
template
<
typename
GUM_SCALAR
>
270
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
271
const
NodeId
id
) {}
272
273
274
/// fired after a new set target is inserted
275
template
<
typename
GUM_SCALAR
>
276
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onJointTargetAdded_
(
277
const
NodeSet
&
set
) {}
278
279
280
/// fired before a set target is removed
281
template
<
typename
GUM_SCALAR
>
282
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onJointTargetErased_
(
283
const
NodeSet
&
set
) {}
284
285
286
/// fired after all the nodes of the MN are added as single targets
287
template
<
typename
GUM_SCALAR
>
288
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {
289
}
290
291
292
/// fired before a all the single_targets are removed
293
template
<
typename
GUM_SCALAR
>
294
INLINE
void
295
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
296
297
/// fired after a new Markov net has been assigned to the engine
298
template
<
typename
GUM_SCALAR
>
299
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onMarkovNetChanged_
(
300
const
IMarkovNet
<
GUM_SCALAR
>*
mn
) {}
301
302
/// fired before a all the joint_targets are removed
303
template
<
typename
GUM_SCALAR
>
304
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onAllJointTargetsErased_
() {}
305
306
307
/// fired before a all the single and joint_targets are removed
308
template
<
typename
GUM_SCALAR
>
309
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
onAllTargetsErased_
() {}
310
311
312
// check whether a new junction tree is really needed for the next inference
313
template
<
typename
GUM_SCALAR
>
314
bool
ShaferShenoyMNInference
<
GUM_SCALAR
>::
isNewJTNeeded__
()
const
{
315
// if we do not have a JT or if new_jt_needed__ is set to true, then
316
// we know that we need to create a new join tree
317
if
((
propagator__
==
nullptr
) ||
is_new_jt_needed__
)
return
true
;
318
319
// if some some targets do not belong to the join tree and, consequently,
320
// to the undigraph that was used to construct the join tree, then we need
321
// to create a new JT. This situation may occur if we constructed the
322
// join tree after pruning irrelevant/barren nodes from the MN)
323
// however, note that the nodes that received hard evidence do not belong to
324
// the graph and, therefore, should not be taken into account
325
const
auto
&
hard_ev_nodes
=
this
->
hardEvidenceNodes
();
326
for
(
const
auto
node
:
this
->
targets
()) {
327
if
(!
reduced_graph__
.
exists
(
node
) && !
hard_ev_nodes
.
exists
(
node
))
328
return
true
;
329
}
330
for
(
const
auto
&
joint_target
:
this
->
jointTargets
()) {
331
// here, we need to check that at least one clique contains all the
332
// nodes of the joint target.
333
bool
containing_clique_found
=
false
;
334
for
(
const
auto
node
:
joint_target
) {
335
bool
found
=
true
;
336
try
{
337
const
NodeSet
&
clique
=
propagator__
->
clique
(
node_to_clique__
[
node
]);
338
for
(
const
auto
xnode
:
joint_target
) {
339
if
(!
clique
.
contains
(
xnode
) && !
hard_ev_nodes
.
exists
(
xnode
)) {
340
found
=
false
;
341
break
;
342
}
343
}
344
}
catch
(
NotFound
&) {
found
=
false
; }
345
346
if
(
found
) {
347
containing_clique_found
=
true
;
348
break
;
349
}
350
}
351
352
if
(!
containing_clique_found
)
return
true
;
353
}
354
355
// if some new evidence have been added on nodes that do not belong
356
// to reduced_graph__, then we potentially have to reconstruct the join tree
357
for
(
const
auto
&
change
:
evidence_changes__
) {
358
if
((
change
.
second
==
EvidenceChangeType
::
EVIDENCE_ADDED
)
359
&& !
reduced_graph__
.
exists
(
change
.
first
))
360
return
true
;
361
}
362
363
// here, the current JT is exactly what we need for the next inference
364
return
false
;
365
}
366
367
368
/// create a new junction tree as well as its related data structures
369
template
<
typename
GUM_SCALAR
>
370
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
createNewJT__
() {
371
// to create the JT, we first create the moral graph of the MN in the
372
// following way in order to take into account the barren nodes and the
373
// nodes that received evidence:
374
// 1/ we create an undirected graph containing only the nodes and no edge
375
// 2/ add edges so that set targets are cliques of the moral graph
376
// 3/ remove the nodes that received hard evidence (by step 3/, their
377
// parents are linked by edges, which is necessary for inference)
378
//
379
// At the end of step 3/, we have our moral graph and we can triangulate it
380
// to get the new junction tree
381
382
// 1/ create an undirected graph containing only the nodes and no edge
383
const
auto
&
mn
=
this
->
MN
();
384
reduced_graph__
=
mn
.
graph
();
385
386
// 2/ if there exists some joint targets, we shall add new edges into the
387
// moral graph in order to ensure that there exists a clique containing
388
// each joint
389
for
(
const
auto
&
nodeset
:
this
->
jointTargets
())
390
for
(
const
auto
n1
:
nodeset
)
391
for
(
const
auto
n2
:
nodeset
)
392
if
(
n1
<
n2
)
reduced_graph__
.
addEdge
(
n1
,
n2
);
393
394
// 3/ remove all the nodes that received hard evidence
395
hard_ev_nodes__
=
this
->
hardEvidenceNodes
();
396
for
(
const
auto
node
:
hard_ev_nodes__
) {
397
reduced_graph__
.
eraseNode
(
node
);
398
}
399
400
// now, we can compute the new junction tree. To speed-up computations
401
// (essentially, those of a distribution phase), we construct from this
402
// junction tree a binary join tree
403
if
(
propagator__
!=
nullptr
)
delete
propagator__
;
404
if
(
junctionTree__
!=
nullptr
)
delete
junctionTree__
;
405
406
triangulation__
->
setGraph
(&
reduced_graph__
, &(
this
->
domainSizes
()));
407
const
JunctionTree
&
triang_jt
=
triangulation__
->
junctionTree
();
408
if
(
use_binary_join_tree__
) {
409
BinaryJoinTreeConverterDefault
bjt_converter
;
410
NodeSet
emptyset
;
411
propagator__
=
new
CliqueGraph
(
412
bjt_converter
.
convert
(
triang_jt
,
this
->
domainSizes
(),
emptyset
));
413
}
else
{
414
propagator__
=
new
CliqueGraph
(
triang_jt
);
415
}
416
junctionTree__
=
new
CliqueGraph
(
triang_jt
);
417
418
const
std
::
vector
<
NodeId
>&
JT_elim_order
419
=
triangulation__
->
eliminationOrder
();
420
Size
size_elim_order
=
JT_elim_order
.
size
();
421
NodeProperty
<
Idx
>
elim_order
(
size_elim_order
);
422
for
(
Idx
i
=
Idx
(0);
i
<
size_elim_order
; ++
i
)
423
elim_order
.
insert
(
JT_elim_order
[
i
],
i
);
424
425
// indicate, for each factor of the markov network a clique in propagator__
426
// that can contain its conditional probability table
427
factor_to_clique__
.
clear
();
428
// TO REMOVE const UndiGraph& graph = mn.graph();
429
for
(
const
auto
&
kv
:
mn
.
factors
()) {
430
const
auto
&
factor
=
kv
.
first
;
// kv.second is the Potential()
431
// not a true id since smallest_elim_number==size_elim_order+100
432
NodeId
first_eliminated_node
= 0;
433
// (impossible smallest_elim_number \in [0...size_elim_order-1]
434
Idx
smallest_elim_number
=
size_elim_order
;
435
for
(
const
auto
nod
:
factor
) {
436
if
(
reduced_graph__
.
exists
(
nod
)) {
437
if
(
elim_order
[
nod
] <
smallest_elim_number
) {
438
smallest_elim_number
=
elim_order
[
nod
];
439
first_eliminated_node
=
nod
;
440
}
441
}
442
}
443
444
if
(
smallest_elim_number
!=
size_elim_order
) {
445
// first_eliminated_node contains the first var (node or one of its
446
// parents) eliminated => the clique created during its elimination
447
// contains node and all of its parents => it can contain the potential
448
// assigned to the node in the MN
449
factor_to_clique__
.
insert
(
450
factor
,
451
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
452
}
453
}
454
455
// attribute a clique in the joint tree MN for each node in the MN
456
// (using a NodeProperty instead of using mn.smallestFactorFromNode()
457
// could it have a better choice ?)
458
node_to_clique__
.
clear
();
459
for
(
const
NodeId
node
:
mn
.
nodes
()) {
460
if
(
reduced_graph__
.
exists
(
node
)) {
461
node_to_clique__
.
insert
(
462
node
,
463
factor_to_clique__
[
mn
.
smallestFactorFromNode
(
node
)]);
464
}
465
}
466
467
// indicate for each joint_target a clique that contains it
468
joint_target_to_clique__
.
clear
();
469
for
(
const
auto
&
target
:
this
->
jointTargets
()) {
470
// remove from target all the nodes that received hard evidence (since they
471
// do not belong to the join tree)
472
NodeSet
nodeset
=
target
;
473
for
(
const
auto
node
:
hard_ev_nodes__
)
474
if
(
nodeset
.
contains
(
node
))
nodeset
.
erase
(
node
);
475
476
if
(!
nodeset
.
empty
()) {
477
// not a true id since smallest_elim_number==size_elim_order
478
NodeId
first_eliminated_node
= 0;
479
// (smallest_elim_number \in [0...size_elim_order-1])
480
Idx
smallest_elim_number
=
size_elim_order
;
481
for
(
const
auto
nod
:
nodeset
) {
482
if
(
reduced_graph__
.
exists
(
nod
)) {
483
if
(
elim_order
[
nod
] <
smallest_elim_number
) {
484
smallest_elim_number
=
elim_order
[
nod
];
485
first_eliminated_node
=
nod
;
486
}
487
}
488
}
489
490
if
(
smallest_elim_number
!=
size_elim_order
) {
491
// first_eliminated_node contains the first var (node or one of its
492
// parents) eliminated => the clique created during its elimination
493
// contains node and all of its parents => it can contain the potential
494
// assigned to the node in the MN
495
factor_to_clique__
.
insert
(
496
nodeset
,
497
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
498
}
499
}
500
}
501
502
// compute the roots of propagator__'s connected components
503
computeJoinTreeRoots__
();
504
505
// remove all the Shafer-Shenoy potentials stored into the cliques
506
for
(
const
auto
&
xpot
:
clique_potentials__
) {
507
if
(
clique_potentials_list__
[
xpot
.
first
].
size
() > 1)
delete
xpot
.
second
;
508
}
509
clique_potentials_list__
.
clear
();
510
511
// create empty potential lists into the cliques of the joint tree as well
512
// as empty lists of evidence
513
PotentialSet__
empty_set
;
514
node_to_soft_evidence__
.
clear
();
515
for
(
const
auto
clik
: *
propagator__
) {
516
clique_potentials_list__
.
insert
(
clik
,
empty_set
);
517
}
518
519
// remove all the potentials created during the last inference
520
for
(
const
auto
&
potlist
:
created_messages__
)
521
for
(
auto
pot
:
potlist
.
second
)
522
delete
pot
;
523
created_messages__
.
clear
();
524
525
// remove all the potentials created to take into account hard evidence
526
// during the last inference
527
for
(
auto
pot_pair
:
hard_ev_projected_factors__
)
528
delete
pot_pair
.
second
;
529
hard_ev_projected_factors__
.
clear
();
530
531
// remove all the constants created due to projections of factors that were
532
// defined over only hard evidence nodes
533
//__constants.clear();
534
535
// create empty lists of potentials for the messages and indicate that no
536
// message has been computed yet
537
separator_potentials__
.
clear
();
538
messages_computed__
.
clear
();
539
for
(
const
auto
&
edge
:
propagator__
->
edges
()) {
540
const
Arc
arc1
(
edge
.
first
(),
edge
.
second
());
541
separator_potentials__
.
insert
(
arc1
,
empty_set
);
542
messages_computed__
.
insert
(
arc1
,
false
);
543
const
Arc
arc2
(
Arc
(
edge
.
second
(),
edge
.
first
()));
544
separator_potentials__
.
insert
(
arc2
,
empty_set
);
545
messages_computed__
.
insert
(
arc2
,
false
);
546
}
547
548
// remove all the posteriors computed so far
549
for
(
const
auto
&
pot
:
target_posteriors__
)
550
delete
pot
.
second
;
551
target_posteriors__
.
clear
();
552
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
553
delete
pot
.
second
;
554
joint_target_posteriors__
.
clear
();
555
556
557
// put all the factors of the Markov net nodes into the cliques
558
// here, beware: all the potentials that are defined over some nodes
559
// including hard evidence must be projected so that these nodes are
560
// removed from the potential
561
const
auto
&
evidence
=
this
->
evidence
();
562
for
(
const
auto
&
kv
:
mn
.
factors
()) {
563
const
auto
&
factor
=
kv
.
first
;
564
const
auto
&
pot
= *(
kv
.
second
);
565
566
NodeSet
hard_nodes_in_factor
;
567
for
(
const
auto
xnode
:
factor
) {
568
if
(
hard_ev_nodes__
.
contains
(
xnode
))
hard_nodes_in_factor
.
insert
(
xnode
);
569
}
570
571
// if hard_nodes_in_factor contains hard evidence nodes, perform a projection
572
// and insert the result into the appropriate clique, else insert
573
// directly factor into the clique
574
if
(
hard_nodes_in_factor
.
empty
()) {
575
clique_potentials_list__
[
factor_to_clique__
[
factor
]].
insert
(&
pot
);
576
}
else
{
577
// marginalize out the hard evidence nodes: if the factor is defined
578
// only over nodes that received hard evidence, do not consider it
579
// as a potential anymore but as a constant
580
if
(
hard_nodes_in_factor
.
size
() !=
factor
.
size
()) {
581
// perform the projection with a combine and project instance
582
Set
<
const
DiscreteVariable
* >
hard_variables
;
583
PotentialSet__
marg_factor_set
;
584
marg_factor_set
.
insert
(&
pot
);
585
for
(
const
auto
xnode
:
hard_nodes_in_factor
) {
586
marg_factor_set
.
insert
(
evidence
[
xnode
]);
587
hard_variables
.
insert
(&(
mn
.
variable
(
xnode
)));
588
}
589
590
// perform the combination of those potentials and their projection
591
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
592
combine_and_project
(
combination_op__
,
SSNewMNprojPotential
);
593
PotentialSet__
new_factor_list
594
=
combine_and_project
.
combineAndProject
(
marg_factor_set
,
595
hard_variables
);
596
597
// there should be only one potential in new_factor_list
598
if
(
new_factor_list
.
size
() != 1) {
599
for
(
auto
pot
:
new_factor_list
) {
600
if
(!
marg_factor_set
.
contains
(
pot
))
delete
pot
;
601
}
602
GUM_ERROR
(
FatalError
,
603
"the projection of a potential containing "
604
<<
"hard evidence is empty!"
);
605
}
606
const
Potential
<
GUM_SCALAR
>*
projected_factor
607
= *(
new_factor_list
.
begin
());
608
clique_potentials_list__
[
factor_to_clique__
[
factor
]].
insert
(
609
projected_factor
);
610
hard_ev_projected_factors__
.
insert
(
factor
,
projected_factor
);
611
}
612
}
613
}
614
615
// we shall now add all the potentials of the soft evidence
616
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
617
node_to_soft_evidence__
.
insert
(
node
,
evidence
[
node
]);
618
clique_potentials_list__
[
node_to_clique__
[
node
]].
insert
(
evidence
[
node
]);
619
}
620
621
// now, in clique_potentials_list__, for each clique, we have the list of
622
// potentials that must be combined in order to produce the Shafer-Shenoy's
623
// potential stored into the clique. So, perform this combination and
624
// store the result in clique_potentials__
625
clique_potentials__
.
clear
();
626
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
627
combination_op__
);
628
for
(
const
auto
&
xpotset
:
clique_potentials_list__
) {
629
const
auto
&
potset
=
xpotset
.
second
;
630
if
(
potset
.
size
() > 0) {
631
// here, there will be an entry in clique_potentials__
632
// If there is only one element in potset, this element shall be
633
// stored into clique_potentials__, else all the elements of potset
634
// shall be combined and their result shall be stored
635
if
(
potset
.
size
() == 1) {
636
clique_potentials__
.
insert
(
xpotset
.
first
, *(
potset
.
cbegin
()));
637
}
else
{
638
auto
joint
=
fast_combination
.
combine
(
potset
);
639
clique_potentials__
.
insert
(
xpotset
.
first
,
joint
);
640
}
641
}
642
}
643
644
// indicate that the data structures are up to date.
645
evidence_changes__
.
clear
();
646
is_new_jt_needed__
=
false
;
647
}
// namespace gum
648
649
650
/// prepare the inference structures w.r.t. new targets, soft/hard evidence
651
template
<
typename
GUM_SCALAR
>
652
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {
653
// check if a new JT is really needed. If so, create it
654
if
(
isNewJTNeeded__
()) {
655
createNewJT__
();
656
}
else
{
657
// here, we can answer the next queries without reconstructing all the
658
// junction tree. All we need to do is to indicate that we should
659
// update the potentials and messages for these queries
660
updateOutdatedPotentials_
();
661
}
662
}
663
664
665
/// invalidate all the messages sent from a given clique
666
template
<
typename
GUM_SCALAR
>
667
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
diffuseMessageInvalidations__
(
668
NodeId
from_id
,
669
NodeId
to_id
,
670
NodeSet
&
invalidated_cliques
) {
671
// invalidate the current clique
672
invalidated_cliques
.
insert
(
to_id
);
673
674
// invalidate the current arc
675
const
Arc
arc
(
from_id
,
to_id
);
676
bool
&
message_computed
=
messages_computed__
[
arc
];
677
if
(
message_computed
) {
678
message_computed
=
false
;
679
separator_potentials__
[
arc
].
clear
();
680
if
(
created_messages__
.
exists
(
arc
)) {
681
auto
&
arc_created_potentials
=
created_messages__
[
arc
];
682
for
(
auto
pot
:
arc_created_potentials
)
683
delete
pot
;
684
arc_created_potentials
.
clear
();
685
}
686
687
// go on with the diffusion
688
for
(
const
auto
node_id
:
propagator__
->
neighbours
(
to_id
)) {
689
if
(
node_id
!=
from_id
)
690
diffuseMessageInvalidations__
(
to_id
,
node_id
,
invalidated_cliques
);
691
}
692
}
693
}
694
695
696
/// update the potentials stored in the cliques and invalidate outdated
697
/// messages
698
template
<
typename
GUM_SCALAR
>
699
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {
700
// for each clique, indicate whether the potential stored into
701
// clique_ss_potentials__[clique] is the result of a combination. In this
702
// case, it has been allocated by the combination and will need to be
703
// deallocated if its clique has been invalidated
704
const
auto
&
mn
=
this
->
MN
();
705
706
NodeProperty
<
bool
>
to_deallocate
(
clique_potentials_list__
.
size
());
707
for
(
auto
pot_iter
=
clique_potentials_list__
.
cbegin
();
708
pot_iter
!=
clique_potentials_list__
.
cend
();
709
++
pot_iter
) {
710
to_deallocate
.
insert
(
pot_iter
.
key
(), (
pot_iter
.
val
().
size
() > 1));
711
}
712
713
// compute the set of factors that were projected due to hard evidence and
714
// whose hard evidence have changed, so that they need a new projection.
715
// By the way, remove these factors since they are no more needed
716
// Here only the values of the hard evidence can have changed (else a
717
// fully new join tree would have been computed).
718
// Note also that we know that the factors still contain some variable(s) after
719
// the projection (else they should be constants)
720
NodeSet
hard_nodes_changed
(
hard_ev_nodes__
.
size
());
721
Set
<
NodeId
>
hard_cliques_changed
;
722
Set
<
NodeSet
>
hard_factors_changed
;
723
for
(
const
auto
node
:
hard_ev_nodes__
) {
724
if
(
evidence_changes__
.
exists
(
node
)) {
725
hard_nodes_changed
.
insert
(
node
);
726
for
(
const
auto
&
elt
:
mn
.
factors
()) {
727
const
auto
&
chgFactor
=
elt
.
first
;
728
if
(
chgFactor
.
contains
(
node
)) {
729
if
(!
hard_factors_changed
.
contains
(
chgFactor
)) {
730
hard_factors_changed
.
insert
(
chgFactor
);
731
732
const
auto
&
chgPot
=
hard_ev_projected_factors__
[
chgFactor
];
733
const
NodeId
chgClique
=
factor_to_clique__
[
chgFactor
];
734
clique_potentials_list__
[
chgClique
].
erase
(
chgPot
);
735
delete
chgPot
;
736
hard_ev_projected_factors__
.
erase
(
chgFactor
);
737
738
if
(!
hard_cliques_changed
.
contains
(
chgClique
))
739
hard_cliques_changed
.
insert
(
chgClique
);
740
}
741
}
742
}
743
}
744
}
745
746
// invalidate all the messages that are no more correct: start from each of
747
// the nodes whose soft evidence has changed and perform a diffusion from
748
// the clique into which the soft evidence has been entered, indicating that
749
// the messages spreading from this clique are now invalid. At the same time,
750
// if there were potentials created on the arcs over which the messages were
751
// sent, remove them from memory. For all the cliques that received some
752
// projected factors that should now be changed, do the same.
753
NodeSet
invalidated_cliques
(
propagator__
->
size
());
754
for
(
const
auto
&
pair
:
evidence_changes__
) {
755
if
(
node_to_clique__
.
exists
(
pair
.
first
)) {
756
const
auto
clique
=
node_to_clique__
[
pair
.
first
];
757
invalidated_cliques
.
insert
(
clique
);
758
for
(
const
auto
neighbor
:
propagator__
->
neighbours
(
clique
)) {
759
diffuseMessageInvalidations__
(
clique
,
neighbor
,
invalidated_cliques
);
760
}
761
}
762
}
763
764
// now, add to the set of invalidated cliques those that contain projected
765
// factors that were changed.
766
for
(
const
auto
clique
:
hard_cliques_changed
) {
767
invalidated_cliques
.
insert
(
clique
);
768
for
(
const
auto
neighbor
:
propagator__
->
neighbours
(
clique
)) {
769
diffuseMessageInvalidations__
(
clique
,
neighbor
,
invalidated_cliques
);
770
}
771
}
772
773
// now that we know the cliques whose set of potentials have been changed,
774
// we can discard their corresponding Shafer-Shenoy potential
775
for
(
const
auto
clique
:
invalidated_cliques
)
776
if
(
clique_potentials__
.
exists
(
clique
) &&
to_deallocate
[
clique
])
777
delete
clique_potentials__
[
clique
];
778
779
// now we shall remove all the posteriors that belong to the
780
// invalidated cliques. First, cope only with the nodes that did not
781
// received hard evidence since the other nodes do not belong to the
782
// join tree
783
if
(!
target_posteriors__
.
empty
()) {
784
for
(
auto
iter
=
target_posteriors__
.
beginSafe
();
785
iter
!=
target_posteriors__
.
endSafe
();
786
++
iter
) {
787
if
(
reduced_graph__
.
exists
(
iter
.
key
())
788
&& (
invalidated_cliques
.
exists
(
node_to_clique__
[
iter
.
key
()]))) {
789
delete
iter
.
val
();
790
target_posteriors__
.
erase
(
iter
);
791
}
792
}
793
794
// now cope with the nodes that received hard evidence
795
for
(
auto
iter
=
target_posteriors__
.
beginSafe
();
796
iter
!=
target_posteriors__
.
endSafe
();
797
++
iter
) {
798
if
(
hard_nodes_changed
.
contains
(
iter
.
key
())) {
799
delete
iter
.
val
();
800
target_posteriors__
.
erase
(
iter
);
801
}
802
}
803
}
804
805
// finally, cope with joint targets
806
for
(
auto
iter
=
joint_target_posteriors__
.
beginSafe
();
807
iter
!=
joint_target_posteriors__
.
endSafe
();
808
++
iter
) {
809
if
(
invalidated_cliques
.
exists
(
joint_target_to_clique__
[
iter
.
key
()])) {
810
delete
iter
.
val
();
811
joint_target_posteriors__
.
erase
(
iter
);
812
}
813
}
814
815
// remove all the evidence that were entered into node_to_soft_evidence__
816
// and clique_potentials_list__ and add the new soft ones
817
for
(
auto
&
pot_pair
:
node_to_soft_evidence__
)
818
clique_potentials_list__
[
node_to_clique__
[
pot_pair
.
first
]].
erase
(
819
pot_pair
.
second
);
820
node_to_soft_evidence__
.
clear
();
821
822
const
auto
&
evidence
=
this
->
evidence
();
823
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
824
node_to_soft_evidence__
.
insert
(
node
,
evidence
[
node
]);
825
clique_potentials_list__
[
node_to_clique__
[
node
]].
insert
(
evidence
[
node
]);
826
}
827
828
829
// Now add the projections of the factors due to newly changed hard evidence:
830
// if we are performing updateOutdatedPotentials_, this means that the
831
// set of nodes that received hard evidence has not been changed, only
832
// their instantiations can have been changed. So, if there is an entry
833
// for node in constants__, there will still be such an entry after
834
// performing the new projections. Idem for hard_ev_projected_factors__
835
836
// for (const auto node: factors_with_projected_factors_changed) {
837
for
(
const
auto
&
iter_factor
:
mn
.
factors
()) {
838
NodeSet
hard_nodes
=
hard_nodes_changed
*
iter_factor
.
first
;
839
if
(
hard_nodes
.
empty
())
continue
;
// no change in iter_factor
840
841
// perform the projection with a combine and project instance
842
PotentialSet__
marg_factor_set
{
iter_factor
.
second
};
843
Set
<
const
DiscreteVariable
* >
hard_variables
;
844
for
(
const
auto
xnode
:
hard_nodes
) {
845
marg_factor_set
.
insert
(
evidence
[
xnode
]);
846
hard_variables
.
insert
(&
mn
.
variable
(
xnode
));
847
}
848
849
// perform the combination of those potentials and their projection
850
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
851
combine_and_project
(
combination_op__
,
SSNewMNprojPotential
);
852
PotentialSet__
new_potentials_list
853
=
combine_and_project
.
combineAndProject
(
marg_factor_set
,
hard_variables
);
854
855
// there should be only one potential in new_potentials_list
856
if
(
new_potentials_list
.
size
() != 1) {
857
// remove the potentials created to avoid memory leaks
858
for
(
auto
pot
:
new_potentials_list
) {
859
if
(!
marg_factor_set
.
contains
(
pot
))
delete
pot
;
860
}
861
GUM_ERROR
(
FatalError
,
862
"the projection of a potential containing "
863
<<
"hard evidence is empty!"
);
864
}
865
866
const
Potential
<
GUM_SCALAR
>*
projected_factor
867
= *(
new_potentials_list
.
begin
());
868
clique_potentials_list__
[
factor_to_clique__
[
iter_factor
.
first
]].
insert
(
869
projected_factor
);
870
hard_ev_projected_factors__
.
insert
(
iter_factor
.
first
,
projected_factor
);
871
}
872
873
// here, the list of potentials stored in the invalidated cliques have
874
// been updated. So, now, we can combine them to produce the Shafer-Shenoy
875
// potential stored into the clique
876
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
877
combination_op__
);
878
for
(
const
auto
clique
:
invalidated_cliques
) {
879
const
auto
&
potset
=
clique_potentials_list__
[
clique
];
880
881
if
(
potset
.
size
() > 0) {
882
// here, there will be an entry in clique_potentials__
883
// If there is only one element in potset, this element shall be
884
// stored into clique_potentials__, else all the elements of potset
885
// shall be combined and their result shall be stored
886
if
(
potset
.
size
() == 1) {
887
clique_potentials__
[
clique
] = *(
potset
.
cbegin
());
888
}
else
{
889
auto
joint
=
fast_combination
.
combine
(
potset
);
890
clique_potentials__
[
clique
] =
joint
;
891
}
892
}
893
}
894
895
// update the constants
896
/*const auto& hard_evidence = this->hardEvidence();
897
for (auto& node_cst: constants__) {
898
const Potential< GUM_SCALAR >& cpt = mn.cpt(node_cst.first);
899
const auto& variables = cpt.variablesSequence();
900
Instantiation inst;
901
for (const auto var: variables)
902
inst << *var;
903
for (const auto var: variables) {
904
inst.chgVal(var, hard_evidence[mn.nodeId(*var)]);
905
}
906
node_cst.second = cpt[inst];
907
}*/
908
909
// indicate that all changes have been performed
910
evidence_changes__
.
clear
();
911
}
912
913
914
/// compute a root for each connected component of propagator__
915
template
<
typename
GUM_SCALAR
>
916
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
computeJoinTreeRoots__
() {
917
const
auto
&
mn
=
this
->
MN
();
918
919
// get the set of cliques in which we can find the targets and joint_targets
920
NodeSet
clique_targets
;
921
for
(
const
auto
node
:
this
->
targets
()) {
922
try
{
923
clique_targets
.
insert
(
node_to_clique__
[
node
]);
924
}
catch
(
Exception
&) {}
925
}
926
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
927
try
{
928
clique_targets
.
insert
(
joint_target_to_clique__
[
set
]);
929
}
catch
(
Exception
&) {}
930
}
931
932
// put in a vector these cliques and their size
933
std
::
vector
<
std
::
pair
<
NodeId
,
Size
> >
possible_roots
(
clique_targets
.
size
());
934
std
::
size_t
i
= 0;
935
for
(
const
auto
clique_id
:
clique_targets
) {
936
const
auto
&
clique
=
propagator__
->
clique
(
clique_id
);
937
Size
dom_size
= 1;
938
for
(
const
auto
node
:
clique
) {
939
dom_size
*=
mn
.
variable
(
node
).
domainSize
();
940
}
941
possible_roots
[
i
] =
std
::
pair
<
NodeId
,
Size
>(
clique_id
,
dom_size
);
942
++
i
;
943
}
944
945
// sort the cliques by increasing domain size
946
std
::
sort
(
possible_roots
.
begin
(),
947
possible_roots
.
end
(),
948
[](
const
std
::
pair
<
NodeId
,
Size
>&
a
,
949
const
std
::
pair
<
NodeId
,
Size
>&
b
) ->
bool
{
950
return
a
.
second
<
b
.
second
;
951
});
952
953
// pick up the clique with the smallest size in each connected component
954
NodeProperty
<
bool
>
marked
=
propagator__
->
nodesProperty
(
false
);
955
std
::
function
<
void
(
NodeId
,
NodeId
) >
diffuse_marks
956
= [&
marked
, &
diffuse_marks
,
this
](
NodeId
node
,
NodeId
from
) {
957
if
(!
marked
[
node
]) {
958
marked
[
node
] =
true
;
959
for
(
const
auto
neigh
:
propagator__
->
neighbours
(
node
))
960
if
((
neigh
!=
from
) && !
marked
[
neigh
])
diffuse_marks
(
neigh
,
node
);
961
}
962
};
963
roots__
.
clear
();
964
for
(
const
auto
xclique
:
possible_roots
) {
965
NodeId
clique
=
xclique
.
first
;
966
if
(!
marked
[
clique
]) {
967
roots__
.
insert
(
clique
);
968
diffuse_marks
(
clique
,
clique
);
969
}
970
}
971
}
972
973
974
// performs the collect phase of Lazy Propagation
975
template
<
typename
GUM_SCALAR
>
976
INLINE
void
977
ShaferShenoyMNInference
<
GUM_SCALAR
>::
collectMessage__
(
NodeId
id
,
978
NodeId
from
) {
979
for
(
const
auto
other
:
propagator__
->
neighbours
(
id
)) {
980
if
((
other
!=
from
) && !
messages_computed__
[
Arc
(
other
,
id
)])
981
collectMessage__
(
other
,
id
);
982
}
983
984
if
((
id
!=
from
) && !
messages_computed__
[
Arc
(
id
,
from
)]) {
985
produceMessage__
(
id
,
from
);
986
}
987
}
988
989
// remove variables del_vars from the list of potentials pot_list
990
template
<
typename
GUM_SCALAR
>
991
Set
<
const
Potential
<
GUM_SCALAR
>* >
992
ShaferShenoyMNInference
<
GUM_SCALAR
>::
marginalizeOut__
(
993
Set
<
const
Potential
<
GUM_SCALAR
>* >
pot_list
,
994
Set
<
const
DiscreteVariable
* >&
del_vars
,
995
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
996
// remove the potentials corresponding to barren variables if we want
997
// to exploit barren nodes
998
/*PotentialSet__ barren_projected_potentials;
999
if (barren_nodes_type__ == FindBarrenNodesType::FIND_BARREN_NODES) {
1000
barren_projected_potentials = removeBarrenVariables__(pot_list, del_vars);
1001
}*/
1002
1003
// create a combine and project operator that will perform the
1004
// marginalization
1005
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
1006
combination_op__
,
1007
projection_op__
);
1008
PotentialSet__
new_pot_list
1009
=
combine_and_project
.
combineAndProject
(
pot_list
,
del_vars
);
1010
1011
// remove all the potentials that were created due to projections of
1012
// barren nodes and that are not part of the new_pot_list: these
1013
// potentials were just temporary potentials
1014
/*for (auto iter = barren_projected_potentials.beginSafe();
1015
iter != barren_projected_potentials.endSafe();
1016
++iter) {
1017
if (!new_pot_list.exists(*iter)) delete *iter;
1018
}*/
1019
1020
// remove all the potentials that have no dimension
1021
for
(
auto
iter_pot
=
new_pot_list
.
beginSafe
();
1022
iter_pot
!=
new_pot_list
.
endSafe
();
1023
++
iter_pot
) {
1024
if
((*
iter_pot
)->
variablesSequence
().
size
() == 0) {
1025
// as we have already marginalized out variables that received evidence,
1026
// it may be the case that, after combining and projecting, some
1027
// potentials might be empty. In this case, we shall keep their
1028
// constant and remove them from memory
1029
// # TODO: keep the constants!
1030
delete
*
iter_pot
;
1031
new_pot_list
.
erase
(
iter_pot
);
1032
}
1033
}
1034
1035
return
new_pot_list
;
1036
}
1037
1038
1039
// creates the message sent by clique from_id to clique to_id
1040
template
<
typename
GUM_SCALAR
>
1041
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
produceMessage__
(
NodeId
from_id
,
1042
NodeId
to_id
) {
1043
// get the potentials of the clique.
1044
PotentialSet__
pot_list
;
1045
if
(
clique_potentials__
.
exists
(
from_id
))
1046
pot_list
.
insert
(
clique_potentials__
[
from_id
]);
1047
1048
// add the messages sent by adjacent nodes to from_id
1049
for
(
const
auto
other_id
:
propagator__
->
neighbours
(
from_id
))
1050
if
(
other_id
!=
to_id
)
1051
pot_list
+=
separator_potentials__
[
Arc
(
other_id
,
from_id
)];
1052
1053
// get the set of variables that need be removed from the potentials
1054
const
NodeSet
&
from_clique
=
propagator__
->
clique
(
from_id
);
1055
const
NodeSet
&
separator
=
propagator__
->
separator
(
from_id
,
to_id
);
1056
Set
<
const
DiscreteVariable
* >
del_vars
(
from_clique
.
size
());
1057
Set
<
const
DiscreteVariable
* >
kept_vars
(
separator
.
size
());
1058
const
auto
&
mn
=
this
->
MN
();
1059
1060
for
(
const
auto
node
:
from_clique
) {
1061
if
(!
separator
.
contains
(
node
)) {
1062
del_vars
.
insert
(&(
mn
.
variable
(
node
)));
1063
}
else
{
1064
kept_vars
.
insert
(&(
mn
.
variable
(
node
)));
1065
}
1066
}
1067
1068
// pot_list now contains all the potentials to multiply and marginalize
1069
// => combine the messages
1070
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1071
1072
/*
1073
for the moment, remove this test: due to some optimizations, some
1074
potentials might have all their cells greater than 1.
1075
1076
// remove all the potentials that are equal to ones (as probability
1077
// matrix multiplications are tensorial, such potentials are useless)
1078
for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
1079
++iter) {
1080
const auto pot = *iter;
1081
if (pot->variablesSequence().size() == 1) {
1082
bool is_all_ones = true;
1083
for (Instantiation inst(*pot); !inst.end(); ++inst) {
1084
if ((*pot)[inst] < one_minus_epsilon__) {
1085
is_all_ones = false;
1086
break;
1087
}
1088
}
1089
if (is_all_ones) {
1090
if (!pot_list.exists(pot)) delete pot;
1091
new_pot_list.erase(iter);
1092
continue;
1093
}
1094
}
1095
}
1096
*/
1097
1098
// if there are still potentials in new_pot_list, combine them to
1099
// produce the message on the separator
1100
const
Arc
arc
(
from_id
,
to_id
);
1101
if
(!
new_pot_list
.
empty
()) {
1102
if
(
new_pot_list
.
size
() == 1) {
// there is only one potential
1103
// in new_pot_list, so this corresponds to our message on
1104
// the separator
1105
auto
pot
= *(
new_pot_list
.
begin
());
1106
separator_potentials__
[
arc
] =
std
::
move
(
new_pot_list
);
1107
if
(!
pot_list
.
exists
(
pot
)) {
1108
if
(!
created_messages__
.
exists
(
arc
))
1109
created_messages__
.
insert
(
arc
,
PotentialSet__
());
1110
created_messages__
[
arc
].
insert
(
pot
);
1111
}
1112
}
else
{
1113
// create the message in the separator
1114
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1115
combination_op__
);
1116
auto
joint
=
fast_combination
.
combine
(
new_pot_list
);
1117
separator_potentials__
[
arc
].
insert
(
joint
);
1118
if
(!
created_messages__
.
exists
(
arc
))
1119
created_messages__
.
insert
(
arc
,
PotentialSet__
());
1120
created_messages__
[
arc
].
insert
(
joint
);
1121
1122
// remove the temporary messages created in new_pot_list
1123
for
(
const
auto
pot
:
new_pot_list
) {
1124
if
(!
pot_list
.
exists
(
pot
)) {
delete
pot
; }
1125
}
1126
}
1127
}
1128
1129
messages_computed__
[
arc
] =
true
;
1130
}
1131
1132
1133
// performs a whole inference
1134
template
<
typename
GUM_SCALAR
>
1135
INLINE
void
ShaferShenoyMNInference
<
GUM_SCALAR
>::
makeInference_
() {
1136
// collect messages for all single targets
1137
for
(
const
auto
node
:
this
->
targets
()) {
1138
// perform only collects in the join tree for nodes that have
1139
// not received hard evidence (those that received hard evidence were
1140
// not included into the join tree for speed-up reasons)
1141
if
(
reduced_graph__
.
exists
(
node
)) {
1142
collectMessage__
(
node_to_clique__
[
node
],
node_to_clique__
[
node
]);
1143
}
1144
}
1145
1146
// collect messages for all set targets
1147
// by parsing joint_target_to_clique__, we ensure that the cliques that
1148
// are referenced belong to the join tree (even if some of the nodes in
1149
// their associated joint_target do not belong to reduced_graph__)
1150
for
(
const
auto
set
:
joint_target_to_clique__
)
1151
collectMessage__
(
set
.
second
,
set
.
second
);
1152
}
1153
1154
1155
/// returns a fresh potential equal to P(1st arg,evidence)
1156
template
<
typename
GUM_SCALAR
>
1157
Potential
<
GUM_SCALAR
>*
1158
ShaferShenoyMNInference
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
1159
NodeId
id
) {
1160
const
auto
&
mn
=
this
->
MN
();
1161
1162
// hard evidence do not belong to the join tree
1163
// # TODO: check for sets of inconsistent hard evidence
1164
if
(
this
->
hardEvidenceNodes
().
contains
(
id
)) {
1165
return
new
Potential
<
GUM_SCALAR
>(*(
this
->
evidence
()[
id
]));
1166
}
1167
1168
// if we still need to perform some inference task, do it (this should
1169
// already have been done by makeInference_)
1170
NodeId
clique_of_id
=
node_to_clique__
[
id
];
1171
collectMessage__
(
clique_of_id
,
clique_of_id
);
1172
1173
// now we just need to create the product of the potentials of the clique
1174
// containing id with the messages received by this clique and
1175
// marginalize out all variables except id
1176
PotentialSet__
pot_list
;
1177
if
(
clique_potentials__
.
exists
(
clique_of_id
))
1178
pot_list
.
insert
(
clique_potentials__
[
clique_of_id
]);
1179
1180
// add the messages sent by adjacent nodes to targetCliquxse
1181
for
(
const
auto
other
:
propagator__
->
neighbours
(
clique_of_id
))
1182
pot_list
+=
separator_potentials__
[
Arc
(
other
,
clique_of_id
)];
1183
1184
// get the set of variables that need be removed from the potentials
1185
const
NodeSet
&
nodes
=
propagator__
->
clique
(
clique_of_id
);
1186
Set
<
const
DiscreteVariable
* >
kept_vars
{&(
mn
.
variable
(
id
))};
1187
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1188
for
(
const
auto
node
:
nodes
) {
1189
if
(
node
!=
id
)
del_vars
.
insert
(&(
mn
.
variable
(
node
)));
1190
}
1191
1192
// pot_list now contains all the potentials to multiply and marginalize
1193
// => combine the messages
1194
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1195
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1196
1197
if
(
new_pot_list
.
size
() == 1) {
1198
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1199
// if pot already existed, create a copy, so that we can put it into
1200
// the target_posteriors__ property
1201
if
(
pot_list
.
exists
(
joint
)) {
1202
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1203
}
else
{
1204
// remove the joint from new_pot_list so that it will not be
1205
// removed just after the else block
1206
new_pot_list
.
clear
();
1207
}
1208
}
else
{
1209
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1210
combination_op__
);
1211
joint
=
fast_combination
.
combine
(
new_pot_list
);
1212
}
1213
1214
// remove the potentials that were created in new_pot_list
1215
for
(
auto
pot
:
new_pot_list
)
1216
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1217
1218
// check that the joint posterior is different from a 0 vector: this would
1219
// indicate that some hard evidence are not compatible (their joint
1220
// probability is equal to 0)
1221
bool
nonzero_found
=
false
;
1222
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1223
if
((*
joint
)[
inst
]) {
1224
nonzero_found
=
true
;
1225
break
;
1226
}
1227
}
1228
if
(!
nonzero_found
) {
1229
// remove joint from memory to avoid memory leaks
1230
delete
joint
;
1231
GUM_ERROR
(
IncompatibleEvidence
,
1232
"some evidence entered into the Markov "
1233
"net are incompatible (their joint proba = 0)"
);
1234
}
1235
return
joint
;
1236
}
1237
1238
1239
/// returns the posterior of a given variable
1240
template
<
typename
GUM_SCALAR
>
1241
const
Potential
<
GUM_SCALAR
>&
1242
ShaferShenoyMNInference
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
1243
// check if we have already computed the posterior
1244
if
(
target_posteriors__
.
exists
(
id
)) {
return
*(
target_posteriors__
[
id
]); }
1245
1246
// compute the joint posterior and normalize
1247
auto
joint
=
unnormalizedJointPosterior_
(
id
);
1248
joint
->
normalize
();
1249
target_posteriors__
.
insert
(
id
,
joint
);
1250
1251
return
*
joint
;
1252
}
1253
1254
1255
// returns the marginal a posteriori proba of a given node
1256
template
<
typename
GUM_SCALAR
>
1257
Potential
<
GUM_SCALAR
>*
1258
ShaferShenoyMNInference
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
1259
const
NodeSet
&
set
) {
1260
// hard evidence do not belong to the join tree, so extract the nodes
1261
// from targets that are not hard evidence
1262
NodeSet
targets
=
set
,
hard_ev_nodes
;
1263
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
1264
if
(
targets
.
contains
(
node
)) {
1265
targets
.
erase
(
node
);
1266
hard_ev_nodes
.
insert
(
node
);
1267
}
1268
}
1269
1270
// if all the nodes have received hard evidence, then compute the
1271
// joint posterior directly by multiplying the hard evidence potentials
1272
const
auto
&
evidence
=
this
->
evidence
();
1273
if
(
targets
.
empty
()) {
1274
PotentialSet__
pot_list
;
1275
for
(
const
auto
node
:
set
) {
1276
pot_list
.
insert
(
evidence
[
node
]);
1277
}
1278
if
(
pot_list
.
size
() == 1) {
1279
auto
pot
=
new
Potential
<
GUM_SCALAR
>(**(
pot_list
.
begin
()));
1280
return
pot
;
1281
}
else
{
1282
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1283
combination_op__
);
1284
return
fast_combination
.
combine
(
pot_list
);
1285
}
1286
}
1287
1288
1289
// if we still need to perform some inference task, do it: so, first,
1290
// determine the clique on which we should perform collect to compute
1291
// the unnormalized joint posterior of a set of nodes containing "set"
1292
NodeId
clique_of_set
;
1293
try
{
1294
clique_of_set
=
joint_target_to_clique__
[
set
];
1295
}
catch
(
NotFound
&) {
1296
// here, the precise set of targets does not belong to the set of targets
1297
// defined by the user. So we will try to find a clique in the junction
1298
// tree that contains "targets":
1299
1300
// 1/ we should check that all the nodes belong to the join tree
1301
for
(
const
auto
node
:
targets
) {
1302
if
(!
reduced_graph__
.
exists
(
node
)) {
1303
GUM_ERROR
(
UndefinedElement
,
node
<<
" is not a target node"
);
1304
}
1305
}
1306
1307
// 2/ the clique created by the first eliminated node among target is the
1308
// one we are looking for
1309
const
std
::
vector
<
NodeId
>&
JT_elim_order
1310
=
triangulation__
->
eliminationOrder
();
1311
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
1312
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
;
1313
++
i
)
1314
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
1315
NodeId
first_eliminated_node
= *(
targets
.
begin
());
1316
int
elim_number
=
elim_order
[
first_eliminated_node
];
1317
for
(
const
auto
node
:
targets
) {
1318
if
(
elim_order
[
node
] <
elim_number
) {
1319
elim_number
=
elim_order
[
node
];
1320
first_eliminated_node
=
node
;
1321
}
1322
}
1323
clique_of_set
1324
=
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
);
1325
1326
// 3/ check that cliquee_of_set contains the all the nodes in the target
1327
const
NodeSet
&
clique_nodes
=
propagator__
->
clique
(
clique_of_set
);
1328
for
(
const
auto
node
:
targets
) {
1329
if
(!
clique_nodes
.
contains
(
node
)) {
1330
GUM_ERROR
(
UndefinedElement
,
set
<<
" is not a joint target"
);
1331
}
1332
}
1333
1334
// add the discovered clique to joint_target_to_clique__
1335
joint_target_to_clique__
.
insert
(
set
,
clique_of_set
);
1336
}
1337
1338
// now perform a collect on the clique
1339
collectMessage__
(
clique_of_set
,
clique_of_set
);
1340
1341
// now we just need to create the product of the potentials of the clique
1342
// containing set with the messages received by this clique and
1343
// marginalize out all variables except set
1344
PotentialSet__
pot_list
;
1345
if
(
clique_potentials__
.
exists
(
clique_of_set
))
1346
pot_list
.
insert
(
clique_potentials__
[
clique_of_set
]);
1347
1348
// add the messages sent by adjacent nodes to targetClique
1349
for
(
const
auto
other
:
propagator__
->
neighbours
(
clique_of_set
))
1350
pot_list
+=
separator_potentials__
[
Arc
(
other
,
clique_of_set
)];
1351
1352
// get the set of variables that need be removed from the potentials
1353
const
NodeSet
&
nodes
=
propagator__
->
clique
(
clique_of_set
);
1354
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1355
Set
<
const
DiscreteVariable
* >
kept_vars
(
targets
.
size
());
1356
const
auto
&
mn
=
this
->
MN
();
1357
for
(
const
auto
node
:
nodes
) {
1358
if
(!
targets
.
contains
(
node
)) {
1359
del_vars
.
insert
(&(
mn
.
variable
(
node
)));
1360
}
else
{
1361
kept_vars
.
insert
(&(
mn
.
variable
(
node
)));
1362
}
1363
}
1364
1365
// pot_list now contains all the potentials to multiply and marginalize
1366
// => combine the messages
1367
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1368
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1369
1370
if
((
new_pot_list
.
size
() == 1) &&
hard_ev_nodes
.
empty
()) {
1371
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1372
// if pot already existed, create a copy, so that we can put it into
1373
// the target_posteriors__ property
1374
if
(
pot_list
.
exists
(
joint
)) {
1375
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1376
}
else
{
1377
// remove the joint from new_pot_list so that it will not be
1378
// removed just after the next else block
1379
new_pot_list
.
clear
();
1380
}
1381
}
else
{
1382
// combine all the potentials in new_pot_list with all the hard evidence
1383
// of the nodes in set
1384
PotentialSet__
new_new_pot_list
=
new_pot_list
;
1385
for
(
const
auto
node
:
hard_ev_nodes
) {
1386
new_new_pot_list
.
insert
(
evidence
[
node
]);
1387
}
1388
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1389
combination_op__
);
1390
joint
=
fast_combination
.
combine
(
new_new_pot_list
);
1391
}
1392
1393
// remove the potentials that were created in new_pot_list
1394
for
(
auto
pot
:
new_pot_list
)
1395
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1396
1397
// check that the joint posterior is different from a 0 vector: this would
1398
// indicate that some hard evidence are not compatible
1399
bool
nonzero_found
=
false
;
1400
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1401
if
((*
joint
)[
inst
]) {
1402
nonzero_found
=
true
;
1403
break
;
1404
}
1405
}
1406
if
(!
nonzero_found
) {
1407
// remove joint from memory to avoid memory leaks
1408
delete
joint
;
1409
GUM_ERROR
(
IncompatibleEvidence
,
1410
"some evidence entered into the Markov "
1411
"net are incompatible (their joint proba = 0)"
);
1412
}
1413
1414
return
joint
;
1415
}
1416
1417
1418
/// returns the posterior of a given set of variables
1419
template
<
typename
GUM_SCALAR
>
1420
const
Potential
<
GUM_SCALAR
>&
1421
ShaferShenoyMNInference
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
set
) {
1422
// check if we have already computed the posterior
1423
if
(
joint_target_posteriors__
.
exists
(
set
)) {
1424
return
*(
joint_target_posteriors__
[
set
]);
1425
}
1426
1427
// compute the joint posterior and normalize
1428
auto
joint
=
unnormalizedJointPosterior_
(
set
);
1429
joint
->
normalize
();
1430
joint_target_posteriors__
.
insert
(
set
,
joint
);
1431
1432
return
*
joint
;
1433
}
1434
1435
1436
/// returns the posterior of a given set of variables
1437
template
<
typename
GUM_SCALAR
>
1438
const
Potential
<
GUM_SCALAR
>&
1439
ShaferShenoyMNInference
<
GUM_SCALAR
>::
jointPosterior_
(
1440
const
NodeSet
&
wanted_target
,
1441
const
NodeSet
&
declared_target
) {
1442
// check if we have already computed the posterior of wanted_target
1443
if
(
joint_target_posteriors__
.
exists
(
wanted_target
))
1444
return
*(
joint_target_posteriors__
[
wanted_target
]);
1445
1446
// here, we will have to compute the posterior of declared_target and
1447
// marginalize out all the variables that do not belong to wanted_target
1448
1449
// check if we have already computed the posterior of declared_target
1450
if
(!
joint_target_posteriors__
.
exists
(
declared_target
)) {
1451
jointPosterior_
(
declared_target
);
1452
}
1453
1454
// marginalize out all the variables that do not belong to wanted_target
1455
const
auto
&
mn
=
this
->
MN
();
1456
Set
<
const
DiscreteVariable
* >
del_vars
;
1457
for
(
const
auto
node
:
declared_target
)
1458
if
(!
wanted_target
.
contains
(
node
))
del_vars
.
insert
(&(
mn
.
variable
(
node
)));
1459
auto
pot
=
new
Potential
<
GUM_SCALAR
>(
1460
joint_target_posteriors__
[
declared_target
]->
margSumOut
(
del_vars
));
1461
1462
// save the result into the cache
1463
joint_target_posteriors__
.
insert
(
wanted_target
,
pot
);
1464
1465
return
*
pot
;
1466
}
1467
1468
1469
template
<
typename
GUM_SCALAR
>
1470
GUM_SCALAR
ShaferShenoyMNInference
<
GUM_SCALAR
>::
evidenceProbability
() {
1471
// perform inference in each connected component
1472
this
->
makeInference
();
1473
1474
// for each connected component, select a variable X and compute the
1475
// joint probability of X and evidence e. Then marginalize-out X to get
1476
// p(e) in this connected component. Finally, multiply all the p(e) that
1477
// we got and the elements in constants__. The result is the probability
1478
// of evidence
1479
1480
GUM_SCALAR
prob_ev
= 1;
1481
for
(
const
auto
root
:
roots__
) {
1482
// get a node in the clique
1483
NodeId
node
= *(
propagator__
->
clique
(
root
).
begin
());
1484
Potential
<
GUM_SCALAR
>*
tmp
=
unnormalizedJointPosterior_
(
node
);
1485
GUM_SCALAR
sum
= 0;
1486
for
(
Instantiation
iter
(*
tmp
); !
iter
.
end
(); ++
iter
)
1487
sum
+=
tmp
->
get
(
iter
);
1488
prob_ev
*=
sum
;
1489
delete
tmp
;
1490
}
1491
1492
// for (const auto& projected_cpt: constants__)
1493
// prob_ev *= projected_cpt.second;
1494
1495
return
prob_ev
;
1496
}
1497
1498
template
<
typename
GUM_SCALAR
>
1499
bool
ShaferShenoyMNInference
<
GUM_SCALAR
>::
isExactJointComputable_
(
1500
const
NodeSet
&
vars
) {
1501
if
(
JointTargetedMNInference
<
GUM_SCALAR
>::
isExactJointComputable_
(
vars
))
1502
return
true
;
1503
1504
this
->
prepareInference
();
1505
1506
for
(
const
auto
&
cliknode
:
this
->
propagator__
->
nodes
()) {
1507
GUM_TRACE_VAR
(
cliknode
);
1508
const
auto
clique
=
propagator__
->
clique
(
cliknode
);
1509
GUM_TRACE_VAR
(
clique
);
1510
if
(
vars
==
clique
)
return
true
;
1511
GUM_CHECKPOINT
;
1512
}
1513
GUM_CHECKPOINT
;
1514
return
false
;
1515
}
1516
1517
template
<
typename
GUM_SCALAR
>
1518
NodeSet
ShaferShenoyMNInference
<
GUM_SCALAR
>::
superForJointComputable_
(
1519
const
NodeSet
&
vars
) {
1520
const
auto
superset
1521
=
JointTargetedMNInference
<
GUM_SCALAR
>::
superForJointComputable_
(
vars
);
1522
if
(!
superset
.
empty
())
return
superset
;
1523
1524
this
->
prepareInference
();
1525
1526
for
(
const
auto
&
cliknode
:
propagator__
->
nodes
()) {
1527
const
auto
clique
=
propagator__
->
clique
(
cliknode
);
1528
if
(
vars
.
isProperSubsetOf
(
clique
))
return
clique
;
1529
}
1530
1531
1532
return
NodeSet
();
1533
}
1534
1535
}
/* namespace gum */
1536
1537
#
endif
// DOXYGEN_SHOULD_SKIP_THIS
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669