aGrUM  0.13.2
variableElimination.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES et Pierre-Henri WUILLEMIN *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
27 #ifndef GUM_VARIABLE_ELIMINATION_H
28 #define GUM_VARIABLE_ELIMINATION_H
29 
30 #include <cmath>
31 #include <utility>
32 
36 #include <agrum/agrum.h>
38 
39 namespace gum {
40 
41 
42  // the function used to combine two tables
43  template < typename GUM_SCALAR >
44  INLINE static Potential< GUM_SCALAR >*
46  const Potential< GUM_SCALAR >& t2) {
47  return new Potential< GUM_SCALAR >(t1 * t2);
48  }
49 
50  // the function used to combine two tables
51  template < typename GUM_SCALAR >
52  INLINE static Potential< GUM_SCALAR >*
54  const Set< const DiscreteVariable* >& del_vars) {
55  return new Potential< GUM_SCALAR >(t1.margSumOut(del_vars));
56  }
57 
58 
66  template < typename GUM_SCALAR >
67  class VariableElimination : public JointTargetedInference< GUM_SCALAR > {
68  public:
69  // ############################################################################
71  // ############################################################################
73 
75  explicit VariableElimination(
77  RelevantPotentialsFinderType relevant_type =
80 
83 
87 
89  ~VariableElimination() final;
90 
92 
93 
94  // ############################################################################
96  // ############################################################################
98 
100  void setTriangulation(const Triangulation& new_triangulation);
101 
103 
112 
114 
120 
122  const JunctionTree* junctionTree(NodeId id);
123 
125 
126 
127  protected:
129  void _onStateChanged() final{};
130 
132  void _onEvidenceAdded(NodeId id, bool isHardEvidence) final;
133 
135  void _onEvidenceErased(NodeId id, bool isHardEvidence) final;
136 
138  void _onAllEvidenceErased(bool contains_hard_evidence) final;
139 
147  void _onEvidenceChanged(NodeId id, bool hasChangedSoftHard) final;
148 
150 
151  void _onMarginalTargetAdded(NodeId id) final;
152 
154 
155  void _onMarginalTargetErased(NodeId id) final;
156 
158 
159  void _onJointTargetAdded(const NodeSet& set) final;
160 
162 
163  void _onJointTargetErased(const NodeSet& set) final;
164 
166  void _onAllMarginalTargetsAdded() final;
167 
169  void _onAllMarginalTargetsErased() final;
170 
172  void _onAllJointTargetsErased() final;
173 
175  void _onAllTargetsErased() final;
176 
178 
181  void _updateOutdatedBNStructure() final;
182 
184 
187  void _updateOutdatedBNPotentials() final;
188 
190 
191  void _makeInference() final;
192 
193 
195 
196  const Potential< GUM_SCALAR >& _posterior(NodeId id) final;
197 
199 
201  const Potential< GUM_SCALAR >& _jointPosterior(const NodeSet& set) final;
202 
211  _jointPosterior(const NodeSet& wanted_target,
212  const NodeSet& declared_target) final;
213 
216 
219 
220 
221  private:
225 
226 
229 
234  Set< const DiscreteVariable* >& kept_vars);
235 
238 
240  Potential< GUM_SCALAR >* (*__projection_op)(
243 
245  Potential< GUM_SCALAR >* (*__combination_op)(const Potential< GUM_SCALAR >&,
246  const Potential< GUM_SCALAR >&){
248 
251 
253 
259 
261  JunctionTree* __JT{nullptr};
262 
265 
268 
271 
273 
274  Potential< GUM_SCALAR >* __target_posterior{nullptr};
275 
277  const GUM_SCALAR __1_minus_epsilon{GUM_SCALAR(1.0 - 1e-6)};
278 
279 
281  void __createNewJT(const NodeSet& targets);
282 
284  void __setProjectionFunction(Potential< GUM_SCALAR >* (*proj)(
285  const Potential< GUM_SCALAR >&, const Set< const DiscreteVariable* >&));
286 
288  void __setCombinationFunction(Potential< GUM_SCALAR >* (*comb)(
289  const Potential< GUM_SCALAR >&, const Potential< GUM_SCALAR >&));
290 
295  __PotentialSet& pot_list, Set< const DiscreteVariable* >& kept_vars);
296 
301  __PotentialSet& pot_list, Set< const DiscreteVariable* >& kept_vars);
302 
307  __PotentialSet& pot_list, Set< const DiscreteVariable* >& kept_vars);
308 
312  void __findRelevantPotentialsGetAll(__PotentialSet& pot_list,
313  Set< const DiscreteVariable* >& kept_vars);
314 
318  void __findRelevantPotentialsXX(__PotentialSet& pot_list,
319  Set< const DiscreteVariable* >& kept_vars);
320 
321  // remove barren variables and return the newly created projected potentials
322  __PotentialSet
323  __removeBarrenVariables(__PotentialSet& pot_list,
325 
327  std::pair< __PotentialSet, __PotentialSet > __collectMessage(NodeId id,
328  NodeId from);
329 
331  std::pair< __PotentialSet, __PotentialSet > __NodePotentials(NodeId node);
332 
334  std::pair< __PotentialSet, __PotentialSet > __produceMessage(
335  NodeId from_id,
336  NodeId to_id,
337  std::pair< __PotentialSet, __PotentialSet >&& incoming_messages);
338 
341  __PotentialSet __marginalizeOut(__PotentialSet pot_list,
343  Set< const DiscreteVariable* >& kept_vars);
344  };
345 
346 
347  extern template class VariableElimination< float >;
348  extern template class VariableElimination< double >;
349 
350 
351 } /* namespace gum */
352 
353 
355 
356 
357 #endif /* GUM_VARIABLE_ELIMINATION_ */
void __setProjectionFunction(Potential< GUM_SCALAR > *(*proj)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
sets the operator for performing the projections
void _onEvidenceAdded(NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
void(VariableElimination< GUM_SCALAR >::* __findRelevantPotentials)(Set< const Potential< GUM_SCALAR > * > &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
aGrUM&#39;s Potential is a multi-dimensional array with tensor operators.
Definition: potential.h:57
NodeId __targets2clique
indicate a clique that contains all the nodes of the target
VariableElimination(const IBayesNet< GUM_SCALAR > *BN, RelevantPotentialsFinderType relevant_type=RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS, FindBarrenNodesType=FindBarrenNodesType::FIND_BARREN_NODES)
default constructor
void _onAllTargetsErased() final
fired before a all single and joint_targets are removed
__PotentialSet __removeBarrenVariables(__PotentialSet &pot_list, Set< const DiscreteVariable * > &del_vars)
unsigned int NodeId
Type for node ids.
Definition: graphElements.h:97
static INLINE Potential< GUM_SCALAR > * VENewprojPotential(const Potential< GUM_SCALAR > &t1, const Set< const DiscreteVariable * > &del_vars)
void __setCombinationFunction(Potential< GUM_SCALAR > *(*comb)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &))
sets the operator for performing the combinations
Safe iterators for the Set classDevelopers may consider using Set<x>::iterator_safe instead of SetIte...
Definition: set.h:808
Implementation of Variable Elimination for inference in Bayesian Networks.
const Potential< GUM_SCALAR > & _posterior(NodeId id) final
returns the posterior of a given variable
This file contains the abstract inference class definition for computing (incrementally) joint poster...
JunctionTree * __JT
the junction tree used to answer the last inference query
RelevantPotentialsFinderType __find_relevant_potential_type
the type of relevant potential finding algorithm to be used
void _onEvidenceChanged(NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes ...
void __findRelevantPotentialsXX(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
static INLINE Potential< GUM_SCALAR > * VENewmultiPotential(const Potential< GUM_SCALAR > &t1, const Potential< GUM_SCALAR > &t2)
void setFindBarrenNodesType(FindBarrenNodesType type)
sets how we determine barren nodes
std::pair< __PotentialSet, __PotentialSet > __collectMessage(NodeId id, NodeId from)
actually perform the collect phase
Set< const Potential< GUM_SCALAR > * > __PotentialSet
void _makeInference() final
called when the inference has to be performed effectively
SetIteratorSafe< const Potential< GUM_SCALAR > * > __PotentialSetIterator
Potential< GUM_SCALAR > * __target_posterior
the posterior computed during the last inference
Triangulation * __triangulation
the triangulation class creating the junction tree used for inference
<agrum/BN/inference/variableElimination.h>
void _updateOutdatedBNPotentials() final
prepares inference when the latter is in OutdatedBNPotentials state
the type of algorithm to use to perform relevant reasoning in Bayes net inference ...
Potential< GUM_SCALAR > margSumOut(const Set< const DiscreteVariable * > &del_vars) const
Projection using sum as operation (and implementation-optimized operations)
void _onEvidenceErased(NodeId id, bool isHardEvidence) final
fired before an evidence is removed
Class representing the minimal interface for Bayesian Network.
Definition: IBayesNet.h:59
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
const Potential< GUM_SCALAR > & _jointPosterior(const NodeSet &set) final
returns the posterior of a declared target set
void __findRelevantPotentialsGetAll(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void __findRelevantPotentialsWithdSeparation(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
RelevantPotentialsFinderType
type of algorithm for determining the relevant potentials for combinations using some d-separation an...
const JunctionTree * junctionTree(NodeId id)
returns the join tree used for compute the posterior of node id
void _onAllMarginalTargetsErased() final
fired before a all the single targets are removed
void __findRelevantPotentialsWithdSeparation3(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
VariableElimination< GUM_SCALAR > & operator=(const VariableElimination< GUM_SCALAR > &)=delete
avoid copy operators
Representation of a setA Set is a structure that contains arbitrary elements.
Definition: set.h:162
FindBarrenNodesType
type of algorithm to determine barren nodes
UndiGraph __graph
the undigraph extracted from the BN and used to construct the join tree
void _onAllEvidenceErased(bool contains_hard_evidence) final
fired before all the evidence are erased
void __findRelevantPotentialsWithdSeparation2(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
<agrum/BN/inference/jointTargetedInference.h>
HashTable< NodeId, NodeSet > __clique_potentials
for each BN node, indicate in which clique its CPT will be stored
void _onJointTargetErased(const NodeSet &set) final
fired before a joint target is removed
void _onJointTargetAdded(const NodeSet &set) final
fired after a new joint target is inserted
FindBarrenNodesType __barren_nodes_type
the type of barren nodes computation we wish
void _onAllJointTargetsErased() final
fired before a all the joint targets are removed
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
void setRelevantPotentialsFinderType(RelevantPotentialsFinderType type)
sets how we determine the relevant potentials to combine
Basic graph of cliques.
Definition: cliqueGraph.h:55
void _onStateChanged() final
fired when the stage is changed
HashTable< NodeId, NodeId > __node_to_clique
for each node of __graph (~ in the Bayes net), associate an ID in the JT
Potential< GUM_SCALAR > * _unnormalizedJointPosterior(NodeId id) final
returns a fresh potential equal to P(argument,evidence)
void _onMarginalTargetAdded(NodeId id) final
fired after a new single target is inserted
Detect barren nodes for inference in Bayesian networks.
__PotentialSet __marginalizeOut(__PotentialSet pot_list, Set< const DiscreteVariable * > &del_vars, Set< const DiscreteVariable * > &kept_vars)
removes variables del_vars from a list of potentials and returns the resulting list ...
std::pair< __PotentialSet, __PotentialSet > __NodePotentials(NodeId node)
returns the CPT + evidence of a node projected w.r.t. hard evidence
virtual const NodeSet & targets() const noexceptfinal
returns the list of marginal targets
void _updateOutdatedBNStructure() final
prepares inference when the latter is in OutdatedBNStructure state
void _onAllMarginalTargetsAdded() final
fired after all the nodes of the BN are added as single targets
Class for computing default triangulations of graphs.
std::pair< __PotentialSet, __PotentialSet > __produceMessage(NodeId from_id, NodeId to_id, std::pair< __PotentialSet, __PotentialSet > &&incoming_messages)
creates the message sent by clique from_id to clique to_id
Base class for undirected graphs.
Definition: undiGraph.h:106
void __createNewJT(const NodeSet &targets)
create a new junction tree as well as its related data structures
~VariableElimination() final
destructor
Interface for all the triangulation methods.
Definition: triangulation.h:44
const GUM_SCALAR __1_minus_epsilon
for comparisons with 1 - epsilon
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
void _onMarginalTargetErased(NodeId id) final
fired before a single target is removed