aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
variableElimination_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of Variable Elimination for inference in
25
* Bayesian networks.
26
*
27
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
30
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
31
32
#
include
<
agrum
/
BN
/
inference
/
variableElimination
.
h
>
33
34
#
include
<
agrum
/
BN
/
algorithms
/
BayesBall
.
h
>
35
#
include
<
agrum
/
BN
/
algorithms
/
dSeparation
.
h
>
36
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
37
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
38
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
39
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimProjection
.
h
>
40
41
42
namespace
gum
{
43
44
45
// default constructor
46
template
<
typename
GUM_SCALAR >
47
INLINE VariableElimination<
GUM_SCALAR
>::
VariableElimination
(
48
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
49
RelevantPotentialsFinderType
relevant_type
,
50
FindBarrenNodesType
barren_type
) :
51
JointTargetedInference
<
GUM_SCALAR
>(
BN
) {
52
// sets the relevant potential and the barren nodes finding algorithm
53
setRelevantPotentialsFinderType
(
relevant_type
);
54
setFindBarrenNodesType
(
barren_type
);
55
56
// create a default triangulation (the user can change it afterwards)
57
_triangulation_
=
new
DefaultTriangulation
;
58
59
// for debugging purposessetRequiredInference
60
GUM_CONSTRUCTOR
(
VariableElimination
);
61
}
62
63
64
// destructor
65
template
<
typename
GUM_SCALAR
>
66
INLINE
VariableElimination
<
GUM_SCALAR
>::~
VariableElimination
() {
67
// remove the junction tree and the triangulation algorithm
68
if
(
_JT_
!=
nullptr
)
delete
_JT_
;
69
delete
_triangulation_
;
70
if
(
_target_posterior_
!=
nullptr
)
delete
_target_posterior_
;
71
72
// for debugging purposes
73
GUM_DESTRUCTOR
(
VariableElimination
);
74
}
75
76
77
/// set a new triangulation algorithm
78
template
<
typename
GUM_SCALAR
>
79
void
VariableElimination
<
GUM_SCALAR
>::
setTriangulation
(
const
Triangulation
&
new_triangulation
) {
80
delete
_triangulation_
;
81
_triangulation_
=
new_triangulation
.
newFactory
();
82
}
83
84
85
/// returns the current join tree used
86
template
<
typename
GUM_SCALAR
>
87
INLINE
const
JunctionTree
*
VariableElimination
<
GUM_SCALAR
>::
junctionTree
(
NodeId
id
) {
88
_createNewJT_
(
NodeSet
{
id
});
89
90
return
_JT_
;
91
}
92
93
94
/// sets the operator for performing the projections
95
template
<
typename
GUM_SCALAR
>
96
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
_setProjectionFunction_
(
Potential
<
GUM_SCALAR
>* (
97
*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
const
Set
<
const
DiscreteVariable
* >&)) {
98
_projection_op_
=
proj
;
99
}
100
101
102
/// sets the operator for performing the combinations
103
template
<
typename
GUM_SCALAR
>
104
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
_setCombinationFunction_
(
Potential
<
GUM_SCALAR
>* (
105
*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
const
Potential
<
GUM_SCALAR
>&)) {
106
_combination_op_
=
comb
;
107
}
108
109
110
/// sets how we determine the relevant potentials to combine
111
template
<
typename
GUM_SCALAR
>
112
void
VariableElimination
<
GUM_SCALAR
>::
setRelevantPotentialsFinderType
(
113
RelevantPotentialsFinderType
type
) {
114
if
(
type
!=
_find_relevant_potential_type_
) {
115
switch
(
type
) {
116
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
117
_findRelevantPotentials_
118
= &
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation2_
;
119
break
;
120
121
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
122
_findRelevantPotentials_
123
= &
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation_
;
124
break
;
125
126
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
127
_findRelevantPotentials_
128
= &
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation3_
;
129
break
;
130
131
case
RelevantPotentialsFinderType
::
FIND_ALL
:
132
_findRelevantPotentials_
133
= &
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsGetAll_
;
134
break
;
135
136
default
:
137
GUM_ERROR
(
InvalidArgument
,
138
"setRelevantPotentialsFinderType for type "
<< (
unsigned
int
)
type
139
<<
" is not implemented yet"
);
140
}
141
142
_find_relevant_potential_type_
=
type
;
143
}
144
}
145
146
147
/// sets how we determine barren nodes
148
template
<
typename
GUM_SCALAR
>
149
void
VariableElimination
<
GUM_SCALAR
>::
setFindBarrenNodesType
(
FindBarrenNodesType
type
) {
150
if
(
type
!=
_barren_nodes_type_
) {
151
// WARNING: if a new type is added here, method _createJT_ should
152
// certainly
153
// be updated as well, in particular its step 2.
154
switch
(
type
) {
155
case
FindBarrenNodesType
::
FIND_BARREN_NODES
:
156
case
FindBarrenNodesType
::
FIND_NO_BARREN_NODES
:
157
break
;
158
159
default
:
160
GUM_ERROR
(
InvalidArgument
,
161
"setFindBarrenNodesType for type "
<< (
unsigned
int
)
type
162
<<
" is not implemented yet"
);
163
}
164
165
_barren_nodes_type_
=
type
;
166
}
167
}
168
169
170
/// fired when a new evidence is inserted
171
template
<
typename
GUM_SCALAR
>
172
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
,
bool
) {}
173
174
175
/// fired when an evidence is removed
176
template
<
typename
GUM_SCALAR
>
177
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onEvidenceErased_
(
const
NodeId
,
bool
) {}
178
179
180
/// fired when all the evidence are erased
181
template
<
typename
GUM_SCALAR
>
182
void
VariableElimination
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
bool
) {}
183
184
185
/// fired when an evidence is changed
186
template
<
typename
GUM_SCALAR
>
187
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onEvidenceChanged_
(
const
NodeId
,
bool
) {}
188
189
190
/// fired after a new target is inserted
191
template
<
typename
GUM_SCALAR
>
192
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
const
NodeId
) {}
193
194
195
/// fired before a target is removed
196
template
<
typename
GUM_SCALAR
>
197
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
const
NodeId
) {}
198
199
/// fired after a new Bayes net has been assigned to the engine
200
template
<
typename
GUM_SCALAR
>
201
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
bn
) {}
202
203
/// fired after a new set target is inserted
204
template
<
typename
GUM_SCALAR
>
205
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onJointTargetAdded_
(
const
NodeSet
&) {}
206
207
208
/// fired before a set target is removed
209
template
<
typename
GUM_SCALAR
>
210
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onJointTargetErased_
(
const
NodeSet
&) {}
211
212
213
/// fired after all the nodes of the BN are added as single targets
214
template
<
typename
GUM_SCALAR
>
215
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {}
216
217
218
/// fired before a all the single_targets are removed
219
template
<
typename
GUM_SCALAR
>
220
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
221
222
223
/// fired before a all the joint_targets are removed
224
template
<
typename
GUM_SCALAR
>
225
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllJointTargetsErased_
() {}
226
227
228
/// fired before a all the single and joint_targets are removed
229
template
<
typename
GUM_SCALAR
>
230
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
onAllTargetsErased_
() {}
231
232
233
/// create a new junction tree as well as its related data structures
234
template
<
typename
GUM_SCALAR
>
235
void
VariableElimination
<
GUM_SCALAR
>::
_createNewJT_
(
const
NodeSet
&
targets
) {
236
// to create the JT, we first create the moral graph of the BN in the
237
// following way in order to take into account the barren nodes and the
238
// nodes that received evidence:
239
// 1/ we create an undirected graph containing only the nodes and no edge
240
// 2/ if we take into account barren nodes, remove them from the graph
241
// 3/ if we take d-separation into account, remove the d-separated nodes
242
// 4/ add edges so that each node and its parents in the BN form a clique
243
// 5/ add edges so that the targets form a clique of the moral graph
244
// 6/ remove the nodes that received hard evidence (by step 4/, their
245
// parents are linked by edges, which is necessary for inference)
246
//
247
// At the end of step 6/, we have our moral graph and we can triangulate it
248
// to get the new junction tree
249
250
// 1/ create an undirected graph containing only the nodes and no edge
251
const
auto
&
bn
=
this
->
BN
();
252
_graph_
.
clear
();
253
for
(
auto
node
:
bn
.
dag
())
254
_graph_
.
addNodeWithId
(
node
);
255
256
// 2/ if we wish to exploit barren nodes, we shall remove them from the BN
257
// to do so: we identify all the nodes that are not targets and have
258
// received no evidence and such that their descendants are neither targets
259
// nor evidence nodes. Such nodes can be safely discarded from the BN
260
// without altering the inference output
261
if
(
_barren_nodes_type_
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
262
// check that all the nodes are not targets, otherwise, there is no
263
// barren node
264
if
(
targets
.
size
() !=
bn
.
size
()) {
265
BarrenNodesFinder
finder
(&(
bn
.
dag
()));
266
finder
.
setTargets
(&
targets
);
267
268
NodeSet
evidence_nodes
;
269
for
(
const
auto
&
pair
:
this
->
evidence
()) {
270
evidence_nodes
.
insert
(
pair
.
first
);
271
}
272
finder
.
setEvidence
(&
evidence_nodes
);
273
274
NodeSet
barren_nodes
=
finder
.
barrenNodes
();
275
276
// remove the barren nodes from the moral graph
277
for
(
const
auto
node
:
barren_nodes
) {
278
_graph_
.
eraseNode
(
node
);
279
}
280
}
281
}
282
283
// 3/ if we wish to exploit d-separation, remove all the nodes that are
284
// d-separated from our targets
285
{
286
NodeSet
requisite_nodes
;
287
bool
dsep_analysis
=
false
;
288
switch
(
_find_relevant_potential_type_
) {
289
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
290
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
: {
291
BayesBall
::
requisiteNodes
(
bn
.
dag
(),
292
targets
,
293
this
->
hardEvidenceNodes
(),
294
this
->
softEvidenceNodes
(),
295
requisite_nodes
);
296
dsep_analysis
=
true
;
297
}
break
;
298
299
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
: {
300
dSeparation
dsep
;
301
dsep
.
requisiteNodes
(
bn
.
dag
(),
302
targets
,
303
this
->
hardEvidenceNodes
(),
304
this
->
softEvidenceNodes
(),
305
requisite_nodes
);
306
dsep_analysis
=
true
;
307
}
break
;
308
309
case
RelevantPotentialsFinderType
::
FIND_ALL
:
310
break
;
311
312
default
:
313
GUM_ERROR
(
FatalError
,
"not implemented yet"
)
314
}
315
316
// remove all the nodes that are not requisite
317
if
(
dsep_analysis
) {
318
for
(
auto
iter
=
_graph_
.
beginSafe
();
iter
!=
_graph_
.
endSafe
(); ++
iter
) {
319
if
(!
requisite_nodes
.
contains
(*
iter
) && !
this
->
hardEvidenceNodes
().
contains
(*
iter
)) {
320
_graph_
.
eraseNode
(*
iter
);
321
}
322
}
323
}
324
}
325
326
// 4/ add edges so that each node and its parents in the BN form a clique
327
for
(
const
auto
node
:
_graph_
) {
328
const
NodeSet
&
parents
=
bn
.
parents
(
node
);
329
for
(
auto
iter1
=
parents
.
cbegin
();
iter1
!=
parents
.
cend
(); ++
iter1
) {
330
// before adding an edge between node and its parent, check that the
331
// parent belong to the graph. Actually, when d-separated nodes are
332
// removed, it may be the case that the parents of hard evidence nodes
333
// are removed. But the latter still exist in the graph.
334
if
(
_graph_
.
existsNode
(*
iter1
))
_graph_
.
addEdge
(*
iter1
,
node
);
335
336
auto
iter2
=
iter1
;
337
for
(++
iter2
;
iter2
!=
parents
.
cend
(); ++
iter2
) {
338
// before adding an edge, check that both extremities belong to
339
// the graph. Actually, when d-separated nodes are removed, it may
340
// be the case that the parents of hard evidence nodes are removed.
341
// But the latter still exist in the graph.
342
if
(
_graph_
.
existsNode
(*
iter1
) &&
_graph_
.
existsNode
(*
iter2
))
343
_graph_
.
addEdge
(*
iter1
, *
iter2
);
344
}
345
}
346
}
347
348
// 5/ if targets contains several nodes, we shall add new edges into the
349
// moral graph in order to ensure that there exists a clique containing
350
// thier joint distribution
351
for
(
auto
iter1
=
targets
.
cbegin
();
iter1
!=
targets
.
cend
(); ++
iter1
) {
352
auto
iter2
=
iter1
;
353
for
(++
iter2
;
iter2
!=
targets
.
cend
(); ++
iter2
) {
354
_graph_
.
addEdge
(*
iter1
, *
iter2
);
355
}
356
}
357
358
// 6/ remove all the nodes that received hard evidence
359
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
360
_graph_
.
eraseNode
(
node
);
361
}
362
363
364
// now, we can compute the new junction tree.
365
if
(
_JT_
!=
nullptr
)
delete
_JT_
;
366
_triangulation_
->
setGraph
(&
_graph_
, &(
this
->
domainSizes
()));
367
const
JunctionTree
&
triang_jt
=
_triangulation_
->
junctionTree
();
368
_JT_
=
new
CliqueGraph
(
triang_jt
);
369
370
// indicate, for each node of the moral graph a clique in _JT_ that can
371
// contain its conditional probability table
372
_node_to_clique_
.
clear
();
373
_clique_potentials_
.
clear
();
374
NodeSet
emptyset
;
375
for
(
auto
clique
: *
_JT_
)
376
_clique_potentials_
.
insert
(
clique
,
emptyset
);
377
const
std
::
vector
<
NodeId
>&
JT_elim_order
=
_triangulation_
->
eliminationOrder
();
378
NodeProperty
<
Size
>
elim_order
(
Size
(
JT_elim_order
.
size
()));
379
for
(
std
::
size_t
i
=
std
::
size_t
(0),
size
=
JT_elim_order
.
size
();
i
<
size
; ++
i
)
380
elim_order
.
insert
(
JT_elim_order
[
i
],
NodeId
(
i
));
381
const
DAG
&
dag
=
bn
.
dag
();
382
for
(
const
auto
node
:
_graph_
) {
383
// get the variables in the potential of node (and its parents)
384
NodeId
first_eliminated_node
=
node
;
385
Size
elim_number
=
elim_order
[
first_eliminated_node
];
386
387
for
(
const
auto
parent
:
dag
.
parents
(
node
)) {
388
if
(
_graph_
.
existsNode
(
parent
) && (
elim_order
[
parent
] <
elim_number
)) {
389
elim_number
=
elim_order
[
parent
];
390
first_eliminated_node
=
parent
;
391
}
392
}
393
394
// first_eliminated_node contains the first var (node or one of its
395
// parents) eliminated => the clique created during its elimination
396
// contains node and all of its parents => it can contain the potential
397
// assigned to the node in the BN
398
NodeId
clique
=
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
);
399
_node_to_clique_
.
insert
(
node
,
clique
);
400
_clique_potentials_
[
clique
].
insert
(
node
);
401
}
402
403
// do the same for the nodes that received evidence. Here, we only store
404
// the nodes whose at least one parent belongs to _graph_ (otherwise
405
// their CPT is just a constant real number).
406
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
407
// get the set of parents of the node that belong to _graph_
408
NodeSet
pars
(
dag
.
parents
(
node
).
size
());
409
for
(
const
auto
par
:
dag
.
parents
(
node
))
410
if
(
_graph_
.
exists
(
par
))
pars
.
insert
(
par
);
411
412
if
(!
pars
.
empty
()) {
413
NodeId
first_eliminated_node
= *(
pars
.
begin
());
414
Size
elim_number
=
elim_order
[
first_eliminated_node
];
415
416
for
(
const
auto
parent
:
pars
) {
417
if
(
elim_order
[
parent
] <
elim_number
) {
418
elim_number
=
elim_order
[
parent
];
419
first_eliminated_node
=
parent
;
420
}
421
}
422
423
// first_eliminated_node contains the first var (node or one of its
424
// parents) eliminated => the clique created during its elimination
425
// contains node and all of its parents => it can contain the potential
426
// assigned to the node in the BN
427
NodeId
clique
=
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
);
428
_node_to_clique_
.
insert
(
node
,
clique
);
429
_clique_potentials_
[
clique
].
insert
(
node
);
430
}
431
}
432
433
434
// indicate a clique that contains all the nodes of targets
435
_targets2clique_
=
std
::
numeric_limits
<
NodeId
>::
max
();
436
{
437
// remove from set all the nodes that received hard evidence (since they
438
// do not belong to the join tree)
439
NodeSet
nodeset
=
targets
;
440
for
(
const
auto
node
:
this
->
hardEvidenceNodes
())
441
if
(
nodeset
.
contains
(
node
))
nodeset
.
erase
(
node
);
442
443
if
(!
nodeset
.
empty
()) {
444
NodeId
first_eliminated_node
= *(
nodeset
.
begin
());
445
Size
elim_number
=
elim_order
[
first_eliminated_node
];
446
for
(
const
auto
node
:
nodeset
) {
447
if
(
elim_order
[
node
] <
elim_number
) {
448
elim_number
=
elim_order
[
node
];
449
first_eliminated_node
=
node
;
450
}
451
}
452
_targets2clique_
=
_triangulation_
->
createdJunctionTreeClique
(
first_eliminated_node
);
453
}
454
}
455
}
456
457
458
/// prepare the inference structures w.r.t. new targets, soft/hard evidence
459
template
<
typename
GUM_SCALAR
>
460
void
VariableElimination
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {}
461
462
463
/// update the potentials stored in the cliques and invalidate outdated
464
/// messages
465
template
<
typename
GUM_SCALAR
>
466
void
VariableElimination
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {}
467
468
469
// find the potentials d-connected to a set of variables
470
template
<
typename
GUM_SCALAR
>
471
void
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsGetAll_
(
472
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
473
Set
<
const
DiscreteVariable
* >&
kept_vars
) {}
474
475
476
// find the potentials d-connected to a set of variables
477
template
<
typename
GUM_SCALAR
>
478
void
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation_
(
479
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
480
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
481
// find the node ids of the kept variables
482
NodeSet
kept_ids
;
483
const
auto
&
bn
=
this
->
BN
();
484
for
(
const
auto
var
:
kept_vars
) {
485
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
486
}
487
488
// determine the set of potentials d-connected with the kept variables
489
NodeSet
requisite_nodes
;
490
BayesBall
::
requisiteNodes
(
bn
.
dag
(),
491
kept_ids
,
492
this
->
hardEvidenceNodes
(),
493
this
->
softEvidenceNodes
(),
494
requisite_nodes
);
495
for
(
auto
iter
=
pot_list
.
beginSafe
();
iter
!=
pot_list
.
endSafe
(); ++
iter
) {
496
const
Sequence
<
const
DiscreteVariable
* >&
vars
= (**
iter
).
variablesSequence
();
497
bool
found
=
false
;
498
for
(
auto
var
:
vars
) {
499
if
(
requisite_nodes
.
exists
(
bn
.
nodeId
(*
var
))) {
500
found
=
true
;
501
break
;
502
}
503
}
504
505
if
(!
found
) {
pot_list
.
erase
(
iter
); }
506
}
507
}
508
509
510
// find the potentials d-connected to a set of variables
511
template
<
typename
GUM_SCALAR
>
512
void
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation2_
(
513
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
514
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
515
// find the node ids of the kept variables
516
NodeSet
kept_ids
;
517
const
auto
&
bn
=
this
->
BN
();
518
for
(
const
auto
var
:
kept_vars
) {
519
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
520
}
521
522
// determine the set of potentials d-connected with the kept variables
523
BayesBall
::
relevantPotentials
(
bn
,
524
kept_ids
,
525
this
->
hardEvidenceNodes
(),
526
this
->
softEvidenceNodes
(),
527
pot_list
);
528
}
529
530
531
// find the potentials d-connected to a set of variables
532
template
<
typename
GUM_SCALAR
>
533
void
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsWithdSeparation3_
(
534
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
535
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
536
// find the node ids of the kept variables
537
NodeSet
kept_ids
;
538
const
auto
&
bn
=
this
->
BN
();
539
for
(
const
auto
var
:
kept_vars
) {
540
kept_ids
.
insert
(
bn
.
nodeId
(*
var
));
541
}
542
543
// determine the set of potentials d-connected with the kept variables
544
dSeparation
dsep
;
545
dsep
.
relevantPotentials
(
bn
,
546
kept_ids
,
547
this
->
hardEvidenceNodes
(),
548
this
->
softEvidenceNodes
(),
549
pot_list
);
550
}
551
552
553
// find the potentials d-connected to a set of variables
554
template
<
typename
GUM_SCALAR
>
555
void
VariableElimination
<
GUM_SCALAR
>::
_findRelevantPotentialsXX_
(
556
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
557
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
558
switch
(
_find_relevant_potential_type_
) {
559
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
:
560
_findRelevantPotentialsWithdSeparation2_
(
pot_list
,
kept_vars
);
561
break
;
562
563
case
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_NODES
:
564
_findRelevantPotentialsWithdSeparation_
(
pot_list
,
kept_vars
);
565
break
;
566
567
case
RelevantPotentialsFinderType
::
DSEP_KOLLER_FRIEDMAN_2009
:
568
_findRelevantPotentialsWithdSeparation3_
(
pot_list
,
kept_vars
);
569
break
;
570
571
case
RelevantPotentialsFinderType
::
FIND_ALL
:
572
_findRelevantPotentialsGetAll_
(
pot_list
,
kept_vars
);
573
break
;
574
575
default
:
576
GUM_ERROR
(
FatalError
,
"not implemented yet"
)
577
}
578
}
579
580
581
// remove barren variables
582
template
<
typename
GUM_SCALAR
>
583
Set
<
const
Potential
<
GUM_SCALAR
>* >
VariableElimination
<
GUM_SCALAR
>::
_removeBarrenVariables_
(
584
_PotentialSet_
&
pot_list
,
585
Set
<
const
DiscreteVariable
* >&
del_vars
) {
586
// remove from del_vars the variables that received some evidence:
587
// only those that did not received evidence can be barren variables
588
Set
<
const
DiscreteVariable
* >
the_del_vars
=
del_vars
;
589
for
(
auto
iter
=
the_del_vars
.
beginSafe
();
iter
!=
the_del_vars
.
endSafe
(); ++
iter
) {
590
NodeId
id
=
this
->
BN
().
nodeId
(**
iter
);
591
if
(
this
->
hardEvidenceNodes
().
exists
(
id
) ||
this
->
softEvidenceNodes
().
exists
(
id
)) {
592
the_del_vars
.
erase
(
iter
);
593
}
594
}
595
596
// assign to each random variable the set of potentials that contain it
597
HashTable
<
const
DiscreteVariable
*,
_PotentialSet_
>
var2pots
;
598
_PotentialSet_
empty_pot_set
;
599
for
(
const
auto
pot
:
pot_list
) {
600
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
601
for
(
const
auto
var
:
vars
) {
602
if
(
the_del_vars
.
exists
(
var
)) {
603
if
(!
var2pots
.
exists
(
var
)) {
var2pots
.
insert
(
var
,
empty_pot_set
); }
604
var2pots
[
var
].
insert
(
pot
);
605
}
606
}
607
}
608
609
// each variable with only one potential is a barren variable
610
// assign to each potential with barren nodes its set of barren variables
611
HashTable
<
const
Potential
<
GUM_SCALAR
>*,
Set
<
const
DiscreteVariable
* > >
pot2barren_var
;
612
Set
<
const
DiscreteVariable
* >
empty_var_set
;
613
for
(
auto
elt
:
var2pots
) {
614
if
(
elt
.
second
.
size
() == 1) {
// here we have a barren variable
615
const
Potential
<
GUM_SCALAR
>*
pot
= *(
elt
.
second
.
begin
());
616
if
(!
pot2barren_var
.
exists
(
pot
)) {
pot2barren_var
.
insert
(
pot
,
empty_var_set
); }
617
pot2barren_var
[
pot
].
insert
(
elt
.
first
);
// insert the barren variable
618
}
619
}
620
621
// for each potential with barren variables, marginalize them.
622
// if the potential has only barren variables, simply remove them from the
623
// set of potentials, else just project the potential
624
MultiDimProjection
<
GUM_SCALAR
,
Potential
>
projector
(
VENewprojPotential
);
625
_PotentialSet_
projected_pots
;
626
for
(
auto
elt
:
pot2barren_var
) {
627
// remove the current potential from pot_list as, anyway, we will change
628
// it
629
const
Potential
<
GUM_SCALAR
>*
pot
=
elt
.
first
;
630
pot_list
.
erase
(
pot
);
631
632
// check whether we need to add a projected new potential or not (i.e.,
633
// whether there exist non-barren variables or not)
634
if
(
pot
->
variablesSequence
().
size
() !=
elt
.
second
.
size
()) {
635
auto
new_pot
=
projector
.
project
(*
pot
,
elt
.
second
);
636
pot_list
.
insert
(
new_pot
);
637
projected_pots
.
insert
(
new_pot
);
638
}
639
}
640
641
return
projected_pots
;
642
}
643
644
645
// performs the collect phase of Lazy Propagation
646
template
<
typename
GUM_SCALAR
>
647
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
Set
<
const
Potential
<
GUM_SCALAR
>* > >
648
VariableElimination
<
GUM_SCALAR
>::
_collectMessage_
(
NodeId
id
,
NodeId
from
) {
649
// collect messages from all the neighbors
650
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>
collect_messages
;
651
for
(
const
auto
other
:
_JT_
->
neighbours
(
id
)) {
652
if
(
other
!=
from
) {
653
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>
message
(
_collectMessage_
(
other
,
id
));
654
collect_messages
.
first
+=
message
.
first
;
655
collect_messages
.
second
+=
message
.
second
;
656
}
657
}
658
659
// combine the collect messages with those of id's clique
660
return
_produceMessage_
(
id
,
from
,
std
::
move
(
collect_messages
));
661
}
662
663
664
// get the CPT + evidence of a node projected w.r.t. hard evidence
665
template
<
typename
GUM_SCALAR
>
666
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
Set
<
const
Potential
<
GUM_SCALAR
>* > >
667
VariableElimination
<
GUM_SCALAR
>::
_NodePotentials_
(
NodeId
node
) {
668
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>
res
;
669
const
auto
&
bn
=
this
->
BN
();
670
671
// get the CPT's of the node
672
// beware: all the potentials that are defined over some nodes
673
// including hard evidence must be projected so that these nodes are
674
// removed from the potential
675
// also beware that the CPT of a hard evidence node may be defined over
676
// parents that do not belong to _graph_ and that are not hard evidence.
677
// In this case, those parents have been removed by d-separation and it is
678
// easy to show that, in this case all the parents have been removed, so
679
// that the CPT does not need to be taken into account
680
const
auto
&
evidence
=
this
->
evidence
();
681
const
auto
&
hard_evidence
=
this
->
hardEvidence
();
682
if
(
_graph_
.
exists
(
node
) ||
this
->
hardEvidenceNodes
().
contains
(
node
)) {
683
const
Potential
<
GUM_SCALAR
>&
cpt
=
bn
.
cpt
(
node
);
684
const
auto
&
variables
=
cpt
.
variablesSequence
();
685
686
// check if the parents of a hard evidence node do not belong to _graph_
687
// and are not themselves hard evidence, discard the CPT, it is useless
688
// for inference
689
if
(
this
->
hardEvidenceNodes
().
contains
(
node
)) {
690
for
(
const
auto
var
:
variables
) {
691
NodeId
xnode
=
bn
.
nodeId
(*
var
);
692
if
(!
this
->
hardEvidenceNodes
().
contains
(
xnode
) && !
_graph_
.
existsNode
(
xnode
))
return
res
;
693
}
694
}
695
696
// get the list of nodes with hard evidence in cpt
697
NodeSet
hard_nodes
;
698
for
(
const
auto
var
:
variables
) {
699
NodeId
xnode
=
bn
.
nodeId
(*
var
);
700
if
(
this
->
hardEvidenceNodes
().
contains
(
xnode
))
hard_nodes
.
insert
(
xnode
);
701
}
702
703
// if hard_nodes contains hard evidence nodes, perform a projection
704
// and insert the result into the appropriate clique, else insert
705
// directly cpt into the clique
706
if
(
hard_nodes
.
empty
()) {
707
res
.
first
.
insert
(&
cpt
);
708
}
else
{
709
// marginalize out the hard evidence nodes: if the cpt is defined
710
// only over nodes that received hard evidence, do not consider it
711
// as a potential anymore
712
if
(
hard_nodes
.
size
() !=
variables
.
size
()) {
713
// perform the projection with a combine and project instance
714
Set
<
const
DiscreteVariable
* >
hard_variables
;
715
_PotentialSet_
marg_cpt_set
{&
cpt
};
716
for
(
const
auto
xnode
:
hard_nodes
) {
717
marg_cpt_set
.
insert
(
evidence
[
xnode
]);
718
hard_variables
.
insert
(&(
bn
.
variable
(
xnode
)));
719
}
720
// perform the combination of those potentials and their projection
721
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
722
_combination_op_
,
723
VENewprojPotential
);
724
_PotentialSet_
new_cpt_list
725
=
combine_and_project
.
combineAndProject
(
marg_cpt_set
,
hard_variables
);
726
727
// there should be only one potential in new_cpt_list
728
if
(
new_cpt_list
.
size
() != 1) {
729
// remove the CPT created to avoid memory leaks
730
for
(
auto
pot
:
new_cpt_list
) {
731
if
(!
marg_cpt_set
.
contains
(
pot
))
delete
pot
;
732
}
733
GUM_ERROR
(
FatalError
,
734
"the projection of a potential containing "
735
<<
"hard evidence is empty!"
);
736
}
737
const
Potential
<
GUM_SCALAR
>*
projected_cpt
= *(
new_cpt_list
.
begin
());
738
res
.
first
.
insert
(
projected_cpt
);
739
res
.
second
.
insert
(
projected_cpt
);
740
}
741
}
742
743
// if the node received some soft evidence, add it
744
if
(
evidence
.
exists
(
node
) && !
hard_evidence
.
exists
(
node
)) {
745
res
.
first
.
insert
(
this
->
evidence
()[
node
]);
746
}
747
}
748
749
return
res
;
750
}
751
752
753
// creates the message sent by clique from_id to clique to_id
754
template
<
typename
GUM_SCALAR
>
755
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
Set
<
const
Potential
<
GUM_SCALAR
>* > >
756
VariableElimination
<
GUM_SCALAR
>::
_produceMessage_
(
757
NodeId
from_id
,
758
NodeId
to_id
,
759
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
Set
<
const
Potential
<
GUM_SCALAR
>* > >&&
760
incoming_messages
) {
761
// get the messages sent by adjacent nodes to from_id
762
std
::
pair
<
Set
<
const
Potential
<
GUM_SCALAR
>* >,
Set
<
const
Potential
<
GUM_SCALAR
>* > >
763
pot_list
(
std
::
move
(
incoming_messages
));
764
765
// get the potentials of the clique
766
for
(
const
auto
node
:
_clique_potentials_
[
from_id
]) {
767
auto
new_pots
=
_NodePotentials_
(
node
);
768
pot_list
.
first
+=
new_pots
.
first
;
769
pot_list
.
second
+=
new_pots
.
second
;
770
}
771
772
// if from_id = to_id: this is the endpoint of a collect
773
if
(!
_JT_
->
existsEdge
(
from_id
,
to_id
)) {
774
return
pot_list
;
775
}
else
{
776
// get the set of variables that need be removed from the potentials
777
const
NodeSet
&
from_clique
=
_JT_
->
clique
(
from_id
);
778
const
NodeSet
&
separator
=
_JT_
->
separator
(
from_id
,
to_id
);
779
Set
<
const
DiscreteVariable
* >
del_vars
(
from_clique
.
size
());
780
Set
<
const
DiscreteVariable
* >
kept_vars
(
separator
.
size
());
781
const
auto
&
bn
=
this
->
BN
();
782
783
for
(
const
auto
node
:
from_clique
) {
784
if
(!
separator
.
contains
(
node
)) {
785
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
786
}
else
{
787
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
788
}
789
}
790
791
// pot_list now contains all the potentials to multiply and marginalize
792
// => combine the messages
793
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
.
first
,
del_vars
,
kept_vars
);
794
795
/*
796
for the moment, remove this test: due to some optimizations, some
797
potentials might have all their cells greater than 1.
798
799
// remove all the potentials that are equal to ones (as probability
800
// matrix multiplications are tensorial, such potentials are useless)
801
for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
802
++iter) {
803
const auto pot = *iter;
804
if (pot->variablesSequence().size() == 1) {
805
bool is_all_ones = true;
806
for (Instantiation inst(*pot); !inst.end(); ++inst) {
807
if ((*pot)[inst] < _one_minus_epsilon_) {
808
is_all_ones = false;
809
break;
810
}
811
}
812
if (is_all_ones) {
813
if (!pot_list.first.exists(pot)) delete pot;
814
new_pot_list.erase(iter);
815
continue;
816
}
817
}
818
}
819
*/
820
821
// remove the unnecessary temporary messages
822
for
(
auto
iter
=
pot_list
.
second
.
beginSafe
();
iter
!=
pot_list
.
second
.
endSafe
(); ++
iter
) {
823
if
(!
new_pot_list
.
contains
(*
iter
)) {
824
delete
*
iter
;
825
pot_list
.
second
.
erase
(
iter
);
826
}
827
}
828
829
// keep track of all the newly created potentials
830
for
(
const
auto
pot
:
new_pot_list
) {
831
if
(!
pot_list
.
first
.
contains
(
pot
)) {
pot_list
.
second
.
insert
(
pot
); }
832
}
833
834
// return the new set of potentials
835
return
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>(
std
::
move
(
new_pot_list
),
836
std
::
move
(
pot_list
.
second
));
837
}
838
}
839
840
841
// remove variables del_vars from the list of potentials pot_list
842
template
<
typename
GUM_SCALAR
>
843
Set
<
const
Potential
<
GUM_SCALAR
>* >
VariableElimination
<
GUM_SCALAR
>::
_marginalizeOut_
(
844
Set
<
const
Potential
<
GUM_SCALAR
>* >
pot_list
,
845
Set
<
const
DiscreteVariable
* >&
del_vars
,
846
Set
<
const
DiscreteVariable
* >&
kept_vars
) {
847
// use d-separation analysis to check which potentials shall be combined
848
_findRelevantPotentialsXX_
(
pot_list
,
kept_vars
);
849
850
// remove the potentials corresponding to barren variables if we want
851
// to exploit barren nodes
852
_PotentialSet_
barren_projected_potentials
;
853
if
(
_barren_nodes_type_
==
FindBarrenNodesType
::
FIND_BARREN_NODES
) {
854
barren_projected_potentials
=
_removeBarrenVariables_
(
pot_list
,
del_vars
);
855
}
856
857
// create a combine and project operator that will perform the
858
// marginalization
859
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
Potential
>
combine_and_project
(
_combination_op_
,
860
_projection_op_
);
861
_PotentialSet_
new_pot_list
=
combine_and_project
.
combineAndProject
(
pot_list
,
del_vars
);
862
863
// remove all the potentials that were created due to projections of
864
// barren nodes and that are not part of the new_pot_list: these
865
// potentials were just temporary potentials
866
for
(
auto
iter
=
barren_projected_potentials
.
beginSafe
();
867
iter
!=
barren_projected_potentials
.
endSafe
();
868
++
iter
) {
869
if
(!
new_pot_list
.
exists
(*
iter
))
delete
*
iter
;
870
}
871
872
// remove all the potentials that have no dimension
873
for
(
auto
iter_pot
=
new_pot_list
.
beginSafe
();
iter_pot
!=
new_pot_list
.
endSafe
(); ++
iter_pot
) {
874
if
((*
iter_pot
)->
variablesSequence
().
size
() == 0) {
875
// as we have already marginalized out variables that received evidence,
876
// it may be the case that, after combining and projecting, some
877
// potentials might be empty. In this case, we shall keep their
878
// constant and remove them from memory
879
// # TODO: keep the constants!
880
delete
*
iter_pot
;
881
new_pot_list
.
erase
(
iter_pot
);
882
}
883
}
884
885
return
new_pot_list
;
886
}
887
888
889
// performs a whole inference
890
template
<
typename
GUM_SCALAR
>
891
INLINE
void
VariableElimination
<
GUM_SCALAR
>::
makeInference_
() {}
892
893
894
/// returns a fresh potential equal to P(1st arg,evidence)
895
template
<
typename
GUM_SCALAR
>
896
Potential
<
GUM_SCALAR
>*
897
VariableElimination
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
NodeId
id
) {
898
const
auto
&
bn
=
this
->
BN
();
899
900
// hard evidence do not belong to the join tree
901
// # TODO: check for sets of inconsistent hard evidence
902
if
(
this
->
hardEvidenceNodes
().
contains
(
id
)) {
903
return
new
Potential
<
GUM_SCALAR
>(*(
this
->
evidence
()[
id
]));
904
}
905
906
// if we still need to perform some inference task, do it
907
_createNewJT_
(
NodeSet
{
id
});
908
NodeId
clique_of_id
=
_node_to_clique_
[
id
];
909
auto
pot_list
=
_collectMessage_
(
clique_of_id
,
clique_of_id
);
910
911
// get the set of variables that need be removed from the potentials
912
const
NodeSet
&
nodes
=
_JT_
->
clique
(
clique_of_id
);
913
Set
<
const
DiscreteVariable
* >
kept_vars
{&(
bn
.
variable
(
id
))};
914
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
915
for
(
const
auto
node
:
nodes
) {
916
if
(
node
!=
id
)
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
917
}
918
919
// pot_list now contains all the potentials to multiply and marginalize
920
// => combine the messages
921
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
.
first
,
del_vars
,
kept_vars
);
922
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
923
924
if
(
new_pot_list
.
size
() == 1) {
925
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
926
// if joint already existed, create a copy, so that we can put it into
927
// the _target_posterior_ property
928
if
(
pot_list
.
first
.
exists
(
joint
)) {
929
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
930
}
else
{
931
// remove the joint from new_pot_list so that it will not be
932
// removed just after the else block
933
new_pot_list
.
clear
();
934
}
935
}
else
{
936
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
937
joint
=
fast_combination
.
combine
(
new_pot_list
);
938
}
939
940
// remove the potentials that were created in new_pot_list
941
for
(
auto
pot
:
new_pot_list
)
942
if
(!
pot_list
.
first
.
exists
(
pot
))
delete
pot
;
943
944
// remove all the temporary potentials created in pot_list
945
for
(
auto
pot
:
pot_list
.
second
)
946
delete
pot
;
947
948
// check that the joint posterior is different from a 0 vector: this would
949
// indicate that some hard evidence are not compatible (their joint
950
// probability is equal to 0)
951
bool
nonzero_found
=
false
;
952
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
953
if
((*
joint
)[
inst
]) {
954
nonzero_found
=
true
;
955
break
;
956
}
957
}
958
if
(!
nonzero_found
) {
959
// remove joint from memory to avoid memory leaks
960
delete
joint
;
961
GUM_ERROR
(
IncompatibleEvidence
,
962
"some evidence entered into the Bayes "
963
"net are incompatible (their joint proba = 0)"
);
964
}
965
966
return
joint
;
967
}
968
969
970
/// returns the posterior of a given variable
971
template
<
typename
GUM_SCALAR
>
972
const
Potential
<
GUM_SCALAR
>&
VariableElimination
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
973
// compute the joint posterior and normalize
974
auto
joint
=
unnormalizedJointPosterior_
(
id
);
975
if
(
joint
->
sum
() != 1)
// hard test for ReadOnly CPT (as aggregator)
976
joint
->
normalize
();
977
978
if
(
_target_posterior_
!=
nullptr
)
delete
_target_posterior_
;
979
_target_posterior_
=
joint
;
980
981
return
*
joint
;
982
}
983
984
985
// returns the marginal a posteriori proba of a given node
986
template
<
typename
GUM_SCALAR
>
987
Potential
<
GUM_SCALAR
>*
988
VariableElimination
<
GUM_SCALAR
>::
unnormalizedJointPosterior_
(
const
NodeSet
&
set
) {
989
// hard evidence do not belong to the join tree, so extract the nodes
990
// from targets that are not hard evidence
991
NodeSet
targets
=
set
,
hard_ev_nodes
;
992
for
(
const
auto
node
:
this
->
hardEvidenceNodes
()) {
993
if
(
targets
.
contains
(
node
)) {
994
targets
.
erase
(
node
);
995
hard_ev_nodes
.
insert
(
node
);
996
}
997
}
998
999
// if all the nodes have received hard evidence, then compute the
1000
// joint posterior directly by multiplying the hard evidence potentials
1001
const
auto
&
evidence
=
this
->
evidence
();
1002
if
(
targets
.
empty
()) {
1003
_PotentialSet_
pot_list
;
1004
for
(
const
auto
node
:
set
) {
1005
pot_list
.
insert
(
evidence
[
node
]);
1006
}
1007
if
(
pot_list
.
size
() == 1) {
1008
return
new
Potential
<
GUM_SCALAR
>(**(
pot_list
.
begin
()));
1009
}
else
{
1010
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1011
return
fast_combination
.
combine
(
pot_list
);
1012
}
1013
}
1014
1015
// if we still need to perform some inference task, do it
1016
_createNewJT_
(
set
);
1017
auto
pot_list
=
_collectMessage_
(
_targets2clique_
,
_targets2clique_
);
1018
1019
// get the set of variables that need be removed from the potentials
1020
const
NodeSet
&
nodes
=
_JT_
->
clique
(
_targets2clique_
);
1021
Set
<
const
DiscreteVariable
* >
del_vars
(
nodes
.
size
());
1022
Set
<
const
DiscreteVariable
* >
kept_vars
(
targets
.
size
());
1023
const
auto
&
bn
=
this
->
BN
();
1024
for
(
const
auto
node
:
nodes
) {
1025
if
(!
targets
.
contains
(
node
)) {
1026
del_vars
.
insert
(&(
bn
.
variable
(
node
)));
1027
}
else
{
1028
kept_vars
.
insert
(&(
bn
.
variable
(
node
)));
1029
}
1030
}
1031
1032
// pot_list now contains all the potentials to multiply and marginalize
1033
// => combine the messages
1034
_PotentialSet_
new_pot_list
=
_marginalizeOut_
(
pot_list
.
first
,
del_vars
,
kept_vars
);
1035
Potential
<
GUM_SCALAR
>*
joint
=
nullptr
;
1036
1037
if
((
new_pot_list
.
size
() == 1) &&
hard_ev_nodes
.
empty
()) {
1038
joint
=
const_cast
<
Potential
<
GUM_SCALAR
>* >(*(
new_pot_list
.
begin
()));
1039
// if pot already existed, create a copy, so that we can put it into
1040
// the _target_posteriors_ property
1041
if
(
pot_list
.
first
.
exists
(
joint
)) {
1042
joint
=
new
Potential
<
GUM_SCALAR
>(*
joint
);
1043
}
else
{
1044
// remove the joint from new_pot_list so that it will not be
1045
// removed just after the next else block
1046
new_pot_list
.
clear
();
1047
}
1048
}
else
{
1049
// combine all the potentials in new_pot_list with all the hard evidence
1050
// of the nodes in set
1051
_PotentialSet_
new_new_pot_list
=
new_pot_list
;
1052
for
(
const
auto
node
:
hard_ev_nodes
) {
1053
new_new_pot_list
.
insert
(
evidence
[
node
]);
1054
}
1055
MultiDimCombinationDefault
<
GUM_SCALAR
,
Potential
>
fast_combination
(
_combination_op_
);
1056
joint
=
fast_combination
.
combine
(
new_new_pot_list
);
1057
}
1058
1059
// remove the potentials that were created in new_pot_list
1060
for
(
auto
pot
:
new_pot_list
)
1061
if
(!
pot_list
.
first
.
exists
(
pot
))
delete
pot
;
1062
1063
// remove all the temporary potentials created in pot_list
1064
for
(
auto
pot
:
pot_list
.
second
)
1065
delete
pot
;
1066
1067
// check that the joint posterior is different from a 0 vector: this would
1068
// indicate that some hard evidence are not compatible
1069
bool
nonzero_found
=
false
;
1070
for
(
Instantiation
inst
(*
joint
); !
inst
.
end
(); ++
inst
) {
1071
if
((*
joint
)[
inst
]) {
1072
nonzero_found
=
true
;
1073
break
;
1074
}
1075
}
1076
if
(!
nonzero_found
) {
1077
// remove joint from memory to avoid memory leaks
1078
delete
joint
;
1079
GUM_ERROR
(
IncompatibleEvidence
,
1080
"some evidence entered into the Bayes "
1081
"net are incompatible (their joint proba = 0)"
);
1082
}
1083
1084
return
joint
;
1085
}
1086
1087
1088
/// returns the posterior of a given set of variables
1089
template
<
typename
GUM_SCALAR
>
1090
const
Potential
<
GUM_SCALAR
>&
1091
VariableElimination
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
set
) {
1092
// compute the joint posterior and normalize
1093
auto
joint
=
unnormalizedJointPosterior_
(
set
);
1094
joint
->
normalize
();
1095
1096
if
(
_target_posterior_
!=
nullptr
)
delete
_target_posterior_
;
1097
_target_posterior_
=
joint
;
1098
1099
return
*
joint
;
1100
}
1101
1102
1103
/// returns the posterior of a given set of variables
1104
template
<
typename
GUM_SCALAR
>
1105
const
Potential
<
GUM_SCALAR
>&
1106
VariableElimination
<
GUM_SCALAR
>::
jointPosterior_
(
const
NodeSet
&
wanted_target
,
1107
const
NodeSet
&
declared_target
) {
1108
return
jointPosterior_
(
wanted_target
);
1109
}
1110
1111
1112
}
/* namespace gum */
1113
1114
#
endif
// DOXYGEN_SHOULD_SKIP_THIS
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643