aGrUM  0.16.0
multiDimCombineAndProjectDefault_tpl.h
Go to the documentation of this file.
1 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
32 # include <limits>
33 
34 # include <agrum/agrum.h>
35 
37 
38 namespace gum {
39 
40  // default constructor
41  template < typename GUM_SCALAR, template < typename > class TABLE >
44  TABLE< GUM_SCALAR >* (*combine)(const TABLE< GUM_SCALAR >&,
45  const TABLE< GUM_SCALAR >&),
46  TABLE< GUM_SCALAR >* (*project)(const TABLE< GUM_SCALAR >&,
47  const Set< const DiscreteVariable* >&)) :
48  MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
49  __combination(new MultiDimCombinationDefault< GUM_SCALAR, TABLE >(combine)),
50  __projection(new MultiDimProjection< GUM_SCALAR, TABLE >(project)) {
51  // for debugging purposes
52  GUM_CONSTRUCTOR(MultiDimCombineAndProjectDefault);
53  }
54 
55  // copy constructor
56  template < typename GUM_SCALAR, template < typename > class TABLE >
59  const MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >& from) :
60  MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
63  // for debugging purposes
65  }
66 
67  // destructor
68  template < typename GUM_SCALAR, template < typename > class TABLE >
71  // for debugging purposes
72  GUM_DESTRUCTOR(MultiDimCombineAndProjectDefault);
73  delete __combination;
74  delete __projection;
75  }
76 
77  // virtual constructor
78  template < typename GUM_SCALAR, template < typename > class TABLE >
79  MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >*
81  return new MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >(*this);
82  }
83 
84  // combine and project
85  template < typename GUM_SCALAR, template < typename > class TABLE >
86  Set< const TABLE< GUM_SCALAR >* >
88  Set< const TABLE< GUM_SCALAR >* > table_set,
89  Set< const DiscreteVariable* > del_vars) {
90  // when we remove a variable, we need to combine all the tables containing
91  // this
92  // variable in order to produce a new unique table containing this variable.
93  // removing the variable is then performed by marginalizing it out of the
94  // table. In the combineAndProject algorithm, we wish to remove first
95  // variables that produce small tables. This should speed up the
96  // marginalizing
97  // process
98 
99  Size nb_vars;
100  {
101  // determine the set of all the variables involved in the tables.
102  // this should help sizing correctly the hashtables
103  Set< const DiscreteVariable* > all_vars;
104 
105  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
106  table_set.beginSafe();
107  iter != table_set.endSafe();
108  ++iter) {
109  const Sequence< const DiscreteVariable* >& iter_vars =
110  (*iter)->variablesSequence();
111 
113  iter_vars.beginSafe();
114  it != iter_vars.endSafe();
115  ++it) {
116  all_vars.insert(*it);
117  }
118  }
119 
120  nb_vars = all_vars.size();
121  }
122 
123  // the tables containing a given variable
124  HashTable< const DiscreteVariable*, Set< const TABLE< GUM_SCALAR >* > >
125  tables_per_var(nb_vars);
126  // for a given variable X to be deleted, the list of all the variables of
127  // the tables containing X (actually, we count the number of tables
128  // containing the variable. This is more efficient for computing and
129  // updating
130  // the product_size priority queue (see below) when some tables are removed)
131  HashTable< const DiscreteVariable*,
132  HashTable< const DiscreteVariable*, unsigned int > >
133  tables_vars_per_var(nb_vars);
134 
135  // initialize tables_vars_per_var and tables_per_var
136  {
137  Set< const TABLE< GUM_SCALAR >* > empty_set(table_set.size());
138  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
139 
141  del_vars.beginSafe();
142  iter != del_vars.endSafe();
143  ++iter) {
144  tables_per_var.insert(*iter, empty_set);
145  tables_vars_per_var.insert(*iter, empty_hash);
146  }
147 
148  // update properly tables_per_var and tables_vars_per_var
149  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
150  table_set.beginSafe();
151  iter != table_set.endSafe();
152  ++iter) {
153  const Sequence< const DiscreteVariable* >& vars =
154  (*iter)->variablesSequence();
155 
156  for (unsigned int i = 0; i < vars.size(); ++i) {
157  if (del_vars.contains(vars[i])) {
158  // add the table to the set of tables related to vars[i]
159  tables_per_var[vars[i]].insert(*iter);
160  // add the variables of the table to tables_vars_per_var[vars[i]]
161  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
162  tables_vars_per_var[vars[i]];
163 
164  for (unsigned int j = 0; j < vars.size(); ++j) {
165  try {
166  ++iter_vars[vars[j]];
167  } catch (const NotFound&) { iter_vars.insert(vars[j], 1); }
168  }
169  }
170  }
171  }
172  }
173 
174  // the sizes of the tables produced when removing a given discrete variable
175  PriorityQueue< const DiscreteVariable*, float > product_size;
176 
177  // initialize properly product_size
178 
179  for (typename HashTable< const DiscreteVariable*,
180  HashTable< const DiscreteVariable*, unsigned int > >::
181  const_iterator_safe iter = tables_vars_per_var.beginSafe();
182  iter != tables_vars_per_var.endSafe();
183  ++iter) {
184  float size = 1.0f;
185  const HashTable< const DiscreteVariable*, unsigned int >& vars = iter.val();
186 
187  if (vars.size()) {
188  for (typename HashTable< const DiscreteVariable*,
189  unsigned int >::const_iterator_safe iter2 =
190  vars.beginSafe();
191  iter2 != vars.endSafe();
192  ++iter2) {
193  size *= iter2.key()->domainSize();
194  }
195 
196  product_size.insert(iter.key(), size);
197  }
198  }
199 
200  // create a set of the temporary tables created during the
201  // marginalization process (useful for deallocating temporary tables)
202  Set< const TABLE< GUM_SCALAR >* > tmp_marginals(table_set.size());
203 
204  // now, remove all the variables in del_vars, starting from those that
205  // produce
206  // the smallest tables
207  while (!product_size.empty()) {
208  // get the best variable to remove
209  const DiscreteVariable* del_var = product_size.pop();
210  del_vars.erase(del_var);
211 
212  // get the set of tables to combine
213  Set< const TABLE< GUM_SCALAR >* >& tables_to_combine =
214  tables_per_var[del_var];
215 
216  // if there is no tables to combine, do nothing
217 
218  if (tables_to_combine.size() == 0) continue;
219 
220  // compute the combination of all the tables: if there is only one table,
221  // there is nothing to do, else we shall use the MultiDimCombination
222  // to perform the combination
223  TABLE< GUM_SCALAR >* joint;
224 
225  bool joint_to_delete = false;
226 
227  if (tables_to_combine.size() == 1) {
228  joint =
229  const_cast< TABLE< GUM_SCALAR >* >(*(tables_to_combine.beginSafe()));
230  joint_to_delete = false;
231  } else {
232  joint = __combination->combine(tables_to_combine);
233  joint_to_delete = true;
234  }
235 
236  // compute the table resulting from marginalizing out del_var from joint
237  Set< const DiscreteVariable* > del_one_var;
238 
239  del_one_var << del_var;
240 
241  TABLE< GUM_SCALAR >* marginal = __projection->project(*joint, del_one_var);
242 
243  // remove the temporary joint if needed
244  if (joint_to_delete) delete joint;
245 
246  // update tables_vars_per_var : remove the variables of the TABLEs we
247  // combined from this hashtable
248  // update accordingly tables_per_vars : remove these TABLEs
249  // update accordingly product_size : when a variable is no more used by
250  // any TABLE, divide product_size by its domain size
251 
252  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
253  tables_to_combine.beginSafe();
254  iter != tables_to_combine.endSafe();
255  ++iter) {
256  const Sequence< const DiscreteVariable* >& table_vars =
257  (*iter)->variablesSequence();
258 
259  for (unsigned int i = 0; i < table_vars.size(); ++i) {
260  if (del_vars.contains(table_vars[i])) {
261  // ok, here we have a variable that needed to be removed => update
262  // product_size, tables_per_var and tables_vars_per_var
263  HashTable< const DiscreteVariable*, unsigned int >&
264  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
265  float div_size = 1.0f;
266 
267  for (unsigned int j = 0; j < table_vars.size(); ++j) {
268  unsigned int k = --table_vars_of_var_i[table_vars[j]];
269 
270  if (k == 0) {
271  div_size *= table_vars[j]->domainSize();
272  table_vars_of_var_i.erase(table_vars[j]);
273  }
274  }
275 
276  tables_per_var[table_vars[i]].erase(*iter);
277 
278  if (div_size != 1) {
279  product_size.setPriority(
280  table_vars[i], product_size.priority(table_vars[i]) / div_size);
281  }
282  }
283  }
284 
285  if (tmp_marginals.contains(*iter)) {
286  delete *iter;
287  tmp_marginals.erase(*iter);
288  }
289 
290  table_set.erase(*iter);
291  }
292 
293  tables_per_var.erase(del_var);
294 
295  // add the new projected marginal to the list of TABLES
296  const Sequence< const DiscreteVariable* >& marginal_vars =
297  marginal->variablesSequence();
298 
299  for (unsigned int i = 0; i < marginal_vars.size(); ++i) {
300  if (del_vars.contains(marginal_vars[i])) {
301  // add the new marginal table to the set of tables of var i
302  tables_per_var[marginal_vars[i]].insert(marginal);
303 
304  // add the variables of the table to tables_vars_per_var[vars[i]]
305  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
306  tables_vars_per_var[marginal_vars[i]];
307  float mult_size = 1.0f;
308 
309  for (unsigned int j = 0; j < marginal_vars.size(); ++j) {
310  try {
311  ++iter_vars[marginal_vars[j]];
312  } catch (const NotFound&) {
313  iter_vars.insert(marginal_vars[j], 1);
314  mult_size *= marginal_vars[j]->domainSize();
315  }
316  }
317 
318  if (mult_size != 1) {
319  product_size.setPriority(marginal_vars[i],
320  product_size.priority(marginal_vars[i])
321  * mult_size);
322  }
323  }
324  }
325 
326  table_set.insert(marginal);
327 
328  tmp_marginals.insert(marginal);
329  }
330 
331  // here, tmp_marginals contains all the newly created tables and
332  // table_set contains the list of the tables resulting from the
333  // marginalizing out of del_vars of the combination of the tables
334  // of table_set
335 
336  return table_set;
337  }
338 
339  // changes the function used for combining two TABLES
340  template < typename GUM_SCALAR, template < typename > class TABLE >
341  INLINE void
343  TABLE< GUM_SCALAR >* (*combine)(const TABLE< GUM_SCALAR >&,
344  const TABLE< GUM_SCALAR >&)) {
345  __combination->setCombineFunction(combine);
346  }
347 
348  // returns the current combination function
349  template < typename GUM_SCALAR, template < typename > class TABLE >
350  INLINE TABLE< GUM_SCALAR >* (
352  const TABLE< GUM_SCALAR >&, const TABLE< GUM_SCALAR >&) {
353  return __combination->combineFunction();
354  }
355 
356  // changes the class that performs the combinations
357  template < typename GUM_SCALAR, template < typename > class TABLE >
358  INLINE void
360  const MultiDimCombination< GUM_SCALAR, TABLE >& comb_class) {
361  delete __combination;
362  __combination = comb_class.newFactory();
363  }
364 
365  // changes the function used for projecting TABLES
366  template < typename GUM_SCALAR, template < typename > class TABLE >
367  INLINE void
369  TABLE< GUM_SCALAR >* (*proj)(const TABLE< GUM_SCALAR >&,
370  const Set< const DiscreteVariable* >&)) {
371  __projection->setProjectFunction(proj);
372  }
373 
374  // returns the current projection function
375  template < typename GUM_SCALAR, template < typename > class TABLE >
376  INLINE TABLE< GUM_SCALAR >* (
378  const TABLE< GUM_SCALAR >&, const Set< const DiscreteVariable* >&) {
379  return __projection->projectFunction();
380  }
381 
382  // changes the class that performs the projections
383  template < typename GUM_SCALAR, template < typename > class TABLE >
384  INLINE void
386  const MultiDimProjection< GUM_SCALAR, TABLE >& proj_class) {
387  delete __projection;
388  __projection = proj_class.newFactory();
389  }
390 
393  template < typename GUM_SCALAR, template < typename > class TABLE >
395  const Set< const Sequence< const DiscreteVariable* >* >& table_set,
396  Set< const DiscreteVariable* > del_vars) const {
397  // when we remove a variable, we need to combine all the tables containing
398  // this
399  // variable in order to produce a new unique table containing this variable.
400  // Here, we do not have the tables but only their variables (dimensions),
401  // but
402  // the principle is identical. Removing a variable is then performed by
403  // marginalizing it out of the table or, equivalently, to remove it from the
404  // table's list of variables. In the combineAndProjectDefault algorithm, we
405  // wish to remove first variables that would produce small tables. This
406  // should speed up the whole marginalizing process.
407 
408  Size nb_vars;
409  {
410  // determine the set of all the variables involved in the tables.
411  // this should help sizing correctly the hashtables
412  Set< const DiscreteVariable* > all_vars;
413 
414  for (typename Set< const Sequence< const DiscreteVariable* >* >::
415  const_iterator_safe iter = table_set.beginSafe();
416  iter != table_set.endSafe();
417  ++iter) {
418  const Sequence< const DiscreteVariable* >& iter_vars = **iter;
419 
421  iter_vars.beginSafe();
422  it != iter_vars.endSafe();
423  ++it) {
424  all_vars.insert(*it);
425  }
426  }
427 
428  nb_vars = all_vars.size();
429  }
430 
431  // the tables (actually their variables) containing a given variable
432  HashTable< const DiscreteVariable*,
433  Set< const Sequence< const DiscreteVariable* >* > >
434  tables_per_var(nb_vars);
435  // for a given variable X to be deleted, the list of all the variables of
436  // the tables containing X (actually, we count the number of tables
437  // containing the variable. This is more efficient for computing and
438  // updating
439  // the product_size priority queue (see below) when some tables are removed)
440  HashTable< const DiscreteVariable*,
441  HashTable< const DiscreteVariable*, unsigned int > >
442  tables_vars_per_var(nb_vars);
443 
444  // initialize tables_vars_per_var and tables_per_var
445  {
446  Set< const Sequence< const DiscreteVariable* >* > empty_set(
447  table_set.size());
448  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
449 
451  del_vars.beginSafe();
452  iter != del_vars.endSafe();
453  ++iter) {
454  tables_per_var.insert(*iter, empty_set);
455  tables_vars_per_var.insert(*iter, empty_hash);
456  }
457 
458  // update properly tables_per_var and tables_vars_per_var
459  for (typename Set< const Sequence< const DiscreteVariable* >* >::
460  const_iterator_safe iter = table_set.beginSafe();
461  iter != table_set.endSafe();
462  ++iter) {
463  const Sequence< const DiscreteVariable* >& vars = **iter;
464 
465  for (unsigned int i = 0; i < vars.size(); ++i) {
466  if (del_vars.contains(vars[i])) {
467  // add the table's variables to the set of those related to vars[i]
468  tables_per_var[vars[i]].insert(*iter);
469  // add the variables of the table to tables_vars_per_var[vars[i]]
470  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
471  tables_vars_per_var[vars[i]];
472 
473  for (unsigned int j = 0; j < vars.size(); ++j) {
474  try {
475  ++iter_vars[vars[j]];
476  } catch (const NotFound&) { iter_vars.insert(vars[j], 1); }
477  }
478  }
479  }
480  }
481  }
482 
483  // the sizes of the tables produced when removing a given discrete variable
484  PriorityQueue< const DiscreteVariable*, float > product_size;
485 
486  // initialize properly product_size
487 
488  for (typename HashTable< const DiscreteVariable*,
489  HashTable< const DiscreteVariable*, unsigned int > >::
490  const_iterator_safe iter = tables_vars_per_var.beginSafe();
491  iter != tables_vars_per_var.endSafe();
492  ++iter) {
493  float size = 1.0f;
494  const HashTable< const DiscreteVariable*, unsigned int >& vars = iter.val();
495 
496  if (vars.size()) {
497  for (typename HashTable< const DiscreteVariable*,
498  unsigned int >::const_iterator_safe iter2 =
499  vars.beginSafe();
500  iter2 != vars.endSafe();
501  ++iter2) {
502  size *= iter2.key()->domainSize();
503  }
504 
505  product_size.insert(iter.key(), size);
506  }
507  }
508 
509  // the resulting number of operations
510  float nb_operations = 0;
511 
512  // create a set of the temporary table's variables created during the
513  // marginalization process (useful for deallocating temporary tables)
514  Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
515  table_set.size());
516 
517  // now, remove all the variables in del_vars, starting from those that
518  // produce
519  // the smallest tables
520  while (!product_size.empty()) {
521  // get the best variable to remove
522  const DiscreteVariable* del_var = product_size.pop();
523  del_vars.erase(del_var);
524 
525  // get the set of tables to combine
526  Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
527  tables_per_var[del_var];
528 
529  // if there is no tables to combine, do nothing
530 
531  if (tables_to_combine.size() == 0) continue;
532 
533  // compute the combination of all the tables: if there is only one table,
534  // there is nothing to do, else we shall use the MultiDimCombination
535  // to perform the combination
536  Sequence< const DiscreteVariable* >* joint;
537 
538  bool joint_to_delete = false;
539 
540  if (tables_to_combine.size() == 1) {
541  joint = const_cast< Sequence< const DiscreteVariable* >* >(
542  *(tables_to_combine.beginSafe()));
543  joint_to_delete = false;
544  } else {
545  // here, compute the union of all the variables of the tables to combine
546  joint = new Sequence< const DiscreteVariable* >;
547 
548  for (typename Set< const Sequence< const DiscreteVariable* >* >::
549  const_iterator_safe iter = tables_to_combine.beginSafe();
550  iter != tables_to_combine.endSafe();
551  ++iter) {
552  const Sequence< const DiscreteVariable* >& vars = **iter;
553 
555  iter2 = vars.beginSafe();
556  iter2 != vars.endSafe();
557  ++iter2) {
558  if (!joint->exists(*iter2)) { joint->insert(*iter2); }
559  }
560  }
561 
562  joint_to_delete = true;
563 
564  // update the number of operations performed
565  nb_operations += __combination->nbOperations(tables_to_combine);
566  }
567 
568  // update the number of operations performed by marginalizing out del_var
569  Set< const DiscreteVariable* > del_one_var;
570 
571  del_one_var << del_var;
572 
573  nb_operations += __projection->nbOperations(*joint, del_one_var);
574 
575  // compute the table resulting from marginalizing out del_var from joint
576  Sequence< const DiscreteVariable* >* marginal;
577 
578  if (joint_to_delete) {
579  marginal = joint;
580  } else {
581  marginal = new Sequence< const DiscreteVariable* >(*joint);
582  }
583 
584  marginal->erase(del_var);
585 
586  // update tables_vars_per_var : remove the variables of the TABLEs we
587  // combined from this hashtable
588  // update accordingly tables_per_vars : remove these TABLEs
589  // update accordingly product_size : when a variable is no more used by
590  // any TABLE, divide product_size by its domain size
591 
592  for (typename Set< const Sequence< const DiscreteVariable* >* >::
593  const_iterator_safe iter = tables_to_combine.beginSafe();
594  iter != tables_to_combine.endSafe();
595  ++iter) {
596  const Sequence< const DiscreteVariable* >& table_vars = **iter;
597 
598  for (unsigned int i = 0; i < table_vars.size(); ++i) {
599  if (del_vars.contains(table_vars[i])) {
600  // ok, here we have a variable that needed to be removed => update
601  // product_size, tables_per_var and tables_vars_per_var
602  HashTable< const DiscreteVariable*, unsigned int >&
603  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
604  float div_size = 1.0f;
605 
606  for (unsigned int j = 0; j < table_vars.size(); ++j) {
607  unsigned int k = --table_vars_of_var_i[table_vars[j]];
608 
609  if (k == 0) {
610  div_size *= table_vars[j]->domainSize();
611  table_vars_of_var_i.erase(table_vars[j]);
612  }
613  }
614 
615  tables_per_var[table_vars[i]].erase(*iter);
616 
617  if (div_size != 1) {
618  product_size.setPriority(
619  table_vars[i], product_size.priority(table_vars[i]) / div_size);
620  }
621  }
622  }
623 
624  if (tmp_marginals.contains(*iter)) {
625  delete *iter;
626  tmp_marginals.erase(*iter);
627  }
628  }
629 
630  tables_per_var.erase(del_var);
631 
632  // add the new projected marginal to the list of TABLES
633 
634  for (unsigned int i = 0; i < marginal->size(); ++i) {
635  const DiscreteVariable* var_i = marginal->atPos(i);
636 
637  if (del_vars.contains(var_i)) {
638  // add the new marginal table to the set of tables of var i
639  tables_per_var[var_i].insert(marginal);
640 
641  // add the variables of the table to tables_vars_per_var[vars[i]]
642  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
643  tables_vars_per_var[var_i];
644  float mult_size = 1.0f;
645 
646  for (unsigned int j = 0; j < marginal->size(); ++j) {
647  try {
648  ++iter_vars[marginal->atPos(j)];
649  } catch (const NotFound&) {
650  iter_vars.insert(marginal->atPos(j), 1);
651  mult_size *= marginal->atPos(j)->domainSize();
652  }
653  }
654 
655  if (mult_size != 1) {
656  product_size.setPriority(var_i,
657  product_size.priority(var_i) * mult_size);
658  }
659  }
660  }
661 
662  tmp_marginals.insert(marginal);
663  }
664 
665  // here, tmp_marginals contains all the newly created tables
666 
667  for (typename Set< const Sequence< const DiscreteVariable* >* >::
668  const_iterator_safe iter = tmp_marginals.beginSafe();
669  iter != tmp_marginals.endSafe();
670  ++iter) {
671  delete *iter;
672  }
673 
674  return nb_operations;
675  }
676 
679  template < typename GUM_SCALAR, template < typename > class TABLE >
681  const Set< const TABLE< GUM_SCALAR >* >& set,
682  const Set< const DiscreteVariable* >& del_vars) const {
683  // create the set of sets of discrete variables involved in the tables
684  Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
685 
686  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
687  set.beginSafe();
688  iter != set.endSafe();
689  ++iter) {
690  var_set << &((*iter)->variablesSequence());
691  }
692 
693  return nbOperations(var_set, del_vars);
694  }
695 
696  // returns the memory consumption used during the combinations and
697  // projections
698  template < typename GUM_SCALAR, template < typename > class TABLE >
699  std::pair< long, long >
701  const Set< const Sequence< const DiscreteVariable* >* >& table_set,
702  Set< const DiscreteVariable* > del_vars) const {
703  // when we remove a variable, we need to combine all the tables containing
704  // this
705  // variable in order to produce a new unique table containing this variable.
706  // Here, we do not have the tables but only their variables (dimensions),
707  // but
708  // the principle is identical. Removing a variable is then performed by
709  // marginalizing it out of the table or, equivalently, to remove it from the
710  // table's list of variables. In the combineAndProjectDefault algorithm, we
711  // wish to remove first variables that would produce small tables. This
712  // should speed up the whole marginalizing process.
713 
714  Size nb_vars;
715  {
716  // determine the set of all the variables involved in the tables.
717  // this should help sizing correctly the hashtables
718  Set< const DiscreteVariable* > all_vars;
719 
720  for (typename Set< const Sequence< const DiscreteVariable* >* >::
721  const_iterator_safe iter = table_set.beginSafe();
722  iter != table_set.endSafe();
723  ++iter) {
724  const Sequence< const DiscreteVariable* >& iter_vars = **iter;
725 
727  iter_vars.beginSafe();
728  it != iter_vars.endSafe();
729  ++it) {
730  all_vars.insert(*it);
731  }
732  }
733 
734  nb_vars = all_vars.size();
735  }
736 
737  // the tables (actually their variables) containing a given variable
738  HashTable< const DiscreteVariable*,
739  Set< const Sequence< const DiscreteVariable* >* > >
740  tables_per_var(nb_vars);
741  // for a given variable X to be deleted, the list of all the variables of
742  // the tables containing X (actually, we count the number of tables
743  // containing the variable. This is more efficient for computing and
744  // updating
745  // the product_size priority queue (see below) when some tables are removed)
746  HashTable< const DiscreteVariable*,
747  HashTable< const DiscreteVariable*, unsigned int > >
748  tables_vars_per_var(nb_vars);
749 
750  // initialize tables_vars_per_var and tables_per_var
751  {
752  Set< const Sequence< const DiscreteVariable* >* > empty_set(
753  table_set.size());
754  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
755 
757  del_vars.beginSafe();
758  iter != del_vars.endSafe();
759  ++iter) {
760  tables_per_var.insert(*iter, empty_set);
761  tables_vars_per_var.insert(*iter, empty_hash);
762  }
763 
764  // update properly tables_per_var and tables_vars_per_var
765  for (typename Set< const Sequence< const DiscreteVariable* >* >::
766  const_iterator_safe iter = table_set.beginSafe();
767  iter != table_set.endSafe();
768  ++iter) {
769  const Sequence< const DiscreteVariable* >& vars = **iter;
770 
771  for (unsigned int i = 0; i < vars.size(); ++i) {
772  if (del_vars.contains(vars[i])) {
773  // add the table's variables to the set of those related to vars[i]
774  tables_per_var[vars[i]].insert(*iter);
775  // add the variables of the table to tables_vars_per_var[vars[i]]
776  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
777  tables_vars_per_var[vars[i]];
778 
779  for (unsigned int j = 0; j < vars.size(); ++j) {
780  try {
781  ++iter_vars[vars[j]];
782  } catch (const NotFound&) { iter_vars.insert(vars[j], 1); }
783  }
784  }
785  }
786  }
787  }
788 
789  // the sizes of the tables produced when removing a given discrete variable
790  PriorityQueue< const DiscreteVariable*, float > product_size;
791 
792  // initialize properly product_size
793 
794  for (typename HashTable< const DiscreteVariable*,
795  HashTable< const DiscreteVariable*, unsigned int > >::
796  const_iterator_safe iter = tables_vars_per_var.beginSafe();
797  iter != tables_vars_per_var.endSafe();
798  ++iter) {
799  float size = 1.0f;
800  const HashTable< const DiscreteVariable*, unsigned int >& vars = iter.val();
801 
802  if (vars.size()) {
803  for (typename HashTable< const DiscreteVariable*,
804  unsigned int >::const_iterator_safe iter2 =
805  vars.beginSafe();
806  iter2 != vars.endSafe();
807  ++iter2) {
808  size *= iter2.key()->domainSize();
809  }
810 
811  product_size.insert(iter.key(), size);
812  }
813  }
814 
815  // the resulting memory consumtions
816  long max_memory = 0;
817 
818  long current_memory = 0;
819 
820  // create a set of the temporary table's variables created during the
821  // marginalization process (useful for deallocating temporary tables)
822  Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
823  table_set.size());
824 
825  // now, remove all the variables in del_vars, starting from those that
826  // produce
827  // the smallest tables
828  while (!product_size.empty()) {
829  // get the best variable to remove
830  const DiscreteVariable* del_var = product_size.pop();
831  del_vars.erase(del_var);
832 
833  // get the set of tables to combine
834  Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
835  tables_per_var[del_var];
836 
837  // if there is no tables to combine, do nothing
838 
839  if (tables_to_combine.size() == 0) continue;
840 
841  // compute the combination of all the tables: if there is only one table,
842  // there is nothing to do, else we shall use the MultiDimCombination
843  // to perform the combination
844  Sequence< const DiscreteVariable* >* joint;
845 
846  bool joint_to_delete = false;
847 
848  if (tables_to_combine.size() == 1) {
849  joint = const_cast< Sequence< const DiscreteVariable* >* >(
850  *(tables_to_combine.beginSafe()));
851  joint_to_delete = false;
852  } else {
853  // here, compute the union of all the variables of the tables to combine
854  joint = new Sequence< const DiscreteVariable* >;
855 
856  for (typename Set< const Sequence< const DiscreteVariable* >* >::
857  const_iterator_safe iter = tables_to_combine.beginSafe();
858  iter != tables_to_combine.endSafe();
859  ++iter) {
860  const Sequence< const DiscreteVariable* >& vars = **iter;
861 
863  iter2 = vars.beginSafe();
864  iter2 != vars.endSafe();
865  ++iter2) {
866  if (!joint->exists(*iter2)) { joint->insert(*iter2); }
867  }
868  }
869 
870  joint_to_delete = true;
871 
872  // update the number of operations performed
873  std::pair< long, long > comb_memory =
874  __combination->memoryUsage(tables_to_combine);
875 
876  if ((std::numeric_limits< long >::max() - current_memory
877  < comb_memory.first)
878  || (std::numeric_limits< long >::max() - current_memory
879  < comb_memory.second)) {
880  GUM_ERROR(OutOfBounds, "memory usage out of long int range");
881  }
882 
883  if (current_memory + comb_memory.first > max_memory) {
884  max_memory = current_memory + comb_memory.first;
885  }
886 
887  current_memory += comb_memory.second;
888  }
889 
890  // update the number of operations performed by marginalizing out del_var
891  Set< const DiscreteVariable* > del_one_var;
892 
893  del_one_var << del_var;
894 
895  std::pair< long, long > comb_memory =
896  __projection->memoryUsage(*joint, del_one_var);
897 
898  if ((std::numeric_limits< long >::max() - current_memory < comb_memory.first)
899  || (std::numeric_limits< long >::max() - current_memory
900  < comb_memory.second)) {
901  GUM_ERROR(OutOfBounds, "memory usage out of long int range");
902  }
903 
904  if (current_memory + comb_memory.first > max_memory) {
905  max_memory = current_memory + comb_memory.first;
906  }
907 
908  current_memory += comb_memory.second;
909 
910  // compute the table resulting from marginalizing out del_var from joint
911  Sequence< const DiscreteVariable* >* marginal;
912 
913  if (joint_to_delete) {
914  marginal = joint;
915  } else {
916  marginal = new Sequence< const DiscreteVariable* >(*joint);
917  }
918 
919  marginal->erase(del_var);
920 
921  // update tables_vars_per_var : remove the variables of the TABLEs we
922  // combined from this hashtable
923  // update accordingly tables_per_vars : remove these TABLEs
924  // update accordingly product_size : when a variable is no more used by
925  // any TABLE, divide product_size by its domain size
926 
927  for (typename Set< const Sequence< const DiscreteVariable* >* >::
928  const_iterator_safe iter = tables_to_combine.beginSafe();
929  iter != tables_to_combine.endSafe();
930  ++iter) {
931  const Sequence< const DiscreteVariable* >& table_vars = **iter;
932 
933  for (unsigned int i = 0; i < table_vars.size(); ++i) {
934  if (del_vars.contains(table_vars[i])) {
935  // ok, here we have a variable that needed to be removed => update
936  // product_size, tables_per_var and tables_vars_per_var
937  HashTable< const DiscreteVariable*, unsigned int >&
938  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
939  float div_size = 1.0f;
940 
941  for (unsigned int j = 0; j < table_vars.size(); ++j) {
942  unsigned int k = --table_vars_of_var_i[table_vars[j]];
943 
944  if (k == 0) {
945  div_size *= table_vars[j]->domainSize();
946  table_vars_of_var_i.erase(table_vars[j]);
947  }
948  }
949 
950  tables_per_var[table_vars[i]].erase(*iter);
951 
952  if (div_size != 1) {
953  product_size.setPriority(
954  table_vars[i], product_size.priority(table_vars[i]) / div_size);
955  }
956  }
957  }
958 
959  if (tmp_marginals.contains(*iter)) {
960  Size del_size = 1;
961  const Sequence< const DiscreteVariable* >& del = **iter;
962 
964  iter_del = del.beginSafe();
965  iter_del != del.endSafe();
966  ++iter_del) {
967  del_size *= (*iter_del)->domainSize();
968  }
969 
970  current_memory -= long(del_size);
971 
972  delete *iter;
973  tmp_marginals.erase(*iter);
974  }
975  }
976 
977  tables_per_var.erase(del_var);
978 
979  // add the new projected marginal to the list of TABLES
980 
981  for (unsigned int i = 0; i < marginal->size(); ++i) {
982  const DiscreteVariable* var_i = marginal->atPos(i);
983 
984  if (del_vars.contains(var_i)) {
985  // add the new marginal table to the set of tables of var i
986  tables_per_var[var_i].insert(marginal);
987 
988  // add the variables of the table to tables_vars_per_var[vars[i]]
989  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
990  tables_vars_per_var[var_i];
991  float mult_size = 1.0f;
992 
993  for (unsigned int j = 0; j < marginal->size(); ++j) {
994  try {
995  ++iter_vars[marginal->atPos(j)];
996  } catch (const NotFound&) {
997  iter_vars.insert(marginal->atPos(j), 1);
998  mult_size *= marginal->atPos(j)->domainSize();
999  }
1000  }
1001 
1002  if (mult_size != 1) {
1003  product_size.setPriority(var_i,
1004  product_size.priority(var_i) * mult_size);
1005  }
1006  }
1007  }
1008 
1009  tmp_marginals.insert(marginal);
1010  }
1011 
1012  // here, tmp_marginals contains all the newly created tables
1013  for (typename Set< const Sequence< const DiscreteVariable* >* >::
1014  const_iterator_safe iter = tmp_marginals.beginSafe();
1015  iter != tmp_marginals.endSafe();
1016  ++iter) {
1017  delete *iter;
1018  }
1019 
1020  return std::pair< long, long >(max_memory, current_memory);
1021  }
1022 
1023  // returns the memory consumption used during the combinations and
1024  // projections
1025  template < typename GUM_SCALAR, template < typename > class TABLE >
1026  std::pair< long, long >
1028  const Set< const TABLE< GUM_SCALAR >* >& set,
1029  const Set< const DiscreteVariable* >& del_vars) const {
1030  // create the set of sets of discrete variables involved in the tables
1031  Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
1032 
1033  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
1034  set.beginSafe();
1035  iter != set.endSafe();
1036  ++iter) {
1037  var_set << &((*iter)->variablesSequence());
1038  }
1039 
1040  return memoryUsage(var_set, del_vars);
1041  }
1042 
1043 } /* namespace gum */
1044 
1045 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
virtual ~MultiDimCombineAndProjectDefault()
Destructor.
virtual void setCombineFunction(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &))
changes the function used for combining two TABLES
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
MultiDimProjection< GUM_SCALAR, TABLE > * __projection
the class used for the projections
SetIteratorSafe< Key > const_iterator_safe
Types for STL compliance.
Definition: set.h:180
virtual float nbOperations(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns a rough estimate of the number of operations that will be performed to compute the combinatio...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &) combineFunction()
Returns the current combination function.
virtual void setProjectionClass(const MultiDimProjection< GUM_SCALAR, TABLE > &proj_class)
Changes the class that performs the projections.
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable *> &) projectFunction()
returns the current projection function
virtual void setCombinationClass(const MultiDimCombination< GUM_SCALAR, TABLE > &comb_class)
changes the class that performs the combinations
MultiDimCombination< GUM_SCALAR, TABLE > * __combination
the class used for the combinations
MultiDimCombineAndProject()
default constructor
MultiDimCombineAndProjectDefault(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &), TABLE< GUM_SCALAR > *(*project)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Default constructor.
virtual MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE > * newFactory() const
virtual constructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:48
virtual std::pair< long, long > memoryUsage(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns the memory consumption used during the combinations and projections
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
virtual void setProjectFunction(TABLE< GUM_SCALAR > *(*proj)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Changes the function used for projecting TABLES.
SequenceIteratorSafe< Key > const_iterator_safe
Types for STL compliance.
Definition: sequence.h:1038
virtual Set< const TABLE< GUM_SCALAR > *> combineAndProject(Set< const TABLE< GUM_SCALAR > * > set, Set< const DiscreteVariable * > del_vars)
creates and returns the result of the projection over the variables not in del_vars of the combinatio...