aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
ShaferShenoyInference_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of Shafer-Shenoy's propagation for inference in
25
* Bayesian networks.
26
*/
27
28
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
29
#
include
<
agrum
/
BN
/
inference
/
ShaferShenoyInference
.
h
>
30
31
#
include
<
agrum
/
BN
/
algorithms
/
BayesBall
.
h
>
32
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
33
#
include
<
agrum
/
BN
/
algorithms
/
dSeparation
.
h
>
34
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
35
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
36
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
37
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimProjection
.
h
>
38
39
40
namespace
gum
{
41
// default constructor
42
template
<
typename
GUM_SCALAR >
43
INLINE ShaferShenoyInference<
GUM_SCALAR
>::
ShaferShenoyInference
(
44
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
45
FindBarrenNodesType
barren_type
,
46
bool
use_binary_join_tree
) :
47
JointTargetedInference
<
GUM_SCALAR
>(
BN
),
48
EvidenceInference
<
GUM_SCALAR
>(
BN
),
49
use_binary_join_tree__
(
use_binary_join_tree
) {
50
// sets the barren nodes finding algorithm
51
setFindBarrenNodesType
(
barren_type
);
52
53
// create a default triangulation (the user can change it afterwards)
54
triangulation__
=
new
DefaultTriangulation
;
55
56
// for debugging purposessetRequiredInference
57
GUM_CONSTRUCTOR
(
ShaferShenoyInference
);
58
}
59
60
61
// destructor
62
template
<
typename
GUM_SCALAR
>
63
INLINE
ShaferShenoyInference
<
GUM_SCALAR
>::~
ShaferShenoyInference
() {
64
// remove all the potentials created during the last message passing
65
for
(
const
auto
&
pots
:
created_potentials__
)
66
for
(
const
auto
pot
:
pots
.
second
)
67
delete
pot
;
68
69
// remove the potentials created after removing the nodes that received
70
// hard evidence
71
for
(
const
auto
&
pot
:
hard_ev_projected_CPTs__
)
72
delete
pot
.
second
;
73
74
// remove all the potentials in clique_ss_potential__ that do not belong
75
// to clique_potentials__ : in this case, those potentials have been
76
// created by combination of the corresponding list of potentials in
77
// clique_potentials__. In other words, the size of this list is strictly
78
// greater than 1.
79
for
(
auto
pot
:
clique_ss_potential__
) {
80
if
(
clique_potentials__
[
pot
.
first
].
size
() > 1)
delete
pot
.
second
;
81
}
82
83
// remove all the posteriors computed
84
for
(
const
auto
&
pot
:
target_posteriors__
)
85
delete
pot
.
second
;
86
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
87
delete
pot
.
second
;
88
89
// remove the junction tree and the triangulation algorithm
90
if
(
JT__
!=
nullptr
)
delete
JT__
;
91
if
(
junctionTree__
!=
nullptr
)
delete
junctionTree__
;
92
delete
triangulation__
;
93
94
// for debugging purposes
95
GUM_DESTRUCTOR
(
ShaferShenoyInference
);
96
}
97
98
99
/// set a new triangulation algorithm
100
template
<
typename
GUM_SCALAR
>
101
void
ShaferShenoyInference
<
GUM_SCALAR
>::
setTriangulation
(
102
const
Triangulation
&
new_triangulation
) {
103
delete
triangulation__
;
104
triangulation__
=
new_triangulation
.
newFactory
();
105
is_new_jt_needed__
=
true
;
106
this
->
setOutdatedStructureState_
();
107
}
108
109
110
/// returns the current join tree used
111
template
<
typename
GUM_SCALAR
>
112
INLINE
const
JoinTree
*
ShaferShenoyInference
<
GUM_SCALAR
>::
joinTree
() {
113
createNewJT__
();
114
115
return
JT__
;
116
}
117
118
/// returns the current junction tree
119
template
<
typename
GUM_SCALAR
>
120
INLINE
const
JunctionTree
*
ShaferShenoyInference
<
GUM_SCALAR
>::
junctionTree
() {
121
createNewJT__
();
122
123
return
junctionTree__
;
124
}
125
126
127
/// sets the operator for performing the projections
128
template
<
typename
GUM_SCALAR
>
129
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
setProjectionFunction__
(
130
Potential
<
GUM_SCALAR
>* (*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
131
const
Set
<
const
DiscreteVariable
* >&)) {
132
projection_op__
=
proj
;
133
}
134
135
136
/// sets the operator for performing the combinations
137
template
<
typename
GUM_SCALAR
>
138
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
setCombinationFunction__
(
139
Potential
<
GUM_SCALAR
>* (*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
140
const
Potential
<
GUM_SCALAR
>&)) {
141
combination_op__
=
comb
;
142
}
143
144
145
/// invalidate all messages, posteriors and created potentials
146
template
<
typename
GUM_SCALAR
>
147
void
ShaferShenoyInference
<
GUM_SCALAR
>::
invalidateAllMessages__
() {
148
// remove all the messages computed
149
for
(
auto
&
potset
:
separator_potentials__
)
150
potset
.
second
.
clear
();
151
for
(
auto
&
mess_computed
:
messages_computed__
)
152
mess_computed
.
second
=
false
;
153
154
// remove all the created potentials
155
for
(
const
auto
&
potset
:
created_potentials__
)
156
for
(
const
auto
pot
:
potset
.
second
)
157
delete
pot
;
158
159
// remove all the posteriors
160
for
(
const
auto
&
pot
:
target_posteriors__
)
161
delete
pot
.
second
;
162
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
163
delete
pot
.
second
;
164
165
// indicate that new messages need be computed
166
if
(
this
->
isInferenceReady
() ||
this
->
isInferenceDone
())
167
this
->
setOutdatedPotentialsState_
();
168
}
169
170
171
/// sets how we determine barren nodes
172
template
<
typename
GUM_SCALAR
>
173
void
ShaferShenoyInference
<
GUM_SCALAR
>::
setFindBarrenNodesType
(
174
FindBarrenNodesType
type
) {
175
if
(
type
!=
barren_nodes_type__
) {
176
// WARNING: if a new type is added here, method createJT__ should
177
// certainly be updated as well, in particular its step 2.
178
switch
(
type
) {
179
case
FindBarrenNodesType
::
FIND_BARREN_NODES
:
180
case
FindBarrenNodesType
::
FIND_NO_BARREN_NODES
:
181
break
;
182
183
default
:
184
GUM_ERROR
(
InvalidArgument
,
185
"setFindBarrenNodesType for type "
186
<< (
unsigned
int
)
type
<<
" is not implemented yet"
);
187
}
188
189
barren_nodes_type__
=
type
;
190
191
// potentially, we may need to reconstruct a junction tree
192
this
->
setOutdatedStructureState_
();
193
}
194
}
195
196
197
/// fired when a new evidence is inserted
198
template
<
typename
GUM_SCALAR
>
199
INLINE
void
200
ShaferShenoyInference
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
id
,
201
bool
isHardEvidence
) {
202
// if we have a new hard evidence, this modifies the undigraph over which
203
// the join tree is created. This is also the case if id is not a node of
204
// of the undigraph
205
if
(
isHardEvidence
|| !
graph__
.
exists
(
id
))
206
is_new_jt_needed__
=
true
;
207
else
{
208
try
{
209
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ADDED
);
210
}
catch
(
DuplicateElement
&) {
211
// here, the evidence change already existed. This necessarily means
212
// that the current saved change is an EVIDENCE_ERASED. So if we
213
// erased the evidence and added some again, this corresponds to an
214
// EVIDENCE_MODIFIED
215
evidence_changes__
[
id
] =
EvidenceChangeType
::
EVIDENCE_MODIFIED
;
216
}
217
}
218
}
219
220
221
/// fired when an evidence is removed
222
template
<
typename
GUM_SCALAR
>
223
INLINE
void
224
ShaferShenoyInference
<
GUM_SCALAR
>::
onEvidenceErased_
(
const
NodeId
id
,
225
bool
isHardEvidence
) {
226
// if we delete a hard evidence, this modifies the undigraph over which
227
// the join tree is created.
228
if
(
isHardEvidence
)
229
is_new_jt_needed__
=
true
;
230
else
{
231
try
{
232
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
233
}
catch
(
DuplicateElement
&) {
234
// here, the evidence change already existed and it is necessarily an
235
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
236
// been added and is now erased, this is similar to not having created
237
// it. If the evidence was only modified, it already existed in the
238
// last inference and we should now indicate that it has been removed.
239
if
(
evidence_changes__
[
id
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
240
evidence_changes__
.
erase
(
id
);
241
else
242
evidence_changes__
[
id
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
243
}
244
}
245
}
246
247
248
/// fired when all the evidence are erased
249
template
<
typename
GUM_SCALAR
>
250
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
251
bool
has_hard_evidence
) {
252
if
(
has_hard_evidence
|| !
this
->
hardEvidenceNodes
().
empty
())
253
is_new_jt_needed__
=
true
;
254
else
{
255
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
256
try
{
257
evidence_changes__
.
insert
(
node
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
258
}
catch
(
DuplicateElement
&) {
259
// here, the evidence change already existed and it is necessarily an
260
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
261
// been added and is now erased, this is similar to not having created
262
// it. If the evidence was only modified, it already existed in the
263
// last inference and we should now indicate that it has been removed.
264
if
(
evidence_changes__
[
node
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
265
evidence_changes__
.
erase
(
node
);
266
else
267
evidence_changes__
[
node
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
268
}
269
}
270
}
271
}
272
273
274
/// fired when an evidence is changed
275
template
<
typename
GUM_SCALAR
>
276
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onEvidenceChanged_
(
277
const
NodeId
id
,
278
bool
hasChangedSoftHard
) {
279
if
(
hasChangedSoftHard
)
280
is_new_jt_needed__
=
true
;
281
else
{
282
try
{
283
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_MODIFIED
);
284
}
catch
(
DuplicateElement
&) {
285
// here, the evidence change already existed and it is necessarily an
286
// EVIDENCE_ADDED. So we should keep this state to indicate that this
287
// evidence is new w.r.t. the last inference
288
}
289
}
290
}
291
292
293
/// fired after a new target is inserted
294
template
<
typename
GUM_SCALAR
>
295
INLINE
void
296
ShaferShenoyInference
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
const
NodeId
id
) {
297
}
298
299
300
/// fired before a target is removed
301
template
<
typename
GUM_SCALAR
>
302
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
303
const
NodeId
id
) {}
304
305
306
/// fired after a new set target is inserted
307
template
<
typename
GUM_SCALAR
>
308
INLINE
void
309
ShaferShenoyInference
<
GUM_SCALAR
>::
onJointTargetAdded_
(
const
NodeSet
&
set
) {
310
}
311
312
313
/// fired before a set target is removed
314
template
<
typename
GUM_SCALAR
>
315
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onJointTargetErased_
(
316
const
NodeSet
&
set
) {}
317
318
319
/// fired after all the nodes of the BN are added as single targets
320
template
<
typename
GUM_SCALAR
>
321
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {}
322
323
324
/// fired before a all the single_targets are removed
325
template
<
typename
GUM_SCALAR
>
326
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
327
328
/// fired after a new Bayes net has been assigned to the engine
329
template
<
typename
GUM_SCALAR
>
330
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onModelChanged_
(
331
const
GraphicalModel
*
bn
) {}
332
333
/// fired before a all the joint_targets are removed
334
template
<
typename
GUM_SCALAR
>
335
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllJointTargetsErased_
() {}
336
337
338
/// fired before a all the single and joint_targets are removed
339
template
<
typename
GUM_SCALAR
>
340
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllTargetsErased_
() {}
341
342
343
// check whether a new junction tree is really needed for the next inference
344
template
<
typename
GUM_SCALAR
>
345
bool
ShaferShenoyInference
<
GUM_SCALAR
>::
isNewJTNeeded__
()
const
{
346
// if we do not have a JT or if new_jt_needed__ is set to true, then
347
// we know that we need to create a new join tree
348
if
((
JT__
==
nullptr
) ||
is_new_jt_needed__
)
return
true
;
349
350
// if some some targets do not belong to the join tree and, consequently,
351
// to the undigraph that was used to construct the join tree, then we need
352
// to create a new JT. This situation may occur if we constructed the
353
// join tree after pruning irrelevant/barren nodes from the BN)
354
// however, note that the nodes that received hard evidence do not belong to
355
// the graph and, therefore, should not be taken into account
356
const
auto
&
hard_ev_nodes
=
this
->
hardEvidenceNodes
();
357
for
(
const
auto
node
:
this
->
targets
()) {
358
if
(!
graph__
.
exists
(
node
) && !
hard_ev_nodes
.
exists
(
node
))
return
true
;
359
}
360
for
(
const
auto
&
joint_target
:
this
->
jointTargets
()) {
361
// here, we need to check that at least one clique contains all the
362
// nodes of the joint target.
363
bool
containing_clique_found
=
false
;
364
for
(
const
auto
node
:
joint_target
) {
365
bool
found
=
true
;
366
try
{
367
const
NodeSet
&
clique
=
JT__
->
clique
(
node_to_clique__
[
node
]);
368
for
(
const
auto
xnode
:
joint_target
) {
369
if
(!
clique
.
contains
(
xnode
) && !
hard_ev_nodes
.
exists
(
xnode
)) {
370
found
=
false
;
371
break
;
372
}
373
}
374
}
catch
(
NotFound
&) {
found
=
false
; }
375
376
if
(
found
) {
377
containing_clique_found
=
true
;
378
break
;
379
}
380
}
381
382
if
(!
containing_clique_found
)
return
true
;
383
}
384
385
// if some new evidence have been added on nodes that do not belong
386
// to graph__, then we potentially have to reconstruct the join tree
387
for
(
const
auto
&
change
:
evidence_changes__
) {
388
if
((
change
.
second
==
EvidenceChangeType
::
EVIDENCE_ADDED
)
389
&& !
graph__
.
exists
(
change
.
first
))
390
return
true
;
391
}
392
393
// here, the current JT is exactly what we need for the next inference
394
return
false
;
395
}
396
397
398
/// create a new junction tree as well as its related data structures
399
template
<
typename
GUM_SCALAR
>
400
void
ShaferShenoyInference
<
GUM_SCALAR
>::
createNewJT__
() {
401
// to create the JT, we first create the moral graph of the BN in the
402
// following way in order to take into account the barren nodes and the
403
// nodes that received evidence:
404
// 1/ we create an undirected graph containing only the nodes and no edge
405
// 2/ if we take into account barren nodes, remove them from the graph
406
// 3/ add edges so that each node and its parents in the BN form a clique
407
// 4/ add edges so that set targets are cliques of the moral graph
408
// 5/ remove the nodes that received hard evidence (by step 3/, their
409
// parents are linked by edges, which is necessary for inference)
410
//
411
// At the end of step 5/, we have our moral graph and we can triangulate it
412
// to get the new junction tree
413
414
// 1/ create an undirected graph containing only the nodes and no edge
415
const
auto
&
bn
=
this
->
BN
();
416
graph__
.
clear
();
417
for
(
auto
node
:
bn
.
dag
())
418
graph__
.
addNodeWithId
(
node
);
419
420
// 2/ if we wish to exploit barren nodes, we shall remove them from the
421
// BN to do so: we identify all the nodes that are not targets and have
422
// received no evidence and such that their descendants are neither
423
// targets nor evidence nodes. Such nodes can be safely discarded from
424
// the BN without altering the inference output
425
if
(
barren_nodes_type__
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
426
// identify the barren nodes
427
NodeSet
target_nodes
=
this
->
targets
();
428
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
429
target_nodes
+=
nodeset
;
430
}
431
432
// check that all the nodes are not targets, otherwise, there is no
433
// barren node
434
if
(
target_nodes
.
size
() !=
bn
.
size
()) {
435
BarrenNodesFinder
finder
(&(
bn
.
dag
()));
436
finder
.
setTargets
(&
target_nodes
);
437
438
NodeSet
evidence_nodes
;
439
for
(
const
auto
&
pair
:
this
->
evidence
()) {
440
evidence_nodes
.
insert
(
pair
.
first
);
441
}
442
finder
.
setEvidence
(&
evidence_nodes
);
443
444
NodeSet
barren_nodes
=
finder
.
barrenNodes
();
445
446
// remove the barren nodes from the moral graph
447
for
(
const
auto
node
:
barren_nodes
) {
448
graph__
.
eraseNode
(
node
);
449
}
450
}
451
}
452
453
// 3/ add edges so that each node and its parents in the BN form a clique
454
for
(
const
auto
node
:
graph__
) {
455
const
NodeSet
&
parents
=
bn
.
parents
(
node
);
456
for
(
auto
iter1
=
parents
.
cbegin
();
iter1
!=
parents
.
cend
(); ++
iter1
) {
457
graph__
.
addEdge
(*
iter1
,
node
);
458
auto
iter2
=
iter1
;
459
for
(++
iter2
;
iter2
!=
parents
.
cend
(); ++
iter2
) {
460
graph__
.
addEdge
(*
iter1
, *
iter2
);
461
}
462
}
463
}
464
465
// 4/ if there exist some joint targets, we shall add new edges into the
466
// moral graph in order to ensure that there exists a clique containing
467
// each joint
468
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
469
for
(
auto
iter1
=
nodeset
.
cbegin
();
iter1
!=
nodeset
.
cend
(); ++
iter1
) {
470
auto
iter2
=
iter1
;
471
for
(++
iter2
;
iter2
!=
nodeset
.
cend
(); ++
iter2
) {
472
graph__
.
addEdge
(*
iter1
, *
iter2
);
473
}
474
}
475
}
476
477
// 5/ remove all the nodes that received hard evidence
478
hard_ev_nodes__
=
this
->
hardEvidenceNodes
();
479
for
(
const
auto
node
:
hard_ev_nodes__
) {
480
graph__
.
eraseNode
(
node
);
481
}
482
483
484
// now, we can compute the new junction tree. To speed-up computations
485
// (essentially, those of a distribution phase), we construct from this
486
// junction tree a binary join tree
487
if
(
JT__
!=
nullptr
)
delete
JT__
;
488
if
(
junctionTree__
!=
nullptr
)
delete
junctionTree__
;
489
490
triangulation__
->
setGraph
(&
graph__
, &(
this
->
domainSizes
()));
491
const
JunctionTree
&
triang_jt
=
triangulation__
->
junctionTree
();
492
if
(
use_binary_join_tree__
) {
493
BinaryJoinTreeConverterDefault
bjt_converter
;
494
NodeSet
emptyset
;
495
JT__
=
new
CliqueGraph
(
496
bjt_converter
.
convert
(
triang_jt
,
this
->
domainSizes
(),
emptyset
));
497
}
else
{
498
JT__
=
new
CliqueGraph
(
triang_jt
);
499
}
500
junctionTree__
=
new
CliqueGraph
(
triang_jt
);
501
502
503
// indicate, for each node of the moral graph a clique in JT__ that can
504
// contain its conditional probability table
505
node_to_clique__
.
clear
();
506
const
std
::
vector
<
NodeId
>&
JT_elim_order
507
=
triangulation__
->
eliminationOrder
();
508
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
509
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
;
510
++
i
)
511
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
512
const
DAG
&
dag
=
bn
.
dag
();
513
for
(
const
auto
node
:
graph__
) {
514
// get the variables in the potential of node (and its parents)
515
NodeId
first_eliminated_node
=
node
;
516
int
elim_number
=
elim_order
[
first_eliminated_node
];
517
518
for
(
const
auto
parent
:
dag
.
parents
(
node
)) {
519
if
(
graph__
.
existsNode
(
parent
) && (
elim_order
[
parent
] <
elim_number
)) {
520
elim_number
=
elim_order
[
parent
];
521
first_eliminated_node
=
parent
;
522
}
523
}
524
525
// first_eliminated_node contains the first var (node or one of its
526
// parents) eliminated => the clique created during its elimination
527
// contains node and all of its parents => it can contain the potential
528
// assigned to the node in the BN
529
node_to_clique__
.
insert
(
530
node
,
531
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
532
}
533
534
// do the same for the nodes that received evidence. Here, we only store
535
// the nodes for which at least one parent belongs to graph__ (otherwise
536
// their CPT is just a constant real number).
537
for
(
const
auto
node
:
hard_ev_nodes__
) {
538
// get the set of parents of the node that belong to graph__
539
NodeSet
pars
(
dag
.
parents
(
node
).
size
());
540
for
(
const
auto
par
:
dag
.
parents
(
node
))
541
if
(
graph__
.
exists
(
par
))
pars
.
insert
(
par
);
542
543
if
(!
pars
.
empty
()) {
544
NodeId
first_eliminated_node
= *(
pars
.
begin
());
545
int
elim_number
=
elim_order
[
first_eliminated_node
];
546
547
for
(
const
auto
parent
:
pars
) {
548
if
(
elim_order
[
parent
] <
elim_number
) {
549
elim_number
=
elim_order
[
parent
];
550
first_eliminated_node
=
parent
;
551
}
552
}
553
554
// first_eliminated_node contains the first var (node or one of its
555
// parents) eliminated => the clique created during its elimination
556
// contains node and all of its parents => it can contain the potential
557
// assigned to the node in the BN
558
node_to_clique__
.
insert
(
559
node
,
560
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
561
}
562
}
563
564
// indicate for each joint_target a clique that contains it
565
joint_target_to_clique__
.
clear
();
566
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
567
// remove from set all the nodes that received hard evidence (since they
568
// do not belong to the join tree)
569
NodeSet
nodeset
=
set
;
570
for
(
const
auto
node
:
hard_ev_nodes__
)
571
if
(
nodeset
.
contains
(
node
))
nodeset
.
erase
(
node
);
572
573
if
(!
nodeset
.
empty
()) {
574
// the clique we are looking for is the one that was created when
575
// the first element of nodeset was eliminated
576
NodeId
first_eliminated_node
= *(
nodeset
.
begin
());
577
int
elim_number
=
elim_order
[
first_eliminated_node
];
578
for
(
const
auto
node
:
nodeset
) {
579
if
(
elim_order
[
node
] <
elim_number
) {
580
elim_number
=
elim_order
[
node
];
581
first_eliminated_node
=
node
;
582
}
583
}
584
585
joint_target_to_clique__
.
insert
(
586
set
,
587
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
588
}
589
}
590
591
592
// compute the roots of JT__'s connected components
593
computeJoinTreeRoots__
();
594
595
// remove all the Shafer-Shenoy potentials stored into the cliques
596
for
(
const
auto
&
xpot
:
clique_ss_potential__
) {
597
if
(
clique_potentials__
[
xpot
.
first
].
size
() > 1)
delete
xpot
.
second
;
598
}
599
600
// create empty potential lists into the cliques of the joint tree as well
601
// as empty lists of evidence
602
PotentialSet__
empty_set
;
603
clique_potentials__
.
clear
();
604
node_to_soft_evidence__
.
clear
();
605
for
(
const
auto
node
: *
JT__
) {
606
clique_potentials__
.
insert
(
node
,
empty_set
);
607
}
608
609
// remove all the potentials created during the last inference
610
for
(
auto
&
potlist
:
created_potentials__
)
611
for
(
auto
pot
:
potlist
.
second
)
612
delete
pot
;
613
created_potentials__
.
clear
();
614
615
// remove all the potentials created to take into account hard evidence
616
// during the last inference
617
for
(
auto
pot_pair
:
hard_ev_projected_CPTs__
)
618
delete
pot_pair
.
second
;
619
hard_ev_projected_CPTs__
.
clear
();
620
621
// remove all the constants created due to projections of CPTs that were
622
// defined over only hard evidence nodes
623
constants__
.
clear
();
624
625
// create empty lists of potentials for the messages and indicate that no
626
// message has been computed yet
627
separator_potentials__
.
clear
();
628
messages_computed__
.
clear
();
629
for
(
const
auto
&
edge
:
JT__
->
edges
()) {
630
const
Arc
arc1
(
edge
.
first
(),
edge
.
second
());
631
separator_potentials__
.
insert
(
arc1
,
empty_set
);
632
messages_computed__
.
insert
(
arc1
,
false
);
633
const
Arc
arc2
(
Arc
(
edge
.
second
(),
edge
.
first
()));
634
separator_potentials__
.
insert
(
arc2
,
empty_set
);
635
messages_computed__
.
insert
(
arc2
,
false
);
636
}
637
638
// remove all the posteriors computed so far
639
for
(
const
auto
&
pot
:
target_posteriors__
)
640
delete
pot
.
second
;
641
target_posteriors__
.
clear
();
642
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
643
delete
pot
.
second
;
644
joint_target_posteriors__
.
clear
();
645
646
647
// put all the CPT's of the Bayes net nodes into the cliques
648
// here, beware: all the potentials that are defined over some nodes
649
// including hard evidence must be projected so that these nodes are
650
// removed from the potential
651
const
auto
&
evidence
=
this
->
evidence
();
652
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
653
for
(
const
auto
node
:
dag
) {
654
if
(
graph__
.
exists
(
node
) ||
hard_ev_nodes__
.
contains
(
node
)) {
655
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
656
657
// get the list of nodes with hard evidence in cpt
658
NodeSet
hard_nodes
;
659
const
auto
&
variables
=
cpt
.
variablesSequence
();
660
for
(
const
auto
var
:
variables
) {
661
NodeId
xnode
=
bn
.
nodeId
(*
var
);
662
if
(
hard_ev_nodes__
.
contains
(
xnode
))
hard_nodes
.
insert
(
xnode
);
663
}
664
665
// if hard_nodes contains hard evidence nodes, perform a projection
666
// and insert the result into the appropriate clique, else insert
667
// directly cpt into the clique
668
if
(
hard_nodes
.
empty
()) {
669
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(&
cpt
);
670
}
else
{
671
// marginalize out the hard evidence nodes: if the cpt is defined
672
// only over nodes that received hard evidence, do not consider it
673
// as a potential anymore but as a constant
674
if
(
hard_nodes
.
size
() ==
variables
.
size
()) {
675
Instantiation
inst
;
676
const
auto
&
vars
=
cpt
.
variablesSequence
();
677
for
(
auto
var
:
vars
)
678
inst
<< *
var
;
679
for
(
Size
i
= 0;
i
<
hard_nodes
.
size
(); ++
i
) {
680
inst
.
chgVal
(
variables
[
i
],
hard_evidence
[
bn
.
nodeId
(*(
variables
[
i
]))]);
681
}
682
constants__
.
insert
(
node
,
cpt
[
inst
]);
683
}
else
{
684
// perform the projection with a combine and project instance
685
Set
<
const
DiscreteVariable
* >
hard_variables
;
686
PotentialSet__
marg_cpt_set
{&
cpt
};
687
for
(
const
auto
xnode
:
hard_nodes
) {
688
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
689
hard_variables
.
insert
(&(
bn
.
variable
(
xnode
)));
690
}
691
692
// perform the combination of those potentials and their projection
693
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
694
combine_and_project
(
combination_op__
,
SSNewprojPotential
);
695
PotentialSet__
new_cpt_list
696
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
697
hard_variables
);
698
699
// there should be only one potential in new_cpt_list
700
if
(
new_cpt_list
.
size
() != 1) {
701
// remove the CPT created to avoid memory leaks
702
for
(
auto
pot
:
new_cpt_list
) {
703
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
704
}
705
GUM_ERROR
(
FatalError
,
706
"the projection of a potential containing "
707
<<
"hard evidence is empty!"
);
708
}
709
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
710
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
projected_cpt
);
711
hard_ev_projected_CPTs__
.
insert
(
node
,
projected_cpt
);
712
}
713
}
714
}
715
}
716
717
// we shall now add all the potentials of the soft evidence
718
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
719
node_to_soft_evidence__
.
insert
(
node
,
evidence
[
node
]);
720
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
evidence
[
node
]);
721
}
722
723
// now, in clique_potentials__, for each clique, we have the list of
724
// potentials that must be combined in order to produce the Shafer-Shenoy's
725
// potential stored into the clique. So, perform this combination and
726
// store the result in clique_ss_potential__
727
clique_ss_potential__
.
clear
();
728
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
729
combination_op__
);
730
for
(
const
auto
&
xpotset
:
clique_potentials__
) {
731
const
auto
&
potset
=
xpotset
.
second
;
732
if
(
potset
.
size
() > 0) {
733
// here, there will be an entry in clique_ss_potential__
734
// If there is only one element in potset, this element shall be
735
// stored into clique_ss_potential__, else all the elements of potset
736
// shall be combined and their result shall be stored
737
if
(
potset
.
size
() == 1) {
738
clique_ss_potential__
.
insert
(
xpotset
.
first
, *(
potset
.
cbegin
()));
739
}
else
{
740
auto
joint
=
fast_combination
.
combine
(
potset
);
741
clique_ss_potential__
.
insert
(
xpotset
.
first
,
joint
);
742
}
743
}
744
}
745
746
// indicate that the data structures are up to date.
747
evidence_changes__
.
clear
();
748
is_new_jt_needed__
=
false
;
749
}
750
751
752
/// prepare the inference structures w.r.t. new targets, soft/hard evidence
753
template
<
typename
GUM_SCALAR
>
754
void
ShaferShenoyInference
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {
755
// check if a new JT is really needed. If so, create it
756
if
(
isNewJTNeeded__
()) {
757
createNewJT__
();
758
}
else
{
759
// here, we can answer the next queries without reconstructing all the
760
// junction tree. All we need to do is to indicate that we should
761
// update the potentials and messages for these queries
762
updateOutdatedPotentials_
();
763
}
764
}
765
766
767
/// invalidate all the messages sent from a given clique
768
template
<
typename
GUM_SCALAR
>
769
void
ShaferShenoyInference
<
GUM_SCALAR
>::
diffuseMessageInvalidations__
(
770
NodeId
from_id
,
771
NodeId
to_id
,
772
NodeSet
&
invalidated_cliques
) {
773
// invalidate the current clique
774
invalidated_cliques
.
insert
(
to_id
);
775
776
// invalidate the current arc
777
const
Arc
arc
(
from_id
,
to_id
);
778
bool
&
message_computed
=
messages_computed__
[
arc
];
779
if
(
message_computed
) {
780
message_computed
=
false
;
781
separator_potentials__
[
arc
].
clear
();
782
if
(
created_potentials__
.
exists
(
arc
)) {
783
auto
&
arc_created_potentials
=
created_potentials__
[
arc
];
784
for
(
auto
pot
:
arc_created_potentials
)
785
delete
pot
;
786
arc_created_potentials
.
clear
();
787
}
788
789
// go on with the diffusion
790
for
(
const
auto
node_id
:
JT__
->
neighbours
(
to_id
)) {
791
if
(
node_id
!=
from_id
)
792
diffuseMessageInvalidations__
(
to_id
,
node_id
,
invalidated_cliques
);
793
}
794
}
795
}
796
797
798
/// update the potentials stored in the cliques and invalidate outdated
799
/// messages
800
template
<
typename
GUM_SCALAR
>
801
void
ShaferShenoyInference
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {
802
// for each clique, indicate whether the potential stored into
803
// clique_ss_potentials__[clique] is the result of a combination. In this
804
// case, it has been allocated by the combination and will need to be
805
// deallocated if its clique has been invalidated
806
NodeProperty
<
bool
>
ss_potential_to_deallocate
(
clique_potentials__
.
size
());
807
for
(
auto
pot_iter
=
clique_potentials__
.
cbegin
();
808
pot_iter
!=
clique_potentials__
.
cend
();
809
++
pot_iter
) {
810
ss_potential_to_deallocate
.
insert
(
pot_iter
.
key
(),
811
(
pot_iter
.
val
().
size
() > 1));
812
}
813
814
// compute the set of CPTs that were projected due to hard evidence and
815
// whose hard evidence have changed, so that they need a new projection.
816
// By the way, remove these CPTs since they are no more needed
817
// Here only the values of the hard evidence can have changed (else a
818
// fully new join tree would have been computed).
819
// Note also that we know that the CPTs still contain some variable(s) after
820
// the projection (else they should be constants)
821
NodeSet
hard_nodes_changed
(
hard_ev_nodes__
.
size
());
822
for
(
const
auto
node
:
hard_ev_nodes__
)
823
if
(
evidence_changes__
.
exists
(
node
))
hard_nodes_changed
.
insert
(
node
);
824
825
NodeSet
nodes_with_projected_CPTs_changed
;
826
const
auto
&
bn
=
this
->
BN
();
827
for
(
auto
pot_iter
=
hard_ev_projected_CPTs__
.
beginSafe
();
828
pot_iter
!=
hard_ev_projected_CPTs__
.
endSafe
();
829
++
pot_iter
) {
830
for
(
const
auto
var
:
bn
.
cpt
(
pot_iter
.
key
()).
variablesSequence
()) {
831
if
(
hard_nodes_changed
.
contains
(
bn
.
nodeId
(*
var
))) {
832
nodes_with_projected_CPTs_changed
.
insert
(
pot_iter
.
key
());
833
delete
pot_iter
.
val
();
834
clique_potentials__
[
node_to_clique__
[
pot_iter
.
key
()]].
erase
(
835
pot_iter
.
val
());
836
hard_ev_projected_CPTs__
.
erase
(
pot_iter
);
837
break
;
838
}
839
}
840
}
841
842
843
// invalidate all the messages that are no more correct: start from each of
844
// the nodes whose soft evidence has changed and perform a diffusion from
845
// the clique into which the soft evidence has been entered, indicating that
846
// the messages spreading from this clique are now invalid. At the same time,
847
// if there were potentials created on the arcs over which the messages were
848
// sent, remove them from memory. For all the cliques that received some
849
// projected CPT that should now be changed, do the same.
850
NodeSet
invalidated_cliques
(
JT__
->
size
());
851
for
(
const
auto
&
pair
:
evidence_changes__
) {
852
if
(
node_to_clique__
.
exists
(
pair
.
first
)) {
853
const
auto
clique
=
node_to_clique__
[
pair
.
first
];
854
invalidated_cliques
.
insert
(
clique
);
855
for
(
const
auto
neighbor
:
JT__
->
neighbours
(
clique
)) {
856
diffuseMessageInvalidations__
(
clique
,
neighbor
,
invalidated_cliques
);
857
}
858
}
859
}
860
861
// now, add to the set of invalidated cliques those that contain projected
862
// CPTs that were changed.
863
for
(
auto
node
:
nodes_with_projected_CPTs_changed
) {
864
const
auto
clique
=
node_to_clique__
[
node
];
865
invalidated_cliques
.
insert
(
clique
);
866
for
(
const
auto
neighbor
:
JT__
->
neighbours
(
clique
)) {
867
diffuseMessageInvalidations__
(
clique
,
neighbor
,
invalidated_cliques
);
868
}
869
}
870
871
// now that we know the cliques whose set of potentials have been changed,
872
// we can discard their corresponding Shafer-Shenoy potential
873
for
(
const
auto
clique
:
invalidated_cliques
) {
874
if
(
clique_ss_potential__
.
exists
(
clique
)
875
&&
ss_potential_to_deallocate
[
clique
]) {
876
delete
clique_ss_potential__
[
clique
];
877
}
878
}
879
880
881
// now we shall remove all the posteriors that belong to the
882
// invalidated cliques. First, cope only with the nodes that did not
883
// received hard evidence since the other nodes do not belong to the
884
// join tree
885
for
(
auto
iter
=
target_posteriors__
.
beginSafe
();
886
iter
!=
target_posteriors__
.
endSafe
();
887
++
iter
) {
888
if
(
graph__
.
exists
(
iter
.
key
())
889
&& (
invalidated_cliques
.
exists
(
node_to_clique__
[
iter
.
key
()]))) {
890
delete
iter
.
val
();
891
target_posteriors__
.
erase
(
iter
);
892
}
893
}
894
895
// now cope with the nodes that received hard evidence
896
for
(
auto
iter
=
target_posteriors__
.
beginSafe
();
897
iter
!=
target_posteriors__
.
endSafe
();
898
++
iter
) {
899
if
(
hard_nodes_changed
.
contains
(
iter
.
key
())) {
900
delete
iter
.
val
();
901
target_posteriors__
.
erase
(
iter
);
902
}
903
}
904
905
// finally, cope with joint targets
906
for
(
auto
iter
=
joint_target_posteriors__
.
beginSafe
();
907
iter
!=
joint_target_posteriors__
.
endSafe
();
908
++
iter
) {
909
if
(
invalidated_cliques
.
exists
(
joint_target_to_clique__
[
iter
.
key
()])) {
910
delete
iter
.
val
();
911
joint_target_posteriors__
.
erase
(
iter
);
912
}
913
}
914
915
916
// remove all the evidence that were entered into node_to_soft_evidence__
917
// and clique_potentials__ and add the new soft ones
918
for
(
auto
&
pot_pair
:
node_to_soft_evidence__
) {
919
clique_potentials__
[
node_to_clique__
[
pot_pair
.
first
]].
erase
(
pot_pair
.
second
);
920
}
921
node_to_soft_evidence__
.
clear
();
922
923
const
auto
&
evidence
=
this
->
evidence
();
924
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
925
node_to_soft_evidence__
.
insert
(
node
,
evidence
[
node
]);
926
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
evidence
[
node
]);
927
}
928
929
930
// Now add the projections of the CPTs due to newly changed hard evidence:
931
// if we are performing updateOutdatedPotentials_, this means that the
932
// set of nodes that received hard evidence has not been changed, only
933
// their instantiations can have been changed. So, if there is an entry
934
// for node in constants__, there will still be such an entry after
935
// performing the new projections. Idem for hard_ev_projected_CPTs__
936
for
(
const
auto
node
:
nodes_with_projected_CPTs_changed
) {
937
// perform the projection with a combine and project instance
938
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
939
const
auto
&
variables
=
cpt
.
variablesSequence
();
940
Set
<
const
DiscreteVariable
* >
hard_variables
;
941
PotentialSet__
marg_cpt_set
{&
cpt
};
942
for
(
const
auto
var
:
variables
) {
943
NodeId
xnode
=
bn
.
nodeId
(*
var
);
944
if
(
hard_ev_nodes__
.
exists
(
xnode
)) {
945
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
946
hard_variables
.
insert
(
var
);
947
}
948
}
949
950
// perform the combination of those potentials and their projection
951
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
952
combine_and_project
(
combination_op__
,
SSNewprojPotential
);
953
PotentialSet__
new_cpt_list
954
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
955
956
// there should be only one potential in new_cpt_list
957
if
(
new_cpt_list
.
size
() != 1) {
958
// remove the CPT created to avoid memory leaks
959
for
(
auto
pot
:
new_cpt_list
) {
960
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
961
}
962
GUM_ERROR
(
FatalError
,
963
"the projection of a potential containing "
964
<<
"hard evidence is empty!"
);
965
}
966
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
967
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
projected_cpt
);
968
hard_ev_projected_CPTs__
.
insert
(
node
,
projected_cpt
);
969
}
970
971
// here, the list of potentials stored in the invalidated cliques have
972
// been updated. So, now, we can combine them to produce the Shafer-Shenoy
973
// potential stored into the clique
974
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
975
combination_op__
);
976
for
(
const
auto
clique
:
invalidated_cliques
) {
977
const
auto
&
potset
=
clique_potentials__
[
clique
];
978
979
if
(
potset
.
size
() > 0) {
980
// here, there will be an entry in clique_ss_potential__
981
// If there is only one element in potset, this element shall be
982
// stored into clique_ss_potential__, else all the elements of potset
983
// shall be combined and their result shall be stored
984
if
(
potset
.
size
() == 1) {
985
clique_ss_potential__
[
clique
] = *(
potset
.
cbegin
());
986
}
else
{
987
auto
joint
=
fast_combination
.
combine
(
potset
);
988
clique_ss_potential__
[
clique
] =
joint
;
989
}
990
}
991
}
992
993
994
// update the constants
995
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
996
for
(
auto
&
node_cst
:
constants__
) {
997
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node_cst
.
first
);
998
const
auto
&
variables
=
cpt
.
variablesSequence
();
999
Instantiation
inst
;
1000
for
(
const
auto
var
:
variables
)
1001
inst
<< *
var
;
1002
for
(
const
auto
var
:
variables
) {
1003
inst
.
chgVal
(
var
,
hard_evidence
[
bn
.
nodeId
(*
var
)]);
1004
}
1005
node_cst
.
second
=
cpt
[
inst
];
1006
}
1007
1008
// indicate that all changes have been performed
1009
evidence_changes__
.
clear
();
1010
}
1011
1012
1013
/// compute a root for each connected component of JT__
1014
template
<
typename
GUM_SCALAR
>
1015
void
ShaferShenoyInference
<
GUM_SCALAR
>::
computeJoinTreeRoots__
() {
1016
// get the set of cliques in which we can find the targets and joint_targets
1017
NodeSet
clique_targets
;
1018
for
(
const
auto
node
:
this
->
targets
()) {
1019
try
{
1020
clique_targets
.
insert
(
node_to_clique__
[
node
]);
1021
}
catch
(
Exception
&) {}
1022
}
1023
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
1024
try
{
1025
clique_targets
.
insert
(
joint_target_to_clique__
[
set
]);
1026
}
catch
(
Exception
&) {}
1027
}
1028
1029
// put in a vector these cliques and their size
1030
std
::
vector
<
std
::
pair
<
NodeId
,
Size
> >
possible_roots
(
clique_targets
.
size
());
1031
const
auto
&
bn
=
this
->
BN
();
1032
std
::
size_t
i
= 0;
1033
for
(
const
auto
clique_id
:
clique_targets
) {
1034
const
auto
&
clique
=
JT__
->
clique
(
clique_id
);
1035
Size
dom_size
= 1;
1036
for
(
const
auto
node
:
clique
) {
1037
dom_size
*=
bn
.
variable
(
node
).
domainSize
();
1038
}
1039
possible_roots
[
i
] =
std
::
pair
<
NodeId
,
Size
>(
clique_id
,
dom_size
);
1040
++
i
;
1041
}
1042
1043
// sort the cliques by increasing domain size
1044
std
::
sort
(
possible_roots
.
begin
(),
1045
possible_roots
.
end
(),
1046
[](
const
std
::
pair
<
NodeId
,
Size
>&
a
,
1047
const
std
::
pair
<
NodeId
,
Size
>&
b
) ->
bool
{
1048
return
a
.
second
<
b
.
second
;
1049
});
1050
1051
// pick up the clique with the smallest size in each connected component
1052
NodeProperty
<
bool
>
marked
=
JT__
->
nodesProperty
(
false
);
1053
std
::
function
<
void
(
NodeId
,
NodeId
) >
diffuse_marks
1054
= [&
marked
, &
diffuse_marks
,
this
](
NodeId
node
,
NodeId
from
) {
1055
if
(!
marked
[
node
]) {
1056
marked
[
node
] =
true
;
1057
for
(
const
auto
neigh
:
JT__
->
neighbours
(
node
))
1058
if
((
neigh
!=
from
) && !
marked
[
neigh
])
diffuse_marks
(
neigh
,
node
);
1059
}
1060
};
1061
roots__
.
clear
();
1062
for
(
const
auto
xclique
:
possible_roots
) {
1063
NodeId
clique
=
xclique
.
first
;
1064
if
(!
marked
[
clique
]) {
1065
roots__
.
insert
(
clique
);
1066
diffuse_marks
(
clique
,
clique
);
1067
}
1068
}
1069
}
1070
1071
1072
// performs the collect phase of Lazy Propagation
1073
template
<
typename
GUM_SCALAR
>
1074
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
collectMessage__
(
NodeId
id
,
1075
NodeId
from
) {
1076
for
(
const
auto
other
:
JT__
->
neighbours
(
id
)) {
1077
if
((
other
!=
from
) && !
messages_computed__
[
Arc
(
other
,
id
)])
1078
collectMessage__
(
other
,
id
);
1079
}
1080
1081
if
((
id
!=
from
) && !
messages_computed__
[
Arc
(
id
,
from
)]) {
1082
produceMessage__
(
id
,
from
);
1083
}
1084
}
1085
1086
1087
// remove barren variables
1088
template
<
typename
GUM_SCALAR
>
1089
Set
<
const
Potential
<
GUM_SCALAR
>* >
1090
ShaferShenoyInference
<
GUM_SCALAR
>::
removeBarrenVariables__
(
1091
PotentialSet__
&
pot_list
,
1092
Set
<
const
DiscreteVariable
* >&
del_vars
) {
1093
// remove from del_vars the variables that received some evidence:
1094
// only those that did not received evidence can be barren variables
1095
Set
<
const
DiscreteVariable
* >
the_del_vars
=
del_vars
;
1096
for
(
auto
iter
=
the_del_vars
.
beginSafe
();
iter
!=
the_del_vars
.
endSafe
();
1097
++
iter
) {
1098
NodeId
id
=
this
->
BN
().
nodeId
(**
iter
);
1099
if
(
this
->
hardEvidenceNodes
().
exists
(
id
)
1100
||
this
->
softEvidenceNodes
().
exists
(
id
)) {
1101
the_del_vars
.
erase
(
iter
);
1102
}
1103
}
1104
1105
// assign to each random variable the set of potentials that contain it
1106
HashTable
<
const
DiscreteVariable
*,
PotentialSet__
>
var2pots
;
1107
PotentialSet__
empty_pot_set
;
1108
for
(
const
auto
pot
:
pot_list
) {
1109
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
1110
for
(
const
auto
var
:
vars
) {
1111
if
(
the_del_vars
.
exists
(
var
)) {
1112
if
(!
var2pots
.
exists
(
var
)) {
var2pots
.
insert
(
var
,
empty_pot_set
); }
1113
var2pots
[
var
].
insert
(
pot
);
1114
}
1115
}
1116
}
1117
1118
// each variable with only one potential is a barren variable
1119
// assign to each potential with barren nodes its set of barren variables
1120
HashTable
<
const
Potential
<
GUM_SCALAR
>*,
Set
<
const
DiscreteVariable
* > >
1121
pot2barren_var
;
1122
Set
<
const
DiscreteVariable
* >
empty_var_set
;
1123
for
(
auto
elt
:
var2pots
) {
1124
if
(
elt
.
second
.
size
() == 1) {
// here we have a barren variable
1125
const
Potential
<
GUM_SCALAR
>*
pot
= *(
elt
.
second
.
begin
());
1126
if
(!
pot2barren_var
.
exists
(
pot
)) {
1127
pot2barren_var
.
insert
(
pot
,
empty_var_set
);
1128
}
1129
pot2barren_var
[
pot
].
insert
(
elt
.
first
);
// insert the barren variable
1130
}
1131
}
1132
1133
// for each potential with barren variables, marginalize them.
1134
// if the potential has only barren variables, simply remove them from the
1135
// set of potentials, else just project the potential
1136
MultiDimProjection
<
GUM_SCALAR
,
Potential
>
projector
(
SSNewprojPotential
);
1137
PotentialSet__
projected_pots
;
1138
for
(
auto
elt
:
pot2barren_var
) {
1139
// remove the current potential from pot_list as, anyway, we will change
1140
// it
1141
const
Potential
<
GUM_SCALAR
>*
pot
=
elt
.
first
;
1142
pot_list
.
erase
(
pot
);
1143
1144
// check whether we need to add a projected new potential or not (i.e.,
1145
// whether there exist non-barren variables or not)
1146
if
(
pot
->
variablesSequence
().
size
() !=
elt
.
second
.
size
()) {
1147
auto
new_pot
=
projector
.
project
(*
pot
,
elt
.
second
);
1148
pot_list
.
insert
(
new_pot
);
1149
projected_pots
.
insert
(
new_pot
);
1150
}
1151
}
1152
1153
return
projected_pots
;
1154
}
1155
1156
1157
// remove variables del_vars from the list of potentials pot_list
1158
template
<
typename
GUM_SCALAR
>
1159
Set
<
const
Potential
<
GUM_SCALAR
>* >
1160
ShaferShenoyInference
<
GUM_SCALAR
>::
marginalizeOut__
(
1161
Set
<
const
Potential
<
GUM_SCALAR
>* >
pot_list
,
1162
Set
<
const
DiscreteVariable
* >&
del_vars
,
1163
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1164
// remove the potentials corresponding to barren variables if we want
1165
// to exploit barren nodes
1166
PotentialSet__
barren_projected_potentials
;
1167
if
(
barren_nodes_type__
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
1168
barren_projected_potentials
=
removeBarrenVariables__
(
pot_list
,
del_vars
);
1169
}
1170
1171
// create a combine and project operator that will perform the
1172
// marginalization
1173
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
1174
combination_op__
,
1175
projection_op__
);
1176
PotentialSet__
new_pot_list
1177
=
combine_and_project
.
combineAndProject
(
pot_list
,
del_vars
);
1178
1179
// remove all the potentials that were created due to projections of
1180
// barren nodes and that are not part of the new_pot_list: these
1181
// potentials were just temporary potentials
1182
for
(
auto
iter
=
barren_projected_potentials
.
beginSafe
();
1183
iter
!=
barren_projected_potentials
.
endSafe
();
1184
++
iter
) {
1185
if
(!
new_pot_list
.
exists
(*
iter
))
delete
*
iter
;
1186
}
1187
1188
// remove all the potentials that have no dimension
1189
for
(
auto
iter_pot
=
new_pot_list
.
beginSafe
();
1190
iter_pot
!=
new_pot_list
.
endSafe
();
1191
++
iter_pot
) {
1192
if
((*
iter_pot
)->
variablesSequence
().
size
() == 0) {
1193
// as we have already marginalized out variables that received evidence,
1194
// it may be the case that, after combining and projecting, some
1195
// potentials might be empty. In this case, we shall keep their
1196
// constant and remove them from memory
1197
// # TODO: keep the constants!
1198
delete
*
iter_pot
;
1199
new_pot_list
.
erase
(
iter_pot
);
1200
}
1201
}
1202
1203
return
new_pot_list
;
1204
}
1205
1206
1207
// creates the message sent by clique from_id to clique to_id
1208
template
<
typename
GUM_SCALAR
>
1209
void
ShaferShenoyInference
<
GUM_SCALAR
>::
produceMessage__
(
NodeId
from_id
,
1210
NodeId
to_id
) {
1211
// get the potentials of the clique.
1212
PotentialSet__
pot_list
;
1213
if
(
clique_ss_potential__
.
exists
(
from_id
))
1214
pot_list
.
insert
(
clique_ss_potential__
[
from_id
]);
1215
1216
// add the messages sent by adjacent nodes to from_id
1217
for
(
const
auto
other_id
:
JT__
->
neighbours
(
from_id
))
1218
if
(
other_id
!=
to_id
)
1219
pot_list
+=
separator_potentials__
[
Arc
(
other_id
,
from_id
)];
1220
1221
// get the set of variables that need be removed from the potentials
1222
const
NodeSet
&
from_clique
=
JT__
->
clique
(
from_id
);
1223
const
NodeSet
&
separator
=
JT__
->
separator
(
from_id
,
to_id
);
1224
Set
<
const
DiscreteVariable
* >
del_vars
(
from_clique
.
size
());
1225
Set
<
const
DiscreteVariable
* >
kept_vars
(
separator
.
size
());
1226
const
auto
&
bn
=
this
->
BN
();
1227
1228
for
(
const
auto
node
:
from_clique
) {
1229
if
(!
separator
.
contains
(
node
)) {
1230
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1231
}
else
{
1232
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1233
}
1234
}
1235
1236
// pot_list now contains all the potentials to multiply and marginalize
1237
// => combine the messages
1238
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1239
1240
/*
1241
for the moment, remove this test: due to some optimizations, some
1242
potentials might have all their cells greater than 1.
1243
1244
// remove all the potentials that are equal to ones (as probability
1245
// matrix multiplications are tensorial, such potentials are useless)
1246
for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
1247
++iter) {
1248
const auto pot = *iter;
1249
if (pot->variablesSequence().size() == 1) {
1250
bool is_all_ones = true;
1251
for (Instantiation inst(*pot); !inst.end(); ++inst) {
1252
if ((*pot)[inst] < one_minus_epsilon__) {
1253
is_all_ones = false;
1254
break;
1255
}
1256
}
1257
if (is_all_ones) {
1258
if (!pot_list.exists(pot)) delete pot;
1259
new_pot_list.erase(iter);
1260
continue;
1261
}
1262
}
1263
}
1264
*/
1265
1266
// if there are still potentials in new_pot_list, combine them to
1267
// produce the message on the separator
1268
const
Arc
arc
(
from_id
,
to_id
);
1269
if
(!
new_pot_list
.
empty
()) {
1270
if
(
new_pot_list
.
size
() == 1) {
// there is only one potential
1271
// in new_pot_list, so this corresponds to our message on
1272
// the separator
1273
auto
pot
= *(
new_pot_list
.
begin
());
1274
separator_potentials__
[
arc
] =
std
::
move
(
new_pot_list
);
1275
if
(!
pot_list
.
exists
(
pot
)) {
1276
if
(!
created_potentials__
.
exists
(
arc
))
1277
created_potentials__
.
insert
(
arc
,
PotentialSet__
());
1278
created_potentials__
[
arc
].
insert
(
pot
);
1279
}
1280
}
else
{
1281
// create the message in the separator
1282
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1283
combination_op__
);
1284
auto
joint
=
fast_combination
.
combine
(
new_pot_list
);
1285
separator_potentials__
[
arc
].
insert
(
joint
);
1286
if
(!
created_potentials__
.
exists
(
arc
))
1287
created_potentials__
.
insert
(
arc
,
PotentialSet__
());
1288
created_potentials__
[
arc
].
insert
(
joint
);
1289
1290
// remove the temporary messages created in new_pot_list
1291
for
(
const
auto
pot
:
new_pot_list
) {
1292
if
(!
pot_list
.
exists
(
pot
)) {
delete
pot
; }
1293
}
1294
}
1295
}
1296
1297
messages_computed__
[
arc
] =
true
;
1298
}
1299
1300
1301
// performs a whole inference
1302
template
<
typename
GUM_SCALAR
>
1303
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
makeInference_
() {
1304
// collect messages for all single targets
1305
for
(
const
auto
node
:
this
->
targets
()) {
1306
// perform only collects in the join tree for nodes that have
1307
// not received hard evidence (those that received hard evidence were
1308
// not included into the join tree for speed-up reasons)
1309
if
(
graph__
.
exists
(
node
)) {
1310
collectMessage__
(
node_to_clique__
[
node
],
node_to_clique__
[
node
]);
1311
}
1312
}
1313
1314
// collect messages for all set targets
1315
// by parsing joint_target_to_clique__, we ensure that the cliques that
1316
// are referenced belong to the join tree (even if some of the nodes in
1317
// their associated joint_target do not belong to graph__)
1318
for
(
const
auto
set
:
joint_target_to_clique__
)
1319
collectMessage__
(
set
.
second
,
set
.
second
);
1320
}
1321
1322
1323
/// returns a fresh potential equal to P(1st arg,evidence)
1324
template
<
typename
GUM_SCALAR
>
1325
Potential
<
GUM_SCALAR
>*
1326
ShaferShenoyInference
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
NodeId
id
) {
1327
const
auto
&
bn
=
this
->
BN
();
1328
1329
// hard evidence do not belong to the join tree
1330
// # TODO: check for sets of inconsistent hard evidence
1331
if
(
this
->
hardEvidenceNodes
().
contains
(
id
)) {
1332
return
new
Potential
<
GUM_SCALAR
>(*(
this
->
evidence
()[
id
]));
1333
}
1334
1335
// if we still need to perform some inference task, do it (this should
1336
// already have been done by makeInference_)
1337
NodeId
clique_of_id
=
node_to_clique__
[
id
];
1338
collectMessage__
(
clique_of_id
,
clique_of_id
);
1339
1340
// now we just need to create the product of the potentials of the clique
1341
// containing id with the messages received by this clique and
1342
// marginalize out all variables except id
1343
PotentialSet__
pot_list
;
1344
if
(
clique_ss_potential__
.
exists
(
clique_of_id
))
1345
pot_list
.
insert
(
clique_ss_potential__
[
clique_of_id
]);
1346
1347
// add the messages sent by adjacent nodes to targetCliquxse
1348
for
(
const
auto
other
:
JT__
->
neighbours
(
clique_of_id
))
1349
pot_list
+=
separator_potentials__
[
Arc
(
other
,
clique_of_id
)];
1350
1351
// get the set of variables that need be removed from the potentials
1352
const
NodeSet
&
nodes
=
JT__
->
clique
(
clique_of_id
);
1353
Set
<
const
DiscreteVariable
* >
kept_vars
{&(
bn
.
variable
(
id
))};
1354
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1355
for
(
const
auto
node
:
nodes
) {
1356
if
(
node
!=
id
)
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1357
}
1358
1359
// pot_list now contains all the potentials to multiply and marginalize
1360
// => combine the messages
1361
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1362
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1363
1364
if
(
new_pot_list
.
size
() == 1) {
1365
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1366
// if pot already existed, create a copy, so that we can put it into
1367
// the target_posteriors__ property
1368
if
(
pot_list
.
exists
(
joint
)) {
1369
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1370
}
else
{
1371
// remove the joint from new_pot_list so that it will not be
1372
// removed just after the else block
1373
new_pot_list
.
clear
();
1374
}
1375
}
else
{
1376
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1377
combination_op__
);
1378
joint
=
fast_combination
.
combine
(
new_pot_list
);
1379
}
1380
1381
// remove the potentials that were created in new_pot_list
1382
for
(
auto
pot
:
new_pot_list
)
1383
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1384
1385
// check that the joint posterior is different from a 0 vector: this would
1386
// indicate that some hard evidence are not compatible (their joint
1387
// probability is equal to 0)
1388
bool
nonzero_found
=
false
;
1389
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1390
if
((*
joint
)[
inst
]) {
1391
nonzero_found
=
true
;
1392
break
;
1393
}
1394
}
1395
if
(!
nonzero_found
) {
1396
// remove joint from memory to avoid memory leaks
1397
delete
joint
;
1398
GUM_ERROR
(
IncompatibleEvidence
,
1399
"some evidence entered into the Bayes "
1400
"net are incompatible (their joint proba = 0)"
);
1401
}
1402
return
joint
;
1403
}
1404
1405
1406
/// returns the posterior of a given variable
1407
template
<
typename
GUM_SCALAR
>
1408
const
Potential
<
GUM_SCALAR
>&
1409
ShaferShenoyInference
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
1410
// check if we have already computed the posterior
1411
if
(
target_posteriors__
.
exists
(
id
)) {
return
*(
target_posteriors__
[
id
]); }
1412
1413
// compute the joint posterior and normalize
1414
auto
joint
=
unnormalizedJointPosterior_
(
id
);
1415
if
(
joint
->
sum
() != 1)
// hard test for ReadOnly CPT (as aggregator)
1416
joint
->
normalize
();
1417
target_posteriors__
.
insert
(
id
,
joint
);
1418
1419
return
*
joint
;
1420
}
1421
1422
1423
// returns the marginal a posteriori proba of a given node
1424
template
<
typename
GUM_SCALAR
>
1425
Potential
<
GUM_SCALAR
>*
1426
ShaferShenoyInference
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
1427
const
NodeSet
&
set
) {
1428
// hard evidence do not belong to the join tree, so extract the nodes
1429
// from targets that are not hard evidence
1430
NodeSet
targets
=
set
,
hard_ev_nodes
;
1431
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
1432
if
(
targets
.
contains
(
node
)) {
1433
targets
.
erase
(
node
);
1434
hard_ev_nodes
.
insert
(
node
);
1435
}
1436
}
1437
1438
// if all the nodes have received hard evidence, then compute the
1439
// joint posterior directly by multiplying the hard evidence potentials
1440
const
auto
&
evidence
=
this
->
evidence
();
1441
if
(
targets
.
empty
()) {
1442
PotentialSet__
pot_list
;
1443
for
(
const
auto
node
:
set
) {
1444
pot_list
.
insert
(
evidence
[
node
]);
1445
}
1446
if
(
pot_list
.
size
() == 1) {
1447
auto
pot
=
new
Potential
<
GUM_SCALAR
>(**(
pot_list
.
begin
()));
1448
return
pot
;
1449
}
else
{
1450
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1451
combination_op__
);
1452
return
fast_combination
.
combine
(
pot_list
);
1453
}
1454
}
1455
1456
1457
// if we still need to perform some inference task, do it: so, first,
1458
// determine the clique on which we should perform collect to compute
1459
// the unnormalized joint posterior of a set of nodes containing "set"
1460
NodeId
clique_of_set
;
1461
try
{
1462
clique_of_set
=
joint_target_to_clique__
[
set
];
1463
}
catch
(
NotFound
&) {
1464
// here, the precise set of targets does not belong to the set of targets
1465
// defined by the user. So we will try to find a clique in the junction
1466
// tree that contains "targets":
1467
1468
// 1/ we should check that all the nodes belong to the join tree
1469
for
(
const
auto
node
:
targets
) {
1470
if
(!
graph__
.
exists
(
node
)) {
1471
GUM_ERROR
(
UndefinedElement
,
node
<<
" is not a target node"
);
1472
}
1473
}
1474
1475
// 2/ the clique created by the first eliminated node among target is the
1476
// one we are looking for
1477
const
std
::
vector
<
NodeId
>&
JT_elim_order
1478
=
triangulation__
->
eliminationOrder
();
1479
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
1480
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
;
1481
++
i
)
1482
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
1483
NodeId
first_eliminated_node
= *(
targets
.
begin
());
1484
int
elim_number
=
elim_order
[
first_eliminated_node
];
1485
for
(
const
auto
node
:
targets
) {
1486
if
(
elim_order
[
node
] <
elim_number
) {
1487
elim_number
=
elim_order
[
node
];
1488
first_eliminated_node
=
node
;
1489
}
1490
}
1491
clique_of_set
1492
=
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
);
1493
1494
// 3/ check that cliquee_of_set contains the all the nodes in the target
1495
const
NodeSet
&
clique_nodes
=
JT__
->
clique
(
clique_of_set
);
1496
for
(
const
auto
node
:
targets
) {
1497
if
(!
clique_nodes
.
contains
(
node
)) {
1498
GUM_ERROR
(
UndefinedElement
,
set
<<
" is not a joint target"
);
1499
}
1500
}
1501
1502
// add the discovered clique to joint_target_to_clique__
1503
joint_target_to_clique__
.
insert
(
set
,
clique_of_set
);
1504
}
1505
1506
// now perform a collect on the clique
1507
collectMessage__
(
clique_of_set
,
clique_of_set
);
1508
1509
// now we just need to create the product of the potentials of the clique
1510
// containing set with the messages received by this clique and
1511
// marginalize out all variables except set
1512
PotentialSet__
pot_list
;
1513
if
(
clique_ss_potential__
.
exists
(
clique_of_set
))
1514
pot_list
.
insert
(
clique_ss_potential__
[
clique_of_set
]);
1515
1516
// add the messages sent by adjacent nodes to targetClique
1517
for
(
const
auto
other
:
JT__
->
neighbours
(
clique_of_set
))
1518
pot_list
+=
separator_potentials__
[
Arc
(
other
,
clique_of_set
)];
1519
1520
// get the set of variables that need be removed from the potentials
1521
const
NodeSet
&
nodes
=
JT__
->
clique
(
clique_of_set
);
1522
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1523
Set
<
const
DiscreteVariable
* >
kept_vars
(
targets
.
size
());
1524
const
auto
&
bn
=
this
->
BN
();
1525
for
(
const
auto
node
:
nodes
) {
1526
if
(!
targets
.
contains
(
node
)) {
1527
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1528
}
else
{
1529
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1530
}
1531
}
1532
1533
// pot_list now contains all the potentials to multiply and marginalize
1534
// => combine the messages
1535
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1536
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1537
1538
if
((
new_pot_list
.
size
() == 1) &&
hard_ev_nodes
.
empty
()) {
1539
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1540
// if pot already existed, create a copy, so that we can put it into
1541
// the target_posteriors__ property
1542
if
(
pot_list
.
exists
(
joint
)) {
1543
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1544
}
else
{
1545
// remove the joint from new_pot_list so that it will not be
1546
// removed just after the next else block
1547
new_pot_list
.
clear
();
1548
}
1549
}
else
{
1550
// combine all the potentials in new_pot_list with all the hard evidence
1551
// of the nodes in set
1552
PotentialSet__
new_new_pot_list
=
new_pot_list
;
1553
for
(
const
auto
node
:
hard_ev_nodes
) {
1554
new_new_pot_list
.
insert
(
evidence
[
node
]);
1555
}
1556
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1557
combination_op__
);
1558
joint
=
fast_combination
.
combine
(
new_new_pot_list
);
1559
}
1560
1561
// remove the potentials that were created in new_pot_list
1562
for
(
auto
pot
:
new_pot_list
)
1563
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1564
1565
// check that the joint posterior is different from a 0 vector: this would
1566
// indicate that some hard evidence are not compatible
1567
bool
nonzero_found
=
false
;
1568
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1569
if
((*
joint
)[
inst
]) {
1570
nonzero_found
=
true
;
1571
break
;
1572
}
1573
}
1574
if
(!
nonzero_found
) {
1575
// remove joint from memory to avoid memory leaks
1576
delete
joint
;
1577
GUM_ERROR
(
IncompatibleEvidence
,
1578
"some evidence entered into the Bayes "
1579
"net are incompatible (their joint proba = 0)"
);
1580
}
1581
1582
return
joint
;
1583
}
1584
1585
1586
/// returns the posterior of a given set of variables
1587
template
<
typename
GUM_SCALAR
>
1588
const
Potential
<
GUM_SCALAR
>&
1589
ShaferShenoyInference
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
set
) {
1590
// check if we have already computed the posterior
1591
if
(
joint_target_posteriors__
.
exists
(
set
)) {
1592
return
*(
joint_target_posteriors__
[
set
]);
1593
}
1594
1595
// compute the joint posterior and normalize
1596
auto
joint
=
unnormalizedJointPosterior_
(
set
);
1597
joint
->
normalize
();
1598
joint_target_posteriors__
.
insert
(
set
,
joint
);
1599
1600
return
*
joint
;
1601
}
1602
1603
1604
/// returns the posterior of a given set of variables
1605
template
<
typename
GUM_SCALAR
>
1606
const
Potential
<
GUM_SCALAR
>&
1607
ShaferShenoyInference
<
GUM_SCALAR
>::
jointPosterior_
(
1608
const
NodeSet
&
wanted_target
,
1609
const
NodeSet
&
declared_target
) {
1610
// check if we have already computed the posterior of wanted_target
1611
if
(
joint_target_posteriors__
.
exists
(
wanted_target
))
1612
return
*(
joint_target_posteriors__
[
wanted_target
]);
1613
1614
// here, we will have to compute the posterior of declared_target and
1615
// marginalize out all the variables that do not belong to wanted_target
1616
1617
// check if we have already computed the posterior of declared_target
1618
if
(!
joint_target_posteriors__
.
exists
(
declared_target
)) {
1619
jointPosterior_
(
declared_target
);
1620
}
1621
1622
// marginalize out all the variables that do not belong to wanted_target
1623
const
auto
&
bn
=
this
->
BN
();
1624
Set
<
const
DiscreteVariable
* >
del_vars
;
1625
for
(
const
auto
node
:
declared_target
)
1626
if
(!
wanted_target
.
contains
(
node
))
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1627
auto
pot
=
new
Potential
<
GUM_SCALAR
>(
1628
joint_target_posteriors__
[
declared_target
]->
margSumOut
(
del_vars
));
1629
1630
// save the result into the cache
1631
joint_target_posteriors__
.
insert
(
wanted_target
,
pot
);
1632
1633
return
*
pot
;
1634
}
1635
1636
1637
template
<
typename
GUM_SCALAR
>
1638
GUM_SCALAR
ShaferShenoyInference
<
GUM_SCALAR
>::
evidenceProbability
() {
1639
// perform inference in each connected component
1640
this
->
makeInference
();
1641
1642
// for each connected component, select a variable X and compute the
1643
// joint probability of X and evidence e. Then marginalize-out X to get
1644
// p(e) in this connected component. Finally, multiply all the p(e) that
1645
// we got and the elements in constants__. The result is the probability
1646
// of evidence
1647
1648
GUM_SCALAR
prob_ev
= 1;
1649
for
(
const
auto
root
:
roots__
) {
1650
// get a node in the clique
1651
NodeId
node
= *(
JT__
->
clique
(
root
).
begin
());
1652
Potential
<
GUM_SCALAR
>*
tmp
=
unnormalizedJointPosterior_
(
node
);
1653
GUM_SCALAR
sum
= 0;
1654
for
(
Instantiation
iter
(*
tmp
); !
iter
.
end
(); ++
iter
)
1655
sum
+=
tmp
->
get
(
iter
);
1656
prob_ev
*=
sum
;
1657
delete
tmp
;
1658
}
1659
1660
for
(
const
auto
&
projected_cpt
:
constants__
)
1661
prob_ev
*=
projected_cpt
.
second
;
1662
1663
return
prob_ev
;
1664
}
1665
1666
1667
}
/* namespace gum */
1668
1669
#
endif
// DOXYGEN_SHOULD_SKIP_THIS
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669