aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
lazyPropagation_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of lazy propagation for inference in
25
* Bayesian networks.
26
*
27
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
30
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
31
#
include
<
algorithm
>
32
33
#
include
<
agrum
/
BN
/
algorithms
/
BayesBall
.
h
>
34
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
35
#
include
<
agrum
/
BN
/
algorithms
/
dSeparation
.
h
>
36
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
37
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
38
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
39
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimProjection
.
h
>
40
41
42
namespace
gum
{
43
// default constructor
44
template
<
typename
GUM_SCALAR >
45
INLINE LazyPropagation<
GUM_SCALAR
>::
LazyPropagation
(
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
46
RelevantPotentialsFinderType
relevant_type
,
47
FindBarrenNodesType
barren_type
,
48
bool
use_binary_join_tree
) :
49
JointTargetedInference
<
GUM_SCALAR
>(
BN
),
50
EvidenceInference
<
GUM_SCALAR
>(
BN
),
_use_binary_join_tree_
(
use_binary_join_tree
) {
51
// sets the relevant potential and the barren nodes finding algorithm
52
setRelevantPotentialsFinderType
(
relevant_type
);
53
setFindBarrenNodesType
(
barren_type
);
54
55
// create a default triangulation (the user can change it afterwards)
56
_triangulation_
=
new
DefaultTriangulation
;
57
58
GUM_CONSTRUCTOR
(
LazyPropagation
);
59
}
60
61
62
// destructor
63
template
<
typename
GUM_SCALAR
>
64
INLINE
LazyPropagation
<
GUM_SCALAR
>::~
LazyPropagation
() {
65
// remove all the potentials created during the last message passing
66
for
(
const
auto
&
pots
:
_created_potentials_
)
67
for
(
const
auto
pot
:
pots
.
second
)
68
delete
pot
;
69
70
// remove the potentials created after removing the nodes that received
71
// hard evidence
72
for
(
const
auto
&
pot
:
_hard_ev_projected_CPTs_
)
73
delete
pot
.
second
;
74
75
// remove all the posteriors computed
76
for
(
const
auto
&
pot
:
_target_posteriors_
)
77
delete
pot
.
second
;
78
for
(
const
auto
&
pot
:
_joint_target_posteriors_
)
79
delete
pot
.
second
;
80
81
// remove the junction tree and the triangulation algorithm
82
if
(
_JT_
!=
nullptr
)
delete
_JT_
;
83
if
(
_junctionTree_
!=
nullptr
)
delete
_junctionTree_
;
84
delete
_triangulation_
;
85
86
GUM_DESTRUCTOR
(
LazyPropagation
);
87
}
88
89
90
/// set a new triangulation algorithm
91
template
<
typename
GUM_SCALAR
>
92
void
LazyPropagation
<
GUM_SCALAR
>::
setTriangulation
(
const
Triangulation
&
new_triangulation
) {
93
delete
_triangulation_
;
94
_triangulation_
=
new_triangulation
.
newFactory
();
95
_is_new_jt_needed_
=
true
;
96
this
->
setOutdatedStructureState_
();
97
}
98
99
100
/// returns the current join (or junction) tree used
101
template
<
typename
GUM_SCALAR
>
102
INLINE
const
JoinTree
*
LazyPropagation
<
GUM_SCALAR
>::
joinTree
() {
103
_createNewJT_
();
104
105
return
_JT_
;
106
}
107
108
/// returns the current junction tree
109
template
<
typename
GUM_SCALAR
>
110
INLINE
const
JunctionTree
*
LazyPropagation
<
GUM_SCALAR
>::
junctionTree
() {
111
_createNewJT_
();
112
113
return
_junctionTree_
;
114
}
115
116
117
/// sets how we determine the relevant potentials to combine
118
template
<
typename
GUM_SCALAR
>
119
void
LazyPropagation
<
GUM_SCALAR
>::
setRelevantPotentialsFinderType
(
120
RelevantPotentialsFinderType
type
) {
121
if
(
type
!=
_find_relevant_potential_type_
) {
122
switch
(
type
) {
123
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
124
_findRelevantPotentials_
125
= &
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation2_
;
126
break
;
127
128
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
129
_findRelevantPotentials_
130
= &
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation_
;
131
break
;
132
133
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
134
_findRelevantPotentials_
135
= &
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation3_
;
136
break
;
137
138
case
RelevantPotentialsFinderType
::
FIND_ALL
:
139
_findRelevantPotentials_
= &
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsGetAll_
;
140
break
;
141
142
default
:
143
GUM_ERROR
(
InvalidArgument
,
144
"setRelevantPotentialsFinderType for type "
<< (
unsigned
int
)
type
145
<<
" is not implemented yet"
);
146
}
147
148
_find_relevant_potential_type_
=
type
;
149
150
// indicate that all messages need be reconstructed to take into account
151
// the change in d-separation analysis
152
_invalidateAllMessages_
();
153
}
154
}
155
156
157
/// sets the operator for performing the projections
158
template
<
typename
GUM_SCALAR
>
159
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
_setProjectionFunction_
(
Potential
<
GUM_SCALAR
>* (
160
*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
const
Set
<
const
DiscreteVariable
* >&)) {
161
_projection_op_
=
proj
;
162
}
163
164
165
/// sets the operator for performing the combinations
166
template
<
typename
GUM_SCALAR
>
167
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
_setCombinationFunction_
(
Potential
<
GUM_SCALAR
>* (
168
*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
const
Potential
<
GUM_SCALAR
>&)) {
169
_combination_op_
=
comb
;
170
}
171
172
173
/// invalidate all messages, posteriors and created potentials
174
template
<
typename
GUM_SCALAR
>
175
void
LazyPropagation
<
GUM_SCALAR
>::
_invalidateAllMessages_
() {
176
// remove all the messages computed
177
for
(
auto
&
potset
:
_separator_potentials_
)
178
potset
.
second
.
clear
();
179
for
(
auto
&
mess_computed
:
_messages_computed_
)
180
mess_computed
.
second
=
false
;
181
182
// remove all the created potentials
183
for
(
const
auto
&
potset
:
_created_potentials_
)
184
for
(
const
auto
pot
:
potset
.
second
)
185
delete
pot
;
186
187
// remove all the posteriors
188
for
(
const
auto
&
pot
:
_target_posteriors_
)
189
delete
pot
.
second
;
190
for
(
const
auto
&
pot
:
_joint_target_posteriors_
)
191
delete
pot
.
second
;
192
193
// indicate that new messages need be computed
194
if
(
this
->
isInferenceReady
() ||
this
->
isInferenceDone
())
this
->
setOutdatedPotentialsState_
();
195
}
196
197
198
/// sets how we determine barren nodes
199
template
<
typename
GUM_SCALAR
>
200
void
LazyPropagation
<
GUM_SCALAR
>::
setFindBarrenNodesType
(
FindBarrenNodesType
type
) {
201
if
(
type
!=
_barren_nodes_type_
) {
202
// WARNING: if a new type is added here, method _createJT_ should
203
// certainly
204
// be updated as well, in particular its step 2.
205
switch
(
type
) {
206
case
FindBarrenNodesType
::
FIND_BARREN_NODES
:
207
case
FindBarrenNodesType
::
FIND_NO_BARREN_NODES
:
208
break
;
209
210
default
:
211
GUM_ERROR
(
InvalidArgument
,
212
"setFindBarrenNodesType for type "
<< (
unsigned
int
)
type
213
<<
" is not implemented yet"
);
214
}
215
216
_barren_nodes_type_
=
type
;
217
218
// potentially, we may need to reconstruct a junction tree
219
this
->
setOutdatedStructureState_
();
220
}
221
}
222
223
224
/// fired when a new evidence is inserted
225
template
<
typename
GUM_SCALAR
>
226
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
id
,
227
bool
isHardEvidence
) {
228
// if we have a new hard evidence, this modifies the undigraph over which
229
// the join tree is created. This is also the case if id is not a node of
230
// of the undigraph
231
if
(
isHardEvidence
|| !
_graph_
.
exists
(
id
))
232
_is_new_jt_needed_
=
true
;
233
else
{
234
try
{
235
_evidence_changes_
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ADDED
);
236
}
catch
(
DuplicateElement
&) {
237
// here, the evidence change already existed. This necessarily means
238
// that the current saved change is an EVIDENCE_ERASED. So if we
239
// erased the evidence and added some again, this corresponds to an
240
// EVIDENCE_MODIFIED
241
_evidence_changes_
[
id
] =
EvidenceChangeType
::
EVIDENCE_MODIFIED
;
242
}
243
}
244
}
245
246
247
/// fired when an evidence is removed
248
template
<
typename
GUM_SCALAR
>
249
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onEvidenceErased_
(
const
NodeId
id
,
250
bool
isHardEvidence
) {
251
// if we delete a hard evidence, this modifies the undigraph over which
252
// the join tree is created.
253
if
(
isHardEvidence
)
254
_is_new_jt_needed_
=
true
;
255
else
{
256
try
{
257
_evidence_changes_
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
258
}
catch
(
DuplicateElement
&) {
259
// here, the evidence change already existed and it is necessarily an
260
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
261
// been added and is now erased, this is similar to not having created
262
// it. If the evidence was only modified, it already existed in the
263
// last inference and we should now indicate that it has been removed.
264
if
(
_evidence_changes_
[
id
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
265
_evidence_changes_
.
erase
(
id
);
266
else
267
_evidence_changes_
[
id
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
268
}
269
}
270
}
271
272
273
/// fired when all the evidence are erased
274
template
<
typename
GUM_SCALAR
>
275
void
LazyPropagation
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
bool
has_hard_evidence
) {
276
if
(
has_hard_evidence
|| !
this
->
hardEvidenceNodes
().
empty
())
277
_is_new_jt_needed_
=
true
;
278
else
{
279
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
280
try
{
281
_evidence_changes_
.
insert
(
node
,
EvidenceChangeType
::
EVIDENCE_ERASED
);
282
}
catch
(
DuplicateElement
&) {
283
// here, the evidence change already existed and it is necessarily an
284
// EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
285
// been added and is now erased, this is similar to not having created
286
// it. If the evidence was only modified, it already existed in the
287
// last inference and we should now indicate that it has been removed.
288
if
(
_evidence_changes_
[
node
] ==
EvidenceChangeType
::
EVIDENCE_ADDED
)
289
_evidence_changes_
.
erase
(
node
);
290
else
291
_evidence_changes_
[
node
] =
EvidenceChangeType
::
EVIDENCE_ERASED
;
292
}
293
}
294
}
295
}
296
297
298
/// fired when an evidence is changed
299
template
<
typename
GUM_SCALAR
>
300
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onEvidenceChanged_
(
const
NodeId
id
,
301
bool
hasChangedSoftHard
) {
302
if
(
hasChangedSoftHard
)
303
_is_new_jt_needed_
=
true
;
304
else
{
305
try
{
306
_evidence_changes_
.
insert
(
id
,
EvidenceChangeType
::
EVIDENCE_MODIFIED
);
307
}
catch
(
DuplicateElement
&) {
308
// here, the evidence change already existed and it is necessarily an
309
// EVIDENCE_ADDED. So we should keep this state to indicate that this
310
// evidence is new w.r.t. the last inference
311
}
312
}
313
}
314
315
316
/// fired after a new Bayes net has been assigned to the engine
317
template
<
typename
GUM_SCALAR
>
318
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
bn
) {}
319
320
321
/// fired after a new target is inserted
322
template
<
typename
GUM_SCALAR
>
323
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
const
NodeId
id
) {}
324
325
326
/// fired before a target is removed
327
template
<
typename
GUM_SCALAR
>
328
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
const
NodeId
id
) {}
329
330
331
/// fired after a new set target is inserted
332
template
<
typename
GUM_SCALAR
>
333
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onJointTargetAdded_
(
const
NodeSet
&
set
) {}
334
335
336
/// fired before a set target is removed
337
template
<
typename
GUM_SCALAR
>
338
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onJointTargetErased_
(
const
NodeSet
&
set
) {}
339
340
341
/// fired after all the nodes of the BN are added as single targets
342
template
<
typename
GUM_SCALAR
>
343
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {}
344
345
346
/// fired before a all the single_targets are removed
347
template
<
typename
GUM_SCALAR
>
348
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
349
350
351
/// fired before a all the joint_targets are removed
352
template
<
typename
GUM_SCALAR
>
353
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllJointTargetsErased_
() {}
354
355
356
/// fired before a all the single and joint_targets are removed
357
template
<
typename
GUM_SCALAR
>
358
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
onAllTargetsErased_
() {}
359
360
361
// check whether a new junction tree is really needed for the next inference
362
template
<
typename
GUM_SCALAR
>
363
bool
LazyPropagation
<
GUM_SCALAR
>::
_isNewJTNeeded_
()
const
{
364
// if we do not have a JT or if _new_jt_needed_ is set to true, then
365
// we know that we need to create a new join tree
366
if
((
_JT_
==
nullptr
) ||
_is_new_jt_needed_
)
return
true
;
367
368
// if some some targets do not belong to the join tree and, consequently,
369
// to the undigraph that was used to construct the join tree, then we need
370
// to create a new JT. This situation may occur if we constructed the
371
// join tree after pruning irrelevant/barren nodes from the BN)
372
// however, note that the nodes that received hard evidence do not belong to
373
// the graph and, therefore, should not be taken into account
374
const
auto
&
hard_ev_nodes
=
this
->
hardEvidenceNodes
();
375
for
(
const
auto
node
:
this
->
targets
()) {
376
if
(!
_graph_
.
exists
(
node
) && !
hard_ev_nodes
.
exists
(
node
))
return
true
;
377
}
378
for
(
const
auto
&
joint_target
:
this
->
jointTargets
()) {
379
// here, we need to check that at least one clique contains all the
380
// nodes of the joint target.
381
bool
containing_clique_found
=
false
;
382
for
(
const
auto
node
:
joint_target
) {
383
bool
found
=
true
;
384
try
{
385
const
NodeSet
&
clique
=
_JT_
->
clique
(
_node_to_clique_
[
node
]);
386
for
(
const
auto
xnode
:
joint_target
) {
387
if
(!
clique
.
contains
(
xnode
) && !
hard_ev_nodes
.
exists
(
xnode
)) {
388
found
=
false
;
389
break
;
390
}
391
}
392
}
catch
(
NotFound
&) {
found
=
false
; }
393
394
if
(
found
) {
395
containing_clique_found
=
true
;
396
break
;
397
}
398
}
399
400
if
(!
containing_clique_found
)
return
true
;
401
}
402
403
// if some new evidence have been added on nodes that do not belong
404
// to _graph_, then we potentially have to reconstruct the join tree
405
for
(
const
auto
&
change
:
_evidence_changes_
) {
406
if
((
change
.
second
==
EvidenceChangeType
::
EVIDENCE_ADDED
) && !
_graph_
.
exists
(
change
.
first
))
407
return
true
;
408
}
409
410
// here, the current JT is exactly what we need for the next inference
411
return
false
;
412
}
413
414
415
/// create a new junction tree as well as its related data structures
416
template
<
typename
GUM_SCALAR
>
417
void
LazyPropagation
<
GUM_SCALAR
>::
_createNewJT_
() {
418
// to create the JT, we first create the moral graph of the BN in the
419
// following way in order to take into account the barren nodes and the
420
// nodes that received evidence:
421
// 1/ we create an undirected graph containing only the nodes and no edge
422
// 2/ if we take into account barren nodes, remove them from the graph
423
// 3/ add edges so that each node and its parents in the BN form a clique
424
// 4/ add edges so that set targets are cliques of the moral graph
425
// 5/ remove the nodes that received hard evidence (by step 3/, their
426
// parents are linked by edges, which is necessary for inference)
427
//
428
// At the end of step 5/, we have our moral graph and we can triangulate it
429
// to get the new junction tree
430
431
// 1/ create an undirected graph containing only the nodes and no edge
432
const
auto
&
bn
=
this
->
BN
();
433
_graph_
.
clear
();
434
for
(
const
auto
node
:
bn
.
dag
())
435
_graph_
.
addNodeWithId
(
node
);
436
437
// 2/ if we wish to exploit barren nodes, we shall remove them from the BN
438
// to do so: we identify all the nodes that are not targets and have
439
// received
440
// no evidence and such that their descendants are neither targets nor
441
// evidence nodes. Such nodes can be safely discarded from the BN without
442
// altering the inference output
443
if
(
_barren_nodes_type_
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
444
// identify the barren nodes
445
NodeSet
target_nodes
=
this
->
targets
();
446
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
447
target_nodes
+=
nodeset
;
448
}
449
450
// check that all the nodes are not targets, otherwise, there is no
451
// barren node
452
if
(
target_nodes
.
size
() !=
bn
.
size
()) {
453
BarrenNodesFinder
finder
(&(
bn
.
dag
()));
454
finder
.
setTargets
(&
target_nodes
);
455
456
NodeSet
evidence_nodes
;
457
for
(
const
auto
&
pair
:
this
->
evidence
()) {
458
evidence_nodes
.
insert
(
pair
.
first
);
459
}
460
finder
.
setEvidence
(&
evidence_nodes
);
461
462
NodeSet
barren_nodes
=
finder
.
barrenNodes
();
463
464
// remove the barren nodes from the moral graph
465
for
(
const
auto
node
:
barren_nodes
) {
466
_graph_
.
eraseNode
(
node
);
467
}
468
}
469
}
470
471
// 3/ add edges so that each node and its parents in the BN form a clique
472
for
(
const
auto
node
:
_graph_
) {
473
const
NodeSet
&
parents
=
bn
.
parents
(
node
);
474
for
(
auto
iter1
=
parents
.
cbegin
();
iter1
!=
parents
.
cend
(); ++
iter1
) {
475
_graph_
.
addEdge
(*
iter1
,
node
);
476
auto
iter2
=
iter1
;
477
for
(++
iter2
;
iter2
!=
parents
.
cend
(); ++
iter2
) {
478
_graph_
.
addEdge
(*
iter1
, *
iter2
);
479
}
480
}
481
}
482
483
// 4/ if there exist some joint targets, we shall add new edges
484
// into the moral graph in order to ensure that there exists a clique
485
// containing each joint
486
for
(
const
auto
&
nodeset
:
this
->
jointTargets
()) {
487
for
(
auto
iter1
=
nodeset
.
cbegin
();
iter1
!=
nodeset
.
cend
(); ++
iter1
) {
488
auto
iter2
=
iter1
;
489
for
(++
iter2
;
iter2
!=
nodeset
.
cend
(); ++
iter2
) {
490
_graph_
.
addEdge
(*
iter1
, *
iter2
);
491
}
492
}
493
}
494
495
// 5/ remove all the nodes that received hard evidence
496
_hard_ev_nodes_
=
this
->
hardEvidenceNodes
();
497
for
(
const
auto
node
:
_hard_ev_nodes_
) {
498
_graph_
.
eraseNode
(
node
);
499
}
500
501
502
// now, we can compute the new junction tree. To speed-up computations
503
// (essentially, those of a distribution phase), we construct from this
504
// junction tree a binary join tree
505
if
(
_JT_
!=
nullptr
)
delete
_JT_
;
506
if
(
_junctionTree_
!=
nullptr
)
delete
_junctionTree_
;
507
508
_triangulation_
->
setGraph
(&
_graph_
, &(
this
->
domainSizes
()));
509
const
JunctionTree
&
triang_jt
=
_triangulation_
->
junctionTree
();
510
if
(
_use_binary_join_tree_
) {
511
BinaryJoinTreeConverterDefault
bjt_converter
;
512
NodeSet
emptyset
;
513
_JT_
=
new
CliqueGraph
(
bjt_converter
.
convert
(
triang_jt
,
this
->
domainSizes
(),
emptyset
));
514
}
else
{
515
_JT_
=
new
CliqueGraph
(
triang_jt
);
516
}
517
_junctionTree_
=
new
CliqueGraph
(
triang_jt
);
518
519
520
// indicate, for each node of the moral graph a clique in _JT_ that can
521
// contain its conditional probability table
522
_node_to_clique_
.
clear
();
523
const
std
::
vector
<
NodeId
>&
JT_elim_order
=
_triangulation_
->
eliminationOrder
();
524
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
525
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
; ++
i
)
526
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
527
const
DAG
&
dag
=
bn
.
dag
();
528
for
(
const
auto
node
:
_graph_
) {
529
// get the variables in the potential of node (and its parents)
530
NodeId
first_eliminated_node
=
node
;
531
int
elim_number
=
elim_order
[
first_eliminated_node
];
532
533
for
(
const
auto
parent
:
dag
.
parents
(
node
)) {
534
if
(
_graph_
.
existsNode
(
parent
) && (
elim_order
[
parent
] <
elim_number
)) {
535
elim_number
=
elim_order
[
parent
];
536
first_eliminated_node
=
parent
;
537
}
538
}
539
540
// first_eliminated_node contains the first var (node or one of its
541
// parents) eliminated => the clique created during its elimination
542
// contains node and all of its parents => it can contain the potential
543
// assigned to the node in the BN
544
_node_to_clique_
.
insert
(
node
,
545
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
));
546
}
547
548
// do the same for the nodes that received evidence. Here, we only store
549
// the nodes for which at least one parent belongs to _graph_ (otherwise
550
// their CPT is just a constant real number).
551
for
(
const
auto
node
:
_hard_ev_nodes_
) {
552
// get the set of parents of the node that belong to _graph_
553
NodeSet
pars
(
dag
.
parents
(
node
).
size
());
554
for
(
const
auto
par
:
dag
.
parents
(
node
))
555
if
(
_graph_
.
exists
(
par
))
pars
.
insert
(
par
);
556
557
if
(!
pars
.
empty
()) {
558
NodeId
first_eliminated_node
= *(
pars
.
begin
());
559
int
elim_number
=
elim_order
[
first_eliminated_node
];
560
561
for
(
const
auto
parent
:
pars
) {
562
if
(
elim_order
[
parent
] <
elim_number
) {
563
elim_number
=
elim_order
[
parent
];
564
first_eliminated_node
=
parent
;
565
}
566
}
567
568
// first_eliminated_node contains the first var (node or one of its
569
// parents) eliminated => the clique created during its elimination
570
// contains node and all of its parents => it can contain the potential
571
// assigned to the node in the BN
572
_node_to_clique_
.
insert
(
node
,
573
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
));
574
}
575
}
576
577
// indicate for each joint_target a clique that contains it
578
_joint_target_to_clique_
.
clear
();
579
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
580
// remove from set all the nodes that received hard evidence (since they
581
// do not belong to the join tree)
582
NodeSet
nodeset
=
set
;
583
for
(
const
auto
node
:
_hard_ev_nodes_
)
584
if
(
nodeset
.
contains
(
node
))
nodeset
.
erase
(
node
);
585
586
if
(!
nodeset
.
empty
()) {
587
// the clique we are looking for is the one that was created when
588
// the first element of nodeset was eliminated
589
NodeId
first_eliminated_node
= *(
nodeset
.
begin
());
590
int
elim_number
=
elim_order
[
first_eliminated_node
];
591
for
(
const
auto
node
:
nodeset
) {
592
if
(
elim_order
[
node
] <
elim_number
) {
593
elim_number
=
elim_order
[
node
];
594
first_eliminated_node
=
node
;
595
}
596
}
597
598
_joint_target_to_clique_
.
insert
(
599
set
,
600
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
));
601
}
602
}
603
604
605
// compute the roots of _JT_'s connected components
606
_computeJoinTreeRoots_
();
607
608
// create empty potential lists into the cliques of the joint tree as well
609
// as empty lists of evidence
610
_PotentialSet_
empty_set
;
611
_clique_potentials_
.
clear
();
612
_node_to_soft_evidence_
.
clear
();
613
for
(
const
auto
node
: *
_JT_
) {
614
_clique_potentials_
.
insert
(
node
,
empty_set
);
615
}
616
617
// remove all the potentials created during the last inference
618
for
(
const
auto
&
potlist
:
_created_potentials_
)
619
for
(
const
auto
pot
:
potlist
.
second
)
620
delete
pot
;
621
_created_potentials_
.
clear
();
622
623
// remove all the potentials created to take into account hard evidence
624
// during the last inference
625
for
(
const
auto
pot_pair
:
_hard_ev_projected_CPTs_
)
626
delete
pot_pair
.
second
;
627
_hard_ev_projected_CPTs_
.
clear
();
628
629
// remove all the constants created due to projections of CPTs that were
630
// defined over only hard evidence nodes
631
_constants_
.
clear
();
632
633
// create empty lists of potentials for the messages and indicate that no
634
// message has been computed yet
635
_separator_potentials_
.
clear
();
636
_messages_computed_
.
clear
();
637
for
(
const
auto
&
edge
:
_JT_
->
edges
()) {
638
const
Arc
arc1
(
edge
.
first
(),
edge
.
second
());
639
_separator_potentials_
.
insert
(
arc1
,
empty_set
);
640
_messages_computed_
.
insert
(
arc1
,
false
);
641
const
Arc
arc2
(
Arc
(
edge
.
second
(),
edge
.
first
()));
642
_separator_potentials_
.
insert
(
arc2
,
empty_set
);
643
_messages_computed_
.
insert
(
arc2
,
false
);
644
}
645
646
// remove all the posteriors computed so far
647
for
(
const
auto
&
pot
:
_target_posteriors_
)
648
delete
pot
.
second
;
649
_target_posteriors_
.
clear
();
650
for
(
const
auto
&
pot
:
_joint_target_posteriors_
)
651
delete
pot
.
second
;
652
_joint_target_posteriors_
.
clear
();
653
654
655
// put all the CPT's of the Bayes net nodes into the cliques
656
// here, beware: all the potentials that are defined over some nodes
657
// including hard evidence must be projected so that these nodes are
658
// removed from the potential
659
const
auto
&
evidence
=
this
->
evidence
();
660
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
661
for
(
const
auto
node
:
dag
) {
662
if
(
_graph_
.
exists
(
node
) ||
_hard_ev_nodes_
.
contains
(
node
)) {
663
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
664
665
// get the list of nodes with hard evidence in cpt
666
NodeSet
hard_nodes
;
667
const
auto
&
variables
=
cpt
.
variablesSequence
();
668
for
(
const
auto
var
:
variables
) {
669
NodeId
xnode
=
bn
.
nodeId
(*
var
);
670
if
(
_hard_ev_nodes_
.
contains
(
xnode
))
hard_nodes
.
insert
(
xnode
);
671
}
672
673
// if hard_nodes contains hard evidence nodes, perform a projection
674
// and insert the result into the appropriate clique, else insert
675
// directly cpt into the clique
676
if
(
hard_nodes
.
empty
()) {
677
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(&
cpt
);
678
}
else
{
679
// marginalize out the hard evidence nodes: if the cpt is defined
680
// only over nodes that received hard evidence, do not consider it
681
// as a potential anymore but as a constant
682
if
(
hard_nodes
.
size
() ==
variables
.
size
()) {
683
Instantiation
inst
;
684
const
auto
&
vars
=
cpt
.
variablesSequence
();
685
for
(
const
auto
var
:
vars
)
686
inst
<< *
var
;
687
for
(
Size
i
= 0;
i
<
hard_nodes
.
size
(); ++
i
) {
688
inst
.
chgVal
(
variables
[
i
],
hard_evidence
[
bn
.
nodeId
(*(
variables
[
i
]))]);
689
}
690
_constants_
.
insert
(
node
,
cpt
.
get
(
inst
));
691
}
else
{
692
// perform the projection with a combine and project instance
693
Set
<
const
DiscreteVariable
* >
hard_variables
;
694
_PotentialSet_
marg_cpt_set
{&
cpt
};
695
for
(
const
auto
xnode
:
hard_nodes
) {
696
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
697
hard_variables
.
insert
(&(
bn
.
variable
(
xnode
)));
698
}
699
700
// perform the combination of those potentials and their projection
701
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
702
_combination_op_
,
703
LPNewprojPotential
);
704
_PotentialSet_
new_cpt_list
705
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
706
707
// there should be only one potential in new_cpt_list
708
if
(
new_cpt_list
.
size
() != 1) {
709
// remove the CPT created to avoid memory leaks
710
for
(
const
auto
pot
:
new_cpt_list
) {
711
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
712
}
713
GUM_ERROR
(
FatalError
,
714
"the projection of a potential containing "
715
<<
"hard evidence is empty!"
);
716
}
717
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
718
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
projected_cpt
);
719
_hard_ev_projected_CPTs_
.
insert
(
node
,
projected_cpt
);
720
}
721
}
722
}
723
}
724
725
// we shall now add all the potentials of the soft evidence
726
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
727
_node_to_soft_evidence_
.
insert
(
node
,
evidence
[
node
]);
728
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
evidence
[
node
]);
729
}
730
731
// indicate that the data structures are up to date.
732
_evidence_changes_
.
clear
();
733
_is_new_jt_needed_
=
false
;
734
}
735
736
737
/// prepare the inference structures w.r.t. new targets, soft/hard evidence
738
template
<
typename
GUM_SCALAR
>
739
void
LazyPropagation
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {
740
// check if a new JT is really needed. If so, create it
741
if
(
_isNewJTNeeded_
()) {
742
_createNewJT_
();
743
}
else
{
744
// here, we can answer the next queries without reconstructing all the
745
// junction tree. All we need to do is to indicate that we should
746
// update the potentials and messages for these queries
747
updateOutdatedPotentials_
();
748
}
749
}
750
751
752
/// invalidate all the messages sent from a given clique
753
template
<
typename
GUM_SCALAR
>
754
void
LazyPropagation
<
GUM_SCALAR
>::
_diffuseMessageInvalidations_
(
NodeId
from_id
,
755
NodeId
to_id
,
756
NodeSet
&
invalidated_cliques
) {
757
// invalidate the current clique
758
invalidated_cliques
.
insert
(
to_id
);
759
760
// invalidate the current arc
761
const
Arc
arc
(
from_id
,
to_id
);
762
bool
&
message_computed
=
_messages_computed_
[
arc
];
763
if
(
message_computed
) {
764
message_computed
=
false
;
765
_separator_potentials_
[
arc
].
clear
();
766
if
(
_created_potentials_
.
exists
(
arc
)) {
767
auto
&
arc_created_potentials
=
_created_potentials_
[
arc
];
768
for
(
const
auto
pot
:
arc_created_potentials
)
769
delete
pot
;
770
arc_created_potentials
.
clear
();
771
}
772
773
// go on with the diffusion
774
for
(
const
auto
node_id
:
_JT_
->
neighbours
(
to_id
)) {
775
if
(
node_id
!=
from_id
)
_diffuseMessageInvalidations_
(
to_id
,
node_id
,
invalidated_cliques
);
776
}
777
}
778
}
779
780
781
/// update the potentials stored in the cliques and invalidate outdated
782
/// messages
783
template
<
typename
GUM_SCALAR
>
784
void
LazyPropagation
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {
785
// compute the set of CPTs that were projected due to hard evidence and
786
// whose hard evidence have changed, so that they need a new projection.
787
// By the way, remove these CPTs since they are no more needed
788
// Here only the values of the hard evidence can have changed (else a
789
// fully new join tree would have been computed).
790
// Note also that we know that the CPTs still contain some variable(s) after
791
// the projection (else they should be constants)
792
NodeSet
hard_nodes_changed
(
_hard_ev_nodes_
.
size
());
793
for
(
const
auto
node
:
_hard_ev_nodes_
)
794
if
(
_evidence_changes_
.
exists
(
node
))
hard_nodes_changed
.
insert
(
node
);
795
796
NodeSet
nodes_with_projected_CPTs_changed
;
797
const
auto
&
bn
=
this
->
BN
();
798
for
(
auto
pot_iter
=
_hard_ev_projected_CPTs_
.
beginSafe
();
799
pot_iter
!=
_hard_ev_projected_CPTs_
.
endSafe
();
800
++
pot_iter
) {
801
for
(
const
auto
var
:
bn
.
cpt
(
pot_iter
.
key
()).
variablesSequence
()) {
802
if
(
hard_nodes_changed
.
contains
(
bn
.
nodeId
(*
var
))) {
803
nodes_with_projected_CPTs_changed
.
insert
(
pot_iter
.
key
());
804
delete
pot_iter
.
val
();
805
_clique_potentials_
[
_node_to_clique_
[
pot_iter
.
key
()]].
erase
(
pot_iter
.
val
());
806
_hard_ev_projected_CPTs_
.
erase
(
pot_iter
);
807
break
;
808
}
809
}
810
}
811
812
813
// invalidate all the messages that are no more correct: start from each of
814
// the nodes whose soft evidence has changed and perform a diffusion from
815
// the clique into which the soft evidence has been entered, indicating that
816
// the messages spreading from this clique are now invalid. At the same time,
817
// if there were potentials created on the arcs over which the messages were
818
// sent, remove them from memory. For all the cliques that received some
819
// projected CPT that should now be changed, do the same.
820
NodeSet
invalidated_cliques
(
_JT_
->
size
());
821
for
(
const
auto
&
pair
:
_evidence_changes_
) {
822
if
(
_node_to_clique_
.
exists
(
pair
.
first
)) {
823
const
auto
clique
=
_node_to_clique_
[
pair
.
first
];
824
invalidated_cliques
.
insert
(
clique
);
825
for
(
const
auto
neighbor
:
_JT_
->
neighbours
(
clique
)) {
826
_diffuseMessageInvalidations_
(
clique
,
neighbor
,
invalidated_cliques
);
827
}
828
}
829
}
830
831
// now, add to the set of invalidated cliques those that contain projected
832
// CPTs that were changed.
833
for
(
const
auto
node
:
nodes_with_projected_CPTs_changed
) {
834
const
auto
clique
=
_node_to_clique_
[
node
];
835
invalidated_cliques
.
insert
(
clique
);
836
for
(
const
auto
neighbor
:
_JT_
->
neighbours
(
clique
)) {
837
_diffuseMessageInvalidations_
(
clique
,
neighbor
,
invalidated_cliques
);
838
}
839
}
840
841
842
// now we shall remove all the posteriors that belong to the
843
// invalidated cliques. First, cope only with the nodes that did not
844
// received hard evidence since the other nodes do not belong to the
845
// join tree
846
for
(
auto
iter
=
_target_posteriors_
.
beginSafe
();
iter
!=
_target_posteriors_
.
endSafe
();
847
++
iter
) {
848
if
(
_graph_
.
exists
(
iter
.
key
())
849
&& (
invalidated_cliques
.
exists
(
_node_to_clique_
[
iter
.
key
()]))) {
850
delete
iter
.
val
();
851
_target_posteriors_
.
erase
(
iter
);
852
}
853
}
854
855
// now cope with the nodes that received hard evidence
856
for
(
auto
iter
=
_target_posteriors_
.
beginSafe
();
iter
!=
_target_posteriors_
.
endSafe
();
857
++
iter
) {
858
if
(
hard_nodes_changed
.
contains
(
iter
.
key
())) {
859
delete
iter
.
val
();
860
_target_posteriors_
.
erase
(
iter
);
861
}
862
}
863
864
// finally, cope with joint targets
865
for
(
auto
iter
=
_joint_target_posteriors_
.
beginSafe
();
866
iter
!=
_joint_target_posteriors_
.
endSafe
();
867
++
iter
) {
868
if
(
invalidated_cliques
.
exists
(
_joint_target_to_clique_
[
iter
.
key
()])) {
869
delete
iter
.
val
();
870
_joint_target_posteriors_
.
erase
(
iter
);
871
}
872
}
873
874
875
// remove all the evidence that were entered into _node_to_soft_evidence_
876
// and _clique_potentials_ and add the new soft ones
877
for
(
const
auto
&
pot_pair
:
_node_to_soft_evidence_
) {
878
_clique_potentials_
[
_node_to_clique_
[
pot_pair
.
first
]].
erase
(
pot_pair
.
second
);
879
}
880
_node_to_soft_evidence_
.
clear
();
881
882
const
auto
&
evidence
=
this
->
evidence
();
883
for
(
const
auto
node
:
this
->
softEvidenceNodes
()) {
884
_node_to_soft_evidence_
.
insert
(
node
,
evidence
[
node
]);
885
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
evidence
[
node
]);
886
}
887
888
889
// Now add the projections of the CPTs due to newly changed hard evidence:
890
// if we are performing updateOutdatedPotentials_, this means that the
891
// set of nodes that received hard evidence has not changed, only
892
// their instantiations can have changed. So, if there is an entry
893
// for node in _constants_, there will still be such an entry after
894
// performing the new projections. Idem for _hard_ev_projected_CPTs_
895
for
(
const
auto
node
:
nodes_with_projected_CPTs_changed
) {
896
// perform the projection with a combine and project instance
897
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
898
const
auto
&
variables
=
cpt
.
variablesSequence
();
899
Set
<
const
DiscreteVariable
* >
hard_variables
;
900
_PotentialSet_
marg_cpt_set
{&
cpt
};
901
for
(
const
auto
var
:
variables
) {
902
NodeId
xnode
=
bn
.
nodeId
(*
var
);
903
if
(
_hard_ev_nodes_
.
exists
(
xnode
)) {
904
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
905
hard_variables
.
insert
(
var
);
906
}
907
}
908
909
// perform the combination of those potentials and their projection
910
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
911
_combination_op_
,
912
LPNewprojPotential
);
913
_PotentialSet_
new_cpt_list
914
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
915
916
// there should be only one potential in new_cpt_list
917
if
(
new_cpt_list
.
size
() != 1) {
918
// remove the CPT created to avoid memory leaks
919
for
(
const
auto
pot
:
new_cpt_list
) {
920
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
921
}
922
GUM_ERROR
(
FatalError
,
923
"the projection of a potential containing "
924
<<
"hard evidence is empty!"
);
925
}
926
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
927
_clique_potentials_
[
_node_to_clique_
[
node
]].
insert
(
projected_cpt
);
928
_hard_ev_projected_CPTs_
.
insert
(
node
,
projected_cpt
);
929
}
930
931
// update the constants
932
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
933
for
(
auto
&
node_cst
:
_constants_
) {
934
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node_cst
.
first
);
935
const
auto
&
variables
=
cpt
.
variablesSequence
();
936
Instantiation
inst
;
937
for
(
const
auto
var
:
variables
)
938
inst
<< *
var
;
939
for
(
const
auto
var
:
variables
) {
940
inst
.
chgVal
(
var
,
hard_evidence
[
bn
.
nodeId
(*
var
)]);
941
}
942
node_cst
.
second
=
cpt
.
get
(
inst
);
943
}
944
945
// indicate that all changes have been performed
946
_evidence_changes_
.
clear
();
947
}
948
949
950
/// compute a root for each connected component of _JT_
951
template
<
typename
GUM_SCALAR
>
952
void
LazyPropagation
<
GUM_SCALAR
>::
_computeJoinTreeRoots_
() {
953
// get the set of cliques in which we can find the targets and joint_targets
954
NodeSet
clique_targets
;
955
for
(
const
auto
node
:
this
->
targets
()) {
956
try
{
957
clique_targets
.
insert
(
_node_to_clique_
[
node
]);
958
}
catch
(
Exception
&) {}
959
}
960
for
(
const
auto
&
set
:
this
->
jointTargets
()) {
961
try
{
962
clique_targets
.
insert
(
_joint_target_to_clique_
[
set
]);
963
}
catch
(
Exception
&) {}
964
}
965
966
// put in a vector these cliques and their size
967
std
::
vector
<
std
::
pair
<
NodeId
,
Size
> >
possible_roots
(
clique_targets
.
size
());
968
const
auto
&
bn
=
this
->
BN
();
969
std
::
size_t
i
= 0;
970
for
(
const
auto
clique_id
:
clique_targets
) {
971
const
auto
&
clique
=
_JT_
->
clique
(
clique_id
);
972
Size
dom_size
= 1;
973
for
(
const
auto
node
:
clique
) {
974
dom_size
*=
bn
.
variable
(
node
).
domainSize
();
975
}
976
possible_roots
[
i
] =
std
::
pair
<
NodeId
,
Size
>(
clique_id
,
dom_size
);
977
++
i
;
978
}
979
980
// sort the cliques by increasing domain size
981
std
::
sort
(
possible_roots
.
begin
(),
982
possible_roots
.
end
(),
983
[](
const
std
::
pair
<
NodeId
,
Size
>&
a
,
const
std
::
pair
<
NodeId
,
Size
>&
b
) ->
bool
{
984
return
a
.
second
<
b
.
second
;
985
});
986
987
// pick up the clique with the smallest size in each connected component
988
NodeProperty
<
bool
>
marked
=
_JT_
->
nodesProperty
(
false
);
989
std
::
function
<
void
(
NodeId
,
NodeId
) >
diffuse_marks
990
= [&
marked
, &
diffuse_marks
,
this
](
NodeId
node
,
NodeId
from
) {
991
if
(!
marked
[
node
]) {
992
marked
[
node
] =
true
;
993
for
(
const
auto
neigh
:
_JT_
->
neighbours
(
node
))
994
if
((
neigh
!=
from
) && !
marked
[
neigh
])
diffuse_marks
(
neigh
,
node
);
995
}
996
};
997
_roots_
.
clear
();
998
for
(
const
auto
xclique
:
possible_roots
) {
999
NodeId
clique
=
xclique
.
first
;
1000
if
(!
marked
[
clique
]) {
1001
_roots_
.
insert
(
clique
);
1002
diffuse_marks
(
clique
,
clique
);
1003
}
1004
}
1005
}
1006
1007
1008
// performs the collect phase of Lazy Propagation
1009
template
<
typename
GUM_SCALAR
>
1010
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
_collectMessage_
(
NodeId
id
,
NodeId
from
) {
1011
for
(
const
auto
other
:
_JT_
->
neighbours
(
id
)) {
1012
if
((
other
!=
from
) && !
_messages_computed_
[
Arc
(
other
,
id
)])
_collectMessage_
(
other
,
id
);
1013
}
1014
1015
if
((
id
!=
from
) && !
_messages_computed_
[
Arc
(
id
,
from
)]) {
_produceMessage_
(
id
,
from
); }
1016
}
1017
1018
1019
// find the potentials d-connected to a set of variables
1020
template
<
typename
GUM_SCALAR
>
1021
void
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsGetAll_
(
1022
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1023
Set
<
const
DiscreteVariable
* >&
kept_vars
) {}
1024
1025
1026
// find the potentials d-connected to a set of variables
1027
template
<
typename
GUM_SCALAR
>
1028
void
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation_
(
1029
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1030
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1031
// find the node ids of the kept variables
1032
NodeSet
kept_ids
;
1033
const
auto
&
bn
=
this
->
BN
();
1034
for
(
const
auto
var
:
kept_vars
) {
1035
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
1036
}
1037
1038
// determine the set of potentials d-connected with the kept variables
1039
NodeSet
requisite_nodes
;
1040
BayesBall
::
requisiteNodes
(
bn
.
dag
(),
1041
kept_ids
,
1042
this
->
hardEvidenceNodes
(),
1043
this
->
softEvidenceNodes
(),
1044
requisite_nodes
);
1045
for
(
auto
iter
=
pot_list
.
beginSafe
();
iter
!=
pot_list
.
endSafe
(); ++
iter
) {
1046
const
Sequence
<
const
DiscreteVariable
* >&
vars
= (**
iter
).
variablesSequence
();
1047
bool
found
=
false
;
1048
for
(
const
auto
var
:
vars
) {
1049
if
(
requisite_nodes
.
exists
(
bn
.
nodeId
(*
var
))) {
1050
found
=
true
;
1051
break
;
1052
}
1053
}
1054
1055
if
(!
found
) {
pot_list
.
erase
(
iter
); }
1056
}
1057
}
1058
1059
1060
// find the potentials d-connected to a set of variables
1061
template
<
typename
GUM_SCALAR
>
1062
void
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation2_
(
1063
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1064
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1065
// find the node ids of the kept variables
1066
NodeSet
kept_ids
;
1067
const
auto
&
bn
=
this
->
BN
();
1068
for
(
const
auto
var
:
kept_vars
) {
1069
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
1070
}
1071
1072
// determine the set of potentials d-connected with the kept variables
1073
BayesBall
::
relevantPotentials
(
bn
,
1074
kept_ids
,
1075
this
->
hardEvidenceNodes
(),
1076
this
->
softEvidenceNodes
(),
1077
pot_list
);
1078
}
1079
1080
1081
// find the potentials d-connected to a set of variables
1082
template
<
typename
GUM_SCALAR
>
1083
void
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation3_
(
1084
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1085
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1086
// find the node ids of the kept variables
1087
NodeSet
kept_ids
;
1088
const
auto
&
bn
=
this
->
BN
();
1089
for
(
const
auto
var
:
kept_vars
) {
1090
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
1091
}
1092
1093
// determine the set of potentials d-connected with the kept variables
1094
dSeparation
dsep
;
1095
dsep
.
relevantPotentials
(
bn
,
1096
kept_ids
,
1097
this
->
hardEvidenceNodes
(),
1098
this
->
softEvidenceNodes
(),
1099
pot_list
);
1100
}
1101
1102
1103
// find the potentials d-connected to a set of variables
1104
template
<
typename
GUM_SCALAR
>
1105
void
LazyPropagation
<
GUM_SCALAR
>::
_findRelevantPotentialsXX_
(
1106
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
1107
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1108
switch
(
_find_relevant_potential_type_
) {
1109
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
1110
_findRelevantPotentialsWithdSeparation2_
(
pot_list
,
kept_vars
);
1111
break
;
1112
1113
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
1114
_findRelevantPotentialsWithdSeparation_
(
pot_list
,
kept_vars
);
1115
break
;
1116
1117
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
1118
_findRelevantPotentialsWithdSeparation3_
(
pot_list
,
kept_vars
);
1119
break
;
1120
1121
case
RelevantPotentialsFinderType
::
FIND_ALL
:
1122
_findRelevantPotentialsGetAll_
(
pot_list
,
kept_vars
);
1123
break
;
1124
1125
default
:
1126
GUM_ERROR
(
FatalError
,
"not implemented yet"
)
1127
}
1128
}
1129
1130
1131
// remove barren variables
1132
template
<
typename
GUM_SCALAR
>
1133
Set
<
const
Potential
<
GUM_SCALAR
>* >
LazyPropagation
<
GUM_SCALAR
>::
_removeBarrenVariables_
(
1134
_PotentialSet_
&
pot_list
,
1135
Set
<
const
DiscreteVariable
* >&
del_vars
) {
1136
// remove from del_vars the variables that received some evidence:
1137
// only those that did not received evidence can be barren variables
1138
Set
<
const
DiscreteVariable
* >
the_del_vars
=
del_vars
;
1139
for
(
auto
iter
=
the_del_vars
.
beginSafe
();
iter
!=
the_del_vars
.
endSafe
(); ++
iter
) {
1140
NodeId
id
=
this
->
BN
().
nodeId
(**
iter
);
1141
if
(
this
->
hardEvidenceNodes
().
exists
(
id
) ||
this
->
softEvidenceNodes
().
exists
(
id
)) {
1142
the_del_vars
.
erase
(
iter
);
1143
}
1144
}
1145
1146
// assign to each random variable the set of potentials that contain it
1147
HashTable
<
const
DiscreteVariable
*,
_PotentialSet_
>
var2pots
;
1148
_PotentialSet_
empty_pot_set
;
1149
for
(
const
auto
pot
:
pot_list
) {
1150
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
1151
for
(
const
auto
var
:
vars
) {
1152
if
(
the_del_vars
.
exists
(
var
)) {
1153
if
(!
var2pots
.
exists
(
var
)) {
var2pots
.
insert
(
var
,
empty_pot_set
); }
1154
var2pots
[
var
].
insert
(
pot
);
1155
}
1156
}
1157
}
1158
1159
// each variable with only one potential is a barren variable
1160
// assign to each potential with barren nodes its set of barren variables
1161
HashTable
<
const
Potential
<
GUM_SCALAR
>*,
Set
<
const
DiscreteVariable
* > >
pot2barren_var
;
1162
Set
<
const
DiscreteVariable
* >
empty_var_set
;
1163
for
(
const
auto
elt
:
var2pots
) {
1164
if
(
elt
.
second
.
size
() == 1) {
// here we have a barren variable
1165
const
Potential
<
GUM_SCALAR
>*
pot
= *(
elt
.
second
.
begin
());
1166
if
(!
pot2barren_var
.
exists
(
pot
)) {
pot2barren_var
.
insert
(
pot
,
empty_var_set
); }
1167
pot2barren_var
[
pot
].
insert
(
elt
.
first
);
// insert the barren variable
1168
}
1169
}
1170
1171
// for each potential with barren variables, marginalize them.
1172
// if the potential has only barren variables, simply remove them from the
1173
// set of potentials, else just project the potential
1174
MultiDimProjection
<
GUM_SCALAR
,
Potential
>
projector
(
LPNewprojPotential
);
1175
_PotentialSet_
projected_pots
;
1176
for
(
const
auto
elt
:
pot2barren_var
) {
1177
// remove the current potential from pot_list as, anyway, we will change
1178
// it
1179
const
Potential
<
GUM_SCALAR
>*
pot
=
elt
.
first
;
1180
pot_list
.
erase
(
pot
);
1181
1182
// check whether we need to add a projected new potential or not (i.e.,
1183
// whether there exist non-barren variables or not)
1184
if
(
pot
->
variablesSequence
().
size
() !=
elt
.
second
.
size
()) {
1185
auto
new_pot
=
projector
.
project
(*
pot
,
elt
.
second
);
1186
pot_list
.
insert
(
new_pot
);
1187
projected_pots
.
insert
(
new_pot
);
1188
}
1189
}
1190
1191
return
projected_pots
;
1192
}
1193
1194
1195
// remove variables del_vars from the list of potentials pot_list
1196
template
<
typename
GUM_SCALAR
>
1197
Set
<
const
Potential
<
GUM_SCALAR
>* >
1198
LazyPropagation
<
GUM_SCALAR
>::
_marginalizeOut_
(
Set
<
const
Potential
<
GUM_SCALAR
>* >
pot_list
,
1199
Set
<
const
DiscreteVariable
* >&
del_vars
,
1200
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
1201
// use d-separation analysis to check which potentials shall be combined
1202
_findRelevantPotentialsXX_
(
pot_list
,
kept_vars
);
1203
1204
// remove the potentials corresponding to barren variables if we want
1205
// to exploit barren nodes
1206
_PotentialSet_
barren_projected_potentials
;
1207
if
(
_barren_nodes_type_
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
1208
barren_projected_potentials
=
_removeBarrenVariables_
(
pot_list
,
del_vars
);
1209
}
1210
1211
// create a combine and project operator that will perform the
1212
// marginalization
1213
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
_combination_op_
,
1214
_projection_op_
);
1215
_PotentialSet_
new_pot_list
=
combine_and_project
.
combineAndProject
(
pot_list
,
del_vars
);
1216
1217
// remove all the potentials that were created due to projections of
1218
// barren nodes and that are not part of the new_pot_list: these
1219
// potentials were just temporary potentials
1220
for
(
auto
iter
=
barren_projected_potentials
.
beginSafe
();
1221
iter
!=
barren_projected_potentials
.
endSafe
();
1222
++
iter
) {
1223
if
(!
new_pot_list
.
exists
(*
iter
))
delete
*
iter
;
1224
}
1225
1226
// remove all the potentials that have no dimension
1227
for
(
auto
iter_pot
=
new_pot_list
.
beginSafe
();
iter_pot
!=
new_pot_list
.
endSafe
(); ++
iter_pot
) {
1228
if
((*
iter_pot
)->
variablesSequence
().
size
() == 0) {
1229
// as we have already marginalized out variables that received evidence,
1230
// it may be the case that, after combining and projecting, some
1231
// potentials might be empty. In this case, we shall keep their
1232
// constant and remove them from memory
1233
// # TODO: keep the constants!
1234
delete
*
iter_pot
;
1235
new_pot_list
.
erase
(
iter_pot
);
1236
}
1237
}
1238
1239
return
new_pot_list
;
1240
}
1241
1242
1243
// creates the message sent by clique from_id to clique to_id
1244
template
<
typename
GUM_SCALAR
>
1245
void
LazyPropagation
<
GUM_SCALAR
>::
_produceMessage_
(
NodeId
from_id
,
NodeId
to_id
) {
1246
// get the potentials of the clique.
1247
_PotentialSet_
pot_list
=
_clique_potentials_
[
from_id
];
1248
1249
// add the messages sent by adjacent nodes to from_id
1250
for
(
const
auto
other_id
:
_JT_
->
neighbours
(
from_id
))
1251
if
(
other_id
!=
to_id
)
pot_list
+=
_separator_potentials_
[
Arc
(
other_id
,
from_id
)];
1252
1253
// get the set of variables that need be removed from the potentials
1254
const
NodeSet
&
from_clique
=
_JT_
->
clique
(
from_id
);
1255
const
NodeSet
&
separator
=
_JT_
->
separator
(
from_id
,
to_id
);
1256
Set
<
const
DiscreteVariable
* >
del_vars
(
from_clique
.
size
());
1257
Set
<
const
DiscreteVariable
* >
kept_vars
(
separator
.
size
());
1258
const
auto
&
bn
=
this
->
BN
();
1259
1260
for
(
const
auto
node
:
from_clique
) {
1261
if
(!
separator
.
contains
(
node
)) {
1262
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1263
}
else
{
1264
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1265
}
1266
}
1267
1268
// pot_list now contains all the potentials to multiply and marginalize
1269
// => combine the messages
1270
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
,
del_vars
,
kept_vars
);
1271
1272
// keep track of the newly created potentials but first replace all the
1273
// potentials whose values are all equal by constant potentials (nbrDim=0)
1274
// with this very value (as probability matrix multiplications
1275
// are tensorial, replacing the former potential by constants provides the
1276
// same computation results but speeds-up these computations)
1277
const
Arc
arc
(
from_id
,
to_id
);
1278
1279
if
(!
_created_potentials_
.
exists
(
arc
))
_created_potentials_
.
insert
(
arc
,
_PotentialSet_
());
1280
1281
for
(
auto
iter
=
new_pot_list
.
beginSafe
();
iter
!=
new_pot_list
.
endSafe
(); ++
iter
) {
1282
const
auto
pot
= *
iter
;
1283
1284
/*
1285
if (pot->variablesSequence().size() == 1) {
1286
bool is_constant = true;
1287
Instantiation inst(*pot);
1288
GUM_SCALAR first_val = pot->get(inst);
1289
1290
for (++inst; !inst.end(); ++inst) {
1291
if (pot->get(inst) != first_val) {
1292
is_constant = false;
1293
break;
1294
}
1295
}
1296
1297
if (is_constant) {
1298
// if pot is not a message sent by a separator or a potential stored
1299
// into the clique, we can remove it since it is now useless
1300
if (!pot_list.exists(pot)) delete pot;
1301
new_pot_list.erase(iter);
1302
1303
// add the new constant potential to new_pot_list
1304
const auto new_pot = new Potential<GUM_SCALAR>;
1305
Instantiation new_inst(new_pot);
1306
new_pot->set(new_inst, first_val);
1307
new_pot_list.insert (new_pot);
1308
_created_potentials_[arc].insert(new_pot);
1309
continue;
1310
}
1311
}
1312
*/
1313
1314
if
(!
pot_list
.
exists
(
pot
)) {
_created_potentials_
[
arc
].
insert
(
pot
); }
1315
}
1316
1317
_separator_potentials_
[
arc
] =
std
::
move
(
new_pot_list
);
1318
_messages_computed_
[
arc
] =
true
;
1319
}
1320
1321
1322
// performs a whole inference
1323
template
<
typename
GUM_SCALAR
>
1324
INLINE
void
LazyPropagation
<
GUM_SCALAR
>::
makeInference_
() {
1325
// collect messages for all single targets
1326
for
(
const
auto
node
:
this
->
targets
()) {
1327
// perform only collects in the join tree for nodes that have
1328
// not received hard evidence (those that received hard evidence were
1329
// not included into the join tree for speed-up reasons)
1330
if
(
_graph_
.
exists
(
node
)) {
1331
_collectMessage_
(
_node_to_clique_
[
node
],
_node_to_clique_
[
node
]);
1332
}
1333
}
1334
1335
// collect messages for all set targets
1336
// by parsing _joint_target_to_clique_, we ensure that the cliques that
1337
// are referenced belong to the join tree (even if some of the nodes in
1338
// their associated joint_target do not belong to _graph_)
1339
for
(
const
auto
set
:
_joint_target_to_clique_
)
1340
_collectMessage_
(
set
.
second
,
set
.
second
);
1341
}
1342
1343
1344
/// returns a fresh potential equal to P(1st arg,evidence)
1345
template
<
typename
GUM_SCALAR
>
1346
Potential
<
GUM_SCALAR
>*
LazyPropagation
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
NodeId
id
) {
1347
const
auto
&
bn
=
this
->
BN
();
1348
1349
// hard evidence do not belong to the join tree
1350
// # TODO: check for sets of inconsistent hard evidence
1351
if
(
this
->
hardEvidenceNodes
().
contains
(
id
)) {
1352
return
new
Potential
<
GUM_SCALAR
>(*(
this
->
evidence
()[
id
]));
1353
}
1354
1355
// if we still need to perform some inference task, do it (this should
1356
// already have been done by makeInference_)
1357
NodeId
clique_of_id
=
_node_to_clique_
[
id
];
1358
_collectMessage_
(
clique_of_id
,
clique_of_id
);
1359
1360
// now we just need to create the product of the potentials of the clique
1361
// containing id with the messages received by this clique and
1362
// marginalize out all variables except id
1363
_PotentialSet_
pot_list
=
_clique_potentials_
[
clique_of_id
];
1364
1365
// add the messages sent by adjacent nodes to targetClique
1366
for
(
const
auto
other
:
_JT_
->
neighbours
(
clique_of_id
))
1367
pot_list
+=
_separator_potentials_
[
Arc
(
other
,
clique_of_id
)];
1368
1369
// get the set of variables that need be removed from the potentials
1370
const
NodeSet
&
nodes
=
_JT_
->
clique
(
clique_of_id
);
1371
Set
<
const
DiscreteVariable
* >
kept_vars
{&(
bn
.
variable
(
id
))};
1372
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1373
for
(
const
auto
node
:
nodes
) {
1374
if
(
node
!=
id
)
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1375
}
1376
1377
// pot_list now contains all the potentials to multiply and marginalize
1378
// => combine the messages
1379
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
,
del_vars
,
kept_vars
);
1380
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1381
1382
if
(
new_pot_list
.
size
() == 1) {
1383
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1384
// if pot already existed, create a copy, so that we can put it into
1385
// the _target_posteriors_ property
1386
if
(
pot_list
.
exists
(
joint
)) {
1387
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1388
}
else
{
1389
// remove the joint from new_pot_list so that it will not be
1390
// removed just after the else block
1391
new_pot_list
.
clear
();
1392
}
1393
}
else
{
1394
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1395
joint
=
fast_combination
.
combine
(
new_pot_list
);
1396
}
1397
1398
// remove the potentials that were created in new_pot_list
1399
for
(
const
auto
pot
:
new_pot_list
)
1400
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1401
1402
// check that the joint posterior is different from a 0 vector: this would
1403
// indicate that some hard evidence are not compatible (their joint
1404
// probability is equal to 0)
1405
bool
nonzero_found
=
false
;
1406
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1407
if
(
joint
->
get
(
inst
)) {
1408
nonzero_found
=
true
;
1409
break
;
1410
}
1411
}
1412
if
(!
nonzero_found
) {
1413
// remove joint from memory to avoid memory leaks
1414
delete
joint
;
1415
GUM_ERROR
(
IncompatibleEvidence
,
1416
"some evidence entered into the Bayes "
1417
"net are incompatible (their joint proba = 0)"
);
1418
}
1419
return
joint
;
1420
}
1421
1422
1423
/// returns the posterior of a given variable
1424
template
<
typename
GUM_SCALAR
>
1425
const
Potential
<
GUM_SCALAR
>&
LazyPropagation
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
1426
// check if we have already computed the posterior
1427
if
(
_target_posteriors_
.
exists
(
id
)) {
return
*(
_target_posteriors_
[
id
]); }
1428
1429
// compute the joint posterior and normalize
1430
auto
joint
=
unnormalizedJointPosterior_
(
id
);
1431
if
(
joint
->
sum
() != 1)
// hard test for ReadOnly CPT (as aggregator)
1432
joint
->
normalize
();
1433
_target_posteriors_
.
insert
(
id
,
joint
);
1434
1435
return
*
joint
;
1436
}
1437
1438
1439
// returns the marginal a posteriori proba of a given node
1440
template
<
typename
GUM_SCALAR
>
1441
Potential
<
GUM_SCALAR
>*
1442
LazyPropagation
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
const
NodeSet
&
set
) {
1443
// hard evidence do not belong to the join tree, so extract the nodes
1444
// from targets that are not hard evidence
1445
NodeSet
targets
=
set
,
hard_ev_nodes
;
1446
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
1447
if
(
targets
.
contains
(
node
)) {
1448
targets
.
erase
(
node
);
1449
hard_ev_nodes
.
insert
(
node
);
1450
}
1451
}
1452
1453
// if all the nodes have received hard evidence, then compute the
1454
// joint posterior directly by multiplying the hard evidence potentials
1455
const
auto
&
evidence
=
this
->
evidence
();
1456
if
(
targets
.
empty
()) {
1457
_PotentialSet_
pot_list
;
1458
for
(
const
auto
node
:
set
) {
1459
pot_list
.
insert
(
evidence
[
node
]);
1460
}
1461
if
(
pot_list
.
size
() == 1) {
1462
auto
pot
=
new
Potential
<
GUM_SCALAR
>(**(
pot_list
.
begin
()));
1463
return
pot
;
1464
}
else
{
1465
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1466
return
fast_combination
.
combine
(
pot_list
);
1467
}
1468
}
1469
1470
1471
// if we still need to perform some inference task, do it: so, first,
1472
// determine the clique on which we should perform collect to compute
1473
// the unnormalized joint posterior of a set of nodes containing "targets"
1474
NodeId
clique_of_set
;
1475
try
{
1476
clique_of_set
=
_joint_target_to_clique_
[
set
];
1477
}
catch
(
NotFound
&) {
1478
// here, the precise set of targets does not belong to the set of targets
1479
// defined by the user. So we will try to find a clique in the junction
1480
// tree that contains "targets":
1481
1482
// 1/ we should check that all the nodes belong to the join tree
1483
for
(
const
auto
node
:
targets
) {
1484
if
(!
_graph_
.
exists
(
node
)) {
GUM_ERROR
(
UndefinedElement
,
node
<<
" is not a target node"
) }
1485
}
1486
1487
// 2/ the clique created by the first eliminated node among target is the
1488
// one we are looking for
1489
const
std
::
vector
<
NodeId
>&
JT_elim_order
=
_triangulation_
->
eliminationOrder
();
1490
1491
NodeProperty
<
int
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
1492
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
; ++
i
)
1493
elim_order
.
insert
(
JT_elim_order
[
i
], (
int
)
i
);
1494
NodeId
first_eliminated_node
= *(
targets
.
begin
());
1495
int
elim_number
=
elim_order
[
first_eliminated_node
];
1496
for
(
const
auto
node
:
targets
) {
1497
if
(
elim_order
[
node
] <
elim_number
) {
1498
elim_number
=
elim_order
[
node
];
1499
first_eliminated_node
=
node
;
1500
}
1501
}
1502
1503
clique_of_set
=
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
);
1504
1505
1506
// 3/ check that clique_of_set contains the all the nodes in the target
1507
const
NodeSet
&
clique_nodes
=
_JT_
->
clique
(
clique_of_set
);
1508
for
(
const
auto
node
:
targets
) {
1509
if
(!
clique_nodes
.
contains
(
node
)) {
1510
GUM_ERROR
(
UndefinedElement
,
set
<<
" is not a joint target"
)
1511
}
1512
}
1513
1514
// add the discovered clique to _joint_target_to_clique_
1515
_joint_target_to_clique_
.
insert
(
set
,
clique_of_set
);
1516
}
1517
1518
// now perform a collect on the clique
1519
_collectMessage_
(
clique_of_set
,
clique_of_set
);
1520
1521
// now we just need to create the product of the potentials of the clique
1522
// containing set with the messages received by this clique and
1523
// marginalize out all variables except set
1524
_PotentialSet_
pot_list
=
_clique_potentials_
[
clique_of_set
];
1525
1526
// add the messages sent by adjacent nodes to targetClique
1527
for
(
const
auto
other
:
_JT_
->
neighbours
(
clique_of_set
))
1528
pot_list
+=
_separator_potentials_
[
Arc
(
other
,
clique_of_set
)];
1529
1530
// get the set of variables that need be removed from the potentials
1531
const
NodeSet
&
nodes
=
_JT_
->
clique
(
clique_of_set
);
1532
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1533
Set
<
const
DiscreteVariable
* >
kept_vars
(
targets
.
size
());
1534
const
auto
&
bn
=
this
->
BN
();
1535
for
(
const
auto
node
:
nodes
) {
1536
if
(!
targets
.
contains
(
node
)) {
1537
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1538
}
else
{
1539
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1540
}
1541
}
1542
1543
// pot_list now contains all the potentials to multiply and marginalize
1544
// => combine the messages
1545
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
,
del_vars
,
kept_vars
);
1546
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1547
1548
if
((
new_pot_list
.
size
() == 1) &&
hard_ev_nodes
.
empty
()) {
1549
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1550
1551
// if pot already existed, create a copy, so that we can put it into
1552
// the _target_posteriors_ property
1553
if
(
pot_list
.
exists
(
joint
)) {
1554
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1555
}
else
{
1556
// remove the joint from new_pot_list so that it will not be
1557
// removed just after the next else block
1558
new_pot_list
.
clear
();
1559
}
1560
}
else
{
1561
// combine all the potentials in new_pot_list with all the hard evidence
1562
// of the nodes in set
1563
_PotentialSet_
new_new_pot_list
=
new_pot_list
;
1564
for
(
const
auto
node
:
hard_ev_nodes
) {
1565
new_new_pot_list
.
insert
(
evidence
[
node
]);
1566
}
1567
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1568
joint
=
fast_combination
.
combine
(
new_new_pot_list
);
1569
}
1570
1571
// remove the potentials that were created in new_pot_list
1572
for
(
const
auto
pot
:
new_pot_list
)
1573
if
(!
pot_list
.
exists
(
pot
))
delete
pot
;
1574
1575
// check that the joint posterior is different from a 0 vector: this would
1576
// indicate that some hard evidence are not compatible
1577
bool
nonzero_found
=
false
;
1578
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1579
if
((*
joint
)[
inst
]) {
1580
nonzero_found
=
true
;
1581
break
;
1582
}
1583
}
1584
if
(!
nonzero_found
) {
1585
// remove joint from memory to avoid memory leaks
1586
delete
joint
;
1587
GUM_ERROR
(
IncompatibleEvidence
,
1588
"some evidence entered into the Bayes "
1589
"net are incompatible (their joint proba = 0)"
);
1590
}
1591
1592
return
joint
;
1593
}
1594
1595
1596
/// returns the posterior of a given set of variables
1597
template
<
typename
GUM_SCALAR
>
1598
const
Potential
<
GUM_SCALAR
>&
1599
LazyPropagation
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
set
) {
1600
// check if we have already computed the posterior
1601
if
(
_joint_target_posteriors_
.
exists
(
set
)) {
return
*(
_joint_target_posteriors_
[
set
]); }
1602
1603
// compute the joint posterior and normalize
1604
auto
joint
=
unnormalizedJointPosterior_
(
set
);
1605
joint
->
normalize
();
1606
_joint_target_posteriors_
.
insert
(
set
,
joint
);
1607
1608
return
*
joint
;
1609
}
1610
1611
1612
/// returns the posterior of a given set of variables
1613
template
<
typename
GUM_SCALAR
>
1614
const
Potential
<
GUM_SCALAR
>&
1615
LazyPropagation
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
wanted_target
,
1616
const
NodeSet
&
declared_target
) {
1617
// check if we have already computed the posterior of wanted_target
1618
if
(
_joint_target_posteriors_
.
exists
(
wanted_target
))
1619
return
*(
_joint_target_posteriors_
[
wanted_target
]);
1620
1621
// here, we will have to compute the posterior of declared_target and
1622
// marginalize out all the variables that do not belong to wanted_target
1623
1624
// check if we have already computed the posterior of declared_target
1625
if
(!
_joint_target_posteriors_
.
exists
(
declared_target
)) {
jointPosterior_
(
declared_target
); }
1626
1627
// marginalize out all the variables that do not belong to wanted_target
1628
const
auto
&
bn
=
this
->
BN
();
1629
Set
<
const
DiscreteVariable
* >
del_vars
;
1630
for
(
const
auto
node
:
declared_target
)
1631
if
(!
wanted_target
.
contains
(
node
))
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1632
Potential
<
GUM_SCALAR
>*
pot
=
new
Potential
<
GUM_SCALAR
>(
1633
_joint_target_posteriors_
[
declared_target
]->
margSumOut
(
del_vars
));
1634
1635
// save the result into the cache
1636
_joint_target_posteriors_
.
insert
(
wanted_target
,
pot
);
1637
1638
return
*
pot
;
1639
}
1640
1641
1642
template
<
typename
GUM_SCALAR
>
1643
GUM_SCALAR
LazyPropagation
<
GUM_SCALAR
>::
evidenceProbability
() {
1644
// here, we should check that _find_relevant_potential_type_ is equal to
1645
// FIND_ALL. Otherwise, the computations could be wrong.
1646
RelevantPotentialsFinderType
old_relevant_type
=
_find_relevant_potential_type_
;
1647
1648
// if the relevant potentials finder is not equal to FIND_ALL, all the
1649
// current computations may lead to incorrect results, so we shall
1650
// discard them
1651
if
(
old_relevant_type
!=
RelevantPotentialsFinderType
::
FIND_ALL
) {
1652
_find_relevant_potential_type_
=
RelevantPotentialsFinderType
::
FIND_ALL
;
1653
_is_new_jt_needed_
=
true
;
1654
this
->
setOutdatedStructureState_
();
1655
}
1656
1657
// perform inference in each connected component
1658
this
->
makeInference
();
1659
1660
// for each connected component, select a variable X and compute the
1661
// joint probability of X and evidence e. Then marginalize-out X to get
1662
// p(e) in this connected component. Finally, multiply all the p(e) that
1663
// we got and the elements in _constants_. The result is the probability
1664
// of evidence
1665
1666
GUM_SCALAR
prob_ev
= 1;
1667
for
(
const
auto
root
:
_roots_
) {
1668
// get a node in the clique
1669
NodeId
node
= *(
_JT_
->
clique
(
root
).
begin
());
1670
Potential
<
GUM_SCALAR
>*
tmp
=
unnormalizedJointPosterior_
(
node
);
1671
GUM_SCALAR
sum
= 0;
1672
for
(
Instantiation
iter
(*
tmp
); !
iter
.
end
(); ++
iter
)
1673
sum
+=
tmp
->
get
(
iter
);
1674
1675
prob_ev
*=
sum
;
1676
delete
tmp
;
1677
}
1678
1679
for
(
const
auto
&
projected_cpt
:
_constants_
)
1680
prob_ev
*=
projected_cpt
.
second
;
1681
1682
// put back the relevant potential type selected by the user
1683
_find_relevant_potential_type_
=
old_relevant_type
;
1684
1685
return
prob_ev
;
1686
}
1687
1688
}
/* namespace gum */
1689
1690
#
endif
// DOXYGEN_SHOULD_SKIP_THIS
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643