aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
recordCounter_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief The class that computes countings of observations from the database.
24
*
25
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
26
*/
27
28
#
include
<
agrum
/
tools
/
stattests
/
recordCounter
.
h
>
29
30
31
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
32
33
namespace
gum
{
34
35
namespace
learning
{
36
37
38
/// returns the allocator used by the translator
39
template
<
template
<
typename
>
class
ALLOC
>
40
INLINE
typename
RecordCounter
<
ALLOC
>::
allocator_type
41
RecordCounter
<
ALLOC
>::
getAllocator
()
const
{
42
return
parsers__
.
get_allocator
();
43
}
44
45
46
/// default constructor
47
template
<
template
<
typename
>
class
ALLOC
>
48
RecordCounter
<
ALLOC
>::
RecordCounter
(
49
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
50
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
51
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
ranges
,
52
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
53
nodeId2columns
,
54
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
55
parsers__
(
alloc
),
56
ranges__
(
alloc
),
nodeId2columns__
(
nodeId2columns
),
57
last_DB_countings__
(
alloc
),
last_DB_ids__
(
alloc
),
58
last_nonDB_countings__
(
alloc
),
last_nonDB_ids__
(
alloc
) {
59
// check that the columns in nodeId2columns do belong to the database
60
const
std
::
size_t
db_nb_cols
=
parser
.
database
().
nbVariables
();
61
for
(
auto
iter
=
nodeId2columns
.
cbegin
();
iter
!=
nodeId2columns
.
cend
();
62
++
iter
) {
63
if
(
iter
.
second
() >=
db_nb_cols
) {
64
GUM_ERROR
(
OutOfBounds
,
65
"the mapping between ids and database columns "
66
<<
"is incorrect because Column "
<<
iter
.
second
()
67
<<
" does not belong to the database."
);
68
}
69
}
70
71
// create the parsers. There should always be at least one parser
72
if
(
max_nb_threads__
<
std
::
size_t
(1))
max_nb_threads__
=
std
::
size_t
(1);
73
parsers__
.
reserve
(
max_nb_threads__
);
74
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
max_nb_threads__
; ++
i
)
75
parsers__
.
push_back
(
parser
);
76
77
// check that the ranges are within the bounds of the database and
78
// save them
79
checkRanges__
(
ranges
);
80
ranges__
.
reserve
(
ranges
.
size
());
81
for
(
const
auto
&
range
:
ranges
)
82
ranges__
.
push_back
(
range
);
83
84
// dispatch the ranges for the threads
85
dispatchRangesToThreads__
();
86
87
GUM_CONSTRUCTOR
(
RecordCounter
);
88
}
89
90
91
/// default constructor
92
template
<
template
<
typename
>
class
ALLOC
>
93
RecordCounter
<
ALLOC
>::
RecordCounter
(
94
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
95
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
96
nodeId2columns
,
97
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
98
RecordCounter
<
ALLOC
>(
99
parser
,
100
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
101
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >(),
102
nodeId2columns
,
103
alloc
) {}
104
105
106
/// copy constructor with a given allocator
107
template
<
template
<
typename
>
class
ALLOC
>
108
RecordCounter
<
ALLOC
>::
RecordCounter
(
109
const
RecordCounter
<
ALLOC
>&
from
,
110
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
111
parsers__
(
from
.
parsers__
,
alloc
),
112
ranges__
(
from
.
ranges__
,
alloc
),
113
thread_ranges__
(
from
.
thread_ranges__
,
alloc
),
114
nodeId2columns__
(
from
.
nodeId2columns__
),
115
last_DB_countings__
(
from
.
last_DB_countings__
,
alloc
),
116
last_DB_ids__
(
from
.
last_DB_ids__
),
117
last_nonDB_countings__
(
from
.
last_nonDB_countings__
,
alloc
),
118
last_nonDB_ids__
(
from
.
last_nonDB_ids__
),
119
max_nb_threads__
(
from
.
max_nb_threads__
),
120
min_nb_rows_per_thread__
(
from
.
min_nb_rows_per_thread__
) {
121
GUM_CONS_CPY
(
RecordCounter
);
122
}
123
124
125
/// copy constructor
126
template
<
template
<
typename
>
class
ALLOC
>
127
RecordCounter
<
ALLOC
>::
RecordCounter
(
const
RecordCounter
<
ALLOC
>&
from
) :
128
RecordCounter
<
ALLOC
>(
from
,
from
.
getAllocator
()) {}
129
130
131
/// move constructor with a given allocator
132
template
<
template
<
typename
>
class
ALLOC
>
133
RecordCounter
<
ALLOC
>::
RecordCounter
(
134
RecordCounter
<
ALLOC
>&&
from
,
135
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
136
parsers__
(
std
::
move
(
from
.
parsers__
),
alloc
),
137
ranges__
(
std
::
move
(
from
.
ranges__
),
alloc
),
138
thread_ranges__
(
std
::
move
(
from
.
thread_ranges__
),
alloc
),
139
nodeId2columns__
(
std
::
move
(
from
.
nodeId2columns__
)),
140
last_DB_countings__
(
std
::
move
(
from
.
last_DB_countings__
),
alloc
),
141
last_DB_ids__
(
std
::
move
(
from
.
last_DB_ids__
)),
142
last_nonDB_countings__
(
std
::
move
(
from
.
last_nonDB_countings__
),
alloc
),
143
last_nonDB_ids__
(
std
::
move
(
from
.
last_nonDB_ids__
)),
144
max_nb_threads__
(
from
.
max_nb_threads__
),
145
min_nb_rows_per_thread__
(
from
.
min_nb_rows_per_thread__
) {
146
GUM_CONS_MOV
(
RecordCounter
);
147
}
148
149
150
/// move constructor
151
template
<
template
<
typename
>
class
ALLOC
>
152
RecordCounter
<
ALLOC
>::
RecordCounter
(
RecordCounter
<
ALLOC
>&&
from
) :
153
RecordCounter
<
ALLOC
>(
std
::
move
(
from
),
from
.
getAllocator
()) {}
154
155
156
/// virtual copy constructor with a given allocator
157
template
<
template
<
typename
>
class
ALLOC
>
158
RecordCounter
<
ALLOC
>*
RecordCounter
<
ALLOC
>::
clone
(
159
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
)
const
{
160
ALLOC
<
RecordCounter
<
ALLOC
> >
allocator
(
alloc
);
161
RecordCounter
<
ALLOC
>*
new_counter
=
allocator
.
allocate
(1);
162
try
{
163
allocator
.
construct
(
new_counter
, *
this
,
alloc
);
164
}
catch
(...) {
165
allocator
.
deallocate
(
new_counter
, 1);
166
throw
;
167
}
168
169
return
new_counter
;
170
}
171
172
173
/// virtual copy constructor
174
template
<
template
<
typename
>
class
ALLOC
>
175
RecordCounter
<
ALLOC
>*
RecordCounter
<
ALLOC
>::
clone
()
const
{
176
return
clone
(
this
->
getAllocator
());
177
}
178
179
180
/// destructor
181
template
<
template
<
typename
>
class
ALLOC
>
182
RecordCounter
<
ALLOC
>::~
RecordCounter
() {
183
GUM_DESTRUCTOR
(
RecordCounter
);
184
}
185
186
187
/// copy operator
188
template
<
template
<
typename
>
class
ALLOC
>
189
RecordCounter
<
ALLOC
>&
190
RecordCounter
<
ALLOC
>::
operator
=(
const
RecordCounter
<
ALLOC
>&
from
) {
191
if
(
this
!= &
from
) {
192
parsers__
=
from
.
parsers__
;
193
ranges__
=
from
.
ranges__
;
194
thread_ranges__
=
from
.
thread_ranges__
;
195
nodeId2columns__
=
from
.
nodeId2columns__
;
196
last_DB_countings__
=
from
.
last_DB_countings__
;
197
last_DB_ids__
=
from
.
last_DB_ids__
;
198
last_nonDB_countings__
=
from
.
last_nonDB_countings__
;
199
last_nonDB_ids__
=
from
.
last_nonDB_ids__
;
200
max_nb_threads__
=
from
.
max_nb_threads__
;
201
min_nb_rows_per_thread__
=
from
.
min_nb_rows_per_thread__
;
202
}
203
return
*
this
;
204
}
205
206
207
/// move operator
208
template
<
template
<
typename
>
class
ALLOC
>
209
RecordCounter
<
ALLOC
>&
210
RecordCounter
<
ALLOC
>::
operator
=(
RecordCounter
<
ALLOC
>&&
from
) {
211
if
(
this
!= &
from
) {
212
parsers__
=
std
::
move
(
from
.
parsers__
);
213
ranges__
=
std
::
move
(
from
.
ranges__
);
214
thread_ranges__
=
std
::
move
(
from
.
thread_ranges__
);
215
nodeId2columns__
=
std
::
move
(
from
.
nodeId2columns__
);
216
last_DB_countings__
=
std
::
move
(
from
.
last_DB_countings__
);
217
last_DB_ids__
=
std
::
move
(
from
.
last_DB_ids__
);
218
last_nonDB_countings__
=
std
::
move
(
from
.
last_nonDB_countings__
);
219
last_nonDB_ids__
=
std
::
move
(
from
.
last_nonDB_ids__
);
220
max_nb_threads__
=
from
.
max_nb_threads__
;
221
min_nb_rows_per_thread__
=
from
.
min_nb_rows_per_thread__
;
222
}
223
return
*
this
;
224
}
225
226
227
/// clears all the last database-parsed countings from memory
228
template
<
template
<
typename
>
class
ALLOC
>
229
void
RecordCounter
<
ALLOC
>::
clear
() {
230
last_DB_countings__
.
clear
();
231
last_DB_ids__
.
clear
();
232
last_nonDB_countings__
.
clear
();
233
last_nonDB_ids__
.
clear
();
234
}
235
236
237
/// changes the max number of threads used to parse the database
238
template
<
template
<
typename
>
class
ALLOC
>
239
void
RecordCounter
<
ALLOC
>::
setMaxNbThreads
(
const
std
::
size_t
nb
)
const
{
240
if
(
nb
==
std
::
size_t
(0) || !
isOMP
())
241
max_nb_threads__
=
std
::
size_t
(1);
242
else
243
max_nb_threads__
=
nb
;
244
}
245
246
247
/// returns the number of threads used to parse the database
248
template
<
template
<
typename
>
class
ALLOC
>
249
INLINE
std
::
size_t
RecordCounter
<
ALLOC
>::
nbThreads
()
const
{
250
return
max_nb_threads__
;
251
}
252
253
254
// changes the number min of rows a thread should process in a
255
// multithreading context
256
template
<
template
<
typename
>
class
ALLOC
>
257
void
258
RecordCounter
<
ALLOC
>::
setMinNbRowsPerThread
(
const
std
::
size_t
nb
)
const
{
259
if
(
nb
==
std
::
size_t
(0))
260
min_nb_rows_per_thread__
=
std
::
size_t
(1);
261
else
262
min_nb_rows_per_thread__
=
nb
;
263
}
264
265
266
/// returns the minimum of rows that each thread should process
267
template
<
template
<
typename
>
class
ALLOC
>
268
INLINE
std
::
size_t
RecordCounter
<
ALLOC
>::
minNbRowsPerThread
()
const
{
269
return
min_nb_rows_per_thread__
;
270
}
271
272
273
/// compute and raise the exception when some variables are continuous
274
template
<
template
<
typename
>
class
ALLOC
>
275
void
RecordCounter
<
ALLOC
>::
raiseCheckException__
(
276
const
std
::
vector
<
std
::
string
,
ALLOC
<
std
::
string
> >&
bad_vars
)
const
{
277
// generate the exception
278
std
::
stringstream
msg
;
279
msg
<<
"Counts cannot be performed on continuous variables. "
;
280
msg
<<
"Unfortunately the following variable"
;
281
if
(
bad_vars
.
size
() == 1)
282
msg
<<
" is continuous: "
<<
bad_vars
[0];
283
else
{
284
msg
<<
"s are continuous: "
;
285
bool
deja
=
false
;
286
for
(
const
auto
&
name
:
bad_vars
) {
287
if
(
deja
)
288
msg
<<
", "
;
289
else
290
deja
=
true
;
291
msg
<<
name
;
292
}
293
}
294
GUM_ERROR
(
TypeError
,
msg
.
str
());
295
}
296
297
298
/// check that all the variables in an idset are discrete
299
template
<
template
<
typename
>
class
ALLOC
>
300
void
RecordCounter
<
ALLOC
>::
checkDiscreteVariables__
(
301
const
IdCondSet
<
ALLOC
>&
ids
)
const
{
302
const
std
::
size_t
size
=
ids
.
size
();
303
const
DatabaseTable
<
ALLOC
>&
database
=
parsers__
[0].
data
.
database
();
304
305
if
(
nodeId2columns__
.
empty
()) {
306
// check all the ids
307
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
size
; ++
i
) {
308
if
(
database
.
variable
(
i
).
varType
() ==
VarType
::
Continuous
) {
309
// here, var i does not correspond to a discrete variable.
310
// we check whether there are other non discrete variables, so that
311
// we can generate an exception mentioning all these variables
312
std
::
vector
<
std
::
string
,
ALLOC
<
std
::
string
> >
bad_vars
{
313
database
.
variable
(
i
).
name
()};
314
for
(++
i
;
i
<
size
; ++
i
) {
315
if
(
database
.
variable
(
i
).
varType
() ==
VarType
::
Continuous
)
316
bad_vars
.
push_back
(
database
.
variable
(
i
).
name
());
317
}
318
raiseCheckException__
(
bad_vars
);
319
}
320
}
321
}
else
{
322
// check all the ids
323
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
size
; ++
i
) {
324
// get the position of the variable in the database
325
std
::
size_t
pos
=
nodeId2columns__
.
second
(
ids
[
i
]);
326
327
if
(
database
.
variable
(
pos
).
varType
() ==
VarType
::
Continuous
) {
328
// here, id does not correspond to a discrete variable.
329
// we check whether there are other non discrete variables, so that
330
// we can generate an exception mentioning all these variables
331
std
::
vector
<
std
::
string
,
ALLOC
<
std
::
string
> >
bad_vars
{
332
database
.
variable
(
pos
).
name
()};
333
for
(++
i
;
i
<
size
; ++
i
) {
334
pos
=
nodeId2columns__
.
second
(
ids
[
i
]);
335
if
(
database
.
variable
(
pos
).
varType
() ==
VarType
::
Continuous
)
336
bad_vars
.
push_back
(
database
.
variable
(
pos
).
name
());
337
}
338
raiseCheckException__
(
bad_vars
);
339
}
340
}
341
}
342
}
343
344
345
/// returns the mapping from ids to column positions in the database
346
template
<
template
<
typename
>
class
ALLOC
>
347
INLINE
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
348
RecordCounter
<
ALLOC
>::
nodeId2Columns
()
const
{
349
return
nodeId2columns__
;
350
}
351
352
353
/// returns the database on which we perform the counts
354
template
<
template
<
typename
>
class
ALLOC
>
355
const
DatabaseTable
<
ALLOC
>&
RecordCounter
<
ALLOC
>::
database
()
const
{
356
return
parsers__
[0].
data
.
database
();
357
}
358
359
360
/// returns the counts for a given set of nodes
361
template
<
template
<
typename
>
class
ALLOC
>
362
INLINE
const
std
::
vector
<
double
,
ALLOC
<
double
> >&
363
RecordCounter
<
ALLOC
>::
counts
(
const
IdCondSet
<
ALLOC
>&
ids
,
364
const
bool
check_discrete_vars
) {
365
// if the idset is empty, return an empty vector
366
if
(
ids
.
empty
()) {
367
last_nonDB_ids__
.
clear
();
368
last_nonDB_countings__
.
clear
();
369
return
last_nonDB_countings__
;
370
}
371
372
// check whether we can extract the vector we wish to return from
373
// some already computed counting vector
374
if
(
last_nonDB_ids__
.
contains
(
ids
))
375
return
extractFromCountings__
(
ids
,
376
last_nonDB_ids__
,
377
last_nonDB_countings__
);
378
else
if
(
last_DB_ids__
.
contains
(
ids
))
379
return
extractFromCountings__
(
ids
,
last_DB_ids__
,
last_DB_countings__
);
380
else
{
381
if
(
check_discrete_vars
)
checkDiscreteVariables__
(
ids
);
382
return
countFromDatabase__
(
ids
);
383
}
384
}
385
386
387
// returns a mapping from the nodes ids to the columns of the database
388
// for a given sequence of ids
389
template
<
template
<
typename
>
class
ALLOC
>
390
HashTable
<
NodeId
,
std
::
size_t
>
RecordCounter
<
ALLOC
>::
getNodeIds2Columns__
(
391
const
IdCondSet
<
ALLOC
>&
ids
)
const
{
392
HashTable
<
NodeId
,
std
::
size_t
>
res
(
ids
.
size
());
393
if
(
nodeId2columns__
.
empty
()) {
394
for
(
const
auto
id
:
ids
) {
395
res
.
insert
(
id
,
std
::
size_t
(
id
));
396
}
397
}
else
{
398
for
(
const
auto
id
:
ids
) {
399
res
.
insert
(
id
,
nodeId2columns__
.
second
(
id
));
400
}
401
}
402
return
res
;
403
}
404
405
406
/// extracts some new countings from previously computed ones
407
template
<
template
<
typename
>
class
ALLOC
>
408
INLINE
std
::
vector
<
double
,
ALLOC
<
double
> >&
409
RecordCounter
<
ALLOC
>::
extractFromCountings__
(
410
const
IdCondSet
<
ALLOC
>&
subset_ids
,
411
const
IdCondSet
<
ALLOC
>&
superset_ids
,
412
const
std
::
vector
<
double
,
ALLOC
<
double
> >&
superset_vect
) {
413
// get a mapping between the node Ids and their columns in the database.
414
// This should be stored into nodeId2columns__, except if the latter is
415
// empty, in which case there is an identity mapping
416
const
auto
nodeId2columns
=
getNodeIds2Columns__
(
superset_ids
);
417
418
// we first determine the size of the output vector, the domain of
419
// each of its variables and their offsets in the output vector
420
const
auto
&
database
=
parsers__
[0].
data
.
database
();
421
std
::
size_t
result_vect_size
=
std
::
size_t
(1);
422
for
(
const
auto
id
:
subset_ids
) {
423
result_vect_size
*=
database
.
domainSize
(
nodeId2columns
[
id
]);
424
}
425
426
// we create the output vector
427
const
std
::
size_t
subset_ids_size
=
std
::
size_t
(
subset_ids
.
size
());
428
std
::
vector
<
double
,
ALLOC
<
double
> >
result_vect
(
result_vect_size
, 0.0);
429
430
431
// check if the subset_ids is the beginning of the sequence of superset_ids
432
// if this is the case, then we can outer loop over the variables not in
433
// subset_ids and, for each iteration of this loop add a vector of size
434
// result_size to result_vect
435
bool
subset_begin
=
true
;
436
for
(
std
::
size_t
i
= 0;
i
<
subset_ids_size
; ++
i
) {
437
if
(
superset_ids
.
pos
(
subset_ids
[
i
]) !=
i
) {
438
subset_begin
=
false
;
439
break
;
440
}
441
}
442
443
if
(
subset_begin
) {
444
const
std
::
size_t
superset_vect_size
=
superset_vect
.
size
();
445
std
::
size_t
i
=
std
::
size_t
(0);
446
while
(
i
<
superset_vect_size
) {
447
for
(
std
::
size_t
j
=
std
::
size_t
(0);
j
<
result_vect_size
; ++
j
, ++
i
) {
448
result_vect
[
j
] +=
superset_vect
[
i
];
449
}
450
}
451
452
// save the subset_ids and the result vector
453
try
{
454
last_nonDB_ids__
=
subset_ids
;
455
last_nonDB_countings__
=
std
::
move
(
result_vect
);
456
return
last_nonDB_countings__
;
457
}
catch
(...) {
458
last_nonDB_ids__
.
clear
();
459
last_nonDB_countings__
.
clear
();
460
throw
;
461
}
462
}
463
464
465
// check if subset_ids is the end of the sequence of superset_ids.
466
// In this case, as above, there are two simple loops to perform the
467
// countings
468
bool
subset_end
=
true
;
469
const
std
::
size_t
superset_ids_size
=
std
::
size_t
(
superset_ids
.
size
());
470
for
(
std
::
size_t
i
= 0;
i
<
subset_ids_size
; ++
i
) {
471
if
(
superset_ids
.
pos
(
subset_ids
[
i
])
472
!=
i
+
superset_ids_size
-
subset_ids_size
) {
473
subset_end
=
false
;
474
break
;
475
}
476
}
477
478
if
(
subset_end
) {
479
// determine the size of the vector corresponding to the variables
480
// not belonging to subset_ids
481
std
::
size_t
vect_not_subset_size
=
std
::
size_t
(1);
482
for
(
std
::
size_t
i
=
std
::
size_t
(0);
483
i
<
superset_ids_size
-
subset_ids_size
;
484
++
i
)
485
vect_not_subset_size
486
*=
database
.
domainSize
(
nodeId2columns
[
superset_ids
[
i
]]);
487
488
// perform the two loops
489
std
::
size_t
i
=
std
::
size_t
(0);
490
for
(
std
::
size_t
j
=
std
::
size_t
(0);
j
<
result_vect_size
; ++
j
) {
491
for
(
std
::
size_t
k
=
std
::
size_t
(0);
k
<
vect_not_subset_size
;
492
++
k
, ++
i
) {
493
result_vect
[
j
] +=
superset_vect
[
i
];
494
}
495
}
496
497
// save the subset_ids and the result vector
498
try
{
499
last_nonDB_ids__
=
subset_ids
;
500
last_nonDB_countings__
=
std
::
move
(
result_vect
);
501
return
last_nonDB_countings__
;
502
}
catch
(...) {
503
last_nonDB_ids__
.
clear
();
504
last_nonDB_countings__
.
clear
();
505
throw
;
506
}
507
}
508
509
510
// here subset_ids is a subset of superset_ids neither prefixing nor
511
// postfixing it. So the computation is somewhat more complicated.
512
513
// We will parse the superset_vect sequentially (using ++ operator).
514
// Sometimes, we will need to change the offset of the cell of result_vect
515
// that will be affected, sometimes not. Vector before_incr will indicate
516
// whether we need to change the offset (value = 0) or not (value different
517
// from 0). Vectors result_domain will indicate how this offset should be
518
// computed. Here is an example of the values of these vectors. Assume that
519
// superset_ids = <A,B,C,D,E> and subset_ids = <A,D,C>. Then, the three
520
// vectors before_incr, result_domain and result_offset are indexed w.r.t.
521
// A,C,D, i.e., w.r.t. to the variables in subset_ids but order w.r.t.
522
// superset_ids (this is convenient as we will parse superset_vect
523
// sequentially. For a variable or a set of variables X, let M_X denote the
524
// domain size of X. Then the contents of the three vectors are as follows:
525
// before_incr = {0, M_B, 0} (this means that whenever we iterate over B's
526
// values, the offset in result_vect does not change)
527
// result_domain = { M_A, M_C, M_D } (i.e., the domain sizes of the variables
528
// in subset_ids, order w.r.t. superset_ids)
529
// result_offset = { 1, M_A*M_D, M_A } (this corresponds to the offsets
530
// in result_vect of variables A, C and D)
531
// Vector superset_order = { 0, 2, 1} : this is a map from the indices of
532
// the variables in subset_ids to the indices of these variables in the
533
// three vectors described above. For instance, the "2" means that variable
534
// D (which is at index 1 in subset_ids) is located at index 2 in vector
535
// before_incr
536
std
::
vector
<
std
::
size_t
>
before_incr
(
subset_ids_size
);
537
std
::
vector
<
std
::
size_t
>
result_domain
(
subset_ids_size
);
538
std
::
vector
<
std
::
size_t
>
result_offset
(
subset_ids_size
);
539
{
540
std
::
size_t
result_domain_size
=
std
::
size_t
(1);
541
std
::
size_t
tmp_before_incr
=
std
::
size_t
(1);
542
std
::
vector
<
std
::
size_t
>
superset_order
(
subset_ids_size
);
543
544
for
(
std
::
size_t
h
=
std
::
size_t
(0),
j
=
std
::
size_t
(0);
545
j
<
subset_ids_size
;
546
++
h
) {
547
if
(
subset_ids
.
exists
(
superset_ids
[
h
])) {
548
before_incr
[
j
] =
tmp_before_incr
- 1;
549
superset_order
[
subset_ids
.
pos
(
superset_ids
[
h
])] =
j
;
550
tmp_before_incr
= 1;
551
++
j
;
552
}
else
{
553
tmp_before_incr
554
*=
database
.
domainSize
(
nodeId2columns
[
superset_ids
[
h
]]);
555
}
556
}
557
558
// compute the offsets in the order of the superset_ids
559
for
(
std
::
size_t
i
= 0;
i
<
subset_ids
.
size
(); ++
i
) {
560
const
std
::
size_t
domain_size
561
=
database
.
domainSize
(
nodeId2columns
[
subset_ids
[
i
]]);
562
const
std
::
size_t
j
=
superset_order
[
i
];
563
result_domain
[
j
] =
domain_size
;
564
result_offset
[
j
] =
result_domain_size
;
565
result_domain_size
*=
domain_size
;
566
}
567
}
568
569
std
::
vector
<
std
::
size_t
>
result_value
(
result_domain
);
570
std
::
vector
<
std
::
size_t
>
current_incr
(
before_incr
);
571
std
::
vector
<
std
::
size_t
>
result_down
(
result_offset
);
572
573
for
(
std
::
size_t
j
=
std
::
size_t
(0);
j
<
result_down
.
size
(); ++
j
) {
574
result_down
[
j
] *= (
result_domain
[
j
] - 1);
575
}
576
577
// now we can loop over the superset_vect to fill result_vect
578
const
std
::
size_t
superset_vect_size
=
superset_vect
.
size
();
579
std
::
size_t
the_result_offset
=
std
::
size_t
(0);
580
for
(
std
::
size_t
h
=
std
::
size_t
(0);
h
<
superset_vect_size
; ++
h
) {
581
result_vect
[
the_result_offset
] +=
superset_vect
[
h
];
582
583
// update the offset of result_vect
584
for
(
std
::
size_t
k
= 0;
k
<
current_incr
.
size
(); ++
k
) {
585
// check if we need modify result_offset
586
if
(
current_incr
[
k
]) {
587
--
current_incr
[
k
];
588
break
;
589
}
590
591
current_incr
[
k
] =
before_incr
[
k
];
592
593
// here we shall modify result_offset
594
--
result_value
[
k
];
595
596
if
(
result_value
[
k
]) {
597
the_result_offset
+=
result_offset
[
k
];
598
break
;
599
}
600
601
result_value
[
k
] =
result_domain
[
k
];
602
the_result_offset
-=
result_down
[
k
];
603
}
604
}
605
606
// save the subset_ids and the result vector
607
try
{
608
last_nonDB_ids__
=
subset_ids
;
609
last_nonDB_countings__
=
std
::
move
(
result_vect
);
610
return
last_nonDB_countings__
;
611
}
catch
(...) {
612
last_nonDB_ids__
.
clear
();
613
last_nonDB_countings__
.
clear
();
614
throw
;
615
}
616
}
617
618
619
/// parse the database to produce new countings
620
template
<
template
<
typename
>
class
ALLOC
>
621
std
::
vector
<
double
,
ALLOC
<
double
> >&
622
RecordCounter
<
ALLOC
>::
countFromDatabase__
(
const
IdCondSet
<
ALLOC
>&
ids
) {
623
// if the ids vector is empty or the database is empty, return an
624
// empty vector
625
const
auto
&
database
=
parsers__
[0].
data
.
database
();
626
if
(
ids
.
empty
() ||
database
.
empty
() ||
thread_ranges__
.
empty
()) {
627
last_nonDB_countings__
.
clear
();
628
last_nonDB_ids__
.
clear
();
629
return
last_nonDB_countings__
;
630
}
631
632
// we translate the ids into their corresponding columns in the
633
// DatabaseTable
634
const
auto
nodeId2columns
=
getNodeIds2Columns__
(
ids
);
635
636
// we first determine the size of the counting vector, the domain of
637
// each of its variables and their offsets in the output vector
638
const
std
::
size_t
ids_size
=
ids
.
size
();
639
std
::
size_t
counting_vect_size
=
std
::
size_t
(1);
640
std
::
vector
<
std
::
size_t
,
ALLOC
<
std
::
size_t
> >
domain_sizes
(
ids_size
);
641
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
642
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
643
cols_offsets
(
ids_size
);
644
{
645
std
::
size_t
i
=
std
::
size_t
(0);
646
for
(
const
auto
id
:
ids
) {
647
const
std
::
size_t
domain_size
=
database
.
domainSize
(
nodeId2columns
[
id
]);
648
domain_sizes
[
i
] =
domain_size
;
649
cols_offsets
[
i
].
first
=
nodeId2columns
[
id
];
650
cols_offsets
[
i
].
second
=
counting_vect_size
;
651
counting_vect_size
*=
domain_size
;
652
++
i
;
653
}
654
}
655
656
// we sort the columns and offsets by increasing column index. This
657
// may speed up threaded countings by improving the cacheline hits
658
std
::
sort
(
cols_offsets
.
begin
(),
659
cols_offsets
.
end
(),
660
[](
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
a
,
661
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
b
) ->
bool
{
662
return
a
.
first
<
b
.
first
;
663
});
664
665
// create parsers if needed
666
const
std
::
size_t
nb_ranges
=
thread_ranges__
.
size
();
667
const
std
::
size_t
nb_threads
668
=
nb_ranges
<=
max_nb_threads__
?
nb_ranges
:
max_nb_threads__
;
669
while
(
parsers__
.
size
() <
nb_threads
) {
670
ThreadData
<
DBRowGeneratorParser
<
ALLOC
> >
new_parser
(
parsers__
[0]);
671
parsers__
.
push_back
(
std
::
move
(
new_parser
));
672
}
673
674
// set the columns of interest for each parser. This specifies to the
675
// parser which columns are used for the countings. This is important
676
// for parsers like the EM parser that complete unobserved variables.
677
std
::
vector
<
std
::
size_t
,
ALLOC
<
std
::
size_t
> >
cols_of_interest
(
ids_size
);
678
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
ids_size
; ++
i
) {
679
cols_of_interest
[
i
] =
cols_offsets
[
i
].
first
;
680
}
681
for
(
auto
&
parser
:
parsers__
) {
682
parser
.
data
.
setColumnsOfInterest
(
cols_of_interest
);
683
}
684
685
// allocate all the counting vectors, including that which will add
686
// all the results provided by the threads. We initialize once and
687
// for all these vectors with zeroes
688
std
::
vector
<
double
,
ALLOC
<
double
> >
counting_vect
(
counting_vect_size
,
689
0.0);
690
std
::
vector
<
ThreadData
<
std
::
vector
<
double
,
ALLOC
<
double
> > >,
691
ALLOC
<
ThreadData
<
std
::
vector
<
double
,
ALLOC
<
double
> > > > >
692
thread_countings
(
693
nb_threads
,
694
ThreadData
<
std
::
vector
<
double
,
ALLOC
<
double
> > >(
counting_vect
));
695
696
// launch the threads
697
// here we use openMP for launching the threads because, experimentally,
698
// it seems to provide results that are twice as fast as the results
699
// with the std::thread
700
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_ranges
;
i
+=
nb_threads
) {
701
#
pragma
omp
parallel
num_threads
(
int
(
nb_threads
)
)
702
{
703
// get the number of the thread
704
const
std
::
size_t
this_thread
=
getThreadNumber
();
705
if
(
this_thread
+
i
<
nb_ranges
) {
706
DBRowGeneratorParser
<
ALLOC
>&
parser
=
parsers__
[
this_thread
].
data
;
707
parser
.
setRange
(
thread_ranges__
[
this_thread
+
i
].
first
,
708
thread_ranges__
[
this_thread
+
i
].
second
);
709
std
::
vector
<
double
,
ALLOC
<
double
> >&
countings
710
=
thread_countings
[
this_thread
].
data
;
711
712
// parse the database
713
try
{
714
while
(
parser
.
hasRows
()) {
715
// get the observed rows
716
const
DBRow
<
DBTranslatedValue
>&
row
=
parser
.
row
();
717
718
// fill the counts for the current row
719
std
::
size_t
offset
=
std
::
size_t
(0);
720
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
ids_size
; ++
i
) {
721
offset
+=
row
[
cols_offsets
[
i
].
first
].
discr_val
722
*
cols_offsets
[
i
].
second
;
723
}
724
725
countings
[
offset
] +=
row
.
weight
();
726
}
727
}
catch
(
NotFound
&) {}
// this exception is raised by the row filter
728
// if the row generators create no output row
729
// from the last rows of the database
730
}
731
}
732
}
733
734
735
// add the counts to counting_vect
736
for
(
std
::
size_t
k
=
std
::
size_t
(0);
k
<
nb_threads
; ++
k
) {
737
const
auto
&
thread_counting
=
thread_countings
[
k
].
data
;
738
for
(
std
::
size_t
r
=
std
::
size_t
(0);
r
<
counting_vect_size
; ++
r
) {
739
counting_vect
[
r
] +=
thread_counting
[
r
];
740
}
741
}
742
743
// save the final results
744
last_DB_ids__
=
ids
;
745
last_DB_countings__
=
std
::
move
(
counting_vect
);
746
747
return
last_DB_countings__
;
748
}
749
750
751
/// the method used by threads to produce countings by parsing the database
752
template
<
template
<
typename
>
class
ALLOC
>
753
void
RecordCounter
<
ALLOC
>::
threadedCount__
(
754
const
std
::
size_t
begin
,
755
const
std
::
size_t
end
,
756
DBRowGeneratorParser
<
ALLOC
>&
parser
,
757
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
758
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
759
cols_offsets
,
760
std
::
vector
<
double
,
ALLOC
<
double
> >&
countings
) {
761
parser
.
setRange
(
begin
,
end
);
762
763
try
{
764
const
std
::
size_t
nb_columns
=
cols_offsets
.
size
();
765
while
(
parser
.
hasRows
()) {
766
// get the observed filtered rows
767
const
DBRow
<
DBTranslatedValue
>&
row
=
parser
.
row
();
768
769
// fill the counts for the current row
770
std
::
size_t
offset
=
std
::
size_t
(0);
771
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_columns
; ++
i
) {
772
offset
773
+=
row
[
cols_offsets
[
i
].
first
].
discr_val
*
cols_offsets
[
i
].
second
;
774
}
775
776
countings
[
offset
] +=
row
.
weight
();
777
}
778
}
catch
(
NotFound
&) {}
// this exception is raised by the row filter if the
779
// row generators create no output row from the last
780
// rows of the database
781
}
782
783
784
/// checks that the ranges passed in argument are ok or raise an exception
785
template
<
template
<
typename
>
class
ALLOC
>
786
template
<
template
<
typename
>
class
XALLOC
>
787
void
RecordCounter
<
ALLOC
>::
checkRanges__
(
788
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
789
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
790
new_ranges
)
const
{
791
const
std
::
size_t
dbsize
=
parsers__
[0].
data
.
database
().
nbRows
();
792
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
793
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
794
incorrect_ranges
;
795
for
(
const
auto
&
range
:
new_ranges
) {
796
if
((
range
.
first
>=
range
.
second
) || (
range
.
second
>
dbsize
)) {
797
incorrect_ranges
.
push_back
(
range
);
798
}
799
}
800
if
(!
incorrect_ranges
.
empty
()) {
801
std
::
stringstream
str
;
802
str
<<
"It is impossible to set the ranges because the following one"
;
803
if
(
incorrect_ranges
.
size
() > 1)
804
str
<<
"s are incorrect: "
;
805
else
806
str
<<
" is incorrect: "
;
807
bool
deja
=
false
;
808
for
(
const
auto
&
range
:
incorrect_ranges
) {
809
if
(
deja
)
810
str
<<
", "
;
811
else
812
deja
=
true
;
813
str
<<
'['
<<
range
.
first
<<
';'
<<
range
.
second
<<
')'
;
814
}
815
816
GUM_ERROR
(
OutOfBounds
,
str
.
str
());
817
}
818
}
819
820
821
/// sets the ranges within which each thread should perform its computations
822
template
<
template
<
typename
>
class
ALLOC
>
823
void
RecordCounter
<
ALLOC
>::
dispatchRangesToThreads__
() {
824
thread_ranges__
.
clear
();
825
826
// ensure that ranges__ contains the ranges asked by the user
827
bool
add_range
=
false
;
828
if
(
ranges__
.
empty
()) {
829
const
auto
&
database
=
parsers__
[0].
data
.
database
();
830
ranges__
.
push_back
(
831
std
::
pair
<
std
::
size_t
,
std
::
size_t
>(
std
::
size_t
(0),
832
database
.
nbRows
()));
833
add_range
=
true
;
834
}
835
836
// dispatch the ranges
837
for
(
const
auto
&
range
:
ranges__
) {
838
if
(
range
.
second
>
range
.
first
) {
839
const
std
::
size_t
range_size
=
range
.
second
-
range
.
first
;
840
std
::
size_t
nb_threads
=
range_size
/
min_nb_rows_per_thread__
;
841
if
(
nb_threads
< 1)
842
nb_threads
= 1;
843
else
if
(
nb_threads
>
max_nb_threads__
)
844
nb_threads
=
max_nb_threads__
;
845
std
::
size_t
nb_rows_par_thread
=
range_size
/
nb_threads
;
846
std
::
size_t
rest_rows
=
range_size
-
nb_rows_par_thread
*
nb_threads
;
847
848
std
::
size_t
begin_index
=
range
.
first
;
849
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_threads
; ++
i
) {
850
std
::
size_t
end_index
=
begin_index
+
nb_rows_par_thread
;
851
if
(
rest_rows
!=
std
::
size_t
(0)) {
852
++
end_index
;
853
--
rest_rows
;
854
}
855
thread_ranges__
.
push_back
(
856
std
::
pair
<
std
::
size_t
,
std
::
size_t
>(
begin_index
,
end_index
));
857
begin_index
=
end_index
;
858
}
859
}
860
}
861
if
(
add_range
)
ranges__
.
clear
();
862
863
// sort ranges by decreasing range size, so that if the number of
864
// ranges exceeds the number of threads allowed, we start a first round of
865
// threads with the highest range, then another round with lower ranges,
866
// and so on until all the ranges have been processed
867
std
::
sort
(
thread_ranges__
.
begin
(),
868
thread_ranges__
.
end
(),
869
[](
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
a
,
870
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
b
) ->
bool
{
871
return
(
a
.
second
-
a
.
first
) > (
b
.
second
-
b
.
first
);
872
});
873
}
874
875
876
/// sets new ranges to perform the countings
877
template
<
template
<
typename
>
class
ALLOC
>
878
template
<
template
<
typename
>
class
XALLOC
>
879
void
RecordCounter
<
ALLOC
>::
setRanges
(
880
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
881
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
882
new_ranges
) {
883
// first, we check that all ranges are within the database's bounds
884
checkRanges__
(
new_ranges
);
885
886
// since the ranges are OK, save them and clear the counting caches
887
const
std
::
size_t
new_size
=
new_ranges
.
size
();
888
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
889
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
890
ranges
(
new_size
);
891
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
new_size
; ++
i
) {
892
ranges
[
i
].
first
=
new_ranges
[
i
].
first
;
893
ranges
[
i
].
second
=
new_ranges
[
i
].
second
;
894
}
895
896
clear
();
897
ranges__
=
std
::
move
(
ranges
);
898
899
// dispatch the ranges to the threads
900
dispatchRangesToThreads__
();
901
}
902
903
904
/// reset the ranges to the one range corresponding to the whole database
905
template
<
template
<
typename
>
class
ALLOC
>
906
void
RecordCounter
<
ALLOC
>::
clearRanges
() {
907
if
(
ranges__
.
empty
())
return
;
908
clear
();
909
ranges__
.
clear
();
910
dispatchRangesToThreads__
();
911
}
912
913
914
/// returns the current ranges
915
template
<
template
<
typename
>
class
ALLOC
>
916
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
917
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
918
RecordCounter
<
ALLOC
>::
ranges
()
const
{
919
return
ranges__
;
920
}
921
922
923
/// assign a new Bayes net to all the counter's generators depending on a BN
924
template
<
template
<
typename
>
class
ALLOC
>
925
template
<
typename
GUM_SCALAR
>
926
INLINE
void
927
RecordCounter
<
ALLOC
>::
setBayesNet
(
const
BayesNet
<
GUM_SCALAR
>&
new_bn
) {
928
// remove the caches
929
clear
();
930
931
// assign the new BN
932
for
(
auto
&
xparser
:
parsers__
) {
933
xparser
.
data
.
setBayesNet
(
new_bn
);
934
}
935
}
936
937
}
/* namespace learning */
938
939
}
/* namespace gum */
940
941
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31