aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
ShaferShenoyInference_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of Shafer-Shenoy's propagation for inference in
25
* Bayesian networks.
26
*/
27
28
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
29
#
include
<
agrum
/
BN
/
inference
/
ShaferShenoyInference
.
h
>
30
31
#
include
<
agrum
/
BN
/
algorithms
/
BayesBall
.
h
>
32
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
33
#
include
<
agrum
/
BN
/
algorithms
/
dSeparation
.
h
>
34
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
35
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
36
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
37
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimProjection
.
h
>
38
39
40
namespace
gum
{
41
// default constructor
42
template
<
typename
GUM_SCALAR >
43
INLINE
44
ShaferShenoyInference<
GUM_SCALAR
>::
ShaferShenoyInference
(
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
45
FindBarrenNodesType
barren_type
,
46
bool
use_binary_join_tree
) :
47
JointTargetedInference
<
GUM_SCALAR
>(
BN
),
48
EvidenceInference
<
GUM_SCALAR
>(
BN
),
_use_binary_join_tree_
(
use_binary_join_tree
) {
49
// sets the barren nodes finding algorithm
50
setFindBarrenNodesType
(
barren_type
);
51
52
// create a default triangulation (the user can change it afterwards)
53
_triangulation_
=
new
DefaultTriangulation
;
54
55
// for debugging purposessetRequiredInference
56
GUM_CONSTRUCTOR
(
ShaferShenoyInference
);
57
}
58
59
60
// destructor
61
template
<
typename
GUM_SCALAR
>
62
INLINE
ShaferShenoyInference
<
GUM_SCALAR
>::~
ShaferShenoyInference
() {
63
// remove all the potentials created during the last message passing
64
for
(
const
auto
&
pots
:
_created_potentials_
)
65
for
(
const
auto
pot
:
pots
.
second
)
66
delete
pot
;
67
68
// remove the potentials created after removing the nodes that received
69
// hard evidence
70
for
(
const
auto
&
pot
:
_hard_ev_projected_CPTs_
)
71
delete
pot
.
second
;
72
73
// remove all the potentials in _clique_ss_potential_ that do not belong
74
// to _clique_potentials_ : in this case, those potentials have been
75
// created by combination of the corresponding list of potentials in
76
// _clique_potentials_. In other words, the size of this list is strictly
77
// greater than 1.
78
for
(
auto
pot
:
_clique_ss_potential_
) {
79
if
(
_clique_potentials_
[
pot
.
first
].
size
() > 1)
delete
pot
.
second
;
80
}
81
82
// remove all the posteriors computed
83
for
(
const
auto
&
pot
:
_target_posteriors_
)
84
delete
pot
.
second
;
85
for
(
const
auto
&
pot
:
_joint_target_posteriors_
)
86
delete
pot
.
second
;
87
88
// remove the junction tree and the triangulation algorithm
89
if
(
_JT_
!=
nullptr
)
delete
_JT_
;
90
if
(
_junctionTree_
!=
nullptr
)
delete
_junctionTree_
;
91
delete
_triangulation_
;
92
93
// for debugging purposes
94
GUM_DESTRUCTOR
(
ShaferShenoyInference
);
95
}
96
97
98
/// set a new triangulation algorithm
99
template
<
typename
GUM_SCALAR
>
100
void
101
ShaferShenoyInference
<
GUM_SCALAR
>::
setTriangulation
(
const
Triangulation
&
new_triangulation
) {
102
delete
_triangulation_
;
103
_triangulation_
=
new_triangulation
.
newFactory
();
104
_is_new_jt_needed_
=
true
;
105
this
->
setOutdatedStructureState_
();
106
}
107
108
109
/// returns the current join tree used
110
template
<
typename
GUM_SCALAR
>
111
INLINE
const
JoinTree
*
ShaferShenoyInference
<
GUM_SCALAR
>::
joinTree
() {
112
_createNewJT_
();
113
114
return
_JT_
;
115
}
116
117
/// returns the current junction tree
118
template
<
typename
GUM_SCALAR
>
119
INLINE
const
JunctionTree
*
ShaferShenoyInference
<
GUM_SCALAR
>::
junctionTree
() {
120
_createNewJT_
();
121
122
return
_junctionTree_
;
123
}
124
125
126
/// sets the operator for performing the projections
127
template
<
typename
GUM_SCALAR
>
128
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_setProjectionFunction_
(
129
Potential
<
GUM_SCALAR
>* (*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
130
const
Set
<
const
DiscreteVariable
* >&)) {
131
_projection_op_
=
proj
;
132
}
133
134
135
/// sets the operator for performing the combinations
136
template
<
typename
GUM_SCALAR
>
137
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_setCombinationFunction_
(
138
Potential
<
GUM_SCALAR
>* (*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
139
const
Potential
<
GUM_SCALAR
>&)) {
140
_combination_op_
=
comb
;
141
}
142
143
144
/// invalidate all messages, posteriors and created potentials
145
template
<
typename
GUM_SCALAR
>
146
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_invalidateAllMessages_
() {
147
// remove all the messages computed
148
for
(
auto
&
potset
:
_separator_potentials_
)
149
potset
.
second
.
clear
();
150
for
(
auto
&
mess_computed
:
_messages_computed_
)
151
mess_computed
.
second
=
false
;
152
153
// remove all the created potentials
154
for
(
const
auto
&
potset
:
_created_potentials_
)
155
for
(
const
auto
pot
:
potset
.
second
)
156
delete
pot
;
157
158
// remove all the posteriors
159
for
(
const
auto
&
pot
:
_target_posteriors_
)
160
delete
pot
.
second
;
161
for
(
const
auto
&
pot
:
_joint_target_posteriors_
)
162
delete
pot
.
second
;
163
164
// indicate that new messages need be computed
165
if
(
this
->
isInferenceReady
() ||
this
->
isInferenceDone
())
this
->
setOutdatedPotentialsState_
();
166
}
167
168
169
/// sets how we determine barren nodes
170
template
<
typename
GUM_SCALAR
>
171
void
ShaferShenoyInference
<
GUM_SCALAR
>::
setFindBarrenNodesType
(
FindBarrenNodesType
type
) {
172
if
(
type
!=
_barren_nodes_type_
) {
173
// WARNING: if a new type is added here, method _createJT_ should
174
// certainly be updated as well, in particular its step 2.
175
switch
(
type
) {
176
case
FindBarrenNodesType
::
FIND_BARREN_NODES
:
177
case
FindBarrenNodesType
::
FIND_NO_BARREN_NODES
:
178
break
;
179
180
default
:
181
GUM_ERROR
(
InvalidArgument
,
182
"setFindBarrenNodesType for type "
<< (
unsigned
int
)
type
183
<<
" is not implemented yet"
);
184
}
185
186
_barren_nodes_type_
=
type
;
187
188
// potentially, we may need to reconstruct a junction tree
189
this
->
setOutdatedStructureState_
();
190
}
191
}
192
193
194
/// fired when a new evidence is inserted
195
template
<
typename
GUM_SCALAR
>
196
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
id
,
197
bool
isHardEvidence
) {
198
// if we have a new hard evidence, this modifies the undigraph over which
199
// the join tree is created. This is also the case if id is not a node of
200
// of the undigraph
201
if
(
isHardEvidence
|| !
_graph_
.
exists
(
id
))
202
_is_new_jt_needed_
=
true
;
203
else
{
204
try
{
205
_evidence_changes_
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ADDED
);
206
}
catch
(
DuplicateElement
&) {
207
// here, the evidence change already existed. This necessarily means
208
// that the current saved change is an EVIDENCE_ERASED. So if we
209
// erased the evidence and added some again, this corresponds to an
210
// EVIDENCE_MODIFIED
211
_evidence_changes_
[
id
] =
EvidenceChangeType
::
EVIDENCE_MODIFIED
;
212
}
213
}
214
}
215
216
217
/// fired when an evidence is removed
218
template
<
typename
GUM_SCALAR
>
219
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onEvidenceErased_
(
const
NodeId
id
,
220
bool
isHardEvidence
) {
221
// if we delete a hard evidence, this modifies the undigraph over which
222
// the join tree is created.
223
if
(
isHardEvidence
)
224
_is_new_jt_needed_
=
true
;
225
else
{
226
try
{
227
_evidence_changes_
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
228
}
catch
(
DuplicateElement
&) {
229
// here, the evidence change already existed and it is necessarily an
230
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
231
// been added and is now erased, this is similar to not having created
232
// it. If the evidence was only modified, it already existed in the
233
// last inference and we should now indicate that it has been removed.
234
if
(
_evidence_changes_
[
id
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
235
_evidence_changes_
.
erase
(
id
);
236
else
237
_evidence_changes_
[
id
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
238
}
239
}
240
}
241
242
243
/// fired when all the evidence are erased
244
template
<
typename
GUM_SCALAR
>
245
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
bool
has_hard_evidence
) {
246
if
(
has_hard_evidence
|| !
this
->
hardEvidenceNodes
().
empty
())
247
_is_new_jt_needed_
=
true
;
248
else
{
249
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
250
try
{
251
_evidence_changes_
.
insert
(
node
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
252
}
catch
(
DuplicateElement
&) {
253
// here, the evidence change already existed and it is necessarily an
254
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
255
// been added and is now erased, this is similar to not having created
256
// it. If the evidence was only modified, it already existed in the
257
// last inference and we should now indicate that it has been removed.
258
if
(
_evidence_changes_
[
node
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
259
_evidence_changes_
.
erase
(
node
);
260
else
261
_evidence_changes_
[
node
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
262
}
263
}
264
}
265
}
266
267
268
/// fired when an evidence is changed
269
template
<
typename
GUM_SCALAR
>
270
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onEvidenceChanged_
(
const
NodeId
id
,
271
bool
hasChangedSoftHard
) {
272
if
(
hasChangedSoftHard
)
273
_is_new_jt_needed_
=
true
;
274
else
{
275
try
{
276
_evidence_changes_
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_MODIFIED
);
277
}
catch
(
DuplicateElement
&) {
278
// here, the evidence change already existed and it is necessarily an
279
// EVIDENCE_ADDED. So we should keep this state to indicate that this
280
// evidence is new w.r.t. the last inference
281
}
282
}
283
}
284
285
286
/// fired after a new target is inserted
287
template
<
typename
GUM_SCALAR
>
288
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
const
NodeId
id
) {}
289
290
291
/// fired before a target is removed
292
template
<
typename
GUM_SCALAR
>
293
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
const
NodeId
id
) {}
294
295
296
/// fired after a new set target is inserted
297
template
<
typename
GUM_SCALAR
>
298
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onJointTargetAdded_
(
const
NodeSet
&
set
) {}
299
300
301
/// fired before a set target is removed
302
template
<
typename
GUM_SCALAR
>
303
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onJointTargetErased_
(
const
NodeSet
&
set
) {}
304
305
306
/// fired after all the nodes of the BN are added as single targets
307
template
<
typename
GUM_SCALAR
>
308
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {}
309
310
311
/// fired before a all the single_targets are removed
312
template
<
typename
GUM_SCALAR
>
313
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
314
315
/// fired after a new Bayes net has been assigned to the engine
316
template
<
typename
GUM_SCALAR
>
317
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
bn
) {}
318
319
/// fired before a all the joint_targets are removed
320
template
<
typename
GUM_SCALAR
>
321
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllJointTargetsErased_
() {}
322
323
324
/// fired before a all the single and joint_targets are removed
325
template
<
typename
GUM_SCALAR
>
326
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
onAllTargetsErased_
() {}
327
328
329
// check whether a new junction tree is really needed for the next inference
330
template
<
typename
GUM_SCALAR
>
331
bool
ShaferShenoyInference
<
GUM_SCALAR
>::
_isNewJTNeeded_
()
const
{
332
// if we do not have a JT or if _new_jt_needed_ is set to true, then
333
// we know that we need to create a new join tree
334
if
((
_JT_
==
nullptr
) ||
_is_new_jt_needed_
)
return
true
;
335
336
// if some some targets do not belong to the join tree and, consequently,
337
// to the undigraph that was used to construct the join tree, then we need
338
// to create a new JT. This situation may occur if we constructed the
339
// join tree after pruning irrelevant/barren nodes from the BN)
340
// however, note that the nodes that received hard evidence do not belong to
341
// the graph and, therefore, should not be taken into account
342
const
auto
&
hard_ev_nodes
=
this
->
hardEvidenceNodes
();
343
for
(
const
auto
node
:
this
->
targets
()) {
344
if
(!
_graph_
.
exists
(
node
) && !
hard_ev_nodes
.
exists
(
node
))
return
true
;
345
}
346
for
(
const
auto
&
joint_target
:
this
->
jointTargets
()) {
347
// here, we need to check that at least one clique contains all the
348
// nodes of the joint target.
349
bool
containing_clique_found
=
false
;
350
for
(
const
auto
node
:
joint_target
) {
351
bool
found
=
true
;
352
try
{
353
const
NodeSet
&
clique
=
_JT_
->
clique
(
_node_to_clique_
[
node
]);
354
for
(
const
auto
xnode
:
joint_target
) {
355
if
(!
clique
.
contains
(
xnode
) && !
hard_ev_nodes
.
exists
(
xnode
)) {
356
found
=
false
;
357
break
;
358
}
359
}
360
}
catch
(
NotFound
&) {
found
=
false
; }
361
362
if
(
found
) {
363
containing_clique_found
=
true
;
364
break
;
365
}
366
}
367
368
if
(!
containing_clique_found
)
return
true
;
369
}
370
371
// if some new evidence have been added on nodes that do not belong
372
// to _graph_, then we potentially have to reconstruct the join tree
373
for
(
const
auto
&
change
:
_evidence_changes_
) {
374
if
((
change
.
second
==
EvidenceChangeType
::
EVIDENCE_ADDED
) && !
_graph_
.
exists
(
change
.
first
))
375
return
true
;
376
}
377
378
// here, the current JT is exactly what we need for the next inference
379
return
false
;
380
}
381
382
383
/// create a new junction tree as well as its related data structures
384
template
<
typename
GUM_SCALAR
>
385
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_createNewJT_
() {
386
// to create the JT, we first create the moral graph of the BN in the
387
// following way in order to take into account the barren nodes and the
388
// nodes that received evidence:
389
// 1/ we create an undirected graph containing only the nodes and no edge
390
// 2/ if we take into account barren nodes, remove them from the graph
391
// 3/ add edges so that each node and its parents in the BN form a clique
392
// 4/ add edges so that set targets are cliques of the moral graph
393
// 5/ remove the nodes that received hard evidence (by step 3/, their
394
// parents are linked by edges, which is necessary for inference)
395
//
396
// At the end of step 5/, we have our moral graph and we can triangulate it
397
// to get the new junction tree
398
399
// 1/ create an undirected graph containing only the nodes and no edge
400
const
auto
&
bn
=
this
->
BN
();
401
_graph_
.
clear
();
402
for
(
auto
node
:
bn
.
dag
())
403
_graph_
.
addNodeWithId
(
node
);
404
405
// 2/ if we wish to exploit barren nodes, we shall remove them from the
406
// BN to do so: we identify all the nodes that are not targets and have
407
// received no evidence and such that their descendants are neither
408
// targets nor evidence nodes. Such nodes can be safely discarded from
409
// the BN without altering the inference output
410
if
(
_barren_nodes_type_
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
411
// identify the barren nodes
412
NodeSet
target_nodes
=
this
->
targets
();
413
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
414
target_nodes
+=
nodeset
;
415
}
416
417
// check that all the nodes are not targets, otherwise, there is no
418
// barren node
419
if
(
target_nodes
.
size
() !=
bn
.
size
()) {
420
BarrenNodesFinder
finder
(&(
bn
.
dag
()));
421
finder
.
setTargets
(&
target_nodes
);
422
423
NodeSet
evidence_nodes
;
424
for
(
const
auto
&
pair
:
this
->
evidence
()) {
425
evidence_nodes
.
insert
(
pair
.
first
);
426
}
427
finder
.
setEvidence
(&
evidence_nodes
);
428
429
NodeSet
barren_nodes
=
finder
.
barrenNodes
();
430
431
// remove the barren nodes from the moral graph
432
for
(
const
auto
node
:
barren_nodes
) {
433
_graph_
.
eraseNode
(
node
);
434
}
435
}
436
}
437
438
// 3/ add edges so that each node and its parents in the BN form a clique
439
for
(
const
auto
node
:
_graph_
) {
440
const
NodeSet
&
parents
=
bn
.
parents
(
node
);
441
for
(
auto
iter1
=
parents
.
cbegin
();
iter1
!=
parents
.
cend
(); ++
iter1
) {
442
_graph_
.
addEdge
(*
iter1
,
node
);
443
auto
iter2
=
iter1
;
444
for
(++
iter2
;
iter2
!=
parents
.
cend
(); ++
iter2
) {
445
_graph_
.
addEdge
(*
iter1
, *
iter2
);
446
}
447
}
448
}
449
450
// 4/ if there exist some joint targets, we shall add new edges into the
451
// moral graph in order to ensure that there exists a clique containing
452
// each joint
453
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
454
for
(
auto
iter1
=
nodeset
.
cbegin
();
iter1
!=
nodeset
.
cend
(); ++
iter1
) {
455
auto
iter2
=
iter1
;
456
for
(++
iter2
;
iter2
!=
nodeset
.
cend
(); ++
iter2
) {
457
_graph_
.
addEdge
(*
iter1
, *
iter2
);
458
}
459
}
460
}
461
462
// 5/ remove all the nodes that received hard evidence
463
_hard_ev_nodes_
=
this
->
hardEvidenceNodes
();
464
for
(
const
auto
node
:
_hard_ev_nodes_
) {
465
_graph_
.
eraseNode
(
node
);
466
}
467
468
469
// now, we can compute the new junction tree. To speed-up computations
470
// (essentially, those of a distribution phase), we construct from this
471
// junction tree a binary join tree
472
if
(
_JT_
!=
nullptr
)
delete
_JT_
;
473
if
(
_junctionTree_
!=
nullptr
)
delete
_junctionTree_
;
474
475
_triangulation_
->
setGraph
(&
_graph_
, &(
this
->
domainSizes
()));
476
const
JunctionTree
&
triang_jt
=
_triangulation_
->
junctionTree
();
477
if
(
_use_binary_join_tree_
) {
478
BinaryJoinTreeConverterDefault
bjt_converter
;
479
NodeSet
emptyset
;
480
_JT_
=
new
CliqueGraph
(
bjt_converter
.
convert
(
triang_jt
,
this
->
domainSizes
(),
emptyset
));
481
}
else
{
482
_JT_
=
new
CliqueGraph
(
triang_jt
);
483
}
484
_junctionTree_
=
new
CliqueGraph
(
triang_jt
);
485
486
487
// indicate, for each node of the moral graph a clique in _JT_ that can
488
// contain its conditional probability table
489
_node_to_clique_
.
clear
();
490
const
std
::
vector
<
NodeId
>&
JT_elim_order
=
_triangulation_
->
eliminationOrder
();
491
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
492
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
; ++
i
)
493
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
494
const
DAG
&
dag
=
bn
.
dag
();
495
for
(
const
auto
node
:
_graph_
) {
496
// get the variables in the potential of node (and its parents)
497
NodeId
first_eliminated_node
=
node
;
498
int
elim_number
=
elim_order
[
first_eliminated_node
];
499
500
for
(
const
auto
parent
:
dag
.
parents
(
node
)) {
501
if
(
_graph_
.
existsNode
(
parent
) && (
elim_order
[
parent
] <
elim_number
)) {
502
elim_number
=
elim_order
[
parent
];
503
first_eliminated_node
=
parent
;
504
}
505
}
506
507
// first_eliminated_node contains the first var (node or one of its
508
// parents) eliminated => the clique created during its elimination
509
// contains node and all of its parents => it can contain the potential
510
// assigned to the node in the BN
511
_node_to_clique_
.
insert
(
node
,
512
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
));
513
}
514
515
// do the same for the nodes that received evidence. Here, we only store
516
// the nodes for which at least one parent belongs to _graph_ (otherwise
517
// their CPT is just a constant real number).
518
for
(
const
auto
node
:
_hard_ev_nodes_
) {
519
// get the set of parents of the node that belong to _graph_
520
NodeSet
pars
(
dag
.
parents
(
node
).
size
());
521
for
(
const
auto
par
:
dag
.
parents
(
node
))
522
if
(
_graph_
.
exists
(
par
))
pars
.
insert
(
par
);
523
524
if
(!
pars
.
empty
()) {
525
NodeId
first_eliminated_node
= *(
pars
.
begin
());
526
int
elim_number
=
elim_order
[
first_eliminated_node
];
527
528
for
(
const
auto
parent
:
pars
) {
529
if
(
elim_order
[
parent
] <
elim_number
) {
530
elim_number
=
elim_order
[
parent
];
531
first_eliminated_node
=
parent
;
532
}
533
}
534
535
// first_eliminated_node contains the first var (node or one of its
536
// parents) eliminated => the clique created during its elimination
537
// contains node and all of its parents => it can contain the potential
538
// assigned to the node in the BN
539
_node_to_clique_
.
insert
(
node
,
540
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
));
541
}
542
}
543
544
// indicate for each joint_target a clique that contains it
545
_joint_target_to_clique_
.
clear
();
546
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
547
// remove from set all the nodes that received hard evidence (since they
548
// do not belong to the join tree)
549
NodeSet
nodeset
=
set
;
550
for
(
const
auto
node
:
_hard_ev_nodes_
)
551
if
(
nodeset
.
contains
(
node
))
nodeset
.
erase
(
node
);
552
553
if
(!
nodeset
.
empty
()) {
554
// the clique we are looking for is the one that was created when
555
// the first element of nodeset was eliminated
556
NodeId
first_eliminated_node
= *(
nodeset
.
begin
());
557
int
elim_number
=
elim_order
[
first_eliminated_node
];
558
for
(
const
auto
node
:
nodeset
) {
559
if
(
elim_order
[
node
] <
elim_number
) {
560
elim_number
=
elim_order
[
node
];
561
first_eliminated_node
=
node
;
562
}
563
}
564
565
_joint_target_to_clique_
.
insert
(
566
set
,
567
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
));
568
}
569
}
570
571
572
// compute the roots of _JT_'s connected components
573
_computeJoinTreeRoots_
();
574
575
// remove all the Shafer-Shenoy potentials stored into the cliques
576
for
(
const
auto
&
xpot
:
_clique_ss_potential_
) {
577
if
(
_clique_potentials_
[
xpot
.
first
].
size
() > 1)
delete
xpot
.
second
;
578
}
579
580
// create empty potential lists into the cliques of the joint tree as well
581
// as empty lists of evidence
582
_PotentialSet_
empty_set
;
583
_clique_potentials_
.
clear
();
584
_node_to_soft_evidence_
.
clear
();
585
for
(
const
auto
node
: *
_JT_
) {
586
_clique_potentials_
.
insert
(
node
,
empty_set
);
587
}
588
589
// remove all the potentials created during the last inference
590
for
(
auto
&
potlist
:
_created_potentials_
)
591
for
(
auto
pot
:
potlist
.
second
)
592
delete
pot
;
593
_created_potentials_
.
clear
();
594
595
// remove all the potentials created to take into account hard evidence
596
// during the last inference
597
for
(
auto
pot_pair
:
_hard_ev_projected_CPTs_
)
598
delete
pot_pair
.
second
;
599
_hard_ev_projected_CPTs_
.
clear
();
600
601
// remove all the constants created due to projections of CPTs that were
602
// defined over only hard evidence nodes
603
_constants_
.
clear
();
604
605
// create empty lists of potentials for the messages and indicate that no
606
// message has been computed yet
607
_separator_potentials_
.
clear
();
608
_messages_computed_
.
clear
();
609
for
(
const
auto
&
edge
:
_JT_
->
edges
()) {
610
const
Arc
arc1
(
edge
.
first
(),
edge
.
second
());
611
_separator_potentials_
.
insert
(
arc1
,
empty_set
);
612
_messages_computed_
.
insert
(
arc1
,
false
);
613
const
Arc
arc2
(
Arc
(
edge
.
second
(),
edge
.
first
()));
614
_separator_potentials_
.
insert
(
arc2
,
empty_set
);
615
_messages_computed_
.
insert
(
arc2
,
false
);
616
}
617
618
// remove all the posteriors computed so far
619
for
(
const
auto
&
pot
:
_target_posteriors_
)
620
delete
pot
.
second
;
621
_target_posteriors_
.
clear
();
622
for
(
const
auto
&
pot
:
_joint_target_posteriors_
)
623
delete
pot
.
second
;
624
_joint_target_posteriors_
.
clear
();
625
626
627
// put all the CPT's of the Bayes net nodes into the cliques
628
// here, beware: all the potentials that are defined over some nodes
629
// including hard evidence must be projected so that these nodes are
630
// removed from the potential
631
const
auto
&
evidence
=
this
->
evidence
();
632
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
633
for
(
const
auto
node
:
dag
) {
634
if
(
_graph_
.
exists
(
node
) ||
_hard_ev_nodes_
.
contains
(
node
)) {
635
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
636
637
// get the list of nodes with hard evidence in cpt
638
NodeSet
hard_nodes
;
639
const
auto
&
variables
=
cpt
.
variablesSequence
();
640
for
(
const
auto
var
:
variables
) {
641
NodeId
xnode
=
bn
.
nodeId
(*
var
);
642
if
(
_hard_ev_nodes_
.
contains
(
xnode
))
hard_nodes
.
insert
(
xnode
);
643
}
644
645
// if hard_nodes contains hard evidence nodes, perform a projection
646
// and insert the result into the appropriate clique, else insert
647
// directly cpt into the clique
648
if
(
hard_nodes
.
empty
()) {
649
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(&
cpt
);
650
}
else
{
651
// marginalize out the hard evidence nodes: if the cpt is defined
652
// only over nodes that received hard evidence, do not consider it
653
// as a potential anymore but as a constant
654
if
(
hard_nodes
.
size
() ==
variables
.
size
()) {
655
Instantiation
inst
;
656
const
auto
&
vars
=
cpt
.
variablesSequence
();
657
for
(
auto
var
:
vars
)
658
inst
<< *
var
;
659
for
(
Size
i
= 0;
i
<
hard_nodes
.
size
(); ++
i
) {
660
inst
.
chgVal
(
variables
[
i
],
hard_evidence
[
bn
.
nodeId
(*(
variables
[
i
]))]);
661
}
662
_constants_
.
insert
(
node
,
cpt
[
inst
]);
663
}
else
{
664
// perform the projection with a combine and project instance
665
Set
<
const
DiscreteVariable
* >
hard_variables
;
666
_PotentialSet_
marg_cpt_set
{&
cpt
};
667
for
(
const
auto
xnode
:
hard_nodes
) {
668
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
669
hard_variables
.
insert
(&(
bn
.
variable
(
xnode
)));
670
}
671
672
// perform the combination of those potentials and their projection
673
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
674
_combination_op_
,
675
SSNewprojPotential
);
676
_PotentialSet_
new_cpt_list
677
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
678
679
// there should be only one potential in new_cpt_list
680
if
(
new_cpt_list
.
size
() != 1) {
681
// remove the CPT created to avoid memory leaks
682
for
(
auto
pot
:
new_cpt_list
) {
683
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
684
}
685
GUM_ERROR
(
FatalError
,
686
"the projection of a potential containing "
687
<<
"hard evidence is empty!"
);
688
}
689
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
690
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
projected_cpt
);
691
_hard_ev_projected_CPTs_
.
insert
(
node
,
projected_cpt
);
692
}
693
}
694
}
695
}
696
697
// we shall now add all the potentials of the soft evidence
698
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
699
_node_to_soft_evidence_
.
insert
(
node
,
evidence
[
node
]);
700
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
evidence
[
node
]);
701
}
702
703
// now, in _clique_potentials_, for each clique, we have the list of
704
// potentials that must be combined in order to produce the Shafer-Shenoy's
705
// potential stored into the clique. So, perform this combination and
706
// store the result in _clique_ss_potential_
707
_clique_ss_potential_
.
clear
();
708
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
709
for
(
const
auto
&
xpotset
:
_clique_potentials_
) {
710
const
auto
&
potset
=
xpotset
.
second
;
711
if
(
potset
.
size
() > 0) {
712
// here, there will be an entry in _clique_ss_potential_
713
// If there is only one element in potset, this element shall be
714
// stored into _clique_ss_potential_, else all the elements of potset
715
// shall be combined and their result shall be stored
716
if
(
potset
.
size
() == 1) {
717
_clique_ss_potential_
.
insert
(
xpotset
.
first
, *(
potset
.
cbegin
()));
718
}
else
{
719
auto
joint
=
fast_combination
.
combine
(
potset
);
720
_clique_ss_potential_
.
insert
(
xpotset
.
first
,
joint
);
721
}
722
}
723
}
724
725
// indicate that the data structures are up to date.
726
_evidence_changes_
.
clear
();
727
_is_new_jt_needed_
=
false
;
728
}
729
730
731
/// prepare the inference structures w.r.t. new targets, soft/hard evidence
732
template
<
typename
GUM_SCALAR
>
733
void
ShaferShenoyInference
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {
734
// check if a new JT is really needed. If so, create it
735
if
(
_isNewJTNeeded_
()) {
736
_createNewJT_
();
737
}
else
{
738
// here, we can answer the next queries without reconstructing all the
739
// junction tree. All we need to do is to indicate that we should
740
// update the potentials and messages for these queries
741
updateOutdatedPotentials_
();
742
}
743
}
744
745
746
/// invalidate all the messages sent from a given clique
747
template
<
typename
GUM_SCALAR
>
748
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_diffuseMessageInvalidations_
(
749
NodeId
from_id
,
750
NodeId
to_id
,
751
NodeSet
&
invalidated_cliques
) {
752
// invalidate the current clique
753
invalidated_cliques
.
insert
(
to_id
);
754
755
// invalidate the current arc
756
const
Arc
arc
(
from_id
,
to_id
);
757
bool
&
message_computed
=
_messages_computed_
[
arc
];
758
if
(
message_computed
) {
759
message_computed
=
false
;
760
_separator_potentials_
[
arc
].
clear
();
761
if
(
_created_potentials_
.
exists
(
arc
)) {
762
auto
&
arc_created_potentials
=
_created_potentials_
[
arc
];
763
for
(
auto
pot
:
arc_created_potentials
)
764
delete
pot
;
765
arc_created_potentials
.
clear
();
766
}
767
768
// go on with the diffusion
769
for
(
const
auto
node_id
:
_JT_
->
neighbours
(
to_id
)) {
770
if
(
node_id
!=
from_id
)
_diffuseMessageInvalidations_
(
to_id
,
node_id
,
invalidated_cliques
);
771
}
772
}
773
}
774
775
776
/// update the potentials stored in the cliques and invalidate outdated
777
/// messages
778
template
<
typename
GUM_SCALAR
>
779
void
ShaferShenoyInference
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {
780
// for each clique, indicate whether the potential stored into
781
// _clique_ss_potentials_[clique] is the result of a combination. In this
782
// case, it has been allocated by the combination and will need to be
783
// deallocated if its clique has been invalidated
784
NodeProperty
<
bool
>
ss_potential_to_deallocate
(
_clique_potentials_
.
size
());
785
for
(
auto
pot_iter
=
_clique_potentials_
.
cbegin
();
pot_iter
!=
_clique_potentials_
.
cend
();
786
++
pot_iter
) {
787
ss_potential_to_deallocate
.
insert
(
pot_iter
.
key
(), (
pot_iter
.
val
().
size
() > 1));
788
}
789
790
// compute the set of CPTs that were projected due to hard evidence and
791
// whose hard evidence have changed, so that they need a new projection.
792
// By the way, remove these CPTs since they are no more needed
793
// Here only the values of the hard evidence can have changed (else a
794
// fully new join tree would have been computed).
795
// Note also that we know that the CPTs still contain some variable(s) after
796
// the projection (else they should be constants)
797
NodeSet
hard_nodes_changed
(
_hard_ev_nodes_
.
size
());
798
for
(
const
auto
node
:
_hard_ev_nodes_
)
799
if
(
_evidence_changes_
.
exists
(
node
))
hard_nodes_changed
.
insert
(
node
);
800
801
NodeSet
nodes_with_projected_CPTs_changed
;
802
const
auto
&
bn
=
this
->
BN
();
803
for
(
auto
pot_iter
=
_hard_ev_projected_CPTs_
.
beginSafe
();
804
pot_iter
!=
_hard_ev_projected_CPTs_
.
endSafe
();
805
++
pot_iter
) {
806
for
(
const
auto
var
:
bn
.
cpt
(
pot_iter
.
key
()).
variablesSequence
()) {
807
if
(
hard_nodes_changed
.
contains
(
bn
.
nodeId
(*
var
))) {
808
nodes_with_projected_CPTs_changed
.
insert
(
pot_iter
.
key
());
809
delete
pot_iter
.
val
();
810
_clique_potentials_
[
_node_to_clique_
[
pot_iter
.
key
()]].
erase
(
pot_iter
.
val
());
811
_hard_ev_projected_CPTs_
.
erase
(
pot_iter
);
812
break
;
813
}
814
}
815
}
816
817
818
// invalidate all the messages that are no more correct: start from each of
819
// the nodes whose soft evidence has changed and perform a diffusion from
820
// the clique into which the soft evidence has been entered, indicating that
821
// the messages spreading from this clique are now invalid. At the same time,
822
// if there were potentials created on the arcs over which the messages were
823
// sent, remove them from memory. For all the cliques that received some
824
// projected CPT that should now be changed, do the same.
825
NodeSet
invalidated_cliques
(
_JT_
->
size
());
826
for
(
const
auto
&
pair
:
_evidence_changes_
) {
827
if
(
_node_to_clique_
.
exists
(
pair
.
first
)) {
828
const
auto
clique
=
_node_to_clique_
[
pair
.
first
];
829
invalidated_cliques
.
insert
(
clique
);
830
for
(
const
auto
neighbor
:
_JT_
->
neighbours
(
clique
)) {
831
_diffuseMessageInvalidations_
(
clique
,
neighbor
,
invalidated_cliques
);
832
}
833
}
834
}
835
836
// now, add to the set of invalidated cliques those that contain projected
837
// CPTs that were changed.
838
for
(
auto
node
:
nodes_with_projected_CPTs_changed
) {
839
const
auto
clique
=
_node_to_clique_
[
node
];
840
invalidated_cliques
.
insert
(
clique
);
841
for
(
const
auto
neighbor
:
_JT_
->
neighbours
(
clique
)) {
842
_diffuseMessageInvalidations_
(
clique
,
neighbor
,
invalidated_cliques
);
843
}
844
}
845
846
// now that we know the cliques whose set of potentials have been changed,
847
// we can discard their corresponding Shafer-Shenoy potential
848
for
(
const
auto
clique
:
invalidated_cliques
) {
849
if
(
_clique_ss_potential_
.
exists
(
clique
) &&
ss_potential_to_deallocate
[
clique
]) {
850
delete
_clique_ss_potential_
[
clique
];
851
}
852
}
853
854
855
// now we shall remove all the posteriors that belong to the
856
// invalidated cliques. First, cope only with the nodes that did not
857
// received hard evidence since the other nodes do not belong to the
858
// join tree
859
for
(
auto
iter
=
_target_posteriors_
.
beginSafe
();
iter
!=
_target_posteriors_
.
endSafe
();
860
++
iter
) {
861
if
(
_graph_
.
exists
(
iter
.
key
())
862
&& (
invalidated_cliques
.
exists
(
_node_to_clique_
[
iter
.
key
()]))) {
863
delete
iter
.
val
();
864
_target_posteriors_
.
erase
(
iter
);
865
}
866
}
867
868
// now cope with the nodes that received hard evidence
869
for
(
auto
iter
=
_target_posteriors_
.
beginSafe
();
iter
!=
_target_posteriors_
.
endSafe
();
870
++
iter
) {
871
if
(
hard_nodes_changed
.
contains
(
iter
.
key
())) {
872
delete
iter
.
val
();
873
_target_posteriors_
.
erase
(
iter
);
874
}
875
}
876
877
// finally, cope with joint targets
878
for
(
auto
iter
=
_joint_target_posteriors_
.
beginSafe
();
879
iter
!=
_joint_target_posteriors_
.
endSafe
();
880
++
iter
) {
881
if
(
invalidated_cliques
.
exists
(
_joint_target_to_clique_
[
iter
.
key
()])) {
882
delete
iter
.
val
();
883
_joint_target_posteriors_
.
erase
(
iter
);
884
}
885
}
886
887
888
// remove all the evidence that were entered into _node_to_soft_evidence_
889
// and _clique_potentials_ and add the new soft ones
890
for
(
auto
&
pot_pair
:
_node_to_soft_evidence_
) {
891
_clique_potentials_
[
_node_to_clique_
[
pot_pair
.
first
]].
erase
(
pot_pair
.
second
);
892
}
893
_node_to_soft_evidence_
.
clear
();
894
895
const
auto
&
evidence
=
this
->
evidence
();
896
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
897
_node_to_soft_evidence_
.
insert
(
node
,
evidence
[
node
]);
898
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
evidence
[
node
]);
899
}
900
901
902
// Now add the projections of the CPTs due to newly changed hard evidence:
903
// if we are performing updateOutdatedPotentials_, this means that the
904
// set of nodes that received hard evidence has not been changed, only
905
// their instantiations can have been changed. So, if there is an entry
906
// for node in _constants_, there will still be such an entry after
907
// performing the new projections. Idem for _hard_ev_projected_CPTs_
908
for
(
const
auto
node
:
nodes_with_projected_CPTs_changed
) {
909
// perform the projection with a combine and project instance
910
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
911
const
auto
&
variables
=
cpt
.
variablesSequence
();
912
Set
<
const
DiscreteVariable
* >
hard_variables
;
913
_PotentialSet_
marg_cpt_set
{&
cpt
};
914
for
(
const
auto
var
:
variables
) {
915
NodeId
xnode
=
bn
.
nodeId
(*
var
);
916
if
(
_hard_ev_nodes_
.
exists
(
xnode
)) {
917
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
918
hard_variables
.
insert
(
var
);
919
}
920
}
921
922
// perform the combination of those potentials and their projection
923
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
924
_combination_op_
,
925
SSNewprojPotential
);
926
_PotentialSet_
new_cpt_list
927
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
928
929
// there should be only one potential in new_cpt_list
930
if
(
new_cpt_list
.
size
() != 1) {
931
// remove the CPT created to avoid memory leaks
932
for
(
auto
pot
:
new_cpt_list
) {
933
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
934
}
935
GUM_ERROR
(
FatalError
,
936
"the projection of a potential containing "
937
<<
"hard evidence is empty!"
);
938
}
939
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
940
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
projected_cpt
);
941
_hard_ev_projected_CPTs_
.
insert
(
node
,
projected_cpt
);
942
}
943
944
// here, the list of potentials stored in the invalidated cliques have
945
// been updated. So, now, we can combine them to produce the Shafer-Shenoy
946
// potential stored into the clique
947
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
948
for
(
const
auto
clique
:
invalidated_cliques
) {
949
const
auto
&
potset
=
_clique_potentials_
[
clique
];
950
951
if
(
potset
.
size
() > 0) {
952
// here, there will be an entry in _clique_ss_potential_
953
// If there is only one element in potset, this element shall be
954
// stored into _clique_ss_potential_, else all the elements of potset
955
// shall be combined and their result shall be stored
956
if
(
potset
.
size
() == 1) {
957
_clique_ss_potential_
[
clique
] = *(
potset
.
cbegin
());
958
}
else
{
959
auto
joint
=
fast_combination
.
combine
(
potset
);
960
_clique_ss_potential_
[
clique
] =
joint
;
961
}
962
}
963
}
964
965
966
// update the constants
967
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
968
for
(
auto
&
node_cst
:
_constants_
) {
969
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node_cst
.
first
);
970
const
auto
&
variables
=
cpt
.
variablesSequence
();
971
Instantiation
inst
;
972
for
(
const
auto
var
:
variables
)
973
inst
<< *
var
;
974
for
(
const
auto
var
:
variables
) {
975
inst
.
chgVal
(
var
,
hard_evidence
[
bn
.
nodeId
(*
var
)]);
976
}
977
node_cst
.
second
=
cpt
[
inst
];
978
}
979
980
// indicate that all changes have been performed
981
_evidence_changes_
.
clear
();
982
}
983
984
985
/// compute a root for each connected component of _JT_
986
template
<
typename
GUM_SCALAR
>
987
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_computeJoinTreeRoots_
() {
988
// get the set of cliques in which we can find the targets and joint_targets
989
NodeSet
clique_targets
;
990
for
(
const
auto
node
:
this
->
targets
()) {
991
try
{
992
clique_targets
.
insert
(
_node_to_clique_
[
node
]);
993
}
catch
(
Exception
&) {}
994
}
995
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
996
try
{
997
clique_targets
.
insert
(
_joint_target_to_clique_
[
set
]);
998
}
catch
(
Exception
&) {}
999
}
1000
1001
// put in a vector these cliques and their size
1002
std
::
vector
<
std
::
pair
<
NodeId
,
Size
> >
possible_roots
(
clique_targets
.
size
());
1003
const
auto
&
bn
=
this
->
BN
();
1004
std
::
size_t
i
= 0;
1005
for
(
const
auto
clique_id
:
clique_targets
) {
1006
const
auto
&
clique
=
_JT_
->
clique
(
clique_id
);
1007
Size
dom_size
= 1;
1008
for
(
const
auto
node
:
clique
) {
1009
dom_size
*=
bn
.
variable
(
node
).
domainSize
();
1010
}
1011
possible_roots
[
i
] =
std
::
pair
<
NodeId
,
Size
>(
clique_id
,
dom_size
);
1012
++
i
;
1013
}
1014
1015
// sort the cliques by increasing domain size
1016
std
::
sort
(
possible_roots
.
begin
(),
1017
possible_roots
.
end
(),
1018
[](
const
std
::
pair
<
NodeId
,
Size
>&
a
,
const
std
::
pair
<
NodeId
,
Size
>&
b
) ->
bool
{
1019
return
a
.
second
<
b
.
second
;
1020
});
1021
1022
// pick up the clique with the smallest size in each connected component
1023
NodeProperty
<
bool
>
marked
=
_JT_
->
nodesProperty
(
false
);
1024
std
::
function
<
void
(
NodeId
,
NodeId
) >
diffuse_marks
1025
= [&
marked
, &
diffuse_marks
,
this
](
NodeId
node
,
NodeId
from
) {
1026
if
(!
marked
[
node
]) {
1027
marked
[
node
] =
true
;
1028
for
(
const
auto
neigh
:
_JT_
->
neighbours
(
node
))
1029
if
((
neigh
!=
from
) && !
marked
[
neigh
])
diffuse_marks
(
neigh
,
node
);
1030
}
1031
};
1032
_roots_
.
clear
();
1033
for
(
const
auto
xclique
:
possible_roots
) {
1034
NodeId
clique
=
xclique
.
first
;
1035
if
(!
marked
[
clique
]) {
1036
_roots_
.
insert
(
clique
);
1037
diffuse_marks
(
clique
,
clique
);
1038
}
1039
}
1040
}
1041
1042
1043
// performs the collect phase of Lazy Propagation
1044
template
<
typename
GUM_SCALAR
>
1045
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_collectMessage_
(
NodeId
id
,
NodeId
from
) {
1046
for
(
const
auto
other
:
_JT_
->
neighbours
(
id
)) {
1047
if
((
other
!=
from
) && !
_messages_computed_
[
Arc
(
other
,
id
)])
_collectMessage_
(
other
,
id
);
1048
}
1049
1050
if
((
id
!=
from
) && !
_messages_computed_
[
Arc
(
id
,
from
)]) {
_produceMessage_
(
id
,
from
); }
1051
}
1052
1053
1054
// remove barren variables
1055
template
<
typename
GUM_SCALAR
>
1056
Set
<
const
Potential
<
GUM_SCALAR
>* >
1057
ShaferShenoyInference
<
GUM_SCALAR
>::
_removeBarrenVariables_
(
1058
_PotentialSet_
&
pot_list
,
1059
Set
<
const
DiscreteVariable
* >&
del_vars
) {
1060
// remove from del_vars the variables that received some evidence:
1061
// only those that did not received evidence can be barren variables
1062
Set
<
const
DiscreteVariable
* >
the_del_vars
=
del_vars
;
1063
for
(
auto
iter
=
the_del_vars
.
beginSafe
();
iter
!=
the_del_vars
.
endSafe
(); ++
iter
) {
1064
NodeId
id
=
this
->
BN
().
nodeId
(**
iter
);
1065
if
(
this
->
hardEvidenceNodes
().
exists
(
id
) ||
this
->
softEvidenceNodes
().
exists
(
id
)) {
1066
the_del_vars
.
erase
(
iter
);
1067
}
1068
}
1069
1070
// assign to each random variable the set of potentials that contain it
1071
HashTable
<
const
DiscreteVariable
*,
_PotentialSet_
>
var2pots
;
1072
_PotentialSet_
empty_pot_set
;
1073
for
(
const
auto
pot
:
pot_list
) {
1074
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
1075
for
(
const
auto
var
:
vars
) {
1076
if
(
the_del_vars
.
exists
(
var
)) {
1077
if
(!
var2pots
.
exists
(
var
)) {
var2pots
.
insert
(
var
,
empty_pot_set
); }
1078
var2pots
[
var
].
insert
(
pot
);
1079
}
1080
}
1081
}
1082
1083
// each variable with only one potential is a barren variable
1084
// assign to each potential with barren nodes its set of barren variables
1085
HashTable
<
const
Potential
<
GUM_SCALAR
>*,
Set
<
const
DiscreteVariable
* > >
pot2barren_var
;
1086
Set
<
const
DiscreteVariable
* >
empty_var_set
;
1087
for
(
auto
elt
:
var2pots
) {
1088
if
(
elt
.
second
.
size
() == 1) {
// here we have a barren variable
1089
const
Potential
<
GUM_SCALAR
>*
pot
= *(
elt
.
second
.
begin
());
1090
if
(!
pot2barren_var
.
exists
(
pot
)) {
pot2barren_var
.
insert
(
pot
,
empty_var_set
); }
1091
pot2barren_var
[
pot
].
insert
(
elt
.
first
);
// insert the barren variable
1092
}
1093
}
1094
1095
// for each potential with barren variables, marginalize them.
1096
// if the potential has only barren variables, simply remove them from the
1097
// set of potentials, else just project the potential
1098
MultiDimProjection
<
GUM_SCALAR
,
Potential
>
projector
(
SSNewprojPotential
);
1099
_PotentialSet_
projected_pots
;
1100
for
(
auto
elt
:
pot2barren_var
) {
1101
// remove the current potential from pot_list as, anyway, we will change
1102
// it
1103
const
Potential
<
GUM_SCALAR
>*
pot
=
elt
.
first
;
1104
pot_list
.
erase
(
pot
);
1105
1106
// check whether we need to add a projected new potential or not (i.e.,
1107
// whether there exist non-barren variables or not)
1108
if
(
pot
->
variablesSequence
().
size
() !=
elt
.
second
.
size
()) {
1109
auto
new_pot
=
projector
.
project
(*
pot
,
elt
.
second
);
1110
pot_list
.
insert
(
new_pot
);
1111
projected_pots
.
insert
(
new_pot
);
1112
}
1113
}
1114
1115
return
projected_pots
;
1116
}
1117
1118
1119
// remove variables del_vars from the list of potentials pot_list
1120
template
<
typename
GUM_SCALAR
>
1121
Set
<
const
Potential
<
GUM_SCALAR
>* >
ShaferShenoyInference
<
GUM_SCALAR
>::
_marginalizeOut_
(
1122
Set
<
const
Potential
<
GUM_SCALAR
>* >
pot_list
,
1123
Set
<
const
DiscreteVariable
* >&
del_vars
,
1124
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1125
// remove the potentials corresponding to barren variables if we want
1126
// to exploit barren nodes
1127
_PotentialSet_
barren_projected_potentials
;
1128
if
(
_barren_nodes_type_
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
1129
barren_projected_potentials
=
_removeBarrenVariables_
(
pot_list
,
del_vars
);
1130
}
1131
1132
// create a combine and project operator that will perform the
1133
// marginalization
1134
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
_combination_op_
,
1135
_projection_op_
);
1136
_PotentialSet_
new_pot_list
=
combine_and_project
.
combineAndProject
(
pot_list
,
del_vars
);
1137
1138
// remove all the potentials that were created due to projections of
1139
// barren nodes and that are not part of the new_pot_list: these
1140
// potentials were just temporary potentials
1141
for
(
auto
iter
=
barren_projected_potentials
.
beginSafe
();
1142
iter
!=
barren_projected_potentials
.
endSafe
();
1143
++
iter
) {
1144
if
(!
new_pot_list
.
exists
(*
iter
))
delete
*
iter
;
1145
}
1146
1147
// remove all the potentials that have no dimension
1148
for
(
auto
iter_pot
=
new_pot_list
.
beginSafe
();
iter_pot
!=
new_pot_list
.
endSafe
(); ++
iter_pot
) {
1149
if
((*
iter_pot
)->
variablesSequence
().
size
() == 0) {
1150
// as we have already marginalized out variables that received evidence,
1151
// it may be the case that, after combining and projecting, some
1152
// potentials might be empty. In this case, we shall keep their
1153
// constant and remove them from memory
1154
// # TODO: keep the constants!
1155
delete
*
iter_pot
;
1156
new_pot_list
.
erase
(
iter_pot
);
1157
}
1158
}
1159
1160
return
new_pot_list
;
1161
}
1162
1163
1164
// creates the message sent by clique from_id to clique to_id
1165
template
<
typename
GUM_SCALAR
>
1166
void
ShaferShenoyInference
<
GUM_SCALAR
>::
_produceMessage_
(
NodeId
from_id
,
NodeId
to_id
) {
1167
// get the potentials of the clique.
1168
_PotentialSet_
pot_list
;
1169
if
(
_clique_ss_potential_
.
exists
(
from_id
))
pot_list
.
insert
(
_clique_ss_potential_
[
from_id
]);
1170
1171
// add the messages sent by adjacent nodes to from_id
1172
for
(
const
auto
other_id
:
_JT_
->
neighbours
(
from_id
))
1173
if
(
other_id
!=
to_id
)
pot_list
+=
_separator_potentials_
[
Arc
(
other_id
,
from_id
)];
1174
1175
// get the set of variables that need be removed from the potentials
1176
const
NodeSet
&
from_clique
=
_JT_
->
clique
(
from_id
);
1177
const
NodeSet
&
separator
=
_JT_
->
separator
(
from_id
,
to_id
);
1178
Set
<
const
DiscreteVariable
* >
del_vars
(
from_clique
.
size
());
1179
Set
<
const
DiscreteVariable
* >
kept_vars
(
separator
.
size
());
1180
const
auto
&
bn
=
this
->
BN
();
1181
1182
for
(
const
auto
node
:
from_clique
) {
1183
if
(!
separator
.
contains
(
node
)) {
1184
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1185
}
else
{
1186
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1187
}
1188
}
1189
1190
// pot_list now contains all the potentials to multiply and marginalize
1191
// => combine the messages
1192
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
,
del_vars
,
kept_vars
);
1193
1194
/*
1195
for the moment, remove this test: due to some optimizations, some
1196
potentials might have all their cells greater than 1.
1197
1198
// remove all the potentials that are equal to ones (as probability
1199
// matrix multiplications are tensorial, such potentials are useless)
1200
for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
1201
++iter) {
1202
const auto pot = *iter;
1203
if (pot->variablesSequence().size() == 1) {
1204
bool is_all_ones = true;
1205
for (Instantiation inst(*pot); !inst.end(); ++inst) {
1206
if ((*pot)[inst] < _one_minus_epsilon_) {
1207
is_all_ones = false;
1208
break;
1209
}
1210
}
1211
if (is_all_ones) {
1212
if (!pot_list.exists(pot)) delete pot;
1213
new_pot_list.erase(iter);
1214
continue;
1215
}
1216
}
1217
}
1218
*/
1219
1220
// if there are still potentials in new_pot_list, combine them to
1221
// produce the message on the separator
1222
const
Arc
arc
(
from_id
,
to_id
);
1223
if
(!
new_pot_list
.
empty
()) {
1224
if
(
new_pot_list
.
size
() == 1) {
// there is only one potential
1225
// in new_pot_list, so this corresponds to our message on
1226
// the separator
1227
auto
pot
= *(
new_pot_list
.
begin
());
1228
_separator_potentials_
[
arc
] =
std
::
move
(
new_pot_list
);
1229
if
(!
pot_list
.
exists
(
pot
)) {
1230
if
(!
_created_potentials_
.
exists
(
arc
))
_created_potentials_
.
insert
(
arc
,
_PotentialSet_
());
1231
_created_potentials_
[
arc
].
insert
(
pot
);
1232
}
1233
}
else
{
1234
// create the message in the separator
1235
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1236
auto
joint
=
fast_combination
.
combine
(
new_pot_list
);
1237
_separator_potentials_
[
arc
].
insert
(
joint
);
1238
if
(!
_created_potentials_
.
exists
(
arc
))
_created_potentials_
.
insert
(
arc
,
_PotentialSet_
());
1239
_created_potentials_
[
arc
].
insert
(
joint
);
1240
1241
// remove the temporary messages created in new_pot_list
1242
for
(
const
auto
pot
:
new_pot_list
) {
1243
if
(!
pot_list
.
exists
(
pot
)) {
delete
pot
; }
1244
}
1245
}
1246
}
1247
1248
_messages_computed_
[
arc
] =
true
;
1249
}
1250
1251
1252
// performs a whole inference
1253
template
<
typename
GUM_SCALAR
>
1254
INLINE
void
ShaferShenoyInference
<
GUM_SCALAR
>::
makeInference_
() {
1255
// collect messages for all single targets
1256
for
(
const
auto
node
:
this
->
targets
()) {
1257
// perform only collects in the join tree for nodes that have
1258
// not received hard evidence (those that received hard evidence were
1259
// not included into the join tree for speed-up reasons)
1260
if
(
_graph_
.
exists
(
node
)) {
1261
_collectMessage_
(
_node_to_clique_
[
node
],
_node_to_clique_
[
node
]);
1262
}
1263
}
1264
1265
// collect messages for all set targets
1266
// by parsing _joint_target_to_clique_, we ensure that the cliques that
1267
// are referenced belong to the join tree (even if some of the nodes in
1268
// their associated joint_target do not belong to _graph_)
1269
for
(
const
auto
set
:
_joint_target_to_clique_
)
1270
_collectMessage_
(
set
.
second
,
set
.
second
);
1271
}
1272
1273
1274
/// returns a fresh potential equal to P(1st arg,evidence)
1275
template
<
typename
GUM_SCALAR
>
1276
Potential
<
GUM_SCALAR
>*
1277
ShaferShenoyInference
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
NodeId
id
) {
1278
const
auto
&
bn
=
this
->
BN
();
1279
1280
// hard evidence do not belong to the join tree
1281
// # TODO: check for sets of inconsistent hard evidence
1282
if
(
this
->
hardEvidenceNodes
().
contains
(
id
)) {
1283
return
new
Potential
<
GUM_SCALAR
>(*(
this
->
evidence
()[
id
]));
1284
}
1285
1286
// if we still need to perform some inference task, do it (this should
1287
// already have been done by makeInference_)
1288
NodeId
clique_of_id
=
_node_to_clique_
[
id
];
1289
_collectMessage_
(
clique_of_id
,
clique_of_id
);
1290
1291
// now we just need to create the product of the potentials of the clique
1292
// containing id with the messages received by this clique and
1293
// marginalize out all variables except id
1294
_PotentialSet_
pot_list
;
1295
if
(
_clique_ss_potential_
.
exists
(
clique_of_id
))
1296
pot_list
.
insert
(
_clique_ss_potential_
[
clique_of_id
]);
1297
1298
// add the messages sent by adjacent nodes to targetCliquxse
1299
for
(
const
auto
other
:
_JT_
->
neighbours
(
clique_of_id
))
1300
pot_list
+=
_separator_potentials_
[
Arc
(
other
,
clique_of_id
)];
1301
1302
// get the set of variables that need be removed from the potentials
1303
const
NodeSet
&
nodes
=
_JT_
->
clique
(
clique_of_id
);
1304
Set
<
const
DiscreteVariable
* >
kept_vars
{&(
bn
.
variable
(
id
))};
1305
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1306
for
(
const
auto
node
:
nodes
) {
1307
if
(
node
!=
id
)
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1308
}
1309
1310
// pot_list now contains all the potentials to multiply and marginalize
1311
// => combine the messages
1312
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
,
del_vars
,
kept_vars
);
1313
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1314
1315
if
(
new_pot_list
.
size
() == 1) {
1316
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1317
// if pot already existed, create a copy, so that we can put it into
1318
// the _target_posteriors_ property
1319
if
(
pot_list
.
exists
(
joint
)) {
1320
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1321
}
else
{
1322
// remove the joint from new_pot_list so that it will not be
1323
// removed just after the else block
1324
new_pot_list
.
clear
();
1325
}
1326
}
else
{
1327
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1328
joint
=
fast_combination
.
combine
(
new_pot_list
);
1329
}
1330
1331
// remove the potentials that were created in new_pot_list
1332
for
(
auto
pot
:
new_pot_list
)
1333
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1334
1335
// check that the joint posterior is different from a 0 vector: this would
1336
// indicate that some hard evidence are not compatible (their joint
1337
// probability is equal to 0)
1338
bool
nonzero_found
=
false
;
1339
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1340
if
((*
joint
)[
inst
]) {
1341
nonzero_found
=
true
;
1342
break
;
1343
}
1344
}
1345
if
(!
nonzero_found
) {
1346
// remove joint from memory to avoid memory leaks
1347
delete
joint
;
1348
GUM_ERROR
(
IncompatibleEvidence
,
1349
"some evidence entered into the Bayes "
1350
"net are incompatible (their joint proba = 0)"
);
1351
}
1352
return
joint
;
1353
}
1354
1355
1356
/// returns the posterior of a given variable
1357
template
<
typename
GUM_SCALAR
>
1358
const
Potential
<
GUM_SCALAR
>&
ShaferShenoyInference
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
1359
// check if we have already computed the posterior
1360
if
(
_target_posteriors_
.
exists
(
id
)) {
return
*(
_target_posteriors_
[
id
]); }
1361
1362
// compute the joint posterior and normalize
1363
auto
joint
=
unnormalizedJointPosterior_
(
id
);
1364
if
(
joint
->
sum
() != 1)
// hard test for ReadOnly CPT (as aggregator)
1365
joint
->
normalize
();
1366
_target_posteriors_
.
insert
(
id
,
joint
);
1367
1368
return
*
joint
;
1369
}
1370
1371
1372
// returns the marginal a posteriori proba of a given node
1373
template
<
typename
GUM_SCALAR
>
1374
Potential
<
GUM_SCALAR
>*
1375
ShaferShenoyInference
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
const
NodeSet
&
set
) {
1376
// hard evidence do not belong to the join tree, so extract the nodes
1377
// from targets that are not hard evidence
1378
NodeSet
targets
=
set
,
hard_ev_nodes
;
1379
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
1380
if
(
targets
.
contains
(
node
)) {
1381
targets
.
erase
(
node
);
1382
hard_ev_nodes
.
insert
(
node
);
1383
}
1384
}
1385
1386
// if all the nodes have received hard evidence, then compute the
1387
// joint posterior directly by multiplying the hard evidence potentials
1388
const
auto
&
evidence
=
this
->
evidence
();
1389
if
(
targets
.
empty
()) {
1390
_PotentialSet_
pot_list
;
1391
for
(
const
auto
node
:
set
) {
1392
pot_list
.
insert
(
evidence
[
node
]);
1393
}
1394
if
(
pot_list
.
size
() == 1) {
1395
auto
pot
=
new
Potential
<
GUM_SCALAR
>(**(
pot_list
.
begin
()));
1396
return
pot
;
1397
}
else
{
1398
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1399
return
fast_combination
.
combine
(
pot_list
);
1400
}
1401
}
1402
1403
1404
// if we still need to perform some inference task, do it: so, first,
1405
// determine the clique on which we should perform collect to compute
1406
// the unnormalized joint posterior of a set of nodes containing "set"
1407
NodeId
clique_of_set
;
1408
try
{
1409
clique_of_set
=
_joint_target_to_clique_
[
set
];
1410
}
catch
(
NotFound
&) {
1411
// here, the precise set of targets does not belong to the set of targets
1412
// defined by the user. So we will try to find a clique in the junction
1413
// tree that contains "targets":
1414
1415
// 1/ we should check that all the nodes belong to the join tree
1416
for
(
const
auto
node
:
targets
) {
1417
if
(!
_graph_
.
exists
(
node
)) {
GUM_ERROR
(
UndefinedElement
,
node
<<
" is not a target node"
) }
1418
}
1419
1420
// 2/ the clique created by the first eliminated node among target is the
1421
// one we are looking for
1422
const
std
::
vector
<
NodeId
>&
JT_elim_order
=
_triangulation_
->
eliminationOrder
();
1423
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
1424
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
; ++
i
)
1425
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
1426
NodeId
first_eliminated_node
= *(
targets
.
begin
());
1427
int
elim_number
=
elim_order
[
first_eliminated_node
];
1428
for
(
const
auto
node
:
targets
) {
1429
if
(
elim_order
[
node
] <
elim_number
) {
1430
elim_number
=
elim_order
[
node
];
1431
first_eliminated_node
=
node
;
1432
}
1433
}
1434
clique_of_set
=
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
);
1435
1436
// 3/ check that cliquee_of_set contains the all the nodes in the target
1437
const
NodeSet
&
clique_nodes
=
_JT_
->
clique
(
clique_of_set
);
1438
for
(
const
auto
node
:
targets
) {
1439
if
(!
clique_nodes
.
contains
(
node
)) {
1440
GUM_ERROR
(
UndefinedElement
,
set
<<
" is not a joint target"
)
1441
}
1442
}
1443
1444
// add the discovered clique to _joint_target_to_clique_
1445
_joint_target_to_clique_
.
insert
(
set
,
clique_of_set
);
1446
}
1447
1448
// now perform a collect on the clique
1449
_collectMessage_
(
clique_of_set
,
clique_of_set
);
1450
1451
// now we just need to create the product of the potentials of the clique
1452
// containing set with the messages received by this clique and
1453
// marginalize out all variables except set
1454
_PotentialSet_
pot_list
;
1455
if
(
_clique_ss_potential_
.
exists
(
clique_of_set
))
1456
pot_list
.
insert
(
_clique_ss_potential_
[
clique_of_set
]);
1457
1458
// add the messages sent by adjacent nodes to targetClique
1459
for
(
const
auto
other
:
_JT_
->
neighbours
(
clique_of_set
))
1460
pot_list
+=
_separator_potentials_
[
Arc
(
other
,
clique_of_set
)];
1461
1462
// get the set of variables that need be removed from the potentials
1463
const
NodeSet
&
nodes
=
_JT_
->
clique
(
clique_of_set
);
1464
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1465
Set
<
const
DiscreteVariable
* >
kept_vars
(
targets
.
size
());
1466
const
auto
&
bn
=
this
->
BN
();
1467
for
(
const
auto
node
:
nodes
) {
1468
if
(!
targets
.
contains
(
node
)) {
1469
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1470
}
else
{
1471
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1472
}
1473
}
1474
1475
// pot_list now contains all the potentials to multiply and marginalize
1476
// => combine the messages
1477
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
,
del_vars
,
kept_vars
);
1478
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1479
1480
if
((
new_pot_list
.
size
() == 1) &&
hard_ev_nodes
.
empty
()) {
1481
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1482
// if pot already existed, create a copy, so that we can put it into
1483
// the _target_posteriors_ property
1484
if
(
pot_list
.
exists
(
joint
)) {
1485
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1486
}
else
{
1487
// remove the joint from new_pot_list so that it will not be
1488
// removed just after the next else block
1489
new_pot_list
.
clear
();
1490
}
1491
}
else
{
1492
// combine all the potentials in new_pot_list with all the hard evidence
1493
// of the nodes in set
1494
_PotentialSet_
new_new_pot_list
=
new_pot_list
;
1495
for
(
const
auto
node
:
hard_ev_nodes
) {
1496
new_new_pot_list
.
insert
(
evidence
[
node
]);
1497
}
1498
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1499
joint
=
fast_combination
.
combine
(
new_new_pot_list
);
1500
}
1501
1502
// remove the potentials that were created in new_pot_list
1503
for
(
auto
pot
:
new_pot_list
)
1504
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1505
1506
// check that the joint posterior is different from a 0 vector: this would
1507
// indicate that some hard evidence are not compatible
1508
bool
nonzero_found
=
false
;
1509
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1510
if
((*
joint
)[
inst
]) {
1511
nonzero_found
=
true
;
1512
break
;
1513
}
1514
}
1515
if
(!
nonzero_found
) {
1516
// remove joint from memory to avoid memory leaks
1517
delete
joint
;
1518
GUM_ERROR
(
IncompatibleEvidence
,
1519
"some evidence entered into the Bayes "
1520
"net are incompatible (their joint proba = 0)"
);
1521
}
1522
1523
return
joint
;
1524
}
1525
1526
1527
/// returns the posterior of a given set of variables
1528
template
<
typename
GUM_SCALAR
>
1529
const
Potential
<
GUM_SCALAR
>&
1530
ShaferShenoyInference
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
set
) {
1531
// check if we have already computed the posterior
1532
if
(
_joint_target_posteriors_
.
exists
(
set
)) {
return
*(
_joint_target_posteriors_
[
set
]); }
1533
1534
// compute the joint posterior and normalize
1535
auto
joint
=
unnormalizedJointPosterior_
(
set
);
1536
joint
->
normalize
();
1537
_joint_target_posteriors_
.
insert
(
set
,
joint
);
1538
1539
return
*
joint
;
1540
}
1541
1542
1543
/// returns the posterior of a given set of variables
1544
template
<
typename
GUM_SCALAR
>
1545
const
Potential
<
GUM_SCALAR
>&
1546
ShaferShenoyInference
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
wanted_target
,
1547
const
NodeSet
&
declared_target
) {
1548
// check if we have already computed the posterior of wanted_target
1549
if
(
_joint_target_posteriors_
.
exists
(
wanted_target
))
1550
return
*(
_joint_target_posteriors_
[
wanted_target
]);
1551
1552
// here, we will have to compute the posterior of declared_target and
1553
// marginalize out all the variables that do not belong to wanted_target
1554
1555
// check if we have already computed the posterior of declared_target
1556
if
(!
_joint_target_posteriors_
.
exists
(
declared_target
)) {
jointPosterior_
(
declared_target
); }
1557
1558
// marginalize out all the variables that do not belong to wanted_target
1559
const
auto
&
bn
=
this
->
BN
();
1560
Set
<
const
DiscreteVariable
* >
del_vars
;
1561
for
(
const
auto
node
:
declared_target
)
1562
if
(!
wanted_target
.
contains
(
node
))
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1563
auto
pot
=
new
Potential
<
GUM_SCALAR
>(
1564
_joint_target_posteriors_
[
declared_target
]->
margSumOut
(
del_vars
));
1565
1566
// save the result into the cache
1567
_joint_target_posteriors_
.
insert
(
wanted_target
,
pot
);
1568
1569
return
*
pot
;
1570
}
1571
1572
1573
template
<
typename
GUM_SCALAR
>
1574
GUM_SCALAR
ShaferShenoyInference
<
GUM_SCALAR
>::
evidenceProbability
() {
1575
// perform inference in each connected component
1576
this
->
makeInference
();
1577
1578
// for each connected component, select a variable X and compute the
1579
// joint probability of X and evidence e. Then marginalize-out X to get
1580
// p(e) in this connected component. Finally, multiply all the p(e) that
1581
// we got and the elements in _constants_. The result is the probability
1582
// of evidence
1583
1584
GUM_SCALAR
prob_ev
= 1;
1585
for
(
const
auto
root
:
_roots_
) {
1586
// get a node in the clique
1587
NodeId
node
= *(
_JT_
->
clique
(
root
).
begin
());
1588
Potential
<
GUM_SCALAR
>*
tmp
=
unnormalizedJointPosterior_
(
node
);
1589
GUM_SCALAR
sum
= 0;
1590
for
(
Instantiation
iter
(*
tmp
); !
iter
.
end
(); ++
iter
)
1591
sum
+=
tmp
->
get
(
iter
);
1592
prob_ev
*=
sum
;
1593
delete
tmp
;
1594
}
1595
1596
for
(
const
auto
&
projected_cpt
:
_constants_
)
1597
prob_ev
*=
projected_cpt
.
second
;
1598
1599
return
prob_ev
;
1600
}
1601
1602
1603
}
/* namespace gum */
1604
1605
#
endif
// DOXYGEN_SHOULD_SKIP_THIS
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643