aGrUM  0.14.2
variableElimination_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES et Pierre-Henri WUILLEMIN *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
28 #ifndef DOXYGEN_SHOULD_SKIP_THIS
29 
31 
38 
39 
40 namespace gum {
41 
42 
43  // default constructor
44  template < typename GUM_SCALAR >
46  const IBayesNet< GUM_SCALAR >* BN,
47  RelevantPotentialsFinderType relevant_type,
48  FindBarrenNodesType barren_type) :
49  JointTargetedInference< GUM_SCALAR >(BN) {
50  // sets the relevant potential and the barren nodes finding algorithm
51  setRelevantPotentialsFinderType(relevant_type);
52  setFindBarrenNodesType(barren_type);
53 
54  // create a default triangulation (the user can change it afterwards)
55  __triangulation = new DefaultTriangulation;
56 
57  // for debugging purposessetRequiredInference
58  GUM_CONSTRUCTOR(VariableElimination);
59  }
60 
61 
62  // destructor
63  template < typename GUM_SCALAR >
65  // remove the junction tree and the triangulation algorithm
66  if (__JT != nullptr) delete __JT;
67  delete __triangulation;
68  if (__target_posterior != nullptr) delete __target_posterior;
69 
70  // for debugging purposes
71  GUM_DESTRUCTOR(VariableElimination);
72  }
73 
74 
76  template < typename GUM_SCALAR >
78  const Triangulation& new_triangulation) {
79  delete __triangulation;
80  __triangulation = new_triangulation.newFactory();
81  }
82 
83 
85  template < typename GUM_SCALAR >
86  INLINE const JunctionTree*
89 
90  return __JT;
91  }
92 
93 
95  template < typename GUM_SCALAR >
97  Potential< GUM_SCALAR >* (*proj)(const Potential< GUM_SCALAR >&,
98  const Set< const DiscreteVariable* >&)) {
99  __projection_op = proj;
100  }
101 
102 
104  template < typename GUM_SCALAR >
106  Potential< GUM_SCALAR >* (*comb)(const Potential< GUM_SCALAR >&,
107  const Potential< GUM_SCALAR >&)) {
108  __combination_op = comb;
109  }
110 
111 
113  template < typename GUM_SCALAR >
116  if (type != __find_relevant_potential_type) {
117  switch (type) {
121  break;
122 
126  break;
127 
131  break;
132 
136  break;
137 
138  default:
139  GUM_ERROR(InvalidArgument,
140  "setRelevantPotentialsFinderType for type "
141  << (unsigned int)type << " is not implemented yet");
142  }
143 
145  }
146  }
147 
148 
150  template < typename GUM_SCALAR >
152  FindBarrenNodesType type) {
153  if (type != __barren_nodes_type) {
154  // WARNING: if a new type is added here, method __createJT should
155  // certainly
156  // be updated as well, in particular its step 2.
157  switch (type) {
160 
161  default:
162  GUM_ERROR(InvalidArgument,
163  "setFindBarrenNodesType for type "
164  << (unsigned int)type << " is not implemented yet");
165  }
166 
167  __barren_nodes_type = type;
168  }
169  }
170 
171 
173  template < typename GUM_SCALAR >
175  bool) {}
176 
177 
179  template < typename GUM_SCALAR >
181  bool) {}
182 
183 
185  template < typename GUM_SCALAR >
187 
188 
190  template < typename GUM_SCALAR >
192  bool) {}
193 
194 
196  template < typename GUM_SCALAR >
197  INLINE void
199 
200 
202  template < typename GUM_SCALAR >
203  INLINE void
205 
207  template < typename GUM_SCALAR >
209  const IBayesNet< GUM_SCALAR >* bn) {}
210 
212  template < typename GUM_SCALAR >
213  INLINE void
215 
216 
218  template < typename GUM_SCALAR >
219  INLINE void
221 
222 
224  template < typename GUM_SCALAR >
226 
227 
229  template < typename GUM_SCALAR >
231 
232 
234  template < typename GUM_SCALAR >
236 
237 
239  template < typename GUM_SCALAR >
241 
242 
244  template < typename GUM_SCALAR >
246  // to create the JT, we first create the moral graph of the BN in the
247  // following way in order to take into account the barren nodes and the
248  // nodes that received evidence:
249  // 1/ we create an undirected graph containing only the nodes and no edge
250  // 2/ if we take into account barren nodes, remove them from the graph
251  // 3/ if we take d-separation into account, remove the d-separated nodes
252  // 4/ add edges so that each node and its parents in the BN form a clique
253  // 5/ add edges so that the targets form a clique of the moral graph
254  // 6/ remove the nodes that received hard evidence (by step 4/, their
255  // parents are linked by edges, which is necessary for inference)
256  //
257  // At the end of step 6/, we have our moral graph and we can triangulate it
258  // to get the new junction tree
259 
260  // 1/ create an undirected graph containing only the nodes and no edge
261  const auto& bn = this->BN();
262  __graph.clear();
263  for (auto node : bn.dag())
264  __graph.addNodeWithId(node);
265 
266  // 2/ if we wish to exploit barren nodes, we shall remove them from the BN
267  // to do so: we identify all the nodes that are not targets and have
268  // received no evidence and such that their descendants are neither targets
269  // nor evidence nodes. Such nodes can be safely discarded from the BN
270  // without altering the inference output
272  // check that all the nodes are not targets, otherwise, there is no
273  // barren node
274  if (targets.size() != bn.size()) {
275  BarrenNodesFinder finder(&(bn.dag()));
276  finder.setTargets(&targets);
277 
278  NodeSet evidence_nodes;
279  for (const auto& pair : this->evidence()) {
280  evidence_nodes.insert(pair.first);
281  }
282  finder.setEvidence(&evidence_nodes);
283 
284  NodeSet barren_nodes = finder.barrenNodes();
285 
286  // remove the barren nodes from the moral graph
287  for (const auto node : barren_nodes) {
288  __graph.eraseNode(node);
289  }
290  }
291  }
292 
293  // 3/ if we wish to exploit d-separation, remove all the nodes that are
294  // d-separated from our targets
295  {
296  NodeSet requisite_nodes;
297  bool dsep_analysis = false;
301  BayesBall::requisiteNodes(bn.dag(),
302  targets,
303  this->hardEvidenceNodes(),
304  this->softEvidenceNodes(),
305  requisite_nodes);
306  dsep_analysis = true;
307  } break;
308 
310  dSeparation dsep;
311  dsep.requisiteNodes(bn.dag(),
312  targets,
313  this->hardEvidenceNodes(),
314  this->softEvidenceNodes(),
315  requisite_nodes);
316  dsep_analysis = true;
317  } break;
318 
320 
321  default: GUM_ERROR(FatalError, "not implemented yet");
322  }
323 
324  // remove all the nodes that are not requisite
325  if (dsep_analysis) {
326  for (auto iter = __graph.beginSafe(); iter != __graph.endSafe(); ++iter) {
327  if (!requisite_nodes.contains(*iter)
328  && !this->hardEvidenceNodes().contains(*iter)) {
329  __graph.eraseNode(*iter);
330  }
331  }
332  }
333  }
334 
335  // 4/ add edges so that each node and its parents in the BN form a clique
336  for (const auto node : __graph) {
337  const NodeSet& parents = bn.parents(node);
338  for (auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
339  // before adding an edge between node and its parent, check that the
340  // parent belong to the graph. Actually, when d-separated nodes are
341  // removed, it may be the case that the parents of hard evidence nodes
342  // are removed. But the latter still exist in the graph.
343  if (__graph.existsNode(*iter1)) __graph.addEdge(*iter1, node);
344 
345  auto iter2 = iter1;
346  for (++iter2; iter2 != parents.cend(); ++iter2) {
347  // before adding an edge, check that both extremities belong to
348  // the graph. Actually, when d-separated nodes are removed, it may
349  // be the case that the parents of hard evidence nodes are removed.
350  // But the latter still exist in the graph.
351  if (__graph.existsNode(*iter1) && __graph.existsNode(*iter2))
352  __graph.addEdge(*iter1, *iter2);
353  }
354  }
355  }
356 
357  // 5/ if targets contains several nodes, we shall add new edges into the
358  // moral graph in order to ensure that there exists a clique containing
359  // thier joint distribution
360  for (auto iter1 = targets.cbegin(); iter1 != targets.cend(); ++iter1) {
361  auto iter2 = iter1;
362  for (++iter2; iter2 != targets.cend(); ++iter2) {
363  __graph.addEdge(*iter1, *iter2);
364  }
365  }
366 
367  // 6/ remove all the nodes that received hard evidence
368  for (const auto node : this->hardEvidenceNodes()) {
369  __graph.eraseNode(node);
370  }
371 
372 
373  // now, we can compute the new junction tree.
374  if (__JT != nullptr) delete __JT;
375  __triangulation->setGraph(&__graph, &(this->domainSizes()));
376  const JunctionTree& triang_jt = __triangulation->junctionTree();
377  __JT = new CliqueGraph(triang_jt);
378 
379  // indicate, for each node of the moral graph a clique in __JT that can
380  // contain its conditional probability table
382  __clique_potentials.clear();
383  NodeSet emptyset;
384  for (auto clique : *__JT)
385  __clique_potentials.insert(clique, emptyset);
386  const std::vector< NodeId >& JT_elim_order =
388  NodeProperty< Size > elim_order(Size(JT_elim_order.size()));
389  for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
390  ++i)
391  elim_order.insert(JT_elim_order[i], NodeId(i));
392  const DAG& dag = bn.dag();
393  for (const auto node : __graph) {
394  // get the variables in the potential of node (and its parents)
395  NodeId first_eliminated_node = node;
396  Size elim_number = elim_order[first_eliminated_node];
397 
398  for (const auto parent : dag.parents(node)) {
399  if (__graph.existsNode(parent) && (elim_order[parent] < elim_number)) {
400  elim_number = elim_order[parent];
401  first_eliminated_node = parent;
402  }
403  }
404 
405  // first_eliminated_node contains the first var (node or one of its
406  // parents) eliminated => the clique created during its elimination
407  // contains node and all of its parents => it can contain the potential
408  // assigned to the node in the BN
409  NodeId clique =
410  __triangulation->createdJunctionTreeClique(first_eliminated_node);
411  __node_to_clique.insert(node, clique);
412  __clique_potentials[clique].insert(node);
413  }
414 
415  // do the same for the nodes that received evidence. Here, we only store
416  // the nodes whose at least one parent belongs to __graph (otherwise
417  // their CPT is just a constant real number).
418  for (const auto node : this->hardEvidenceNodes()) {
419  // get the set of parents of the node that belong to __graph
420  NodeSet pars(dag.parents(node).size());
421  for (const auto par : dag.parents(node))
422  if (__graph.exists(par)) pars.insert(par);
423 
424  if (!pars.empty()) {
425  NodeId first_eliminated_node = *(pars.begin());
426  Size elim_number = elim_order[first_eliminated_node];
427 
428  for (const auto parent : pars) {
429  if (elim_order[parent] < elim_number) {
430  elim_number = elim_order[parent];
431  first_eliminated_node = parent;
432  }
433  }
434 
435  // first_eliminated_node contains the first var (node or one of its
436  // parents) eliminated => the clique created during its elimination
437  // contains node and all of its parents => it can contain the potential
438  // assigned to the node in the BN
439  NodeId clique =
440  __triangulation->createdJunctionTreeClique(first_eliminated_node);
441  __node_to_clique.insert(node, clique);
442  __clique_potentials[clique].insert(node);
443  }
444  }
445 
446 
447  // indicate a clique that contains all the nodes of targets
448  __targets2clique = std::numeric_limits< NodeId >::max();
449  {
450  // remove from set all the nodes that received hard evidence (since they
451  // do not belong to the join tree)
452  NodeSet nodeset = targets;
453  for (const auto node : this->hardEvidenceNodes())
454  if (nodeset.contains(node)) nodeset.erase(node);
455 
456  if (!nodeset.empty()) {
457  NodeId first_eliminated_node = *(nodeset.begin());
458  Size elim_number = elim_order[first_eliminated_node];
459  for (const auto node : nodeset) {
460  if (elim_order[node] < elim_number) {
461  elim_number = elim_order[node];
462  first_eliminated_node = node;
463  }
464  }
466  __triangulation->createdJunctionTreeClique(first_eliminated_node);
467  }
468  }
469  }
470 
471 
473  template < typename GUM_SCALAR >
475 
476 
479  template < typename GUM_SCALAR >
481 
482 
483  // find the potentials d-connected to a set of variables
484  template < typename GUM_SCALAR >
486  Set< const Potential< GUM_SCALAR >* >& pot_list,
487  Set< const DiscreteVariable* >& kept_vars) {}
488 
489 
490  // find the potentials d-connected to a set of variables
491  template < typename GUM_SCALAR >
493  Set< const Potential< GUM_SCALAR >* >& pot_list,
494  Set< const DiscreteVariable* >& kept_vars) {
495  // find the node ids of the kept variables
496  NodeSet kept_ids;
497  const auto& bn = this->BN();
498  for (const auto var : kept_vars) {
499  kept_ids.insert(bn.nodeId(*var));
500  }
501 
502  // determine the set of potentials d-connected with the kept variables
503  NodeSet requisite_nodes;
504  BayesBall::requisiteNodes(bn.dag(),
505  kept_ids,
506  this->hardEvidenceNodes(),
507  this->softEvidenceNodes(),
508  requisite_nodes);
509  for (auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) {
510  const Sequence< const DiscreteVariable* >& vars =
511  (**iter).variablesSequence();
512  bool found = false;
513  for (auto var : vars) {
514  if (requisite_nodes.exists(bn.nodeId(*var))) {
515  found = true;
516  break;
517  }
518  }
519 
520  if (!found) { pot_list.erase(iter); }
521  }
522  }
523 
524 
525  // find the potentials d-connected to a set of variables
526  template < typename GUM_SCALAR >
528  Set< const Potential< GUM_SCALAR >* >& pot_list,
529  Set< const DiscreteVariable* >& kept_vars) {
530  // find the node ids of the kept variables
531  NodeSet kept_ids;
532  const auto& bn = this->BN();
533  for (const auto var : kept_vars) {
534  kept_ids.insert(bn.nodeId(*var));
535  }
536 
537  // determine the set of potentials d-connected with the kept variables
539  kept_ids,
540  this->hardEvidenceNodes(),
541  this->softEvidenceNodes(),
542  pot_list);
543  }
544 
545 
546  // find the potentials d-connected to a set of variables
547  template < typename GUM_SCALAR >
549  Set< const Potential< GUM_SCALAR >* >& pot_list,
550  Set< const DiscreteVariable* >& kept_vars) {
551  // find the node ids of the kept variables
552  NodeSet kept_ids;
553  const auto& bn = this->BN();
554  for (const auto var : kept_vars) {
555  kept_ids.insert(bn.nodeId(*var));
556  }
557 
558  // determine the set of potentials d-connected with the kept variables
559  dSeparation dsep;
560  dsep.relevantPotentials(bn,
561  kept_ids,
562  this->hardEvidenceNodes(),
563  this->softEvidenceNodes(),
564  pot_list);
565  }
566 
567 
568  // find the potentials d-connected to a set of variables
569  template < typename GUM_SCALAR >
571  Set< const Potential< GUM_SCALAR >* >& pot_list,
572  Set< const DiscreteVariable* >& kept_vars) {
575  __findRelevantPotentialsWithdSeparation2(pot_list, kept_vars);
576  break;
577 
579  __findRelevantPotentialsWithdSeparation(pot_list, kept_vars);
580  break;
581 
583  __findRelevantPotentialsWithdSeparation3(pot_list, kept_vars);
584  break;
585 
587  __findRelevantPotentialsGetAll(pot_list, kept_vars);
588  break;
589 
590  default: GUM_ERROR(FatalError, "not implemented yet");
591  }
592  }
593 
594 
595  // remove barren variables
596  template < typename GUM_SCALAR >
597  Set< const Potential< GUM_SCALAR >* >
599  __PotentialSet& pot_list, Set< const DiscreteVariable* >& del_vars) {
600  // remove from del_vars the variables that received some evidence:
601  // only those that did not received evidence can be barren variables
602  Set< const DiscreteVariable* > the_del_vars = del_vars;
603  for (auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe();
604  ++iter) {
605  NodeId id = this->BN().nodeId(**iter);
606  if (this->hardEvidenceNodes().exists(id)
607  || this->softEvidenceNodes().exists(id)) {
608  the_del_vars.erase(iter);
609  }
610  }
611 
612  // assign to each random variable the set of potentials that contain it
613  HashTable< const DiscreteVariable*, __PotentialSet > var2pots;
614  __PotentialSet empty_pot_set;
615  for (const auto pot : pot_list) {
616  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
617  for (const auto var : vars) {
618  if (the_del_vars.exists(var)) {
619  if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
620  var2pots[var].insert(pot);
621  }
622  }
623  }
624 
625  // each variable with only one potential is a barren variable
626  // assign to each potential with barren nodes its set of barren variables
627  HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > >
628  pot2barren_var;
629  Set< const DiscreteVariable* > empty_var_set;
630  for (auto elt : var2pots) {
631  if (elt.second.size() == 1) { // here we have a barren variable
632  const Potential< GUM_SCALAR >* pot = *(elt.second.begin());
633  if (!pot2barren_var.exists(pot)) {
634  pot2barren_var.insert(pot, empty_var_set);
635  }
636  pot2barren_var[pot].insert(elt.first); // insert the barren variable
637  }
638  }
639 
640  // for each potential with barren variables, marginalize them.
641  // if the potential has only barren variables, simply remove them from the
642  // set of potentials, else just project the potential
643  MultiDimProjection< GUM_SCALAR, Potential > projector(VENewprojPotential);
644  __PotentialSet projected_pots;
645  for (auto elt : pot2barren_var) {
646  // remove the current potential from pot_list as, anyway, we will change
647  // it
648  const Potential< GUM_SCALAR >* pot = elt.first;
649  pot_list.erase(pot);
650 
651  // check whether we need to add a projected new potential or not (i.e.,
652  // whether there exist non-barren variables or not)
653  if (pot->variablesSequence().size() != elt.second.size()) {
654  auto new_pot = projector.project(*pot, elt.second);
655  pot_list.insert(new_pot);
656  projected_pots.insert(new_pot);
657  }
658  }
659 
660  return projected_pots;
661  }
662 
663 
664  // performs the collect phase of Lazy Propagation
665  template < typename GUM_SCALAR >
666  std::pair< Set< const Potential< GUM_SCALAR >* >,
667  Set< const Potential< GUM_SCALAR >* > >
669  // collect messages from all the neighbors
670  std::pair< __PotentialSet, __PotentialSet > collect_messages;
671  for (const auto other : __JT->neighbours(id)) {
672  if (other != from) {
673  std::pair< __PotentialSet, __PotentialSet > message(
674  __collectMessage(other, id));
675  collect_messages.first += message.first;
676  collect_messages.second += message.second;
677  }
678  }
679 
680  // combine the collect messages with those of id's clique
681  return __produceMessage(id, from, std::move(collect_messages));
682  }
683 
684 
685  // get the CPT + evidence of a node projected w.r.t. hard evidence
686  template < typename GUM_SCALAR >
687  std::pair< Set< const Potential< GUM_SCALAR >* >,
688  Set< const Potential< GUM_SCALAR >* > >
690  std::pair< __PotentialSet, __PotentialSet > res;
691  const auto& bn = this->BN();
692 
693  // get the CPT's of the node
694  // beware: all the potentials that are defined over some nodes
695  // including hard evidence must be projected so that these nodes are
696  // removed from the potential
697  // also beware that the CPT of a hard evidence node may be defined over
698  // parents that do not belong to __graph and that are not hard evidence.
699  // In this case, those parents have been removed by d-separation and it is
700  // easy to show that, in this case all the parents have been removed, so
701  // that the CPT does not need to be taken into account
702  const auto& evidence = this->evidence();
703  const auto& hard_evidence = this->hardEvidence();
704  if (__graph.exists(node) || this->hardEvidenceNodes().contains(node)) {
705  const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
706  const auto& variables = cpt.variablesSequence();
707 
708  // check if the parents of a hard evidence node do not belong to __graph
709  // and are not themselves hard evidence, discard the CPT, it is useless
710  // for inference
711  if (this->hardEvidenceNodes().contains(node)) {
712  for (const auto var : variables) {
713  NodeId xnode = bn.nodeId(*var);
714  if (!this->hardEvidenceNodes().contains(xnode)
715  && !__graph.existsNode(xnode))
716  return res;
717  }
718  }
719 
720  // get the list of nodes with hard evidence in cpt
721  NodeSet hard_nodes;
722  for (const auto var : variables) {
723  NodeId xnode = bn.nodeId(*var);
724  if (this->hardEvidenceNodes().contains(xnode)) hard_nodes.insert(xnode);
725  }
726 
727  // if hard_nodes contains hard evidence nodes, perform a projection
728  // and insert the result into the appropriate clique, else insert
729  // directly cpt into the clique
730  if (hard_nodes.empty()) {
731  res.first.insert(&cpt);
732  } else {
733  // marginalize out the hard evidence nodes: if the cpt is defined
734  // only over nodes that received hard evidence, do not consider it
735  // as a potential anymore
736  if (hard_nodes.size() != variables.size()) {
737  // perform the projection with a combine and project instance
738  Set< const DiscreteVariable* > hard_variables;
739  __PotentialSet marg_cpt_set{&cpt};
740  for (const auto xnode : hard_nodes) {
741  marg_cpt_set.insert(evidence[xnode]);
742  hard_variables.insert(&(bn.variable(xnode)));
743  }
744  // perform the combination of those potentials and their projection
745  MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
746  combine_and_project(__combination_op, VENewprojPotential);
747  __PotentialSet new_cpt_list =
748  combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
749 
750  // there should be only one potential in new_cpt_list
751  if (new_cpt_list.size() != 1) {
752  // remove the CPT created to avoid memory leaks
753  for (auto pot : new_cpt_list) {
754  if (!marg_cpt_set.contains(pot)) delete pot;
755  }
756  GUM_ERROR(FatalError,
757  "the projection of a potential containing "
758  << "hard evidence is empty!");
759  }
760  const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
761  res.first.insert(projected_cpt);
762  res.second.insert(projected_cpt);
763  }
764  }
765 
766  // if the node received some soft evidence, add it
767  if (evidence.exists(node) && !hard_evidence.exists(node)) {
768  res.first.insert(this->evidence()[node]);
769  }
770  }
771 
772  return res;
773  }
774 
775 
776  // creates the message sent by clique from_id to clique to_id
777  template < typename GUM_SCALAR >
778  std::pair< Set< const Potential< GUM_SCALAR >* >,
779  Set< const Potential< GUM_SCALAR >* > >
781  NodeId from_id,
782  NodeId to_id,
783  std::pair< Set< const Potential< GUM_SCALAR >* >,
784  Set< const Potential< GUM_SCALAR >* > >&& incoming_messages) {
785  // get the messages sent by adjacent nodes to from_id
786  std::pair< Set< const Potential< GUM_SCALAR >* >,
787  Set< const Potential< GUM_SCALAR >* > >
788  pot_list(std::move(incoming_messages));
789 
790  // get the potentials of the clique
791  for (const auto node : __clique_potentials[from_id]) {
792  auto new_pots = __NodePotentials(node);
793  pot_list.first += new_pots.first;
794  pot_list.second += new_pots.second;
795  }
796 
797  // if from_id = to_id: this is the endpoint of a collect
798  if (!__JT->existsEdge(from_id, to_id)) {
799  return pot_list;
800  } else {
801  // get the set of variables that need be removed from the potentials
802  const NodeSet& from_clique = __JT->clique(from_id);
803  const NodeSet& separator = __JT->separator(from_id, to_id);
804  Set< const DiscreteVariable* > del_vars(from_clique.size());
805  Set< const DiscreteVariable* > kept_vars(separator.size());
806  const auto& bn = this->BN();
807 
808  for (const auto node : from_clique) {
809  if (!separator.contains(node)) {
810  del_vars.insert(&(bn.variable(node)));
811  } else {
812  kept_vars.insert(&(bn.variable(node)));
813  }
814  }
815 
816  // pot_list now contains all the potentials to multiply and marginalize
817  // => combine the messages
818  __PotentialSet new_pot_list =
819  __marginalizeOut(pot_list.first, del_vars, kept_vars);
820 
821  // remove all the potentials that are equal to ones (as probability
822  // matrix multiplications are tensorial, such potentials are useless)
823  for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
824  ++iter) {
825  const auto pot = *iter;
826  if (pot->variablesSequence().size() == 1) {
827  bool is_all_ones = true;
828  for (Instantiation inst(*pot); !inst.end(); ++inst) {
829  if ((*pot)[inst] < __1_minus_epsilon) {
830  is_all_ones = false;
831  break;
832  }
833  }
834  if (is_all_ones) {
835  if (!pot_list.first.exists(pot)) delete pot;
836  new_pot_list.erase(iter);
837  continue;
838  }
839  }
840  }
841 
842  // remove the unnecessary temporary messages
843  for (auto iter = pot_list.second.beginSafe();
844  iter != pot_list.second.endSafe();
845  ++iter) {
846  if (!new_pot_list.contains(*iter)) {
847  delete *iter;
848  pot_list.second.erase(iter);
849  }
850  }
851 
852  // keep track of all the newly created potentials
853  for (const auto pot : new_pot_list) {
854  if (!pot_list.first.contains(pot)) { pot_list.second.insert(pot); }
855  }
856 
857  // return the new set of potentials
858  return std::pair< __PotentialSet, __PotentialSet >(
859  std::move(new_pot_list), std::move(pot_list.second));
860  }
861  }
862 
863 
864  // remove variables del_vars from the list of potentials pot_list
865  template < typename GUM_SCALAR >
866  Set< const Potential< GUM_SCALAR >* >
868  Set< const Potential< GUM_SCALAR >* > pot_list,
869  Set< const DiscreteVariable* >& del_vars,
870  Set< const DiscreteVariable* >& kept_vars) {
871  // use d-separation analysis to check which potentials shall be combined
872  __findRelevantPotentialsXX(pot_list, kept_vars);
873 
874  // remove the potentials corresponding to barren variables if we want
875  // to exploit barren nodes
876  __PotentialSet barren_projected_potentials;
878  barren_projected_potentials = __removeBarrenVariables(pot_list, del_vars);
879  }
880 
881  // create a combine and project operator that will perform the
882  // marginalization
883  MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(
885  __PotentialSet new_pot_list =
886  combine_and_project.combineAndProject(pot_list, del_vars);
887 
888  // remove all the potentials that were created due to projections of
889  // barren nodes and that are not part of the new_pot_list: these
890  // potentials were just temporary potentials
891  for (auto iter = barren_projected_potentials.beginSafe();
892  iter != barren_projected_potentials.endSafe();
893  ++iter) {
894  if (!new_pot_list.exists(*iter)) delete *iter;
895  }
896 
897  // remove all the potentials that have no dimension
898  for (auto iter_pot = new_pot_list.beginSafe();
899  iter_pot != new_pot_list.endSafe();
900  ++iter_pot) {
901  if ((*iter_pot)->variablesSequence().size() == 0) {
902  // as we have already marginalized out variables that received evidence,
903  // it may be the case that, after combining and projecting, some
904  // potentials might be empty. In this case, we shall keep their
905  // constant and remove them from memory
906  // # TODO: keep the constants!
907  delete *iter_pot;
908  new_pot_list.erase(iter_pot);
909  }
910  }
911 
912  return new_pot_list;
913  }
914 
915 
916  // performs a whole inference
917  template < typename GUM_SCALAR >
919 
920 
922  template < typename GUM_SCALAR >
923  Potential< GUM_SCALAR >*
925  const auto& bn = this->BN();
926 
927  // hard evidence do not belong to the join tree
928  // # TODO: check for sets of inconsistent hard evidence
929  if (this->hardEvidenceNodes().contains(id)) {
930  return new Potential< GUM_SCALAR >(*(this->evidence()[id]));
931  }
932 
933  // if we still need to perform some inference task, do it
934  __createNewJT(NodeSet{id});
935  NodeId clique_of_id = __node_to_clique[id];
936  auto pot_list = __collectMessage(clique_of_id, clique_of_id);
937 
938  // get the set of variables that need be removed from the potentials
939  const NodeSet& nodes = __JT->clique(clique_of_id);
940  Set< const DiscreteVariable* > kept_vars{&(bn.variable(id))};
941  Set< const DiscreteVariable* > del_vars(nodes.size());
942  for (const auto node : nodes) {
943  if (node != id) del_vars.insert(&(bn.variable(node)));
944  }
945 
946  // pot_list now contains all the potentials to multiply and marginalize
947  // => combine the messages
948  __PotentialSet new_pot_list =
949  __marginalizeOut(pot_list.first, del_vars, kept_vars);
950  Potential< GUM_SCALAR >* joint = nullptr;
951 
952  if (new_pot_list.size() == 1) {
953  joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
954  // if joint already existed, create a copy, so that we can put it into
955  // the __target_posterior property
956  if (pot_list.first.exists(joint)) {
957  joint = new Potential< GUM_SCALAR >(*joint);
958  } else {
959  // remove the joint from new_pot_list so that it will not be
960  // removed just after the else block
961  new_pot_list.clear();
962  }
963  } else {
964  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
966  joint = fast_combination.combine(new_pot_list);
967  }
968 
969  // remove the potentials that were created in new_pot_list
970  for (auto pot : new_pot_list)
971  if (!pot_list.first.exists(pot)) delete pot;
972 
973  // remove all the temporary potentials created in pot_list
974  for (auto pot : pot_list.second)
975  delete pot;
976 
977  // check that the joint posterior is different from a 0 vector: this would
978  // indicate that some hard evidence are not compatible (their joint
979  // probability is equal to 0)
980  bool nonzero_found = false;
981  for (Instantiation inst(*joint); !inst.end(); ++inst) {
982  if ((*joint)[inst]) {
983  nonzero_found = true;
984  break;
985  }
986  }
987  if (!nonzero_found) {
988  // remove joint from memory to avoid memory leaks
989  delete joint;
990  GUM_ERROR(IncompatibleEvidence,
991  "some evidence entered into the Bayes "
992  "net are incompatible (their joint proba = 0)");
993  }
994 
995  return joint;
996  }
997 
998 
1000  template < typename GUM_SCALAR >
1001  const Potential< GUM_SCALAR >&
1003  // compute the joint posterior and normalize
1004  auto joint = _unnormalizedJointPosterior(id);
1005  joint->normalize();
1006 
1007  if (__target_posterior != nullptr) delete __target_posterior;
1008  __target_posterior = joint;
1009 
1010  return *joint;
1011  }
1012 
1013 
1014  // returns the marginal a posteriori proba of a given node
1015  template < typename GUM_SCALAR >
1016  Potential< GUM_SCALAR >*
1018  const NodeSet& set) {
1019  // hard evidence do not belong to the join tree, so extract the nodes
1020  // from targets that are not hard evidence
1021  NodeSet targets = set, hard_ev_nodes;
1022  for (const auto node : this->hardEvidenceNodes()) {
1023  if (targets.contains(node)) {
1024  targets.erase(node);
1025  hard_ev_nodes.insert(node);
1026  }
1027  }
1028 
1029  // if all the nodes have received hard evidence, then compute the
1030  // joint posterior directly by multiplying the hard evidence potentials
1031  const auto& evidence = this->evidence();
1032  if (targets.empty()) {
1033  __PotentialSet pot_list;
1034  for (const auto node : set) {
1035  pot_list.insert(evidence[node]);
1036  }
1037  if (pot_list.size() == 1) {
1038  return new Potential< GUM_SCALAR >(**(pot_list.begin()));
1039  } else {
1040  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1042  return fast_combination.combine(pot_list);
1043  }
1044  }
1045 
1046  // if we still need to perform some inference task, do it
1047  __createNewJT(set);
1049 
1050  // get the set of variables that need be removed from the potentials
1051  const NodeSet& nodes = __JT->clique(__targets2clique);
1052  Set< const DiscreteVariable* > del_vars(nodes.size());
1053  Set< const DiscreteVariable* > kept_vars(targets.size());
1054  const auto& bn = this->BN();
1055  for (const auto node : nodes) {
1056  if (!targets.contains(node)) {
1057  del_vars.insert(&(bn.variable(node)));
1058  } else {
1059  kept_vars.insert(&(bn.variable(node)));
1060  }
1061  }
1062 
1063  // pot_list now contains all the potentials to multiply and marginalize
1064  // => combine the messages
1065  __PotentialSet new_pot_list =
1066  __marginalizeOut(pot_list.first, del_vars, kept_vars);
1067  Potential< GUM_SCALAR >* joint = nullptr;
1068 
1069  if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1070  joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
1071  // if pot already existed, create a copy, so that we can put it into
1072  // the __target_posteriors property
1073  if (pot_list.first.exists(joint)) {
1074  joint = new Potential< GUM_SCALAR >(*joint);
1075  } else {
1076  // remove the joint from new_pot_list so that it will not be
1077  // removed just after the next else block
1078  new_pot_list.clear();
1079  }
1080  } else {
1081  // combine all the potentials in new_pot_list with all the hard evidence
1082  // of the nodes in set
1083  __PotentialSet new_new_pot_list = new_pot_list;
1084  for (const auto node : hard_ev_nodes) {
1085  new_new_pot_list.insert(evidence[node]);
1086  }
1087  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1089  joint = fast_combination.combine(new_new_pot_list);
1090  }
1091 
1092  // remove the potentials that were created in new_pot_list
1093  for (auto pot : new_pot_list)
1094  if (!pot_list.first.exists(pot)) delete pot;
1095 
1096  // remove all the temporary potentials created in pot_list
1097  for (auto pot : pot_list.second)
1098  delete pot;
1099 
1100  // check that the joint posterior is different from a 0 vector: this would
1101  // indicate that some hard evidence are not compatible
1102  bool nonzero_found = false;
1103  for (Instantiation inst(*joint); !inst.end(); ++inst) {
1104  if ((*joint)[inst]) {
1105  nonzero_found = true;
1106  break;
1107  }
1108  }
1109  if (!nonzero_found) {
1110  // remove joint from memory to avoid memory leaks
1111  delete joint;
1112  GUM_ERROR(IncompatibleEvidence,
1113  "some evidence entered into the Bayes "
1114  "net are incompatible (their joint proba = 0)");
1115  }
1116 
1117  return joint;
1118  }
1119 
1120 
1122  template < typename GUM_SCALAR >
1123  const Potential< GUM_SCALAR >&
1125  // compute the joint posterior and normalize
1126  auto joint = _unnormalizedJointPosterior(set);
1127  joint->normalize();
1128 
1129  if (__target_posterior != nullptr) delete __target_posterior;
1130  __target_posterior = joint;
1131 
1132  return *joint;
1133  }
1134 
1135 
1137  template < typename GUM_SCALAR >
1138  const Potential< GUM_SCALAR >&
1140  const NodeSet& wanted_target, const NodeSet& declared_target) {
1141  return _jointPosterior(wanted_target);
1142  }
1143 
1144 
1145 } /* namespace gum */
1146 
1147 #endif // DOXYGEN_SHOULD_SKIP_THIS
void __setProjectionFunction(Potential< GUM_SCALAR > *(*proj)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
sets the operator for performing the projections
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:578
NodeId __targets2clique
indicate a clique that contains all the nodes of the target
VariableElimination(const IBayesNet< GUM_SCALAR > *BN, RelevantPotentialsFinderType relevant_type=RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS, FindBarrenNodesType=FindBarrenNodesType::FIND_BARREN_NODES)
default constructor
void _onAllTargetsErased() final
fired before a all single and joint_targets are removed
__PotentialSet __removeBarrenVariables(__PotentialSet &pot_list, Set< const DiscreteVariable * > &del_vars)
virtual void clear()
removes all the nodes and edges from the graph
Definition: undiGraph_inl.h:40
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
const NodeProperty< const Potential< GUM_SCALAR > *> & evidence() const
returns the set of evidence
static INLINE Potential< GUM_SCALAR > * VENewprojPotential(const Potential< GUM_SCALAR > &t1, const Set< const DiscreteVariable * > &del_vars)
void __setCombinationFunction(Potential< GUM_SCALAR > *(*comb)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &))
sets the operator for performing the combinations
node_iterator_safe beginSafe() const
a begin iterator to parse the set of nodes contained in the NodeGraphPart
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
const Potential< GUM_SCALAR > & _posterior(NodeId id) final
returns the posterior of a given variable
JunctionTree * __JT
the junction tree used to answer the last inference query
RelevantPotentialsFinderType __find_relevant_potential_type
the type of relevant potential finding algorithm to be used
void __findRelevantPotentialsXX(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void setFindBarrenNodesType(FindBarrenNodesType type)
sets how we determine barren nodes
d-separation analysis (as described in Koller & Friedman 2009)
static void relevantPotentials(const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
update a set of potentials, keeping only those d-connected with query variables given evidence ...
Definition: BayesBall_tpl.h:32
void _onEvidenceErased(const NodeId id, bool isHardEvidence) final
fired before an evidence is removed
std::pair< __PotentialSet, __PotentialSet > __collectMessage(NodeId id, NodeId from)
actually perform the collect phase
Set< const Potential< GUM_SCALAR > *> __PotentialSet
void _makeInference() final
called when the inference has to be performed effectively
The BayesBall algorithm (as described by Schachter).
Potential< GUM_SCALAR > * __target_posterior
the posterior computed during the last inference
Triangulation * __triangulation
the triangulation class creating the junction tree used for inference
void _updateOutdatedBNPotentials() final
prepares inference when the latter is in OutdatedBNPotentials state
void erase(const Key &k)
Erases an element from the set.
Definition: set_tpl.h:653
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
const Potential< GUM_SCALAR > & _jointPosterior(const NodeSet &set) final
returns the posterior of a declared target set
void __findRelevantPotentialsGetAll(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void __findRelevantPotentialsWithdSeparation(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
RelevantPotentialsFinderType
type of algorithm for determining the relevant potentials for combinations using some d-separation an...
Implementation of a variable elimination algorithm for inference in Bayesian Networks.
const JunctionTree * junctionTree(NodeId id)
returns the join tree used for compute the posterior of node id
void _onMarginalTargetErased(const NodeId id) final
fired before a single target is removed
void _onAllMarginalTargetsErased() final
fired before a all the single targets are removed
void __findRelevantPotentialsWithdSeparation3(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
virtual void eraseNode(const NodeId id)
remove a node and its adjacent edges from the graph
Definition: undiGraph_inl.h:55
Potential< GUM_SCALAR > *(* __combination_op)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &)
the operator for performing the combinations
FindBarrenNodesType
type of algorithm to determine barren nodes
UndiGraph __graph
the undigraph extracted from the BN and used to construct the join tree
const NodeSet & softEvidenceNodes() const
returns the set of nodes with soft evidence
virtual const NodeProperty< Size > & domainSizes() const final
get the domain sizes of the random variables of the BN
void _onAllEvidenceErased(bool contains_hard_evidence) final
fired before all the evidence are erased
void __findRelevantPotentialsWithdSeparation2(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
virtual NodeId createdJunctionTreeClique(const NodeId id)=0
returns the Id of the clique created by the elimination of a given node during the triangulation proc...
const NodeProperty< Idx > & hardEvidence() const
indicate for each node with hard evidence which value it took
Header files of gum::Instantiation.
HashTable< NodeId, NodeSet > __clique_potentials
for each BN node, indicate in which clique its CPT will be stored
void _onJointTargetErased(const NodeSet &set) final
fired before a joint target is removed
void _onJointTargetAdded(const NodeSet &set) final
fired after a new joint target is inserted
FindBarrenNodesType __barren_nodes_type
the type of barren nodes computation we wish
void _onAllJointTargetsErased() final
fired before a all the joint targets are removed
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
void setRelevantPotentialsFinderType(RelevantPotentialsFinderType type)
sets how we determine the relevant potentials to combine
const NodeSet & hardEvidenceNodes() const
returns the set of nodes with hard evidence
HashTable< NodeId, NodeId > __node_to_clique
for each node of __graph (~ in the Bayes net), associate an ID in the JT
Potential< GUM_SCALAR > * _unnormalizedJointPosterior(NodeId id) final
returns a fresh potential equal to P(argument,evidence)
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
Definition: cliqueGraph.h:299
__PotentialSet __marginalizeOut(__PotentialSet pot_list, Set< const DiscreteVariable * > &del_vars, Set< const DiscreteVariable * > &kept_vars)
removes variables del_vars from a list of potentials and returns the resulting list ...
void clear()
Removes all the elements in the hash table.
std::pair< __PotentialSet, __PotentialSet > __NodePotentials(NodeId node)
returns the CPT + evidence of a node projected w.r.t. hard evidence
void _updateOutdatedBNStructure() final
prepares inference when the latter is in OutdatedBNStructure state
void _onAllMarginalTargetsAdded() final
fired after all the nodes of the BN are added as single targets
An algorithm for converting a join tree into a binary join tree.
std::pair< __PotentialSet, __PotentialSet > __produceMessage(NodeId from_id, NodeId to_id, std::pair< __PotentialSet, __PotentialSet > &&incoming_messages)
creates the message sent by clique from_id to clique to_id
virtual Triangulation * newFactory() const =0
returns a fresh triangulation of the same type as the current object but with an empty graph ...
virtual const NodeSet & targets() const noexcept final
returns the list of marginal targets
const node_iterator_safe & endSafe() const noexcept
the end iterator to parse the set of nodes contained in the NodeGraphPart
void __createNewJT(const NodeSet &targets)
create a new junction tree as well as its related data structures
~VariableElimination() final
destructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
virtual const CliqueGraph & junctionTree()=0
returns a compatible junction tree
const GUM_SCALAR __1_minus_epsilon
for comparisons with 1 - epsilon
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
virtual const std::vector< NodeId > & eliminationOrder()=0
returns an elimination ordering compatible with the triangulated graph
void _onMarginalTargetAdded(const NodeId id) final
fired after a new single target is inserted
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
Potential< GUM_SCALAR > *(* __projection_op)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &)
the operator for performing the projections
Size NodeId
Type for node ids.
Definition: graphElements.h:97
void _onEvidenceChanged(const NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes ...
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:610
void(VariableElimination< GUM_SCALAR >::* __findRelevantPotentials)(Set< const Potential< GUM_SCALAR > * > &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void _onEvidenceAdded(const NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
virtual void setGraph(const UndiGraph *graph, const NodeProperty< Size > *domsizes)=0
initialize the triangulation data structures for a new graph
static void requisiteNodes(const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
Fill the &#39;requisite&#39; nodeset with the requisite nodes in dag given a query and evidence.
Definition: BayesBall.cpp:33
virtual void _onBayesNetChanged(const IBayesNet< GUM_SCALAR > *bn) final
fired after a new Bayes net has been assigned to the engine