aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
lazyPropagation_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of lazy propagation for inference in
25
* Bayesian networks.
26
*
27
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
30
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
31
#
include
<
algorithm
>
32
33
#
include
<
agrum
/
BN
/
algorithms
/
BayesBall
.
h
>
34
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
35
#
include
<
agrum
/
BN
/
algorithms
/
dSeparation
.
h
>
36
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
37
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
38
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
39
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimProjection
.
h
>
40
41
42
namespace
gum
{
43
// default constructor
44
template
<
typename
GUM_SCALAR >
45
INLINE LazyPropagation<
GUM_SCALAR
>::
LazyPropagation
(
46
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
47
RelevantPotentialsFinderType
relevant_type
,
48
FindBarrenNodesType
barren_type
,
49
bool
use_binary_join_tree
) :
50
JointTargetedInference
<
GUM_SCALAR
>(
BN
),
51
EvidenceInference
<
GUM_SCALAR
>(
BN
),
52
use_binary_join_tree__
(
use_binary_join_tree
) {
53
// sets the relevant potential and the barren nodes finding algorithm
54
setRelevantPotentialsFinderType
(
relevant_type
);
55
setFindBarrenNodesType
(
barren_type
);
56
57
// create a default triangulation (the user can change it afterwards)
58
triangulation__
=
new
DefaultTriangulation
;
59
60
// for debugging purposessetRequiredInference
61
GUM_CONSTRUCTOR
(
LazyPropagation
);
62
}
63
64
65
// destructor
66
template
<
typename
GUM_SCALAR
>
67
INLINE
LazyPropagation
<
GUM_SCALAR
>::~
LazyPropagation
() {
68
// remove all the potentials created during the last message passing
69
for
(
const
auto
&
pots
:
created_potentials__
)
70
for
(
const
auto
pot
:
pots
.
second
)
71
delete
pot
;
72
73
// remove the potentials created after removing the nodes that received
74
// hard evidence
75
for
(
const
auto
&
pot
:
hard_ev_projected_CPTs__
)
76
delete
pot
.
second
;
77
78
// remove all the posteriors computed
79
for
(
const
auto
&
pot
:
target_posteriors__
)
80
delete
pot
.
second
;
81
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
82
delete
pot
.
second
;
83
84
// remove the junction tree and the triangulation algorithm
85
if
(
JT__
!=
nullptr
)
delete
JT__
;
86
if
(
junctionTree__
!=
nullptr
)
delete
junctionTree__
;
87
delete
triangulation__
;
88
89
// for debugging purposes
90
GUM_DESTRUCTOR
(
LazyPropagation
);
91
}
92
93
94
/// set a new triangulation algorithm
95
template
<
typename
GUM_SCALAR
>
96
void
LazyPropagation
<
GUM_SCALAR
>::
setTriangulation
(
97
const
Triangulation
&
new_triangulation
) {
98
delete
triangulation__
;
99
triangulation__
=
new_triangulation
.
newFactory
();
100
is_new_jt_needed__
=
true
;
101
this
->
setOutdatedStructureState_
();
102
}
103
104
105
/// returns the current join (or junction) tree used
106
template
<
typename
GUM_SCALAR
>
107
INLINE
const
JoinTree
*
LazyPropagation
<
GUM_SCALAR
>::
joinTree
() {
108
createNewJT__
();
109
110
return
JT__
;
111
}
112
113
/// returns the current junction tree
114
template
<
typename
GUM_SCALAR
>
115
INLINE
const
JunctionTree
*
LazyPropagation
<
GUM_SCALAR
>::
junctionTree
() {
116
createNewJT__
();
117
118
return
junctionTree__
;
119
}
120
121
122
/// sets how we determine the relevant potentials to combine
123
template
<
typename
GUM_SCALAR
>
124
void
LazyPropagation
<
GUM_SCALAR
>::
setRelevantPotentialsFinderType
(
125
RelevantPotentialsFinderType
type
) {
126
if
(
type
!=
find_relevant_potential_type__
) {
127
switch
(
type
) {
128
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
129
findRelevantPotentials__
= &
LazyPropagation
<
130
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation2__
;
131
break
;
132
133
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
134
findRelevantPotentials__
= &
LazyPropagation
<
135
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation__
;
136
break
;
137
138
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
139
findRelevantPotentials__
= &
LazyPropagation
<
140
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation3__
;
141
break
;
142
143
case
RelevantPotentialsFinderType
::
FIND_ALL
:
144
findRelevantPotentials__
145
= &
LazyPropagation
<
GUM_SCALAR
>::
findRelevantPotentialsGetAll__
;
146
break
;
147
148
default
:
149
GUM_ERROR
(
InvalidArgument
,
150
"setRelevantPotentialsFinderType for type "
151
<< (
unsigned
int
)
type
<<
" is not implemented yet"
);
152
}
153
154
find_relevant_potential_type__
=
type
;
155
156
// indicate that all messages need be reconstructed to take into account
157
// the change in d-separation analysis
158
invalidateAllMessages__
();
159
}
160
}
161
162
163
/// sets the operator for performing the projections
164
template
<
typename
GUM_SCALAR
>
165
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
setProjectionFunction__
(
166
Potential
<
GUM_SCALAR
>* (*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
167
const
Set
<
const
DiscreteVariable
* >&)) {
168
projection_op__
=
proj
;
169
}
170
171
172
/// sets the operator for performing the combinations
173
template
<
typename
GUM_SCALAR
>
174
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
setCombinationFunction__
(
175
Potential
<
GUM_SCALAR
>* (*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
176
const
Potential
<
GUM_SCALAR
>&)) {
177
combination_op__
=
comb
;
178
}
179
180
181
/// invalidate all messages, posteriors and created potentials
182
template
<
typename
GUM_SCALAR
>
183
void
LazyPropagation
<
GUM_SCALAR
>::
invalidateAllMessages__
() {
184
// remove all the messages computed
185
for
(
auto
&
potset
:
separator_potentials__
)
186
potset
.
second
.
clear
();
187
for
(
auto
&
mess_computed
:
messages_computed__
)
188
mess_computed
.
second
=
false
;
189
190
// remove all the created potentials
191
for
(
const
auto
&
potset
:
created_potentials__
)
192
for
(
const
auto
pot
:
potset
.
second
)
193
delete
pot
;
194
195
// remove all the posteriors
196
for
(
const
auto
&
pot
:
target_posteriors__
)
197
delete
pot
.
second
;
198
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
199
delete
pot
.
second
;
200
201
// indicate that new messages need be computed
202
if
(
this
->
isInferenceReady
() ||
this
->
isInferenceDone
())
203
this
->
setOutdatedPotentialsState_
();
204
}
205
206
207
/// sets how we determine barren nodes
208
template
<
typename
GUM_SCALAR
>
209
void
LazyPropagation
<
GUM_SCALAR
>::
setFindBarrenNodesType
(
210
FindBarrenNodesType
type
) {
211
if
(
type
!=
barren_nodes_type__
) {
212
// WARNING: if a new type is added here, method createJT__ should
213
// certainly
214
// be updated as well, in particular its step 2.
215
switch
(
type
) {
216
case
FindBarrenNodesType
::
FIND_BARREN_NODES
:
217
case
FindBarrenNodesType
::
FIND_NO_BARREN_NODES
:
218
break
;
219
220
default
:
221
GUM_ERROR
(
InvalidArgument
,
222
"setFindBarrenNodesType for type "
223
<< (
unsigned
int
)
type
<<
" is not implemented yet"
);
224
}
225
226
barren_nodes_type__
=
type
;
227
228
// potentially, we may need to reconstruct a junction tree
229
this
->
setOutdatedStructureState_
();
230
}
231
}
232
233
234
/// fired when a new evidence is inserted
235
template
<
typename
GUM_SCALAR
>
236
INLINE
void
237
LazyPropagation
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
id
,
238
bool
isHardEvidence
) {
239
// if we have a new hard evidence, this modifies the undigraph over which
240
// the join tree is created. This is also the case if id is not a node of
241
// of the undigraph
242
if
(
isHardEvidence
|| !
graph__
.
exists
(
id
))
243
is_new_jt_needed__
=
true
;
244
else
{
245
try
{
246
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ADDED
);
247
}
catch
(
DuplicateElement
&) {
248
// here, the evidence change already existed. This necessarily means
249
// that the current saved change is an EVIDENCE_ERASED. So if we
250
// erased the evidence and added some again, this corresponds to an
251
// EVIDENCE_MODIFIED
252
evidence_changes__
[
id
] =
EvidenceChangeType
::
EVIDENCE_MODIFIED
;
253
}
254
}
255
}
256
257
258
/// fired when an evidence is removed
259
template
<
typename
GUM_SCALAR
>
260
INLINE
void
261
LazyPropagation
<
GUM_SCALAR
>::
onEvidenceErased_
(
const
NodeId
id
,
262
bool
isHardEvidence
) {
263
// if we delete a hard evidence, this modifies the undigraph over which
264
// the join tree is created.
265
if
(
isHardEvidence
)
266
is_new_jt_needed__
=
true
;
267
else
{
268
try
{
269
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
270
}
catch
(
DuplicateElement
&) {
271
// here, the evidence change already existed and it is necessarily an
272
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
273
// been added and is now erased, this is similar to not having created
274
// it. If the evidence was only modified, it already existed in the
275
// last inference and we should now indicate that it has been removed.
276
if
(
evidence_changes__
[
id
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
277
evidence_changes__
.
erase
(
id
);
278
else
279
evidence_changes__
[
id
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
280
}
281
}
282
}
283
284
285
/// fired when all the evidence are erased
286
template
<
typename
GUM_SCALAR
>
287
void
288
LazyPropagation
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
bool
has_hard_evidence
) {
289
if
(
has_hard_evidence
|| !
this
->
hardEvidenceNodes
().
empty
())
290
is_new_jt_needed__
=
true
;
291
else
{
292
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
293
try
{
294
evidence_changes__
.
insert
(
node
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
295
}
catch
(
DuplicateElement
&) {
296
// here, the evidence change already existed and it is necessarily an
297
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
298
// been added and is now erased, this is similar to not having created
299
// it. If the evidence was only modified, it already existed in the
300
// last inference and we should now indicate that it has been removed.
301
if
(
evidence_changes__
[
node
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
302
evidence_changes__
.
erase
(
node
);
303
else
304
evidence_changes__
[
node
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
305
}
306
}
307
}
308
}
309
310
311
/// fired when an evidence is changed
312
template
<
typename
GUM_SCALAR
>
313
INLINE
void
314
LazyPropagation
<
GUM_SCALAR
>::
onEvidenceChanged_
(
const
NodeId
id
,
315
bool
hasChangedSoftHard
) {
316
if
(
hasChangedSoftHard
)
317
is_new_jt_needed__
=
true
;
318
else
{
319
try
{
320
evidence_changes__
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_MODIFIED
);
321
}
catch
(
DuplicateElement
&) {
322
// here, the evidence change already existed and it is necessarily an
323
// EVIDENCE_ADDED. So we should keep this state to indicate that this
324
// evidence is new w.r.t. the last inference
325
}
326
}
327
}
328
329
330
/// fired after a new Bayes net has been assigned to the engine
331
template
<
typename
GUM_SCALAR
>
332
INLINE
void
333
LazyPropagation
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
bn
) {}
334
335
336
/// fired after a new target is inserted
337
template
<
typename
GUM_SCALAR
>
338
INLINE
void
339
LazyPropagation
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
const
NodeId
id
) {}
340
341
342
/// fired before a target is removed
343
template
<
typename
GUM_SCALAR
>
344
INLINE
void
345
LazyPropagation
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
const
NodeId
id
) {}
346
347
348
/// fired after a new set target is inserted
349
template
<
typename
GUM_SCALAR
>
350
INLINE
void
351
LazyPropagation
<
GUM_SCALAR
>::
onJointTargetAdded_
(
const
NodeSet
&
set
) {}
352
353
354
/// fired before a set target is removed
355
template
<
typename
GUM_SCALAR
>
356
INLINE
void
357
LazyPropagation
<
GUM_SCALAR
>::
onJointTargetErased_
(
const
NodeSet
&
set
) {}
358
359
360
/// fired after all the nodes of the BN are added as single targets
361
template
<
typename
GUM_SCALAR
>
362
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {}
363
364
365
/// fired before a all the single_targets are removed
366
template
<
typename
GUM_SCALAR
>
367
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
368
369
370
/// fired before a all the joint_targets are removed
371
template
<
typename
GUM_SCALAR
>
372
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllJointTargetsErased_
() {}
373
374
375
/// fired before a all the single and joint_targets are removed
376
template
<
typename
GUM_SCALAR
>
377
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllTargetsErased_
() {}
378
379
380
// check whether a new junction tree is really needed for the next inference
381
template
<
typename
GUM_SCALAR
>
382
bool
LazyPropagation
<
GUM_SCALAR
>::
isNewJTNeeded__
()
const
{
383
// if we do not have a JT or if new_jt_needed__ is set to true, then
384
// we know that we need to create a new join tree
385
if
((
JT__
==
nullptr
) ||
is_new_jt_needed__
)
return
true
;
386
387
// if some some targets do not belong to the join tree and, consequently,
388
// to the undigraph that was used to construct the join tree, then we need
389
// to create a new JT. This situation may occur if we constructed the
390
// join tree after pruning irrelevant/barren nodes from the BN)
391
// however, note that the nodes that received hard evidence do not belong to
392
// the graph and, therefore, should not be taken into account
393
const
auto
&
hard_ev_nodes
=
this
->
hardEvidenceNodes
();
394
for
(
const
auto
node
:
this
->
targets
()) {
395
if
(!
graph__
.
exists
(
node
) && !
hard_ev_nodes
.
exists
(
node
))
return
true
;
396
}
397
for
(
const
auto
&
joint_target
:
this
->
jointTargets
()) {
398
// here, we need to check that at least one clique contains all the
399
// nodes of the joint target.
400
bool
containing_clique_found
=
false
;
401
for
(
const
auto
node
:
joint_target
) {
402
bool
found
=
true
;
403
try
{
404
const
NodeSet
&
clique
=
JT__
->
clique
(
node_to_clique__
[
node
]);
405
for
(
const
auto
xnode
:
joint_target
) {
406
if
(!
clique
.
contains
(
xnode
) && !
hard_ev_nodes
.
exists
(
xnode
)) {
407
found
=
false
;
408
break
;
409
}
410
}
411
}
catch
(
NotFound
&) {
found
=
false
; }
412
413
if
(
found
) {
414
containing_clique_found
=
true
;
415
break
;
416
}
417
}
418
419
if
(!
containing_clique_found
)
return
true
;
420
}
421
422
// if some new evidence have been added on nodes that do not belong
423
// to graph__, then we potentially have to reconstruct the join tree
424
for
(
const
auto
&
change
:
evidence_changes__
) {
425
if
((
change
.
second
==
EvidenceChangeType
::
EVIDENCE_ADDED
)
426
&& !
graph__
.
exists
(
change
.
first
))
427
return
true
;
428
}
429
430
// here, the current JT is exactly what we need for the next inference
431
return
false
;
432
}
433
434
435
/// create a new junction tree as well as its related data structures
436
template
<
typename
GUM_SCALAR
>
437
void
LazyPropagation
<
GUM_SCALAR
>::
createNewJT__
() {
438
// to create the JT, we first create the moral graph of the BN in the
439
// following way in order to take into account the barren nodes and the
440
// nodes that received evidence:
441
// 1/ we create an undirected graph containing only the nodes and no edge
442
// 2/ if we take into account barren nodes, remove them from the graph
443
// 3/ add edges so that each node and its parents in the BN form a clique
444
// 4/ add edges so that set targets are cliques of the moral graph
445
// 5/ remove the nodes that received hard evidence (by step 3/, their
446
// parents are linked by edges, which is necessary for inference)
447
//
448
// At the end of step 5/, we have our moral graph and we can triangulate it
449
// to get the new junction tree
450
451
// 1/ create an undirected graph containing only the nodes and no edge
452
const
auto
&
bn
=
this
->
BN
();
453
graph__
.
clear
();
454
for
(
const
auto
node
:
bn
.
dag
())
455
graph__
.
addNodeWithId
(
node
);
456
457
// 2/ if we wish to exploit barren nodes, we shall remove them from the BN
458
// to do so: we identify all the nodes that are not targets and have
459
// received
460
// no evidence and such that their descendants are neither targets nor
461
// evidence nodes. Such nodes can be safely discarded from the BN without
462
// altering the inference output
463
if
(
barren_nodes_type__
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
464
// identify the barren nodes
465
NodeSet
target_nodes
=
this
->
targets
();
466
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
467
target_nodes
+=
nodeset
;
468
}
469
470
// check that all the nodes are not targets, otherwise, there is no
471
// barren node
472
if
(
target_nodes
.
size
() !=
bn
.
size
()) {
473
BarrenNodesFinder
finder
(&(
bn
.
dag
()));
474
finder
.
setTargets
(&
target_nodes
);
475
476
NodeSet
evidence_nodes
;
477
for
(
const
auto
&
pair
:
this
->
evidence
()) {
478
evidence_nodes
.
insert
(
pair
.
first
);
479
}
480
finder
.
setEvidence
(&
evidence_nodes
);
481
482
NodeSet
barren_nodes
=
finder
.
barrenNodes
();
483
484
// remove the barren nodes from the moral graph
485
for
(
const
auto
node
:
barren_nodes
) {
486
graph__
.
eraseNode
(
node
);
487
}
488
}
489
}
490
491
// 3/ add edges so that each node and its parents in the BN form a clique
492
for
(
const
auto
node
:
graph__
) {
493
const
NodeSet
&
parents
=
bn
.
parents
(
node
);
494
for
(
auto
iter1
=
parents
.
cbegin
();
iter1
!=
parents
.
cend
(); ++
iter1
) {
495
graph__
.
addEdge
(*
iter1
,
node
);
496
auto
iter2
=
iter1
;
497
for
(++
iter2
;
iter2
!=
parents
.
cend
(); ++
iter2
) {
498
graph__
.
addEdge
(*
iter1
, *
iter2
);
499
}
500
}
501
}
502
503
// 4/ if there exist some joint targets, we shall add new edges
504
// into the moral graph in order to ensure that there exists a clique
505
// containing each joint
506
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
507
for
(
auto
iter1
=
nodeset
.
cbegin
();
iter1
!=
nodeset
.
cend
(); ++
iter1
) {
508
auto
iter2
=
iter1
;
509
for
(++
iter2
;
iter2
!=
nodeset
.
cend
(); ++
iter2
) {
510
graph__
.
addEdge
(*
iter1
, *
iter2
);
511
}
512
}
513
}
514
515
// 5/ remove all the nodes that received hard evidence
516
hard_ev_nodes__
=
this
->
hardEvidenceNodes
();
517
for
(
const
auto
node
:
hard_ev_nodes__
) {
518
graph__
.
eraseNode
(
node
);
519
}
520
521
522
// now, we can compute the new junction tree. To speed-up computations
523
// (essentially, those of a distribution phase), we construct from this
524
// junction tree a binary join tree
525
if
(
JT__
!=
nullptr
)
delete
JT__
;
526
if
(
junctionTree__
!=
nullptr
)
delete
junctionTree__
;
527
528
triangulation__
->
setGraph
(&
graph__
, &(
this
->
domainSizes
()));
529
const
JunctionTree
&
triang_jt
=
triangulation__
->
junctionTree
();
530
if
(
use_binary_join_tree__
) {
531
BinaryJoinTreeConverterDefault
bjt_converter
;
532
NodeSet
emptyset
;
533
JT__
=
new
CliqueGraph
(
534
bjt_converter
.
convert
(
triang_jt
,
this
->
domainSizes
(),
emptyset
));
535
}
else
{
536
JT__
=
new
CliqueGraph
(
triang_jt
);
537
}
538
junctionTree__
=
new
CliqueGraph
(
triang_jt
);
539
540
541
// indicate, for each node of the moral graph a clique in JT__ that can
542
// contain its conditional probability table
543
node_to_clique__
.
clear
();
544
const
std
::
vector
<
NodeId
>&
JT_elim_order
545
=
triangulation__
->
eliminationOrder
();
546
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
547
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
;
548
++
i
)
549
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
550
const
DAG
&
dag
=
bn
.
dag
();
551
for
(
const
auto
node
:
graph__
) {
552
// get the variables in the potential of node (and its parents)
553
NodeId
first_eliminated_node
=
node
;
554
int
elim_number
=
elim_order
[
first_eliminated_node
];
555
556
for
(
const
auto
parent
:
dag
.
parents
(
node
)) {
557
if
(
graph__
.
existsNode
(
parent
) && (
elim_order
[
parent
] <
elim_number
)) {
558
elim_number
=
elim_order
[
parent
];
559
first_eliminated_node
=
parent
;
560
}
561
}
562
563
// first_eliminated_node contains the first var (node or one of its
564
// parents) eliminated => the clique created during its elimination
565
// contains node and all of its parents => it can contain the potential
566
// assigned to the node in the BN
567
node_to_clique__
.
insert
(
568
node
,
569
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
570
}
571
572
// do the same for the nodes that received evidence. Here, we only store
573
// the nodes for which at least one parent belongs to graph__ (otherwise
574
// their CPT is just a constant real number).
575
for
(
const
auto
node
:
hard_ev_nodes__
) {
576
// get the set of parents of the node that belong to graph__
577
NodeSet
pars
(
dag
.
parents
(
node
).
size
());
578
for
(
const
auto
par
:
dag
.
parents
(
node
))
579
if
(
graph__
.
exists
(
par
))
pars
.
insert
(
par
);
580
581
if
(!
pars
.
empty
()) {
582
NodeId
first_eliminated_node
= *(
pars
.
begin
());
583
int
elim_number
=
elim_order
[
first_eliminated_node
];
584
585
for
(
const
auto
parent
:
pars
) {
586
if
(
elim_order
[
parent
] <
elim_number
) {
587
elim_number
=
elim_order
[
parent
];
588
first_eliminated_node
=
parent
;
589
}
590
}
591
592
// first_eliminated_node contains the first var (node or one of its
593
// parents) eliminated => the clique created during its elimination
594
// contains node and all of its parents => it can contain the potential
595
// assigned to the node in the BN
596
node_to_clique__
.
insert
(
597
node
,
598
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
599
}
600
}
601
602
// indicate for each joint_target a clique that contains it
603
joint_target_to_clique__
.
clear
();
604
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
605
// remove from set all the nodes that received hard evidence (since they
606
// do not belong to the join tree)
607
NodeSet
nodeset
=
set
;
608
for
(
const
auto
node
:
hard_ev_nodes__
)
609
if
(
nodeset
.
contains
(
node
))
nodeset
.
erase
(
node
);
610
611
if
(!
nodeset
.
empty
()) {
612
// the clique we are looking for is the one that was created when
613
// the first element of nodeset was eliminated
614
NodeId
first_eliminated_node
= *(
nodeset
.
begin
());
615
int
elim_number
=
elim_order
[
first_eliminated_node
];
616
for
(
const
auto
node
:
nodeset
) {
617
if
(
elim_order
[
node
] <
elim_number
) {
618
elim_number
=
elim_order
[
node
];
619
first_eliminated_node
=
node
;
620
}
621
}
622
623
joint_target_to_clique__
.
insert
(
624
set
,
625
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
));
626
}
627
}
628
629
630
// compute the roots of JT__'s connected components
631
computeJoinTreeRoots__
();
632
633
// create empty potential lists into the cliques of the joint tree as well
634
// as empty lists of evidence
635
PotentialSet__
empty_set
;
636
clique_potentials__
.
clear
();
637
node_to_soft_evidence__
.
clear
();
638
for
(
const
auto
node
: *
JT__
) {
639
clique_potentials__
.
insert
(
node
,
empty_set
);
640
}
641
642
// remove all the potentials created during the last inference
643
for
(
const
auto
&
potlist
:
created_potentials__
)
644
for
(
const
auto
pot
:
potlist
.
second
)
645
delete
pot
;
646
created_potentials__
.
clear
();
647
648
// remove all the potentials created to take into account hard evidence
649
// during the last inference
650
for
(
const
auto
pot_pair
:
hard_ev_projected_CPTs__
)
651
delete
pot_pair
.
second
;
652
hard_ev_projected_CPTs__
.
clear
();
653
654
// remove all the constants created due to projections of CPTs that were
655
// defined over only hard evidence nodes
656
constants__
.
clear
();
657
658
// create empty lists of potentials for the messages and indicate that no
659
// message has been computed yet
660
separator_potentials__
.
clear
();
661
messages_computed__
.
clear
();
662
for
(
const
auto
&
edge
:
JT__
->
edges
()) {
663
const
Arc
arc1
(
edge
.
first
(),
edge
.
second
());
664
separator_potentials__
.
insert
(
arc1
,
empty_set
);
665
messages_computed__
.
insert
(
arc1
,
false
);
666
const
Arc
arc2
(
Arc
(
edge
.
second
(),
edge
.
first
()));
667
separator_potentials__
.
insert
(
arc2
,
empty_set
);
668
messages_computed__
.
insert
(
arc2
,
false
);
669
}
670
671
// remove all the posteriors computed so far
672
for
(
const
auto
&
pot
:
target_posteriors__
)
673
delete
pot
.
second
;
674
target_posteriors__
.
clear
();
675
for
(
const
auto
&
pot
:
joint_target_posteriors__
)
676
delete
pot
.
second
;
677
joint_target_posteriors__
.
clear
();
678
679
680
// put all the CPT's of the Bayes net nodes into the cliques
681
// here, beware: all the potentials that are defined over some nodes
682
// including hard evidence must be projected so that these nodes are
683
// removed from the potential
684
const
auto
&
evidence
=
this
->
evidence
();
685
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
686
for
(
const
auto
node
:
dag
) {
687
if
(
graph__
.
exists
(
node
) ||
hard_ev_nodes__
.
contains
(
node
)) {
688
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
689
690
// get the list of nodes with hard evidence in cpt
691
NodeSet
hard_nodes
;
692
const
auto
&
variables
=
cpt
.
variablesSequence
();
693
for
(
const
auto
var
:
variables
) {
694
NodeId
xnode
=
bn
.
nodeId
(*
var
);
695
if
(
hard_ev_nodes__
.
contains
(
xnode
))
hard_nodes
.
insert
(
xnode
);
696
}
697
698
// if hard_nodes contains hard evidence nodes, perform a projection
699
// and insert the result into the appropriate clique, else insert
700
// directly cpt into the clique
701
if
(
hard_nodes
.
empty
()) {
702
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(&
cpt
);
703
}
else
{
704
// marginalize out the hard evidence nodes: if the cpt is defined
705
// only over nodes that received hard evidence, do not consider it
706
// as a potential anymore but as a constant
707
if
(
hard_nodes
.
size
() ==
variables
.
size
()) {
708
Instantiation
inst
;
709
const
auto
&
vars
=
cpt
.
variablesSequence
();
710
for
(
const
auto
var
:
vars
)
711
inst
<< *
var
;
712
for
(
Size
i
= 0;
i
<
hard_nodes
.
size
(); ++
i
) {
713
inst
.
chgVal
(
variables
[
i
],
hard_evidence
[
bn
.
nodeId
(*(
variables
[
i
]))]);
714
}
715
constants__
.
insert
(
node
,
cpt
.
get
(
inst
));
716
}
else
{
717
// perform the projection with a combine and project instance
718
Set
<
const
DiscreteVariable
* >
hard_variables
;
719
PotentialSet__
marg_cpt_set
{&
cpt
};
720
for
(
const
auto
xnode
:
hard_nodes
) {
721
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
722
hard_variables
.
insert
(&(
bn
.
variable
(
xnode
)));
723
}
724
725
// perform the combination of those potentials and their projection
726
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
727
combine_and_project
(
combination_op__
,
LPNewprojPotential
);
728
PotentialSet__
new_cpt_list
729
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
730
hard_variables
);
731
732
// there should be only one potential in new_cpt_list
733
if
(
new_cpt_list
.
size
() != 1) {
734
// remove the CPT created to avoid memory leaks
735
for
(
const
auto
pot
:
new_cpt_list
) {
736
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
737
}
738
GUM_ERROR
(
FatalError
,
739
"the projection of a potential containing "
740
<<
"hard evidence is empty!"
);
741
}
742
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
743
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
projected_cpt
);
744
hard_ev_projected_CPTs__
.
insert
(
node
,
projected_cpt
);
745
}
746
}
747
}
748
}
749
750
// we shall now add all the potentials of the soft evidence
751
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
752
node_to_soft_evidence__
.
insert
(
node
,
evidence
[
node
]);
753
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
evidence
[
node
]);
754
}
755
756
// indicate that the data structures are up to date.
757
evidence_changes__
.
clear
();
758
is_new_jt_needed__
=
false
;
759
}
760
761
762
/// prepare the inference structures w.r.t. new targets, soft/hard evidence
763
template
<
typename
GUM_SCALAR
>
764
void
LazyPropagation
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {
765
// check if a new JT is really needed. If so, create it
766
if
(
isNewJTNeeded__
()) {
767
createNewJT__
();
768
}
else
{
769
// here, we can answer the next queries without reconstructing all the
770
// junction tree. All we need to do is to indicate that we should
771
// update the potentials and messages for these queries
772
updateOutdatedPotentials_
();
773
}
774
}
775
776
777
/// invalidate all the messages sent from a given clique
778
template
<
typename
GUM_SCALAR
>
779
void
LazyPropagation
<
GUM_SCALAR
>::
diffuseMessageInvalidations__
(
780
NodeId
from_id
,
781
NodeId
to_id
,
782
NodeSet
&
invalidated_cliques
) {
783
// invalidate the current clique
784
invalidated_cliques
.
insert
(
to_id
);
785
786
// invalidate the current arc
787
const
Arc
arc
(
from_id
,
to_id
);
788
bool
&
message_computed
=
messages_computed__
[
arc
];
789
if
(
message_computed
) {
790
message_computed
=
false
;
791
separator_potentials__
[
arc
].
clear
();
792
if
(
created_potentials__
.
exists
(
arc
)) {
793
auto
&
arc_created_potentials
=
created_potentials__
[
arc
];
794
for
(
const
auto
pot
:
arc_created_potentials
)
795
delete
pot
;
796
arc_created_potentials
.
clear
();
797
}
798
799
// go on with the diffusion
800
for
(
const
auto
node_id
:
JT__
->
neighbours
(
to_id
)) {
801
if
(
node_id
!=
from_id
)
802
diffuseMessageInvalidations__
(
to_id
,
node_id
,
invalidated_cliques
);
803
}
804
}
805
}
806
807
808
/// update the potentials stored in the cliques and invalidate outdated
809
/// messages
810
template
<
typename
GUM_SCALAR
>
811
void
LazyPropagation
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {
812
// compute the set of CPTs that were projected due to hard evidence and
813
// whose hard evidence have changed, so that they need a new projection.
814
// By the way, remove these CPTs since they are no more needed
815
// Here only the values of the hard evidence can have changed (else a
816
// fully new join tree would have been computed).
817
// Note also that we know that the CPTs still contain some variable(s) after
818
// the projection (else they should be constants)
819
NodeSet
hard_nodes_changed
(
hard_ev_nodes__
.
size
());
820
for
(
const
auto
node
:
hard_ev_nodes__
)
821
if
(
evidence_changes__
.
exists
(
node
))
hard_nodes_changed
.
insert
(
node
);
822
823
NodeSet
nodes_with_projected_CPTs_changed
;
824
const
auto
&
bn
=
this
->
BN
();
825
for
(
auto
pot_iter
=
hard_ev_projected_CPTs__
.
beginSafe
();
826
pot_iter
!=
hard_ev_projected_CPTs__
.
endSafe
();
827
++
pot_iter
) {
828
for
(
const
auto
var
:
bn
.
cpt
(
pot_iter
.
key
()).
variablesSequence
()) {
829
if
(
hard_nodes_changed
.
contains
(
bn
.
nodeId
(*
var
))) {
830
nodes_with_projected_CPTs_changed
.
insert
(
pot_iter
.
key
());
831
delete
pot_iter
.
val
();
832
clique_potentials__
[
node_to_clique__
[
pot_iter
.
key
()]].
erase
(
833
pot_iter
.
val
());
834
hard_ev_projected_CPTs__
.
erase
(
pot_iter
);
835
break
;
836
}
837
}
838
}
839
840
841
// invalidate all the messages that are no more correct: start from each of
842
// the nodes whose soft evidence has changed and perform a diffusion from
843
// the clique into which the soft evidence has been entered, indicating that
844
// the messages spreading from this clique are now invalid. At the same time,
845
// if there were potentials created on the arcs over which the messages were
846
// sent, remove them from memory. For all the cliques that received some
847
// projected CPT that should now be changed, do the same.
848
NodeSet
invalidated_cliques
(
JT__
->
size
());
849
for
(
const
auto
&
pair
:
evidence_changes__
) {
850
if
(
node_to_clique__
.
exists
(
pair
.
first
)) {
851
const
auto
clique
=
node_to_clique__
[
pair
.
first
];
852
invalidated_cliques
.
insert
(
clique
);
853
for
(
const
auto
neighbor
:
JT__
->
neighbours
(
clique
)) {
854
diffuseMessageInvalidations__
(
clique
,
neighbor
,
invalidated_cliques
);
855
}
856
}
857
}
858
859
// now, add to the set of invalidated cliques those that contain projected
860
// CPTs that were changed.
861
for
(
const
auto
node
:
nodes_with_projected_CPTs_changed
) {
862
const
auto
clique
=
node_to_clique__
[
node
];
863
invalidated_cliques
.
insert
(
clique
);
864
for
(
const
auto
neighbor
:
JT__
->
neighbours
(
clique
)) {
865
diffuseMessageInvalidations__
(
clique
,
neighbor
,
invalidated_cliques
);
866
}
867
}
868
869
870
// now we shall remove all the posteriors that belong to the
871
// invalidated cliques. First, cope only with the nodes that did not
872
// received hard evidence since the other nodes do not belong to the
873
// join tree
874
for
(
auto
iter
=
target_posteriors__
.
beginSafe
();
875
iter
!=
target_posteriors__
.
endSafe
();
876
++
iter
) {
877
if
(
graph__
.
exists
(
iter
.
key
())
878
&& (
invalidated_cliques
.
exists
(
node_to_clique__
[
iter
.
key
()]))) {
879
delete
iter
.
val
();
880
target_posteriors__
.
erase
(
iter
);
881
}
882
}
883
884
// now cope with the nodes that received hard evidence
885
for
(
auto
iter
=
target_posteriors__
.
beginSafe
();
886
iter
!=
target_posteriors__
.
endSafe
();
887
++
iter
) {
888
if
(
hard_nodes_changed
.
contains
(
iter
.
key
())) {
889
delete
iter
.
val
();
890
target_posteriors__
.
erase
(
iter
);
891
}
892
}
893
894
// finally, cope with joint targets
895
for
(
auto
iter
=
joint_target_posteriors__
.
beginSafe
();
896
iter
!=
joint_target_posteriors__
.
endSafe
();
897
++
iter
) {
898
if
(
invalidated_cliques
.
exists
(
joint_target_to_clique__
[
iter
.
key
()])) {
899
delete
iter
.
val
();
900
joint_target_posteriors__
.
erase
(
iter
);
901
}
902
}
903
904
905
// remove all the evidence that were entered into node_to_soft_evidence__
906
// and clique_potentials__ and add the new soft ones
907
for
(
const
auto
&
pot_pair
:
node_to_soft_evidence__
) {
908
clique_potentials__
[
node_to_clique__
[
pot_pair
.
first
]].
erase
(
pot_pair
.
second
);
909
}
910
node_to_soft_evidence__
.
clear
();
911
912
const
auto
&
evidence
=
this
->
evidence
();
913
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
914
node_to_soft_evidence__
.
insert
(
node
,
evidence
[
node
]);
915
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
evidence
[
node
]);
916
}
917
918
919
// Now add the projections of the CPTs due to newly changed hard evidence:
920
// if we are performing updateOutdatedPotentials_, this means that the
921
// set of nodes that received hard evidence has not changed, only
922
// their instantiations can have changed. So, if there is an entry
923
// for node in constants__, there will still be such an entry after
924
// performing the new projections. Idem for hard_ev_projected_CPTs__
925
for
(
const
auto
node
:
nodes_with_projected_CPTs_changed
) {
926
// perform the projection with a combine and project instance
927
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
928
const
auto
&
variables
=
cpt
.
variablesSequence
();
929
Set
<
const
DiscreteVariable
* >
hard_variables
;
930
PotentialSet__
marg_cpt_set
{&
cpt
};
931
for
(
const
auto
var
:
variables
) {
932
NodeId
xnode
=
bn
.
nodeId
(*
var
);
933
if
(
hard_ev_nodes__
.
exists
(
xnode
)) {
934
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
935
hard_variables
.
insert
(
var
);
936
}
937
}
938
939
// perform the combination of those potentials and their projection
940
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
941
combine_and_project
(
combination_op__
,
LPNewprojPotential
);
942
PotentialSet__
new_cpt_list
943
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
944
945
// there should be only one potential in new_cpt_list
946
if
(
new_cpt_list
.
size
() != 1) {
947
// remove the CPT created to avoid memory leaks
948
for
(
const
auto
pot
:
new_cpt_list
) {
949
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
950
}
951
GUM_ERROR
(
FatalError
,
952
"the projection of a potential containing "
953
<<
"hard evidence is empty!"
);
954
}
955
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
956
clique_potentials__
[
node_to_clique__
[
node
]].
insert
(
projected_cpt
);
957
hard_ev_projected_CPTs__
.
insert
(
node
,
projected_cpt
);
958
}
959
960
// update the constants
961
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
962
for
(
auto
&
node_cst
:
constants__
) {
963
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node_cst
.
first
);
964
const
auto
&
variables
=
cpt
.
variablesSequence
();
965
Instantiation
inst
;
966
for
(
const
auto
var
:
variables
)
967
inst
<< *
var
;
968
for
(
const
auto
var
:
variables
) {
969
inst
.
chgVal
(
var
,
hard_evidence
[
bn
.
nodeId
(*
var
)]);
970
}
971
node_cst
.
second
=
cpt
.
get
(
inst
);
972
}
973
974
// indicate that all changes have been performed
975
evidence_changes__
.
clear
();
976
}
977
978
979
/// compute a root for each connected component of JT__
980
template
<
typename
GUM_SCALAR
>
981
void
LazyPropagation
<
GUM_SCALAR
>::
computeJoinTreeRoots__
() {
982
// get the set of cliques in which we can find the targets and joint_targets
983
NodeSet
clique_targets
;
984
for
(
const
auto
node
:
this
->
targets
()) {
985
try
{
986
clique_targets
.
insert
(
node_to_clique__
[
node
]);
987
}
catch
(
Exception
&) {}
988
}
989
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
990
try
{
991
clique_targets
.
insert
(
joint_target_to_clique__
[
set
]);
992
}
catch
(
Exception
&) {}
993
}
994
995
// put in a vector these cliques and their size
996
std
::
vector
<
std
::
pair
<
NodeId
,
Size
> >
possible_roots
(
clique_targets
.
size
());
997
const
auto
&
bn
=
this
->
BN
();
998
std
::
size_t
i
= 0;
999
for
(
const
auto
clique_id
:
clique_targets
) {
1000
const
auto
&
clique
=
JT__
->
clique
(
clique_id
);
1001
Size
dom_size
= 1;
1002
for
(
const
auto
node
:
clique
) {
1003
dom_size
*=
bn
.
variable
(
node
).
domainSize
();
1004
}
1005
possible_roots
[
i
] =
std
::
pair
<
NodeId
,
Size
>(
clique_id
,
dom_size
);
1006
++
i
;
1007
}
1008
1009
// sort the cliques by increasing domain size
1010
std
::
sort
(
possible_roots
.
begin
(),
1011
possible_roots
.
end
(),
1012
[](
const
std
::
pair
<
NodeId
,
Size
>&
a
,
1013
const
std
::
pair
<
NodeId
,
Size
>&
b
) ->
bool
{
1014
return
a
.
second
<
b
.
second
;
1015
});
1016
1017
// pick up the clique with the smallest size in each connected component
1018
NodeProperty
<
bool
>
marked
=
JT__
->
nodesProperty
(
false
);
1019
std
::
function
<
void
(
NodeId
,
NodeId
) >
diffuse_marks
1020
= [&
marked
, &
diffuse_marks
,
this
](
NodeId
node
,
NodeId
from
) {
1021
if
(!
marked
[
node
]) {
1022
marked
[
node
] =
true
;
1023
for
(
const
auto
neigh
:
JT__
->
neighbours
(
node
))
1024
if
((
neigh
!=
from
) && !
marked
[
neigh
])
diffuse_marks
(
neigh
,
node
);
1025
}
1026
};
1027
roots__
.
clear
();
1028
for
(
const
auto
xclique
:
possible_roots
) {
1029
NodeId
clique
=
xclique
.
first
;
1030
if
(!
marked
[
clique
]) {
1031
roots__
.
insert
(
clique
);
1032
diffuse_marks
(
clique
,
clique
);
1033
}
1034
}
1035
}
1036
1037
1038
// performs the collect phase of Lazy Propagation
1039
template
<
typename
GUM_SCALAR
>
1040
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
collectMessage__
(
NodeId
id
,
1041
NodeId
from
) {
1042
for
(
const
auto
other
:
JT__
->
neighbours
(
id
)) {
1043
if
((
other
!=
from
) && !
messages_computed__
[
Arc
(
other
,
id
)])
1044
collectMessage__
(
other
,
id
);
1045
}
1046
1047
if
((
id
!=
from
) && !
messages_computed__
[
Arc
(
id
,
from
)]) {
1048
produceMessage__
(
id
,
from
);
1049
}
1050
}
1051
1052
1053
// find the potentials d-connected to a set of variables
1054
template
<
typename
GUM_SCALAR
>
1055
void
LazyPropagation
<
GUM_SCALAR
>::
findRelevantPotentialsGetAll__
(
1056
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1057
Set
<
const
DiscreteVariable
* >&
kept_vars
) {}
1058
1059
1060
// find the potentials d-connected to a set of variables
1061
template
<
typename
GUM_SCALAR
>
1062
void
LazyPropagation
<
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation__
(
1063
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1064
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1065
// find the node ids of the kept variables
1066
NodeSet
kept_ids
;
1067
const
auto
&
bn
=
this
->
BN
();
1068
for
(
const
auto
var
:
kept_vars
) {
1069
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
1070
}
1071
1072
// determine the set of potentials d-connected with the kept variables
1073
NodeSet
requisite_nodes
;
1074
BayesBall
::
requisiteNodes
(
bn
.
dag
(),
1075
kept_ids
,
1076
this
->
hardEvidenceNodes
(),
1077
this
->
softEvidenceNodes
(),
1078
requisite_nodes
);
1079
for
(
auto
iter
=
pot_list
.
beginSafe
();
iter
!=
pot_list
.
endSafe
(); ++
iter
) {
1080
const
Sequence
<
const
DiscreteVariable
* >&
vars
1081
= (**
iter
).
variablesSequence
();
1082
bool
found
=
false
;
1083
for
(
const
auto
var
:
vars
) {
1084
if
(
requisite_nodes
.
exists
(
bn
.
nodeId
(*
var
))) {
1085
found
=
true
;
1086
break
;
1087
}
1088
}
1089
1090
if
(!
found
) {
pot_list
.
erase
(
iter
); }
1091
}
1092
}
1093
1094
1095
// find the potentials d-connected to a set of variables
1096
template
<
typename
GUM_SCALAR
>
1097
void
LazyPropagation
<
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation2__
(
1098
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1099
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1100
// find the node ids of the kept variables
1101
NodeSet
kept_ids
;
1102
const
auto
&
bn
=
this
->
BN
();
1103
for
(
const
auto
var
:
kept_vars
) {
1104
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
1105
}
1106
1107
// determine the set of potentials d-connected with the kept variables
1108
BayesBall
::
relevantPotentials
(
bn
,
1109
kept_ids
,
1110
this
->
hardEvidenceNodes
(),
1111
this
->
softEvidenceNodes
(),
1112
pot_list
);
1113
}
1114
1115
1116
// find the potentials d-connected to a set of variables
1117
template
<
typename
GUM_SCALAR
>
1118
void
LazyPropagation
<
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation3__
(
1119
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1120
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1121
// find the node ids of the kept variables
1122
NodeSet
kept_ids
;
1123
const
auto
&
bn
=
this
->
BN
();
1124
for
(
const
auto
var
:
kept_vars
) {
1125
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
1126
}
1127
1128
// determine the set of potentials d-connected with the kept variables
1129
dSeparation
dsep
;
1130
dsep
.
relevantPotentials
(
bn
,
1131
kept_ids
,
1132
this
->
hardEvidenceNodes
(),
1133
this
->
softEvidenceNodes
(),
1134
pot_list
);
1135
}
1136
1137
1138
// find the potentials d-connected to a set of variables
1139
template
<
typename
GUM_SCALAR
>
1140
void
LazyPropagation
<
GUM_SCALAR
>::
findRelevantPotentialsXX__
(
1141
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1142
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1143
switch
(
find_relevant_potential_type__
) {
1144
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
1145
findRelevantPotentialsWithdSeparation2__
(
pot_list
,
kept_vars
);
1146
break
;
1147
1148
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
1149
findRelevantPotentialsWithdSeparation__
(
pot_list
,
kept_vars
);
1150
break
;
1151
1152
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
1153
findRelevantPotentialsWithdSeparation3__
(
pot_list
,
kept_vars
);
1154
break
;
1155
1156
case
RelevantPotentialsFinderType
::
FIND_ALL
:
1157
findRelevantPotentialsGetAll__
(
pot_list
,
kept_vars
);
1158
break
;
1159
1160
default
:
1161
GUM_ERROR
(
FatalError
,
"not implemented yet"
);
1162
}
1163
}
1164
1165
1166
// remove barren variables
1167
template
<
typename
GUM_SCALAR
>
1168
Set
<
const
Potential
<
GUM_SCALAR
>* >
1169
LazyPropagation
<
GUM_SCALAR
>::
removeBarrenVariables__
(
1170
PotentialSet__
&
pot_list
,
1171
Set
<
const
DiscreteVariable
* >&
del_vars
) {
1172
// remove from del_vars the variables that received some evidence:
1173
// only those that did not received evidence can be barren variables
1174
Set
<
const
DiscreteVariable
* >
the_del_vars
=
del_vars
;
1175
for
(
auto
iter
=
the_del_vars
.
beginSafe
();
iter
!=
the_del_vars
.
endSafe
();
1176
++
iter
) {
1177
NodeId
id
=
this
->
BN
().
nodeId
(**
iter
);
1178
if
(
this
->
hardEvidenceNodes
().
exists
(
id
)
1179
||
this
->
softEvidenceNodes
().
exists
(
id
)) {
1180
the_del_vars
.
erase
(
iter
);
1181
}
1182
}
1183
1184
// assign to each random variable the set of potentials that contain it
1185
HashTable
<
const
DiscreteVariable
*,
PotentialSet__
>
var2pots
;
1186
PotentialSet__
empty_pot_set
;
1187
for
(
const
auto
pot
:
pot_list
) {
1188
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
1189
for
(
const
auto
var
:
vars
) {
1190
if
(
the_del_vars
.
exists
(
var
)) {
1191
if
(!
var2pots
.
exists
(
var
)) {
var2pots
.
insert
(
var
,
empty_pot_set
); }
1192
var2pots
[
var
].
insert
(
pot
);
1193
}
1194
}
1195
}
1196
1197
// each variable with only one potential is a barren variable
1198
// assign to each potential with barren nodes its set of barren variables
1199
HashTable
<
const
Potential
<
GUM_SCALAR
>*,
Set
<
const
DiscreteVariable
* > >
1200
pot2barren_var
;
1201
Set
<
const
DiscreteVariable
* >
empty_var_set
;
1202
for
(
const
auto
elt
:
var2pots
) {
1203
if
(
elt
.
second
.
size
() == 1) {
// here we have a barren variable
1204
const
Potential
<
GUM_SCALAR
>*
pot
= *(
elt
.
second
.
begin
());
1205
if
(!
pot2barren_var
.
exists
(
pot
)) {
1206
pot2barren_var
.
insert
(
pot
,
empty_var_set
);
1207
}
1208
pot2barren_var
[
pot
].
insert
(
elt
.
first
);
// insert the barren variable
1209
}
1210
}
1211
1212
// for each potential with barren variables, marginalize them.
1213
// if the potential has only barren variables, simply remove them from the
1214
// set of potentials, else just project the potential
1215
MultiDimProjection
<
GUM_SCALAR
,
Potential
>
projector
(
LPNewprojPotential
);
1216
PotentialSet__
projected_pots
;
1217
for
(
const
auto
elt
:
pot2barren_var
) {
1218
// remove the current potential from pot_list as, anyway, we will change
1219
// it
1220
const
Potential
<
GUM_SCALAR
>*
pot
=
elt
.
first
;
1221
pot_list
.
erase
(
pot
);
1222
1223
// check whether we need to add a projected new potential or not (i.e.,
1224
// whether there exist non-barren variables or not)
1225
if
(
pot
->
variablesSequence
().
size
() !=
elt
.
second
.
size
()) {
1226
auto
new_pot
=
projector
.
project
(*
pot
,
elt
.
second
);
1227
pot_list
.
insert
(
new_pot
);
1228
projected_pots
.
insert
(
new_pot
);
1229
}
1230
}
1231
1232
return
projected_pots
;
1233
}
1234
1235
1236
// remove variables del_vars from the list of potentials pot_list
1237
template
<
typename
GUM_SCALAR
>
1238
Set
<
const
Potential
<
GUM_SCALAR
>* >
1239
LazyPropagation
<
GUM_SCALAR
>::
marginalizeOut__
(
1240
Set
<
const
Potential
<
GUM_SCALAR
>* >
pot_list
,
1241
Set
<
const
DiscreteVariable
* >&
del_vars
,
1242
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1243
// use d-separation analysis to check which potentials shall be combined
1244
findRelevantPotentialsXX__
(
pot_list
,
kept_vars
);
1245
1246
// remove the potentials corresponding to barren variables if we want
1247
// to exploit barren nodes
1248
PotentialSet__
barren_projected_potentials
;
1249
if
(
barren_nodes_type__
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
1250
barren_projected_potentials
=
removeBarrenVariables__
(
pot_list
,
del_vars
);
1251
}
1252
1253
// create a combine and project operator that will perform the
1254
// marginalization
1255
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
1256
combination_op__
,
1257
projection_op__
);
1258
PotentialSet__
new_pot_list
1259
=
combine_and_project
.
combineAndProject
(
pot_list
,
del_vars
);
1260
1261
// remove all the potentials that were created due to projections of
1262
// barren nodes and that are not part of the new_pot_list: these
1263
// potentials were just temporary potentials
1264
for
(
auto
iter
=
barren_projected_potentials
.
beginSafe
();
1265
iter
!=
barren_projected_potentials
.
endSafe
();
1266
++
iter
) {
1267
if
(!
new_pot_list
.
exists
(*
iter
))
delete
*
iter
;
1268
}
1269
1270
// remove all the potentials that have no dimension
1271
for
(
auto
iter_pot
=
new_pot_list
.
beginSafe
();
1272
iter_pot
!=
new_pot_list
.
endSafe
();
1273
++
iter_pot
) {
1274
if
((*
iter_pot
)->
variablesSequence
().
size
() == 0) {
1275
// as we have already marginalized out variables that received evidence,
1276
// it may be the case that, after combining and projecting, some
1277
// potentials might be empty. In this case, we shall keep their
1278
// constant and remove them from memory
1279
// # TODO: keep the constants!
1280
delete
*
iter_pot
;
1281
new_pot_list
.
erase
(
iter_pot
);
1282
}
1283
}
1284
1285
return
new_pot_list
;
1286
}
1287
1288
1289
// creates the message sent by clique from_id to clique to_id
1290
template
<
typename
GUM_SCALAR
>
1291
void
LazyPropagation
<
GUM_SCALAR
>::
produceMessage__
(
NodeId
from_id
,
1292
NodeId
to_id
) {
1293
// get the potentials of the clique.
1294
PotentialSet__
pot_list
=
clique_potentials__
[
from_id
];
1295
1296
// add the messages sent by adjacent nodes to from_id
1297
for
(
const
auto
other_id
:
JT__
->
neighbours
(
from_id
))
1298
if
(
other_id
!=
to_id
)
1299
pot_list
+=
separator_potentials__
[
Arc
(
other_id
,
from_id
)];
1300
1301
// get the set of variables that need be removed from the potentials
1302
const
NodeSet
&
from_clique
=
JT__
->
clique
(
from_id
);
1303
const
NodeSet
&
separator
=
JT__
->
separator
(
from_id
,
to_id
);
1304
Set
<
const
DiscreteVariable
* >
del_vars
(
from_clique
.
size
());
1305
Set
<
const
DiscreteVariable
* >
kept_vars
(
separator
.
size
());
1306
const
auto
&
bn
=
this
->
BN
();
1307
1308
for
(
const
auto
node
:
from_clique
) {
1309
if
(!
separator
.
contains
(
node
)) {
1310
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1311
}
else
{
1312
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1313
}
1314
}
1315
1316
// pot_list now contains all the potentials to multiply and marginalize
1317
// => combine the messages
1318
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1319
1320
// keep track of the newly created potentials but first replace all the
1321
// potentials whose values are all equal by constant potentials (nbrDim=0)
1322
// with this very value (as probability matrix multiplications
1323
// are tensorial, replacing the former potential by constants provides the
1324
// same computation results but speeds-up these computations)
1325
const
Arc
arc
(
from_id
,
to_id
);
1326
1327
if
(!
created_potentials__
.
exists
(
arc
))
1328
created_potentials__
.
insert
(
arc
,
PotentialSet__
());
1329
1330
for
(
auto
iter
=
new_pot_list
.
beginSafe
();
iter
!=
new_pot_list
.
endSafe
();
1331
++
iter
) {
1332
const
auto
pot
= *
iter
;
1333
1334
/*
1335
if (pot->variablesSequence().size() == 1) {
1336
bool is_constant = true;
1337
Instantiation inst(*pot);
1338
GUM_SCALAR first_val = pot->get(inst);
1339
1340
for (++inst; !inst.end(); ++inst) {
1341
if (pot->get(inst) != first_val) {
1342
is_constant = false;
1343
break;
1344
}
1345
}
1346
1347
if (is_constant) {
1348
// if pot is not a message sent by a separator or a potential stored
1349
// into the clique, we can remove it since it is now useless
1350
if (!pot_list.exists(pot)) delete pot;
1351
new_pot_list.erase(iter);
1352
1353
// add the new constant potential to new_pot_list
1354
const auto new_pot = new Potential<GUM_SCALAR>;
1355
Instantiation new_inst(new_pot);
1356
new_pot->set(new_inst, first_val);
1357
new_pot_list.insert (new_pot);
1358
created_potentials__[arc].insert(new_pot);
1359
continue;
1360
}
1361
}
1362
*/
1363
1364
if
(!
pot_list
.
exists
(
pot
)) {
created_potentials__
[
arc
].
insert
(
pot
); }
1365
}
1366
1367
separator_potentials__
[
arc
] =
std
::
move
(
new_pot_list
);
1368
messages_computed__
[
arc
] =
true
;
1369
}
1370
1371
1372
// performs a whole inference
1373
template
<
typename
GUM_SCALAR
>
1374
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
makeInference_
() {
1375
// collect messages for all single targets
1376
for
(
const
auto
node
:
this
->
targets
()) {
1377
// perform only collects in the join tree for nodes that have
1378
// not received hard evidence (those that received hard evidence were
1379
// not included into the join tree for speed-up reasons)
1380
if
(
graph__
.
exists
(
node
)) {
1381
collectMessage__
(
node_to_clique__
[
node
],
node_to_clique__
[
node
]);
1382
}
1383
}
1384
1385
// collect messages for all set targets
1386
// by parsing joint_target_to_clique__, we ensure that the cliques that
1387
// are referenced belong to the join tree (even if some of the nodes in
1388
// their associated joint_target do not belong to graph__)
1389
for
(
const
auto
set
:
joint_target_to_clique__
)
1390
collectMessage__
(
set
.
second
,
set
.
second
);
1391
}
1392
1393
1394
/// returns a fresh potential equal to P(1st arg,evidence)
1395
template
<
typename
GUM_SCALAR
>
1396
Potential
<
GUM_SCALAR
>*
1397
LazyPropagation
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
NodeId
id
) {
1398
const
auto
&
bn
=
this
->
BN
();
1399
1400
// hard evidence do not belong to the join tree
1401
// # TODO: check for sets of inconsistent hard evidence
1402
if
(
this
->
hardEvidenceNodes
().
contains
(
id
)) {
1403
return
new
Potential
<
GUM_SCALAR
>(*(
this
->
evidence
()[
id
]));
1404
}
1405
1406
// if we still need to perform some inference task, do it (this should
1407
// already have been done by makeInference_)
1408
NodeId
clique_of_id
=
node_to_clique__
[
id
];
1409
collectMessage__
(
clique_of_id
,
clique_of_id
);
1410
1411
// now we just need to create the product of the potentials of the clique
1412
// containing id with the messages received by this clique and
1413
// marginalize out all variables except id
1414
PotentialSet__
pot_list
=
clique_potentials__
[
clique_of_id
];
1415
1416
// add the messages sent by adjacent nodes to targetClique
1417
for
(
const
auto
other
:
JT__
->
neighbours
(
clique_of_id
))
1418
pot_list
+=
separator_potentials__
[
Arc
(
other
,
clique_of_id
)];
1419
1420
// get the set of variables that need be removed from the potentials
1421
const
NodeSet
&
nodes
=
JT__
->
clique
(
clique_of_id
);
1422
Set
<
const
DiscreteVariable
* >
kept_vars
{&(
bn
.
variable
(
id
))};
1423
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1424
for
(
const
auto
node
:
nodes
) {
1425
if
(
node
!=
id
)
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1426
}
1427
1428
// pot_list now contains all the potentials to multiply and marginalize
1429
// => combine the messages
1430
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1431
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1432
1433
if
(
new_pot_list
.
size
() == 1) {
1434
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1435
// if pot already existed, create a copy, so that we can put it into
1436
// the target_posteriors__ property
1437
if
(
pot_list
.
exists
(
joint
)) {
1438
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1439
}
else
{
1440
// remove the joint from new_pot_list so that it will not be
1441
// removed just after the else block
1442
new_pot_list
.
clear
();
1443
}
1444
}
else
{
1445
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1446
combination_op__
);
1447
joint
=
fast_combination
.
combine
(
new_pot_list
);
1448
}
1449
1450
// remove the potentials that were created in new_pot_list
1451
for
(
const
auto
pot
:
new_pot_list
)
1452
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1453
1454
// check that the joint posterior is different from a 0 vector: this would
1455
// indicate that some hard evidence are not compatible (their joint
1456
// probability is equal to 0)
1457
bool
nonzero_found
=
false
;
1458
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1459
if
(
joint
->
get
(
inst
)) {
1460
nonzero_found
=
true
;
1461
break
;
1462
}
1463
}
1464
if
(!
nonzero_found
) {
1465
// remove joint from memory to avoid memory leaks
1466
delete
joint
;
1467
GUM_ERROR
(
IncompatibleEvidence
,
1468
"some evidence entered into the Bayes "
1469
"net are incompatible (their joint proba = 0)"
);
1470
}
1471
return
joint
;
1472
}
1473
1474
1475
/// returns the posterior of a given variable
1476
template
<
typename
GUM_SCALAR
>
1477
const
Potential
<
GUM_SCALAR
>&
1478
LazyPropagation
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
1479
// check if we have already computed the posterior
1480
if
(
target_posteriors__
.
exists
(
id
)) {
return
*(
target_posteriors__
[
id
]); }
1481
1482
// compute the joint posterior and normalize
1483
auto
joint
=
unnormalizedJointPosterior_
(
id
);
1484
if
(
joint
->
sum
() != 1)
// hard test for ReadOnly CPT (as aggregator)
1485
joint
->
normalize
();
1486
target_posteriors__
.
insert
(
id
,
joint
);
1487
1488
return
*
joint
;
1489
}
1490
1491
1492
// returns the marginal a posteriori proba of a given node
1493
template
<
typename
GUM_SCALAR
>
1494
Potential
<
GUM_SCALAR
>*
1495
LazyPropagation
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
1496
const
NodeSet
&
set
) {
1497
// hard evidence do not belong to the join tree, so extract the nodes
1498
// from targets that are not hard evidence
1499
NodeSet
targets
=
set
,
hard_ev_nodes
;
1500
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
1501
if
(
targets
.
contains
(
node
)) {
1502
targets
.
erase
(
node
);
1503
hard_ev_nodes
.
insert
(
node
);
1504
}
1505
}
1506
1507
// if all the nodes have received hard evidence, then compute the
1508
// joint posterior directly by multiplying the hard evidence potentials
1509
const
auto
&
evidence
=
this
->
evidence
();
1510
if
(
targets
.
empty
()) {
1511
PotentialSet__
pot_list
;
1512
for
(
const
auto
node
:
set
) {
1513
pot_list
.
insert
(
evidence
[
node
]);
1514
}
1515
if
(
pot_list
.
size
() == 1) {
1516
auto
pot
=
new
Potential
<
GUM_SCALAR
>(**(
pot_list
.
begin
()));
1517
return
pot
;
1518
}
else
{
1519
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1520
combination_op__
);
1521
return
fast_combination
.
combine
(
pot_list
);
1522
}
1523
}
1524
1525
1526
// if we still need to perform some inference task, do it: so, first,
1527
// determine the clique on which we should perform collect to compute
1528
// the unnormalized joint posterior of a set of nodes containing "targets"
1529
NodeId
clique_of_set
;
1530
try
{
1531
clique_of_set
=
joint_target_to_clique__
[
set
];
1532
}
catch
(
NotFound
&) {
1533
// here, the precise set of targets does not belong to the set of targets
1534
// defined by the user. So we will try to find a clique in the junction
1535
// tree that contains "targets":
1536
1537
// 1/ we should check that all the nodes belong to the join tree
1538
for
(
const
auto
node
:
targets
) {
1539
if
(!
graph__
.
exists
(
node
)) {
1540
GUM_ERROR
(
UndefinedElement
,
node
<<
" is not a target node"
);
1541
}
1542
}
1543
1544
// 2/ the clique created by the first eliminated node among target is the
1545
// one we are looking for
1546
const
std
::
vector
<
NodeId
>&
JT_elim_order
1547
=
triangulation__
->
eliminationOrder
();
1548
1549
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
1550
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
;
1551
++
i
)
1552
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
1553
NodeId
first_eliminated_node
= *(
targets
.
begin
());
1554
int
elim_number
=
elim_order
[
first_eliminated_node
];
1555
for
(
const
auto
node
:
targets
) {
1556
if
(
elim_order
[
node
] <
elim_number
) {
1557
elim_number
=
elim_order
[
node
];
1558
first_eliminated_node
=
node
;
1559
}
1560
}
1561
1562
clique_of_set
1563
=
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
);
1564
1565
1566
// 3/ check that clique_of_set contains the all the nodes in the target
1567
const
NodeSet
&
clique_nodes
=
JT__
->
clique
(
clique_of_set
);
1568
for
(
const
auto
node
:
targets
) {
1569
if
(!
clique_nodes
.
contains
(
node
)) {
1570
GUM_ERROR
(
UndefinedElement
,
set
<<
" is not a joint target"
);
1571
}
1572
}
1573
1574
// add the discovered clique to joint_target_to_clique__
1575
joint_target_to_clique__
.
insert
(
set
,
clique_of_set
);
1576
}
1577
1578
// now perform a collect on the clique
1579
collectMessage__
(
clique_of_set
,
clique_of_set
);
1580
1581
// now we just need to create the product of the potentials of the clique
1582
// containing set with the messages received by this clique and
1583
// marginalize out all variables except set
1584
PotentialSet__
pot_list
=
clique_potentials__
[
clique_of_set
];
1585
1586
// add the messages sent by adjacent nodes to targetClique
1587
for
(
const
auto
other
:
JT__
->
neighbours
(
clique_of_set
))
1588
pot_list
+=
separator_potentials__
[
Arc
(
other
,
clique_of_set
)];
1589
1590
// get the set of variables that need be removed from the potentials
1591
const
NodeSet
&
nodes
=
JT__
->
clique
(
clique_of_set
);
1592
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1593
Set
<
const
DiscreteVariable
* >
kept_vars
(
targets
.
size
());
1594
const
auto
&
bn
=
this
->
BN
();
1595
for
(
const
auto
node
:
nodes
) {
1596
if
(!
targets
.
contains
(
node
)) {
1597
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1598
}
else
{
1599
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1600
}
1601
}
1602
1603
// pot_list now contains all the potentials to multiply and marginalize
1604
// => combine the messages
1605
PotentialSet__
new_pot_list
=
marginalizeOut__
(
pot_list
,
del_vars
,
kept_vars
);
1606
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1607
1608
if
((
new_pot_list
.
size
() == 1) &&
hard_ev_nodes
.
empty
()) {
1609
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1610
1611
// if pot already existed, create a copy, so that we can put it into
1612
// the target_posteriors__ property
1613
if
(
pot_list
.
exists
(
joint
)) {
1614
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1615
}
else
{
1616
// remove the joint from new_pot_list so that it will not be
1617
// removed just after the next else block
1618
new_pot_list
.
clear
();
1619
}
1620
}
else
{
1621
// combine all the potentials in new_pot_list with all the hard evidence
1622
// of the nodes in set
1623
PotentialSet__
new_new_pot_list
=
new_pot_list
;
1624
for
(
const
auto
node
:
hard_ev_nodes
) {
1625
new_new_pot_list
.
insert
(
evidence
[
node
]);
1626
}
1627
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1628
combination_op__
);
1629
joint
=
fast_combination
.
combine
(
new_new_pot_list
);
1630
}
1631
1632
// remove the potentials that were created in new_pot_list
1633
for
(
const
auto
pot
:
new_pot_list
)
1634
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1635
1636
// check that the joint posterior is different from a 0 vector: this would
1637
// indicate that some hard evidence are not compatible
1638
bool
nonzero_found
=
false
;
1639
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1640
if
((*
joint
)[
inst
]) {
1641
nonzero_found
=
true
;
1642
break
;
1643
}
1644
}
1645
if
(!
nonzero_found
) {
1646
// remove joint from memory to avoid memory leaks
1647
delete
joint
;
1648
GUM_ERROR
(
IncompatibleEvidence
,
1649
"some evidence entered into the Bayes "
1650
"net are incompatible (their joint proba = 0)"
);
1651
}
1652
1653
return
joint
;
1654
}
1655
1656
1657
/// returns the posterior of a given set of variables
1658
template
<
typename
GUM_SCALAR
>
1659
const
Potential
<
GUM_SCALAR
>&
1660
LazyPropagation
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
set
) {
1661
// check if we have already computed the posterior
1662
if
(
joint_target_posteriors__
.
exists
(
set
)) {
1663
return
*(
joint_target_posteriors__
[
set
]);
1664
}
1665
1666
// compute the joint posterior and normalize
1667
auto
joint
=
unnormalizedJointPosterior_
(
set
);
1668
joint
->
normalize
();
1669
joint_target_posteriors__
.
insert
(
set
,
joint
);
1670
1671
return
*
joint
;
1672
}
1673
1674
1675
/// returns the posterior of a given set of variables
1676
template
<
typename
GUM_SCALAR
>
1677
const
Potential
<
GUM_SCALAR
>&
LazyPropagation
<
GUM_SCALAR
>::
jointPosterior_
(
1678
const
NodeSet
&
wanted_target
,
1679
const
NodeSet
&
declared_target
) {
1680
// check if we have already computed the posterior of wanted_target
1681
if
(
joint_target_posteriors__
.
exists
(
wanted_target
))
1682
return
*(
joint_target_posteriors__
[
wanted_target
]);
1683
1684
// here, we will have to compute the posterior of declared_target and
1685
// marginalize out all the variables that do not belong to wanted_target
1686
1687
// check if we have already computed the posterior of declared_target
1688
if
(!
joint_target_posteriors__
.
exists
(
declared_target
)) {
1689
jointPosterior_
(
declared_target
);
1690
}
1691
1692
// marginalize out all the variables that do not belong to wanted_target
1693
const
auto
&
bn
=
this
->
BN
();
1694
Set
<
const
DiscreteVariable
* >
del_vars
;
1695
for
(
const
auto
node
:
declared_target
)
1696
if
(!
wanted_target
.
contains
(
node
))
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1697
Potential
<
GUM_SCALAR
>*
pot
=
new
Potential
<
GUM_SCALAR
>(
1698
joint_target_posteriors__
[
declared_target
]->
margSumOut
(
del_vars
));
1699
1700
// save the result into the cache
1701
joint_target_posteriors__
.
insert
(
wanted_target
,
pot
);
1702
1703
return
*
pot
;
1704
}
1705
1706
1707
template
<
typename
GUM_SCALAR
>
1708
GUM_SCALAR
LazyPropagation
<
GUM_SCALAR
>::
evidenceProbability
() {
1709
// here, we should check that find_relevant_potential_type__ is equal to
1710
// FIND_ALL. Otherwise, the computations could be wrong.
1711
RelevantPotentialsFinderType
old_relevant_type
1712
=
find_relevant_potential_type__
;
1713
1714
// if the relevant potentials finder is not equal to FIND_ALL, all the
1715
// current computations may lead to incorrect results, so we shall
1716
// discard them
1717
if
(
old_relevant_type
!=
RelevantPotentialsFinderType
::
FIND_ALL
) {
1718
find_relevant_potential_type__
=
RelevantPotentialsFinderType
::
FIND_ALL
;
1719
is_new_jt_needed__
=
true
;
1720
this
->
setOutdatedStructureState_
();
1721
}
1722
1723
// perform inference in each connected component
1724
this
->
makeInference
();
1725
1726
// for each connected component, select a variable X and compute the
1727
// joint probability of X and evidence e. Then marginalize-out X to get
1728
// p(e) in this connected component. Finally, multiply all the p(e) that
1729
// we got and the elements in constants__. The result is the probability
1730
// of evidence
1731
1732
GUM_SCALAR
prob_ev
= 1;
1733
for
(
const
auto
root
:
roots__
) {
1734
// get a node in the clique
1735
NodeId
node
= *(
JT__
->
clique
(
root
).
begin
());
1736
Potential
<
GUM_SCALAR
>*
tmp
=
unnormalizedJointPosterior_
(
node
);
1737
GUM_SCALAR
sum
= 0;
1738
for
(
Instantiation
iter
(*
tmp
); !
iter
.
end
(); ++
iter
)
1739
sum
+=
tmp
->
get
(
iter
);
1740
1741
prob_ev
*=
sum
;
1742
delete
tmp
;
1743
}
1744
1745
for
(
const
auto
&
projected_cpt
:
constants__
)
1746
prob_ev
*=
projected_cpt
.
second
;
1747
1748
// put back the relevant potential type selected by the user
1749
find_relevant_potential_type__
=
old_relevant_type
;
1750
1751
return
prob_ev
;
1752
}
1753
1754
}
/* namespace gum */
1755
1756
#
endif
// DOXYGEN_SHOULD_SKIP_THIS
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669