aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
marginalTargetedInference_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation of the generic class for the computation of
25  * (possibly incrementally) marginal posteriors
26  */
27 #include <iterator>
28 
29 namespace gum {
30 
31 
32  // Default Constructor
33  template < typename GUM_SCALAR >
34  MarginalTargetedInference< GUM_SCALAR >::MarginalTargetedInference(
35  const IBayesNet< GUM_SCALAR >* bn) :
36  BayesNetInference< GUM_SCALAR >(bn) {
37  // assign a BN if this has not been done before (due to virtual inheritance)
38  if (this->hasNoModel_()) {
39  BayesNetInference< GUM_SCALAR >::_setBayesNetDuringConstruction_(bn);
40  }
41 
42  // sets all the nodes as targets
43  if (bn != nullptr) {
44  _targeted_mode_ = false;
45  _targets_ = bn->dag().asNodeSet();
46  }
47 
48  GUM_CONSTRUCTOR(MarginalTargetedInference);
49  }
50 
51 
52  // Destructor
53  template < typename GUM_SCALAR >
54  MarginalTargetedInference< GUM_SCALAR >::~MarginalTargetedInference() {
55  GUM_DESTRUCTOR(MarginalTargetedInference);
56  }
57 
58 
59  // fired when a new BN is assigned to the inference engine
60  template < typename GUM_SCALAR >
61  void MarginalTargetedInference< GUM_SCALAR >::onModelChanged_(const GraphicalModel* bn) {
62  _targeted_mode_ = true;
63  _setAllMarginalTargets_();
64  }
65 
66 
67  // ##############################################################################
68  // Targets
69  // ##############################################################################
70 
71  // return true if variable is a target
72  template < typename GUM_SCALAR >
73  INLINE bool MarginalTargetedInference< GUM_SCALAR >::isTarget(NodeId node) const {
74  // check that the variable belongs to the bn
75  if (this->hasNoModel_())
76  GUM_ERROR(NullElement,
77  "No Bayes net has been assigned to the "
78  "inference algorithm");
79  if (!this->BN().dag().exists(node)) {
80  GUM_ERROR(UndefinedElement, node << " is not a NodeId in the bn")
81  }
82 
83  return _targets_.contains(node);
84  }
85 
86  // Add a single target to the list of targets
87  template < typename GUM_SCALAR >
88  INLINE bool MarginalTargetedInference< GUM_SCALAR >::isTarget(const std::string& nodeName) const {
89  return isTarget(this->BN().idFromName(nodeName));
90  }
91 
92 
93  // Clear all previously defined targets (single targets and sets of targets)
94  template < typename GUM_SCALAR >
95  INLINE void MarginalTargetedInference< GUM_SCALAR >::eraseAllTargets() {
96  onAllMarginalTargetsErased_();
97 
98  _targets_.clear();
99  setTargetedMode_(); // does nothing if already in targeted mode
100 
101  this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
102  }
103 
104 
105  // Add a single target to the list of targets
106  template < typename GUM_SCALAR >
107  void MarginalTargetedInference< GUM_SCALAR >::addTarget(NodeId target) {
108  // check if the node belongs to the Bayesian network
109  if (this->hasNoModel_())
110  GUM_ERROR(NullElement,
111  "No Bayes net has been assigned to the "
112  "inference algorithm");
113 
114  if (!this->BN().dag().exists(target)) {
115  GUM_ERROR(UndefinedElement, target << " is not a NodeId in the bn")
116  }
117 
118  setTargetedMode_(); // does nothing if already in targeted mode
119  // add the new target
120  if (!_targets_.contains(target)) {
121  _targets_.insert(target);
122  onMarginalTargetAdded_(target);
123  this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
124  }
125  }
126 
127 
128  // Add all nodes as targets
129  template < typename GUM_SCALAR >
130  void MarginalTargetedInference< GUM_SCALAR >::addAllTargets() {
131  // check if the node belongs to the Bayesian network
132  if (this->hasNoModel_())
133  GUM_ERROR(NullElement,
134  "No Bayes net has been assigned to the "
135  "inference algorithm");
136 
137 
138  setTargetedMode_(); // does nothing if already in targeted mode
139  for (const auto target: this->BN().dag()) {
140  if (!_targets_.contains(target)) {
141  _targets_.insert(target);
142  onMarginalTargetAdded_(target);
143  this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
144  }
145  }
146  }
147 
148 
149  // Add a single target to the list of targets
150  template < typename GUM_SCALAR >
151  void MarginalTargetedInference< GUM_SCALAR >::addTarget(const std::string& nodeName) {
152  // check if the node belongs to the Bayesian network
153  if (this->hasNoModel_())
154  GUM_ERROR(NullElement,
155  "No Bayes net has been assigned to the "
156  "inference algorithm");
157 
158  addTarget(this->BN().idFromName(nodeName));
159  }
160 
161 
162  // removes an existing target
163  template < typename GUM_SCALAR >
164  void MarginalTargetedInference< GUM_SCALAR >::eraseTarget(NodeId target) {
165  // check if the node belongs to the Bayesian network
166  if (this->hasNoModel_())
167  GUM_ERROR(NullElement,
168  "No Bayes net has been assigned to the "
169  "inference algorithm");
170 
171  if (!this->BN().dag().exists(target)) {
172  GUM_ERROR(UndefinedElement, target << " is not a NodeId in the bn")
173  }
174 
175 
176  if (_targets_.contains(target)) {
177  _targeted_mode_ = true; // we do not use setTargetedMode_ because we do not
178  // want to clear the targets
179  onMarginalTargetErased_(target);
180  _targets_.erase(target);
181  this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
182  }
183  }
184 
185 
186  // Add a single target to the list of targets
187  template < typename GUM_SCALAR >
188  void MarginalTargetedInference< GUM_SCALAR >::eraseTarget(const std::string& nodeName) {
189  // check if the node belongs to the Bayesian network
190  if (this->hasNoModel_())
191  GUM_ERROR(NullElement,
192  "No Bayes net has been assigned to the "
193  "inference algorithm");
194 
195  eraseTarget(this->BN().idFromName(nodeName));
196  }
197 
198 
199  // returns the list of single targets
200  template < typename GUM_SCALAR >
201  INLINE const NodeSet& MarginalTargetedInference< GUM_SCALAR >::targets() const noexcept {
202  return _targets_;
203  }
204 
205  // returns the list of single targets
206  template < typename GUM_SCALAR >
207  INLINE const Size MarginalTargetedInference< GUM_SCALAR >::nbrTargets() const noexcept {
208  return _targets_.size();
209  }
210 
211 
212  /// sets all the nodes of the Bayes net as targets
213  template < typename GUM_SCALAR >
214  void MarginalTargetedInference< GUM_SCALAR >::_setAllMarginalTargets_() {
215  _targets_.clear();
216  if (!this->hasNoModel_()) {
217  _targets_ = this->BN().dag().asNodeSet();
218  onAllMarginalTargetsAdded_();
219  }
220  }
221 
222 
223  // ##############################################################################
224  // Inference
225  // ##############################################################################
226 
227  // Compute the posterior of a node.
228  template < typename GUM_SCALAR >
229  const Potential< GUM_SCALAR >& MarginalTargetedInference< GUM_SCALAR >::posterior(NodeId node) {
230  if (this->hardEvidenceNodes().contains(node)) { return *(this->evidence()[node]); }
231 
232  if (!isTarget(node)) {
233  // throws UndefinedElement if var is not a target
234  GUM_ERROR(UndefinedElement, node << " is not a target node")
235  }
236 
237  if (!this->isInferenceDone()) { this->makeInference(); }
238 
239  return posterior_(node);
240  }
241 
242  // Compute the posterior of a node.
243  template < typename GUM_SCALAR >
244  const Potential< GUM_SCALAR >&
245  MarginalTargetedInference< GUM_SCALAR >::posterior(const std::string& nodeName) {
246  return posterior(this->BN().idFromName(nodeName));
247  }
248 
249  /* Entropy
250  * Compute Shanon's entropy of a node given the observation
251  */
252  template < typename GUM_SCALAR >
253  INLINE GUM_SCALAR MarginalTargetedInference< GUM_SCALAR >::H(NodeId X) {
254  return posterior(X).entropy();
255  }
256 
257  /* Entropy
258  * Compute Shanon's entropy of a node given the observation
259  */
260  template < typename GUM_SCALAR >
261  INLINE GUM_SCALAR MarginalTargetedInference< GUM_SCALAR >::H(const std::string& nodeName) {
262  return H(this->BN().idFromName(nodeName));
263  }
264 
265 
266  template < typename GUM_SCALAR >
267  Potential< GUM_SCALAR >
268  MarginalTargetedInference< GUM_SCALAR >::evidenceImpact(NodeId target, const NodeSet& evs) {
269  const auto& vtarget = this->BN().variable(target);
270 
271  if (evs.contains(target)) {
272  GUM_ERROR(InvalidArgument,
273  "Target <" << vtarget.name() << "> (" << target << ") can not be in evs (" << evs
274  << ").");
275  }
276  auto condset = this->BN().minimalCondSet(target, evs);
277 
278  Potential< GUM_SCALAR > res;
279  this->eraseAllTargets();
280  this->eraseAllEvidence();
281  res.add(this->BN().variable(target));
282  this->addTarget(target);
283  for (const auto& n: condset) {
284  res.add(this->BN().variable(n));
285  this->addEvidence(n, 0);
286  }
287 
288  Instantiation inst(res);
289  for (inst.setFirst(); !inst.end(); inst.incNotVar(vtarget)) {
290  // inferring
291  for (const auto& n: condset)
292  this->chgEvidence(n, inst.val(this->BN().variable(n)));
293  this->makeInference();
294  // populate res
295  for (inst.setFirstVar(vtarget); !inst.end(); inst.incVar(vtarget)) {
296  res.set(inst, this->posterior(target)[inst]);
297  }
298  inst.setFirstVar(vtarget); // remove inst.end() flag
299  }
300 
301  return res;
302  }
303 
304 
305  template < typename GUM_SCALAR >
306  Potential< GUM_SCALAR > MarginalTargetedInference< GUM_SCALAR >::evidenceImpact(
307  const std::string& target,
308  const std::vector< std::string >& evs) {
309  const auto& bn = this->BN();
310  return evidenceImpact(bn.idFromName(target), bn.nodeset(evs));
311  }
312 
313 
314  template < typename GUM_SCALAR >
315  INLINE bool MarginalTargetedInference< GUM_SCALAR >::isTargetedMode_() const {
316  return _targeted_mode_;
317  }
318  template < typename GUM_SCALAR >
319  INLINE void MarginalTargetedInference< GUM_SCALAR >::setTargetedMode_() {
320  if (!_targeted_mode_) {
321  _targets_.clear();
322  _targeted_mode_ = true;
323  }
324  }
325 } /* namespace gum */