aGrUM  0.14.2
loopyBeliefPropagation_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES et Pierre-Henri WUILLEMIN *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
24 #ifndef DOXYGEN_SHOULD_SKIP_THIS
25 
26 # include <algorithm>
27 # include <sstream>
28 # include <string>
29 
30 # define LBP_DEFAULT_MAXITER 100
31 # define LBP_DEFAULT_EPSILON 1e-8
32 # define LBP_DEFAULT_MIN_EPSILON_RATE 1e-10
33 # define LBP_DEFAULT_PERIOD_SIZE 1
34 # define LBP_DEFAULT_VERBOSITY false
35 
36 
37 // to ease parsing for IDE
39 
40 
41 namespace gum {
42 
44  template < typename GUM_SCALAR >
46  const IBayesNet< GUM_SCALAR >* bn) :
47  ApproximateInference< GUM_SCALAR >(bn) {
48  // for debugging purposes
49  GUM_CONSTRUCTOR(LoopyBeliefPropagation);
50 
51  this->setEpsilon(LBP_DEFAULT_EPSILON);
52  this->setMinEpsilonRate(LBP_DEFAULT_MIN_EPSILON_RATE);
53  this->setMaxIter(LBP_DEFAULT_MAXITER);
54  this->setVerbosity(LBP_DEFAULT_VERBOSITY);
55  this->setPeriodSize(LBP_DEFAULT_PERIOD_SIZE);
56 
58  }
59 
61  template < typename GUM_SCALAR >
63  GUM_DESTRUCTOR(LoopyBeliefPropagation);
64  }
65 
66 
67  template < typename GUM_SCALAR >
69  __messages.clear();
70  for (const auto& tail : this->BN().nodes()) {
71  Potential< GUM_SCALAR > p;
72  p.add(this->BN().variable(tail));
73  p.fill(static_cast< GUM_SCALAR >(1));
74 
75  for (const auto& head : this->BN().children(tail)) {
76  __messages.insert(Arc(head, tail), p);
77  __messages.insert(Arc(tail, head), p);
78  }
79  }
80  }
81 
82  template < typename GUM_SCALAR >
85  }
86 
87 
88  template < typename GUM_SCALAR >
89  Potential< GUM_SCALAR >
91  const auto& varX = this->BN().variable(X);
92 
93  auto piX = this->BN().cpt(X);
94  for (const auto& U : this->BN().parents(X)) {
95  piX *= __messages[Arc(U, X)];
96  }
97  piX = piX.margSumIn({&varX});
98 
99  return piX;
100  }
101 
102  template < typename GUM_SCALAR >
103  Potential< GUM_SCALAR >
105  NodeId except) {
106  const auto& varX = this->BN().variable(X);
107  const auto& varExcept = this->BN().variable(except);
108  auto piXexcept = this->BN().cpt(X);
109  for (const auto& U : this->BN().parents(X)) {
110  if (U != except) { piXexcept *= __messages[Arc(U, X)]; }
111  }
112  piXexcept = piXexcept.margSumIn({&varX, &varExcept});
113  return piXexcept;
114  }
115 
116 
117  template < typename GUM_SCALAR >
118  Potential< GUM_SCALAR >
120  Potential< GUM_SCALAR > lamX;
121  if (this->hasEvidence(X)) {
122  lamX = *(this->evidence()[X]);
123  } else {
124  lamX.add(this->BN().variable(X));
125  lamX.fill(1);
126  }
127  for (const auto& Y : this->BN().children(X)) {
128  lamX *= __messages[Arc(Y, X)];
129  }
130 
131  return lamX;
132  }
133 
134  template < typename GUM_SCALAR >
135  Potential< GUM_SCALAR >
137  NodeId except) {
138  Potential< GUM_SCALAR > lamXexcept;
139  if (this->hasEvidence(X)) { //
140  lamXexcept = *this->evidence()[X];
141  } else {
142  lamXexcept.add(this->BN().variable(X));
143  lamXexcept.fill(1);
144  }
145  for (const auto& Y : this->BN().children(X)) {
146  if (Y != except) { lamXexcept *= __messages[Arc(Y, X)]; }
147  }
148 
149  return lamXexcept;
150  }
151 
152 
153  template < typename GUM_SCALAR >
155  auto piX = __computeProdPi(X);
156  auto lamX = __computeProdLambda(X);
157 
158  GUM_SCALAR KL = 0;
159  Arc argKL(0, 0);
160 
161  // update lambda_par (for arc U->x)
162  for (const auto& U : this->BN().parents(X)) {
163  auto newLambda =
164  (__computeProdPi(X, U) * lamX).margSumIn({&this->BN().variable(U)});
165  newLambda.normalize();
166  auto ekl = static_cast< GUM_SCALAR >(0);
167  try {
168  ekl = __messages[Arc(X, U)].KL(newLambda);
169  } catch (InvalidArgument&) {
170  GUM_ERROR(InvalidArgument, "Not compatible pi during computation");
171  } catch (FatalError&) { // 0 misplaced
172  ekl = std::numeric_limits< GUM_SCALAR >::infinity();
173  }
174  if (ekl > KL) {
175  KL = ekl;
176  argKL = Arc(X, U);
177  }
178  __messages.set(Arc(X, U), newLambda);
179  }
180 
181  // update pi_child (for arc x->child)
182  for (const auto& Y : this->BN().children(X)) {
183  auto newPi = (piX * __computeProdLambda(X, Y));
184  newPi.normalize();
185  GUM_SCALAR ekl = KL;
186  try {
187  ekl = __messages[Arc(X, Y)].KL(newPi);
188  } catch (InvalidArgument&) {
189  GUM_ERROR(InvalidArgument, "Not compatible pi during computation");
190  } catch (FatalError&) { // 0 misplaced
191  ekl = std::numeric_limits< GUM_SCALAR >::infinity();
192  }
193  if (ekl > KL) {
194  KL = ekl;
195  argKL = Arc(X, Y);
196  }
197  __messages.set(Arc(X, Y), newPi);
198  }
199 
200  return KL;
201  }
202 
203  template < typename GUM_SCALAR >
205  __init_messages();
206  for (const auto& node : this->BN().topologicalOrder()) {
207  __updateNodeMessage(node);
208  }
209  }
210 
211 
213  template < typename GUM_SCALAR >
215  __initStats();
216  this->initApproximationScheme();
217 
218  std::vector< NodeId > shuffleIds;
219  for (const auto& node : this->BN().nodes())
220  shuffleIds.push_back(node);
221 
222  auto engine = std::default_random_engine{};
223 
224  GUM_SCALAR error = 0.0;
225  do {
226  std::shuffle(std::begin(shuffleIds), std::end(shuffleIds), engine);
228  for (const auto& node : shuffleIds) {
229  GUM_SCALAR e = __updateNodeMessage(node);
230  if (e > error) error = e;
231  }
232  } while (this->continueApproximationScheme(error));
233  }
234 
235 
237  template < typename GUM_SCALAR >
238  INLINE const Potential< GUM_SCALAR >&
240  auto p = __computeProdPi(id) * __computeProdLambda(id);
241  p.normalize();
242  __posteriors.set(id, p);
243 
244  return __posteriors[id];
245  }
246 } /* namespace gum */
247 
248 #endif // DOXYGEN_SHOULD_SKIP_THIS
const NodeProperty< const Potential< GUM_SCALAR > *> & evidence() const
returns the set of evidence
This file contains gibbs sampling (for BNs) class definitions.
void setPeriodSize(Size p)
How many samples between two stopping is enable.
void initApproximationScheme()
Initialise the scheme.
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
void setMinEpsilonRate(double rate)
Given that we approximate f(t), stopping criterion on d/dt(|f(t+1)-f(t)|).
Potential< GUM_SCALAR > __computeProdLambda(NodeId X)
void setVerbosity(bool v)
Set the verbosity on (true) or off (false).
virtual ~LoopyBeliefPropagation()
Destructor.
GUM_SCALAR __updateNodeMessage(NodeId X)
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
virtual bool hasEvidence() const final
indicates whether some node(s) have received evidence
LoopyBeliefPropagation(const IBayesNet< GUM_SCALAR > *bn)
Default constructor.
KL is the base class for KL computation betweens 2 BNs.
ArcProperty< Potential< GUM_SCALAR > > __messages
virtual const Potential< GUM_SCALAR > & _posterior(NodeId id)
asks derived classes for the posterior of a given variable
void setMaxIter(Size max)
Stopping criterion on number of iterations.
void setEpsilon(double eps)
Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|.
NodeProperty< Potential< GUM_SCALAR > > __posteriors
virtual void _makeInference()
called when the inference has to be performed effectively
Potential< GUM_SCALAR > __computeProdPi(NodeId X)
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
Size NodeId
Type for node ids.
Definition: graphElements.h:97
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
virtual void _updateOutdatedBNStructure()
prepares inference when the latter is in OutdatedBNStructure state
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.