aGrUM
0.21.0
a C++ library for (probabilistic) graphical models
multiDimCombineAndProjectDefault_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief An efficient class for combining and projecting MultiDim tables
25
*
26
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
29
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
30
31
#
include
<
limits
>
32
33
#
include
<
agrum
/
agrum
.
h
>
34
35
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
36
37
namespace
gum
{
38
39
// default constructor
40
template
<
typename
GUM_SCALAR,
template
<
typename
>
class
TABLE >
41
MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >::MultiDimCombineAndProjectDefault(
42
TABLE< GUM_SCALAR >* (*combine)(
const
TABLE< GUM_SCALAR >&,
const
TABLE< GUM_SCALAR >&),
43
TABLE< GUM_SCALAR >* (*project)(
const
TABLE< GUM_SCALAR >&,
44
const
Set<
const
DiscreteVariable* >&)) :
45
MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
46
_combination_(
new
MultiDimCombinationDefault< GUM_SCALAR, TABLE >(combine)),
47
_projection_(
new
MultiDimProjection< GUM_SCALAR, TABLE >(project)) {
48
// for debugging purposes
49
GUM_CONSTRUCTOR(MultiDimCombineAndProjectDefault);
50
}
51
52
// copy constructor
53
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
54
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
MultiDimCombineAndProjectDefault
(
55
const
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>&
from
) :
56
MultiDimCombineAndProject
<
GUM_SCALAR
,
TABLE
>(),
57
_combination_
(
from
.
_combination_
->
newFactory
()),
58
_projection_
(
from
.
_projection_
->
newFactory
()) {
59
// for debugging purposes
60
GUM_CONS_CPY
(
MultiDimCombineAndProjectDefault
);
61
}
62
63
// destructor
64
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
65
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::~
MultiDimCombineAndProjectDefault
() {
66
// for debugging purposes
67
GUM_DESTRUCTOR
(
MultiDimCombineAndProjectDefault
);
68
delete
_combination_
;
69
delete
_projection_
;
70
}
71
72
// virtual constructor
73
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
74
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>*
75
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
newFactory
()
const
{
76
return
new
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>(*
this
);
77
}
78
79
// combine and project
80
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
81
Set
<
const
TABLE
<
GUM_SCALAR
>* >
82
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
combineAndProject
(
83
Set
<
const
TABLE
<
GUM_SCALAR
>* >
table_set
,
84
Set
<
const
DiscreteVariable
* >
del_vars
) {
85
// when we remove a variable, we need to combine all the tables containing
86
// this variable in order to produce a new unique table containing this
87
// variable. Removing the variable is then performed by marginalizing it
88
// out of the table. In the combineAndProject algorithm, we wish to remove
89
// first variables that produce small tables. This should speed up the
90
// marginalizing process
91
Size
nb_vars
;
92
{
93
// determine the set of all the variables involved in the tables.
94
// this should help sizing correctly the hashtables
95
Set
<
const
DiscreteVariable
* >
all_vars
;
96
97
for
(
const
auto
ptrTab
:
table_set
) {
98
for
(
const
auto
ptrVar
:
ptrTab
->
variablesSequence
()) {
99
all_vars
.
insert
(
ptrVar
);
100
}
101
}
102
103
nb_vars
=
all_vars
.
size
();
104
}
105
106
// the tables containing a given variable to be deleted
107
HashTable
<
const
DiscreteVariable
*,
Set
<
const
TABLE
<
GUM_SCALAR
>* > >
tables_per_var
(
nb_vars
);
108
109
// for a given variable X to be deleted, the list of all the variables of
110
// the tables containing X (actually, we also count the number of tables
111
// containing the variable. This is more efficient for computing and
112
// updating the product_size priority queue (see below) when some tables
113
// are removed)
114
HashTable
<
const
DiscreteVariable
*,
HashTable
<
const
DiscreteVariable
*,
unsigned
int
> >
115
tables_vars_per_var
(
nb_vars
);
116
117
// initialize tables_vars_per_var and tables_per_var
118
{
119
Set
<
const
TABLE
<
GUM_SCALAR
>* >
empty_set
(
table_set
.
size
());
120
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>
empty_hash
(
nb_vars
);
121
122
for
(
const
auto
ptrVar
:
del_vars
) {
123
tables_per_var
.
insert
(
ptrVar
,
empty_set
);
124
tables_vars_per_var
.
insert
(
ptrVar
,
empty_hash
);
125
}
126
127
// update properly tables_per_var and tables_vars_per_var
128
for
(
const
auto
ptrTab
:
table_set
) {
129
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
ptrTab
->
variablesSequence
();
130
131
for
(
const
auto
ptrVar
:
vars
) {
132
if
(
del_vars
.
contains
(
ptrVar
)) {
133
// add the table to the set of tables related to vars[i]
134
tables_per_var
[
ptrVar
].
insert
(
ptrTab
);
135
136
// add the variables of the table to tables_vars_per_var[vars[i]]
137
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
138
=
tables_vars_per_var
[
ptrVar
];
139
140
for
(
const
auto
xptrVar
:
vars
) {
141
try
{
142
++
iter_vars
[
xptrVar
];
143
}
catch
(
const
NotFound
&) {
iter_vars
.
insert
(
xptrVar
, 1); }
144
}
145
}
146
}
147
}
148
}
149
150
// the sizes of the tables produced when removing a given discrete variable
151
PriorityQueue
<
const
DiscreteVariable
*,
double
>
product_size
;
152
153
// initialize properly product_size
154
for
(
const
auto
&
elt
:
tables_vars_per_var
) {
155
double
size
= 1.0;
156
const
auto
ptrVar
=
elt
.
first
;
157
const
auto
&
hashvars
=
elt
.
second
;
// HashTable<DiscreteVariable*, int>
158
159
if
(
hashvars
.
size
()) {
160
for
(
const
auto
&
xelt
:
hashvars
) {
161
size
*= (
double
)
xelt
.
first
->
domainSize
();
162
}
163
164
product_size
.
insert
(
ptrVar
,
size
);
165
}
166
}
167
168
// create a set of the temporary tables created during the
169
// marginalization process (useful for deallocating temporary tables)
170
Set
<
const
TABLE
<
GUM_SCALAR
>* >
tmp_marginals
(
table_set
.
size
());
171
172
// now, remove all the variables in del_vars, starting from those that
173
// produce the smallest tables
174
while
(!
product_size
.
empty
()) {
175
// get the best variable to remove
176
const
DiscreteVariable
*
del_var
=
product_size
.
pop
();
177
del_vars
.
erase
(
del_var
);
178
179
// get the set of tables to combine
180
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
tables_to_combine
=
tables_per_var
[
del_var
];
181
182
// if there is no tables to combine, do nothing
183
if
(
tables_to_combine
.
size
() == 0)
continue
;
184
185
// compute the combination of all the tables: if there is only one table,
186
// there is nothing to do, else we shall use the MultiDimCombination
187
// to perform the combination
188
TABLE
<
GUM_SCALAR
>*
joint
;
189
190
bool
joint_to_delete
=
false
;
191
192
if
(
tables_to_combine
.
size
() == 1) {
193
joint
=
const_cast
<
TABLE
<
GUM_SCALAR
>* >(*(
tables_to_combine
.
begin
()));
194
joint_to_delete
=
false
;
195
}
else
{
196
joint
=
_combination_
->
combine
(
tables_to_combine
);
197
joint_to_delete
=
true
;
198
}
199
200
// compute the table resulting from marginalizing out del_var from joint
201
Set
<
const
DiscreteVariable
* >
del_one_var
;
202
del_one_var
<<
del_var
;
203
204
TABLE
<
GUM_SCALAR
>*
marginal
=
_projection_
->
project
(*
joint
,
del_one_var
);
205
206
// remove the temporary joint if needed
207
if
(
joint_to_delete
)
delete
joint
;
208
209
// update tables_vars_per_var : remove the variables of the TABLEs we
210
// combined from this hashtable
211
// update accordingly tables_per_vars : remove these TABLEs
212
// update accordingly product_size : when a variable is no more used by
213
// any TABLE, divide product_size by its domain size
214
215
for
(
const
auto
ptrTab
:
tables_to_combine
) {
216
const
Sequence
<
const
DiscreteVariable
* >&
table_vars
=
ptrTab
->
variablesSequence
();
217
const
Size
tab_vars_size
=
table_vars
.
size
();
218
219
for
(
Size
i
= 0;
i
<
tab_vars_size
; ++
i
) {
220
if
(
del_vars
.
contains
(
table_vars
[
i
])) {
221
// ok, here we have a variable that needed to be removed => update
222
// product_size, tables_per_var and tables_vars_per_var: here,
223
// the update corresponds to removing table PtrTab
224
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
table_vars_of_var_i
225
=
tables_vars_per_var
[
table_vars
[
i
]];
226
double
div_size
= 1.0;
227
228
for
(
Size
j
= 0;
j
<
tab_vars_size
; ++
j
) {
229
unsigned
int
k
= --
table_vars_of_var_i
[
table_vars
[
j
]];
230
231
if
(
k
== 0) {
232
div_size
*=
table_vars
[
j
]->
domainSize
();
233
table_vars_of_var_i
.
erase
(
table_vars
[
j
]);
234
}
235
}
236
237
tables_per_var
[
table_vars
[
i
]].
erase
(
ptrTab
);
238
239
if
(
div_size
!= 1.0) {
240
product_size
.
setPriority
(
table_vars
[
i
],
241
product_size
.
priority
(
table_vars
[
i
]) /
div_size
);
242
}
243
}
244
}
245
246
if
(
tmp_marginals
.
contains
(
ptrTab
)) {
247
delete
ptrTab
;
248
tmp_marginals
.
erase
(
ptrTab
);
249
}
250
251
table_set
.
erase
(
ptrTab
);
252
}
253
254
tables_per_var
.
erase
(
del_var
);
255
256
// add the new projected marginal to the list of TABLES
257
const
Sequence
<
const
DiscreteVariable
* >&
marginal_vars
=
marginal
->
variablesSequence
();
258
259
for
(
const
auto
mvar
:
marginal_vars
) {
260
if
(
del_vars
.
contains
(
mvar
)) {
261
// add the new marginal table to the set of tables of mvar
262
tables_per_var
[
mvar
].
insert
(
marginal
);
263
264
// add the variables of the table to tables_vars_per_var[mvar]
265
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
=
tables_vars_per_var
[
mvar
];
266
double
mult_size
= 1.0;
267
268
for
(
const
auto
var
:
marginal_vars
) {
269
try
{
270
++
iter_vars
[
var
];
271
}
catch
(
const
NotFound
&) {
272
iter_vars
.
insert
(
var
, 1);
273
mult_size
*= (
double
)
var
->
domainSize
();
274
}
275
}
276
277
if
(
mult_size
!= 1.0) {
278
product_size
.
setPriority
(
mvar
,
product_size
.
priority
(
mvar
) *
mult_size
);
279
}
280
}
281
}
282
283
table_set
.
insert
(
marginal
);
284
tmp_marginals
.
insert
(
marginal
);
285
}
286
287
// here, tmp_marginals contains all the newly created tables and
288
// table_set contains the list of the tables resulting from the
289
// marginalizing out of del_vars of the combination of the tables
290
// of table_set. Note in particular that it will contain all the
291
// potentials with no dimension (constants)
292
return
table_set
;
293
}
294
295
// changes the function used for combining two TABLES
296
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
297
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setCombineFunction
(
298
TABLE
<
GUM_SCALAR
>* (*
combine
)(
const
TABLE
<
GUM_SCALAR
>&,
const
TABLE
<
GUM_SCALAR
>&)) {
299
_combination_
->
setCombineFunction
(
combine
);
300
}
301
302
// returns the current combination function
303
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
304
INLINE
TABLE
<
GUM_SCALAR
>* (
305
*
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
306
TABLE
>::
combineFunction
())(
const
TABLE
<
GUM_SCALAR
>&,
307
const
TABLE
<
GUM_SCALAR
>&) {
308
return
_combination_
->
combineFunction
();
309
}
310
311
// changes the class that performs the combinations
312
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
313
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setCombinationClass
(
314
const
MultiDimCombination
<
GUM_SCALAR
,
TABLE
>&
comb_class
) {
315
delete
_combination_
;
316
_combination_
=
comb_class
.
newFactory
();
317
}
318
319
// changes the function used for projecting TABLES
320
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
321
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setProjectFunction
(
322
TABLE
<
GUM_SCALAR
>* (*
proj
)(
const
TABLE
<
GUM_SCALAR
>&,
323
const
Set
<
const
DiscreteVariable
* >&)) {
324
_projection_
->
setProjectFunction
(
proj
);
325
}
326
327
// returns the current projection function
328
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
329
INLINE
TABLE
<
GUM_SCALAR
>* (*
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
330
projectFunction
())(
const
TABLE
<
GUM_SCALAR
>&,
331
const
Set
<
const
DiscreteVariable
* >&) {
332
return
_projection_
->
projectFunction
();
333
}
334
335
// changes the class that performs the projections
336
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
337
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setProjectionClass
(
338
const
MultiDimProjection
<
GUM_SCALAR
,
TABLE
>&
proj_class
) {
339
delete
_projection_
;
340
_projection_
=
proj_class
.
newFactory
();
341
}
342
343
/** @brief returns a rough estimate of the number of operations that will be
344
* performed to compute the combination */
345
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
346
float
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
nbOperations
(
347
const
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
table_set
,
348
Set
<
const
DiscreteVariable
* >
del_vars
)
const
{
349
// when we remove a variable, we need to combine all the tables containing
350
// this variable in order to produce a new unique table containing this
351
// variable. Here, we do not have the tables but only their variables
352
// (dimensions), but the principle is identical. Removing a variable is then
353
// performed by marginalizing it out of the table or, equivalently, to
354
// remove it from the table's list of variables. In the
355
// combineAndProjectDefault algorithm, we wish to remove first variables
356
// that would produce small tables. This should speed up the whole
357
// marginalizing process.
358
359
Size
nb_vars
;
360
{
361
// determine the set of all the variables involved in the tables.
362
// this should help sizing correctly the hashtables
363
Set
<
const
DiscreteVariable
* >
all_vars
;
364
365
for
(
const
auto
ptrSeq
:
table_set
) {
366
for
(
const
auto
ptrVar
: *
ptrSeq
) {
367
all_vars
.
insert
(
ptrVar
);
368
}
369
}
370
371
nb_vars
=
all_vars
.
size
();
372
}
373
374
// the tables (actually their variables) containing a given variable
375
// to be deleted
376
HashTable
<
const
DiscreteVariable
*,
Set
<
const
Sequence
<
const
DiscreteVariable
* >* > >
377
tables_per_var
(
nb_vars
);
378
379
// for a given variable X to be deleted, the list of all the variables of
380
// the tables containing X (actually, we count the number of tables
381
// containing the variable. This is more efficient for computing and
382
// updating the product_size priority queue (see below) when some tables
383
// are removed)
384
HashTable
<
const
DiscreteVariable
*,
HashTable
<
const
DiscreteVariable
*,
unsigned
int
> >
385
tables_vars_per_var
(
nb_vars
);
386
387
// initialize tables_vars_per_var and tables_per_var
388
{
389
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
empty_set
(
table_set
.
size
());
390
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>
empty_hash
(
nb_vars
);
391
392
for
(
const
auto
ptrVar
:
del_vars
) {
393
tables_per_var
.
insert
(
ptrVar
,
empty_set
);
394
tables_vars_per_var
.
insert
(
ptrVar
,
empty_hash
);
395
}
396
397
// update properly tables_per_var and tables_vars_per_var
398
for
(
const
auto
ptrSeq
:
table_set
) {
399
const
Sequence
<
const
DiscreteVariable
* >&
vars
= *
ptrSeq
;
400
401
for
(
const
auto
ptrVar
:
vars
) {
402
if
(
del_vars
.
contains
(
ptrVar
)) {
403
// add the table's variables to the set of those related to ptrVar
404
tables_per_var
[
ptrVar
].
insert
(
ptrSeq
);
405
406
// add the variables of the table to tables_vars_per_var[ptrVar]
407
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
408
=
tables_vars_per_var
[
ptrVar
];
409
410
for
(
const
auto
xptrVar
:
vars
) {
411
try
{
412
++
iter_vars
[
xptrVar
];
413
}
catch
(
const
NotFound
&) {
iter_vars
.
insert
(
xptrVar
, 1); }
414
}
415
}
416
}
417
}
418
}
419
420
// the sizes of the tables produced when removing a given discrete variable
421
PriorityQueue
<
const
DiscreteVariable
*,
double
>
product_size
;
422
423
// initialize properly product_size
424
for
(
const
auto
&
elt
:
tables_vars_per_var
) {
425
double
size
= 1.0;
426
const
auto
ptrVar
=
elt
.
first
;
427
const
auto
hashvars
=
elt
.
second
;
// HashTable<DiscreteVariable*, int>
428
429
if
(
hashvars
.
size
()) {
430
for
(
const
auto
&
xelt
:
hashvars
) {
431
size
*= (
double
)
xelt
.
first
->
domainSize
();
432
}
433
434
product_size
.
insert
(
ptrVar
,
size
);
435
}
436
}
437
438
// the resulting number of operations
439
float
nb_operations
= 0;
440
441
// create a set of the temporary table's variables created during the
442
// marginalization process (useful for deallocating temporary tables)
443
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
tmp_marginals
(
table_set
.
size
());
444
445
// now, remove all the variables in del_vars, starting from those that
446
// produce the smallest tables
447
while
(!
product_size
.
empty
()) {
448
// get the best variable to remove
449
const
DiscreteVariable
*
del_var
=
product_size
.
pop
();
450
del_vars
.
erase
(
del_var
);
451
452
// get the set of tables to combine
453
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
tables_to_combine
454
=
tables_per_var
[
del_var
];
455
456
// if there is no tables to combine, do nothing
457
if
(
tables_to_combine
.
size
() == 0)
continue
;
458
459
// compute the combination of all the tables: if there is only one table,
460
// there is nothing to do, else we shall use the MultiDimCombination
461
// to perform the combination
462
Sequence
<
const
DiscreteVariable
* >*
joint
;
463
464
bool
joint_to_delete
=
false
;
465
466
if
(
tables_to_combine
.
size
() == 1) {
467
joint
468
=
const_cast
<
Sequence
<
const
DiscreteVariable
* >* >(*(
tables_to_combine
.
beginSafe
()));
469
joint_to_delete
=
false
;
470
}
else
{
471
// here, compute the union of all the variables of the tables to combine
472
joint
=
new
Sequence
<
const
DiscreteVariable
* >;
473
474
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
475
for
(
const
auto
ptrVar
: *
ptrSeq
) {
476
if
(!
joint
->
exists
(
ptrVar
)) {
joint
->
insert
(
ptrVar
); }
477
}
478
}
479
480
joint_to_delete
=
true
;
481
482
// update the number of operations performed
483
nb_operations
+=
_combination_
->
nbOperations
(
tables_to_combine
);
484
}
485
486
// update the number of operations performed by marginalizing out del_var
487
Set
<
const
DiscreteVariable
* >
del_one_var
;
488
del_one_var
<<
del_var
;
489
490
nb_operations
+=
_projection_
->
nbOperations
(*
joint
,
del_one_var
);
491
492
// compute the table resulting from marginalizing out del_var from joint
493
Sequence
<
const
DiscreteVariable
* >*
marginal
;
494
495
if
(
joint_to_delete
) {
496
marginal
=
joint
;
497
}
else
{
498
marginal
=
new
Sequence
<
const
DiscreteVariable
* >(*
joint
);
499
}
500
501
marginal
->
erase
(
del_var
);
502
503
// update tables_vars_per_var : remove the variables of the TABLEs we
504
// combined from this hashtable
505
// update accordingly tables_per_vars : remove these TABLEs
506
// update accordingly product_size : when a variable is no more used by
507
// any TABLE, divide product_size by its domain size
508
509
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
510
const
Sequence
<
const
DiscreteVariable
* >&
table_vars
= *
ptrSeq
;
511
const
Size
tab_vars_size
=
table_vars
.
size
();
512
513
for
(
Size
i
= 0;
i
<
tab_vars_size
; ++
i
) {
514
if
(
del_vars
.
contains
(
table_vars
[
i
])) {
515
// ok, here we have a variable that needed to be removed => update
516
// product_size, tables_per_var and tables_vars_per_var
517
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
table_vars_of_var_i
518
=
tables_vars_per_var
[
table_vars
[
i
]];
519
double
div_size
= 1.0;
520
521
for
(
Size
j
= 0;
j
<
tab_vars_size
; ++
j
) {
522
unsigned
int
k
= --
table_vars_of_var_i
[
table_vars
[
j
]];
523
524
if
(
k
== 0) {
525
div_size
*=
table_vars
[
j
]->
domainSize
();
526
table_vars_of_var_i
.
erase
(
table_vars
[
j
]);
527
}
528
}
529
530
tables_per_var
[
table_vars
[
i
]].
erase
(
ptrSeq
);
531
532
if
(
div_size
!= 1.0) {
533
product_size
.
setPriority
(
table_vars
[
i
],
534
product_size
.
priority
(
table_vars
[
i
]) /
div_size
);
535
}
536
}
537
}
538
539
if
(
tmp_marginals
.
contains
(
ptrSeq
)) {
540
delete
ptrSeq
;
541
tmp_marginals
.
erase
(
ptrSeq
);
542
}
543
}
544
545
tables_per_var
.
erase
(
del_var
);
546
547
// add the new projected marginal to the list of TABLES
548
for
(
const
auto
mvar
: *
marginal
) {
549
if
(
del_vars
.
contains
(
mvar
)) {
550
// add the new marginal table to the set of tables of var i
551
tables_per_var
[
mvar
].
insert
(
marginal
);
552
553
// add the variables of the table to tables_vars_per_var[vars[i]]
554
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
=
tables_vars_per_var
[
mvar
];
555
double
mult_size
= 1.0;
556
557
for
(
const
auto
var
: *
marginal
) {
558
try
{
559
++
iter_vars
[
var
];
560
}
catch
(
const
NotFound
&) {
561
iter_vars
.
insert
(
var
, 1);
562
mult_size
*= (
double
)
var
->
domainSize
();
563
}
564
}
565
566
if
(
mult_size
!= 1.0) {
567
product_size
.
setPriority
(
mvar
,
product_size
.
priority
(
mvar
) *
mult_size
);
568
}
569
}
570
}
571
572
tmp_marginals
.
insert
(
marginal
);
573
}
574
575
// here, tmp_marginals contains all the newly created tables
576
for
(
auto
iter
=
tmp_marginals
.
beginSafe
();
iter
!=
tmp_marginals
.
endSafe
(); ++
iter
) {
577
delete
*
iter
;
578
}
579
580
return
nb_operations
;
581
}
582
583
/** @brief returns a rough estimate of the number of operations that will be
584
* performed to compute the combination */
585
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
586
float
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
nbOperations
(
587
const
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
set
,
588
const
Set
<
const
DiscreteVariable
* >&
del_vars
)
const
{
589
// create the set of sets of discrete variables involved in the tables
590
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
var_set
(
set
.
size
());
591
592
for
(
const
auto
ptrTab
:
set
) {
593
var_set
<< &(
ptrTab
->
variablesSequence
());
594
}
595
596
return
nbOperations
(
var_set
,
del_vars
);
597
}
598
599
// returns the memory consumption used during the combinations and
600
// projections
601
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
602
std
::
pair
<
long
,
long
>
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
memoryUsage
(
603
const
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
table_set
,
604
Set
<
const
DiscreteVariable
* >
del_vars
)
const
{
605
// when we remove a variable, we need to combine all the tables containing
606
// this variable in order to produce a new unique table containing this
607
// variable. Here, we do not have the tables but only their variables
608
// (dimensions), but the principle is identical. Removing a variable is then
609
// performed by marginalizing it out of the table or, equivalently, to
610
// remove it from the table's list of variables. In the
611
// combineAndProjectDefault algorithm, we wish to remove first variables
612
// that would produce small tables. This should speed up the whole
613
// marginalizing process.
614
615
Size
nb_vars
;
616
{
617
// determine the set of all the variables involved in the tables.
618
// this should help sizing correctly the hashtables
619
Set
<
const
DiscreteVariable
* >
all_vars
;
620
621
for
(
const
auto
ptrSeq
:
table_set
) {
622
for
(
const
auto
ptrVar
: *
ptrSeq
) {
623
all_vars
.
insert
(
ptrVar
);
624
}
625
}
626
627
nb_vars
=
all_vars
.
size
();
628
}
629
630
// the tables (actually their variables) containing a given variable
631
HashTable
<
const
DiscreteVariable
*,
Set
<
const
Sequence
<
const
DiscreteVariable
* >* > >
632
tables_per_var
(
nb_vars
);
633
// for a given variable X to be deleted, the list of all the variables of
634
// the tables containing X (actually, we count the number of tables
635
// containing the variable. This is more efficient for computing and
636
// updating the product_size priority queue (see below) when some tables
637
// are removed)
638
HashTable
<
const
DiscreteVariable
*,
HashTable
<
const
DiscreteVariable
*,
unsigned
int
> >
639
tables_vars_per_var
(
nb_vars
);
640
641
// initialize tables_vars_per_var and tables_per_var
642
{
643
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
empty_set
(
table_set
.
size
());
644
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>
empty_hash
(
nb_vars
);
645
646
for
(
const
auto
ptrVar
:
del_vars
) {
647
tables_per_var
.
insert
(
ptrVar
,
empty_set
);
648
tables_vars_per_var
.
insert
(
ptrVar
,
empty_hash
);
649
}
650
651
// update properly tables_per_var and tables_vars_per_var
652
for
(
const
auto
ptrSeq
:
table_set
) {
653
const
Sequence
<
const
DiscreteVariable
* >&
vars
= *
ptrSeq
;
654
655
for
(
const
auto
ptrVar
:
vars
) {
656
if
(
del_vars
.
contains
(
ptrVar
)) {
657
// add the table's variables to the set of those related to ptrVar
658
tables_per_var
[
ptrVar
].
insert
(
ptrSeq
);
659
660
// add the variables of the table to tables_vars_per_var[ptrVar]
661
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
662
=
tables_vars_per_var
[
ptrVar
];
663
664
for
(
const
auto
xptrVar
:
vars
) {
665
try
{
666
++
iter_vars
[
xptrVar
];
667
}
catch
(
const
NotFound
&) {
iter_vars
.
insert
(
xptrVar
, 1); }
668
}
669
}
670
}
671
}
672
}
673
674
// the sizes of the tables produced when removing a given discrete variable
675
PriorityQueue
<
const
DiscreteVariable
*,
double
>
product_size
;
676
677
// initialize properly product_size
678
for
(
const
auto
&
elt
:
tables_vars_per_var
) {
679
double
size
= 1.0;
680
const
auto
ptrVar
=
elt
.
first
;
681
const
auto
hashvars
=
elt
.
second
;
// HashTable<DiscreteVariable*, int>
682
683
if
(
hashvars
.
size
()) {
684
for
(
const
auto
&
xelt
:
hashvars
) {
685
size
*= (
double
)
xelt
.
first
->
domainSize
();
686
}
687
688
product_size
.
insert
(
ptrVar
,
size
);
689
}
690
}
691
692
// the resulting memory consumtions
693
long
max_memory
= 0;
694
long
current_memory
= 0;
695
696
// create a set of the temporary table's variables created during the
697
// marginalization process (useful for deallocating temporary tables)
698
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
tmp_marginals
(
table_set
.
size
());
699
700
// now, remove all the variables in del_vars, starting from those that
701
// produce
702
// the smallest tables
703
while
(!
product_size
.
empty
()) {
704
// get the best variable to remove
705
const
DiscreteVariable
*
del_var
=
product_size
.
pop
();
706
del_vars
.
erase
(
del_var
);
707
708
// get the set of tables to combine
709
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
tables_to_combine
710
=
tables_per_var
[
del_var
];
711
712
// if there is no tables to combine, do nothing
713
if
(
tables_to_combine
.
size
() == 0)
continue
;
714
715
// compute the combination of all the tables: if there is only one table,
716
// there is nothing to do, else we shall use the MultiDimCombination
717
// to perform the combination
718
Sequence
<
const
DiscreteVariable
* >*
joint
;
719
720
bool
joint_to_delete
=
false
;
721
722
if
(
tables_to_combine
.
size
() == 1) {
723
joint
724
=
const_cast
<
Sequence
<
const
DiscreteVariable
* >* >(*(
tables_to_combine
.
beginSafe
()));
725
joint_to_delete
=
false
;
726
}
else
{
727
// here, compute the union of all the variables of the tables to combine
728
joint
=
new
Sequence
<
const
DiscreteVariable
* >;
729
730
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
731
for
(
const
auto
ptrVar
: *
ptrSeq
) {
732
if
(!
joint
->
exists
(
ptrVar
)) {
joint
->
insert
(
ptrVar
); }
733
}
734
}
735
736
joint_to_delete
=
true
;
737
738
// update the number of operations performed
739
std
::
pair
<
long
,
long
>
comb_memory
=
_combination_
->
memoryUsage
(
tables_to_combine
);
740
741
if
((
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
first
)
742
|| (
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
second
)) {
743
GUM_ERROR
(
OutOfBounds
,
"memory usage out of long int range"
)
744
}
745
746
if
(
current_memory
+
comb_memory
.
first
>
max_memory
) {
747
max_memory
=
current_memory
+
comb_memory
.
first
;
748
}
749
750
current_memory
+=
comb_memory
.
second
;
751
}
752
753
// update the number of operations performed by marginalizing out del_var
754
Set
<
const
DiscreteVariable
* >
del_one_var
;
755
del_one_var
<<
del_var
;
756
757
std
::
pair
<
long
,
long
>
comb_memory
=
_projection_
->
memoryUsage
(*
joint
,
del_one_var
);
758
759
if
((
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
first
)
760
|| (
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
second
)) {
761
GUM_ERROR
(
OutOfBounds
,
"memory usage out of long int range"
)
762
}
763
764
if
(
current_memory
+
comb_memory
.
first
>
max_memory
) {
765
max_memory
=
current_memory
+
comb_memory
.
first
;
766
}
767
768
current_memory
+=
comb_memory
.
second
;
769
770
// compute the table resulting from marginalizing out del_var from joint
771
Sequence
<
const
DiscreteVariable
* >*
marginal
;
772
773
if
(
joint_to_delete
) {
774
marginal
=
joint
;
775
}
else
{
776
marginal
=
new
Sequence
<
const
DiscreteVariable
* >(*
joint
);
777
}
778
779
marginal
->
erase
(
del_var
);
780
781
// update tables_vars_per_var : remove the variables of the TABLEs we
782
// combined from this hashtable
783
// update accordingly tables_per_vars : remove these TABLEs
784
// update accordingly product_size : when a variable is no more used by
785
// any TABLE, divide product_size by its domain size
786
787
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
788
const
Sequence
<
const
DiscreteVariable
* >&
table_vars
= *
ptrSeq
;
789
const
Size
tab_vars_size
=
table_vars
.
size
();
790
791
for
(
Size
i
= 0;
i
<
tab_vars_size
; ++
i
) {
792
if
(
del_vars
.
contains
(
table_vars
[
i
])) {
793
// ok, here we have a variable that needed to be removed => update
794
// product_size, tables_per_var and tables_vars_per_var
795
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
table_vars_of_var_i
796
=
tables_vars_per_var
[
table_vars
[
i
]];
797
double
div_size
= 1.0;
798
799
for
(
Size
j
= 0;
j
<
tab_vars_size
; ++
j
) {
800
Size
k
= --
table_vars_of_var_i
[
table_vars
[
j
]];
801
802
if
(
k
== 0) {
803
div_size
*=
table_vars
[
j
]->
domainSize
();
804
table_vars_of_var_i
.
erase
(
table_vars
[
j
]);
805
}
806
}
807
808
tables_per_var
[
table_vars
[
i
]].
erase
(
ptrSeq
);
809
810
if
(
div_size
!= 1) {
811
product_size
.
setPriority
(
table_vars
[
i
],
812
product_size
.
priority
(
table_vars
[
i
]) /
div_size
);
813
}
814
}
815
}
816
817
if
(
tmp_marginals
.
contains
(
ptrSeq
)) {
818
Size
del_size
= 1;
819
820
for
(
const
auto
ptrVar
: *
ptrSeq
) {
821
del_size
*=
ptrVar
->
domainSize
();
822
}
823
824
current_memory
-=
long
(
del_size
);
825
826
delete
ptrSeq
;
827
tmp_marginals
.
erase
(
ptrSeq
);
828
}
829
}
830
831
tables_per_var
.
erase
(
del_var
);
832
833
// add the new projected marginal to the list of TABLES
834
for
(
const
auto
mvar
: *
marginal
) {
835
if
(
del_vars
.
contains
(
mvar
)) {
836
// add the new marginal table to the set of tables of var i
837
tables_per_var
[
mvar
].
insert
(
marginal
);
838
839
// add the variables of the table to tables_vars_per_var[vars[i]]
840
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
=
tables_vars_per_var
[
mvar
];
841
double
mult_size
= 1.0;
842
843
for
(
const
auto
var
: *
marginal
) {
844
try
{
845
++
iter_vars
[
var
];
846
}
catch
(
const
NotFound
&) {
847
iter_vars
.
insert
(
var
, 1);
848
mult_size
*= (
double
)
var
->
domainSize
();
849
}
850
}
851
852
if
(
mult_size
!= 1) {
853
product_size
.
setPriority
(
mvar
,
product_size
.
priority
(
mvar
) *
mult_size
);
854
}
855
}
856
}
857
858
tmp_marginals
.
insert
(
marginal
);
859
}
860
861
// here, tmp_marginals contains all the newly created tables
862
for
(
auto
iter
=
tmp_marginals
.
beginSafe
();
iter
!=
tmp_marginals
.
endSafe
(); ++
iter
) {
863
delete
*
iter
;
864
}
865
866
return
std
::
pair
<
long
,
long
>(
max_memory
,
current_memory
);
867
}
868
869
// returns the memory consumption used during the combinations and
870
// projections
871
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
872
std
::
pair
<
long
,
long
>
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
memoryUsage
(
873
const
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
set
,
874
const
Set
<
const
DiscreteVariable
* >&
del_vars
)
const
{
875
// create the set of sets of discrete variables involved in the tables
876
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
var_set
(
set
.
size
());
877
878
for
(
const
auto
ptrTab
:
set
) {
879
var_set
<< &(
ptrTab
->
variablesSequence
());
880
}
881
882
return
memoryUsage
(
var_set
,
del_vars
);
883
}
884
885
}
/* namespace gum */
886
887
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643