aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
multiDimCombineAndProjectDefault_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief An efficient class for combining and projecting MultiDim tables
25
*
26
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
29
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
30
31
#
include
<
limits
>
32
33
#
include
<
agrum
/
agrum
.
h
>
34
35
#
include
<
agrum
/
tools
/
multidim
/
utils
/
operators
/
multiDimCombineAndProjectDefault
.
h
>
36
37
namespace
gum
{
38
39
// default constructor
40
template
<
typename
GUM_SCALAR,
template
<
typename
>
class
TABLE >
41
MultiDimCombineAndProjectDefault< GUM_SCALAR, TABLE >::MultiDimCombineAndProjectDefault(
42
TABLE< GUM_SCALAR >* (*combine)(
const
TABLE< GUM_SCALAR >&,
const
TABLE< GUM_SCALAR >&),
43
TABLE< GUM_SCALAR >* (*project)(
const
TABLE< GUM_SCALAR >&,
44
const
Set<
const
DiscreteVariable* >&)) :
45
MultiDimCombineAndProject< GUM_SCALAR, TABLE >(),
46
_combination_(
new
MultiDimCombinationDefault< GUM_SCALAR, TABLE >(combine)),
47
_projection_(
new
MultiDimProjection< GUM_SCALAR, TABLE >(project)) {
48
// for debugging purposes
49
GUM_CONSTRUCTOR(MultiDimCombineAndProjectDefault);
50
}
51
52
// copy constructor
53
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
54
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
MultiDimCombineAndProjectDefault
(
55
const
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>&
from
) :
56
MultiDimCombineAndProject
<
GUM_SCALAR
,
TABLE
>(),
57
_combination_
(
from
.
_combination_
->
newFactory
()),
58
_projection_
(
from
.
_projection_
->
newFactory
()) {
59
// for debugging purposes
60
GUM_CONS_CPY
(
MultiDimCombineAndProjectDefault
);
61
}
62
63
// destructor
64
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
65
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::~
MultiDimCombineAndProjectDefault
() {
66
// for debugging purposes
67
GUM_DESTRUCTOR
(
MultiDimCombineAndProjectDefault
);
68
delete
_combination_
;
69
delete
_projection_
;
70
}
71
72
// virtual constructor
73
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
74
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>*
75
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
newFactory
()
const
{
76
return
new
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>(*
this
);
77
}
78
79
// combine and project
80
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
81
Set
<
const
TABLE
<
GUM_SCALAR
>* >
82
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
combineAndProject
(
83
Set
<
const
TABLE
<
GUM_SCALAR
>* >
table_set
,
84
Set
<
const
DiscreteVariable
* >
del_vars
) {
85
// when we remove a variable, we need to combine all the tables containing
86
// this variable in order to produce a new unique table containing this
87
// variable. Removing the variable is then performed by marginalizing it
88
// out of the table. In the combineAndProject algorithm, we wish to remove
89
// first variables that produce small tables. This should speed up the
90
// marginalizing process
91
Size
nb_vars
;
92
{
93
// determine the set of all the variables involved in the tables.
94
// this should help sizing correctly the hashtables
95
Set
<
const
DiscreteVariable
* >
all_vars
;
96
97
for
(
const
auto
ptrTab
:
table_set
) {
98
for
(
const
auto
ptrVar
:
ptrTab
->
variablesSequence
()) {
99
all_vars
.
insert
(
ptrVar
);
100
}
101
}
102
103
nb_vars
=
all_vars
.
size
();
104
}
105
106
// the tables containing a given variable to be deleted
107
HashTable
<
const
DiscreteVariable
*,
Set
<
const
TABLE
<
GUM_SCALAR
>* > >
tables_per_var
(
nb_vars
);
108
109
// for a given variable X to be deleted, the list of all the variables of
110
// the tables containing X (actually, we also count the number of tables
111
// containing the variable. This is more efficient for computing and
112
// updating the product_size priority queue (see below) when some tables
113
// are removed)
114
HashTable
<
const
DiscreteVariable
*,
HashTable
<
const
DiscreteVariable
*,
unsigned
int
> >
115
tables_vars_per_var
(
nb_vars
);
116
117
// initialize tables_vars_per_var and tables_per_var
118
{
119
Set
<
const
TABLE
<
GUM_SCALAR
>* >
empty_set
(
table_set
.
size
());
120
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>
empty_hash
(
nb_vars
);
121
122
for
(
const
auto
ptrVar
:
del_vars
) {
123
tables_per_var
.
insert
(
ptrVar
,
empty_set
);
124
tables_vars_per_var
.
insert
(
ptrVar
,
empty_hash
);
125
}
126
127
// update properly tables_per_var and tables_vars_per_var
128
for
(
const
auto
ptrTab
:
table_set
) {
129
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
ptrTab
->
variablesSequence
();
130
131
for
(
const
auto
ptrVar
:
vars
) {
132
if
(
del_vars
.
contains
(
ptrVar
)) {
133
// add the table to the set of tables related to vars[i]
134
tables_per_var
[
ptrVar
].
insert
(
ptrTab
);
135
136
// add the variables of the table to tables_vars_per_var[vars[i]]
137
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
138
=
tables_vars_per_var
[
ptrVar
];
139
140
for
(
const
auto
xptrVar
:
vars
) {
141
try
{
142
++
iter_vars
[
xptrVar
];
143
}
catch
(
const
NotFound
&) {
iter_vars
.
insert
(
xptrVar
, 1); }
144
}
145
}
146
}
147
}
148
}
149
150
// the sizes of the tables produced when removing a given discrete variable
151
PriorityQueue
<
const
DiscreteVariable
*,
double
>
product_size
;
152
153
// initialize properly product_size
154
for
(
const
auto
&
elt
:
tables_vars_per_var
) {
155
double
size
= 1.0;
156
const
auto
ptrVar
=
elt
.
first
;
157
const
auto
&
hashvars
=
elt
.
second
;
// HashTable<DiscreteVariable*, int>
158
159
if
(
hashvars
.
size
()) {
160
for
(
const
auto
&
xelt
:
hashvars
) {
161
size
*= (
double
)
xelt
.
first
->
domainSize
();
162
}
163
164
product_size
.
insert
(
ptrVar
,
size
);
165
}
166
}
167
168
// create a set of the temporary tables created during the
169
// marginalization process (useful for deallocating temporary tables)
170
Set
<
const
TABLE
<
GUM_SCALAR
>* >
tmp_marginals
(
table_set
.
size
());
171
172
// now, remove all the variables in del_vars, starting from those that
173
// produce the smallest tables
174
while
(!
product_size
.
empty
()) {
175
// get the best variable to remove
176
const
DiscreteVariable
*
del_var
=
product_size
.
pop
();
177
del_vars
.
erase
(
del_var
);
178
179
// get the set of tables to combine
180
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
tables_to_combine
=
tables_per_var
[
del_var
];
181
182
// if there is no tables to combine, do nothing
183
if
(
tables_to_combine
.
size
() == 0)
continue
;
184
185
// compute the combination of all the tables: if there is only one table,
186
// there is nothing to do, else we shall use the MultiDimCombination
187
// to perform the combination
188
TABLE
<
GUM_SCALAR
>*
joint
;
189
190
bool
joint_to_delete
=
false
;
191
192
if
(
tables_to_combine
.
size
() == 1) {
193
joint
=
const_cast
<
TABLE
<
GUM_SCALAR
>* >(*(
tables_to_combine
.
begin
()));
194
joint_to_delete
=
false
;
195
}
else
{
196
joint
=
_combination_
->
combine
(
tables_to_combine
);
197
joint_to_delete
=
true
;
198
}
199
200
// compute the table resulting from marginalizing out del_var from joint
201
Set
<
const
DiscreteVariable
* >
del_one_var
;
202
del_one_var
<<
del_var
;
203
204
TABLE
<
GUM_SCALAR
>*
marginal
=
_projection_
->
project
(*
joint
,
del_one_var
);
205
206
// remove the temporary joint if needed
207
if
(
joint_to_delete
)
delete
joint
;
208
209
// update tables_vars_per_var : remove the variables of the TABLEs we
210
// combined from this hashtable
211
// update accordingly tables_per_vars : remove these TABLEs
212
// update accordingly product_size : when a variable is no more used by
213
// any TABLE, divide product_size by its domain size
214
215
for
(
const
auto
ptrTab
:
tables_to_combine
) {
216
const
Sequence
<
const
DiscreteVariable
* >&
table_vars
=
ptrTab
->
variablesSequence
();
217
const
Size
tab_vars_size
=
table_vars
.
size
();
218
219
for
(
Size
i
= 0;
i
<
tab_vars_size
; ++
i
) {
220
if
(
del_vars
.
contains
(
table_vars
[
i
])) {
221
// ok, here we have a variable that needed to be removed => update
222
// product_size, tables_per_var and tables_vars_per_var: here,
223
// the update corresponds to removing table PtrTab
224
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
table_vars_of_var_i
225
=
tables_vars_per_var
[
table_vars
[
i
]];
226
double
div_size
= 1.0;
227
228
for
(
Size
j
= 0;
j
<
tab_vars_size
; ++
j
) {
229
unsigned
int
k
= --
table_vars_of_var_i
[
table_vars
[
j
]];
230
231
if
(
k
== 0) {
232
div_size
*=
table_vars
[
j
]->
domainSize
();
233
table_vars_of_var_i
.
erase
(
table_vars
[
j
]);
234
}
235
}
236
237
tables_per_var
[
table_vars
[
i
]].
erase
(
ptrTab
);
238
239
if
(
div_size
!= 1.0) {
240
product_size
.
setPriority
(
table_vars
[
i
],
241
product_size
.
priority
(
table_vars
[
i
]) /
div_size
);
242
}
243
}
244
}
245
246
if
(
tmp_marginals
.
contains
(
ptrTab
)) {
247
delete
ptrTab
;
248
tmp_marginals
.
erase
(
ptrTab
);
249
}
250
251
table_set
.
erase
(
ptrTab
);
252
}
253
254
tables_per_var
.
erase
(
del_var
);
255
256
// add the new projected marginal to the list of TABLES
257
const
Sequence
<
const
DiscreteVariable
* >&
marginal_vars
=
marginal
->
variablesSequence
();
258
259
for
(
const
auto
mvar
:
marginal_vars
) {
260
if
(
del_vars
.
contains
(
mvar
)) {
261
// add the new marginal table to the set of tables of mvar
262
tables_per_var
[
mvar
].
insert
(
marginal
);
263
264
// add the variables of the table to tables_vars_per_var[mvar]
265
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
=
tables_vars_per_var
[
mvar
];
266
double
mult_size
= 1.0;
267
268
for
(
const
auto
var
:
marginal_vars
) {
269
try
{
270
++
iter_vars
[
var
];
271
}
catch
(
const
NotFound
&) {
272
iter_vars
.
insert
(
var
, 1);
273
mult_size
*= (
double
)
var
->
domainSize
();
274
}
275
}
276
277
if
(
mult_size
!= 1.0) {
278
product_size
.
setPriority
(
mvar
,
product_size
.
priority
(
mvar
) *
mult_size
);
279
}
280
}
281
}
282
283
table_set
.
insert
(
marginal
);
284
tmp_marginals
.
insert
(
marginal
);
285
}
286
287
// here, tmp_marginals contains all the newly created tables and
288
// table_set contains the list of the tables resulting from the
289
// marginalizing out of del_vars of the combination of the tables
290
// of table_set. Note in particular that it will contain all the
291
// potentials with no dimension (constants)
292
return
table_set
;
293
}
294
295
// changes the function used for combining two TABLES
296
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
297
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setCombineFunction
(
298
TABLE
<
GUM_SCALAR
>* (*
combine
)(
const
TABLE
<
GUM_SCALAR
>&,
const
TABLE
<
GUM_SCALAR
>&)) {
299
_combination_
->
setCombineFunction
(
combine
);
300
}
301
302
// returns the current combination function
303
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
304
INLINE
TABLE
<
GUM_SCALAR
>* (
305
*
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
combineFunction
())(
306
const
TABLE
<
GUM_SCALAR
>&,
307
const
TABLE
<
GUM_SCALAR
>&) {
308
return
_combination_
->
combineFunction
();
309
}
310
311
// changes the class that performs the combinations
312
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
313
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setCombinationClass
(
314
const
MultiDimCombination
<
GUM_SCALAR
,
TABLE
>&
comb_class
) {
315
delete
_combination_
;
316
_combination_
=
comb_class
.
newFactory
();
317
}
318
319
// changes the function used for projecting TABLES
320
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
321
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setProjectFunction
(
322
TABLE
<
GUM_SCALAR
>* (*
proj
)(
const
TABLE
<
GUM_SCALAR
>&,
323
const
Set
<
const
DiscreteVariable
* >&)) {
324
_projection_
->
setProjectFunction
(
proj
);
325
}
326
327
// returns the current projection function
328
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
329
INLINE
TABLE
<
GUM_SCALAR
>* (
330
*
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
projectFunction
())(
331
const
TABLE
<
GUM_SCALAR
>&,
332
const
Set
<
const
DiscreteVariable
* >&) {
333
return
_projection_
->
projectFunction
();
334
}
335
336
// changes the class that performs the projections
337
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
338
INLINE
void
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
setProjectionClass
(
339
const
MultiDimProjection
<
GUM_SCALAR
,
TABLE
>&
proj_class
) {
340
delete
_projection_
;
341
_projection_
=
proj_class
.
newFactory
();
342
}
343
344
/** @brief returns a rough estimate of the number of operations that will be
345
* performed to compute the combination */
346
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
347
float
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
nbOperations
(
348
const
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
table_set
,
349
Set
<
const
DiscreteVariable
* >
del_vars
)
const
{
350
// when we remove a variable, we need to combine all the tables containing
351
// this variable in order to produce a new unique table containing this
352
// variable. Here, we do not have the tables but only their variables
353
// (dimensions), but the principle is identical. Removing a variable is then
354
// performed by marginalizing it out of the table or, equivalently, to
355
// remove it from the table's list of variables. In the
356
// combineAndProjectDefault algorithm, we wish to remove first variables
357
// that would produce small tables. This should speed up the whole
358
// marginalizing process.
359
360
Size
nb_vars
;
361
{
362
// determine the set of all the variables involved in the tables.
363
// this should help sizing correctly the hashtables
364
Set
<
const
DiscreteVariable
* >
all_vars
;
365
366
for
(
const
auto
ptrSeq
:
table_set
) {
367
for
(
const
auto
ptrVar
: *
ptrSeq
) {
368
all_vars
.
insert
(
ptrVar
);
369
}
370
}
371
372
nb_vars
=
all_vars
.
size
();
373
}
374
375
// the tables (actually their variables) containing a given variable
376
// to be deleted
377
HashTable
<
const
DiscreteVariable
*,
Set
<
const
Sequence
<
const
DiscreteVariable
* >* > >
378
tables_per_var
(
nb_vars
);
379
380
// for a given variable X to be deleted, the list of all the variables of
381
// the tables containing X (actually, we count the number of tables
382
// containing the variable. This is more efficient for computing and
383
// updating the product_size priority queue (see below) when some tables
384
// are removed)
385
HashTable
<
const
DiscreteVariable
*,
HashTable
<
const
DiscreteVariable
*,
unsigned
int
> >
386
tables_vars_per_var
(
nb_vars
);
387
388
// initialize tables_vars_per_var and tables_per_var
389
{
390
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
empty_set
(
table_set
.
size
());
391
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>
empty_hash
(
nb_vars
);
392
393
for
(
const
auto
ptrVar
:
del_vars
) {
394
tables_per_var
.
insert
(
ptrVar
,
empty_set
);
395
tables_vars_per_var
.
insert
(
ptrVar
,
empty_hash
);
396
}
397
398
// update properly tables_per_var and tables_vars_per_var
399
for
(
const
auto
ptrSeq
:
table_set
) {
400
const
Sequence
<
const
DiscreteVariable
* >&
vars
= *
ptrSeq
;
401
402
for
(
const
auto
ptrVar
:
vars
) {
403
if
(
del_vars
.
contains
(
ptrVar
)) {
404
// add the table's variables to the set of those related to ptrVar
405
tables_per_var
[
ptrVar
].
insert
(
ptrSeq
);
406
407
// add the variables of the table to tables_vars_per_var[ptrVar]
408
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
409
=
tables_vars_per_var
[
ptrVar
];
410
411
for
(
const
auto
xptrVar
:
vars
) {
412
try
{
413
++
iter_vars
[
xptrVar
];
414
}
catch
(
const
NotFound
&) {
iter_vars
.
insert
(
xptrVar
, 1); }
415
}
416
}
417
}
418
}
419
}
420
421
// the sizes of the tables produced when removing a given discrete variable
422
PriorityQueue
<
const
DiscreteVariable
*,
double
>
product_size
;
423
424
// initialize properly product_size
425
for
(
const
auto
&
elt
:
tables_vars_per_var
) {
426
double
size
= 1.0;
427
const
auto
ptrVar
=
elt
.
first
;
428
const
auto
hashvars
=
elt
.
second
;
// HashTable<DiscreteVariable*, int>
429
430
if
(
hashvars
.
size
()) {
431
for
(
const
auto
&
xelt
:
hashvars
) {
432
size
*= (
double
)
xelt
.
first
->
domainSize
();
433
}
434
435
product_size
.
insert
(
ptrVar
,
size
);
436
}
437
}
438
439
// the resulting number of operations
440
float
nb_operations
= 0;
441
442
// create a set of the temporary table's variables created during the
443
// marginalization process (useful for deallocating temporary tables)
444
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
tmp_marginals
(
table_set
.
size
());
445
446
// now, remove all the variables in del_vars, starting from those that
447
// produce the smallest tables
448
while
(!
product_size
.
empty
()) {
449
// get the best variable to remove
450
const
DiscreteVariable
*
del_var
=
product_size
.
pop
();
451
del_vars
.
erase
(
del_var
);
452
453
// get the set of tables to combine
454
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
tables_to_combine
455
=
tables_per_var
[
del_var
];
456
457
// if there is no tables to combine, do nothing
458
if
(
tables_to_combine
.
size
() == 0)
continue
;
459
460
// compute the combination of all the tables: if there is only one table,
461
// there is nothing to do, else we shall use the MultiDimCombination
462
// to perform the combination
463
Sequence
<
const
DiscreteVariable
* >*
joint
;
464
465
bool
joint_to_delete
=
false
;
466
467
if
(
tables_to_combine
.
size
() == 1) {
468
joint
469
=
const_cast
<
Sequence
<
const
DiscreteVariable
* >* >(*(
tables_to_combine
.
beginSafe
()));
470
joint_to_delete
=
false
;
471
}
else
{
472
// here, compute the union of all the variables of the tables to combine
473
joint
=
new
Sequence
<
const
DiscreteVariable
* >;
474
475
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
476
for
(
const
auto
ptrVar
: *
ptrSeq
) {
477
if
(!
joint
->
exists
(
ptrVar
)) {
joint
->
insert
(
ptrVar
); }
478
}
479
}
480
481
joint_to_delete
=
true
;
482
483
// update the number of operations performed
484
nb_operations
+=
_combination_
->
nbOperations
(
tables_to_combine
);
485
}
486
487
// update the number of operations performed by marginalizing out del_var
488
Set
<
const
DiscreteVariable
* >
del_one_var
;
489
del_one_var
<<
del_var
;
490
491
nb_operations
+=
_projection_
->
nbOperations
(*
joint
,
del_one_var
);
492
493
// compute the table resulting from marginalizing out del_var from joint
494
Sequence
<
const
DiscreteVariable
* >*
marginal
;
495
496
if
(
joint_to_delete
) {
497
marginal
=
joint
;
498
}
else
{
499
marginal
=
new
Sequence
<
const
DiscreteVariable
* >(*
joint
);
500
}
501
502
marginal
->
erase
(
del_var
);
503
504
// update tables_vars_per_var : remove the variables of the TABLEs we
505
// combined from this hashtable
506
// update accordingly tables_per_vars : remove these TABLEs
507
// update accordingly product_size : when a variable is no more used by
508
// any TABLE, divide product_size by its domain size
509
510
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
511
const
Sequence
<
const
DiscreteVariable
* >&
table_vars
= *
ptrSeq
;
512
const
Size
tab_vars_size
=
table_vars
.
size
();
513
514
for
(
Size
i
= 0;
i
<
tab_vars_size
; ++
i
) {
515
if
(
del_vars
.
contains
(
table_vars
[
i
])) {
516
// ok, here we have a variable that needed to be removed => update
517
// product_size, tables_per_var and tables_vars_per_var
518
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
table_vars_of_var_i
519
=
tables_vars_per_var
[
table_vars
[
i
]];
520
double
div_size
= 1.0;
521
522
for
(
Size
j
= 0;
j
<
tab_vars_size
; ++
j
) {
523
unsigned
int
k
= --
table_vars_of_var_i
[
table_vars
[
j
]];
524
525
if
(
k
== 0) {
526
div_size
*=
table_vars
[
j
]->
domainSize
();
527
table_vars_of_var_i
.
erase
(
table_vars
[
j
]);
528
}
529
}
530
531
tables_per_var
[
table_vars
[
i
]].
erase
(
ptrSeq
);
532
533
if
(
div_size
!= 1.0) {
534
product_size
.
setPriority
(
table_vars
[
i
],
535
product_size
.
priority
(
table_vars
[
i
]) /
div_size
);
536
}
537
}
538
}
539
540
if
(
tmp_marginals
.
contains
(
ptrSeq
)) {
541
delete
ptrSeq
;
542
tmp_marginals
.
erase
(
ptrSeq
);
543
}
544
}
545
546
tables_per_var
.
erase
(
del_var
);
547
548
// add the new projected marginal to the list of TABLES
549
for
(
const
auto
mvar
: *
marginal
) {
550
if
(
del_vars
.
contains
(
mvar
)) {
551
// add the new marginal table to the set of tables of var i
552
tables_per_var
[
mvar
].
insert
(
marginal
);
553
554
// add the variables of the table to tables_vars_per_var[vars[i]]
555
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
=
tables_vars_per_var
[
mvar
];
556
double
mult_size
= 1.0;
557
558
for
(
const
auto
var
: *
marginal
) {
559
try
{
560
++
iter_vars
[
var
];
561
}
catch
(
const
NotFound
&) {
562
iter_vars
.
insert
(
var
, 1);
563
mult_size
*= (
double
)
var
->
domainSize
();
564
}
565
}
566
567
if
(
mult_size
!= 1.0) {
568
product_size
.
setPriority
(
mvar
,
product_size
.
priority
(
mvar
) *
mult_size
);
569
}
570
}
571
}
572
573
tmp_marginals
.
insert
(
marginal
);
574
}
575
576
// here, tmp_marginals contains all the newly created tables
577
for
(
auto
iter
=
tmp_marginals
.
beginSafe
();
iter
!=
tmp_marginals
.
endSafe
(); ++
iter
) {
578
delete
*
iter
;
579
}
580
581
return
nb_operations
;
582
}
583
584
/** @brief returns a rough estimate of the number of operations that will be
585
* performed to compute the combination */
586
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
587
float
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
nbOperations
(
588
const
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
set
,
589
const
Set
<
const
DiscreteVariable
* >&
del_vars
)
const
{
590
// create the set of sets of discrete variables involved in the tables
591
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
var_set
(
set
.
size
());
592
593
for
(
const
auto
ptrTab
:
set
) {
594
var_set
<< &(
ptrTab
->
variablesSequence
());
595
}
596
597
return
nbOperations
(
var_set
,
del_vars
);
598
}
599
600
// returns the memory consumption used during the combinations and
601
// projections
602
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
603
std
::
pair
<
long
,
long
>
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
memoryUsage
(
604
const
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
table_set
,
605
Set
<
const
DiscreteVariable
* >
del_vars
)
const
{
606
// when we remove a variable, we need to combine all the tables containing
607
// this variable in order to produce a new unique table containing this
608
// variable. Here, we do not have the tables but only their variables
609
// (dimensions), but the principle is identical. Removing a variable is then
610
// performed by marginalizing it out of the table or, equivalently, to
611
// remove it from the table's list of variables. In the
612
// combineAndProjectDefault algorithm, we wish to remove first variables
613
// that would produce small tables. This should speed up the whole
614
// marginalizing process.
615
616
Size
nb_vars
;
617
{
618
// determine the set of all the variables involved in the tables.
619
// this should help sizing correctly the hashtables
620
Set
<
const
DiscreteVariable
* >
all_vars
;
621
622
for
(
const
auto
ptrSeq
:
table_set
) {
623
for
(
const
auto
ptrVar
: *
ptrSeq
) {
624
all_vars
.
insert
(
ptrVar
);
625
}
626
}
627
628
nb_vars
=
all_vars
.
size
();
629
}
630
631
// the tables (actually their variables) containing a given variable
632
HashTable
<
const
DiscreteVariable
*,
Set
<
const
Sequence
<
const
DiscreteVariable
* >* > >
633
tables_per_var
(
nb_vars
);
634
// for a given variable X to be deleted, the list of all the variables of
635
// the tables containing X (actually, we count the number of tables
636
// containing the variable. This is more efficient for computing and
637
// updating the product_size priority queue (see below) when some tables
638
// are removed)
639
HashTable
<
const
DiscreteVariable
*,
HashTable
<
const
DiscreteVariable
*,
unsigned
int
> >
640
tables_vars_per_var
(
nb_vars
);
641
642
// initialize tables_vars_per_var and tables_per_var
643
{
644
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
empty_set
(
table_set
.
size
());
645
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>
empty_hash
(
nb_vars
);
646
647
for
(
const
auto
ptrVar
:
del_vars
) {
648
tables_per_var
.
insert
(
ptrVar
,
empty_set
);
649
tables_vars_per_var
.
insert
(
ptrVar
,
empty_hash
);
650
}
651
652
// update properly tables_per_var and tables_vars_per_var
653
for
(
const
auto
ptrSeq
:
table_set
) {
654
const
Sequence
<
const
DiscreteVariable
* >&
vars
= *
ptrSeq
;
655
656
for
(
const
auto
ptrVar
:
vars
) {
657
if
(
del_vars
.
contains
(
ptrVar
)) {
658
// add the table's variables to the set of those related to ptrVar
659
tables_per_var
[
ptrVar
].
insert
(
ptrSeq
);
660
661
// add the variables of the table to tables_vars_per_var[ptrVar]
662
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
663
=
tables_vars_per_var
[
ptrVar
];
664
665
for
(
const
auto
xptrVar
:
vars
) {
666
try
{
667
++
iter_vars
[
xptrVar
];
668
}
catch
(
const
NotFound
&) {
iter_vars
.
insert
(
xptrVar
, 1); }
669
}
670
}
671
}
672
}
673
}
674
675
// the sizes of the tables produced when removing a given discrete variable
676
PriorityQueue
<
const
DiscreteVariable
*,
double
>
product_size
;
677
678
// initialize properly product_size
679
for
(
const
auto
&
elt
:
tables_vars_per_var
) {
680
double
size
= 1.0;
681
const
auto
ptrVar
=
elt
.
first
;
682
const
auto
hashvars
=
elt
.
second
;
// HashTable<DiscreteVariable*, int>
683
684
if
(
hashvars
.
size
()) {
685
for
(
const
auto
&
xelt
:
hashvars
) {
686
size
*= (
double
)
xelt
.
first
->
domainSize
();
687
}
688
689
product_size
.
insert
(
ptrVar
,
size
);
690
}
691
}
692
693
// the resulting memory consumtions
694
long
max_memory
= 0;
695
long
current_memory
= 0;
696
697
// create a set of the temporary table's variables created during the
698
// marginalization process (useful for deallocating temporary tables)
699
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
tmp_marginals
(
table_set
.
size
());
700
701
// now, remove all the variables in del_vars, starting from those that
702
// produce
703
// the smallest tables
704
while
(!
product_size
.
empty
()) {
705
// get the best variable to remove
706
const
DiscreteVariable
*
del_var
=
product_size
.
pop
();
707
del_vars
.
erase
(
del_var
);
708
709
// get the set of tables to combine
710
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >&
tables_to_combine
711
=
tables_per_var
[
del_var
];
712
713
// if there is no tables to combine, do nothing
714
if
(
tables_to_combine
.
size
() == 0)
continue
;
715
716
// compute the combination of all the tables: if there is only one table,
717
// there is nothing to do, else we shall use the MultiDimCombination
718
// to perform the combination
719
Sequence
<
const
DiscreteVariable
* >*
joint
;
720
721
bool
joint_to_delete
=
false
;
722
723
if
(
tables_to_combine
.
size
() == 1) {
724
joint
725
=
const_cast
<
Sequence
<
const
DiscreteVariable
* >* >(*(
tables_to_combine
.
beginSafe
()));
726
joint_to_delete
=
false
;
727
}
else
{
728
// here, compute the union of all the variables of the tables to combine
729
joint
=
new
Sequence
<
const
DiscreteVariable
* >;
730
731
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
732
for
(
const
auto
ptrVar
: *
ptrSeq
) {
733
if
(!
joint
->
exists
(
ptrVar
)) {
joint
->
insert
(
ptrVar
); }
734
}
735
}
736
737
joint_to_delete
=
true
;
738
739
// update the number of operations performed
740
std
::
pair
<
long
,
long
>
comb_memory
=
_combination_
->
memoryUsage
(
tables_to_combine
);
741
742
if
((
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
first
)
743
|| (
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
second
)) {
744
GUM_ERROR
(
OutOfBounds
,
"memory usage out of long int range"
)
745
}
746
747
if
(
current_memory
+
comb_memory
.
first
>
max_memory
) {
748
max_memory
=
current_memory
+
comb_memory
.
first
;
749
}
750
751
current_memory
+=
comb_memory
.
second
;
752
}
753
754
// update the number of operations performed by marginalizing out del_var
755
Set
<
const
DiscreteVariable
* >
del_one_var
;
756
del_one_var
<<
del_var
;
757
758
std
::
pair
<
long
,
long
>
comb_memory
=
_projection_
->
memoryUsage
(*
joint
,
del_one_var
);
759
760
if
((
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
first
)
761
|| (
std
::
numeric_limits
<
long
>::
max
() -
current_memory
<
comb_memory
.
second
)) {
762
GUM_ERROR
(
OutOfBounds
,
"memory usage out of long int range"
)
763
}
764
765
if
(
current_memory
+
comb_memory
.
first
>
max_memory
) {
766
max_memory
=
current_memory
+
comb_memory
.
first
;
767
}
768
769
current_memory
+=
comb_memory
.
second
;
770
771
// compute the table resulting from marginalizing out del_var from joint
772
Sequence
<
const
DiscreteVariable
* >*
marginal
;
773
774
if
(
joint_to_delete
) {
775
marginal
=
joint
;
776
}
else
{
777
marginal
=
new
Sequence
<
const
DiscreteVariable
* >(*
joint
);
778
}
779
780
marginal
->
erase
(
del_var
);
781
782
// update tables_vars_per_var : remove the variables of the TABLEs we
783
// combined from this hashtable
784
// update accordingly tables_per_vars : remove these TABLEs
785
// update accordingly product_size : when a variable is no more used by
786
// any TABLE, divide product_size by its domain size
787
788
for
(
const
auto
ptrSeq
:
tables_to_combine
) {
789
const
Sequence
<
const
DiscreteVariable
* >&
table_vars
= *
ptrSeq
;
790
const
Size
tab_vars_size
=
table_vars
.
size
();
791
792
for
(
Size
i
= 0;
i
<
tab_vars_size
; ++
i
) {
793
if
(
del_vars
.
contains
(
table_vars
[
i
])) {
794
// ok, here we have a variable that needed to be removed => update
795
// product_size, tables_per_var and tables_vars_per_var
796
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
table_vars_of_var_i
797
=
tables_vars_per_var
[
table_vars
[
i
]];
798
double
div_size
= 1.0;
799
800
for
(
Size
j
= 0;
j
<
tab_vars_size
; ++
j
) {
801
Size
k
= --
table_vars_of_var_i
[
table_vars
[
j
]];
802
803
if
(
k
== 0) {
804
div_size
*=
table_vars
[
j
]->
domainSize
();
805
table_vars_of_var_i
.
erase
(
table_vars
[
j
]);
806
}
807
}
808
809
tables_per_var
[
table_vars
[
i
]].
erase
(
ptrSeq
);
810
811
if
(
div_size
!= 1) {
812
product_size
.
setPriority
(
table_vars
[
i
],
813
product_size
.
priority
(
table_vars
[
i
]) /
div_size
);
814
}
815
}
816
}
817
818
if
(
tmp_marginals
.
contains
(
ptrSeq
)) {
819
Size
del_size
= 1;
820
821
for
(
const
auto
ptrVar
: *
ptrSeq
) {
822
del_size
*=
ptrVar
->
domainSize
();
823
}
824
825
current_memory
-=
long
(
del_size
);
826
827
delete
ptrSeq
;
828
tmp_marginals
.
erase
(
ptrSeq
);
829
}
830
}
831
832
tables_per_var
.
erase
(
del_var
);
833
834
// add the new projected marginal to the list of TABLES
835
for
(
const
auto
mvar
: *
marginal
) {
836
if
(
del_vars
.
contains
(
mvar
)) {
837
// add the new marginal table to the set of tables of var i
838
tables_per_var
[
mvar
].
insert
(
marginal
);
839
840
// add the variables of the table to tables_vars_per_var[vars[i]]
841
HashTable
<
const
DiscreteVariable
*,
unsigned
int
>&
iter_vars
=
tables_vars_per_var
[
mvar
];
842
double
mult_size
= 1.0;
843
844
for
(
const
auto
var
: *
marginal
) {
845
try
{
846
++
iter_vars
[
var
];
847
}
catch
(
const
NotFound
&) {
848
iter_vars
.
insert
(
var
, 1);
849
mult_size
*= (
double
)
var
->
domainSize
();
850
}
851
}
852
853
if
(
mult_size
!= 1) {
854
product_size
.
setPriority
(
mvar
,
product_size
.
priority
(
mvar
) *
mult_size
);
855
}
856
}
857
}
858
859
tmp_marginals
.
insert
(
marginal
);
860
}
861
862
// here, tmp_marginals contains all the newly created tables
863
for
(
auto
iter
=
tmp_marginals
.
beginSafe
();
iter
!=
tmp_marginals
.
endSafe
(); ++
iter
) {
864
delete
*
iter
;
865
}
866
867
return
std
::
pair
<
long
,
long
>(
max_memory
,
current_memory
);
868
}
869
870
// returns the memory consumption used during the combinations and
871
// projections
872
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
873
std
::
pair
<
long
,
long
>
MultiDimCombineAndProjectDefault
<
GUM_SCALAR
,
TABLE
>::
memoryUsage
(
874
const
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
set
,
875
const
Set
<
const
DiscreteVariable
* >&
del_vars
)
const
{
876
// create the set of sets of discrete variables involved in the tables
877
Set
<
const
Sequence
<
const
DiscreteVariable
* >* >
var_set
(
set
.
size
());
878
879
for
(
const
auto
ptrTab
:
set
) {
880
var_set
<< &(
ptrTab
->
variablesSequence
());
881
}
882
883
return
memoryUsage
(
var_set
,
del_vars
);
884
}
885
886
}
/* namespace gum */
887
888
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643