aGrUM  0.14.2
multiDimCombineAndProjectDefault_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Pierre-Henri WUILLEMIN et Christophe GONZALES *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
27 #ifndef DOXYGEN_SHOULD_SKIP_THIS
28 
29 # include <limits>
30 
31 # include <agrum/agrum.h>
32 
34 
35 namespace gum {
36 
37  // default constructor
38  template < typename GUM_SCALAR, template < typename > class TABLE >
41  TABLE< GUM_SCALAR >* (*combine)(const TABLE< GUM_SCALAR >&,
42  const TABLE< GUM_SCALAR >&),
43  TABLE< GUM_SCALAR >* (*project)(const TABLE< GUM_SCALAR >&,
44  const Set< const DiscreteVariable* >&)) :
45  MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
46  __combination(new MultiDimCombinationDefault< GUM_SCALAR, TABLE >(combine)),
47  __projection(new MultiDimProjection< GUM_SCALAR, TABLE >(project)) {
48  // for debugging purposes
49  GUM_CONSTRUCTOR(MultiDimCombineAndProjectDefault);
50  }
51 
52  // copy constructor
53  template < typename GUM_SCALAR, template < typename > class TABLE >
56  const MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >& from) :
57  MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
60  // for debugging purposes
62  }
63 
64  // destructor
65  template < typename GUM_SCALAR, template < typename > class TABLE >
68  // for debugging purposes
69  GUM_DESTRUCTOR(MultiDimCombineAndProjectDefault);
70  delete __combination;
71  delete __projection;
72  }
73 
74  // virtual constructor
75  template < typename GUM_SCALAR, template < typename > class TABLE >
76  MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >*
78  return new MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >(*this);
79  }
80 
81  // combine and project
82  template < typename GUM_SCALAR, template < typename > class TABLE >
83  Set< const TABLE< GUM_SCALAR >* >
85  Set< const TABLE< GUM_SCALAR >* > table_set,
86  Set< const DiscreteVariable* > del_vars) {
87  // when we remove a variable, we need to combine all the tables containing
88  // this
89  // variable in order to produce a new unique table containing this variable.
90  // removing the variable is then performed by marginalizing it out of the
91  // table. In the combineAndProject algorithm, we wish to remove first
92  // variables that produce small tables. This should speed up the
93  // marginalizing
94  // process
95 
96  Size nb_vars;
97  {
98  // determine the set of all the variables involved in the tables.
99  // this should help sizing correctly the hashtables
100  Set< const DiscreteVariable* > all_vars;
101 
102  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
103  table_set.beginSafe();
104  iter != table_set.endSafe();
105  ++iter) {
106  const Sequence< const DiscreteVariable* >& iter_vars =
107  (*iter)->variablesSequence();
108 
110  iter_vars.beginSafe();
111  it != iter_vars.endSafe();
112  ++it) {
113  all_vars.insert(*it);
114  }
115  }
116 
117  nb_vars = all_vars.size();
118  }
119 
120  // the tables containing a given variable
121  HashTable< const DiscreteVariable*, Set< const TABLE< GUM_SCALAR >* > >
122  tables_per_var(nb_vars);
123  // for a given variable X to be deleted, the list of all the variables of
124  // the tables containing X (actually, we count the number of tables
125  // containing the variable. This is more efficient for computing and
126  // updating
127  // the product_size priority queue (see below) when some tables are removed)
128  HashTable< const DiscreteVariable*,
129  HashTable< const DiscreteVariable*, unsigned int > >
130  tables_vars_per_var(nb_vars);
131 
132  // initialize tables_vars_per_var and tables_per_var
133  {
134  Set< const TABLE< GUM_SCALAR >* > empty_set(table_set.size());
135  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
136 
138  del_vars.beginSafe();
139  iter != del_vars.endSafe();
140  ++iter) {
141  tables_per_var.insert(*iter, empty_set);
142  tables_vars_per_var.insert(*iter, empty_hash);
143  }
144 
145  // update properly tables_per_var and tables_vars_per_var
146  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
147  table_set.beginSafe();
148  iter != table_set.endSafe();
149  ++iter) {
150  const Sequence< const DiscreteVariable* >& vars =
151  (*iter)->variablesSequence();
152 
153  for (unsigned int i = 0; i < vars.size(); ++i) {
154  if (del_vars.contains(vars[i])) {
155  // add the table to the set of tables related to vars[i]
156  tables_per_var[vars[i]].insert(*iter);
157  // add the variables of the table to tables_vars_per_var[vars[i]]
158  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
159  tables_vars_per_var[vars[i]];
160 
161  for (unsigned int j = 0; j < vars.size(); ++j) {
162  try {
163  ++iter_vars[vars[j]];
164  } catch (const NotFound&) { iter_vars.insert(vars[j], 1); }
165  }
166  }
167  }
168  }
169  }
170 
171  // the sizes of the tables produced when removing a given discrete variable
172  PriorityQueue< const DiscreteVariable*, float > product_size;
173 
174  // initialize properly product_size
175 
176  for (typename HashTable< const DiscreteVariable*,
177  HashTable< const DiscreteVariable*, unsigned int > >::
178  const_iterator_safe iter = tables_vars_per_var.beginSafe();
179  iter != tables_vars_per_var.endSafe();
180  ++iter) {
181  float size = 1.0f;
182  const HashTable< const DiscreteVariable*, unsigned int >& vars = iter.val();
183 
184  if (vars.size()) {
185  for (typename HashTable< const DiscreteVariable*,
186  unsigned int >::const_iterator_safe iter2 =
187  vars.beginSafe();
188  iter2 != vars.endSafe();
189  ++iter2) {
190  size *= iter2.key()->domainSize();
191  }
192 
193  product_size.insert(iter.key(), size);
194  }
195  }
196 
197  // create a set of the temporary tables created during the
198  // marginalization process (useful for deallocating temporary tables)
199  Set< const TABLE< GUM_SCALAR >* > tmp_marginals(table_set.size());
200 
201  // now, remove all the variables in del_vars, starting from those that
202  // produce
203  // the smallest tables
204  while (!product_size.empty()) {
205  // get the best variable to remove
206  const DiscreteVariable* del_var = product_size.pop();
207  del_vars.erase(del_var);
208 
209  // get the set of tables to combine
210  Set< const TABLE< GUM_SCALAR >* >& tables_to_combine =
211  tables_per_var[del_var];
212 
213  // if there is no tables to combine, do nothing
214 
215  if (tables_to_combine.size() == 0) continue;
216 
217  // compute the combination of all the tables: if there is only one table,
218  // there is nothing to do, else we shall use the MultiDimCombination
219  // to perform the combination
220  TABLE< GUM_SCALAR >* joint;
221 
222  bool joint_to_delete = false;
223 
224  if (tables_to_combine.size() == 1) {
225  joint =
226  const_cast< TABLE< GUM_SCALAR >* >(*(tables_to_combine.beginSafe()));
227  joint_to_delete = false;
228  } else {
229  joint = __combination->combine(tables_to_combine);
230  joint_to_delete = true;
231  }
232 
233  // compute the table resulting from marginalizing out del_var from joint
234  Set< const DiscreteVariable* > del_one_var;
235 
236  del_one_var << del_var;
237 
238  TABLE< GUM_SCALAR >* marginal = __projection->project(*joint, del_one_var);
239 
240  // remove the temporary joint if needed
241  if (joint_to_delete) delete joint;
242 
243  // update tables_vars_per_var : remove the variables of the TABLEs we
244  // combined from this hashtable
245  // update accordingly tables_per_vars : remove these TABLEs
246  // update accordingly product_size : when a variable is no more used by
247  // any TABLE, divide product_size by its domain size
248 
249  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
250  tables_to_combine.beginSafe();
251  iter != tables_to_combine.endSafe();
252  ++iter) {
253  const Sequence< const DiscreteVariable* >& table_vars =
254  (*iter)->variablesSequence();
255 
256  for (unsigned int i = 0; i < table_vars.size(); ++i) {
257  if (del_vars.contains(table_vars[i])) {
258  // ok, here we have a variable that needed to be removed => update
259  // product_size, tables_per_var and tables_vars_per_var
260  HashTable< const DiscreteVariable*, unsigned int >&
261  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
262  float div_size = 1.0f;
263 
264  for (unsigned int j = 0; j < table_vars.size(); ++j) {
265  unsigned int k = --table_vars_of_var_i[table_vars[j]];
266 
267  if (k == 0) {
268  div_size *= table_vars[j]->domainSize();
269  table_vars_of_var_i.erase(table_vars[j]);
270  }
271  }
272 
273  tables_per_var[table_vars[i]].erase(*iter);
274 
275  if (div_size != 1) {
276  product_size.setPriority(
277  table_vars[i], product_size.priority(table_vars[i]) / div_size);
278  }
279  }
280  }
281 
282  if (tmp_marginals.contains(*iter)) {
283  delete *iter;
284  tmp_marginals.erase(*iter);
285  }
286 
287  table_set.erase(*iter);
288  }
289 
290  tables_per_var.erase(del_var);
291 
292  // add the new projected marginal to the list of TABLES
293  const Sequence< const DiscreteVariable* >& marginal_vars =
294  marginal->variablesSequence();
295 
296  for (unsigned int i = 0; i < marginal_vars.size(); ++i) {
297  if (del_vars.contains(marginal_vars[i])) {
298  // add the new marginal table to the set of tables of var i
299  tables_per_var[marginal_vars[i]].insert(marginal);
300 
301  // add the variables of the table to tables_vars_per_var[vars[i]]
302  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
303  tables_vars_per_var[marginal_vars[i]];
304  float mult_size = 1.0f;
305 
306  for (unsigned int j = 0; j < marginal_vars.size(); ++j) {
307  try {
308  ++iter_vars[marginal_vars[j]];
309  } catch (const NotFound&) {
310  iter_vars.insert(marginal_vars[j], 1);
311  mult_size *= marginal_vars[j]->domainSize();
312  }
313  }
314 
315  if (mult_size != 1) {
316  product_size.setPriority(marginal_vars[i],
317  product_size.priority(marginal_vars[i])
318  * mult_size);
319  }
320  }
321  }
322 
323  table_set.insert(marginal);
324 
325  tmp_marginals.insert(marginal);
326  }
327 
328  // here, tmp_marginals contains all the newly created tables and
329  // table_set contains the list of the tables resulting from the
330  // marginalizing out of del_vars of the combination of the tables
331  // of table_set
332 
333  return table_set;
334  }
335 
336  // changes the function used for combining two TABLES
337  template < typename GUM_SCALAR, template < typename > class TABLE >
338  INLINE void
340  TABLE< GUM_SCALAR >* (*combine)(const TABLE< GUM_SCALAR >&,
341  const TABLE< GUM_SCALAR >&)) {
342  __combination->setCombineFunction(combine);
343  }
344 
345  // returns the current combination function
346  template < typename GUM_SCALAR, template < typename > class TABLE >
347  INLINE TABLE< GUM_SCALAR >* (
349  const TABLE< GUM_SCALAR >&, const TABLE< GUM_SCALAR >&) {
350  return __combination->combineFunction();
351  }
352 
353  // changes the class that performs the combinations
354  template < typename GUM_SCALAR, template < typename > class TABLE >
355  INLINE void
357  const MultiDimCombination< GUM_SCALAR, TABLE >& comb_class) {
358  delete __combination;
359  __combination = comb_class.newFactory();
360  }
361 
362  // changes the function used for projecting TABLES
363  template < typename GUM_SCALAR, template < typename > class TABLE >
364  INLINE void
366  TABLE< GUM_SCALAR >* (*proj)(const TABLE< GUM_SCALAR >&,
367  const Set< const DiscreteVariable* >&)) {
368  __projection->setProjectFunction(proj);
369  }
370 
371  // returns the current projection function
372  template < typename GUM_SCALAR, template < typename > class TABLE >
373  INLINE TABLE< GUM_SCALAR >* (
375  const TABLE< GUM_SCALAR >&, const Set< const DiscreteVariable* >&) {
376  return __projection->projectFunction();
377  }
378 
379  // changes the class that performs the projections
380  template < typename GUM_SCALAR, template < typename > class TABLE >
381  INLINE void
383  const MultiDimProjection< GUM_SCALAR, TABLE >& proj_class) {
384  delete __projection;
385  __projection = proj_class.newFactory();
386  }
387 
390  template < typename GUM_SCALAR, template < typename > class TABLE >
392  const Set< const Sequence< const DiscreteVariable* >* >& table_set,
393  Set< const DiscreteVariable* > del_vars) const {
394  // when we remove a variable, we need to combine all the tables containing
395  // this
396  // variable in order to produce a new unique table containing this variable.
397  // Here, we do not have the tables but only their variables (dimensions),
398  // but
399  // the principle is identical. Removing a variable is then performed by
400  // marginalizing it out of the table or, equivalently, to remove it from the
401  // table's list of variables. In the combineAndProjectDefault algorithm, we
402  // wish to remove first variables that would produce small tables. This
403  // should speed up the whole marginalizing process.
404 
405  Size nb_vars;
406  {
407  // determine the set of all the variables involved in the tables.
408  // this should help sizing correctly the hashtables
409  Set< const DiscreteVariable* > all_vars;
410 
411  for (typename Set< const Sequence< const DiscreteVariable* >* >::
412  const_iterator_safe iter = table_set.beginSafe();
413  iter != table_set.endSafe();
414  ++iter) {
415  const Sequence< const DiscreteVariable* >& iter_vars = **iter;
416 
418  iter_vars.beginSafe();
419  it != iter_vars.endSafe();
420  ++it) {
421  all_vars.insert(*it);
422  }
423  }
424 
425  nb_vars = all_vars.size();
426  }
427 
428  // the tables (actually their variables) containing a given variable
429  HashTable< const DiscreteVariable*,
430  Set< const Sequence< const DiscreteVariable* >* > >
431  tables_per_var(nb_vars);
432  // for a given variable X to be deleted, the list of all the variables of
433  // the tables containing X (actually, we count the number of tables
434  // containing the variable. This is more efficient for computing and
435  // updating
436  // the product_size priority queue (see below) when some tables are removed)
437  HashTable< const DiscreteVariable*,
438  HashTable< const DiscreteVariable*, unsigned int > >
439  tables_vars_per_var(nb_vars);
440 
441  // initialize tables_vars_per_var and tables_per_var
442  {
443  Set< const Sequence< const DiscreteVariable* >* > empty_set(
444  table_set.size());
445  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
446 
448  del_vars.beginSafe();
449  iter != del_vars.endSafe();
450  ++iter) {
451  tables_per_var.insert(*iter, empty_set);
452  tables_vars_per_var.insert(*iter, empty_hash);
453  }
454 
455  // update properly tables_per_var and tables_vars_per_var
456  for (typename Set< const Sequence< const DiscreteVariable* >* >::
457  const_iterator_safe iter = table_set.beginSafe();
458  iter != table_set.endSafe();
459  ++iter) {
460  const Sequence< const DiscreteVariable* >& vars = **iter;
461 
462  for (unsigned int i = 0; i < vars.size(); ++i) {
463  if (del_vars.contains(vars[i])) {
464  // add the table's variables to the set of those related to vars[i]
465  tables_per_var[vars[i]].insert(*iter);
466  // add the variables of the table to tables_vars_per_var[vars[i]]
467  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
468  tables_vars_per_var[vars[i]];
469 
470  for (unsigned int j = 0; j < vars.size(); ++j) {
471  try {
472  ++iter_vars[vars[j]];
473  } catch (const NotFound&) { iter_vars.insert(vars[j], 1); }
474  }
475  }
476  }
477  }
478  }
479 
480  // the sizes of the tables produced when removing a given discrete variable
481  PriorityQueue< const DiscreteVariable*, float > product_size;
482 
483  // initialize properly product_size
484 
485  for (typename HashTable< const DiscreteVariable*,
486  HashTable< const DiscreteVariable*, unsigned int > >::
487  const_iterator_safe iter = tables_vars_per_var.beginSafe();
488  iter != tables_vars_per_var.endSafe();
489  ++iter) {
490  float size = 1.0f;
491  const HashTable< const DiscreteVariable*, unsigned int >& vars = iter.val();
492 
493  if (vars.size()) {
494  for (typename HashTable< const DiscreteVariable*,
495  unsigned int >::const_iterator_safe iter2 =
496  vars.beginSafe();
497  iter2 != vars.endSafe();
498  ++iter2) {
499  size *= iter2.key()->domainSize();
500  }
501 
502  product_size.insert(iter.key(), size);
503  }
504  }
505 
506  // the resulting number of operations
507  float nb_operations = 0;
508 
509  // create a set of the temporary table's variables created during the
510  // marginalization process (useful for deallocating temporary tables)
511  Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
512  table_set.size());
513 
514  // now, remove all the variables in del_vars, starting from those that
515  // produce
516  // the smallest tables
517  while (!product_size.empty()) {
518  // get the best variable to remove
519  const DiscreteVariable* del_var = product_size.pop();
520  del_vars.erase(del_var);
521 
522  // get the set of tables to combine
523  Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
524  tables_per_var[del_var];
525 
526  // if there is no tables to combine, do nothing
527 
528  if (tables_to_combine.size() == 0) continue;
529 
530  // compute the combination of all the tables: if there is only one table,
531  // there is nothing to do, else we shall use the MultiDimCombination
532  // to perform the combination
533  Sequence< const DiscreteVariable* >* joint;
534 
535  bool joint_to_delete = false;
536 
537  if (tables_to_combine.size() == 1) {
538  joint = const_cast< Sequence< const DiscreteVariable* >* >(
539  *(tables_to_combine.beginSafe()));
540  joint_to_delete = false;
541  } else {
542  // here, compute the union of all the variables of the tables to combine
543  joint = new Sequence< const DiscreteVariable* >;
544 
545  for (typename Set< const Sequence< const DiscreteVariable* >* >::
546  const_iterator_safe iter = tables_to_combine.beginSafe();
547  iter != tables_to_combine.endSafe();
548  ++iter) {
549  const Sequence< const DiscreteVariable* >& vars = **iter;
550 
552  iter2 = vars.beginSafe();
553  iter2 != vars.endSafe();
554  ++iter2) {
555  if (!joint->exists(*iter2)) { joint->insert(*iter2); }
556  }
557  }
558 
559  joint_to_delete = true;
560 
561  // update the number of operations performed
562  nb_operations += __combination->nbOperations(tables_to_combine);
563  }
564 
565  // update the number of operations performed by marginalizing out del_var
566  Set< const DiscreteVariable* > del_one_var;
567 
568  del_one_var << del_var;
569 
570  nb_operations += __projection->nbOperations(*joint, del_one_var);
571 
572  // compute the table resulting from marginalizing out del_var from joint
573  Sequence< const DiscreteVariable* >* marginal;
574 
575  if (joint_to_delete) {
576  marginal = joint;
577  } else {
578  marginal = new Sequence< const DiscreteVariable* >(*joint);
579  }
580 
581  marginal->erase(del_var);
582 
583  // update tables_vars_per_var : remove the variables of the TABLEs we
584  // combined from this hashtable
585  // update accordingly tables_per_vars : remove these TABLEs
586  // update accordingly product_size : when a variable is no more used by
587  // any TABLE, divide product_size by its domain size
588 
589  for (typename Set< const Sequence< const DiscreteVariable* >* >::
590  const_iterator_safe iter = tables_to_combine.beginSafe();
591  iter != tables_to_combine.endSafe();
592  ++iter) {
593  const Sequence< const DiscreteVariable* >& table_vars = **iter;
594 
595  for (unsigned int i = 0; i < table_vars.size(); ++i) {
596  if (del_vars.contains(table_vars[i])) {
597  // ok, here we have a variable that needed to be removed => update
598  // product_size, tables_per_var and tables_vars_per_var
599  HashTable< const DiscreteVariable*, unsigned int >&
600  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
601  float div_size = 1.0f;
602 
603  for (unsigned int j = 0; j < table_vars.size(); ++j) {
604  unsigned int k = --table_vars_of_var_i[table_vars[j]];
605 
606  if (k == 0) {
607  div_size *= table_vars[j]->domainSize();
608  table_vars_of_var_i.erase(table_vars[j]);
609  }
610  }
611 
612  tables_per_var[table_vars[i]].erase(*iter);
613 
614  if (div_size != 1) {
615  product_size.setPriority(
616  table_vars[i], product_size.priority(table_vars[i]) / div_size);
617  }
618  }
619  }
620 
621  if (tmp_marginals.contains(*iter)) {
622  delete *iter;
623  tmp_marginals.erase(*iter);
624  }
625  }
626 
627  tables_per_var.erase(del_var);
628 
629  // add the new projected marginal to the list of TABLES
630 
631  for (unsigned int i = 0; i < marginal->size(); ++i) {
632  const DiscreteVariable* var_i = marginal->atPos(i);
633 
634  if (del_vars.contains(var_i)) {
635  // add the new marginal table to the set of tables of var i
636  tables_per_var[var_i].insert(marginal);
637 
638  // add the variables of the table to tables_vars_per_var[vars[i]]
639  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
640  tables_vars_per_var[var_i];
641  float mult_size = 1.0f;
642 
643  for (unsigned int j = 0; j < marginal->size(); ++j) {
644  try {
645  ++iter_vars[marginal->atPos(j)];
646  } catch (const NotFound&) {
647  iter_vars.insert(marginal->atPos(j), 1);
648  mult_size *= marginal->atPos(j)->domainSize();
649  }
650  }
651 
652  if (mult_size != 1) {
653  product_size.setPriority(var_i,
654  product_size.priority(var_i) * mult_size);
655  }
656  }
657  }
658 
659  tmp_marginals.insert(marginal);
660  }
661 
662  // here, tmp_marginals contains all the newly created tables
663 
664  for (typename Set< const Sequence< const DiscreteVariable* >* >::
665  const_iterator_safe iter = tmp_marginals.beginSafe();
666  iter != tmp_marginals.endSafe();
667  ++iter) {
668  delete *iter;
669  }
670 
671  return nb_operations;
672  }
673 
676  template < typename GUM_SCALAR, template < typename > class TABLE >
678  const Set< const TABLE< GUM_SCALAR >* >& set,
679  const Set< const DiscreteVariable* >& del_vars) const {
680  // create the set of sets of discrete variables involved in the tables
681  Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
682 
683  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
684  set.beginSafe();
685  iter != set.endSafe();
686  ++iter) {
687  var_set << &((*iter)->variablesSequence());
688  }
689 
690  return nbOperations(var_set, del_vars);
691  }
692 
693  // returns the memory consumption used during the combinations and
694  // projections
695  template < typename GUM_SCALAR, template < typename > class TABLE >
696  std::pair< long, long >
698  const Set< const Sequence< const DiscreteVariable* >* >& table_set,
699  Set< const DiscreteVariable* > del_vars) const {
700  // when we remove a variable, we need to combine all the tables containing
701  // this
702  // variable in order to produce a new unique table containing this variable.
703  // Here, we do not have the tables but only their variables (dimensions),
704  // but
705  // the principle is identical. Removing a variable is then performed by
706  // marginalizing it out of the table or, equivalently, to remove it from the
707  // table's list of variables. In the combineAndProjectDefault algorithm, we
708  // wish to remove first variables that would produce small tables. This
709  // should speed up the whole marginalizing process.
710 
711  Size nb_vars;
712  {
713  // determine the set of all the variables involved in the tables.
714  // this should help sizing correctly the hashtables
715  Set< const DiscreteVariable* > all_vars;
716 
717  for (typename Set< const Sequence< const DiscreteVariable* >* >::
718  const_iterator_safe iter = table_set.beginSafe();
719  iter != table_set.endSafe();
720  ++iter) {
721  const Sequence< const DiscreteVariable* >& iter_vars = **iter;
722 
724  iter_vars.beginSafe();
725  it != iter_vars.endSafe();
726  ++it) {
727  all_vars.insert(*it);
728  }
729  }
730 
731  nb_vars = all_vars.size();
732  }
733 
734  // the tables (actually their variables) containing a given variable
735  HashTable< const DiscreteVariable*,
736  Set< const Sequence< const DiscreteVariable* >* > >
737  tables_per_var(nb_vars);
738  // for a given variable X to be deleted, the list of all the variables of
739  // the tables containing X (actually, we count the number of tables
740  // containing the variable. This is more efficient for computing and
741  // updating
742  // the product_size priority queue (see below) when some tables are removed)
743  HashTable< const DiscreteVariable*,
744  HashTable< const DiscreteVariable*, unsigned int > >
745  tables_vars_per_var(nb_vars);
746 
747  // initialize tables_vars_per_var and tables_per_var
748  {
749  Set< const Sequence< const DiscreteVariable* >* > empty_set(
750  table_set.size());
751  HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
752 
754  del_vars.beginSafe();
755  iter != del_vars.endSafe();
756  ++iter) {
757  tables_per_var.insert(*iter, empty_set);
758  tables_vars_per_var.insert(*iter, empty_hash);
759  }
760 
761  // update properly tables_per_var and tables_vars_per_var
762  for (typename Set< const Sequence< const DiscreteVariable* >* >::
763  const_iterator_safe iter = table_set.beginSafe();
764  iter != table_set.endSafe();
765  ++iter) {
766  const Sequence< const DiscreteVariable* >& vars = **iter;
767 
768  for (unsigned int i = 0; i < vars.size(); ++i) {
769  if (del_vars.contains(vars[i])) {
770  // add the table's variables to the set of those related to vars[i]
771  tables_per_var[vars[i]].insert(*iter);
772  // add the variables of the table to tables_vars_per_var[vars[i]]
773  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
774  tables_vars_per_var[vars[i]];
775 
776  for (unsigned int j = 0; j < vars.size(); ++j) {
777  try {
778  ++iter_vars[vars[j]];
779  } catch (const NotFound&) { iter_vars.insert(vars[j], 1); }
780  }
781  }
782  }
783  }
784  }
785 
786  // the sizes of the tables produced when removing a given discrete variable
787  PriorityQueue< const DiscreteVariable*, float > product_size;
788 
789  // initialize properly product_size
790 
791  for (typename HashTable< const DiscreteVariable*,
792  HashTable< const DiscreteVariable*, unsigned int > >::
793  const_iterator_safe iter = tables_vars_per_var.beginSafe();
794  iter != tables_vars_per_var.endSafe();
795  ++iter) {
796  float size = 1.0f;
797  const HashTable< const DiscreteVariable*, unsigned int >& vars = iter.val();
798 
799  if (vars.size()) {
800  for (typename HashTable< const DiscreteVariable*,
801  unsigned int >::const_iterator_safe iter2 =
802  vars.beginSafe();
803  iter2 != vars.endSafe();
804  ++iter2) {
805  size *= iter2.key()->domainSize();
806  }
807 
808  product_size.insert(iter.key(), size);
809  }
810  }
811 
812  // the resulting memory consumtions
813  long max_memory = 0;
814 
815  long current_memory = 0;
816 
817  // create a set of the temporary table's variables created during the
818  // marginalization process (useful for deallocating temporary tables)
819  Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
820  table_set.size());
821 
822  // now, remove all the variables in del_vars, starting from those that
823  // produce
824  // the smallest tables
825  while (!product_size.empty()) {
826  // get the best variable to remove
827  const DiscreteVariable* del_var = product_size.pop();
828  del_vars.erase(del_var);
829 
830  // get the set of tables to combine
831  Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
832  tables_per_var[del_var];
833 
834  // if there is no tables to combine, do nothing
835 
836  if (tables_to_combine.size() == 0) continue;
837 
838  // compute the combination of all the tables: if there is only one table,
839  // there is nothing to do, else we shall use the MultiDimCombination
840  // to perform the combination
841  Sequence< const DiscreteVariable* >* joint;
842 
843  bool joint_to_delete = false;
844 
845  if (tables_to_combine.size() == 1) {
846  joint = const_cast< Sequence< const DiscreteVariable* >* >(
847  *(tables_to_combine.beginSafe()));
848  joint_to_delete = false;
849  } else {
850  // here, compute the union of all the variables of the tables to combine
851  joint = new Sequence< const DiscreteVariable* >;
852 
853  for (typename Set< const Sequence< const DiscreteVariable* >* >::
854  const_iterator_safe iter = tables_to_combine.beginSafe();
855  iter != tables_to_combine.endSafe();
856  ++iter) {
857  const Sequence< const DiscreteVariable* >& vars = **iter;
858 
860  iter2 = vars.beginSafe();
861  iter2 != vars.endSafe();
862  ++iter2) {
863  if (!joint->exists(*iter2)) { joint->insert(*iter2); }
864  }
865  }
866 
867  joint_to_delete = true;
868 
869  // update the number of operations performed
870  std::pair< long, long > comb_memory =
871  __combination->memoryUsage(tables_to_combine);
872 
873  if ((std::numeric_limits< long >::max() - current_memory
874  < comb_memory.first)
875  || (std::numeric_limits< long >::max() - current_memory
876  < comb_memory.second)) {
877  GUM_ERROR(OutOfBounds, "memory usage out of long int range");
878  }
879 
880  if (current_memory + comb_memory.first > max_memory) {
881  max_memory = current_memory + comb_memory.first;
882  }
883 
884  current_memory += comb_memory.second;
885  }
886 
887  // update the number of operations performed by marginalizing out del_var
888  Set< const DiscreteVariable* > del_one_var;
889 
890  del_one_var << del_var;
891 
892  std::pair< long, long > comb_memory =
893  __projection->memoryUsage(*joint, del_one_var);
894 
895  if ((std::numeric_limits< long >::max() - current_memory < comb_memory.first)
896  || (std::numeric_limits< long >::max() - current_memory
897  < comb_memory.second)) {
898  GUM_ERROR(OutOfBounds, "memory usage out of long int range");
899  }
900 
901  if (current_memory + comb_memory.first > max_memory) {
902  max_memory = current_memory + comb_memory.first;
903  }
904 
905  current_memory += comb_memory.second;
906 
907  // compute the table resulting from marginalizing out del_var from joint
908  Sequence< const DiscreteVariable* >* marginal;
909 
910  if (joint_to_delete) {
911  marginal = joint;
912  } else {
913  marginal = new Sequence< const DiscreteVariable* >(*joint);
914  }
915 
916  marginal->erase(del_var);
917 
918  // update tables_vars_per_var : remove the variables of the TABLEs we
919  // combined from this hashtable
920  // update accordingly tables_per_vars : remove these TABLEs
921  // update accordingly product_size : when a variable is no more used by
922  // any TABLE, divide product_size by its domain size
923 
924  for (typename Set< const Sequence< const DiscreteVariable* >* >::
925  const_iterator_safe iter = tables_to_combine.beginSafe();
926  iter != tables_to_combine.endSafe();
927  ++iter) {
928  const Sequence< const DiscreteVariable* >& table_vars = **iter;
929 
930  for (unsigned int i = 0; i < table_vars.size(); ++i) {
931  if (del_vars.contains(table_vars[i])) {
932  // ok, here we have a variable that needed to be removed => update
933  // product_size, tables_per_var and tables_vars_per_var
934  HashTable< const DiscreteVariable*, unsigned int >&
935  table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
936  float div_size = 1.0f;
937 
938  for (unsigned int j = 0; j < table_vars.size(); ++j) {
939  unsigned int k = --table_vars_of_var_i[table_vars[j]];
940 
941  if (k == 0) {
942  div_size *= table_vars[j]->domainSize();
943  table_vars_of_var_i.erase(table_vars[j]);
944  }
945  }
946 
947  tables_per_var[table_vars[i]].erase(*iter);
948 
949  if (div_size != 1) {
950  product_size.setPriority(
951  table_vars[i], product_size.priority(table_vars[i]) / div_size);
952  }
953  }
954  }
955 
956  if (tmp_marginals.contains(*iter)) {
957  Size del_size = 1;
958  const Sequence< const DiscreteVariable* >& del = **iter;
959 
961  iter_del = del.beginSafe();
962  iter_del != del.endSafe();
963  ++iter_del) {
964  del_size *= (*iter_del)->domainSize();
965  }
966 
967  current_memory -= long(del_size);
968 
969  delete *iter;
970  tmp_marginals.erase(*iter);
971  }
972  }
973 
974  tables_per_var.erase(del_var);
975 
976  // add the new projected marginal to the list of TABLES
977 
978  for (unsigned int i = 0; i < marginal->size(); ++i) {
979  const DiscreteVariable* var_i = marginal->atPos(i);
980 
981  if (del_vars.contains(var_i)) {
982  // add the new marginal table to the set of tables of var i
983  tables_per_var[var_i].insert(marginal);
984 
985  // add the variables of the table to tables_vars_per_var[vars[i]]
986  HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
987  tables_vars_per_var[var_i];
988  float mult_size = 1.0f;
989 
990  for (unsigned int j = 0; j < marginal->size(); ++j) {
991  try {
992  ++iter_vars[marginal->atPos(j)];
993  } catch (const NotFound&) {
994  iter_vars.insert(marginal->atPos(j), 1);
995  mult_size *= marginal->atPos(j)->domainSize();
996  }
997  }
998 
999  if (mult_size != 1) {
1000  product_size.setPriority(var_i,
1001  product_size.priority(var_i) * mult_size);
1002  }
1003  }
1004  }
1005 
1006  tmp_marginals.insert(marginal);
1007  }
1008 
1009  // here, tmp_marginals contains all the newly created tables
1010  for (typename Set< const Sequence< const DiscreteVariable* >* >::
1011  const_iterator_safe iter = tmp_marginals.beginSafe();
1012  iter != tmp_marginals.endSafe();
1013  ++iter) {
1014  delete *iter;
1015  }
1016 
1017  return std::pair< long, long >(max_memory, current_memory);
1018  }
1019 
1020  // returns the memory consumption used during the combinations and
1021  // projections
1022  template < typename GUM_SCALAR, template < typename > class TABLE >
1023  std::pair< long, long >
1025  const Set< const TABLE< GUM_SCALAR >* >& set,
1026  const Set< const DiscreteVariable* >& del_vars) const {
1027  // create the set of sets of discrete variables involved in the tables
1028  Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
1029 
1030  for (typename Set< const TABLE< GUM_SCALAR >* >::const_iterator_safe iter =
1031  set.beginSafe();
1032  iter != set.endSafe();
1033  ++iter) {
1034  var_set << &((*iter)->variablesSequence());
1035  }
1036 
1037  return memoryUsage(var_set, del_vars);
1038  }
1039 
1040 } /* namespace gum */
1041 
1042 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
virtual ~MultiDimCombineAndProjectDefault()
Destructor.
virtual void setCombineFunction(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &))
changes the function used for combining two TABLES
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
MultiDimProjection< GUM_SCALAR, TABLE > * __projection
the class used for the projections
SetIteratorSafe< Key > const_iterator_safe
Types for STL compliance.
Definition: set.h:177
virtual float nbOperations(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns a rough estimate of the number of operations that will be performed to compute the combinatio...
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &) combineFunction()
Returns the current combination function.
virtual void setProjectionClass(const MultiDimProjection< GUM_SCALAR, TABLE > &proj_class)
Changes the class that performs the projections.
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable *> &) projectFunction()
returns the current projection function
virtual void setCombinationClass(const MultiDimCombination< GUM_SCALAR, TABLE > &comb_class)
changes the class that performs the combinations
MultiDimCombination< GUM_SCALAR, TABLE > * __combination
the class used for the combinations
MultiDimCombineAndProject()
default constructor
MultiDimCombineAndProjectDefault(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &), TABLE< GUM_SCALAR > *(*project)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Default constructor.
virtual MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE > * newFactory() const
virtual constructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
virtual std::pair< long, long > memoryUsage(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns the memory consumption used during the combinations and projections
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
virtual void setProjectFunction(TABLE< GUM_SCALAR > *(*proj)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Changes the function used for projecting TABLES.
SequenceIteratorSafe< Key > const_iterator_safe
Types for STL compliance.
Definition: sequence.h:1035
virtual Set< const TABLE< GUM_SCALAR > *> combineAndProject(Set< const TABLE< GUM_SCALAR > * > set, Set< const DiscreteVariable * > del_vars)
creates and returns the result of the projection over the variables not in del_vars of the combinatio...