aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
variableElimination.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of a variable elimination algorithm
25
* for inference in Bayesian networks.
26
*
27
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
#
ifndef
GUM_VARIABLE_ELIMINATION_H
30
#
define
GUM_VARIABLE_ELIMINATION_H
31
32
#
include
<
utility
>
33
34
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
35
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
36
#
include
<
agrum
/
BN
/
inference
/
tools
/
jointTargetedInference
.
h
>
37
#
include
<
agrum
/
BN
/
inference
/
tools
/
relevantPotentialsFinderType
.
h
>
38
#
include
<
agrum
/
agrum
.
h
>
39
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
triangulations
/
defaultTriangulation
.
h
>
40
41
namespace
gum
{
42
43
44
// the function used to combine two tables
45
template
<
typename
GUM_SCALAR
>
46
INLINE
static
Potential
<
GUM_SCALAR
>*
VENewmultiPotential
(
const
Potential
<
GUM_SCALAR
>&
t1
,
47
const
Potential
<
GUM_SCALAR
>&
t2
) {
48
return
new
Potential
<
GUM_SCALAR
>(
t1
*
t2
);
49
}
50
51
// the function used to combine two tables
52
template
<
typename
GUM_SCALAR
>
53
INLINE
static
Potential
<
GUM_SCALAR
>*
54
VENewprojPotential
(
const
Potential
<
GUM_SCALAR
>&
t1
,
55
const
Set
<
const
DiscreteVariable
* >&
del_vars
) {
56
return
new
Potential
<
GUM_SCALAR
>(
t1
.
margSumOut
(
del_vars
));
57
}
58
59
60
/**
61
* @class VariableElimination VariableElimination.h
62
* <agrum/BN/inference/variableElimination.h>
63
* @brief Implementation of a Shafer-Shenoy's-like version of lazy
64
* propagation for inference in Bayesian networks
65
* @ingroup bn_inference
66
*/
67
template
<
typename
GUM_SCALAR
>
68
class
VariableElimination
:
public
JointTargetedInference
<
GUM_SCALAR
> {
69
public
:
70
// ############################################################################
71
/// @name Constructors / Destructors
72
// ############################################################################
73
/// @{
74
75
/// default constructor
76
explicit
VariableElimination
(
const
IBayesNet
<
GUM_SCALAR
>*
BN
,
77
RelevantPotentialsFinderType
relevant_type
78
=
RelevantPotentialsFinderType
::
DSEP_BAYESBALL_POTENTIALS
,
79
FindBarrenNodesType
=
FindBarrenNodesType
::
FIND_BARREN_NODES
);
80
81
/// avoid copy constructors
82
VariableElimination
(
const
VariableElimination
<
GUM_SCALAR
>&) =
delete
;
83
84
/// avoid copy operators
85
VariableElimination
<
GUM_SCALAR
>&
operator
=(
const
VariableElimination
<
GUM_SCALAR
>&) =
delete
;
86
87
/// destructor
88
~
VariableElimination
()
final
;
89
90
/// @}
91
92
93
// ############################################################################
94
/// @name Accessors / Modifiers
95
// ############################################################################
96
/// @{
97
98
/// use a new triangulation algorithm
99
void
setTriangulation
(
const
Triangulation
&
new_triangulation
);
100
101
/// sets how we determine the relevant potentials to combine
102
/** When a clique sends a message to a separator, it first constitute the
103
* set of the potentials it contains and of the potentials contained in the
104
* messages it received. If RelevantPotentialsFinderType = FIND_ALL,
105
* all these potentials are combined and projected to produce the message
106
* sent to the separator.
107
* If RelevantPotentialsFinderType = DSEP_BAYESBALL_NODES, then only the
108
* set of potentials d-connected to the variables of the separator are kept
109
* for combination and projection. */
110
void
setRelevantPotentialsFinderType
(
RelevantPotentialsFinderType
type
);
111
112
/// sets how we determine barren nodes
113
/** Barren nodes are unnecessary for probability inference, so they can
114
* be safely discarded in this case (type = FIND_BARREN_NODES). This
115
* speeds-up inference. However, there are some cases in which we do not
116
* want to remove barren nodes, typically when we want to answer queries
117
* such as Most Probable Explanations (MPE). */
118
void
setFindBarrenNodesType
(
FindBarrenNodesType
type
);
119
120
/// returns the join tree used for compute the posterior of node id
121
const
JunctionTree
*
junctionTree
(
NodeId
id
);
122
123
/// @}
124
125
126
protected
:
127
/// fired when the stage is changed
128
void
onStateChanged_
()
final
{};
129
130
/// fired after a new evidence is inserted
131
void
onEvidenceAdded_
(
const
NodeId
id
,
bool
isHardEvidence
)
final
;
132
133
/// fired before an evidence is removed
134
void
onEvidenceErased_
(
const
NodeId
id
,
bool
isHardEvidence
)
final
;
135
136
/// fired before all the evidence are erased
137
void
onAllEvidenceErased_
(
bool
contains_hard_evidence
)
final
;
138
139
/** @brief fired after an evidence is changed, in particular when its status
140
* (soft/hard) changes
141
*
142
* @param nodeId the node of the changed evidence
143
* @param hasChangedSoftHard true if the evidence has changed from Soft to
144
* Hard or from Hard to Soft
145
*/
146
void
onEvidenceChanged_
(
const
NodeId
id
,
bool
hasChangedSoftHard
)
final
;
147
148
/// fired after a new single target is inserted
149
/** @param id The target variable's id. */
150
void
onMarginalTargetAdded_
(
const
NodeId
id
)
final
;
151
152
/// fired before a single target is removed
153
/** @param id The target variable's id. */
154
void
onMarginalTargetErased_
(
const
NodeId
id
)
final
;
155
156
/// fired after a new Bayes net has been assigned to the engine
157
virtual
void
onModelChanged_
(
const
GraphicalModel
*
bn
)
final
;
158
159
/// fired after a new joint target is inserted
160
/** @param set The set of target variable's ids. */
161
void
onJointTargetAdded_
(
const
NodeSet
&
set
)
final
;
162
163
/// fired before a joint target is removed
164
/** @param set The set of target variable's ids. */
165
void
onJointTargetErased_
(
const
NodeSet
&
set
)
final
;
166
167
/// fired after all the nodes of the BN are added as single targets
168
void
onAllMarginalTargetsAdded_
()
final
;
169
170
/// fired before a all the single targets are removed
171
void
onAllMarginalTargetsErased_
()
final
;
172
173
/// fired before a all the joint targets are removed
174
void
onAllJointTargetsErased_
()
final
;
175
176
/// fired before a all single and joint_targets are removed
177
void
onAllTargetsErased_
()
final
;
178
179
/// prepares inference when the latter is in OutdatedStructure state
180
/** Note that the values of evidence are not necessarily
181
* known and can be changed between updateOutdatedStructure_ and
182
* makeInference_. */
183
void
updateOutdatedStructure_
()
final
;
184
185
/// prepares inference when the latter is in OutdatedPotentials state
186
/** Note that the values of evidence are not necessarily
187
* known and can be changed between updateOutdatedPotentials_ and
188
* makeInference_. */
189
void
updateOutdatedPotentials_
()
final
;
190
191
/// called when the inference has to be performed effectively
192
/** Once the inference is done, fillPosterior_ can be called. */
193
void
makeInference_
()
final
;
194
195
196
/// returns the posterior of a given variable
197
/** @param id The variable's id. */
198
const
Potential
<
GUM_SCALAR
>&
posterior_
(
NodeId
id
)
final
;
199
200
/// returns the posterior of a declared target set
201
/** @param set The set of ids of the variables whose joint posterior is
202
* looked for. */
203
const
Potential
<
GUM_SCALAR
>&
jointPosterior_
(
const
NodeSet
&
set
)
final
;
204
205
/** @brief asks derived classes for the joint posterior of a set of
206
* variables not declared as a joint target
207
*
208
* @param wanted_target The set of ids of the variables whose joint
209
* posterior is looked for.
210
* @param declared_target the joint target declared by the user that contains
211
* set */
212
const
Potential
<
GUM_SCALAR
>&
jointPosterior_
(
const
NodeSet
&
wanted_target
,
213
const
NodeSet
&
declared_target
)
final
;
214
215
/// returns a fresh potential equal to P(argument,evidence)
216
Potential
<
GUM_SCALAR
>*
unnormalizedJointPosterior_
(
NodeId
id
)
final
;
217
218
/// returns a fresh potential equal to P(argument,evidence)
219
Potential
<
GUM_SCALAR
>*
unnormalizedJointPosterior_
(
const
NodeSet
&
set
)
final
;
220
221
222
private
:
223
typedef
Set
<
const
Potential
<
GUM_SCALAR
>* >
_PotentialSet_
;
224
typedef
SetIteratorSafe
<
const
Potential
<
GUM_SCALAR
>* >
_PotentialSetIterator_
;
225
226
227
/// the type of relevant potential finding algorithm to be used
228
RelevantPotentialsFinderType
_find_relevant_potential_type_
;
229
230
/** @brief update a set of potentials: the remaining are those to be
231
* combined to produce a message on a separator */
232
void
(
VariableElimination
<
GUM_SCALAR
>::*
_findRelevantPotentials_
)(
233
Set
<
const
Potential
<
GUM_SCALAR
>* >&
pot_list
,
234
Set
<
const
DiscreteVariable
* >&
kept_vars
);
235
236
/// the type of barren nodes computation we wish
237
FindBarrenNodesType
_barren_nodes_type_
;
238
239
/// the operator for performing the projections
240
Potential
<
GUM_SCALAR
>* (*
_projection_op_
)(
const
Potential
<
GUM_SCALAR
>&,
241
const
Set
<
const
DiscreteVariable
* >&){
242
VENewprojPotential
};
243
244
/// the operator for performing the combinations
245
Potential
<
GUM_SCALAR
>* (*
_combination_op_
)(
const
Potential
<
GUM_SCALAR
>&,
246
const
Potential
<
GUM_SCALAR
>&){
247
VENewmultiPotential
};
248
249
/// the triangulation class creating the junction tree used for inference
250
Triangulation
*
_triangulation_
;
251
252
/// the undigraph extracted from the BN and used to construct the join tree
253
/** If all nodes are targets, this graph corresponds to the moral graph
254
* of the BN. Otherwise, it may be a subgraph of this moral graph. For
255
* instance if the BN is A->B->C and only B is a target, _graph_ will be
256
* equal to A-B if we exploit barren nodes (C is a barren node and,
257
* therefore, can be removed for inference). */
258
UndiGraph
_graph_
;
259
260
/// the junction tree used to answer the last inference query
261
JunctionTree
*
_JT_
{
nullptr
};
262
263
/// for each node of _graph_ (~ in the Bayes net), associate an ID in the JT
264
HashTable
<
NodeId
,
NodeId
>
_node_to_clique_
;
265
266
/// for each BN node, indicate in which clique its CPT will be stored
267
HashTable
<
NodeId
,
NodeSet
>
_clique_potentials_
;
268
269
/// indicate a clique that contains all the nodes of the target
270
NodeId
_targets2clique_
;
271
272
/// the posterior computed during the last inference
273
/** the posterior is owned by VariableElimination. */
274
Potential
<
GUM_SCALAR
>*
_target_posterior_
{
nullptr
};
275
276
/// for comparisons with 1 - epsilon
277
const
GUM_SCALAR
_one_minus_epsilon_
{
GUM_SCALAR
(1.0 - 1e-6)};
278
279
280
/// create a new junction tree as well as its related data structures
281
void
_createNewJT_
(
const
NodeSet
&
targets
);
282
283
/// sets the operator for performing the projections
284
void
_setProjectionFunction_
(
Potential
<
GUM_SCALAR
>* (
285
*
proj
)(
const
Potential
<
GUM_SCALAR
>&,
const
Set
<
const
DiscreteVariable
* >&));
286
287
/// sets the operator for performing the combinations
288
void
_setCombinationFunction_
(
Potential
<
GUM_SCALAR
>* (*
comb
)(
const
Potential
<
GUM_SCALAR
>&,
289
const
Potential
<
GUM_SCALAR
>&));
290
291
/** @brief update a set of potentials: the remaining are those to be
292
* combined
293
* to produce a message on a separator */
294
void
_findRelevantPotentialsWithdSeparation_
(
_PotentialSet_
&
pot_list
,
295
Set
<
const
DiscreteVariable
* >&
kept_vars
);
296
297
/** @brief update a set of potentials: the remaining are those to be
298
* combined
299
* to produce a message on a separator */
300
void
_findRelevantPotentialsWithdSeparation2_
(
_PotentialSet_
&
pot_list
,
301
Set
<
const
DiscreteVariable
* >&
kept_vars
);
302
303
/** @brief update a set of potentials: the remaining are those to be
304
* combined
305
* to produce a message on a separator */
306
void
_findRelevantPotentialsWithdSeparation3_
(
_PotentialSet_
&
pot_list
,
307
Set
<
const
DiscreteVariable
* >&
kept_vars
);
308
309
/** @brief update a set of potentials: the remaining are those to be
310
* combined
311
* to produce a message on a separator */
312
void
_findRelevantPotentialsGetAll_
(
_PotentialSet_
&
pot_list
,
313
Set
<
const
DiscreteVariable
* >&
kept_vars
);
314
315
/** @brief update a set of potentials: the remaining are those to be
316
* combined
317
* to produce a message on a separator */
318
void
_findRelevantPotentialsXX_
(
_PotentialSet_
&
pot_list
,
319
Set
<
const
DiscreteVariable
* >&
kept_vars
);
320
321
// remove barren variables and return the newly created projected potentials
322
_PotentialSet_
_removeBarrenVariables_
(
_PotentialSet_
&
pot_list
,
323
Set
<
const
DiscreteVariable
* >&
del_vars
);
324
325
/// actually perform the collect phase
326
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>
_collectMessage_
(
NodeId
id
,
NodeId
from
);
327
328
/// returns the CPT + evidence of a node projected w.r.t. hard evidence
329
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>
_NodePotentials_
(
NodeId
node
);
330
331
/// creates the message sent by clique from_id to clique to_id
332
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>
333
_produceMessage_
(
NodeId
from_id
,
334
NodeId
to_id
,
335
std
::
pair
<
_PotentialSet_
,
_PotentialSet_
>&&
incoming_messages
);
336
337
/** @brief removes variables del_vars from a list of potentials and
338
* returns the resulting list */
339
_PotentialSet_
_marginalizeOut_
(
_PotentialSet_
pot_list
,
340
Set
<
const
DiscreteVariable
* >&
del_vars
,
341
Set
<
const
DiscreteVariable
* >&
kept_vars
);
342
};
343
344
345
#
ifndef
GUM_NO_EXTERN_TEMPLATE_CLASS
346
extern
template
class
VariableElimination<
double
>;
347
#
endif
348
349
350
}
/* namespace gum */
351
352
353
#
include
<
agrum
/
BN
/
inference
/
variableElimination_tpl
.
h
>
354
355
356
#
endif
/* GUM_VARIABLE_ELIMINATION_ */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643