aGrUM  0.17.2
a C++ library for (probabilistic) graphical models
multiDimCombineAndProjectDefault_tpl.h
Go to the documentation of this file.
1 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
32 # include <limits>
33 
34 # include <agrum/agrum.h>
35 
37 
38 namespace gum {
39 
40  // default constructor
41  template < typename GUM_SCALAR, template < typename > class TABLE >
44  TABLE< GUM_SCALAR >* (*combine)(const TABLE< GUM_SCALAR >&,
45  const TABLE< GUM_SCALAR >&),
46  TABLE< GUM_SCALAR >* (*project)(const TABLE< GUM_SCALAR >&,
47  const Set< const DiscreteVariable* >&)) :
48  MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
49  __combination(new MultiDimCombinationDefault< GUM_SCALAR, TABLE >(combine)),
50  __projection(new MultiDimProjection< GUM_SCALAR, TABLE >(project)) {
51  // for debugging purposes
52  GUM_CONSTRUCTOR(MultiDimCombineAndProjectDefault);
53  }
54 
55  // copy constructor
56  template < typename GUM_SCALAR, template < typename > class TABLE >
59  const MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >& from) :
60  MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
63  // for debugging purposes
65  }
66 
67  // destructor
68  template < typename GUM_SCALAR, template < typename > class TABLE >
71  // for debugging purposes
72  GUM_DESTRUCTOR(MultiDimCombineAndProjectDefault);
73  delete __combination;
74  delete __projection;
75  }
76 
77  // virtual constructor
78  template < typename GUM_SCALAR, template < typename > class TABLE >
79  MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >*
81  return new MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >(*this);
82  }
83 
84  // combine and project
85  template < typename GUM_SCALAR, template < typename > class TABLE >
86  Set< const TABLE< GUM_SCALAR >* >
88  Set< const TABLE< GUM_SCALAR >* > table_set,
89  Set< const DiscreteVariable* > del_vars) {
90  // when we remove a variable, we need to combine all the tables containing
91  // this variable in order to produce a new unique table containing this
92  // variable. Removing the variable is then performed by marginalizing it
93  // out of the table. In the combineAndProject algorithm, we wish to remove
94  // first variables that produce small tables. This should speed up the
95  // marginalizing process
96  Size nb_vars;
97  {
98  // determine the set of all the variables involved in the tables.
99  // this should help sizing correctly the hashtables
100  Set< const DiscreteVariable* > all_vars;
101 
102  for ( const auto ptrTab : table_set) {
103  for (const auto ptrVar : ptrTab->variablesSequence()) {
104  all_vars.insert(ptrVar);
105  }
106  }
107 
108  nb_vars = all_vars.size();
109  }
110 
111  // the tables containing a given variable to be deleted
112  HashTable< const DiscreteVariable*, Set< const TABLE< GUM_SCALAR >* > >
113  tables_per_var(nb_vars);
114 
115  // for a given variable X to be deleted, the list of all the variables of
116  // the tables containing X (actually, we also count the number of tables
117  // containing the variable. This is more efficient for computing and
118  // updating the product_size priority queue (see below) when some tables
119  // are removed)
120  HashTable< const DiscreteVariable*,
121  HashTable< const DiscreteVariable*, unsigned int > >
122  tables_vars_per_var(nb_vars);
123 
124  // initialize tables_vars_per_var and tables_per_var
125  {
126  Set< const TABLE< GUM_SCALAR >* > empty_set(table_set.size());
127  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
128 
129  for (const auto ptrVar : del_vars) {
130  tables_per_var.insert(ptrVar, empty_set);
131  tables_vars_per_var.insert(ptrVar, empty_hash);
132  }
133 
134  // update properly tables_per_var and tables_vars_per_var
135  for (const auto ptrTab : table_set) {
136  const Sequence< const DiscreteVariable* >& vars =
137  ptrTab->variablesSequence();
138 
139  for (const auto ptrVar : vars) {
140  if (del_vars.contains(ptrVar)) {
141  // add the table to the set of tables related to vars[i]
142  tables_per_var[ptrVar].insert(ptrTab);
143 
144  // add the variables of the table to tables_vars_per_var[vars[i]]
145  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
146  tables_vars_per_var[ptrVar];
147 
148  for (const auto xptrVar : vars) {
149  try {
150  ++iter_vars[xptrVar];
151  } catch (const NotFound&) { iter_vars.insert(xptrVar, 1); }
152  }
153  }
154  }
155  }
156  }
157 
158  // the sizes of the tables produced when removing a given discrete variable
159  PriorityQueue< const DiscreteVariable*, double > product_size;
160 
161  // initialize properly product_size
162  for (const auto& elt : tables_vars_per_var) {
163  double size = 1.0;
164  const auto ptrVar = elt.first;
165  const auto& hashvars = elt.second; // HashTable<DiscreteVariable*, int>
166 
167  if (hashvars.size()) {
168  for (const auto& xelt : hashvars) {
169  size *= (double) xelt.first->domainSize();
170  }
171 
172  product_size.insert(ptrVar, size);
173  }
174  }
175 
176  // create a set of the temporary tables created during the
177  // marginalization process (useful for deallocating temporary tables)
178  Set< const TABLE< GUM_SCALAR >* > tmp_marginals(table_set.size());
179 
180  // now, remove all the variables in del_vars, starting from those that
181  // produce the smallest tables
182  while (!product_size.empty()) {
183  // get the best variable to remove
184  const DiscreteVariable* del_var = product_size.pop();
185  del_vars.erase(del_var);
186 
187  // get the set of tables to combine
188  Set< const TABLE< GUM_SCALAR >* >& tables_to_combine =
189  tables_per_var[del_var];
190 
191  // if there is no tables to combine, do nothing
192  if (tables_to_combine.size() == 0) continue;
193 
194  // compute the combination of all the tables: if there is only one table,
195  // there is nothing to do, else we shall use the MultiDimCombination
196  // to perform the combination
197  TABLE< GUM_SCALAR >* joint;
198 
199  bool joint_to_delete = false;
200 
201  if (tables_to_combine.size() == 1) {
202  joint =
203  const_cast< TABLE< GUM_SCALAR >* >(*(tables_to_combine.begin()));
204  joint_to_delete = false;
205  }
206  else {
207  joint = __combination->combine(tables_to_combine);
208  joint_to_delete = true;
209  }
210 
211  // compute the table resulting from marginalizing out del_var from joint
212  Set< const DiscreteVariable* > del_one_var;
213  del_one_var << del_var;
214 
215  TABLE< GUM_SCALAR >* marginal = __projection->project(*joint, del_one_var);
216 
217  // remove the temporary joint if needed
218  if (joint_to_delete) delete joint;
219 
220  // update tables_vars_per_var : remove the variables of the TABLEs we
221  // combined from this hashtable
222  // update accordingly tables_per_vars : remove these TABLEs
223  // update accordingly product_size : when a variable is no more used by
224  // any TABLE, divide product_size by its domain size
225 
226  for (const auto ptrTab : tables_to_combine) {
227  const Sequence< const DiscreteVariable* >& table_vars =
228  ptrTab->variablesSequence();
229  const Size tab_vars_size = table_vars.size();
230 
231  for (Size i = 0; i < tab_vars_size; ++i) {
232  if (del_vars.contains(table_vars[i])) {
233  // ok, here we have a variable that needed to be removed => update
234  // product_size, tables_per_var and tables_vars_per_var: here,
235  // the update corresponds to removing table PtrTab
236  HashTable< const DiscreteVariable*, unsigned int >&
237  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
238  double div_size = 1.0;
239 
240  for (Size j = 0; j < tab_vars_size; ++j) {
241  unsigned int k = --table_vars_of_var_i[table_vars[j]];
242 
243  if (k == 0) {
244  div_size *= table_vars[j]->domainSize();
245  table_vars_of_var_i.erase(table_vars[j]);
246  }
247  }
248 
249  tables_per_var[table_vars[i]].erase(ptrTab);
250 
251  if (div_size != 1.0) {
252  product_size.setPriority(
253  table_vars[i], product_size.priority(table_vars[i]) / div_size);
254  }
255  }
256  }
257 
258  if (tmp_marginals.contains(ptrTab)) {
259  delete ptrTab;
260  tmp_marginals.erase(ptrTab);
261  }
262 
263  table_set.erase(ptrTab);
264  }
265 
266  tables_per_var.erase(del_var);
267 
268  // add the new projected marginal to the list of TABLES
269  const Sequence< const DiscreteVariable* >& marginal_vars =
270  marginal->variablesSequence();
271 
272  for (const auto mvar : marginal_vars) {
273  if (del_vars.contains(mvar)) {
274  // add the new marginal table to the set of tables of mvar
275  tables_per_var[mvar].insert(marginal);
276 
277  // add the variables of the table to tables_vars_per_var[mvar]
278  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
279  tables_vars_per_var[mvar];
280  double mult_size = 1.0;
281 
282  for (const auto var : marginal_vars) {
283  try {
284  ++iter_vars[var];
285  }
286  catch (const NotFound&) {
287  iter_vars.insert(var, 1);
288  mult_size *= (double) var->domainSize();
289  }
290  }
291 
292  if (mult_size != 1.0) {
293  product_size.setPriority(mvar,
294  product_size.priority(mvar) * mult_size);
295  }
296  }
297  }
298 
299  table_set.insert(marginal);
300  tmp_marginals.insert(marginal);
301  }
302 
303  // here, tmp_marginals contains all the newly created tables and
304  // table_set contains the list of the tables resulting from the
305  // marginalizing out of del_vars of the combination of the tables
306  // of table_set. Note in particular that it will contain all the
307  // potentials with no dimension (constants)
308  return table_set;
309  }
310 
311  // changes the function used for combining two TABLES
312  template < typename GUM_SCALAR, template < typename > class TABLE >
313  INLINE void
315  TABLE< GUM_SCALAR >* (*combine)(const TABLE< GUM_SCALAR >&,
316  const TABLE< GUM_SCALAR >&)) {
317  __combination->setCombineFunction(combine);
318  }
319 
320  // returns the current combination function
321  template < typename GUM_SCALAR, template < typename > class TABLE >
322  INLINE TABLE< GUM_SCALAR >* (
324  const TABLE< GUM_SCALAR >&, const TABLE< GUM_SCALAR >&) {
325  return __combination->combineFunction();
326  }
327 
328  // changes the class that performs the combinations
329  template < typename GUM_SCALAR, template < typename > class TABLE >
330  INLINE void
332  const MultiDimCombination< GUM_SCALAR, TABLE >& comb_class) {
333  delete __combination;
334  __combination = comb_class.newFactory();
335  }
336 
337  // changes the function used for projecting TABLES
338  template < typename GUM_SCALAR, template < typename > class TABLE >
339  INLINE void
341  TABLE< GUM_SCALAR >* (*proj)(const TABLE< GUM_SCALAR >&,
342  const Set< const DiscreteVariable* >&)) {
343  __projection->setProjectFunction(proj);
344  }
345 
346  // returns the current projection function
347  template < typename GUM_SCALAR, template < typename > class TABLE >
348  INLINE TABLE< GUM_SCALAR >* (
350  const TABLE< GUM_SCALAR >&, const Set< const DiscreteVariable* >&) {
351  return __projection->projectFunction();
352  }
353 
354  // changes the class that performs the projections
355  template < typename GUM_SCALAR, template < typename > class TABLE >
356  INLINE void
358  const MultiDimProjection< GUM_SCALAR, TABLE >& proj_class) {
359  delete __projection;
360  __projection = proj_class.newFactory();
361  }
362 
365  template < typename GUM_SCALAR, template < typename > class TABLE >
367  const Set< const Sequence< const DiscreteVariable* >* >& table_set,
368  Set< const DiscreteVariable* > del_vars) const {
369  // when we remove a variable, we need to combine all the tables containing
370  // this variable in order to produce a new unique table containing this
371  // variable. Here, we do not have the tables but only their variables
372  // (dimensions), but the principle is identical. Removing a variable is then
373  // performed by marginalizing it out of the table or, equivalently, to
374  // remove it from the table's list of variables. In the
375  // combineAndProjectDefault algorithm, we wish to remove first variables
376  // that would produce small tables. This should speed up the whole
377  // marginalizing process.
378 
379  Size nb_vars;
380  {
381  // determine the set of all the variables involved in the tables.
382  // this should help sizing correctly the hashtables
383  Set< const DiscreteVariable* > all_vars;
384 
385  for ( const auto ptrSeq : table_set) {
386  for (const auto ptrVar : *ptrSeq) {
387  all_vars.insert(ptrVar);
388  }
389  }
390 
391  nb_vars = all_vars.size();
392  }
393 
394  // the tables (actually their variables) containing a given variable
395  // to be deleted
396  HashTable< const DiscreteVariable*,
397  Set< const Sequence< const DiscreteVariable* >* > >
398  tables_per_var(nb_vars);
399 
400  // for a given variable X to be deleted, the list of all the variables of
401  // the tables containing X (actually, we count the number of tables
402  // containing the variable. This is more efficient for computing and
403  // updating the product_size priority queue (see below) when some tables
404  // are removed)
405  HashTable< const DiscreteVariable*,
406  HashTable< const DiscreteVariable*, unsigned int > >
407  tables_vars_per_var(nb_vars);
408 
409  // initialize tables_vars_per_var and tables_per_var
410  {
411  Set< const Sequence< const DiscreteVariable* >* > empty_set(
412  table_set.size());
413  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
414 
415  for (const auto ptrVar : del_vars) {
416  tables_per_var.insert(ptrVar, empty_set);
417  tables_vars_per_var.insert(ptrVar, empty_hash);
418  }
419 
420  // update properly tables_per_var and tables_vars_per_var
421  for (const auto ptrSeq : table_set) {
422  const Sequence< const DiscreteVariable* >& vars = *ptrSeq;
423 
424  for (const auto ptrVar : vars) {
425  if (del_vars.contains(ptrVar)) {
426  // add the table's variables to the set of those related to ptrVar
427  tables_per_var[ptrVar].insert(ptrSeq);
428 
429  // add the variables of the table to tables_vars_per_var[ptrVar]
430  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
431  tables_vars_per_var[ptrVar];
432 
433  for (const auto xptrVar : vars) {
434  try {
435  ++iter_vars[xptrVar];
436  } catch (const NotFound&) { iter_vars.insert(xptrVar, 1); }
437  }
438  }
439  }
440  }
441  }
442 
443  // the sizes of the tables produced when removing a given discrete variable
444  PriorityQueue< const DiscreteVariable*, double > product_size;
445 
446  // initialize properly product_size
447  for (const auto& elt : tables_vars_per_var) {
448  double size = 1.0;
449  const auto ptrVar = elt.first;
450  const auto hashvars = elt.second; // HashTable<DiscreteVariable*, int>
451 
452  if (hashvars.size()) {
453  for (const auto& xelt : hashvars) {
454  size *= (double) xelt.first->domainSize();
455  }
456 
457  product_size.insert(ptrVar, size);
458  }
459  }
460 
461  // the resulting number of operations
462  float nb_operations = 0;
463 
464  // create a set of the temporary table's variables created during the
465  // marginalization process (useful for deallocating temporary tables)
466  Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
467  table_set.size());
468 
469  // now, remove all the variables in del_vars, starting from those that
470  // produce the smallest tables
471  while (!product_size.empty()) {
472  // get the best variable to remove
473  const DiscreteVariable* del_var = product_size.pop();
474  del_vars.erase(del_var);
475 
476  // get the set of tables to combine
477  Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
478  tables_per_var[del_var];
479 
480  // if there is no tables to combine, do nothing
481  if (tables_to_combine.size() == 0) continue;
482 
483  // compute the combination of all the tables: if there is only one table,
484  // there is nothing to do, else we shall use the MultiDimCombination
485  // to perform the combination
486  Sequence< const DiscreteVariable* >* joint;
487 
488  bool joint_to_delete = false;
489 
490  if (tables_to_combine.size() == 1) {
491  joint = const_cast< Sequence< const DiscreteVariable* >* >(
492  *(tables_to_combine.beginSafe()));
493  joint_to_delete = false;
494  } else {
495  // here, compute the union of all the variables of the tables to combine
496  joint = new Sequence< const DiscreteVariable* >;
497 
498  for (const auto ptrSeq : tables_to_combine) {
499  for (const auto ptrVar : *ptrSeq) {
500  if (!joint->exists(ptrVar)) { joint->insert(ptrVar); }
501  }
502  }
503 
504  joint_to_delete = true;
505 
506  // update the number of operations performed
507  nb_operations += __combination->nbOperations(tables_to_combine);
508  }
509 
510  // update the number of operations performed by marginalizing out del_var
511  Set< const DiscreteVariable* > del_one_var;
512  del_one_var << del_var;
513 
514  nb_operations += __projection->nbOperations(*joint, del_one_var);
515 
516  // compute the table resulting from marginalizing out del_var from joint
517  Sequence< const DiscreteVariable* >* marginal;
518 
519  if (joint_to_delete) {
520  marginal = joint;
521  } else {
522  marginal = new Sequence< const DiscreteVariable* >(*joint);
523  }
524 
525  marginal->erase(del_var);
526 
527  // update tables_vars_per_var : remove the variables of the TABLEs we
528  // combined from this hashtable
529  // update accordingly tables_per_vars : remove these TABLEs
530  // update accordingly product_size : when a variable is no more used by
531  // any TABLE, divide product_size by its domain size
532 
533  for (const auto ptrSeq : tables_to_combine) {
534  const Sequence< const DiscreteVariable* >& table_vars = *ptrSeq;
535  const Size tab_vars_size = table_vars.size();
536 
537  for (Size i = 0; i < tab_vars_size; ++i) {
538  if (del_vars.contains(table_vars[i])) {
539  // ok, here we have a variable that needed to be removed => update
540  // product_size, tables_per_var and tables_vars_per_var
541  HashTable< const DiscreteVariable*, unsigned int >&
542  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
543  double div_size = 1.0;
544 
545  for (Size j = 0; j < tab_vars_size; ++j) {
546  unsigned int k = --table_vars_of_var_i[table_vars[j]];
547 
548  if (k == 0) {
549  div_size *= table_vars[j]->domainSize();
550  table_vars_of_var_i.erase(table_vars[j]);
551  }
552  }
553 
554  tables_per_var[table_vars[i]].erase(ptrSeq);
555 
556  if (div_size != 1.0) {
557  product_size.setPriority(
558  table_vars[i], product_size.priority(table_vars[i]) / div_size);
559  }
560  }
561  }
562 
563  if (tmp_marginals.contains(ptrSeq)) {
564  delete ptrSeq;
565  tmp_marginals.erase(ptrSeq);
566  }
567  }
568 
569  tables_per_var.erase(del_var);
570 
571  // add the new projected marginal to the list of TABLES
572  for (const auto mvar : *marginal) {
573  if (del_vars.contains(mvar)) {
574  // add the new marginal table to the set of tables of var i
575  tables_per_var[mvar].insert(marginal);
576 
577  // add the variables of the table to tables_vars_per_var[vars[i]]
578  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
579  tables_vars_per_var[mvar];
580  double mult_size = 1.0;
581 
582  for (const auto var : *marginal) {
583  try {
584  ++iter_vars[var];
585  } catch (const NotFound&) {
586  iter_vars.insert(var, 1);
587  mult_size *= (double) var->domainSize();
588  }
589  }
590 
591  if (mult_size != 1.0) {
592  product_size.setPriority(mvar,
593  product_size.priority(mvar) * mult_size);
594  }
595  }
596  }
597 
598  tmp_marginals.insert(marginal);
599  }
600 
601  // here, tmp_marginals contains all the newly created tables
602  for (auto iter = tmp_marginals.beginSafe();
603  iter != tmp_marginals.endSafe();
604  ++iter) {
605  delete *iter;
606  }
607 
608  return nb_operations;
609  }
610 
613  template < typename GUM_SCALAR, template < typename > class TABLE >
615  const Set< const TABLE< GUM_SCALAR >* >& set,
616  const Set< const DiscreteVariable* >& del_vars) const {
617  // create the set of sets of discrete variables involved in the tables
618  Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
619 
620  for (const auto ptrTab : set) {
621  var_set << &(ptrTab->variablesSequence());
622  }
623 
624  return nbOperations(var_set, del_vars);
625  }
626 
627  // returns the memory consumption used during the combinations and
628  // projections
629  template < typename GUM_SCALAR, template < typename > class TABLE >
630  std::pair< long, long >
632  const Set< const Sequence< const DiscreteVariable* >* >& table_set,
633  Set< const DiscreteVariable* > del_vars) const {
634  // when we remove a variable, we need to combine all the tables containing
635  // this variable in order to produce a new unique table containing this
636  // variable. Here, we do not have the tables but only their variables
637  // (dimensions), but the principle is identical. Removing a variable is then
638  // performed by marginalizing it out of the table or, equivalently, to
639  // remove it from the table's list of variables. In the
640  // combineAndProjectDefault algorithm, we wish to remove first variables
641  // that would produce small tables. This should speed up the whole
642  // marginalizing process.
643 
644  Size nb_vars;
645  {
646  // determine the set of all the variables involved in the tables.
647  // this should help sizing correctly the hashtables
648  Set< const DiscreteVariable* > all_vars;
649 
650  for ( const auto ptrSeq : table_set) {
651  for (const auto ptrVar : *ptrSeq) {
652  all_vars.insert(ptrVar);
653  }
654  }
655 
656  nb_vars = all_vars.size();
657  }
658 
659  // the tables (actually their variables) containing a given variable
660  HashTable< const DiscreteVariable*,
661  Set< const Sequence< const DiscreteVariable* >* > >
662  tables_per_var(nb_vars);
663  // for a given variable X to be deleted, the list of all the variables of
664  // the tables containing X (actually, we count the number of tables
665  // containing the variable. This is more efficient for computing and
666  // updating the product_size priority queue (see below) when some tables
667  // are removed)
668  HashTable< const DiscreteVariable*,
669  HashTable< const DiscreteVariable*, unsigned int > >
670  tables_vars_per_var(nb_vars);
671 
672  // initialize tables_vars_per_var and tables_per_var
673  {
674  Set< const Sequence< const DiscreteVariable* >* > empty_set(
675  table_set.size());
676  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
677 
678  for (const auto ptrVar : del_vars) {
679  tables_per_var.insert(ptrVar, empty_set);
680  tables_vars_per_var.insert(ptrVar, empty_hash);
681  }
682 
683  // update properly tables_per_var and tables_vars_per_var
684  for (const auto ptrSeq : table_set) {
685  const Sequence< const DiscreteVariable* >& vars = *ptrSeq;
686 
687  for (const auto ptrVar : vars) {
688  if (del_vars.contains(ptrVar)) {
689  // add the table's variables to the set of those related to ptrVar
690  tables_per_var[ptrVar].insert(ptrSeq);
691 
692  // add the variables of the table to tables_vars_per_var[ptrVar]
693  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
694  tables_vars_per_var[ptrVar];
695 
696  for (const auto xptrVar : vars) {
697  try {
698  ++iter_vars[xptrVar];
699  } catch (const NotFound&) { iter_vars.insert(xptrVar, 1); }
700  }
701  }
702  }
703  }
704  }
705 
706  // the sizes of the tables produced when removing a given discrete variable
707  PriorityQueue< const DiscreteVariable*, double > product_size;
708 
709  // initialize properly product_size
710  for (const auto& elt : tables_vars_per_var) {
711  double size = 1.0;
712  const auto ptrVar = elt.first;
713  const auto hashvars = elt.second; // HashTable<DiscreteVariable*, int>
714 
715  if (hashvars.size()) {
716  for (const auto& xelt : hashvars) {
717  size *= (double) xelt.first->domainSize();
718  }
719 
720  product_size.insert(ptrVar, size);
721  }
722  }
723 
724  // the resulting memory consumtions
725  long max_memory = 0;
726  long current_memory = 0;
727 
728  // create a set of the temporary table's variables created during the
729  // marginalization process (useful for deallocating temporary tables)
730  Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
731  table_set.size());
732 
733  // now, remove all the variables in del_vars, starting from those that
734  // produce
735  // the smallest tables
736  while (!product_size.empty()) {
737  // get the best variable to remove
738  const DiscreteVariable* del_var = product_size.pop();
739  del_vars.erase(del_var);
740 
741  // get the set of tables to combine
742  Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
743  tables_per_var[del_var];
744 
745  // if there is no tables to combine, do nothing
746  if (tables_to_combine.size() == 0) continue;
747 
748  // compute the combination of all the tables: if there is only one table,
749  // there is nothing to do, else we shall use the MultiDimCombination
750  // to perform the combination
751  Sequence< const DiscreteVariable* >* joint;
752 
753  bool joint_to_delete = false;
754 
755  if (tables_to_combine.size() == 1) {
756  joint = const_cast< Sequence< const DiscreteVariable* >* >(
757  *(tables_to_combine.beginSafe()));
758  joint_to_delete = false;
759  } else {
760  // here, compute the union of all the variables of the tables to combine
761  joint = new Sequence< const DiscreteVariable* >;
762 
763  for (const auto ptrSeq : tables_to_combine) {
764  for (const auto ptrVar : *ptrSeq) {
765  if (!joint->exists(ptrVar)) { joint->insert(ptrVar); }
766  }
767  }
768 
769  joint_to_delete = true;
770 
771  // update the number of operations performed
772  std::pair< long, long > comb_memory =
773  __combination->memoryUsage(tables_to_combine);
774 
775  if ((std::numeric_limits< long >::max() - current_memory
776  < comb_memory.first)
777  || (std::numeric_limits< long >::max() - current_memory
778  < comb_memory.second)) {
779  GUM_ERROR(OutOfBounds, "memory usage out of long int range");
780  }
781 
782  if (current_memory + comb_memory.first > max_memory) {
783  max_memory = current_memory + comb_memory.first;
784  }
785 
786  current_memory += comb_memory.second;
787  }
788 
789  // update the number of operations performed by marginalizing out del_var
790  Set< const DiscreteVariable* > del_one_var;
791  del_one_var << del_var;
792 
793  std::pair< long, long > comb_memory =
794  __projection->memoryUsage(*joint, del_one_var);
795 
796  if ((std::numeric_limits< long >::max() - current_memory < comb_memory.first)
797  || (std::numeric_limits< long >::max() - current_memory
798  < comb_memory.second)) {
799  GUM_ERROR(OutOfBounds, "memory usage out of long int range");
800  }
801 
802  if (current_memory + comb_memory.first > max_memory) {
803  max_memory = current_memory + comb_memory.first;
804  }
805 
806  current_memory += comb_memory.second;
807 
808  // compute the table resulting from marginalizing out del_var from joint
809  Sequence< const DiscreteVariable* >* marginal;
810 
811  if (joint_to_delete) {
812  marginal = joint;
813  } else {
814  marginal = new Sequence< const DiscreteVariable* >(*joint);
815  }
816 
817  marginal->erase(del_var);
818 
819  // update tables_vars_per_var : remove the variables of the TABLEs we
820  // combined from this hashtable
821  // update accordingly tables_per_vars : remove these TABLEs
822  // update accordingly product_size : when a variable is no more used by
823  // any TABLE, divide product_size by its domain size
824 
825  for (const auto ptrSeq : tables_to_combine) {
826  const Sequence< const DiscreteVariable* >& table_vars = *ptrSeq;
827  const Size tab_vars_size = table_vars.size();
828 
829  for (Size i = 0; i < tab_vars_size; ++i) {
830  if (del_vars.contains(table_vars[i])) {
831  // ok, here we have a variable that needed to be removed => update
832  // product_size, tables_per_var and tables_vars_per_var
833  HashTable< const DiscreteVariable*, unsigned int >&
834  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
835  double div_size = 1.0;
836 
837  for (Size j = 0; j < tab_vars_size; ++j) {
838  Size k = --table_vars_of_var_i[table_vars[j]];
839 
840  if (k == 0) {
841  div_size *= table_vars[j]->domainSize();
842  table_vars_of_var_i.erase(table_vars[j]);
843  }
844  }
845 
846  tables_per_var[table_vars[i]].erase(ptrSeq);
847 
848  if (div_size != 1) {
849  product_size.setPriority(
850  table_vars[i], product_size.priority(table_vars[i]) / div_size);
851  }
852  }
853  }
854 
855  if (tmp_marginals.contains(ptrSeq)) {
856  Size del_size = 1;
857 
858  for (const auto ptrVar : *ptrSeq) {
859  del_size *= ptrVar->domainSize();
860  }
861 
862  current_memory -= long(del_size);
863 
864  delete ptrSeq;
865  tmp_marginals.erase(ptrSeq);
866  }
867  }
868 
869  tables_per_var.erase(del_var);
870 
871  // add the new projected marginal to the list of TABLES
872  for (const auto mvar : *marginal) {
873  if (del_vars.contains(mvar)) {
874  // add the new marginal table to the set of tables of var i
875  tables_per_var[mvar].insert(marginal);
876 
877  // add the variables of the table to tables_vars_per_var[vars[i]]
878  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
879  tables_vars_per_var[mvar];
880  double mult_size = 1.0;
881 
882  for (const auto var : *marginal) {
883  try {
884  ++iter_vars[var];
885  } catch (const NotFound&) {
886  iter_vars.insert(var, 1);
887  mult_size *= (double) var->domainSize();
888  }
889  }
890 
891  if (mult_size != 1) {
892  product_size.setPriority(mvar,
893  product_size.priority(mvar) * mult_size);
894  }
895  }
896  }
897 
898  tmp_marginals.insert(marginal);
899  }
900 
901  // here, tmp_marginals contains all the newly created tables
902  for (auto iter = tmp_marginals.beginSafe();
903  iter != tmp_marginals.endSafe();
904  ++iter) {
905  delete *iter;
906  }
907 
908  return std::pair< long, long >(max_memory, current_memory);
909  }
910 
911  // returns the memory consumption used during the combinations and
912  // projections
913  template < typename GUM_SCALAR, template < typename > class TABLE >
914  std::pair< long, long >
916  const Set< const TABLE< GUM_SCALAR >* >& set,
917  const Set< const DiscreteVariable* >& del_vars) const {
918  // create the set of sets of discrete variables involved in the tables
919  Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
920 
921  for (const auto ptrTab : set) {
922  var_set << &(ptrTab->variablesSequence());
923  }
924 
925  return memoryUsage(var_set, del_vars);
926  }
927 
928 } /* namespace gum */
929 
930 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
virtual ~MultiDimCombineAndProjectDefault()
Destructor.
virtual void setCombineFunction(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &))
changes the function used for combining two TABLES
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
Definition: agrum.h:25
MultiDimProjection< GUM_SCALAR, TABLE > * __projection
the class used for the projections
virtual float nbOperations(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns a rough estimate of the number of operations that will be performed to compute the combinatio...
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &) combineFunction()
Returns the current combination function.
virtual void setProjectionClass(const MultiDimProjection< GUM_SCALAR, TABLE > &proj_class)
Changes the class that performs the projections.
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable *> &) projectFunction()
returns the current projection function
virtual void setCombinationClass(const MultiDimCombination< GUM_SCALAR, TABLE > &comb_class)
changes the class that performs the combinations
MultiDimCombination< GUM_SCALAR, TABLE > * __combination
the class used for the combinations
MultiDimCombineAndProject()
default constructor
MultiDimCombineAndProjectDefault(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &), TABLE< GUM_SCALAR > *(*project)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Default constructor.
virtual MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE > * newFactory() const
virtual constructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:48
virtual std::pair< long, long > memoryUsage(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns the memory consumption used during the combinations and projections
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
virtual void setProjectFunction(TABLE< GUM_SCALAR > *(*proj)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Changes the function used for projecting TABLES.
virtual Set< const TABLE< GUM_SCALAR > *> combineAndProject(Set< const TABLE< GUM_SCALAR > * > set, Set< const DiscreteVariable * > del_vars)
creates and returns the result of the projection over the variables not in del_vars of the combinatio...