aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
variableElimination.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of a variable elimination algorithm
25
* for inference in Bayesian networks.
26
*
27
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
#
ifndef
GUM_VARIABLE_ELIMINATION_H
30
#
define
GUM_VARIABLE_ELIMINATION_H
31
32
#
include
<
utility
>
33
34
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
35
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
36
#
include
<
agrum
/
BN
/
inference
/
tools
/
jointTargetedInference
.
h
>
37
#
include
<
agrum
/
BN
/
inference
/
tools
/
relevantPotentialsFinderType
.
h
>
38
#
include
<
agrum
/
agrum
.
h
>
39
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
triangulations
/
defaultTriangulation
.
h
>
40
41
namespace
gum
{
42
43
44
// the function used to combine two tables
45
template
<
typename
GUM_SCALAR
>
46
INLINE
static
Potential
<
GUM_SCALAR
>*
47
VENewmultiPotential
(
const
Potential
<
GUM_SCALAR
>&
t1
,
48
const
Potential
<
GUM_SCALAR
>&
t2
) {
49
return
new
Potential
<
GUM_SCALAR
>(
t1
*
t2
);
50
}
51
52
// the function used to combine two tables
53
template
<
typename
GUM_SCALAR
>
54
INLINE
static
Potential
<
GUM_SCALAR
>*
55
VENewprojPotential
(
const
Potential
<
GUM_SCALAR
>&
t1
,
56
const
Set
<
const
DiscreteVariable
* >&
del_vars
) {
57
return
new
Potential
<
GUM_SCALAR
>(
t1
.
margSumOut
(
del_vars
));
58
}
59
60
61
/**
62
* @class VariableElimination VariableElimination.h
63
* <agrum/BN/inference/variableElimination.h>
64
* @brief Implementation of a Shafer-Shenoy's-like version of lazy
65
* propagation for inference in Bayesian networks
66
* @ingroup bn_inference
67
*/
68
template
<
typename
GUM_SCALAR
>
69
class
VariableElimination
:
public
JointTargetedInference
<
GUM_SCALAR
> {
70
public
:
71
// ############################################################################
72
/// @name Constructors / Destructors
73
// ############################################################################
74
/// @{
75
76
/// default constructor
77
explicit
VariableElimination
(
78
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
79
RelevantPotentialsFinderType
relevant_type
80
=
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
,
81
FindBarrenNodesType
=
FindBarrenNodesType
::
FIND_BARREN_NODES
);
82
83
/// avoid copy constructors
84
VariableElimination
(
const
VariableElimination
<
GUM_SCALAR
>&) =
delete
;
85
86
/// avoid copy operators
87
VariableElimination
<
GUM_SCALAR
>&
88
operator
=(
const
VariableElimination
<
GUM_SCALAR
>&)
89
=
delete
;
90
91
/// destructor
92
~
VariableElimination
()
final
;
93
94
/// @}
95
96
97
// ############################################################################
98
/// @name Accessors / Modifiers
99
// ############################################################################
100
/// @{
101
102
/// use a new triangulation algorithm
103
void
setTriangulation
(
const
Triangulation
&
new_triangulation
);
104
105
/// sets how we determine the relevant potentials to combine
106
/** When a clique sends a message to a separator, it first constitute the
107
* set of the potentials it contains and of the potentials contained in the
108
* messages it received. If RelevantPotentialsFinderType = FIND_ALL,
109
* all these potentials are combined and projected to produce the message
110
* sent to the separator.
111
* If RelevantPotentialsFinderType = DSEP_BAYESBALL_NODES, then only the
112
* set of potentials d-connected to the variables of the separator are kept
113
* for combination and projection. */
114
void
setRelevantPotentialsFinderType
(
RelevantPotentialsFinderType
type
);
115
116
/// sets how we determine barren nodes
117
/** Barren nodes are unnecessary for probability inference, so they can
118
* be safely discarded in this case (type = FIND_BARREN_NODES). This
119
* speeds-up inference. However, there are some cases in which we do not
120
* want to remove barren nodes, typically when we want to answer queries
121
* such as Most Probable Explanations (MPE). */
122
void
setFindBarrenNodesType
(
FindBarrenNodesType
type
);
123
124
/// returns the join tree used for compute the posterior of node id
125
const
JunctionTree
*
junctionTree
(
NodeId
id
);
126
127
/// @}
128
129
130
protected
:
131
/// fired when the stage is changed
132
void
onStateChanged_
()
final
{};
133
134
/// fired after a new evidence is inserted
135
void
onEvidenceAdded_
(
const
NodeId
id
,
bool
isHardEvidence
)
final
;
136
137
/// fired before an evidence is removed
138
void
onEvidenceErased_
(
const
NodeId
id
,
bool
isHardEvidence
)
final
;
139
140
/// fired before all the evidence are erased
141
void
onAllEvidenceErased_
(
bool
contains_hard_evidence
)
final
;
142
143
/** @brief fired after an evidence is changed, in particular when its status
144
* (soft/hard) changes
145
*
146
* @param nodeId the node of the changed evidence
147
* @param hasChangedSoftHard true if the evidence has changed from Soft to
148
* Hard or from Hard to Soft
149
*/
150
void
onEvidenceChanged_
(
const
NodeId
id
,
bool
hasChangedSoftHard
)
final
;
151
152
/// fired after a new single target is inserted
153
/** @param id The target variable's id. */
154
void
onMarginalTargetAdded_
(
const
NodeId
id
)
final
;
155
156
/// fired before a single target is removed
157
/** @param id The target variable's id. */
158
void
onMarginalTargetErased_
(
const
NodeId
id
)
final
;
159
160
/// fired after a new Bayes net has been assigned to the engine
161
virtual
void
onModelChanged_
(
const
GraphicalModel
*
bn
)
final
;
162
163
/// fired after a new joint target is inserted
164
/** @param set The set of target variable's ids. */
165
void
onJointTargetAdded_
(
const
NodeSet
&
set
)
final
;
166
167
/// fired before a joint target is removed
168
/** @param set The set of target variable's ids. */
169
void
onJointTargetErased_
(
const
NodeSet
&
set
)
final
;
170
171
/// fired after all the nodes of the BN are added as single targets
172
void
onAllMarginalTargetsAdded_
()
final
;
173
174
/// fired before a all the single targets are removed
175
void
onAllMarginalTargetsErased_
()
final
;
176
177
/// fired before a all the joint targets are removed
178
void
onAllJointTargetsErased_
()
final
;
179
180
/// fired before a all single and joint_targets are removed
181
void
onAllTargetsErased_
()
final
;
182
183
/// prepares inference when the latter is in OutdatedStructure state
184
/** Note that the values of evidence are not necessarily
185
* known and can be changed between updateOutdatedStructure_ and
186
* makeInference_. */
187
void
updateOutdatedStructure_
()
final
;
188
189
/// prepares inference when the latter is in OutdatedPotentials state
190
/** Note that the values of evidence are not necessarily
191
* known and can be changed between updateOutdatedPotentials_ and
192
* makeInference_. */
193
void
updateOutdatedPotentials_
()
final
;
194
195
/// called when the inference has to be performed effectively
196
/** Once the inference is done, fillPosterior_ can be called. */
197
void
makeInference_
()
final
;
198
199
200
/// returns the posterior of a given variable
201
/** @param id The variable's id. */
202
const
Potential
<
GUM_SCALAR
>&
posterior_
(
NodeId
id
)
final
;
203
204
/// returns the posterior of a declared target set
205
/** @param set The set of ids of the variables whose joint posterior is
206
* looked for. */
207
const
Potential
<
GUM_SCALAR
>&
jointPosterior_
(
const
NodeSet
&
set
)
final
;
208
209
/** @brief asks derived classes for the joint posterior of a set of
210
* variables not declared as a joint target
211
*
212
* @param wanted_target The set of ids of the variables whose joint
213
* posterior is looked for.
214
* @param declared_target the joint target declared by the user that contains
215
* set */
216
const
Potential
<
GUM_SCALAR
>&
217
jointPosterior_
(
const
NodeSet
&
wanted_target
,
218
const
NodeSet
&
declared_target
)
final
;
219
220
/// returns a fresh potential equal to P(argument,evidence)
221
Potential
<
GUM_SCALAR
>*
unnormalizedJointPosterior_
(
NodeId
id
)
final
;
222
223
/// returns a fresh potential equal to P(argument,evidence)
224
Potential
<
GUM_SCALAR
>*
unnormalizedJointPosterior_
(
const
NodeSet
&
set
)
final
;
225
226
227
private
:
228
typedef
Set
<
const
Potential
<
GUM_SCALAR
>* >
PotentialSet__
;
229
typedef
SetIteratorSafe
<
const
Potential
<
GUM_SCALAR
>* >
230
PotentialSetIterator__
;
231
232
233
/// the type of relevant potential finding algorithm to be used
234
RelevantPotentialsFinderType
find_relevant_potential_type__
;
235
236
/** @brief update a set of potentials: the remaining are those to be
237
* combined to produce a message on a separator */
238
void
(
VariableElimination
<
GUM_SCALAR
>::*
findRelevantPotentials__
)(
239
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
240
Set
<
const
DiscreteVariable
* >&
kept_vars
);
241
242
/// the type of barren nodes computation we wish
243
FindBarrenNodesType
barren_nodes_type__
;
244
245
/// the operator for performing the projections
246
Potential
<
GUM_SCALAR
>* (*
projection_op__
)(
247
const
Potential
<
GUM_SCALAR
>&,
248
const
Set
<
const
DiscreteVariable
* >&){
VENewprojPotential
};
249
250
/// the operator for performing the combinations
251
Potential
<
GUM_SCALAR
>* (*
combination_op__
)(
const
Potential
<
GUM_SCALAR
>&,
252
const
Potential
<
GUM_SCALAR
>&){
253
VENewmultiPotential
};
254
255
/// the triangulation class creating the junction tree used for inference
256
Triangulation
*
triangulation__
;
257
258
/// the undigraph extracted from the BN and used to construct the join tree
259
/** If all nodes are targets, this graph corresponds to the moral graph
260
* of the BN. Otherwise, it may be a subgraph of this moral graph. For
261
* instance if the BN is A->B->C and only B is a target, graph__ will be
262
* equal to A-B if we exploit barren nodes (C is a barren node and,
263
* therefore, can be removed for inference). */
264
UndiGraph
graph__
;
265
266
/// the junction tree used to answer the last inference query
267
JunctionTree
*
JT__
{
nullptr
};
268
269
/// for each node of graph__ (~ in the Bayes net), associate an ID in the JT
270
HashTable
<
NodeId
,
NodeId
>
node_to_clique__
;
271
272
/// for each BN node, indicate in which clique its CPT will be stored
273
HashTable
<
NodeId
,
NodeSet
>
clique_potentials__
;
274
275
/// indicate a clique that contains all the nodes of the target
276
NodeId
targets2clique__
;
277
278
/// the posterior computed during the last inference
279
/** the posterior is owned by VariableElimination. */
280
Potential
<
GUM_SCALAR
>*
target_posterior__
{
nullptr
};
281
282
/// for comparisons with 1 - epsilon
283
const
GUM_SCALAR
one_minus_epsilon__
{
GUM_SCALAR
(1.0 - 1e-6)};
284
285
286
/// create a new junction tree as well as its related data structures
287
void
createNewJT__
(
const
NodeSet
&
targets
);
288
289
/// sets the operator for performing the projections
290
void
setProjectionFunction__
(
291
Potential
<
GUM_SCALAR
>* (*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
292
const
Set
<
const
DiscreteVariable
* >&));
293
294
/// sets the operator for performing the combinations
295
void
setCombinationFunction__
(
Potential
<
GUM_SCALAR
>* (
296
*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
const
Potential
<
GUM_SCALAR
>&));
297
298
/** @brief update a set of potentials: the remaining are those to be
299
* combined
300
* to produce a message on a separator */
301
void
findRelevantPotentialsWithdSeparation__
(
302
PotentialSet__
&
pot_list
,
303
Set
<
const
DiscreteVariable
* >&
kept_vars
);
304
305
/** @brief update a set of potentials: the remaining are those to be
306
* combined
307
* to produce a message on a separator */
308
void
findRelevantPotentialsWithdSeparation2__
(
309
PotentialSet__
&
pot_list
,
310
Set
<
const
DiscreteVariable
* >&
kept_vars
);
311
312
/** @brief update a set of potentials: the remaining are those to be
313
* combined
314
* to produce a message on a separator */
315
void
findRelevantPotentialsWithdSeparation3__
(
316
PotentialSet__
&
pot_list
,
317
Set
<
const
DiscreteVariable
* >&
kept_vars
);
318
319
/** @brief update a set of potentials: the remaining are those to be
320
* combined
321
* to produce a message on a separator */
322
void
findRelevantPotentialsGetAll__
(
PotentialSet__
&
pot_list
,
323
Set
<
const
DiscreteVariable
* >&
kept_vars
);
324
325
/** @brief update a set of potentials: the remaining are those to be
326
* combined
327
* to produce a message on a separator */
328
void
findRelevantPotentialsXX__
(
PotentialSet__
&
pot_list
,
329
Set
<
const
DiscreteVariable
* >&
kept_vars
);
330
331
// remove barren variables and return the newly created projected potentials
332
PotentialSet__
333
removeBarrenVariables__
(
PotentialSet__
&
pot_list
,
334
Set
<
const
DiscreteVariable
* >&
del_vars
);
335
336
/// actually perform the collect phase
337
std
::
pair
<
PotentialSet__
,
PotentialSet__
>
collectMessage__
(
NodeId
id
,
338
NodeId
from
);
339
340
/// returns the CPT + evidence of a node projected w.r.t. hard evidence
341
std
::
pair
<
PotentialSet__
,
PotentialSet__
>
NodePotentials__
(
NodeId
node
);
342
343
/// creates the message sent by clique from_id to clique to_id
344
std
::
pair
<
PotentialSet__
,
PotentialSet__
>
produceMessage__
(
345
NodeId
from_id
,
346
NodeId
to_id
,
347
std
::
pair
<
PotentialSet__
,
PotentialSet__
>&&
incoming_messages
);
348
349
/** @brief removes variables del_vars from a list of potentials and
350
* returns the resulting list */
351
PotentialSet__
marginalizeOut__
(
PotentialSet__
pot_list
,
352
Set
<
const
DiscreteVariable
* >&
del_vars
,
353
Set
<
const
DiscreteVariable
* >&
kept_vars
);
354
};
355
356
357
#
ifndef
GUM_NO_EXTERN_TEMPLATE_CLASS
358
extern
template
class
VariableElimination<
double
>;
359
#
endif
360
361
362
}
/* namespace gum */
363
364
365
#
include
<
agrum
/
BN
/
inference
/
variableElimination_tpl
.
h
>
366
367
368
#
endif
/* GUM_VARIABLE_ELIMINATION_ */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669