aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
SVE_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Inline implementation of SVE.
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 #include <agrum/PRM/classDependencyGraph.h>
29 #include <agrum/PRM/inference/SVE.h>
30 
31 namespace gum {
32  namespace prm {
33 
34  template < typename GUM_SCALAR >
36  const PRMAttribute< GUM_SCALAR >& a) {
37  std::stringstream s;
38  const auto& class_a = i.type().get(a.safeName());
39  s << &(a.type().variable()) << " - ";
40  s << i.name() << "." << a.safeName()
41  << ": input=" << i.type().isInputNode(class_a);
42  s << " output=" << i.type().isOutputNode(class_a)
43  << " inner=" << i.type().isInnerNode(class_a);
44  return s.str();
45  }
46 
47  template < typename GUM_SCALAR >
49  std::stringstream s;
50  s << i.name() << std::endl;
51  s << "Attributes: " << std::endl;
52  for (auto a: i) {
53  s << __print_attribute__(i, *(a.second));
54  }
55  if (i.type().slotChains().size()) {
56  s << std::endl << "SlotChains: " << std::endl;
57  for (auto sc: i.type().slotChains()) {
58  s << sc->name() << " ";
59  }
60  }
61  return s.str();
62  }
63 
64  template < typename GUM_SCALAR >
66  std::stringstream str;
67  for (auto i: s) {
68  str << __print_instance__(*(i.second)) << std::endl;
69  }
70  return str.str();
71  }
72 
73  template < typename LIST >
75  std::stringstream s;
76  s << "[";
77  for (auto i: l) {
78  s << i->name() << " ";
79  }
80  s << "]";
81  return s.str();
82  }
83 
84  template < typename GUM_SCALAR >
86  std::stringstream s;
87  s << "{";
88  for (auto var: pot.variablesSequence()) {
89  s << var << ", ";
90  }
91  s << "}";
92  return s.str();
93  }
94 
95  template < typename SET >
97  std::stringstream s;
98  s << "[";
99  for (auto p: set) {
100  s << __print_pot__(*p) << " ";
101  }
102  s << "]";
103  return s.str();
104  }
105 
106  template < typename GUM_SCALAR >
107  SVE< GUM_SCALAR >::~SVE() {
109 
110  for (const auto& elt: elim_orders__)
111  delete elt.second;
112 
113  for (const auto& elt: lifted_pools__)
114  delete elt.second;
115 
116  if (class_elim_order__ != nullptr) delete class_elim_order__;
117 
118  for (const auto trash: lifted_trash__)
119  delete trash;
120 
121  for (auto set: delayedVariables__)
122  delete set.second;
123  }
124 
125  template < typename GUM_SCALAR >
126  void
128  NodeId node,
129  BucketSet& pool,
130  BucketSet& trash) {
133  // Downward elimination
134  List< const PRMInstance< GUM_SCALAR >* > elim_list;
136 
137  for (auto iter = query->beginInvRef(); iter != query->endInvRef(); ++iter) {
138  for (auto child = (*(iter.val())).begin(); child != (*(iter.val())).end();
139  ++child) {
140  if (!ignore.exists(child->first)) {
142  child->first,
143  pool,
144  trash,
145  elim_list,
146  ignore,
147  eliminated);
148  } else if (!eliminated.exists(child->first)) {
151  }
152  }
153  }
154 
155  // Eliminating all nodes in query instance, except query
159 
160  if (this->hasEvidence(query)) { insertEvidence__(query, pool); }
161 
162  for (auto attr = query->begin(); attr != query->end(); ++attr) {
163  pool.insert(
164  &(const_cast< Potential< GUM_SCALAR >& >((*(attr.val())).cpf())));
165  }
166 
167  for (size_t idx = 0; idx < t.eliminationOrder().size(); ++idx) {
168  if ((t.eliminationOrder()[idx] != node)
170  auto var_id = t.eliminationOrder()[idx];
171  const auto& var = bn.variable(var_id);
173  }
174  }
175 
177 
178  // Eliminating delayed variables, if any
181  }
182 
184  // Eliminating instance in elim_list
185  List< const PRMInstance< GUM_SCALAR >* > tmp_list;
186 
187  while (!elim_list.empty()) {
189  if (!ignore.exists(elim_list.front())) {
191  elim_list.front(),
192  pool,
193  trash,
194  elim_list,
195  ignore,
196  eliminated);
197  }
198  } else {
200  }
201 
203  }
204 
205  // Upward elimination
206  for (const auto chain: query->type().slotChains())
207  for (const auto parent: query->getInstances(chain->id()))
208  if (!ignore.exists(parent))
210  pool,
211  trash,
212  tmp_list,
213  ignore,
214  eliminated);
215  }
216 
217  template < typename GUM_SCALAR >
219  const PRMInstance< GUM_SCALAR >* i,
220  BucketSet& pool,
221  BucketSet& trash) {
223 
224  for (const auto var: *delayedVariables__[i]) {
226 
227  for (const auto pot: pool)
228  if (pot->contains(*var)) {
229  bucket->add(*pot);
231  }
232 
233  for (const auto pot: toRemove)
234  pool.erase(pot);
235 
236  for (const auto other: bucket->allVariables())
237  if (other != var) bucket->add(*other);
238 
242  }
243  }
244 
245  template < typename GUM_SCALAR >
247  const PRMInstance< GUM_SCALAR >* from,
248  const PRMInstance< GUM_SCALAR >* i,
249  BucketSet& pool,
250  BucketSet& trash,
251  List< const PRMInstance< GUM_SCALAR >* >& elim_list,
252  Set< const PRMInstance< GUM_SCALAR >* >& ignore,
253  Set< const PRMInstance< GUM_SCALAR >* >& eliminated) {
255  ignore.insert(i);
256  // Calling elimination over child instance
257  List< const PRMInstance< GUM_SCALAR >* > my_list;
258 
259  for (auto iter = i->beginInvRef(); iter != i->endInvRef(); ++iter) {
260  for (auto child = (*(iter.val())).begin(); child != (*(iter.val())).end();
261  ++child) {
262  if (!ignore.exists(child->first)) {
264  child->first,
265  pool,
266  trash,
267  my_list,
268  ignore,
269  eliminated);
270  } else if (!eliminated.exists(child->first)) {
273  }
274  }
275  }
276 
277  // Eliminating all nodes in current instance
279  pool,
280  trash,
281  (delayedVars.empty() ? 0 : &delayedVars));
283 
284  // Calling elimination over child's parents
285  for (const auto node: my_list) {
286  if (checkElimOrder__(i, node) && (node != from)) {
287  if (!ignore.exists(node)) {
289  node,
290  pool,
291  trash,
292  elim_list,
293  ignore,
294  eliminated);
295  }
296  } else if (node != from) {
298  }
299  }
300 
301  // Adding parents instance to elim_list
302  for (const auto chain: i->type().slotChains()) {
303  for (const auto inst: i->getInstances(chain->id())) {
304  if (inst != from) { elim_list.insert(inst); }
305  }
306  }
307  }
308 
309  template < typename GUM_SCALAR >
310  void
312  BucketSet& pool,
313  BucketSet& trash,
314  Set< NodeId >* delayedVars) {
315  if (this->hasEvidence(i)) {
317  } else {
319 
320  for (const auto agg: i->type().aggregates())
322 
323  try {
325 
326  std::vector< const DiscreteVariable* > elim;
327 
328  for (const auto node: getElimOrder__(i->type())) {
329  const auto& var = bn.variable(node);
330  if (delayedVars != nullptr) {
331  if (!delayedVars->exists(node)) {
332  const auto& var = bn.variable(node);
333  elim.push_back(&var);
334  }
335  } else {
336  elim.push_back(&var);
337  }
338  }
339 
341  } catch (NotFound&) {
342  // Raised if there is no inner nodes to eliminate
343  }
344  }
345 
346  // Eliminating delayed variables, if any
347  if (delayedVariables__.exists(i)) {
349  }
350  }
351 
352  template < typename GUM_SCALAR >
354  const PRMInstance< GUM_SCALAR >* i,
355  BucketSet& pool,
356  BucketSet& trash,
357  List< const PRMInstance< GUM_SCALAR >* >& elim_list,
358  Set< const PRMInstance< GUM_SCALAR >* >& ignore,
359  Set< const PRMInstance< GUM_SCALAR >* >& eliminated) {
360  // Downward elimination
361  ignore.insert(i);
362 
363  for (auto iter = i->beginInvRef(); iter != i->endInvRef(); ++iter) {
364  for (auto child = (*(iter.val())).begin(); child != (*(iter.val())).end();
365  ++child) {
366  if (!ignore.exists(child->first)) {
368  child->first,
369  pool,
370  trash,
371  elim_list,
372  ignore,
373  eliminated);
374  }
375  }
376  }
377 
378  // Eliminating all nodes in i instance
381  // Eliminating instance in elim_list
382  List< const PRMInstance< GUM_SCALAR >* > tmp_list;
383 
384  while (!elim_list.empty()) {
385  if (checkElimOrder__(i, elim_list.front())) {
386  if (!ignore.exists(elim_list.front())) {
388  elim_list.front(),
389  pool,
390  trash,
391  elim_list,
392  ignore,
393  eliminated);
394  }
395  } else {
397  }
398 
400  }
401 
402  // Upward elimination
403  for (const auto chain: i->type().slotChains()) {
404  for (const auto parent: i->getInstances(chain->id())) {
405  if (!ignore.exists(parent)) {
407  pool,
408  trash,
409  tmp_list,
410  ignore,
411  eliminated);
412  }
413  }
414  }
415  }
416 
417  template < typename GUM_SCALAR >
419  const PRMInstance< GUM_SCALAR >* i,
420  BucketSet& pool,
421  BucketSet& trash,
422  Set< NodeId >* delayedVars) {
423  // First we check if evidences are on inner nodes
424  bool inner = false;
425 
426  for (const auto& elt: this->evidence(i)) {
427  inner = i->type().isInputNode(i->get(elt.first))
428  || i->type().isInnerNode(i->get(elt.first));
429 
430  if (inner) { break; }
431  }
432 
433  // Evidence on inner nodes
434  if (inner) {
437 
438  // We need a local to not eliminate queried inner nodes of the same
439  // class
440  for (const auto& elt: *i) {
442  &(const_cast< Potential< GUM_SCALAR >& >(elt.second->cpf())));
443  }
444 
448  // Removing Output nodes of elimination order
451 
452  for (size_t idx = 0; idx < full_elim_order.size(); ++idx) {
453  auto var_id = full_elim_order[idx];
454  const auto& var = bn.variable(var_id);
455 
456  if (!i->type().isOutputNode(i->get(full_elim_order[idx]))) {
458  } else if (delayedVars != nullptr) {
461  }
462  } else {
464  }
465  }
466 
468 
469  // Now we add the new potentials in pool and eliminate output nodes
470  for (const auto pot: tmp_pool)
471  pool.insert(pot);
472 
473  if (!output_elim_order.empty())
475 
476  } else {
480 
481  for (const auto agg: i->type().aggregates())
483 
484  try {
485  std::vector< const DiscreteVariable* > elim;
486 
487  for (auto iter = getElimOrder__(i->type()).begin();
488  iter != getElimOrder__(i->type()).end();
489  ++iter) {
490  const auto& var = bn.variable(*iter);
491  if (delayedVars != nullptr) {
492  if (!delayedVars->exists(*iter)) { elim.push_back(&var); }
493  } else {
494  elim.push_back(&var);
495  }
496  }
497 
499  } catch (NotFound&) {
500  GUM_ERROR(FatalError, "there should be at least one node here.");
501  }
502  }
503  }
504 
505  template < typename GUM_SCALAR >
507  BucketSet& pool,
508  BucketSet& trash) {
510 
511  try {
512  lifted_pool = lifted_pools__[&(i->type())];
513  } catch (NotFound&) {
515  lifted_pool = lifted_pools__[&(i->type())];
516  }
517 
518  for (const auto lifted_pot: *lifted_pool) {
520  pool.insert(pot);
521  trash.insert(pot);
522  }
523  }
524 
525  template < typename GUM_SCALAR >
530 
531  for (const auto node: c.containerDag().nodes())
533  if (c.isOutputNode(c.get(node)))
534  outers.insert(node);
535  else if (!outers.exists(node))
536  inners.insert(node);
537 
539  const_cast< Potential< GUM_SCALAR >* >(&(c.get(node).cpf())));
540  } else if (PRMClassElement< GUM_SCALAR >::isAggregate(c.get(node))) {
541  outers.insert(node);
542 
543  // We need to put in the output_elim_order aggregator's parents which
544  // are
545  // innner nodes
546  for (const auto par: c.containerDag().parents(node))
548  && c.isInnerNode(c.get(par))) {
549  inners.erase(par);
550  outers.insert(par);
551  }
552  }
553 
554  // Now we proceed with the elimination of inner attributes
557 
559 
561 
563  &(bn.modalities()),
565 
566  for (size_t idx = 0; idx < inners.size(); ++idx)
568  *lifted_pool,
570 
571  // If there is not only inner and input Attributes
572  if (outers.size()) {
574  &c,
576  t.eliminationOrder().end()));
577  }
578  }
579 
580  template < typename GUM_SCALAR >
584  std::list< NodeId > l;
585 
586  for (const auto node: cdg.dag().nodes()) {
587  if (cdg.dag().parents(node).empty()) { l.push_back(node); }
588  }
589 
591 
592  while (!l.empty()) {
594 
597  }
598 
599  for (const auto child: cdg.dag().children(l.front())) {
601  }
602 
603  l.pop_front();
604  }
605 
607  for (auto c: class_elim_order) {
608  std::string name = c->name();
609  auto pos = name.find_first_of("<");
610  if (pos != std::string::npos) { name = name.substr(0, pos); }
611  try {
613  } catch (DuplicateElement&) {}
614  }
615  }
616 
617  template < typename GUM_SCALAR >
619  Potential< GUM_SCALAR >& m) {
620  const PRMInstance< GUM_SCALAR >* i = chain.first;
621  const PRMAttribute< GUM_SCALAR >* elt = chain.second;
623 
625 
627 
628  for (const auto pot: pool) {
629  if (pot->contains(elt->type().variable())) { result.push_back(pot); }
630  }
631 
632  while (result.size() > 1) {
633  auto& p1 = *(result.back());
634  result.pop_back();
635  auto& p2 = *(result.back());
636  result.pop_back();
637  auto mult = new Potential< GUM_SCALAR >(p1 * p2);
638  trash.insert(mult);
640  }
641 
642  m = *(result.back());
643  m.normalize();
644 
645  for (const auto pot: trash) {
646  delete pot;
647  }
648  }
649 
650  template < typename GUM_SCALAR >
651  void SVE< GUM_SCALAR >::joint_(const std::vector< Chain >& queries,
652  Potential< GUM_SCALAR >& j) {
653  GUM_ERROR(FatalError, "Not implemented.");
654  }
655 
656  template < typename GUM_SCALAR >
658  const PRMSystem< GUM_SCALAR >& system) :
660  class_elim_order__(0) {
662  }
663 
664  template < typename GUM_SCALAR >
665  INLINE void
667  BucketSet& pool) {
668  for (const auto& elt: this->evidence(i))
669  pool.insert(const_cast< Potential< GUM_SCALAR >* >(elt.second));
670  }
671 
672  template < typename GUM_SCALAR >
673  INLINE std::vector< NodeId >&
675  return *(elim_orders__[&c]);
676  }
677 
678  template < typename GUM_SCALAR >
680  auto pos = s.find_first_of("<");
681  if (pos != std::string::npos) { return s.substr(0, pos); }
682  return s;
683  }
684 
685  template < typename GUM_SCALAR >
687  const PRMInstance< GUM_SCALAR >* first,
688  const PRMInstance< GUM_SCALAR >* second) {
689  if (class_elim_order__ == 0) { initElimOrder__(); }
690 
691  auto first_name = trim__(first->type().name());
692  auto second_name = trim__(second->type().name());
695  }
696 
697  template < typename GUM_SCALAR >
699  const PRMInstance< GUM_SCALAR >* i,
700  const PRMAggregate< GUM_SCALAR >* agg) {
701  return &(const_cast< Potential< GUM_SCALAR >& >(i->get(agg->id()).cpf()));
702  }
703 
704  template < typename GUM_SCALAR >
706  // Do nothing
707  }
708 
709  template < typename GUM_SCALAR >
711  // Do nothing
712  }
713 
714  template < typename GUM_SCALAR >
715  INLINE void
717  const PRMInstance< GUM_SCALAR >* j,
718  NodeId id) {
719  try {
721  } catch (NotFound&) {
722  delayedVariables__.insert(i, new Set< const DiscreteVariable* >());
724  } catch (DuplicateElement&) {
725  // happends if j->get(id) is parent of more than one variable in i
726  }
727 
728  static std::string dot = ".";
729 
730  try {
731  delayedVariablesCounters__[j->name() + dot + j->get(id).safeName()] += 1;
732  } catch (NotFound&) {
734  1);
735  }
736  }
737 
738  template < typename GUM_SCALAR >
740  return "SVE";
741  }
742 
743  } /* namespace prm */
744 } /* namespace gum */
std::string __print_pot__(const Potential< GUM_SCALAR > &pot)
Definition: SVE_tpl.h:85
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
std::string __print_list__(LIST l)
Definition: SVE_tpl.h:74
std::string __print_instance__(const PRMInstance< GUM_SCALAR > &i)
Definition: SVE_tpl.h:48
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)
std::string __print_attribute__(const PRMInstance< GUM_SCALAR > &i, const PRMAttribute< GUM_SCALAR > &a)
Definition: SVE_tpl.h:35
std::string __print_system__(const PRMSystem< GUM_SCALAR > &s)
Definition: SVE_tpl.h:65
std::string __print_set__(SET set)
Definition: SVE_tpl.h:96