aGrUM  0.13.2
variableElimination_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES et Pierre-Henri WUILLEMIN *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
28 #ifndef DOXYGEN_SHOULD_SKIP_THIS
29 
31 
38 
39 
40 namespace gum {
41 
42 
43  // default constructor
44  template < typename GUM_SCALAR >
46  const IBayesNet< GUM_SCALAR >* BN,
47  RelevantPotentialsFinderType relevant_type,
48  FindBarrenNodesType barren_type) :
49  JointTargetedInference< GUM_SCALAR >(BN) {
50  // sets the relevant potential and the barren nodes finding algorithm
51  setRelevantPotentialsFinderType(relevant_type);
52  setFindBarrenNodesType(barren_type);
53 
54  // create a default triangulation (the user can change it afterwards)
55  __triangulation = new DefaultTriangulation;
56 
57  // for debugging purposessetRequiredInference
58  GUM_CONSTRUCTOR(VariableElimination);
59  }
60 
61 
62  // destructor
63  template < typename GUM_SCALAR >
65  // remove the junction tree and the triangulation algorithm
66  if (__JT != nullptr) delete __JT;
67  delete __triangulation;
68  if (__target_posterior != nullptr) delete __target_posterior;
69 
70  // for debugging purposes
71  GUM_DESTRUCTOR(VariableElimination);
72  }
73 
74 
76  template < typename GUM_SCALAR >
78  const Triangulation& new_triangulation) {
79  delete __triangulation;
80  __triangulation = new_triangulation.newFactory();
81  }
82 
83 
85  template < typename GUM_SCALAR >
86  INLINE const JunctionTree*
89 
90  return __JT;
91  }
92 
93 
95  template < typename GUM_SCALAR >
97  Potential< GUM_SCALAR >* (*proj)(const Potential< GUM_SCALAR >&,
98  const Set< const DiscreteVariable* >&)) {
99  __projection_op = proj;
100  }
101 
102 
104  template < typename GUM_SCALAR >
106  Potential< GUM_SCALAR >* (*comb)(const Potential< GUM_SCALAR >&,
107  const Potential< GUM_SCALAR >&)) {
108  __combination_op = comb;
109  }
110 
111 
113  template < typename GUM_SCALAR >
116  if (type != __find_relevant_potential_type) {
117  switch (type) {
121  break;
122 
126  break;
127 
131  break;
132 
136  break;
137 
138  default:
139  GUM_ERROR(InvalidArgument,
140  "setRelevantPotentialsFinderType for type "
141  << (unsigned int)type << " is not implemented yet");
142  }
143 
145  }
146  }
147 
148 
150  template < typename GUM_SCALAR >
152  FindBarrenNodesType type) {
153  if (type != __barren_nodes_type) {
154  // WARNING: if a new type is added here, method __createJT should
155  // certainly
156  // be updated as well, in particular its step 2.
157  switch (type) {
160 
161  default:
162  GUM_ERROR(InvalidArgument,
163  "setFindBarrenNodesType for type "
164  << (unsigned int)type << " is not implemented yet");
165  }
166 
167  __barren_nodes_type = type;
168  }
169  }
170 
171 
173  template < typename GUM_SCALAR >
175 
176 
178  template < typename GUM_SCALAR >
180 
181 
183  template < typename GUM_SCALAR >
185 
186 
188  template < typename GUM_SCALAR >
190  }
191 
192 
194  template < typename GUM_SCALAR >
196 
197 
199  template < typename GUM_SCALAR >
201 
202 
204  template < typename GUM_SCALAR >
205  INLINE void
207 
208 
210  template < typename GUM_SCALAR >
211  INLINE void
213 
214 
216  template < typename GUM_SCALAR >
218 
219 
221  template < typename GUM_SCALAR >
223 
224 
226  template < typename GUM_SCALAR >
228 
229 
231  template < typename GUM_SCALAR >
233 
234 
236  template < typename GUM_SCALAR >
238  // to create the JT, we first create the moral graph of the BN in the
239  // following way in order to take into account the barren nodes and the
240  // nodes that received evidence:
241  // 1/ we create an undirected graph containing only the nodes and no edge
242  // 2/ if we take into account barren nodes, remove them from the graph
243  // 3/ if we take d-separation into account, remove the d-separated nodes
244  // 4/ add edges so that each node and its parents in the BN form a clique
245  // 5/ add edges so that the targets form a clique of the moral graph
246  // 6/ remove the nodes that received hard evidence (by step 4/, their
247  // parents are linked by edges, which is necessary for inference)
248  //
249  // At the end of step 6/, we have our moral graph and we can triangulate it
250  // to get the new junction tree
251 
252  // 1/ create an undirected graph containing only the nodes and no edge
253  const auto& bn = this->BN();
254  __graph.clear();
255  for (auto node : bn.dag())
256  __graph.addNodeWithId(node);
257 
258  // 2/ if we wish to exploit barren nodes, we shall remove them from the BN
259  // to do so: we identify all the nodes that are not targets and have
260  // received no evidence and such that their descendants are neither targets
261  // nor evidence nodes. Such nodes can be safely discarded from the BN
262  // without altering the inference output
264  // check that all the nodes are not targets, otherwise, there is no
265  // barren node
266  if (targets.size() != bn.size()) {
267  BarrenNodesFinder finder(&(bn.dag()));
268  finder.setTargets(&targets);
269 
270  NodeSet evidence_nodes;
271  for (const auto& pair : this->evidence()) {
272  evidence_nodes.insert(pair.first);
273  }
274  finder.setEvidence(&evidence_nodes);
275 
276  NodeSet barren_nodes = finder.barrenNodes();
277 
278  // remove the barren nodes from the moral graph
279  for (const auto node : barren_nodes) {
280  __graph.eraseNode(node);
281  }
282  }
283  }
284 
285  // 3/ if we wish to exploit d-separation, remove all the nodes that are
286  // d-separated from our targets
287  {
288  NodeSet requisite_nodes;
289  bool dsep_analysis = false;
293  BayesBall::requisiteNodes(bn.dag(),
294  targets,
295  this->hardEvidenceNodes(),
296  this->softEvidenceNodes(),
297  requisite_nodes);
298  dsep_analysis = true;
299  } break;
300 
302  dSeparation dsep;
303  dsep.requisiteNodes(bn.dag(),
304  targets,
305  this->hardEvidenceNodes(),
306  this->softEvidenceNodes(),
307  requisite_nodes);
308  dsep_analysis = true;
309  } break;
310 
312 
313  default: GUM_ERROR(FatalError, "not implemented yet");
314  }
315 
316  // remove all the nodes that are not requisite
317  if (dsep_analysis) {
318  for (auto iter = __graph.beginSafe(); iter != __graph.endSafe(); ++iter) {
319  if (!requisite_nodes.contains(*iter)
320  && !this->hardEvidenceNodes().contains(*iter)) {
321  __graph.eraseNode(*iter);
322  }
323  }
324  }
325  }
326 
327  // 4/ add edges so that each node and its parents in the BN form a clique
328  for (const auto node : __graph) {
329  const NodeSet& parents = bn.parents(node);
330  for (auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
331  // before adding an edge between node and its parent, check that the
332  // parent belong to the graph. Actually, when d-separated nodes are
333  // removed, it may be the case that the parents of hard evidence nodes
334  // are removed. But the latter still exist in the graph.
335  if (__graph.existsNode(*iter1)) __graph.addEdge(*iter1, node);
336 
337  auto iter2 = iter1;
338  for (++iter2; iter2 != parents.cend(); ++iter2) {
339  // before adding an edge, check that both extremities belong to
340  // the graph. Actually, when d-separated nodes are removed, it may
341  // be the case that the parents of hard evidence nodes are removed.
342  // But the latter still exist in the graph.
343  if (__graph.existsNode(*iter1) && __graph.existsNode(*iter2))
344  __graph.addEdge(*iter1, *iter2);
345  }
346  }
347  }
348 
349  // 5/ if targets contains several nodes, we shall add new edges into the
350  // moral graph in order to ensure that there exists a clique containing
351  // thier joint distribution
352  for (auto iter1 = targets.cbegin(); iter1 != targets.cend(); ++iter1) {
353  auto iter2 = iter1;
354  for (++iter2; iter2 != targets.cend(); ++iter2) {
355  __graph.addEdge(*iter1, *iter2);
356  }
357  }
358 
359  // 6/ remove all the nodes that received hard evidence
360  for (const auto node : this->hardEvidenceNodes()) {
361  __graph.eraseNode(node);
362  }
363 
364 
365  // now, we can compute the new junction tree.
366  if (__JT != nullptr) delete __JT;
367  __triangulation->setGraph(&__graph, &(this->domainSizes()));
368  const JunctionTree& triang_jt = __triangulation->junctionTree();
369  __JT = new CliqueGraph(triang_jt);
370 
371  // indicate, for each node of the moral graph a clique in __JT that can
372  // contain its conditional probability table
374  __clique_potentials.clear();
375  NodeSet emptyset;
376  for (auto clique : *__JT)
377  __clique_potentials.insert(clique, emptyset);
378  const std::vector< NodeId >& JT_elim_order =
380  NodeProperty< int > elim_order(Size(JT_elim_order.size()));
381  for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
382  ++i)
383  elim_order.insert(JT_elim_order[i], NodeId(i));
384  const DAG& dag = bn.dag();
385  for (const auto node : __graph) {
386  // get the variables in the potential of node (and its parents)
387  NodeId first_eliminated_node = node;
388  int elim_number = elim_order[first_eliminated_node];
389 
390  for (const auto parent : dag.parents(node)) {
391  if (__graph.existsNode(parent) && (elim_order[parent] < elim_number)) {
392  elim_number = elim_order[parent];
393  first_eliminated_node = parent;
394  }
395  }
396 
397  // first_eliminated_node contains the first var (node or one of its
398  // parents) eliminated => the clique created during its elimination
399  // contains node and all of its parents => it can contain the potential
400  // assigned to the node in the BN
401  NodeId clique =
402  __triangulation->createdJunctionTreeClique(first_eliminated_node);
403  __node_to_clique.insert(node, clique);
404  __clique_potentials[clique].insert(node);
405  }
406 
407  // do the same for the nodes that received evidence. Here, we only store
408  // the nodes whose at least one parent belongs to __graph (otherwise
409  // their CPT is just a constant real number).
410  for (const auto node : this->hardEvidenceNodes()) {
411  // get the set of parents of the node that belong to __graph
412  NodeSet pars(dag.parents(node).size());
413  for (const auto par : dag.parents(node))
414  if (__graph.exists(par)) pars.insert(par);
415 
416  if (!pars.empty()) {
417  NodeId first_eliminated_node = *(pars.begin());
418  int elim_number = elim_order[first_eliminated_node];
419 
420  for (const auto parent : pars) {
421  if (elim_order[parent] < elim_number) {
422  elim_number = elim_order[parent];
423  first_eliminated_node = parent;
424  }
425  }
426 
427  // first_eliminated_node contains the first var (node or one of its
428  // parents) eliminated => the clique created during its elimination
429  // contains node and all of its parents => it can contain the potential
430  // assigned to the node in the BN
431  NodeId clique =
432  __triangulation->createdJunctionTreeClique(first_eliminated_node);
433  __node_to_clique.insert(node, clique);
434  __clique_potentials[clique].insert(node);
435  }
436  }
437 
438 
439  // indicate a clique that contains all the nodes of targets
440  __targets2clique = std::numeric_limits< NodeId >::max();
441  {
442  // remove from set all the nodes that received hard evidence (since they
443  // do not belong to the join tree)
444  NodeSet nodeset = targets;
445  for (const auto node : this->hardEvidenceNodes())
446  if (nodeset.contains(node)) nodeset.erase(node);
447 
448  if (!nodeset.empty()) {
449  NodeId first_eliminated_node = *(nodeset.begin());
450  int elim_number = elim_order[first_eliminated_node];
451  for (const auto node : nodeset) {
452  if (elim_order[node] < elim_number) {
453  elim_number = elim_order[node];
454  first_eliminated_node = node;
455  }
456  }
458  __triangulation->createdJunctionTreeClique(first_eliminated_node);
459  }
460  }
461  }
462 
463 
465  template < typename GUM_SCALAR >
467 
468 
471  template < typename GUM_SCALAR >
473 
474 
475  // find the potentials d-connected to a set of variables
476  template < typename GUM_SCALAR >
478  Set< const Potential< GUM_SCALAR >* >& pot_list,
479  Set< const DiscreteVariable* >& kept_vars) {}
480 
481 
482  // find the potentials d-connected to a set of variables
483  template < typename GUM_SCALAR >
485  Set< const Potential< GUM_SCALAR >* >& pot_list,
486  Set< const DiscreteVariable* >& kept_vars) {
487  // find the node ids of the kept variables
488  NodeSet kept_ids;
489  const auto& bn = this->BN();
490  for (const auto var : kept_vars) {
491  kept_ids.insert(bn.nodeId(*var));
492  }
493 
494  // determine the set of potentials d-connected with the kept variables
495  NodeSet requisite_nodes;
496  BayesBall::requisiteNodes(bn.dag(),
497  kept_ids,
498  this->hardEvidenceNodes(),
499  this->softEvidenceNodes(),
500  requisite_nodes);
501  for (auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) {
502  const Sequence< const DiscreteVariable* >& vars =
503  (**iter).variablesSequence();
504  bool found = false;
505  for (auto var : vars) {
506  if (requisite_nodes.exists(bn.nodeId(*var))) {
507  found = true;
508  break;
509  }
510  }
511 
512  if (!found) { pot_list.erase(iter); }
513  }
514  }
515 
516 
517  // find the potentials d-connected to a set of variables
518  template < typename GUM_SCALAR >
520  Set< const Potential< GUM_SCALAR >* >& pot_list,
521  Set< const DiscreteVariable* >& kept_vars) {
522  // find the node ids of the kept variables
523  NodeSet kept_ids;
524  const auto& bn = this->BN();
525  for (const auto var : kept_vars) {
526  kept_ids.insert(bn.nodeId(*var));
527  }
528 
529  // determine the set of potentials d-connected with the kept variables
531  kept_ids,
532  this->hardEvidenceNodes(),
533  this->softEvidenceNodes(),
534  pot_list);
535  }
536 
537 
538  // find the potentials d-connected to a set of variables
539  template < typename GUM_SCALAR >
541  Set< const Potential< GUM_SCALAR >* >& pot_list,
542  Set< const DiscreteVariable* >& kept_vars) {
543  // find the node ids of the kept variables
544  NodeSet kept_ids;
545  const auto& bn = this->BN();
546  for (const auto var : kept_vars) {
547  kept_ids.insert(bn.nodeId(*var));
548  }
549 
550  // determine the set of potentials d-connected with the kept variables
551  dSeparation dsep;
552  dsep.relevantPotentials(bn,
553  kept_ids,
554  this->hardEvidenceNodes(),
555  this->softEvidenceNodes(),
556  pot_list);
557  }
558 
559 
560  // find the potentials d-connected to a set of variables
561  template < typename GUM_SCALAR >
563  Set< const Potential< GUM_SCALAR >* >& pot_list,
564  Set< const DiscreteVariable* >& kept_vars) {
567  __findRelevantPotentialsWithdSeparation2(pot_list, kept_vars);
568  break;
569 
571  __findRelevantPotentialsWithdSeparation(pot_list, kept_vars);
572  break;
573 
575  __findRelevantPotentialsWithdSeparation3(pot_list, kept_vars);
576  break;
577 
579  __findRelevantPotentialsGetAll(pot_list, kept_vars);
580  break;
581 
582  default: GUM_ERROR(FatalError, "not implemented yet");
583  }
584  }
585 
586 
587  // remove barren variables
588  template < typename GUM_SCALAR >
589  Set< const Potential< GUM_SCALAR >* >
591  __PotentialSet& pot_list, Set< const DiscreteVariable* >& del_vars) {
592  // remove from del_vars the variables that received some evidence:
593  // only those that did not received evidence can be barren variables
594  Set< const DiscreteVariable* > the_del_vars = del_vars;
595  for (auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe();
596  ++iter) {
597  NodeId id = this->BN().nodeId(**iter);
598  if (this->hardEvidenceNodes().exists(id)
599  || this->softEvidenceNodes().exists(id)) {
600  the_del_vars.erase(iter);
601  }
602  }
603 
604  // assign to each random variable the set of potentials that contain it
605  HashTable< const DiscreteVariable*, __PotentialSet > var2pots;
606  __PotentialSet empty_pot_set;
607  for (const auto pot : pot_list) {
608  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
609  for (const auto var : vars) {
610  if (the_del_vars.exists(var)) {
611  if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
612  var2pots[var].insert(pot);
613  }
614  }
615  }
616 
617  // each variable with only one potential is a barren variable
618  // assign to each potential with barren nodes its set of barren variables
619  HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > >
620  pot2barren_var;
621  Set< const DiscreteVariable* > empty_var_set;
622  for (auto elt : var2pots) {
623  if (elt.second.size() == 1) { // here we have a barren variable
624  const Potential< GUM_SCALAR >* pot = *(elt.second.begin());
625  if (!pot2barren_var.exists(pot)) {
626  pot2barren_var.insert(pot, empty_var_set);
627  }
628  pot2barren_var[pot].insert(elt.first); // insert the barren variable
629  }
630  }
631 
632  // for each potential with barren variables, marginalize them.
633  // if the potential has only barren variables, simply remove them from the
634  // set of potentials, else just project the potential
635  MultiDimProjection< GUM_SCALAR, Potential > projector(VENewprojPotential);
636  __PotentialSet projected_pots;
637  for (auto elt : pot2barren_var) {
638  // remove the current potential from pot_list as, anyway, we will change
639  // it
640  const Potential< GUM_SCALAR >* pot = elt.first;
641  pot_list.erase(pot);
642 
643  // check whether we need to add a projected new potential or not (i.e.,
644  // whether there exist non-barren variables or not)
645  if (pot->variablesSequence().size() != elt.second.size()) {
646  auto new_pot = projector.project(*pot, elt.second);
647  pot_list.insert(new_pot);
648  projected_pots.insert(new_pot);
649  }
650  }
651 
652  return projected_pots;
653  }
654 
655 
656  // performs the collect phase of Lazy Propagation
657  template < typename GUM_SCALAR >
658  std::pair< Set< const Potential< GUM_SCALAR >* >,
659  Set< const Potential< GUM_SCALAR >* > >
661  // collect messages from all the neighbors
662  std::pair< __PotentialSet, __PotentialSet > collect_messages;
663  for (const auto other : __JT->neighbours(id)) {
664  if (other != from) {
665  std::pair< __PotentialSet, __PotentialSet > message(
666  __collectMessage(other, id));
667  collect_messages.first += message.first;
668  collect_messages.second += message.second;
669  }
670  }
671 
672  // combine the collect messages with those of id's clique
673  return __produceMessage(id, from, std::move(collect_messages));
674  }
675 
676 
677  // get the CPT + evidence of a node projected w.r.t. hard evidence
678  template < typename GUM_SCALAR >
679  std::pair< Set< const Potential< GUM_SCALAR >* >,
680  Set< const Potential< GUM_SCALAR >* > >
682  std::pair< __PotentialSet, __PotentialSet > res;
683  const auto& bn = this->BN();
684 
685  // get the CPT's of the node
686  // beware: all the potentials that are defined over some nodes
687  // including hard evidence must be projected so that these nodes are
688  // removed from the potential
689  // also beware that the CPT of a hard evidence node may be defined over
690  // parents that do not belong to __graph and that are not hard evidence.
691  // In this case, those parents have been removed by d-separation and it is
692  // easy to show that, in this case all the parents have been removed, so
693  // that the CPT does not need to be taken into account
694  const auto& evidence = this->evidence();
695  const auto& hard_evidence = this->hardEvidence();
696  if (__graph.exists(node) || this->hardEvidenceNodes().contains(node)) {
697  const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
698  const auto& variables = cpt.variablesSequence();
699 
700  // check if the parents of a hard evidence node do not belong to __graph
701  // and are not themselves hard evidence, discard the CPT, it is useless
702  // for inference
703  if (this->hardEvidenceNodes().contains(node)) {
704  for (const auto var : variables) {
705  NodeId xnode = bn.nodeId(*var);
706  if (!this->hardEvidenceNodes().contains(xnode)
707  && !__graph.existsNode(xnode))
708  return res;
709  }
710  }
711 
712  // get the list of nodes with hard evidence in cpt
713  NodeSet hard_nodes;
714  for (const auto var : variables) {
715  NodeId xnode = bn.nodeId(*var);
716  if (this->hardEvidenceNodes().contains(xnode)) hard_nodes.insert(xnode);
717  }
718 
719  // if hard_nodes contains hard evidence nodes, perform a projection
720  // and insert the result into the appropriate clique, else insert
721  // directly cpt into the clique
722  if (hard_nodes.empty()) {
723  res.first.insert(&cpt);
724  } else {
725  // marginalize out the hard evidence nodes: if the cpt is defined
726  // only over nodes that received hard evidence, do not consider it
727  // as a potential anymore
728  if (hard_nodes.size() != variables.size()) {
729  // perform the projection with a combine and project instance
730  Set< const DiscreteVariable* > hard_variables;
731  __PotentialSet marg_cpt_set{&cpt};
732  for (const auto xnode : hard_nodes) {
733  marg_cpt_set.insert(evidence[xnode]);
734  hard_variables.insert(&(bn.variable(xnode)));
735  }
736  // perform the combination of those potentials and their projection
737  MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
738  combine_and_project(__combination_op, VENewprojPotential);
739  __PotentialSet new_cpt_list =
740  combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
741 
742  // there should be only one potential in new_cpt_list
743  if (new_cpt_list.size() != 1) {
744  // remove the CPT created to avoid memory leaks
745  for (auto pot : new_cpt_list) {
746  if (!marg_cpt_set.contains(pot)) delete pot;
747  }
748  GUM_ERROR(FatalError,
749  "the projection of a potential containing "
750  << "hard evidence is empty!");
751  }
752  const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
753  res.first.insert(projected_cpt);
754  res.second.insert(projected_cpt);
755  }
756  }
757 
758  // if the node received some soft evidence, add it
759  if (evidence.exists(node) && !hard_evidence.exists(node)) {
760  res.first.insert(this->evidence()[node]);
761  }
762  }
763 
764  return res;
765  }
766 
767 
768  // creates the message sent by clique from_id to clique to_id
769  template < typename GUM_SCALAR >
770  std::pair< Set< const Potential< GUM_SCALAR >* >,
771  Set< const Potential< GUM_SCALAR >* > >
773  NodeId from_id,
774  NodeId to_id,
775  std::pair< Set< const Potential< GUM_SCALAR >* >,
776  Set< const Potential< GUM_SCALAR >* > >&& incoming_messages) {
777  // get the messages sent by adjacent nodes to from_id
778  std::pair< Set< const Potential< GUM_SCALAR >* >,
779  Set< const Potential< GUM_SCALAR >* > >
780  pot_list(std::move(incoming_messages));
781 
782  // get the potentials of the clique
783  for (const auto node : __clique_potentials[from_id]) {
784  auto new_pots = __NodePotentials(node);
785  pot_list.first += new_pots.first;
786  pot_list.second += new_pots.second;
787  }
788 
789  // if from_id = to_id: this is the endpoint of a collect
790  if (!__JT->existsEdge(from_id, to_id)) {
791  return pot_list;
792  } else {
793  // get the set of variables that need be removed from the potentials
794  const NodeSet& from_clique = __JT->clique(from_id);
795  const NodeSet& separator = __JT->separator(from_id, to_id);
796  Set< const DiscreteVariable* > del_vars(from_clique.size());
797  Set< const DiscreteVariable* > kept_vars(separator.size());
798  const auto& bn = this->BN();
799 
800  for (const auto node : from_clique) {
801  if (!separator.contains(node)) {
802  del_vars.insert(&(bn.variable(node)));
803  } else {
804  kept_vars.insert(&(bn.variable(node)));
805  }
806  }
807 
808  // pot_list now contains all the potentials to multiply and marginalize
809  // => combine the messages
810  __PotentialSet new_pot_list =
811  __marginalizeOut(pot_list.first, del_vars, kept_vars);
812 
813  // remove all the potentials that are equal to ones (as probability
814  // matrix multiplications are tensorial, such potentials are useless)
815  for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
816  ++iter) {
817  const auto pot = *iter;
818  if (pot->variablesSequence().size() == 1) {
819  bool is_all_ones = true;
820  for (Instantiation inst(*pot); !inst.end(); ++inst) {
821  if ((*pot)[inst] < __1_minus_epsilon) {
822  is_all_ones = false;
823  break;
824  }
825  }
826  if (is_all_ones) {
827  if (!pot_list.first.exists(pot)) delete pot;
828  new_pot_list.erase(iter);
829  continue;
830  }
831  }
832  }
833 
834  // remove the unnecessary temporary messages
835  for (auto iter = pot_list.second.beginSafe();
836  iter != pot_list.second.endSafe();
837  ++iter) {
838  if (!new_pot_list.contains(*iter)) {
839  delete *iter;
840  pot_list.second.erase(iter);
841  }
842  }
843 
844  // keep track of all the newly created potentials
845  for (const auto pot : new_pot_list) {
846  if (!pot_list.first.contains(pot)) { pot_list.second.insert(pot); }
847  }
848 
849  // return the new set of potentials
850  return std::pair< __PotentialSet, __PotentialSet >(
851  std::move(new_pot_list), std::move(pot_list.second));
852  }
853  }
854 
855 
856  // remove variables del_vars from the list of potentials pot_list
857  template < typename GUM_SCALAR >
858  Set< const Potential< GUM_SCALAR >* >
860  Set< const Potential< GUM_SCALAR >* > pot_list,
861  Set< const DiscreteVariable* >& del_vars,
862  Set< const DiscreteVariable* >& kept_vars) {
863  // use d-separation analysis to check which potentials shall be combined
864  __findRelevantPotentialsXX(pot_list, kept_vars);
865 
866  // remove the potentials corresponding to barren variables if we want
867  // to exploit barren nodes
868  __PotentialSet barren_projected_potentials;
870  barren_projected_potentials = __removeBarrenVariables(pot_list, del_vars);
871  }
872 
873  // create a combine and project operator that will perform the
874  // marginalization
875  MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(
877  __PotentialSet new_pot_list =
878  combine_and_project.combineAndProject(pot_list, del_vars);
879 
880  // remove all the potentials that were created due to projections of
881  // barren nodes and that are not part of the new_pot_list: these
882  // potentials were just temporary potentials
883  for (auto iter = barren_projected_potentials.beginSafe();
884  iter != barren_projected_potentials.endSafe();
885  ++iter) {
886  if (!new_pot_list.exists(*iter)) delete *iter;
887  }
888 
889  // remove all the potentials that have no dimension
890  for (auto iter_pot = new_pot_list.beginSafe();
891  iter_pot != new_pot_list.endSafe();
892  ++iter_pot) {
893  if ((*iter_pot)->variablesSequence().size() == 0) {
894  // as we have already marginalized out variables that received evidence,
895  // it may be the case that, after combining and projecting, some
896  // potentials might be empty. In this case, we shall keep their
897  // constant and remove them from memory
898  // # TODO: keep the constants!
899  delete *iter_pot;
900  new_pot_list.erase(iter_pot);
901  }
902  }
903 
904  return new_pot_list;
905  }
906 
907 
908  // performs a whole inference
909  template < typename GUM_SCALAR >
911 
912 
914  template < typename GUM_SCALAR >
915  Potential< GUM_SCALAR >*
917  const auto& bn = this->BN();
918 
919  // hard evidence do not belong to the join tree
920  // # TODO: check for sets of inconsistent hard evidence
921  if (this->hardEvidenceNodes().contains(id)) {
922  return new Potential< GUM_SCALAR >(*(this->evidence()[id]));
923  }
924 
925  // if we still need to perform some inference task, do it
926  __createNewJT(NodeSet{id});
927  NodeId clique_of_id = __node_to_clique[id];
928  auto pot_list = __collectMessage(clique_of_id, clique_of_id);
929 
930  // get the set of variables that need be removed from the potentials
931  const NodeSet& nodes = __JT->clique(clique_of_id);
932  Set< const DiscreteVariable* > kept_vars{&(bn.variable(id))};
933  Set< const DiscreteVariable* > del_vars(nodes.size());
934  for (const auto node : nodes) {
935  if (node != id) del_vars.insert(&(bn.variable(node)));
936  }
937 
938  // pot_list now contains all the potentials to multiply and marginalize
939  // => combine the messages
940  __PotentialSet new_pot_list =
941  __marginalizeOut(pot_list.first, del_vars, kept_vars);
942  Potential< GUM_SCALAR >* joint = nullptr;
943 
944  if (new_pot_list.size() == 1) {
945  joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
946  // if joint already existed, create a copy, so that we can put it into
947  // the __target_posterior property
948  if (pot_list.first.exists(joint)) {
949  joint = new Potential< GUM_SCALAR >(*joint);
950  } else {
951  // remove the joint from new_pot_list so that it will not be
952  // removed just after the else block
953  new_pot_list.clear();
954  }
955  } else {
956  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
958  joint = fast_combination.combine(new_pot_list);
959  }
960 
961  // remove the potentials that were created in new_pot_list
962  for (auto pot : new_pot_list)
963  if (!pot_list.first.exists(pot)) delete pot;
964 
965  // remove all the temporary potentials created in pot_list
966  for (auto pot : pot_list.second)
967  delete pot;
968 
969  // check that the joint posterior is different from a 0 vector: this would
970  // indicate that some hard evidence are not compatible (their joint
971  // probability is equal to 0)
972  bool nonzero_found = false;
973  for (Instantiation inst(*joint); !inst.end(); ++inst) {
974  if ((*joint)[inst]) {
975  nonzero_found = true;
976  break;
977  }
978  }
979  if (!nonzero_found) {
980  // remove joint from memory to avoid memory leaks
981  delete joint;
982  GUM_ERROR(IncompatibleEvidence,
983  "some evidence entered into the Bayes "
984  "net are incompatible (their joint proba = 0)");
985  }
986 
987  return joint;
988  }
989 
990 
992  template < typename GUM_SCALAR >
993  const Potential< GUM_SCALAR >&
995  // compute the joint posterior and normalize
996  auto joint = _unnormalizedJointPosterior(id);
997  joint->normalize();
998 
999  if (__target_posterior != nullptr) delete __target_posterior;
1000  __target_posterior = joint;
1001 
1002  return *joint;
1003  }
1004 
1005 
1006  // returns the marginal a posteriori proba of a given node
1007  template < typename GUM_SCALAR >
1008  Potential< GUM_SCALAR >*
1010  const NodeSet& set) {
1011  // hard evidence do not belong to the join tree, so extract the nodes
1012  // from targets that are not hard evidence
1013  NodeSet targets = set, hard_ev_nodes;
1014  for (const auto node : this->hardEvidenceNodes()) {
1015  if (targets.contains(node)) {
1016  targets.erase(node);
1017  hard_ev_nodes.insert(node);
1018  }
1019  }
1020 
1021  // if all the nodes have received hard evidence, then compute the
1022  // joint posterior directly by multiplying the hard evidence potentials
1023  const auto& evidence = this->evidence();
1024  if (targets.empty()) {
1025  __PotentialSet pot_list;
1026  for (const auto node : set) {
1027  pot_list.insert(evidence[node]);
1028  }
1029  if (pot_list.size() == 1) {
1030  return new Potential< GUM_SCALAR >(**(pot_list.begin()));
1031  } else {
1032  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1034  return fast_combination.combine(pot_list);
1035  }
1036  }
1037 
1038  // if we still need to perform some inference task, do it
1039  __createNewJT(set);
1041 
1042  // get the set of variables that need be removed from the potentials
1043  const NodeSet& nodes = __JT->clique(__targets2clique);
1044  Set< const DiscreteVariable* > del_vars(nodes.size());
1045  Set< const DiscreteVariable* > kept_vars(targets.size());
1046  const auto& bn = this->BN();
1047  for (const auto node : nodes) {
1048  if (!targets.contains(node)) {
1049  del_vars.insert(&(bn.variable(node)));
1050  } else {
1051  kept_vars.insert(&(bn.variable(node)));
1052  }
1053  }
1054 
1055  // pot_list now contains all the potentials to multiply and marginalize
1056  // => combine the messages
1057  __PotentialSet new_pot_list =
1058  __marginalizeOut(pot_list.first, del_vars, kept_vars);
1059  Potential< GUM_SCALAR >* joint = nullptr;
1060 
1061  if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1062  joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
1063  // if pot already existed, create a copy, so that we can put it into
1064  // the __target_posteriors property
1065  if (pot_list.first.exists(joint)) {
1066  joint = new Potential< GUM_SCALAR >(*joint);
1067  } else {
1068  // remove the joint from new_pot_list so that it will not be
1069  // removed just after the next else block
1070  new_pot_list.clear();
1071  }
1072  } else {
1073  // combine all the potentials in new_pot_list with all the hard evidence
1074  // of the nodes in set
1075  __PotentialSet new_new_pot_list = new_pot_list;
1076  for (const auto node : hard_ev_nodes) {
1077  new_new_pot_list.insert(evidence[node]);
1078  }
1079  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1081  joint = fast_combination.combine(new_new_pot_list);
1082  }
1083 
1084  // remove the potentials that were created in new_pot_list
1085  for (auto pot : new_pot_list)
1086  if (!pot_list.first.exists(pot)) delete pot;
1087 
1088  // remove all the temporary potentials created in pot_list
1089  for (auto pot : pot_list.second)
1090  delete pot;
1091 
1092  // check that the joint posterior is different from a 0 vector: this would
1093  // indicate that some hard evidence are not compatible
1094  bool nonzero_found = false;
1095  for (Instantiation inst(*joint); !inst.end(); ++inst) {
1096  if ((*joint)[inst]) {
1097  nonzero_found = true;
1098  break;
1099  }
1100  }
1101  if (!nonzero_found) {
1102  // remove joint from memory to avoid memory leaks
1103  delete joint;
1104  GUM_ERROR(IncompatibleEvidence,
1105  "some evidence entered into the Bayes "
1106  "net are incompatible (their joint proba = 0)");
1107  }
1108 
1109  return joint;
1110  }
1111 
1112 
1114  template < typename GUM_SCALAR >
1115  const Potential< GUM_SCALAR >&
1117  // compute the joint posterior and normalize
1118  auto joint = _unnormalizedJointPosterior(set);
1119  joint->normalize();
1120 
1121  if (__target_posterior != nullptr) delete __target_posterior;
1122  __target_posterior = joint;
1123 
1124  return *joint;
1125  }
1126 
1127 
1129  template < typename GUM_SCALAR >
1130  const Potential< GUM_SCALAR >&
1132  const NodeSet& wanted_target, const NodeSet& declared_target) {
1133  return _jointPosterior(wanted_target);
1134  }
1135 
1136 
1137 } /* namespace gum */
1138 
1139 #endif // DOXYGEN_SHOULD_SKIP_THIS
void __setProjectionFunction(Potential< GUM_SCALAR > *(*proj)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
sets the operator for performing the projections
void _onEvidenceAdded(NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
void(VariableElimination< GUM_SCALAR >::* __findRelevantPotentials)(Set< const Potential< GUM_SCALAR > * > &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
NodeId __targets2clique
indicate a clique that contains all the nodes of the target
VariableElimination(const IBayesNet< GUM_SCALAR > *BN, RelevantPotentialsFinderType relevant_type=RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS, FindBarrenNodesType=FindBarrenNodesType::FIND_BARREN_NODES)
default constructor
unsigned long Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:50
void _onAllTargetsErased() final
fired before a all single and joint_targets are removed
__PotentialSet __removeBarrenVariables(__PotentialSet &pot_list, Set< const DiscreteVariable * > &del_vars)
virtual void clear()
removes all the nodes and edges from the graph
Definition: undiGraph_inl.h:40
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
unsigned int NodeId
Type for node ids.
Definition: graphElements.h:97
static INLINE Potential< GUM_SCALAR > * VENewprojPotential(const Potential< GUM_SCALAR > &t1, const Set< const DiscreteVariable * > &del_vars)
const NodeProperty< const Potential< GUM_SCALAR > * > & evidence() const
returns the set of evidence
void __setCombinationFunction(Potential< GUM_SCALAR > *(*comb)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &))
sets the operator for performing the combinations
node_iterator_safe beginSafe() const
a begin iterator to parse the set of nodes contained in the NodeGraphPart
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
const Potential< GUM_SCALAR > & _posterior(NodeId id) final
returns the posterior of a given variable
JunctionTree * __JT
the junction tree used to answer the last inference query
RelevantPotentialsFinderType __find_relevant_potential_type
the type of relevant potential finding algorithm to be used
void _onEvidenceChanged(NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes ...
void __findRelevantPotentialsXX(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void setFindBarrenNodesType(FindBarrenNodesType type)
sets how we determine barren nodes
d-separation analysis (as described in Koller & Friedman 2009)
static void relevantPotentials(const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
update a set of potentials, keeping only those d-connected with query variables given evidence ...
Definition: BayesBall_tpl.h:32
std::pair< __PotentialSet, __PotentialSet > __collectMessage(NodeId id, NodeId from)
actually perform the collect phase
Set< const Potential< GUM_SCALAR > * > __PotentialSet
void _makeInference() final
called when the inference has to be performed effectively
The BayesBall algorithm (as described by Schachter).
Potential< GUM_SCALAR > * __target_posterior
the posterior computed during the last inference
Triangulation * __triangulation
the triangulation class creating the junction tree used for inference
void _updateOutdatedBNPotentials() final
prepares inference when the latter is in OutdatedBNPotentials state
void erase(const Key &k)
Erases an element from the set.
Definition: set_tpl.h:656
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:581
void _onEvidenceErased(NodeId id, bool isHardEvidence) final
fired before an evidence is removed
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
const Potential< GUM_SCALAR > & _jointPosterior(const NodeSet &set) final
returns the posterior of a declared target set
void __findRelevantPotentialsGetAll(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
const NodeProperty< Idx > & hardEvidence() const
indicate for each node with hard evidence which value it took
void __findRelevantPotentialsWithdSeparation(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
RelevantPotentialsFinderType
type of algorithm for determining the relevant potentials for combinations using some d-separation an...
Implementation of a variable elimination algorithm for inference in Bayesian Networks.
const JunctionTree * junctionTree(NodeId id)
returns the join tree used for compute the posterior of node id
void _onAllMarginalTargetsErased() final
fired before a all the single targets are removed
void __findRelevantPotentialsWithdSeparation3(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
Potential< GUM_SCALAR > *(* __projection_op)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &)
the operator for performing the projections
virtual void eraseNode(const NodeId id)
remove a node and its adjacent edges from the graph
Definition: undiGraph_inl.h:55
Potential< GUM_SCALAR > *(* __combination_op)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &)
the operator for performing the combinations
FindBarrenNodesType
type of algorithm to determine barren nodes
UndiGraph __graph
the undigraph extracted from the BN and used to construct the join tree
virtual const NodeProperty< Size > & domainSizes() const final
get the domain sizes of the random variables of the BN
void _onAllEvidenceErased(bool contains_hard_evidence) final
fired before all the evidence are erased
void __findRelevantPotentialsWithdSeparation2(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
virtual NodeId createdJunctionTreeClique(const NodeId id)=0
returns the Id of the clique created by the elimination of a given node during the triangulation proc...
Header files of gum::Instantiation.
HashTable< NodeId, NodeSet > __clique_potentials
for each BN node, indicate in which clique its CPT will be stored
void _onJointTargetErased(const NodeSet &set) final
fired before a joint target is removed
void _onJointTargetAdded(const NodeSet &set) final
fired after a new joint target is inserted
FindBarrenNodesType __barren_nodes_type
the type of barren nodes computation we wish
void _onAllJointTargetsErased() final
fired before a all the joint targets are removed
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
void setRelevantPotentialsFinderType(RelevantPotentialsFinderType type)
sets how we determine the relevant potentials to combine
HashTable< NodeId, NodeId > __node_to_clique
for each node of __graph (~ in the Bayes net), associate an ID in the JT
const NodeSet & hardEvidenceNodes() const
returns the set of nodes with hard evidence
Potential< GUM_SCALAR > * _unnormalizedJointPosterior(NodeId id) final
returns a fresh potential equal to P(argument,evidence)
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
Definition: cliqueGraph.h:299
void _onMarginalTargetAdded(NodeId id) final
fired after a new single target is inserted
__PotentialSet __marginalizeOut(__PotentialSet pot_list, Set< const DiscreteVariable * > &del_vars, Set< const DiscreteVariable * > &kept_vars)
removes variables del_vars from a list of potentials and returns the resulting list ...
void clear()
Removes all the elements in the hash table.
std::pair< __PotentialSet, __PotentialSet > __NodePotentials(NodeId node)
returns the CPT + evidence of a node projected w.r.t. hard evidence
virtual const NodeSet & targets() const noexceptfinal
returns the list of marginal targets
void _updateOutdatedBNStructure() final
prepares inference when the latter is in OutdatedBNStructure state
void _onAllMarginalTargetsAdded() final
fired after all the nodes of the BN are added as single targets
An algorithm for converting a join tree into a binary join tree.
const NodeSet & softEvidenceNodes() const
returns the set of nodes with soft evidence
std::pair< __PotentialSet, __PotentialSet > __produceMessage(NodeId from_id, NodeId to_id, std::pair< __PotentialSet, __PotentialSet > &&incoming_messages)
creates the message sent by clique from_id to clique to_id
virtual Triangulation * newFactory() const =0
returns a fresh triangulation of the same type as the current object but with an empty graph ...
const node_iterator_safe & endSafe() const noexcept
the end iterator to parse the set of nodes contained in the NodeGraphPart
void __createNewJT(const NodeSet &targets)
create a new junction tree as well as its related data structures
~VariableElimination() final
destructor
virtual const CliqueGraph & junctionTree()=0
returns a compatible junction tree
const GUM_SCALAR __1_minus_epsilon
for comparisons with 1 - epsilon
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
virtual const std::vector< NodeId > & eliminationOrder()=0
returns an elimination ordering compatible with the triangulated graph
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:613
#define GUM_ERROR(type, msg)
Definition: exceptions.h:66
void _onMarginalTargetErased(NodeId id) final
fired before a single target is removed
virtual void setGraph(const UndiGraph *graph, const NodeProperty< Size > *domsizes)=0
initialize the triangulation data structures for a new graph
static void requisiteNodes(const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
Fill the &#39;requisite&#39; nodeset with the requisite nodes in dag given a query and evidence.
Definition: BayesBall.cpp:33