aGrUM  0.16.0
variableElimination_tpl.h
Go to the documentation of this file.
1 
31 #ifndef DOXYGEN_SHOULD_SKIP_THIS
32 
34 
41 
42 
43 namespace gum {
44 
45 
46  // default constructor
47  template < typename GUM_SCALAR >
49  const IBayesNet< GUM_SCALAR >* BN,
50  RelevantPotentialsFinderType relevant_type,
51  FindBarrenNodesType barren_type) :
52  JointTargetedInference< GUM_SCALAR >(BN) {
53  // sets the relevant potential and the barren nodes finding algorithm
54  setRelevantPotentialsFinderType(relevant_type);
55  setFindBarrenNodesType(barren_type);
56 
57  // create a default triangulation (the user can change it afterwards)
58  __triangulation = new DefaultTriangulation;
59 
60  // for debugging purposessetRequiredInference
61  GUM_CONSTRUCTOR(VariableElimination);
62  }
63 
64 
65  // destructor
66  template < typename GUM_SCALAR >
68  // remove the junction tree and the triangulation algorithm
69  if (__JT != nullptr) delete __JT;
70  delete __triangulation;
71  if (__target_posterior != nullptr) delete __target_posterior;
72 
73  // for debugging purposes
74  GUM_DESTRUCTOR(VariableElimination);
75  }
76 
77 
79  template < typename GUM_SCALAR >
81  const Triangulation& new_triangulation) {
82  delete __triangulation;
83  __triangulation = new_triangulation.newFactory();
84  }
85 
86 
88  template < typename GUM_SCALAR >
89  INLINE const JunctionTree*
92 
93  return __JT;
94  }
95 
96 
98  template < typename GUM_SCALAR >
100  Potential< GUM_SCALAR >* (*proj)(const Potential< GUM_SCALAR >&,
101  const Set< const DiscreteVariable* >&)) {
102  __projection_op = proj;
103  }
104 
105 
107  template < typename GUM_SCALAR >
109  Potential< GUM_SCALAR >* (*comb)(const Potential< GUM_SCALAR >&,
110  const Potential< GUM_SCALAR >&)) {
111  __combination_op = comb;
112  }
113 
114 
116  template < typename GUM_SCALAR >
119  if (type != __find_relevant_potential_type) {
120  switch (type) {
124  break;
125 
129  break;
130 
134  break;
135 
139  break;
140 
141  default:
142  GUM_ERROR(InvalidArgument,
143  "setRelevantPotentialsFinderType for type "
144  << (unsigned int)type << " is not implemented yet");
145  }
146 
148  }
149  }
150 
151 
153  template < typename GUM_SCALAR >
155  FindBarrenNodesType type) {
156  if (type != __barren_nodes_type) {
157  // WARNING: if a new type is added here, method __createJT should
158  // certainly
159  // be updated as well, in particular its step 2.
160  switch (type) {
163 
164  default:
165  GUM_ERROR(InvalidArgument,
166  "setFindBarrenNodesType for type "
167  << (unsigned int)type << " is not implemented yet");
168  }
169 
170  __barren_nodes_type = type;
171  }
172  }
173 
174 
176  template < typename GUM_SCALAR >
178  bool) {}
179 
180 
182  template < typename GUM_SCALAR >
184  bool) {}
185 
186 
188  template < typename GUM_SCALAR >
190 
191 
193  template < typename GUM_SCALAR >
195  bool) {}
196 
197 
199  template < typename GUM_SCALAR >
200  INLINE void
202 
203 
205  template < typename GUM_SCALAR >
206  INLINE void
208 
210  template < typename GUM_SCALAR >
212  const IBayesNet< GUM_SCALAR >* bn) {}
213 
215  template < typename GUM_SCALAR >
216  INLINE void
218 
219 
221  template < typename GUM_SCALAR >
222  INLINE void
224 
225 
227  template < typename GUM_SCALAR >
229 
230 
232  template < typename GUM_SCALAR >
234 
235 
237  template < typename GUM_SCALAR >
239 
240 
242  template < typename GUM_SCALAR >
244 
245 
247  template < typename GUM_SCALAR >
249  // to create the JT, we first create the moral graph of the BN in the
250  // following way in order to take into account the barren nodes and the
251  // nodes that received evidence:
252  // 1/ we create an undirected graph containing only the nodes and no edge
253  // 2/ if we take into account barren nodes, remove them from the graph
254  // 3/ if we take d-separation into account, remove the d-separated nodes
255  // 4/ add edges so that each node and its parents in the BN form a clique
256  // 5/ add edges so that the targets form a clique of the moral graph
257  // 6/ remove the nodes that received hard evidence (by step 4/, their
258  // parents are linked by edges, which is necessary for inference)
259  //
260  // At the end of step 6/, we have our moral graph and we can triangulate it
261  // to get the new junction tree
262 
263  // 1/ create an undirected graph containing only the nodes and no edge
264  const auto& bn = this->BN();
265  __graph.clear();
266  for (auto node : bn.dag())
267  __graph.addNodeWithId(node);
268 
269  // 2/ if we wish to exploit barren nodes, we shall remove them from the BN
270  // to do so: we identify all the nodes that are not targets and have
271  // received no evidence and such that their descendants are neither targets
272  // nor evidence nodes. Such nodes can be safely discarded from the BN
273  // without altering the inference output
275  // check that all the nodes are not targets, otherwise, there is no
276  // barren node
277  if (targets.size() != bn.size()) {
278  BarrenNodesFinder finder(&(bn.dag()));
279  finder.setTargets(&targets);
280 
281  NodeSet evidence_nodes;
282  for (const auto& pair : this->evidence()) {
283  evidence_nodes.insert(pair.first);
284  }
285  finder.setEvidence(&evidence_nodes);
286 
287  NodeSet barren_nodes = finder.barrenNodes();
288 
289  // remove the barren nodes from the moral graph
290  for (const auto node : barren_nodes) {
291  __graph.eraseNode(node);
292  }
293  }
294  }
295 
296  // 3/ if we wish to exploit d-separation, remove all the nodes that are
297  // d-separated from our targets
298  {
299  NodeSet requisite_nodes;
300  bool dsep_analysis = false;
304  BayesBall::requisiteNodes(bn.dag(),
305  targets,
306  this->hardEvidenceNodes(),
307  this->softEvidenceNodes(),
308  requisite_nodes);
309  dsep_analysis = true;
310  } break;
311 
313  dSeparation dsep;
314  dsep.requisiteNodes(bn.dag(),
315  targets,
316  this->hardEvidenceNodes(),
317  this->softEvidenceNodes(),
318  requisite_nodes);
319  dsep_analysis = true;
320  } break;
321 
323 
324  default: GUM_ERROR(FatalError, "not implemented yet");
325  }
326 
327  // remove all the nodes that are not requisite
328  if (dsep_analysis) {
329  for (auto iter = __graph.beginSafe(); iter != __graph.endSafe(); ++iter) {
330  if (!requisite_nodes.contains(*iter)
331  && !this->hardEvidenceNodes().contains(*iter)) {
332  __graph.eraseNode(*iter);
333  }
334  }
335  }
336  }
337 
338  // 4/ add edges so that each node and its parents in the BN form a clique
339  for (const auto node : __graph) {
340  const NodeSet& parents = bn.parents(node);
341  for (auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
342  // before adding an edge between node and its parent, check that the
343  // parent belong to the graph. Actually, when d-separated nodes are
344  // removed, it may be the case that the parents of hard evidence nodes
345  // are removed. But the latter still exist in the graph.
346  if (__graph.existsNode(*iter1)) __graph.addEdge(*iter1, node);
347 
348  auto iter2 = iter1;
349  for (++iter2; iter2 != parents.cend(); ++iter2) {
350  // before adding an edge, check that both extremities belong to
351  // the graph. Actually, when d-separated nodes are removed, it may
352  // be the case that the parents of hard evidence nodes are removed.
353  // But the latter still exist in the graph.
354  if (__graph.existsNode(*iter1) && __graph.existsNode(*iter2))
355  __graph.addEdge(*iter1, *iter2);
356  }
357  }
358  }
359 
360  // 5/ if targets contains several nodes, we shall add new edges into the
361  // moral graph in order to ensure that there exists a clique containing
362  // thier joint distribution
363  for (auto iter1 = targets.cbegin(); iter1 != targets.cend(); ++iter1) {
364  auto iter2 = iter1;
365  for (++iter2; iter2 != targets.cend(); ++iter2) {
366  __graph.addEdge(*iter1, *iter2);
367  }
368  }
369 
370  // 6/ remove all the nodes that received hard evidence
371  for (const auto node : this->hardEvidenceNodes()) {
372  __graph.eraseNode(node);
373  }
374 
375 
376  // now, we can compute the new junction tree.
377  if (__JT != nullptr) delete __JT;
378  __triangulation->setGraph(&__graph, &(this->domainSizes()));
379  const JunctionTree& triang_jt = __triangulation->junctionTree();
380  __JT = new CliqueGraph(triang_jt);
381 
382  // indicate, for each node of the moral graph a clique in __JT that can
383  // contain its conditional probability table
385  __clique_potentials.clear();
386  NodeSet emptyset;
387  for (auto clique : *__JT)
388  __clique_potentials.insert(clique, emptyset);
389  const std::vector< NodeId >& JT_elim_order =
391  NodeProperty< Size > elim_order(Size(JT_elim_order.size()));
392  for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
393  ++i)
394  elim_order.insert(JT_elim_order[i], NodeId(i));
395  const DAG& dag = bn.dag();
396  for (const auto node : __graph) {
397  // get the variables in the potential of node (and its parents)
398  NodeId first_eliminated_node = node;
399  Size elim_number = elim_order[first_eliminated_node];
400 
401  for (const auto parent : dag.parents(node)) {
402  if (__graph.existsNode(parent) && (elim_order[parent] < elim_number)) {
403  elim_number = elim_order[parent];
404  first_eliminated_node = parent;
405  }
406  }
407 
408  // first_eliminated_node contains the first var (node or one of its
409  // parents) eliminated => the clique created during its elimination
410  // contains node and all of its parents => it can contain the potential
411  // assigned to the node in the BN
412  NodeId clique =
413  __triangulation->createdJunctionTreeClique(first_eliminated_node);
414  __node_to_clique.insert(node, clique);
415  __clique_potentials[clique].insert(node);
416  }
417 
418  // do the same for the nodes that received evidence. Here, we only store
419  // the nodes whose at least one parent belongs to __graph (otherwise
420  // their CPT is just a constant real number).
421  for (const auto node : this->hardEvidenceNodes()) {
422  // get the set of parents of the node that belong to __graph
423  NodeSet pars(dag.parents(node).size());
424  for (const auto par : dag.parents(node))
425  if (__graph.exists(par)) pars.insert(par);
426 
427  if (!pars.empty()) {
428  NodeId first_eliminated_node = *(pars.begin());
429  Size elim_number = elim_order[first_eliminated_node];
430 
431  for (const auto parent : pars) {
432  if (elim_order[parent] < elim_number) {
433  elim_number = elim_order[parent];
434  first_eliminated_node = parent;
435  }
436  }
437 
438  // first_eliminated_node contains the first var (node or one of its
439  // parents) eliminated => the clique created during its elimination
440  // contains node and all of its parents => it can contain the potential
441  // assigned to the node in the BN
442  NodeId clique =
443  __triangulation->createdJunctionTreeClique(first_eliminated_node);
444  __node_to_clique.insert(node, clique);
445  __clique_potentials[clique].insert(node);
446  }
447  }
448 
449 
450  // indicate a clique that contains all the nodes of targets
451  __targets2clique = std::numeric_limits< NodeId >::max();
452  {
453  // remove from set all the nodes that received hard evidence (since they
454  // do not belong to the join tree)
455  NodeSet nodeset = targets;
456  for (const auto node : this->hardEvidenceNodes())
457  if (nodeset.contains(node)) nodeset.erase(node);
458 
459  if (!nodeset.empty()) {
460  NodeId first_eliminated_node = *(nodeset.begin());
461  Size elim_number = elim_order[first_eliminated_node];
462  for (const auto node : nodeset) {
463  if (elim_order[node] < elim_number) {
464  elim_number = elim_order[node];
465  first_eliminated_node = node;
466  }
467  }
469  __triangulation->createdJunctionTreeClique(first_eliminated_node);
470  }
471  }
472  }
473 
474 
476  template < typename GUM_SCALAR >
478 
479 
482  template < typename GUM_SCALAR >
484 
485 
486  // find the potentials d-connected to a set of variables
487  template < typename GUM_SCALAR >
489  Set< const Potential< GUM_SCALAR >* >& pot_list,
490  Set< const DiscreteVariable* >& kept_vars) {}
491 
492 
493  // find the potentials d-connected to a set of variables
494  template < typename GUM_SCALAR >
496  Set< const Potential< GUM_SCALAR >* >& pot_list,
497  Set< const DiscreteVariable* >& kept_vars) {
498  // find the node ids of the kept variables
499  NodeSet kept_ids;
500  const auto& bn = this->BN();
501  for (const auto var : kept_vars) {
502  kept_ids.insert(bn.nodeId(*var));
503  }
504 
505  // determine the set of potentials d-connected with the kept variables
506  NodeSet requisite_nodes;
507  BayesBall::requisiteNodes(bn.dag(),
508  kept_ids,
509  this->hardEvidenceNodes(),
510  this->softEvidenceNodes(),
511  requisite_nodes);
512  for (auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) {
513  const Sequence< const DiscreteVariable* >& vars =
514  (**iter).variablesSequence();
515  bool found = false;
516  for (auto var : vars) {
517  if (requisite_nodes.exists(bn.nodeId(*var))) {
518  found = true;
519  break;
520  }
521  }
522 
523  if (!found) { pot_list.erase(iter); }
524  }
525  }
526 
527 
528  // find the potentials d-connected to a set of variables
529  template < typename GUM_SCALAR >
531  Set< const Potential< GUM_SCALAR >* >& pot_list,
532  Set< const DiscreteVariable* >& kept_vars) {
533  // find the node ids of the kept variables
534  NodeSet kept_ids;
535  const auto& bn = this->BN();
536  for (const auto var : kept_vars) {
537  kept_ids.insert(bn.nodeId(*var));
538  }
539 
540  // determine the set of potentials d-connected with the kept variables
542  kept_ids,
543  this->hardEvidenceNodes(),
544  this->softEvidenceNodes(),
545  pot_list);
546  }
547 
548 
549  // find the potentials d-connected to a set of variables
550  template < typename GUM_SCALAR >
552  Set< const Potential< GUM_SCALAR >* >& pot_list,
553  Set< const DiscreteVariable* >& kept_vars) {
554  // find the node ids of the kept variables
555  NodeSet kept_ids;
556  const auto& bn = this->BN();
557  for (const auto var : kept_vars) {
558  kept_ids.insert(bn.nodeId(*var));
559  }
560 
561  // determine the set of potentials d-connected with the kept variables
562  dSeparation dsep;
563  dsep.relevantPotentials(bn,
564  kept_ids,
565  this->hardEvidenceNodes(),
566  this->softEvidenceNodes(),
567  pot_list);
568  }
569 
570 
571  // find the potentials d-connected to a set of variables
572  template < typename GUM_SCALAR >
574  Set< const Potential< GUM_SCALAR >* >& pot_list,
575  Set< const DiscreteVariable* >& kept_vars) {
578  __findRelevantPotentialsWithdSeparation2(pot_list, kept_vars);
579  break;
580 
582  __findRelevantPotentialsWithdSeparation(pot_list, kept_vars);
583  break;
584 
586  __findRelevantPotentialsWithdSeparation3(pot_list, kept_vars);
587  break;
588 
590  __findRelevantPotentialsGetAll(pot_list, kept_vars);
591  break;
592 
593  default: GUM_ERROR(FatalError, "not implemented yet");
594  }
595  }
596 
597 
598  // remove barren variables
599  template < typename GUM_SCALAR >
600  Set< const Potential< GUM_SCALAR >* >
602  __PotentialSet& pot_list, Set< const DiscreteVariable* >& del_vars) {
603  // remove from del_vars the variables that received some evidence:
604  // only those that did not received evidence can be barren variables
605  Set< const DiscreteVariable* > the_del_vars = del_vars;
606  for (auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe();
607  ++iter) {
608  NodeId id = this->BN().nodeId(**iter);
609  if (this->hardEvidenceNodes().exists(id)
610  || this->softEvidenceNodes().exists(id)) {
611  the_del_vars.erase(iter);
612  }
613  }
614 
615  // assign to each random variable the set of potentials that contain it
616  HashTable< const DiscreteVariable*, __PotentialSet > var2pots;
617  __PotentialSet empty_pot_set;
618  for (const auto pot : pot_list) {
619  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
620  for (const auto var : vars) {
621  if (the_del_vars.exists(var)) {
622  if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
623  var2pots[var].insert(pot);
624  }
625  }
626  }
627 
628  // each variable with only one potential is a barren variable
629  // assign to each potential with barren nodes its set of barren variables
630  HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > >
631  pot2barren_var;
632  Set< const DiscreteVariable* > empty_var_set;
633  for (auto elt : var2pots) {
634  if (elt.second.size() == 1) { // here we have a barren variable
635  const Potential< GUM_SCALAR >* pot = *(elt.second.begin());
636  if (!pot2barren_var.exists(pot)) {
637  pot2barren_var.insert(pot, empty_var_set);
638  }
639  pot2barren_var[pot].insert(elt.first); // insert the barren variable
640  }
641  }
642 
643  // for each potential with barren variables, marginalize them.
644  // if the potential has only barren variables, simply remove them from the
645  // set of potentials, else just project the potential
646  MultiDimProjection< GUM_SCALAR, Potential > projector(VENewprojPotential);
647  __PotentialSet projected_pots;
648  for (auto elt : pot2barren_var) {
649  // remove the current potential from pot_list as, anyway, we will change
650  // it
651  const Potential< GUM_SCALAR >* pot = elt.first;
652  pot_list.erase(pot);
653 
654  // check whether we need to add a projected new potential or not (i.e.,
655  // whether there exist non-barren variables or not)
656  if (pot->variablesSequence().size() != elt.second.size()) {
657  auto new_pot = projector.project(*pot, elt.second);
658  pot_list.insert(new_pot);
659  projected_pots.insert(new_pot);
660  }
661  }
662 
663  return projected_pots;
664  }
665 
666 
667  // performs the collect phase of Lazy Propagation
668  template < typename GUM_SCALAR >
669  std::pair< Set< const Potential< GUM_SCALAR >* >,
670  Set< const Potential< GUM_SCALAR >* > >
672  // collect messages from all the neighbors
673  std::pair< __PotentialSet, __PotentialSet > collect_messages;
674  for (const auto other : __JT->neighbours(id)) {
675  if (other != from) {
676  std::pair< __PotentialSet, __PotentialSet > message(
677  __collectMessage(other, id));
678  collect_messages.first += message.first;
679  collect_messages.second += message.second;
680  }
681  }
682 
683  // combine the collect messages with those of id's clique
684  return __produceMessage(id, from, std::move(collect_messages));
685  }
686 
687 
688  // get the CPT + evidence of a node projected w.r.t. hard evidence
689  template < typename GUM_SCALAR >
690  std::pair< Set< const Potential< GUM_SCALAR >* >,
691  Set< const Potential< GUM_SCALAR >* > >
693  std::pair< __PotentialSet, __PotentialSet > res;
694  const auto& bn = this->BN();
695 
696  // get the CPT's of the node
697  // beware: all the potentials that are defined over some nodes
698  // including hard evidence must be projected so that these nodes are
699  // removed from the potential
700  // also beware that the CPT of a hard evidence node may be defined over
701  // parents that do not belong to __graph and that are not hard evidence.
702  // In this case, those parents have been removed by d-separation and it is
703  // easy to show that, in this case all the parents have been removed, so
704  // that the CPT does not need to be taken into account
705  const auto& evidence = this->evidence();
706  const auto& hard_evidence = this->hardEvidence();
707  if (__graph.exists(node) || this->hardEvidenceNodes().contains(node)) {
708  const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
709  const auto& variables = cpt.variablesSequence();
710 
711  // check if the parents of a hard evidence node do not belong to __graph
712  // and are not themselves hard evidence, discard the CPT, it is useless
713  // for inference
714  if (this->hardEvidenceNodes().contains(node)) {
715  for (const auto var : variables) {
716  NodeId xnode = bn.nodeId(*var);
717  if (!this->hardEvidenceNodes().contains(xnode)
718  && !__graph.existsNode(xnode))
719  return res;
720  }
721  }
722 
723  // get the list of nodes with hard evidence in cpt
724  NodeSet hard_nodes;
725  for (const auto var : variables) {
726  NodeId xnode = bn.nodeId(*var);
727  if (this->hardEvidenceNodes().contains(xnode)) hard_nodes.insert(xnode);
728  }
729 
730  // if hard_nodes contains hard evidence nodes, perform a projection
731  // and insert the result into the appropriate clique, else insert
732  // directly cpt into the clique
733  if (hard_nodes.empty()) {
734  res.first.insert(&cpt);
735  } else {
736  // marginalize out the hard evidence nodes: if the cpt is defined
737  // only over nodes that received hard evidence, do not consider it
738  // as a potential anymore
739  if (hard_nodes.size() != variables.size()) {
740  // perform the projection with a combine and project instance
741  Set< const DiscreteVariable* > hard_variables;
742  __PotentialSet marg_cpt_set{&cpt};
743  for (const auto xnode : hard_nodes) {
744  marg_cpt_set.insert(evidence[xnode]);
745  hard_variables.insert(&(bn.variable(xnode)));
746  }
747  // perform the combination of those potentials and their projection
748  MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
749  combine_and_project(__combination_op, VENewprojPotential);
750  __PotentialSet new_cpt_list =
751  combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
752 
753  // there should be only one potential in new_cpt_list
754  if (new_cpt_list.size() != 1) {
755  // remove the CPT created to avoid memory leaks
756  for (auto pot : new_cpt_list) {
757  if (!marg_cpt_set.contains(pot)) delete pot;
758  }
759  GUM_ERROR(FatalError,
760  "the projection of a potential containing "
761  << "hard evidence is empty!");
762  }
763  const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
764  res.first.insert(projected_cpt);
765  res.second.insert(projected_cpt);
766  }
767  }
768 
769  // if the node received some soft evidence, add it
770  if (evidence.exists(node) && !hard_evidence.exists(node)) {
771  res.first.insert(this->evidence()[node]);
772  }
773  }
774 
775  return res;
776  }
777 
778 
779  // creates the message sent by clique from_id to clique to_id
780  template < typename GUM_SCALAR >
781  std::pair< Set< const Potential< GUM_SCALAR >* >,
782  Set< const Potential< GUM_SCALAR >* > >
784  NodeId from_id,
785  NodeId to_id,
786  std::pair< Set< const Potential< GUM_SCALAR >* >,
787  Set< const Potential< GUM_SCALAR >* > >&& incoming_messages) {
788  // get the messages sent by adjacent nodes to from_id
789  std::pair< Set< const Potential< GUM_SCALAR >* >,
790  Set< const Potential< GUM_SCALAR >* > >
791  pot_list(std::move(incoming_messages));
792 
793  // get the potentials of the clique
794  for (const auto node : __clique_potentials[from_id]) {
795  auto new_pots = __NodePotentials(node);
796  pot_list.first += new_pots.first;
797  pot_list.second += new_pots.second;
798  }
799 
800  // if from_id = to_id: this is the endpoint of a collect
801  if (!__JT->existsEdge(from_id, to_id)) {
802  return pot_list;
803  } else {
804  // get the set of variables that need be removed from the potentials
805  const NodeSet& from_clique = __JT->clique(from_id);
806  const NodeSet& separator = __JT->separator(from_id, to_id);
807  Set< const DiscreteVariable* > del_vars(from_clique.size());
808  Set< const DiscreteVariable* > kept_vars(separator.size());
809  const auto& bn = this->BN();
810 
811  for (const auto node : from_clique) {
812  if (!separator.contains(node)) {
813  del_vars.insert(&(bn.variable(node)));
814  } else {
815  kept_vars.insert(&(bn.variable(node)));
816  }
817  }
818 
819  // pot_list now contains all the potentials to multiply and marginalize
820  // => combine the messages
821  __PotentialSet new_pot_list =
822  __marginalizeOut(pot_list.first, del_vars, kept_vars);
823 
824  // remove all the potentials that are equal to ones (as probability
825  // matrix multiplications are tensorial, such potentials are useless)
826  for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
827  ++iter) {
828  const auto pot = *iter;
829  if (pot->variablesSequence().size() == 1) {
830  bool is_all_ones = true;
831  for (Instantiation inst(*pot); !inst.end(); ++inst) {
832  if ((*pot)[inst] < __1_minus_epsilon) {
833  is_all_ones = false;
834  break;
835  }
836  }
837  if (is_all_ones) {
838  if (!pot_list.first.exists(pot)) delete pot;
839  new_pot_list.erase(iter);
840  continue;
841  }
842  }
843  }
844 
845  // remove the unnecessary temporary messages
846  for (auto iter = pot_list.second.beginSafe();
847  iter != pot_list.second.endSafe();
848  ++iter) {
849  if (!new_pot_list.contains(*iter)) {
850  delete *iter;
851  pot_list.second.erase(iter);
852  }
853  }
854 
855  // keep track of all the newly created potentials
856  for (const auto pot : new_pot_list) {
857  if (!pot_list.first.contains(pot)) { pot_list.second.insert(pot); }
858  }
859 
860  // return the new set of potentials
861  return std::pair< __PotentialSet, __PotentialSet >(
862  std::move(new_pot_list), std::move(pot_list.second));
863  }
864  }
865 
866 
867  // remove variables del_vars from the list of potentials pot_list
868  template < typename GUM_SCALAR >
869  Set< const Potential< GUM_SCALAR >* >
871  Set< const Potential< GUM_SCALAR >* > pot_list,
872  Set< const DiscreteVariable* >& del_vars,
873  Set< const DiscreteVariable* >& kept_vars) {
874  // use d-separation analysis to check which potentials shall be combined
875  __findRelevantPotentialsXX(pot_list, kept_vars);
876 
877  // remove the potentials corresponding to barren variables if we want
878  // to exploit barren nodes
879  __PotentialSet barren_projected_potentials;
881  barren_projected_potentials = __removeBarrenVariables(pot_list, del_vars);
882  }
883 
884  // create a combine and project operator that will perform the
885  // marginalization
886  MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(
888  __PotentialSet new_pot_list =
889  combine_and_project.combineAndProject(pot_list, del_vars);
890 
891  // remove all the potentials that were created due to projections of
892  // barren nodes and that are not part of the new_pot_list: these
893  // potentials were just temporary potentials
894  for (auto iter = barren_projected_potentials.beginSafe();
895  iter != barren_projected_potentials.endSafe();
896  ++iter) {
897  if (!new_pot_list.exists(*iter)) delete *iter;
898  }
899 
900  // remove all the potentials that have no dimension
901  for (auto iter_pot = new_pot_list.beginSafe();
902  iter_pot != new_pot_list.endSafe();
903  ++iter_pot) {
904  if ((*iter_pot)->variablesSequence().size() == 0) {
905  // as we have already marginalized out variables that received evidence,
906  // it may be the case that, after combining and projecting, some
907  // potentials might be empty. In this case, we shall keep their
908  // constant and remove them from memory
909  // # TODO: keep the constants!
910  delete *iter_pot;
911  new_pot_list.erase(iter_pot);
912  }
913  }
914 
915  return new_pot_list;
916  }
917 
918 
919  // performs a whole inference
920  template < typename GUM_SCALAR >
922 
923 
925  template < typename GUM_SCALAR >
926  Potential< GUM_SCALAR >*
928  const auto& bn = this->BN();
929 
930  // hard evidence do not belong to the join tree
931  // # TODO: check for sets of inconsistent hard evidence
932  if (this->hardEvidenceNodes().contains(id)) {
933  return new Potential< GUM_SCALAR >(*(this->evidence()[id]));
934  }
935 
936  // if we still need to perform some inference task, do it
937  __createNewJT(NodeSet{id});
938  NodeId clique_of_id = __node_to_clique[id];
939  auto pot_list = __collectMessage(clique_of_id, clique_of_id);
940 
941  // get the set of variables that need be removed from the potentials
942  const NodeSet& nodes = __JT->clique(clique_of_id);
943  Set< const DiscreteVariable* > kept_vars{&(bn.variable(id))};
944  Set< const DiscreteVariable* > del_vars(nodes.size());
945  for (const auto node : nodes) {
946  if (node != id) del_vars.insert(&(bn.variable(node)));
947  }
948 
949  // pot_list now contains all the potentials to multiply and marginalize
950  // => combine the messages
951  __PotentialSet new_pot_list =
952  __marginalizeOut(pot_list.first, del_vars, kept_vars);
953  Potential< GUM_SCALAR >* joint = nullptr;
954 
955  if (new_pot_list.size() == 1) {
956  joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
957  // if joint already existed, create a copy, so that we can put it into
958  // the __target_posterior property
959  if (pot_list.first.exists(joint)) {
960  joint = new Potential< GUM_SCALAR >(*joint);
961  } else {
962  // remove the joint from new_pot_list so that it will not be
963  // removed just after the else block
964  new_pot_list.clear();
965  }
966  } else {
967  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
969  joint = fast_combination.combine(new_pot_list);
970  }
971 
972  // remove the potentials that were created in new_pot_list
973  for (auto pot : new_pot_list)
974  if (!pot_list.first.exists(pot)) delete pot;
975 
976  // remove all the temporary potentials created in pot_list
977  for (auto pot : pot_list.second)
978  delete pot;
979 
980  // check that the joint posterior is different from a 0 vector: this would
981  // indicate that some hard evidence are not compatible (their joint
982  // probability is equal to 0)
983  bool nonzero_found = false;
984  for (Instantiation inst(*joint); !inst.end(); ++inst) {
985  if ((*joint)[inst]) {
986  nonzero_found = true;
987  break;
988  }
989  }
990  if (!nonzero_found) {
991  // remove joint from memory to avoid memory leaks
992  delete joint;
993  GUM_ERROR(IncompatibleEvidence,
994  "some evidence entered into the Bayes "
995  "net are incompatible (their joint proba = 0)");
996  }
997 
998  return joint;
999  }
1000 
1001 
1003  template < typename GUM_SCALAR >
1004  const Potential< GUM_SCALAR >&
1006  // compute the joint posterior and normalize
1007  auto joint = _unnormalizedJointPosterior(id);
1008  joint->normalize();
1009 
1010  if (__target_posterior != nullptr) delete __target_posterior;
1011  __target_posterior = joint;
1012 
1013  return *joint;
1014  }
1015 
1016 
1017  // returns the marginal a posteriori proba of a given node
1018  template < typename GUM_SCALAR >
1019  Potential< GUM_SCALAR >*
1021  const NodeSet& set) {
1022  // hard evidence do not belong to the join tree, so extract the nodes
1023  // from targets that are not hard evidence
1024  NodeSet targets = set, hard_ev_nodes;
1025  for (const auto node : this->hardEvidenceNodes()) {
1026  if (targets.contains(node)) {
1027  targets.erase(node);
1028  hard_ev_nodes.insert(node);
1029  }
1030  }
1031 
1032  // if all the nodes have received hard evidence, then compute the
1033  // joint posterior directly by multiplying the hard evidence potentials
1034  const auto& evidence = this->evidence();
1035  if (targets.empty()) {
1036  __PotentialSet pot_list;
1037  for (const auto node : set) {
1038  pot_list.insert(evidence[node]);
1039  }
1040  if (pot_list.size() == 1) {
1041  return new Potential< GUM_SCALAR >(**(pot_list.begin()));
1042  } else {
1043  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1045  return fast_combination.combine(pot_list);
1046  }
1047  }
1048 
1049  // if we still need to perform some inference task, do it
1050  __createNewJT(set);
1052 
1053  // get the set of variables that need be removed from the potentials
1054  const NodeSet& nodes = __JT->clique(__targets2clique);
1055  Set< const DiscreteVariable* > del_vars(nodes.size());
1056  Set< const DiscreteVariable* > kept_vars(targets.size());
1057  const auto& bn = this->BN();
1058  for (const auto node : nodes) {
1059  if (!targets.contains(node)) {
1060  del_vars.insert(&(bn.variable(node)));
1061  } else {
1062  kept_vars.insert(&(bn.variable(node)));
1063  }
1064  }
1065 
1066  // pot_list now contains all the potentials to multiply and marginalize
1067  // => combine the messages
1068  __PotentialSet new_pot_list =
1069  __marginalizeOut(pot_list.first, del_vars, kept_vars);
1070  Potential< GUM_SCALAR >* joint = nullptr;
1071 
1072  if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1073  joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
1074  // if pot already existed, create a copy, so that we can put it into
1075  // the __target_posteriors property
1076  if (pot_list.first.exists(joint)) {
1077  joint = new Potential< GUM_SCALAR >(*joint);
1078  } else {
1079  // remove the joint from new_pot_list so that it will not be
1080  // removed just after the next else block
1081  new_pot_list.clear();
1082  }
1083  } else {
1084  // combine all the potentials in new_pot_list with all the hard evidence
1085  // of the nodes in set
1086  __PotentialSet new_new_pot_list = new_pot_list;
1087  for (const auto node : hard_ev_nodes) {
1088  new_new_pot_list.insert(evidence[node]);
1089  }
1090  MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1092  joint = fast_combination.combine(new_new_pot_list);
1093  }
1094 
1095  // remove the potentials that were created in new_pot_list
1096  for (auto pot : new_pot_list)
1097  if (!pot_list.first.exists(pot)) delete pot;
1098 
1099  // remove all the temporary potentials created in pot_list
1100  for (auto pot : pot_list.second)
1101  delete pot;
1102 
1103  // check that the joint posterior is different from a 0 vector: this would
1104  // indicate that some hard evidence are not compatible
1105  bool nonzero_found = false;
1106  for (Instantiation inst(*joint); !inst.end(); ++inst) {
1107  if ((*joint)[inst]) {
1108  nonzero_found = true;
1109  break;
1110  }
1111  }
1112  if (!nonzero_found) {
1113  // remove joint from memory to avoid memory leaks
1114  delete joint;
1115  GUM_ERROR(IncompatibleEvidence,
1116  "some evidence entered into the Bayes "
1117  "net are incompatible (their joint proba = 0)");
1118  }
1119 
1120  return joint;
1121  }
1122 
1123 
1125  template < typename GUM_SCALAR >
1126  const Potential< GUM_SCALAR >&
1128  // compute the joint posterior and normalize
1129  auto joint = _unnormalizedJointPosterior(set);
1130  joint->normalize();
1131 
1132  if (__target_posterior != nullptr) delete __target_posterior;
1133  __target_posterior = joint;
1134 
1135  return *joint;
1136  }
1137 
1138 
1140  template < typename GUM_SCALAR >
1141  const Potential< GUM_SCALAR >&
1143  const NodeSet& wanted_target, const NodeSet& declared_target) {
1144  return _jointPosterior(wanted_target);
1145  }
1146 
1147 
1148 } /* namespace gum */
1149 
1150 #endif // DOXYGEN_SHOULD_SKIP_THIS
void __setProjectionFunction(Potential< GUM_SCALAR > *(*proj)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
sets the operator for performing the projections
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:581
NodeId __targets2clique
indicate a clique that contains all the nodes of the target
VariableElimination(const IBayesNet< GUM_SCALAR > *BN, RelevantPotentialsFinderType relevant_type=RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS, FindBarrenNodesType=FindBarrenNodesType::FIND_BARREN_NODES)
default constructor
void _onAllTargetsErased() final
fired before a all single and joint_targets are removed
__PotentialSet __removeBarrenVariables(__PotentialSet &pot_list, Set< const DiscreteVariable * > &del_vars)
virtual void clear()
removes all the nodes and edges from the graph
Definition: undiGraph_inl.h:43
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
const NodeProperty< const Potential< GUM_SCALAR > *> & evidence() const
returns the set of evidence
static INLINE Potential< GUM_SCALAR > * VENewprojPotential(const Potential< GUM_SCALAR > &t1, const Set< const DiscreteVariable * > &del_vars)
void __setCombinationFunction(Potential< GUM_SCALAR > *(*comb)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &))
sets the operator for performing the combinations
node_iterator_safe beginSafe() const
a begin iterator to parse the set of nodes contained in the NodeGraphPart
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
const Potential< GUM_SCALAR > & _posterior(NodeId id) final
returns the posterior of a given variable
JunctionTree * __JT
the junction tree used to answer the last inference query
RelevantPotentialsFinderType __find_relevant_potential_type
the type of relevant potential finding algorithm to be used
void __findRelevantPotentialsXX(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void setFindBarrenNodesType(FindBarrenNodesType type)
sets how we determine barren nodes
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
static void relevantPotentials(const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
update a set of potentials, keeping only those d-connected with query variables given evidence ...
Definition: BayesBall_tpl.h:35
void _onEvidenceErased(const NodeId id, bool isHardEvidence) final
fired before an evidence is removed
std::pair< __PotentialSet, __PotentialSet > __collectMessage(NodeId id, NodeId from)
actually perform the collect phase
Set< const Potential< GUM_SCALAR > *> __PotentialSet
void _makeInference() final
called when the inference has to be performed effectively
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Potential< GUM_SCALAR > * __target_posterior
the posterior computed during the last inference
Triangulation * __triangulation
the triangulation class creating the junction tree used for inference
void _updateOutdatedBNPotentials() final
prepares inference when the latter is in OutdatedBNPotentials state
void erase(const Key &k)
Erases an element from the set.
Definition: set_tpl.h:656
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
const Potential< GUM_SCALAR > & _jointPosterior(const NodeSet &set) final
returns the posterior of a declared target set
void __findRelevantPotentialsGetAll(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void __findRelevantPotentialsWithdSeparation(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
RelevantPotentialsFinderType
type of algorithm for determining the relevant potentials for combinations using some d-separation an...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const JunctionTree * junctionTree(NodeId id)
returns the join tree used for compute the posterior of node id
void _onMarginalTargetErased(const NodeId id) final
fired before a single target is removed
void _onAllMarginalTargetsErased() final
fired before a all the single targets are removed
void __findRelevantPotentialsWithdSeparation3(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
virtual void eraseNode(const NodeId id)
remove a node and its adjacent edges from the graph
Definition: undiGraph_inl.h:58
Potential< GUM_SCALAR > *(* __combination_op)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &)
the operator for performing the combinations
FindBarrenNodesType
type of algorithm to determine barren nodes
UndiGraph __graph
the undigraph extracted from the BN and used to construct the join tree
const NodeSet & softEvidenceNodes() const
returns the set of nodes with soft evidence
virtual const NodeProperty< Size > & domainSizes() const final
get the domain sizes of the random variables of the BN
void _onAllEvidenceErased(bool contains_hard_evidence) final
fired before all the evidence are erased
void __findRelevantPotentialsWithdSeparation2(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
virtual NodeId createdJunctionTreeClique(const NodeId id)=0
returns the Id of the clique created by the elimination of a given node during the triangulation proc...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const NodeProperty< Idx > & hardEvidence() const
indicate for each node with hard evidence which value it took
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
HashTable< NodeId, NodeSet > __clique_potentials
for each BN node, indicate in which clique its CPT will be stored
void _onJointTargetErased(const NodeSet &set) final
fired before a joint target is removed
void _onJointTargetAdded(const NodeSet &set) final
fired after a new joint target is inserted
FindBarrenNodesType __barren_nodes_type
the type of barren nodes computation we wish
void _onAllJointTargetsErased() final
fired before a all the joint targets are removed
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
void setRelevantPotentialsFinderType(RelevantPotentialsFinderType type)
sets how we determine the relevant potentials to combine
const NodeSet & hardEvidenceNodes() const
returns the set of nodes with hard evidence
HashTable< NodeId, NodeId > __node_to_clique
for each node of __graph (~ in the Bayes net), associate an ID in the JT
Potential< GUM_SCALAR > * _unnormalizedJointPosterior(NodeId id) final
returns a fresh potential equal to P(argument,evidence)
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
Definition: cliqueGraph.h:302
__PotentialSet __marginalizeOut(__PotentialSet pot_list, Set< const DiscreteVariable * > &del_vars, Set< const DiscreteVariable * > &kept_vars)
removes variables del_vars from a list of potentials and returns the resulting list ...
void clear()
Removes all the elements in the hash table.
std::pair< __PotentialSet, __PotentialSet > __NodePotentials(NodeId node)
returns the CPT + evidence of a node projected w.r.t. hard evidence
void _updateOutdatedBNStructure() final
prepares inference when the latter is in OutdatedBNStructure state
void _onAllMarginalTargetsAdded() final
fired after all the nodes of the BN are added as single targets
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
std::pair< __PotentialSet, __PotentialSet > __produceMessage(NodeId from_id, NodeId to_id, std::pair< __PotentialSet, __PotentialSet > &&incoming_messages)
creates the message sent by clique from_id to clique to_id
virtual Triangulation * newFactory() const =0
returns a fresh triangulation of the same type as the current object but with an empty graph ...
virtual const NodeSet & targets() const noexcept final
returns the list of marginal targets
const node_iterator_safe & endSafe() const noexcept
the end iterator to parse the set of nodes contained in the NodeGraphPart
void __createNewJT(const NodeSet &targets)
create a new junction tree as well as its related data structures
~VariableElimination() final
destructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:48
virtual const CliqueGraph & junctionTree()=0
returns a compatible junction tree
const GUM_SCALAR __1_minus_epsilon
for comparisons with 1 - epsilon
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
virtual const std::vector< NodeId > & eliminationOrder()=0
returns an elimination ordering compatible with the triangulated graph
void _onMarginalTargetAdded(const NodeId id) final
fired after a new single target is inserted
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
Potential< GUM_SCALAR > *(* __projection_op)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &)
the operator for performing the projections
Size NodeId
Type for node ids.
Definition: graphElements.h:98
void _onEvidenceChanged(const NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes ...
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:613
void(VariableElimination< GUM_SCALAR >::* __findRelevantPotentials)(Set< const Potential< GUM_SCALAR > * > &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
void _onEvidenceAdded(const NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
virtual void setGraph(const UndiGraph *graph, const NodeProperty< Size > *domsizes)=0
initialize the triangulation data structures for a new graph
static void requisiteNodes(const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
Fill the &#39;requisite&#39; nodeset with the requisite nodes in dag given a query and evidence.
Definition: BayesBall.cpp:36
virtual void _onBayesNetChanged(const IBayesNet< GUM_SCALAR > *bn) final
fired after a new Bayes net has been assigned to the engine
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.