30 #ifndef DOXYGEN_SHOULD_SKIP_THIS 41 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
44 TABLE< GUM_SCALAR >* (*combine)(
const TABLE< GUM_SCALAR >&,
45 const TABLE< GUM_SCALAR >&),
46 TABLE< GUM_SCALAR >* (*project)(
const TABLE< GUM_SCALAR >&,
47 const Set< const DiscreteVariable* >&)) :
48 MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
49 __combination(new MultiDimCombinationDefault< GUM_SCALAR, TABLE >(combine)),
50 __projection(new MultiDimProjection< GUM_SCALAR, TABLE >(project)) {
56 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
59 const MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >& from) :
68 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
78 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
79 MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >*
81 return new MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >(*this);
85 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
86 Set< const TABLE< GUM_SCALAR >* >
88 Set<
const TABLE< GUM_SCALAR >* > table_set,
89 Set< const DiscreteVariable* > del_vars) {
100 Set< const DiscreteVariable* > all_vars;
102 for (
const auto ptrTab : table_set) {
103 for (
const auto ptrVar : ptrTab->variablesSequence()) {
104 all_vars.insert(ptrVar);
108 nb_vars = all_vars.size();
112 HashTable< const DiscreteVariable*, Set< const TABLE< GUM_SCALAR >* > >
113 tables_per_var(nb_vars);
120 HashTable<
const DiscreteVariable*,
121 HashTable< const DiscreteVariable*, unsigned int > >
122 tables_vars_per_var(nb_vars);
126 Set< const TABLE< GUM_SCALAR >* > empty_set(table_set.size());
127 HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
129 for (
const auto ptrVar : del_vars) {
130 tables_per_var.insert(ptrVar, empty_set);
131 tables_vars_per_var.insert(ptrVar, empty_hash);
135 for (
const auto ptrTab : table_set) {
136 const Sequence< const DiscreteVariable* >& vars =
137 ptrTab->variablesSequence();
139 for (
const auto ptrVar : vars) {
140 if (del_vars.contains(ptrVar)) {
142 tables_per_var[ptrVar].insert(ptrTab);
145 HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
146 tables_vars_per_var[ptrVar];
148 for (
const auto xptrVar : vars) {
150 ++iter_vars[xptrVar];
151 }
catch (
const NotFound&) { iter_vars.insert(xptrVar, 1); }
159 PriorityQueue< const DiscreteVariable*, double > product_size;
162 for (
const auto& elt : tables_vars_per_var) {
164 const auto ptrVar = elt.first;
165 const auto& hashvars = elt.second;
167 if (hashvars.size()) {
168 for (
const auto& xelt : hashvars) {
169 size *= (
double) xelt.first->domainSize();
172 product_size.insert(ptrVar, size);
178 Set< const TABLE< GUM_SCALAR >* > tmp_marginals(table_set.size());
182 while (!product_size.empty()) {
184 const DiscreteVariable* del_var = product_size.pop();
185 del_vars.erase(del_var);
188 Set< const TABLE< GUM_SCALAR >* >& tables_to_combine =
189 tables_per_var[del_var];
192 if (tables_to_combine.size() == 0)
continue;
197 TABLE< GUM_SCALAR >* joint;
199 bool joint_to_delete =
false;
201 if (tables_to_combine.size() == 1) {
203 const_cast< TABLE< GUM_SCALAR >*
>(*(tables_to_combine.begin()));
204 joint_to_delete =
false;
208 joint_to_delete =
true;
212 Set< const DiscreteVariable* > del_one_var;
213 del_one_var << del_var;
215 TABLE< GUM_SCALAR >* marginal =
__projection->project(*joint, del_one_var);
218 if (joint_to_delete)
delete joint;
226 for (
const auto ptrTab : tables_to_combine) {
227 const Sequence< const DiscreteVariable* >& table_vars =
228 ptrTab->variablesSequence();
229 const Size tab_vars_size = table_vars.size();
231 for (
Size i = 0; i < tab_vars_size; ++i) {
232 if (del_vars.contains(table_vars[i])) {
236 HashTable< const DiscreteVariable*, unsigned int >&
237 table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
238 double div_size = 1.0;
240 for (
Size j = 0; j < tab_vars_size; ++j) {
241 unsigned int k = --table_vars_of_var_i[table_vars[j]];
244 div_size *= table_vars[j]->domainSize();
245 table_vars_of_var_i.erase(table_vars[j]);
249 tables_per_var[table_vars[i]].erase(ptrTab);
251 if (div_size != 1.0) {
252 product_size.setPriority(
253 table_vars[i], product_size.priority(table_vars[i]) / div_size);
258 if (tmp_marginals.contains(ptrTab)) {
260 tmp_marginals.erase(ptrTab);
263 table_set.erase(ptrTab);
266 tables_per_var.erase(del_var);
269 const Sequence< const DiscreteVariable* >& marginal_vars =
270 marginal->variablesSequence();
272 for (
const auto mvar : marginal_vars) {
273 if (del_vars.contains(mvar)) {
275 tables_per_var[mvar].insert(marginal);
278 HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
279 tables_vars_per_var[mvar];
280 double mult_size = 1.0;
282 for (
const auto var : marginal_vars) {
286 catch (
const NotFound&) {
287 iter_vars.insert(var, 1);
288 mult_size *= (
double) var->domainSize();
292 if (mult_size != 1.0) {
293 product_size.setPriority(mvar,
294 product_size.priority(mvar) * mult_size);
299 table_set.insert(marginal);
300 tmp_marginals.insert(marginal);
312 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
315 TABLE< GUM_SCALAR >* (*combine)(
const TABLE< GUM_SCALAR >&,
316 const TABLE< GUM_SCALAR >&)) {
321 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
322 INLINE TABLE< GUM_SCALAR >* (
324 const TABLE< GUM_SCALAR >&,
const TABLE< GUM_SCALAR >&) {
329 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
332 const MultiDimCombination< GUM_SCALAR, TABLE >& comb_class) {
338 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
341 TABLE< GUM_SCALAR >* (*proj)(
const TABLE< GUM_SCALAR >&,
342 const Set< const DiscreteVariable* >&)) {
347 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
348 INLINE TABLE< GUM_SCALAR >* (
350 const TABLE< GUM_SCALAR >&,
const Set< const DiscreteVariable* >&) {
355 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
358 const MultiDimProjection< GUM_SCALAR, TABLE >& proj_class) {
365 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
367 const Set<
const Sequence< const DiscreteVariable* >* >& table_set,
368 Set< const DiscreteVariable* > del_vars)
const {
383 Set< const DiscreteVariable* > all_vars;
385 for (
const auto ptrSeq : table_set) {
386 for (
const auto ptrVar : *ptrSeq) {
387 all_vars.insert(ptrVar);
391 nb_vars = all_vars.size();
396 HashTable<
const DiscreteVariable*,
397 Set< const Sequence< const DiscreteVariable* >* > >
398 tables_per_var(nb_vars);
405 HashTable<
const DiscreteVariable*,
406 HashTable< const DiscreteVariable*, unsigned int > >
407 tables_vars_per_var(nb_vars);
411 Set< const Sequence< const DiscreteVariable* >* > empty_set(
413 HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
415 for (
const auto ptrVar : del_vars) {
416 tables_per_var.insert(ptrVar, empty_set);
417 tables_vars_per_var.insert(ptrVar, empty_hash);
421 for (
const auto ptrSeq : table_set) {
422 const Sequence< const DiscreteVariable* >& vars = *ptrSeq;
424 for (
const auto ptrVar : vars) {
425 if (del_vars.contains(ptrVar)) {
427 tables_per_var[ptrVar].insert(ptrSeq);
430 HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
431 tables_vars_per_var[ptrVar];
433 for (
const auto xptrVar : vars) {
435 ++iter_vars[xptrVar];
436 }
catch (
const NotFound&) { iter_vars.insert(xptrVar, 1); }
444 PriorityQueue< const DiscreteVariable*, double > product_size;
447 for (
const auto& elt : tables_vars_per_var) {
449 const auto ptrVar = elt.first;
450 const auto hashvars = elt.second;
452 if (hashvars.size()) {
453 for (
const auto& xelt : hashvars) {
454 size *= (
double) xelt.first->domainSize();
457 product_size.insert(ptrVar, size);
462 float nb_operations = 0;
466 Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
471 while (!product_size.empty()) {
473 const DiscreteVariable* del_var = product_size.pop();
474 del_vars.erase(del_var);
477 Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
478 tables_per_var[del_var];
481 if (tables_to_combine.size() == 0)
continue;
486 Sequence< const DiscreteVariable* >* joint;
488 bool joint_to_delete =
false;
490 if (tables_to_combine.size() == 1) {
491 joint =
const_cast< Sequence< const DiscreteVariable* >*
>(
492 *(tables_to_combine.beginSafe()));
493 joint_to_delete =
false;
496 joint =
new Sequence< const DiscreteVariable* >;
498 for (
const auto ptrSeq : tables_to_combine) {
499 for (
const auto ptrVar : *ptrSeq) {
500 if (!joint->exists(ptrVar)) { joint->insert(ptrVar); }
504 joint_to_delete =
true;
507 nb_operations +=
__combination->nbOperations(tables_to_combine);
511 Set< const DiscreteVariable* > del_one_var;
512 del_one_var << del_var;
514 nb_operations +=
__projection->nbOperations(*joint, del_one_var);
517 Sequence< const DiscreteVariable* >* marginal;
519 if (joint_to_delete) {
522 marginal =
new Sequence< const DiscreteVariable* >(*joint);
525 marginal->erase(del_var);
533 for (
const auto ptrSeq : tables_to_combine) {
534 const Sequence< const DiscreteVariable* >& table_vars = *ptrSeq;
535 const Size tab_vars_size = table_vars.size();
537 for (
Size i = 0; i < tab_vars_size; ++i) {
538 if (del_vars.contains(table_vars[i])) {
541 HashTable< const DiscreteVariable*, unsigned int >&
542 table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
543 double div_size = 1.0;
545 for (
Size j = 0; j < tab_vars_size; ++j) {
546 unsigned int k = --table_vars_of_var_i[table_vars[j]];
549 div_size *= table_vars[j]->domainSize();
550 table_vars_of_var_i.erase(table_vars[j]);
554 tables_per_var[table_vars[i]].erase(ptrSeq);
556 if (div_size != 1.0) {
557 product_size.setPriority(
558 table_vars[i], product_size.priority(table_vars[i]) / div_size);
563 if (tmp_marginals.contains(ptrSeq)) {
565 tmp_marginals.erase(ptrSeq);
569 tables_per_var.erase(del_var);
572 for (
const auto mvar : *marginal) {
573 if (del_vars.contains(mvar)) {
575 tables_per_var[mvar].insert(marginal);
578 HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
579 tables_vars_per_var[mvar];
580 double mult_size = 1.0;
582 for (
const auto var : *marginal) {
585 }
catch (
const NotFound&) {
586 iter_vars.insert(var, 1);
587 mult_size *= (
double) var->domainSize();
591 if (mult_size != 1.0) {
592 product_size.setPriority(mvar,
593 product_size.priority(mvar) * mult_size);
598 tmp_marginals.insert(marginal);
602 for (
auto iter = tmp_marginals.beginSafe();
603 iter != tmp_marginals.endSafe();
608 return nb_operations;
613 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
615 const Set<
const TABLE< GUM_SCALAR >* >&
set,
616 const Set< const DiscreteVariable* >& del_vars)
const {
618 Set< const Sequence< const DiscreteVariable* >* > var_set(
set.size());
620 for (
const auto ptrTab :
set) {
621 var_set << &(ptrTab->variablesSequence());
629 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
630 std::pair< long, long >
632 const Set<
const Sequence< const DiscreteVariable* >* >& table_set,
633 Set< const DiscreteVariable* > del_vars)
const {
648 Set< const DiscreteVariable* > all_vars;
650 for (
const auto ptrSeq : table_set) {
651 for (
const auto ptrVar : *ptrSeq) {
652 all_vars.insert(ptrVar);
656 nb_vars = all_vars.size();
660 HashTable<
const DiscreteVariable*,
661 Set< const Sequence< const DiscreteVariable* >* > >
662 tables_per_var(nb_vars);
668 HashTable<
const DiscreteVariable*,
669 HashTable< const DiscreteVariable*, unsigned int > >
670 tables_vars_per_var(nb_vars);
674 Set< const Sequence< const DiscreteVariable* >* > empty_set(
676 HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
678 for (
const auto ptrVar : del_vars) {
679 tables_per_var.insert(ptrVar, empty_set);
680 tables_vars_per_var.insert(ptrVar, empty_hash);
684 for (
const auto ptrSeq : table_set) {
685 const Sequence< const DiscreteVariable* >& vars = *ptrSeq;
687 for (
const auto ptrVar : vars) {
688 if (del_vars.contains(ptrVar)) {
690 tables_per_var[ptrVar].insert(ptrSeq);
693 HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
694 tables_vars_per_var[ptrVar];
696 for (
const auto xptrVar : vars) {
698 ++iter_vars[xptrVar];
699 }
catch (
const NotFound&) { iter_vars.insert(xptrVar, 1); }
707 PriorityQueue< const DiscreteVariable*, double > product_size;
710 for (
const auto& elt : tables_vars_per_var) {
712 const auto ptrVar = elt.first;
713 const auto hashvars = elt.second;
715 if (hashvars.size()) {
716 for (
const auto& xelt : hashvars) {
717 size *= (
double) xelt.first->domainSize();
720 product_size.insert(ptrVar, size);
726 long current_memory = 0;
730 Set< const Sequence< const DiscreteVariable* >* > tmp_marginals(
736 while (!product_size.empty()) {
738 const DiscreteVariable* del_var = product_size.pop();
739 del_vars.erase(del_var);
742 Set< const Sequence< const DiscreteVariable* >* >& tables_to_combine =
743 tables_per_var[del_var];
746 if (tables_to_combine.size() == 0)
continue;
751 Sequence< const DiscreteVariable* >* joint;
753 bool joint_to_delete =
false;
755 if (tables_to_combine.size() == 1) {
756 joint =
const_cast< Sequence< const DiscreteVariable* >*
>(
757 *(tables_to_combine.beginSafe()));
758 joint_to_delete =
false;
761 joint =
new Sequence< const DiscreteVariable* >;
763 for (
const auto ptrSeq : tables_to_combine) {
764 for (
const auto ptrVar : *ptrSeq) {
765 if (!joint->exists(ptrVar)) { joint->insert(ptrVar); }
769 joint_to_delete =
true;
772 std::pair< long, long > comb_memory =
775 if ((std::numeric_limits< long >::max() - current_memory
777 || (std::numeric_limits< long >::max() - current_memory
778 < comb_memory.second)) {
779 GUM_ERROR(OutOfBounds,
"memory usage out of long int range");
782 if (current_memory + comb_memory.first > max_memory) {
783 max_memory = current_memory + comb_memory.first;
786 current_memory += comb_memory.second;
790 Set< const DiscreteVariable* > del_one_var;
791 del_one_var << del_var;
793 std::pair< long, long > comb_memory =
796 if ((std::numeric_limits< long >::max() - current_memory < comb_memory.first)
797 || (std::numeric_limits< long >::max() - current_memory
798 < comb_memory.second)) {
799 GUM_ERROR(OutOfBounds,
"memory usage out of long int range");
802 if (current_memory + comb_memory.first > max_memory) {
803 max_memory = current_memory + comb_memory.first;
806 current_memory += comb_memory.second;
809 Sequence< const DiscreteVariable* >* marginal;
811 if (joint_to_delete) {
814 marginal =
new Sequence< const DiscreteVariable* >(*joint);
817 marginal->erase(del_var);
825 for (
const auto ptrSeq : tables_to_combine) {
826 const Sequence< const DiscreteVariable* >& table_vars = *ptrSeq;
827 const Size tab_vars_size = table_vars.size();
829 for (
Size i = 0; i < tab_vars_size; ++i) {
830 if (del_vars.contains(table_vars[i])) {
833 HashTable< const DiscreteVariable*, unsigned int >&
834 table_vars_of_var_i = tables_vars_per_var[table_vars[i]];
835 double div_size = 1.0;
837 for (
Size j = 0; j < tab_vars_size; ++j) {
838 Size k = --table_vars_of_var_i[table_vars[j]];
841 div_size *= table_vars[j]->domainSize();
842 table_vars_of_var_i.erase(table_vars[j]);
846 tables_per_var[table_vars[i]].erase(ptrSeq);
849 product_size.setPriority(
850 table_vars[i], product_size.priority(table_vars[i]) / div_size);
855 if (tmp_marginals.contains(ptrSeq)) {
858 for (
const auto ptrVar : *ptrSeq) {
859 del_size *= ptrVar->domainSize();
862 current_memory -= long(del_size);
865 tmp_marginals.erase(ptrSeq);
869 tables_per_var.erase(del_var);
872 for (
const auto mvar : *marginal) {
873 if (del_vars.contains(mvar)) {
875 tables_per_var[mvar].insert(marginal);
878 HashTable< const DiscreteVariable*, unsigned int >& iter_vars =
879 tables_vars_per_var[mvar];
880 double mult_size = 1.0;
882 for (
const auto var : *marginal) {
885 }
catch (
const NotFound&) {
886 iter_vars.insert(var, 1);
887 mult_size *= (
double) var->domainSize();
891 if (mult_size != 1) {
892 product_size.setPriority(mvar,
893 product_size.priority(mvar) * mult_size);
898 tmp_marginals.insert(marginal);
902 for (
auto iter = tmp_marginals.beginSafe();
903 iter != tmp_marginals.endSafe();
908 return std::pair< long, long >(max_memory, current_memory);
913 template <
typename GUM_SCALAR,
template <
typename >
class TABLE >
914 std::pair< long, long >
916 const Set<
const TABLE< GUM_SCALAR >* >&
set,
917 const Set< const DiscreteVariable* >& del_vars)
const {
919 Set< const Sequence< const DiscreteVariable* >* > var_set(
set.size());
921 for (
const auto ptrTab :
set) {
922 var_set << &(ptrTab->variablesSequence());
virtual ~MultiDimCombineAndProjectDefault()
Destructor.
virtual void setCombineFunction(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &))
changes the function used for combining two TABLES
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
MultiDimProjection< GUM_SCALAR, TABLE > * __projection
the class used for the projections
virtual float nbOperations(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns a rough estimate of the number of operations that will be performed to compute the combinatio...
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &) combineFunction()
Returns the current combination function.
virtual void setProjectionClass(const MultiDimProjection< GUM_SCALAR, TABLE > &proj_class)
Changes the class that performs the projections.
virtual TABLE< GUM_SCALAR > *(*)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable *> &) projectFunction()
returns the current projection function
virtual void setCombinationClass(const MultiDimCombination< GUM_SCALAR, TABLE > &comb_class)
changes the class that performs the combinations
MultiDimCombination< GUM_SCALAR, TABLE > * __combination
the class used for the combinations
MultiDimCombineAndProject()
default constructor
MultiDimCombineAndProjectDefault(TABLE< GUM_SCALAR > *(*combine)(const TABLE< GUM_SCALAR > &, const TABLE< GUM_SCALAR > &), TABLE< GUM_SCALAR > *(*project)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Default constructor.
virtual MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE > * newFactory() const
virtual constructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
virtual std::pair< long, long > memoryUsage(const Set< const TABLE< GUM_SCALAR > * > &set, const Set< const DiscreteVariable * > &del_vars) const
returns the memory consumption used during the combinations and projections
#define GUM_ERROR(type, msg)
virtual void setProjectFunction(TABLE< GUM_SCALAR > *(*proj)(const TABLE< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
Changes the function used for projecting TABLES.
virtual Set< const TABLE< GUM_SCALAR > *> combineAndProject(Set< const TABLE< GUM_SCALAR > * > set, Set< const DiscreteVariable * > del_vars)
creates and returns the result of the projection over the variables not in del_vars of the combinatio...