aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
variableElimination_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of Variable Elimination for inference in
25
* Bayesian networks.
26
*
27
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
30
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
31
32
#
include
<
agrum
/
BN
/
inference
/
variableElimination
.
h
>
33
34
#
include
<
agrum
/
BN
/
algorithms
/
BayesBall
.
h
>
35
#
include
<
agrum
/
BN
/
algorithms
/
dSeparation
.
h
>
36
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
37
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
38
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
39
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimProjection
.
h
>
40
41
42
namespace
gum
{
43
44
45
// default constructor
46
template
<
typename
GUM_SCALAR >
47
INLINE VariableElimination<
GUM_SCALAR
>::
VariableElimination
(
48
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
49
RelevantPotentialsFinderType
relevant_type
,
50
FindBarrenNodesType
barren_type
) :
51
JointTargetedInference
<
GUM_SCALAR
>(
BN
) {
52
// sets the relevant potential and the barren nodes finding algorithm
53
setRelevantPotentialsFinderType
(
relevant_type
);
54
setFindBarrenNodesType
(
barren_type
);
55
56
// create a default triangulation (the user can change it afterwards)
57
triangulation__
=
new
DefaultTriangulation
;
58
59
// for debugging purposessetRequiredInference
60
GUM_CONSTRUCTOR
(
VariableElimination
);
61
}
62
63
64
// destructor
65
template
<
typename
GUM_SCALAR
>
66
INLINE
VariableElimination
<
GUM_SCALAR
>::~
VariableElimination
() {
67
// remove the junction tree and the triangulation algorithm
68
if
(
JT__
!=
nullptr
)
delete
JT__
;
69
delete
triangulation__
;
70
if
(
target_posterior__
!=
nullptr
)
delete
target_posterior__
;
71
72
// for debugging purposes
73
GUM_DESTRUCTOR
(
VariableElimination
);
74
}
75
76
77
/// set a new triangulation algorithm
78
template
<
typename
GUM_SCALAR
>
79
void
VariableElimination
<
GUM_SCALAR
>::
setTriangulation
(
80
const
Triangulation
&
new_triangulation
) {
81
delete
triangulation__
;
82
triangulation__
=
new_triangulation
.
newFactory
();
83
}
84
85
86
/// returns the current join tree used
87
template
<
typename
GUM_SCALAR
>
88
INLINE
const
JunctionTree
*
89
VariableElimination
<
GUM_SCALAR
>::
junctionTree
(
NodeId
id
) {
90
createNewJT__
(
NodeSet
{
id
});
91
92
return
JT__
;
93
}
94
95
96
/// sets the operator for performing the projections
97
template
<
typename
GUM_SCALAR
>
98
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
setProjectionFunction__
(
99
Potential
<
GUM_SCALAR
>* (*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
100
const
Set
<
const
DiscreteVariable
* >&)) {
101
projection_op__
=
proj
;
102
}
103
104
105
/// sets the operator for performing the combinations
106
template
<
typename
GUM_SCALAR
>
107
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
setCombinationFunction__
(
108
Potential
<
GUM_SCALAR
>* (*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
109
const
Potential
<
GUM_SCALAR
>&)) {
110
combination_op__
=
comb
;
111
}
112
113
114
/// sets how we determine the relevant potentials to combine
115
template
<
typename
GUM_SCALAR
>
116
void
VariableElimination
<
GUM_SCALAR
>::
setRelevantPotentialsFinderType
(
117
RelevantPotentialsFinderType
type
) {
118
if
(
type
!=
find_relevant_potential_type__
) {
119
switch
(
type
) {
120
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
121
findRelevantPotentials__
= &
VariableElimination
<
122
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation2__
;
123
break
;
124
125
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
126
findRelevantPotentials__
= &
VariableElimination
<
127
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation__
;
128
break
;
129
130
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
131
findRelevantPotentials__
= &
VariableElimination
<
132
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation3__
;
133
break
;
134
135
case
RelevantPotentialsFinderType
::
FIND_ALL
:
136
findRelevantPotentials__
137
= &
VariableElimination
<
GUM_SCALAR
>::
findRelevantPotentialsGetAll__
;
138
break
;
139
140
default
:
141
GUM_ERROR
(
InvalidArgument
,
142
"setRelevantPotentialsFinderType for type "
143
<< (
unsigned
int
)
type
<<
" is not implemented yet"
);
144
}
145
146
find_relevant_potential_type__
=
type
;
147
}
148
}
149
150
151
/// sets how we determine barren nodes
152
template
<
typename
GUM_SCALAR
>
153
void
VariableElimination
<
GUM_SCALAR
>::
setFindBarrenNodesType
(
154
FindBarrenNodesType
type
) {
155
if
(
type
!=
barren_nodes_type__
) {
156
// WARNING: if a new type is added here, method createJT__ should
157
// certainly
158
// be updated as well, in particular its step 2.
159
switch
(
type
) {
160
case
FindBarrenNodesType
::
FIND_BARREN_NODES
:
161
case
FindBarrenNodesType
::
FIND_NO_BARREN_NODES
:
162
break
;
163
164
default
:
165
GUM_ERROR
(
InvalidArgument
,
166
"setFindBarrenNodesType for type "
167
<< (
unsigned
int
)
type
<<
" is not implemented yet"
);
168
}
169
170
barren_nodes_type__
=
type
;
171
}
172
}
173
174
175
/// fired when a new evidence is inserted
176
template
<
typename
GUM_SCALAR
>
177
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
,
178
bool
) {}
179
180
181
/// fired when an evidence is removed
182
template
<
typename
GUM_SCALAR
>
183
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onEvidenceErased_
(
const
NodeId
,
184
bool
) {}
185
186
187
/// fired when all the evidence are erased
188
template
<
typename
GUM_SCALAR
>
189
void
VariableElimination
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
bool
) {}
190
191
192
/// fired when an evidence is changed
193
template
<
typename
GUM_SCALAR
>
194
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onEvidenceChanged_
(
const
NodeId
,
195
bool
) {}
196
197
198
/// fired after a new target is inserted
199
template
<
typename
GUM_SCALAR
>
200
INLINE
void
201
VariableElimination
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
const
NodeId
) {}
202
203
204
/// fired before a target is removed
205
template
<
typename
GUM_SCALAR
>
206
INLINE
void
207
VariableElimination
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
const
NodeId
) {}
208
209
/// fired after a new Bayes net has been assigned to the engine
210
template
<
typename
GUM_SCALAR
>
211
INLINE
void
212
VariableElimination
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
bn
) {
213
}
214
215
/// fired after a new set target is inserted
216
template
<
typename
GUM_SCALAR
>
217
INLINE
void
218
VariableElimination
<
GUM_SCALAR
>::
onJointTargetAdded_
(
const
NodeSet
&) {}
219
220
221
/// fired before a set target is removed
222
template
<
typename
GUM_SCALAR
>
223
INLINE
void
224
VariableElimination
<
GUM_SCALAR
>::
onJointTargetErased_
(
const
NodeSet
&) {}
225
226
227
/// fired after all the nodes of the BN are added as single targets
228
template
<
typename
GUM_SCALAR
>
229
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {}
230
231
232
/// fired before a all the single_targets are removed
233
template
<
typename
GUM_SCALAR
>
234
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
235
236
237
/// fired before a all the joint_targets are removed
238
template
<
typename
GUM_SCALAR
>
239
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllJointTargetsErased_
() {}
240
241
242
/// fired before a all the single and joint_targets are removed
243
template
<
typename
GUM_SCALAR
>
244
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllTargetsErased_
() {}
245
246
247
/// create a new junction tree as well as its related data structures
248
template
<
typename
GUM_SCALAR
>
249
void
VariableElimination
<
GUM_SCALAR
>::
createNewJT__
(
const
NodeSet
&
targets
) {
250
// to create the JT, we first create the moral graph of the BN in the
251
// following way in order to take into account the barren nodes and the
252
// nodes that received evidence:
253
// 1/ we create an undirected graph containing only the nodes and no edge
254
// 2/ if we take into account barren nodes, remove them from the graph
255
// 3/ if we take d-separation into account, remove the d-separated nodes
256
// 4/ add edges so that each node and its parents in the BN form a clique
257
// 5/ add edges so that the targets form a clique of the moral graph
258
// 6/ remove the nodes that received hard evidence (by step 4/, their
259
// parents are linked by edges, which is necessary for inference)
260
//
261
// At the end of step 6/, we have our moral graph and we can triangulate it
262
// to get the new junction tree
263
264
// 1/ create an undirected graph containing only the nodes and no edge
265
const
auto
&
bn
=
this
->
BN
();
266
graph__
.
clear
();
267
for
(
auto
node
:
bn
.
dag
())
268
graph__
.
addNodeWithId
(
node
);
269
270
// 2/ if we wish to exploit barren nodes, we shall remove them from the BN
271
// to do so: we identify all the nodes that are not targets and have
272
// received no evidence and such that their descendants are neither targets
273
// nor evidence nodes. Such nodes can be safely discarded from the BN
274
// without altering the inference output
275
if
(
barren_nodes_type__
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
276
// check that all the nodes are not targets, otherwise, there is no
277
// barren node
278
if
(
targets
.
size
() !=
bn
.
size
()) {
279
BarrenNodesFinder
finder
(&(
bn
.
dag
()));
280
finder
.
setTargets
(&
targets
);
281
282
NodeSet
evidence_nodes
;
283
for
(
const
auto
&
pair
:
this
->
evidence
()) {
284
evidence_nodes
.
insert
(
pair
.
first
);
285
}
286
finder
.
setEvidence
(&
evidence_nodes
);
287
288
NodeSet
barren_nodes
=
finder
.
barrenNodes
();
289
290
// remove the barren nodes from the moral graph
291
for
(
const
auto
node
:
barren_nodes
) {
292
graph__
.
eraseNode
(
node
);
293
}
294
}
295
}
296
297
// 3/ if we wish to exploit d-separation, remove all the nodes that are
298
// d-separated from our targets
299
{
300
NodeSet
requisite_nodes
;
301
bool
dsep_analysis
=
false
;
302
switch
(
find_relevant_potential_type__
) {
303
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
304
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
: {
305
BayesBall
::
requisiteNodes
(
bn
.
dag
(),
306
targets
,
307
this
->
hardEvidenceNodes
(),
308
this
->
softEvidenceNodes
(),
309
requisite_nodes
);
310
dsep_analysis
=
true
;
311
}
break
;
312
313
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
: {
314
dSeparation
dsep
;
315
dsep
.
requisiteNodes
(
bn
.
dag
(),
316
targets
,
317
this
->
hardEvidenceNodes
(),
318
this
->
softEvidenceNodes
(),
319
requisite_nodes
);
320
dsep_analysis
=
true
;
321
}
break
;
322
323
case
RelevantPotentialsFinderType
::
FIND_ALL
:
324
break
;
325
326
default
:
327
GUM_ERROR
(
FatalError
,
"not implemented yet"
);
328
}
329
330
// remove all the nodes that are not requisite
331
if
(
dsep_analysis
) {
332
for
(
auto
iter
=
graph__
.
beginSafe
();
iter
!=
graph__
.
endSafe
(); ++
iter
) {
333
if
(!
requisite_nodes
.
contains
(*
iter
)
334
&& !
this
->
hardEvidenceNodes
().
contains
(*
iter
)) {
335
graph__
.
eraseNode
(*
iter
);
336
}
337
}
338
}
339
}
340
341
// 4/ add edges so that each node and its parents in the BN form a clique
342
for
(
const
auto
node
:
graph__
) {
343
const
NodeSet
&
parents
=
bn
.
parents
(
node
);
344
for
(
auto
iter1
=
parents
.
cbegin
();
iter1
!=
parents
.
cend
(); ++
iter1
) {
345
// before adding an edge between node and its parent, check that the
346
// parent belong to the graph. Actually, when d-separated nodes are
347
// removed, it may be the case that the parents of hard evidence nodes
348
// are removed. But the latter still exist in the graph.
349
if
(
graph__
.
existsNode
(*
iter1
))
graph__
.
addEdge
(*
iter1
,
node
);
350
351
auto
iter2
=
iter1
;
352
for
(++
iter2
;
iter2
!=
parents
.
cend
(); ++
iter2
) {
353
// before adding an edge, check that both extremities belong to
354
// the graph. Actually, when d-separated nodes are removed, it may
355
// be the case that the parents of hard evidence nodes are removed.
356
// But the latter still exist in the graph.
357
if
(
graph__
.
existsNode
(*
iter1
) &&
graph__
.
existsNode
(*
iter2
))
358
graph__
.
addEdge
(*
iter1
, *
iter2
);
359
}
360
}
361
}
362
363
// 5/ if targets contains several nodes, we shall add new edges into the
364
// moral graph in order to ensure that there exists a clique containing
365
// thier joint distribution
366
for
(
auto
iter1
=
targets
.
cbegin
();
iter1
!=
targets
.
cend
(); ++
iter1
) {
367
auto
iter2
=
iter1
;
368
for
(++
iter2
;
iter2
!=
targets
.
cend
(); ++
iter2
) {
369
graph__
.
addEdge
(*
iter1
, *
iter2
);
370
}
371
}
372
373
// 6/ remove all the nodes that received hard evidence
374
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
375
graph__
.
eraseNode
(
node
);
376
}
377
378
379
// now, we can compute the new junction tree.
380
if
(
JT__
!=
nullptr
)
delete
JT__
;
381
triangulation__
->
setGraph
(&
graph__
, &(
this
->
domainSizes
()));
382
const
JunctionTree
&
triang_jt
=
triangulation__
->
junctionTree
();
383
JT__
=
new
CliqueGraph
(
triang_jt
);
384
385
// indicate, for each node of the moral graph a clique in JT__ that can
386
// contain its conditional probability table
387
node_to_clique__
.
clear
();
388
clique_potentials__
.
clear
();
389
NodeSet
emptyset
;
390
for
(
auto
clique
: *
JT__
)
391
clique_potentials__
.
insert
(
clique
,
emptyset
);
392
const
std
::
vector
<
NodeId
>&
JT_elim_order
393
=
triangulation__
->
eliminationOrder
();
394
NodeProperty
<
Size
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
395
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
;
396
++
i
)
397
elim_order
.
insert
(
JT_elim_order
[
i
],
NodeId
(
i
));
398
const
DAG
&
dag
=
bn
.
dag
();
399
for
(
const
auto
node
:
graph__
) {
400
// get the variables in the potential of node (and its parents)
401
NodeId
first_eliminated_node
=
node
;
402
Size
elim_number
=
elim_order
[
first_eliminated_node
];
403
404
for
(
const
auto
parent
:
dag
.
parents
(
node
)) {
405
if
(
graph__
.
existsNode
(
parent
) && (
elim_order
[
parent
] <
elim_number
)) {
406
elim_number
=
elim_order
[
parent
];
407
first_eliminated_node
=
parent
;
408
}
409
}
410
411
// first_eliminated_node contains the first var (node or one of its
412
// parents) eliminated => the clique created during its elimination
413
// contains node and all of its parents => it can contain the potential
414
// assigned to the node in the BN
415
NodeId
clique
416
=
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
);
417
node_to_clique__
.
insert
(
node
,
clique
);
418
clique_potentials__
[
clique
].
insert
(
node
);
419
}
420
421
// do the same for the nodes that received evidence. Here, we only store
422
// the nodes whose at least one parent belongs to graph__ (otherwise
423
// their CPT is just a constant real number).
424
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
425
// get the set of parents of the node that belong to graph__
426
NodeSet
pars
(
dag
.
parents
(
node
).
size
());
427
for
(
const
auto
par
:
dag
.
parents
(
node
))
428
if
(
graph__
.
exists
(
par
))
pars
.
insert
(
par
);
429
430
if
(!
pars
.
empty
()) {
431
NodeId
first_eliminated_node
= *(
pars
.
begin
());
432
Size
elim_number
=
elim_order
[
first_eliminated_node
];
433
434
for
(
const
auto
parent
:
pars
) {
435
if
(
elim_order
[
parent
] <
elim_number
) {
436
elim_number
=
elim_order
[
parent
];
437
first_eliminated_node
=
parent
;
438
}
439
}
440
441
// first_eliminated_node contains the first var (node or one of its
442
// parents) eliminated => the clique created during its elimination
443
// contains node and all of its parents => it can contain the potential
444
// assigned to the node in the BN
445
NodeId
clique
446
=
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
);
447
node_to_clique__
.
insert
(
node
,
clique
);
448
clique_potentials__
[
clique
].
insert
(
node
);
449
}
450
}
451
452
453
// indicate a clique that contains all the nodes of targets
454
targets2clique__
=
std
::
numeric_limits
<
NodeId
>::
max
();
455
{
456
// remove from set all the nodes that received hard evidence (since they
457
// do not belong to the join tree)
458
NodeSet
nodeset
=
targets
;
459
for
(
const
auto
node
:
this
->
hardEvidenceNodes
())
460
if
(
nodeset
.
contains
(
node
))
nodeset
.
erase
(
node
);
461
462
if
(!
nodeset
.
empty
()) {
463
NodeId
first_eliminated_node
= *(
nodeset
.
begin
());
464
Size
elim_number
=
elim_order
[
first_eliminated_node
];
465
for
(
const
auto
node
:
nodeset
) {
466
if
(
elim_order
[
node
] <
elim_number
) {
467
elim_number
=
elim_order
[
node
];
468
first_eliminated_node
=
node
;
469
}
470
}
471
targets2clique__
472
=
triangulation__
->
createdJunctionTreeClique
(
first_eliminated_node
);
473
}
474
}
475
}
476
477
478
/// prepare the inference structures w.r.t. new targets, soft/hard evidence
479
template
<
typename
GUM_SCALAR
>
480
void
VariableElimination
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {}
481
482
483
/// update the potentials stored in the cliques and invalidate outdated
484
/// messages
485
template
<
typename
GUM_SCALAR
>
486
void
VariableElimination
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {}
487
488
489
// find the potentials d-connected to a set of variables
490
template
<
typename
GUM_SCALAR
>
491
void
VariableElimination
<
GUM_SCALAR
>::
findRelevantPotentialsGetAll__
(
492
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
493
Set
<
const
DiscreteVariable
* >&
kept_vars
) {}
494
495
496
// find the potentials d-connected to a set of variables
497
template
<
typename
GUM_SCALAR
>
498
void
VariableElimination
<
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation__
(
499
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
500
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
501
// find the node ids of the kept variables
502
NodeSet
kept_ids
;
503
const
auto
&
bn
=
this
->
BN
();
504
for
(
const
auto
var
:
kept_vars
) {
505
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
506
}
507
508
// determine the set of potentials d-connected with the kept variables
509
NodeSet
requisite_nodes
;
510
BayesBall
::
requisiteNodes
(
bn
.
dag
(),
511
kept_ids
,
512
this
->
hardEvidenceNodes
(),
513
this
->
softEvidenceNodes
(),
514
requisite_nodes
);
515
for
(
auto
iter
=
pot_list
.
beginSafe
();
iter
!=
pot_list
.
endSafe
(); ++
iter
) {
516
const
Sequence
<
const
DiscreteVariable
* >&
vars
517
= (**
iter
).
variablesSequence
();
518
bool
found
=
false
;
519
for
(
auto
var
:
vars
) {
520
if
(
requisite_nodes
.
exists
(
bn
.
nodeId
(*
var
))) {
521
found
=
true
;
522
break
;
523
}
524
}
525
526
if
(!
found
) {
pot_list
.
erase
(
iter
); }
527
}
528
}
529
530
531
// find the potentials d-connected to a set of variables
532
template
<
typename
GUM_SCALAR
>
533
void
VariableElimination
<
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation2__
(
534
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
535
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
536
// find the node ids of the kept variables
537
NodeSet
kept_ids
;
538
const
auto
&
bn
=
this
->
BN
();
539
for
(
const
auto
var
:
kept_vars
) {
540
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
541
}
542
543
// determine the set of potentials d-connected with the kept variables
544
BayesBall
::
relevantPotentials
(
bn
,
545
kept_ids
,
546
this
->
hardEvidenceNodes
(),
547
this
->
softEvidenceNodes
(),
548
pot_list
);
549
}
550
551
552
// find the potentials d-connected to a set of variables
553
template
<
typename
GUM_SCALAR
>
554
void
VariableElimination
<
GUM_SCALAR
>::
findRelevantPotentialsWithdSeparation3__
(
555
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
556
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
557
// find the node ids of the kept variables
558
NodeSet
kept_ids
;
559
const
auto
&
bn
=
this
->
BN
();
560
for
(
const
auto
var
:
kept_vars
) {
561
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
562
}
563
564
// determine the set of potentials d-connected with the kept variables
565
dSeparation
dsep
;
566
dsep
.
relevantPotentials
(
bn
,
567
kept_ids
,
568
this
->
hardEvidenceNodes
(),
569
this
->
softEvidenceNodes
(),
570
pot_list
);
571
}
572
573
574
// find the potentials d-connected to a set of variables
575
template
<
typename
GUM_SCALAR
>
576
void
VariableElimination
<
GUM_SCALAR
>::
findRelevantPotentialsXX__
(
577
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
578
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
579
switch
(
find_relevant_potential_type__
) {
580
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
581
findRelevantPotentialsWithdSeparation2__
(
pot_list
,
kept_vars
);
582
break
;
583
584
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
585
findRelevantPotentialsWithdSeparation__
(
pot_list
,
kept_vars
);
586
break
;
587
588
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
589
findRelevantPotentialsWithdSeparation3__
(
pot_list
,
kept_vars
);
590
break
;
591
592
case
RelevantPotentialsFinderType
::
FIND_ALL
:
593
findRelevantPotentialsGetAll__
(
pot_list
,
kept_vars
);
594
break
;
595
596
default
:
597
GUM_ERROR
(
FatalError
,
"not implemented yet"
);
598
}
599
}
600
601
602
// remove barren variables
603
template
<
typename
GUM_SCALAR
>
604
Set
<
const
Potential
<
GUM_SCALAR
>* >
605
VariableElimination
<
GUM_SCALAR
>::
removeBarrenVariables__
(
606
PotentialSet__
&
pot_list
,
607
Set
<
const
DiscreteVariable
* >&
del_vars
) {
608
// remove from del_vars the variables that received some evidence:
609
// only those that did not received evidence can be barren variables
610
Set
<
const
DiscreteVariable
* >
the_del_vars
=
del_vars
;
611
for
(
auto
iter
=
the_del_vars
.
beginSafe
();
iter
!=
the_del_vars
.
endSafe
();
612
++
iter
) {
613
NodeId
id
=
this
->
BN
().
nodeId
(**
iter
);
614
if
(
this
->
hardEvidenceNodes
().
exists
(
id
)
615
||
this
->
softEvidenceNodes
().
exists
(
id
)) {
616
the_del_vars
.
erase
(
iter
);
617
}
618
}
619
620
// assign to each random variable the set of potentials that contain it
621
HashTable
<
const
DiscreteVariable
*,
PotentialSet__
>
var2pots
;
622
PotentialSet__
empty_pot_set
;
623
for
(
const
auto
pot
:
pot_list
) {
624
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
625
for
(
const
auto
var
:
vars
) {
626
if
(
the_del_vars
.
exists
(
var
)) {
627
if
(!
var2pots
.
exists
(
var
)) {
var2pots
.
insert
(
var
,
empty_pot_set
); }
628
var2pots
[
var
].
insert
(
pot
);
629
}
630
}
631
}
632
633
// each variable with only one potential is a barren variable
634
// assign to each potential with barren nodes its set of barren variables
635
HashTable
<
const
Potential
<
GUM_SCALAR
>*,
Set
<
const
DiscreteVariable
* > >
636
pot2barren_var
;
637
Set
<
const
DiscreteVariable
* >
empty_var_set
;
638
for
(
auto
elt
:
var2pots
) {
639
if
(
elt
.
second
.
size
() == 1) {
// here we have a barren variable
640
const
Potential
<
GUM_SCALAR
>*
pot
= *(
elt
.
second
.
begin
());
641
if
(!
pot2barren_var
.
exists
(
pot
)) {
642
pot2barren_var
.
insert
(
pot
,
empty_var_set
);
643
}
644
pot2barren_var
[
pot
].
insert
(
elt
.
first
);
// insert the barren variable
645
}
646
}
647
648
// for each potential with barren variables, marginalize them.
649
// if the potential has only barren variables, simply remove them from the
650
// set of potentials, else just project the potential
651
MultiDimProjection
<
GUM_SCALAR
,
Potential
>
projector
(
VENewprojPotential
);
652
PotentialSet__
projected_pots
;
653
for
(
auto
elt
:
pot2barren_var
) {
654
// remove the current potential from pot_list as, anyway, we will change
655
// it
656
const
Potential
<
GUM_SCALAR
>*
pot
=
elt
.
first
;
657
pot_list
.
erase
(
pot
);
658
659
// check whether we need to add a projected new potential or not (i.e.,
660
// whether there exist non-barren variables or not)
661
if
(
pot
->
variablesSequence
().
size
() !=
elt
.
second
.
size
()) {
662
auto
new_pot
=
projector
.
project
(*
pot
,
elt
.
second
);
663
pot_list
.
insert
(
new_pot
);
664
projected_pots
.
insert
(
new_pot
);
665
}
666
}
667
668
return
projected_pots
;
669
}
670
671
672
// performs the collect phase of Lazy Propagation
673
template
<
typename
GUM_SCALAR
>
674
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
675
Set
<
const
Potential
<
GUM_SCALAR
>* > >
676
VariableElimination
<
GUM_SCALAR
>::
collectMessage__
(
NodeId
id
,
NodeId
from
) {
677
// collect messages from all the neighbors
678
std
::
pair
<
PotentialSet__
,
PotentialSet__
>
collect_messages
;
679
for
(
const
auto
other
:
JT__
->
neighbours
(
id
)) {
680
if
(
other
!=
from
) {
681
std
::
pair
<
PotentialSet__
,
PotentialSet__
>
message
(
682
collectMessage__
(
other
,
id
));
683
collect_messages
.
first
+=
message
.
first
;
684
collect_messages
.
second
+=
message
.
second
;
685
}
686
}
687
688
// combine the collect messages with those of id's clique
689
return
produceMessage__
(
id
,
from
,
std
::
move
(
collect_messages
));
690
}
691
692
693
// get the CPT + evidence of a node projected w.r.t. hard evidence
694
template
<
typename
GUM_SCALAR
>
695
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
696
Set
<
const
Potential
<
GUM_SCALAR
>* > >
697
VariableElimination
<
GUM_SCALAR
>::
NodePotentials__
(
NodeId
node
) {
698
std
::
pair
<
PotentialSet__
,
PotentialSet__
>
res
;
699
const
auto
&
bn
=
this
->
BN
();
700
701
// get the CPT's of the node
702
// beware: all the potentials that are defined over some nodes
703
// including hard evidence must be projected so that these nodes are
704
// removed from the potential
705
// also beware that the CPT of a hard evidence node may be defined over
706
// parents that do not belong to graph__ and that are not hard evidence.
707
// In this case, those parents have been removed by d-separation and it is
708
// easy to show that, in this case all the parents have been removed, so
709
// that the CPT does not need to be taken into account
710
const
auto
&
evidence
=
this
->
evidence
();
711
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
712
if
(
graph__
.
exists
(
node
) ||
this
->
hardEvidenceNodes
().
contains
(
node
)) {
713
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
714
const
auto
&
variables
=
cpt
.
variablesSequence
();
715
716
// check if the parents of a hard evidence node do not belong to graph__
717
// and are not themselves hard evidence, discard the CPT, it is useless
718
// for inference
719
if
(
this
->
hardEvidenceNodes
().
contains
(
node
)) {
720
for
(
const
auto
var
:
variables
) {
721
NodeId
xnode
=
bn
.
nodeId
(*
var
);
722
if
(!
this
->
hardEvidenceNodes
().
contains
(
xnode
)
723
&& !
graph__
.
existsNode
(
xnode
))
724
return
res
;
725
}
726
}
727
728
// get the list of nodes with hard evidence in cpt
729
NodeSet
hard_nodes
;
730
for
(
const
auto
var
:
variables
) {
731
NodeId
xnode
=
bn
.
nodeId
(*
var
);
732
if
(
this
->
hardEvidenceNodes
().
contains
(
xnode
))
hard_nodes
.
insert
(
xnode
);
733
}
734
735
// if hard_nodes contains hard evidence nodes, perform a projection
736
// and insert the result into the appropriate clique, else insert
737
// directly cpt into the clique
738
if
(
hard_nodes
.
empty
()) {
739
res
.
first
.
insert
(&
cpt
);
740
}
else
{
741
// marginalize out the hard evidence nodes: if the cpt is defined
742
// only over nodes that received hard evidence, do not consider it
743
// as a potential anymore
744
if
(
hard_nodes
.
size
() !=
variables
.
size
()) {
745
// perform the projection with a combine and project instance
746
Set
<
const
DiscreteVariable
* >
hard_variables
;
747
PotentialSet__
marg_cpt_set
{&
cpt
};
748
for
(
const
auto
xnode
:
hard_nodes
) {
749
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
750
hard_variables
.
insert
(&(
bn
.
variable
(
xnode
)));
751
}
752
// perform the combination of those potentials and their projection
753
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
754
combine_and_project
(
combination_op__
,
VENewprojPotential
);
755
PotentialSet__
new_cpt_list
756
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
757
758
// there should be only one potential in new_cpt_list
759
if
(
new_cpt_list
.
size
() != 1) {
760
// remove the CPT created to avoid memory leaks
761
for
(
auto
pot
:
new_cpt_list
) {
762
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
763
}
764
GUM_ERROR
(
FatalError
,
765
"the projection of a potential containing "
766
<<
"hard evidence is empty!"
);
767
}
768
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
769
res
.
first
.
insert
(
projected_cpt
);
770
res
.
second
.
insert
(
projected_cpt
);
771
}
772
}
773
774
// if the node received some soft evidence, add it
775
if
(
evidence
.
exists
(
node
) && !
hard_evidence
.
exists
(
node
)) {
776
res
.
first
.
insert
(
this
->
evidence
()[
node
]);
777
}
778
}
779
780
return
res
;
781
}
782
783
784
// creates the message sent by clique from_id to clique to_id
785
template
<
typename
GUM_SCALAR
>
786
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
787
Set
<
const
Potential
<
GUM_SCALAR
>* > >
788
VariableElimination
<
GUM_SCALAR
>::
produceMessage__
(
789
NodeId
from_id
,
790
NodeId
to_id
,
791
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
792
Set
<
const
Potential
<
GUM_SCALAR
>* > >&&
incoming_messages
) {
793
// get the messages sent by adjacent nodes to from_id
794
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
795
Set
<
const
Potential
<
GUM_SCALAR
>* > >
796
pot_list
(
std
::
move
(
incoming_messages
));
797
798
// get the potentials of the clique
799
for
(
const
auto
node
:
clique_potentials__
[
from_id
]) {
800
auto
new_pots
=
NodePotentials__
(
node
);
801
pot_list
.
first
+=
new_pots
.
first
;
802
pot_list
.
second
+=
new_pots
.
second
;
803
}
804
805
// if from_id = to_id: this is the endpoint of a collect
806
if
(!
JT__
->
existsEdge
(
from_id
,
to_id
)) {
807
return
pot_list
;
808
}
else
{
809
// get the set of variables that need be removed from the potentials
810
const
NodeSet
&
from_clique
=
JT__
->
clique
(
from_id
);
811
const
NodeSet
&
separator
=
JT__
->
separator
(
from_id
,
to_id
);
812
Set
<
const
DiscreteVariable
* >
del_vars
(
from_clique
.
size
());
813
Set
<
const
DiscreteVariable
* >
kept_vars
(
separator
.
size
());
814
const
auto
&
bn
=
this
->
BN
();
815
816
for
(
const
auto
node
:
from_clique
) {
817
if
(!
separator
.
contains
(
node
)) {
818
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
819
}
else
{
820
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
821
}
822
}
823
824
// pot_list now contains all the potentials to multiply and marginalize
825
// => combine the messages
826
PotentialSet__
new_pot_list
827
=
marginalizeOut__
(
pot_list
.
first
,
del_vars
,
kept_vars
);
828
829
/*
830
for the moment, remove this test: due to some optimizations, some
831
potentials might have all their cells greater than 1.
832
833
// remove all the potentials that are equal to ones (as probability
834
// matrix multiplications are tensorial, such potentials are useless)
835
for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
836
++iter) {
837
const auto pot = *iter;
838
if (pot->variablesSequence().size() == 1) {
839
bool is_all_ones = true;
840
for (Instantiation inst(*pot); !inst.end(); ++inst) {
841
if ((*pot)[inst] < one_minus_epsilon__) {
842
is_all_ones = false;
843
break;
844
}
845
}
846
if (is_all_ones) {
847
if (!pot_list.first.exists(pot)) delete pot;
848
new_pot_list.erase(iter);
849
continue;
850
}
851
}
852
}
853
*/
854
855
// remove the unnecessary temporary messages
856
for
(
auto
iter
=
pot_list
.
second
.
beginSafe
();
857
iter
!=
pot_list
.
second
.
endSafe
();
858
++
iter
) {
859
if
(!
new_pot_list
.
contains
(*
iter
)) {
860
delete
*
iter
;
861
pot_list
.
second
.
erase
(
iter
);
862
}
863
}
864
865
// keep track of all the newly created potentials
866
for
(
const
auto
pot
:
new_pot_list
) {
867
if
(!
pot_list
.
first
.
contains
(
pot
)) {
pot_list
.
second
.
insert
(
pot
); }
868
}
869
870
// return the new set of potentials
871
return
std
::
pair
<
PotentialSet__
,
PotentialSet__
>(
872
std
::
move
(
new_pot_list
),
873
std
::
move
(
pot_list
.
second
));
874
}
875
}
876
877
878
// remove variables del_vars from the list of potentials pot_list
879
template
<
typename
GUM_SCALAR
>
880
Set
<
const
Potential
<
GUM_SCALAR
>* >
881
VariableElimination
<
GUM_SCALAR
>::
marginalizeOut__
(
882
Set
<
const
Potential
<
GUM_SCALAR
>* >
pot_list
,
883
Set
<
const
DiscreteVariable
* >&
del_vars
,
884
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
885
// use d-separation analysis to check which potentials shall be combined
886
findRelevantPotentialsXX__
(
pot_list
,
kept_vars
);
887
888
// remove the potentials corresponding to barren variables if we want
889
// to exploit barren nodes
890
PotentialSet__
barren_projected_potentials
;
891
if
(
barren_nodes_type__
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
892
barren_projected_potentials
=
removeBarrenVariables__
(
pot_list
,
del_vars
);
893
}
894
895
// create a combine and project operator that will perform the
896
// marginalization
897
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
898
combination_op__
,
899
projection_op__
);
900
PotentialSet__
new_pot_list
901
=
combine_and_project
.
combineAndProject
(
pot_list
,
del_vars
);
902
903
// remove all the potentials that were created due to projections of
904
// barren nodes and that are not part of the new_pot_list: these
905
// potentials were just temporary potentials
906
for
(
auto
iter
=
barren_projected_potentials
.
beginSafe
();
907
iter
!=
barren_projected_potentials
.
endSafe
();
908
++
iter
) {
909
if
(!
new_pot_list
.
exists
(*
iter
))
delete
*
iter
;
910
}
911
912
// remove all the potentials that have no dimension
913
for
(
auto
iter_pot
=
new_pot_list
.
beginSafe
();
914
iter_pot
!=
new_pot_list
.
endSafe
();
915
++
iter_pot
) {
916
if
((*
iter_pot
)->
variablesSequence
().
size
() == 0) {
917
// as we have already marginalized out variables that received evidence,
918
// it may be the case that, after combining and projecting, some
919
// potentials might be empty. In this case, we shall keep their
920
// constant and remove them from memory
921
// # TODO: keep the constants!
922
delete
*
iter_pot
;
923
new_pot_list
.
erase
(
iter_pot
);
924
}
925
}
926
927
return
new_pot_list
;
928
}
929
930
931
// performs a whole inference
932
template
<
typename
GUM_SCALAR
>
933
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
makeInference_
() {}
934
935
936
/// returns a fresh potential equal to P(1st arg,evidence)
937
template
<
typename
GUM_SCALAR
>
938
Potential
<
GUM_SCALAR
>*
939
VariableElimination
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
NodeId
id
) {
940
const
auto
&
bn
=
this
->
BN
();
941
942
// hard evidence do not belong to the join tree
943
// # TODO: check for sets of inconsistent hard evidence
944
if
(
this
->
hardEvidenceNodes
().
contains
(
id
)) {
945
return
new
Potential
<
GUM_SCALAR
>(*(
this
->
evidence
()[
id
]));
946
}
947
948
// if we still need to perform some inference task, do it
949
createNewJT__
(
NodeSet
{
id
});
950
NodeId
clique_of_id
=
node_to_clique__
[
id
];
951
auto
pot_list
=
collectMessage__
(
clique_of_id
,
clique_of_id
);
952
953
// get the set of variables that need be removed from the potentials
954
const
NodeSet
&
nodes
=
JT__
->
clique
(
clique_of_id
);
955
Set
<
const
DiscreteVariable
* >
kept_vars
{&(
bn
.
variable
(
id
))};
956
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
957
for
(
const
auto
node
:
nodes
) {
958
if
(
node
!=
id
)
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
959
}
960
961
// pot_list now contains all the potentials to multiply and marginalize
962
// => combine the messages
963
PotentialSet__
new_pot_list
964
=
marginalizeOut__
(
pot_list
.
first
,
del_vars
,
kept_vars
);
965
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
966
967
if
(
new_pot_list
.
size
() == 1) {
968
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
969
// if joint already existed, create a copy, so that we can put it into
970
// the target_posterior__ property
971
if
(
pot_list
.
first
.
exists
(
joint
)) {
972
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
973
}
else
{
974
// remove the joint from new_pot_list so that it will not be
975
// removed just after the else block
976
new_pot_list
.
clear
();
977
}
978
}
else
{
979
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
980
combination_op__
);
981
joint
=
fast_combination
.
combine
(
new_pot_list
);
982
}
983
984
// remove the potentials that were created in new_pot_list
985
for
(
auto
pot
:
new_pot_list
)
986
if
(!
pot_list
.
first
.
exists
(
pot
))
delete
pot
;
987
988
// remove all the temporary potentials created in pot_list
989
for
(
auto
pot
:
pot_list
.
second
)
990
delete
pot
;
991
992
// check that the joint posterior is different from a 0 vector: this would
993
// indicate that some hard evidence are not compatible (their joint
994
// probability is equal to 0)
995
bool
nonzero_found
=
false
;
996
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
997
if
((*
joint
)[
inst
]) {
998
nonzero_found
=
true
;
999
break
;
1000
}
1001
}
1002
if
(!
nonzero_found
) {
1003
// remove joint from memory to avoid memory leaks
1004
delete
joint
;
1005
GUM_ERROR
(
IncompatibleEvidence
,
1006
"some evidence entered into the Bayes "
1007
"net are incompatible (their joint proba = 0)"
);
1008
}
1009
1010
return
joint
;
1011
}
1012
1013
1014
/// returns the posterior of a given variable
1015
template
<
typename
GUM_SCALAR
>
1016
const
Potential
<
GUM_SCALAR
>&
1017
VariableElimination
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
1018
// compute the joint posterior and normalize
1019
auto
joint
=
unnormalizedJointPosterior_
(
id
);
1020
if
(
joint
->
sum
() != 1)
// hard test for ReadOnly CPT (as aggregator)
1021
joint
->
normalize
();
1022
1023
if
(
target_posterior__
!=
nullptr
)
delete
target_posterior__
;
1024
target_posterior__
=
joint
;
1025
1026
return
*
joint
;
1027
}
1028
1029
1030
// returns the marginal a posteriori proba of a given node
1031
template
<
typename
GUM_SCALAR
>
1032
Potential
<
GUM_SCALAR
>*
1033
VariableElimination
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
1034
const
NodeSet
&
set
) {
1035
// hard evidence do not belong to the join tree, so extract the nodes
1036
// from targets that are not hard evidence
1037
NodeSet
targets
=
set
,
hard_ev_nodes
;
1038
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
1039
if
(
targets
.
contains
(
node
)) {
1040
targets
.
erase
(
node
);
1041
hard_ev_nodes
.
insert
(
node
);
1042
}
1043
}
1044
1045
// if all the nodes have received hard evidence, then compute the
1046
// joint posterior directly by multiplying the hard evidence potentials
1047
const
auto
&
evidence
=
this
->
evidence
();
1048
if
(
targets
.
empty
()) {
1049
PotentialSet__
pot_list
;
1050
for
(
const
auto
node
:
set
) {
1051
pot_list
.
insert
(
evidence
[
node
]);
1052
}
1053
if
(
pot_list
.
size
() == 1) {
1054
return
new
Potential
<
GUM_SCALAR
>(**(
pot_list
.
begin
()));
1055
}
else
{
1056
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1057
combination_op__
);
1058
return
fast_combination
.
combine
(
pot_list
);
1059
}
1060
}
1061
1062
// if we still need to perform some inference task, do it
1063
createNewJT__
(
set
);
1064
auto
pot_list
=
collectMessage__
(
targets2clique__
,
targets2clique__
);
1065
1066
// get the set of variables that need be removed from the potentials
1067
const
NodeSet
&
nodes
=
JT__
->
clique
(
targets2clique__
);
1068
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1069
Set
<
const
DiscreteVariable
* >
kept_vars
(
targets
.
size
());
1070
const
auto
&
bn
=
this
->
BN
();
1071
for
(
const
auto
node
:
nodes
) {
1072
if
(!
targets
.
contains
(
node
)) {
1073
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1074
}
else
{
1075
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1076
}
1077
}
1078
1079
// pot_list now contains all the potentials to multiply and marginalize
1080
// => combine the messages
1081
PotentialSet__
new_pot_list
1082
=
marginalizeOut__
(
pot_list
.
first
,
del_vars
,
kept_vars
);
1083
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1084
1085
if
((
new_pot_list
.
size
() == 1) &&
hard_ev_nodes
.
empty
()) {
1086
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1087
// if pot already existed, create a copy, so that we can put it into
1088
// the target_posteriors__ property
1089
if
(
pot_list
.
first
.
exists
(
joint
)) {
1090
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1091
}
else
{
1092
// remove the joint from new_pot_list so that it will not be
1093
// removed just after the next else block
1094
new_pot_list
.
clear
();
1095
}
1096
}
else
{
1097
// combine all the potentials in new_pot_list with all the hard evidence
1098
// of the nodes in set
1099
PotentialSet__
new_new_pot_list
=
new_pot_list
;
1100
for
(
const
auto
node
:
hard_ev_nodes
) {
1101
new_new_pot_list
.
insert
(
evidence
[
node
]);
1102
}
1103
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
1104
combination_op__
);
1105
joint
=
fast_combination
.
combine
(
new_new_pot_list
);
1106
}
1107
1108
// remove the potentials that were created in new_pot_list
1109
for
(
auto
pot
:
new_pot_list
)
1110
if
(!
pot_list
.
first
.
exists
(
pot
))
delete
pot
;
1111
1112
// remove all the temporary potentials created in pot_list
1113
for
(
auto
pot
:
pot_list
.
second
)
1114
delete
pot
;
1115
1116
// check that the joint posterior is different from a 0 vector: this would
1117
// indicate that some hard evidence are not compatible
1118
bool
nonzero_found
=
false
;
1119
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1120
if
((*
joint
)[
inst
]) {
1121
nonzero_found
=
true
;
1122
break
;
1123
}
1124
}
1125
if
(!
nonzero_found
) {
1126
// remove joint from memory to avoid memory leaks
1127
delete
joint
;
1128
GUM_ERROR
(
IncompatibleEvidence
,
1129
"some evidence entered into the Bayes "
1130
"net are incompatible (their joint proba = 0)"
);
1131
}
1132
1133
return
joint
;
1134
}
1135
1136
1137
/// returns the posterior of a given set of variables
1138
template
<
typename
GUM_SCALAR
>
1139
const
Potential
<
GUM_SCALAR
>&
1140
VariableElimination
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
set
) {
1141
// compute the joint posterior and normalize
1142
auto
joint
=
unnormalizedJointPosterior_
(
set
);
1143
joint
->
normalize
();
1144
1145
if
(
target_posterior__
!=
nullptr
)
delete
target_posterior__
;
1146
target_posterior__
=
joint
;
1147
1148
return
*
joint
;
1149
}
1150
1151
1152
/// returns the posterior of a given set of variables
1153
template
<
typename
GUM_SCALAR
>
1154
const
Potential
<
GUM_SCALAR
>&
1155
VariableElimination
<
GUM_SCALAR
>::
jointPosterior_
(
1156
const
NodeSet
&
wanted_target
,
1157
const
NodeSet
&
declared_target
) {
1158
return
jointPosterior_
(
wanted_target
);
1159
}
1160
1161
1162
}
/* namespace gum */
1163
1164
#
endif
// DOXYGEN_SHOULD_SKIP_THIS
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669