aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
recordCounter_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief The class that computes countings of observations from the database.
24
*
25
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
26
*/
27
28
#
include
<
agrum
/
tools
/
stattests
/
recordCounter
.
h
>
29
30
31
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
32
33
namespace
gum
{
34
35
namespace
learning
{
36
37
38
/// returns the allocator used by the translator
39
template
<
template
<
typename
>
class
ALLOC
>
40
INLINE
typename
RecordCounter
<
ALLOC
>::
allocator_type
41
RecordCounter
<
ALLOC
>::
getAllocator
()
const
{
42
return
_parsers_
.
get_allocator
();
43
}
44
45
46
/// default constructor
47
template
<
template
<
typename
>
class
ALLOC
>
48
RecordCounter
<
ALLOC
>::
RecordCounter
(
49
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
50
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
51
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
ranges
,
52
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
nodeId2columns
,
53
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
54
_parsers_
(
alloc
),
55
_ranges_
(
alloc
),
_nodeId2columns_
(
nodeId2columns
),
_last_DB_countings_
(
alloc
),
56
_last_DB_ids_
(
alloc
),
_last_nonDB_countings_
(
alloc
),
_last_nonDB_ids_
(
alloc
) {
57
// check that the columns in nodeId2columns do belong to the database
58
const
std
::
size_t
db_nb_cols
=
parser
.
database
().
nbVariables
();
59
for
(
auto
iter
=
nodeId2columns
.
cbegin
();
iter
!=
nodeId2columns
.
cend
(); ++
iter
) {
60
if
(
iter
.
second
() >=
db_nb_cols
) {
61
GUM_ERROR
(
OutOfBounds
,
62
"the mapping between ids and database columns "
63
<<
"is incorrect because Column "
<<
iter
.
second
()
64
<<
" does not belong to the database."
);
65
}
66
}
67
68
// create the parsers. There should always be at least one parser
69
if
(
_max_nb_threads_
<
std
::
size_t
(1))
_max_nb_threads_
=
std
::
size_t
(1);
70
_parsers_
.
reserve
(
_max_nb_threads_
);
71
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
_max_nb_threads_
; ++
i
)
72
_parsers_
.
push_back
(
parser
);
73
74
// check that the ranges are within the bounds of the database and
75
// save them
76
_checkRanges_
(
ranges
);
77
_ranges_
.
reserve
(
ranges
.
size
());
78
for
(
const
auto
&
range
:
ranges
)
79
_ranges_
.
push_back
(
range
);
80
81
// dispatch the ranges for the threads
82
_dispatchRangesToThreads_
();
83
84
GUM_CONSTRUCTOR
(
RecordCounter
);
85
}
86
87
88
/// default constructor
89
template
<
template
<
typename
>
class
ALLOC
>
90
RecordCounter
<
ALLOC
>::
RecordCounter
(
91
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
92
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
nodeId2columns
,
93
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
94
RecordCounter
<
ALLOC
>(
parser
,
95
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
96
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >(),
97
nodeId2columns
,
98
alloc
) {}
99
100
101
/// copy constructor with a given allocator
102
template
<
template
<
typename
>
class
ALLOC
>
103
RecordCounter
<
ALLOC
>::
RecordCounter
(
104
const
RecordCounter
<
ALLOC
>&
from
,
105
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
106
_parsers_
(
from
.
_parsers_
,
alloc
),
107
_ranges_
(
from
.
_ranges_
,
alloc
),
_thread_ranges_
(
from
.
_thread_ranges_
,
alloc
),
108
_nodeId2columns_
(
from
.
_nodeId2columns_
),
109
_last_DB_countings_
(
from
.
_last_DB_countings_
,
alloc
),
_last_DB_ids_
(
from
.
_last_DB_ids_
),
110
_last_nonDB_countings_
(
from
.
_last_nonDB_countings_
,
alloc
),
111
_last_nonDB_ids_
(
from
.
_last_nonDB_ids_
),
_max_nb_threads_
(
from
.
_max_nb_threads_
),
112
_min_nb_rows_per_thread_
(
from
.
_min_nb_rows_per_thread_
) {
113
GUM_CONS_CPY
(
RecordCounter
);
114
}
115
116
117
/// copy constructor
118
template
<
template
<
typename
>
class
ALLOC
>
119
RecordCounter
<
ALLOC
>::
RecordCounter
(
const
RecordCounter
<
ALLOC
>&
from
) :
120
RecordCounter
<
ALLOC
>(
from
,
from
.
getAllocator
()) {}
121
122
123
/// move constructor with a given allocator
124
template
<
template
<
typename
>
class
ALLOC
>
125
RecordCounter
<
ALLOC
>::
RecordCounter
(
126
RecordCounter
<
ALLOC
>&&
from
,
127
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
) :
128
_parsers_
(
std
::
move
(
from
.
_parsers_
),
alloc
),
129
_ranges_
(
std
::
move
(
from
.
_ranges_
),
alloc
),
130
_thread_ranges_
(
std
::
move
(
from
.
_thread_ranges_
),
alloc
),
131
_nodeId2columns_
(
std
::
move
(
from
.
_nodeId2columns_
)),
132
_last_DB_countings_
(
std
::
move
(
from
.
_last_DB_countings_
),
alloc
),
133
_last_DB_ids_
(
std
::
move
(
from
.
_last_DB_ids_
)),
134
_last_nonDB_countings_
(
std
::
move
(
from
.
_last_nonDB_countings_
),
alloc
),
135
_last_nonDB_ids_
(
std
::
move
(
from
.
_last_nonDB_ids_
)),
_max_nb_threads_
(
from
.
_max_nb_threads_
),
136
_min_nb_rows_per_thread_
(
from
.
_min_nb_rows_per_thread_
) {
137
GUM_CONS_MOV
(
RecordCounter
);
138
}
139
140
141
/// move constructor
142
template
<
template
<
typename
>
class
ALLOC
>
143
RecordCounter
<
ALLOC
>::
RecordCounter
(
RecordCounter
<
ALLOC
>&&
from
) :
144
RecordCounter
<
ALLOC
>(
std
::
move
(
from
),
from
.
getAllocator
()) {}
145
146
147
/// virtual copy constructor with a given allocator
148
template
<
template
<
typename
>
class
ALLOC
>
149
RecordCounter
<
ALLOC
>*
RecordCounter
<
ALLOC
>::
clone
(
150
const
typename
RecordCounter
<
ALLOC
>::
allocator_type
&
alloc
)
const
{
151
ALLOC
<
RecordCounter
<
ALLOC
> >
allocator
(
alloc
);
152
RecordCounter
<
ALLOC
>*
new_counter
=
allocator
.
allocate
(1);
153
try
{
154
allocator
.
construct
(
new_counter
, *
this
,
alloc
);
155
}
catch
(...) {
156
allocator
.
deallocate
(
new_counter
, 1);
157
throw
;
158
}
159
160
return
new_counter
;
161
}
162
163
164
/// virtual copy constructor
165
template
<
template
<
typename
>
class
ALLOC
>
166
RecordCounter
<
ALLOC
>*
RecordCounter
<
ALLOC
>::
clone
()
const
{
167
return
clone
(
this
->
getAllocator
());
168
}
169
170
171
/// destructor
172
template
<
template
<
typename
>
class
ALLOC
>
173
RecordCounter
<
ALLOC
>::~
RecordCounter
() {
174
GUM_DESTRUCTOR
(
RecordCounter
);
175
}
176
177
178
/// copy operator
179
template
<
template
<
typename
>
class
ALLOC
>
180
RecordCounter
<
ALLOC
>&
RecordCounter
<
ALLOC
>::
operator
=(
const
RecordCounter
<
ALLOC
>&
from
) {
181
if
(
this
!= &
from
) {
182
_parsers_
=
from
.
_parsers_
;
183
_ranges_
=
from
.
_ranges_
;
184
_thread_ranges_
=
from
.
_thread_ranges_
;
185
_nodeId2columns_
=
from
.
_nodeId2columns_
;
186
_last_DB_countings_
=
from
.
_last_DB_countings_
;
187
_last_DB_ids_
=
from
.
_last_DB_ids_
;
188
_last_nonDB_countings_
=
from
.
_last_nonDB_countings_
;
189
_last_nonDB_ids_
=
from
.
_last_nonDB_ids_
;
190
_max_nb_threads_
=
from
.
_max_nb_threads_
;
191
_min_nb_rows_per_thread_
=
from
.
_min_nb_rows_per_thread_
;
192
}
193
return
*
this
;
194
}
195
196
197
/// move operator
198
template
<
template
<
typename
>
class
ALLOC
>
199
RecordCounter
<
ALLOC
>&
RecordCounter
<
ALLOC
>::
operator
=(
RecordCounter
<
ALLOC
>&&
from
) {
200
if
(
this
!= &
from
) {
201
_parsers_
=
std
::
move
(
from
.
_parsers_
);
202
_ranges_
=
std
::
move
(
from
.
_ranges_
);
203
_thread_ranges_
=
std
::
move
(
from
.
_thread_ranges_
);
204
_nodeId2columns_
=
std
::
move
(
from
.
_nodeId2columns_
);
205
_last_DB_countings_
=
std
::
move
(
from
.
_last_DB_countings_
);
206
_last_DB_ids_
=
std
::
move
(
from
.
_last_DB_ids_
);
207
_last_nonDB_countings_
=
std
::
move
(
from
.
_last_nonDB_countings_
);
208
_last_nonDB_ids_
=
std
::
move
(
from
.
_last_nonDB_ids_
);
209
_max_nb_threads_
=
from
.
_max_nb_threads_
;
210
_min_nb_rows_per_thread_
=
from
.
_min_nb_rows_per_thread_
;
211
}
212
return
*
this
;
213
}
214
215
216
/// clears all the last database-parsed countings from memory
217
template
<
template
<
typename
>
class
ALLOC
>
218
void
RecordCounter
<
ALLOC
>::
clear
() {
219
_last_DB_countings_
.
clear
();
220
_last_DB_ids_
.
clear
();
221
_last_nonDB_countings_
.
clear
();
222
_last_nonDB_ids_
.
clear
();
223
}
224
225
226
/// changes the max number of threads used to parse the database
227
template
<
template
<
typename
>
class
ALLOC
>
228
void
RecordCounter
<
ALLOC
>::
setMaxNbThreads
(
const
std
::
size_t
nb
)
const
{
229
if
(
nb
==
std
::
size_t
(0) || !
isOMP
())
230
_max_nb_threads_
=
std
::
size_t
(1);
231
else
232
_max_nb_threads_
=
nb
;
233
}
234
235
236
/// returns the number of threads used to parse the database
237
template
<
template
<
typename
>
class
ALLOC
>
238
INLINE
std
::
size_t
RecordCounter
<
ALLOC
>::
nbThreads
()
const
{
239
return
_max_nb_threads_
;
240
}
241
242
243
// changes the number min of rows a thread should process in a
244
// multithreading context
245
template
<
template
<
typename
>
class
ALLOC
>
246
void
RecordCounter
<
ALLOC
>::
setMinNbRowsPerThread
(
const
std
::
size_t
nb
)
const
{
247
if
(
nb
==
std
::
size_t
(0))
248
_min_nb_rows_per_thread_
=
std
::
size_t
(1);
249
else
250
_min_nb_rows_per_thread_
=
nb
;
251
}
252
253
254
/// returns the minimum of rows that each thread should process
255
template
<
template
<
typename
>
class
ALLOC
>
256
INLINE
std
::
size_t
RecordCounter
<
ALLOC
>::
minNbRowsPerThread
()
const
{
257
return
_min_nb_rows_per_thread_
;
258
}
259
260
261
/// compute and raise the exception when some variables are continuous
262
template
<
template
<
typename
>
class
ALLOC
>
263
void
RecordCounter
<
ALLOC
>::
_raiseCheckException_
(
264
const
std
::
vector
<
std
::
string
,
ALLOC
<
std
::
string
> >&
bad_vars
)
const
{
265
// generate the exception
266
std
::
stringstream
msg
;
267
msg
<<
"Counts cannot be performed on continuous variables. "
;
268
msg
<<
"Unfortunately the following variable"
;
269
if
(
bad_vars
.
size
() == 1)
270
msg
<<
" is continuous: "
<<
bad_vars
[0];
271
else
{
272
msg
<<
"s are continuous: "
;
273
bool
deja
=
false
;
274
for
(
const
auto
&
name
:
bad_vars
) {
275
if
(
deja
)
276
msg
<<
", "
;
277
else
278
deja
=
true
;
279
msg
<<
name
;
280
}
281
}
282
GUM_ERROR
(
TypeError
,
msg
.
str
())
283
}
284
285
286
/// check that all the variables in an idset are discrete
287
template
<
template
<
typename
>
class
ALLOC
>
288
void
RecordCounter
<
ALLOC
>::
_checkDiscreteVariables_
(
const
IdCondSet
<
ALLOC
>&
ids
)
const
{
289
const
std
::
size_t
size
=
ids
.
size
();
290
const
DatabaseTable
<
ALLOC
>&
database
=
_parsers_
[0].
data
.
database
();
291
292
if
(
_nodeId2columns_
.
empty
()) {
293
// check all the ids
294
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
size
; ++
i
) {
295
if
(
database
.
variable
(
i
).
varType
() ==
VarType
::
Continuous
) {
296
// here, var i does not correspond to a discrete variable.
297
// we check whether there are other non discrete variables, so that
298
// we can generate an exception mentioning all these variables
299
std
::
vector
<
std
::
string
,
ALLOC
<
std
::
string
> >
bad_vars
{
database
.
variable
(
i
).
name
()};
300
for
(++
i
;
i
<
size
; ++
i
) {
301
if
(
database
.
variable
(
i
).
varType
() ==
VarType
::
Continuous
)
302
bad_vars
.
push_back
(
database
.
variable
(
i
).
name
());
303
}
304
_raiseCheckException_
(
bad_vars
);
305
}
306
}
307
}
else
{
308
// check all the ids
309
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
size
; ++
i
) {
310
// get the position of the variable in the database
311
std
::
size_t
pos
=
_nodeId2columns_
.
second
(
ids
[
i
]);
312
313
if
(
database
.
variable
(
pos
).
varType
() ==
VarType
::
Continuous
) {
314
// here, id does not correspond to a discrete variable.
315
// we check whether there are other non discrete variables, so that
316
// we can generate an exception mentioning all these variables
317
std
::
vector
<
std
::
string
,
ALLOC
<
std
::
string
> >
bad_vars
{
318
database
.
variable
(
pos
).
name
()};
319
for
(++
i
;
i
<
size
; ++
i
) {
320
pos
=
_nodeId2columns_
.
second
(
ids
[
i
]);
321
if
(
database
.
variable
(
pos
).
varType
() ==
VarType
::
Continuous
)
322
bad_vars
.
push_back
(
database
.
variable
(
pos
).
name
());
323
}
324
_raiseCheckException_
(
bad_vars
);
325
}
326
}
327
}
328
}
329
330
331
/// returns the mapping from ids to column positions in the database
332
template
<
template
<
typename
>
class
ALLOC
>
333
INLINE
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
334
RecordCounter
<
ALLOC
>::
nodeId2Columns
()
const
{
335
return
_nodeId2columns_
;
336
}
337
338
339
/// returns the database on which we perform the counts
340
template
<
template
<
typename
>
class
ALLOC
>
341
const
DatabaseTable
<
ALLOC
>&
RecordCounter
<
ALLOC
>::
database
()
const
{
342
return
_parsers_
[0].
data
.
database
();
343
}
344
345
346
/// returns the counts for a given set of nodes
347
template
<
template
<
typename
>
class
ALLOC
>
348
INLINE
const
std
::
vector
<
double
,
ALLOC
<
double
> >&
349
RecordCounter
<
ALLOC
>::
counts
(
const
IdCondSet
<
ALLOC
>&
ids
,
350
const
bool
check_discrete_vars
) {
351
// if the idset is empty, return an empty vector
352
if
(
ids
.
empty
()) {
353
_last_nonDB_ids_
.
clear
();
354
_last_nonDB_countings_
.
clear
();
355
return
_last_nonDB_countings_
;
356
}
357
358
// check whether we can extract the vector we wish to return from
359
// some already computed counting vector
360
if
(
_last_nonDB_ids_
.
contains
(
ids
))
361
return
_extractFromCountings_
(
ids
,
_last_nonDB_ids_
,
_last_nonDB_countings_
);
362
else
if
(
_last_DB_ids_
.
contains
(
ids
))
363
return
_extractFromCountings_
(
ids
,
_last_DB_ids_
,
_last_DB_countings_
);
364
else
{
365
if
(
check_discrete_vars
)
_checkDiscreteVariables_
(
ids
);
366
return
_countFromDatabase_
(
ids
);
367
}
368
}
369
370
371
// returns a mapping from the nodes ids to the columns of the database
372
// for a given sequence of ids
373
template
<
template
<
typename
>
class
ALLOC
>
374
HashTable
<
NodeId
,
std
::
size_t
>
375
RecordCounter
<
ALLOC
>::
_getNodeIds2Columns_
(
const
IdCondSet
<
ALLOC
>&
ids
)
const
{
376
HashTable
<
NodeId
,
std
::
size_t
>
res
(
ids
.
size
());
377
if
(
_nodeId2columns_
.
empty
()) {
378
for
(
const
auto
id
:
ids
) {
379
res
.
insert
(
id
,
std
::
size_t
(
id
));
380
}
381
}
else
{
382
for
(
const
auto
id
:
ids
) {
383
res
.
insert
(
id
,
_nodeId2columns_
.
second
(
id
));
384
}
385
}
386
return
res
;
387
}
388
389
390
/// extracts some new countings from previously computed ones
391
template
<
template
<
typename
>
class
ALLOC
>
392
INLINE
std
::
vector
<
double
,
ALLOC
<
double
> >&
RecordCounter
<
ALLOC
>::
_extractFromCountings_
(
393
const
IdCondSet
<
ALLOC
>&
subset_ids
,
394
const
IdCondSet
<
ALLOC
>&
superset_ids
,
395
const
std
::
vector
<
double
,
ALLOC
<
double
> >&
superset_vect
) {
396
// get a mapping between the node Ids and their columns in the database.
397
// This should be stored into _nodeId2columns_, except if the latter is
398
// empty, in which case there is an identity mapping
399
const
auto
nodeId2columns
=
_getNodeIds2Columns_
(
superset_ids
);
400
401
// we first determine the size of the output vector, the domain of
402
// each of its variables and their offsets in the output vector
403
const
auto
&
database
=
_parsers_
[0].
data
.
database
();
404
std
::
size_t
result_vect_size
=
std
::
size_t
(1);
405
for
(
const
auto
id
:
subset_ids
) {
406
result_vect_size
*=
database
.
domainSize
(
nodeId2columns
[
id
]);
407
}
408
409
// we create the output vector
410
const
std
::
size_t
subset_ids_size
=
std
::
size_t
(
subset_ids
.
size
());
411
std
::
vector
<
double
,
ALLOC
<
double
> >
result_vect
(
result_vect_size
, 0.0);
412
413
414
// check if the subset_ids is the beginning of the sequence of superset_ids
415
// if this is the case, then we can outer loop over the variables not in
416
// subset_ids and, for each iteration of this loop add a vector of size
417
// result_size to result_vect
418
bool
subset_begin
=
true
;
419
for
(
std
::
size_t
i
= 0;
i
<
subset_ids_size
; ++
i
) {
420
if
(
superset_ids
.
pos
(
subset_ids
[
i
]) !=
i
) {
421
subset_begin
=
false
;
422
break
;
423
}
424
}
425
426
if
(
subset_begin
) {
427
const
std
::
size_t
superset_vect_size
=
superset_vect
.
size
();
428
std
::
size_t
i
=
std
::
size_t
(0);
429
while
(
i
<
superset_vect_size
) {
430
for
(
std
::
size_t
j
=
std
::
size_t
(0);
j
<
result_vect_size
; ++
j
, ++
i
) {
431
result_vect
[
j
] +=
superset_vect
[
i
];
432
}
433
}
434
435
// save the subset_ids and the result vector
436
try
{
437
_last_nonDB_ids_
=
subset_ids
;
438
_last_nonDB_countings_
=
std
::
move
(
result_vect
);
439
return
_last_nonDB_countings_
;
440
}
catch
(...) {
441
_last_nonDB_ids_
.
clear
();
442
_last_nonDB_countings_
.
clear
();
443
throw
;
444
}
445
}
446
447
448
// check if subset_ids is the end of the sequence of superset_ids.
449
// In this case, as above, there are two simple loops to perform the
450
// countings
451
bool
subset_end
=
true
;
452
const
std
::
size_t
superset_ids_size
=
std
::
size_t
(
superset_ids
.
size
());
453
for
(
std
::
size_t
i
= 0;
i
<
subset_ids_size
; ++
i
) {
454
if
(
superset_ids
.
pos
(
subset_ids
[
i
]) !=
i
+
superset_ids_size
-
subset_ids_size
) {
455
subset_end
=
false
;
456
break
;
457
}
458
}
459
460
if
(
subset_end
) {
461
// determine the size of the vector corresponding to the variables
462
// not belonging to subset_ids
463
std
::
size_t
vect_not_subset_size
=
std
::
size_t
(1);
464
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
superset_ids_size
-
subset_ids_size
; ++
i
)
465
vect_not_subset_size
*=
database
.
domainSize
(
nodeId2columns
[
superset_ids
[
i
]]);
466
467
// perform the two loops
468
std
::
size_t
i
=
std
::
size_t
(0);
469
for
(
std
::
size_t
j
=
std
::
size_t
(0);
j
<
result_vect_size
; ++
j
) {
470
for
(
std
::
size_t
k
=
std
::
size_t
(0);
k
<
vect_not_subset_size
; ++
k
, ++
i
) {
471
result_vect
[
j
] +=
superset_vect
[
i
];
472
}
473
}
474
475
// save the subset_ids and the result vector
476
try
{
477
_last_nonDB_ids_
=
subset_ids
;
478
_last_nonDB_countings_
=
std
::
move
(
result_vect
);
479
return
_last_nonDB_countings_
;
480
}
catch
(...) {
481
_last_nonDB_ids_
.
clear
();
482
_last_nonDB_countings_
.
clear
();
483
throw
;
484
}
485
}
486
487
488
// here subset_ids is a subset of superset_ids neither prefixing nor
489
// postfixing it. So the computation is somewhat more complicated.
490
491
// We will parse the superset_vect sequentially (using ++ operator).
492
// Sometimes, we will need to change the offset of the cell of result_vect
493
// that will be affected, sometimes not. Vector before_incr will indicate
494
// whether we need to change the offset (value = 0) or not (value different
495
// from 0). Vectors result_domain will indicate how this offset should be
496
// computed. Here is an example of the values of these vectors. Assume that
497
// superset_ids = <A,B,C,D,E> and subset_ids = <A,D,C>. Then, the three
498
// vectors before_incr, result_domain and result_offset are indexed w.r.t.
499
// A,C,D, i.e., w.r.t. to the variables in subset_ids but order w.r.t.
500
// superset_ids (this is convenient as we will parse superset_vect
501
// sequentially. For a variable or a set of variables X, let M_X denote the
502
// domain size of X. Then the contents of the three vectors are as follows:
503
// before_incr = {0, M_B, 0} (this means that whenever we iterate over B's
504
// values, the offset in result_vect does not change)
505
// result_domain = { M_A, M_C, M_D } (i.e., the domain sizes of the variables
506
// in subset_ids, order w.r.t. superset_ids)
507
// result_offset = { 1, M_A*M_D, M_A } (this corresponds to the offsets
508
// in result_vect of variables A, C and D)
509
// Vector superset_order = { 0, 2, 1} : this is a map from the indices of
510
// the variables in subset_ids to the indices of these variables in the
511
// three vectors described above. For instance, the "2" means that variable
512
// D (which is at index 1 in subset_ids) is located at index 2 in vector
513
// before_incr
514
std
::
vector
<
std
::
size_t
>
before_incr
(
subset_ids_size
);
515
std
::
vector
<
std
::
size_t
>
result_domain
(
subset_ids_size
);
516
std
::
vector
<
std
::
size_t
>
result_offset
(
subset_ids_size
);
517
{
518
std
::
size_t
result_domain_size
=
std
::
size_t
(1);
519
std
::
size_t
tmp_before_incr
=
std
::
size_t
(1);
520
std
::
vector
<
std
::
size_t
>
superset_order
(
subset_ids_size
);
521
522
for
(
std
::
size_t
h
=
std
::
size_t
(0),
j
=
std
::
size_t
(0);
j
<
subset_ids_size
; ++
h
) {
523
if
(
subset_ids
.
exists
(
superset_ids
[
h
])) {
524
before_incr
[
j
] =
tmp_before_incr
- 1;
525
superset_order
[
subset_ids
.
pos
(
superset_ids
[
h
])] =
j
;
526
tmp_before_incr
= 1;
527
++
j
;
528
}
else
{
529
tmp_before_incr
*=
database
.
domainSize
(
nodeId2columns
[
superset_ids
[
h
]]);
530
}
531
}
532
533
// compute the offsets in the order of the superset_ids
534
for
(
std
::
size_t
i
= 0;
i
<
subset_ids
.
size
(); ++
i
) {
535
const
std
::
size_t
domain_size
=
database
.
domainSize
(
nodeId2columns
[
subset_ids
[
i
]]);
536
const
std
::
size_t
j
=
superset_order
[
i
];
537
result_domain
[
j
] =
domain_size
;
538
result_offset
[
j
] =
result_domain_size
;
539
result_domain_size
*=
domain_size
;
540
}
541
}
542
543
std
::
vector
<
std
::
size_t
>
result_value
(
result_domain
);
544
std
::
vector
<
std
::
size_t
>
current_incr
(
before_incr
);
545
std
::
vector
<
std
::
size_t
>
result_down
(
result_offset
);
546
547
for
(
std
::
size_t
j
=
std
::
size_t
(0);
j
<
result_down
.
size
(); ++
j
) {
548
result_down
[
j
] *= (
result_domain
[
j
] - 1);
549
}
550
551
// now we can loop over the superset_vect to fill result_vect
552
const
std
::
size_t
superset_vect_size
=
superset_vect
.
size
();
553
std
::
size_t
the_result_offset
=
std
::
size_t
(0);
554
for
(
std
::
size_t
h
=
std
::
size_t
(0);
h
<
superset_vect_size
; ++
h
) {
555
result_vect
[
the_result_offset
] +=
superset_vect
[
h
];
556
557
// update the offset of result_vect
558
for
(
std
::
size_t
k
= 0;
k
<
current_incr
.
size
(); ++
k
) {
559
// check if we need modify result_offset
560
if
(
current_incr
[
k
]) {
561
--
current_incr
[
k
];
562
break
;
563
}
564
565
current_incr
[
k
] =
before_incr
[
k
];
566
567
// here we shall modify result_offset
568
--
result_value
[
k
];
569
570
if
(
result_value
[
k
]) {
571
the_result_offset
+=
result_offset
[
k
];
572
break
;
573
}
574
575
result_value
[
k
] =
result_domain
[
k
];
576
the_result_offset
-=
result_down
[
k
];
577
}
578
}
579
580
// save the subset_ids and the result vector
581
try
{
582
_last_nonDB_ids_
=
subset_ids
;
583
_last_nonDB_countings_
=
std
::
move
(
result_vect
);
584
return
_last_nonDB_countings_
;
585
}
catch
(...) {
586
_last_nonDB_ids_
.
clear
();
587
_last_nonDB_countings_
.
clear
();
588
throw
;
589
}
590
}
591
592
593
/// parse the database to produce new countings
594
template
<
template
<
typename
>
class
ALLOC
>
595
std
::
vector
<
double
,
ALLOC
<
double
> >&
596
RecordCounter
<
ALLOC
>::
_countFromDatabase_
(
const
IdCondSet
<
ALLOC
>&
ids
) {
597
// if the ids vector is empty or the database is empty, return an
598
// empty vector
599
const
auto
&
database
=
_parsers_
[0].
data
.
database
();
600
if
(
ids
.
empty
() ||
database
.
empty
() ||
_thread_ranges_
.
empty
()) {
601
_last_nonDB_countings_
.
clear
();
602
_last_nonDB_ids_
.
clear
();
603
return
_last_nonDB_countings_
;
604
}
605
606
// we translate the ids into their corresponding columns in the
607
// DatabaseTable
608
const
auto
nodeId2columns
=
_getNodeIds2Columns_
(
ids
);
609
610
// we first determine the size of the counting vector, the domain of
611
// each of its variables and their offsets in the output vector
612
const
std
::
size_t
ids_size
=
ids
.
size
();
613
std
::
size_t
counting_vect_size
=
std
::
size_t
(1);
614
std
::
vector
<
std
::
size_t
,
ALLOC
<
std
::
size_t
> >
domain_sizes
(
ids_size
);
615
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
616
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
617
cols_offsets
(
ids_size
);
618
{
619
std
::
size_t
i
=
std
::
size_t
(0);
620
for
(
const
auto
id
:
ids
) {
621
const
std
::
size_t
domain_size
=
database
.
domainSize
(
nodeId2columns
[
id
]);
622
domain_sizes
[
i
] =
domain_size
;
623
cols_offsets
[
i
].
first
=
nodeId2columns
[
id
];
624
cols_offsets
[
i
].
second
=
counting_vect_size
;
625
counting_vect_size
*=
domain_size
;
626
++
i
;
627
}
628
}
629
630
// we sort the columns and offsets by increasing column index. This
631
// may speed up threaded countings by improving the cacheline hits
632
std
::
sort
(
633
cols_offsets
.
begin
(),
634
cols_offsets
.
end
(),
635
[](
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
a
,
636
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
b
) ->
bool
{
return
a
.
first
<
b
.
first
; });
637
638
// create parsers if needed
639
const
std
::
size_t
nb_ranges
=
_thread_ranges_
.
size
();
640
const
std
::
size_t
nb_threads
=
nb_ranges
<=
_max_nb_threads_
?
nb_ranges
:
_max_nb_threads_
;
641
while
(
_parsers_
.
size
() <
nb_threads
) {
642
ThreadData
<
DBRowGeneratorParser
<
ALLOC
> >
new_parser
(
_parsers_
[0]);
643
_parsers_
.
push_back
(
std
::
move
(
new_parser
));
644
}
645
646
// set the columns of interest for each parser. This specifies to the
647
// parser which columns are used for the countings. This is important
648
// for parsers like the EM parser that complete unobserved variables.
649
std
::
vector
<
std
::
size_t
,
ALLOC
<
std
::
size_t
> >
cols_of_interest
(
ids_size
);
650
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
ids_size
; ++
i
) {
651
cols_of_interest
[
i
] =
cols_offsets
[
i
].
first
;
652
}
653
for
(
auto
&
parser
:
_parsers_
) {
654
parser
.
data
.
setColumnsOfInterest
(
cols_of_interest
);
655
}
656
657
// allocate all the counting vectors, including that which will add
658
// all the results provided by the threads. We initialize once and
659
// for all these vectors with zeroes
660
std
::
vector
<
double
,
ALLOC
<
double
> >
counting_vect
(
counting_vect_size
, 0.0);
661
std
::
vector
<
ThreadData
<
std
::
vector
<
double
,
ALLOC
<
double
> > >,
662
ALLOC
<
ThreadData
<
std
::
vector
<
double
,
ALLOC
<
double
> > > > >
663
thread_countings
(
nb_threads
,
664
ThreadData
<
std
::
vector
<
double
,
ALLOC
<
double
> > >(
counting_vect
));
665
666
// launch the threads
667
// here we use openMP for launching the threads because, experimentally,
668
// it seems to provide results that are twice as fast as the results
669
// with the std::thread
670
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_ranges
;
i
+=
nb_threads
) {
671
#
pragma
omp
parallel
num_threads
(
int
(
nb_threads
)
)
672
{
673
// get the number of the thread
674
const
std
::
size_t
this_thread
=
getThreadNumber
();
675
if
(
this_thread
+
i
<
nb_ranges
) {
676
DBRowGeneratorParser
<
ALLOC
>&
parser
=
_parsers_
[
this_thread
].
data
;
677
parser
.
setRange
(
_thread_ranges_
[
this_thread
+
i
].
first
,
678
_thread_ranges_
[
this_thread
+
i
].
second
);
679
std
::
vector
<
double
,
ALLOC
<
double
> >&
countings
=
thread_countings
[
this_thread
].
data
;
680
681
// parse the database
682
try
{
683
while
(
parser
.
hasRows
()) {
684
// get the observed rows
685
const
DBRow
<
DBTranslatedValue
>&
row
=
parser
.
row
();
686
687
// fill the counts for the current row
688
std
::
size_t
offset
=
std
::
size_t
(0);
689
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
ids_size
; ++
i
) {
690
offset
+=
row
[
cols_offsets
[
i
].
first
].
discr_val
*
cols_offsets
[
i
].
second
;
691
}
692
693
countings
[
offset
] +=
row
.
weight
();
694
}
695
}
catch
(
NotFound
&) {}
// this exception is raised by the row filter
696
// if the row generators create no output row
697
// from the last rows of the database
698
}
699
}
700
}
701
702
703
// add the counts to counting_vect
704
for
(
std
::
size_t
k
=
std
::
size_t
(0);
k
<
nb_threads
; ++
k
) {
705
const
auto
&
thread_counting
=
thread_countings
[
k
].
data
;
706
for
(
std
::
size_t
r
=
std
::
size_t
(0);
r
<
counting_vect_size
; ++
r
) {
707
counting_vect
[
r
] +=
thread_counting
[
r
];
708
}
709
}
710
711
// save the final results
712
_last_DB_ids_
=
ids
;
713
_last_DB_countings_
=
std
::
move
(
counting_vect
);
714
715
return
_last_DB_countings_
;
716
}
717
718
719
/// the method used by threads to produce countings by parsing the database
720
template
<
template
<
typename
>
class
ALLOC
>
721
void
RecordCounter
<
ALLOC
>::
_threadedCount_
(
722
const
std
::
size_t
begin
,
723
const
std
::
size_t
end
,
724
DBRowGeneratorParser
<
ALLOC
>&
parser
,
725
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
726
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
cols_offsets
,
727
std
::
vector
<
double
,
ALLOC
<
double
> >&
countings
) {
728
parser
.
setRange
(
begin
,
end
);
729
730
try
{
731
const
std
::
size_t
nb_columns
=
cols_offsets
.
size
();
732
while
(
parser
.
hasRows
()) {
733
// get the observed filtered rows
734
const
DBRow
<
DBTranslatedValue
>&
row
=
parser
.
row
();
735
736
// fill the counts for the current row
737
std
::
size_t
offset
=
std
::
size_t
(0);
738
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_columns
; ++
i
) {
739
offset
+=
row
[
cols_offsets
[
i
].
first
].
discr_val
*
cols_offsets
[
i
].
second
;
740
}
741
742
countings
[
offset
] +=
row
.
weight
();
743
}
744
}
catch
(
NotFound
&) {}
// this exception is raised by the row filter if the
745
// row generators create no output row from the last
746
// rows of the database
747
}
748
749
750
/// checks that the ranges passed in argument are ok or raise an exception
751
template
<
template
<
typename
>
class
ALLOC
>
752
template
<
template
<
typename
>
class
XALLOC
>
753
void
RecordCounter
<
ALLOC
>::
_checkRanges_
(
754
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
755
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
new_ranges
)
const
{
756
const
std
::
size_t
dbsize
=
_parsers_
[0].
data
.
database
().
nbRows
();
757
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
758
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
759
incorrect_ranges
;
760
for
(
const
auto
&
range
:
new_ranges
) {
761
if
((
range
.
first
>=
range
.
second
) || (
range
.
second
>
dbsize
)) {
762
incorrect_ranges
.
push_back
(
range
);
763
}
764
}
765
if
(!
incorrect_ranges
.
empty
()) {
766
std
::
stringstream
str
;
767
str
<<
"It is impossible to set the ranges because the following one"
;
768
if
(
incorrect_ranges
.
size
() > 1)
769
str
<<
"s are incorrect: "
;
770
else
771
str
<<
" is incorrect: "
;
772
bool
deja
=
false
;
773
for
(
const
auto
&
range
:
incorrect_ranges
) {
774
if
(
deja
)
775
str
<<
", "
;
776
else
777
deja
=
true
;
778
str
<<
'['
<<
range
.
first
<<
';'
<<
range
.
second
<<
')'
;
779
}
780
781
GUM_ERROR
(
OutOfBounds
,
str
.
str
())
782
}
783
}
784
785
786
/// sets the ranges within which each thread should perform its computations
787
template
<
template
<
typename
>
class
ALLOC
>
788
void
RecordCounter
<
ALLOC
>::
_dispatchRangesToThreads_
() {
789
_thread_ranges_
.
clear
();
790
791
// ensure that _ranges_ contains the ranges asked by the user
792
bool
add_range
=
false
;
793
if
(
_ranges_
.
empty
()) {
794
const
auto
&
database
=
_parsers_
[0].
data
.
database
();
795
_ranges_
.
push_back
(
796
std
::
pair
<
std
::
size_t
,
std
::
size_t
>(
std
::
size_t
(0),
database
.
nbRows
()));
797
add_range
=
true
;
798
}
799
800
// dispatch the ranges
801
for
(
const
auto
&
range
:
_ranges_
) {
802
if
(
range
.
second
>
range
.
first
) {
803
const
std
::
size_t
range_size
=
range
.
second
-
range
.
first
;
804
std
::
size_t
nb_threads
=
range_size
/
_min_nb_rows_per_thread_
;
805
if
(
nb_threads
< 1)
806
nb_threads
= 1;
807
else
if
(
nb_threads
>
_max_nb_threads_
)
808
nb_threads
=
_max_nb_threads_
;
809
std
::
size_t
nb_rows_par_thread
=
range_size
/
nb_threads
;
810
std
::
size_t
rest_rows
=
range_size
-
nb_rows_par_thread
*
nb_threads
;
811
812
std
::
size_t
begin_index
=
range
.
first
;
813
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_threads
; ++
i
) {
814
std
::
size_t
end_index
=
begin_index
+
nb_rows_par_thread
;
815
if
(
rest_rows
!=
std
::
size_t
(0)) {
816
++
end_index
;
817
--
rest_rows
;
818
}
819
_thread_ranges_
.
push_back
(
820
std
::
pair
<
std
::
size_t
,
std
::
size_t
>(
begin_index
,
end_index
));
821
begin_index
=
end_index
;
822
}
823
}
824
}
825
if
(
add_range
)
_ranges_
.
clear
();
826
827
// sort ranges by decreasing range size, so that if the number of
828
// ranges exceeds the number of threads allowed, we start a first round of
829
// threads with the highest range, then another round with lower ranges,
830
// and so on until all the ranges have been processed
831
std
::
sort
(
_thread_ranges_
.
begin
(),
832
_thread_ranges_
.
end
(),
833
[](
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
a
,
834
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>&
b
) ->
bool
{
835
return
(
a
.
second
-
a
.
first
) > (
b
.
second
-
b
.
first
);
836
});
837
}
838
839
840
/// sets new ranges to perform the countings
841
template
<
template
<
typename
>
class
ALLOC
>
842
template
<
template
<
typename
>
class
XALLOC
>
843
void
RecordCounter
<
ALLOC
>::
setRanges
(
844
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
845
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
new_ranges
) {
846
// first, we check that all ranges are within the database's bounds
847
_checkRanges_
(
new_ranges
);
848
849
// since the ranges are OK, save them and clear the counting caches
850
const
std
::
size_t
new_size
=
new_ranges
.
size
();
851
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
852
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
853
ranges
(
new_size
);
854
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
new_size
; ++
i
) {
855
ranges
[
i
].
first
=
new_ranges
[
i
].
first
;
856
ranges
[
i
].
second
=
new_ranges
[
i
].
second
;
857
}
858
859
clear
();
860
_ranges_
=
std
::
move
(
ranges
);
861
862
// dispatch the ranges to the threads
863
_dispatchRangesToThreads_
();
864
}
865
866
867
/// reset the ranges to the one range corresponding to the whole database
868
template
<
template
<
typename
>
class
ALLOC
>
869
void
RecordCounter
<
ALLOC
>::
clearRanges
() {
870
if
(
_ranges_
.
empty
())
return
;
871
clear
();
872
_ranges_
.
clear
();
873
_dispatchRangesToThreads_
();
874
}
875
876
877
/// returns the current ranges
878
template
<
template
<
typename
>
class
ALLOC
>
879
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
880
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
881
RecordCounter
<
ALLOC
>::
ranges
()
const
{
882
return
_ranges_
;
883
}
884
885
886
/// assign a new Bayes net to all the counter's generators depending on a BN
887
template
<
template
<
typename
>
class
ALLOC
>
888
template
<
typename
GUM_SCALAR
>
889
INLINE
void
RecordCounter
<
ALLOC
>::
setBayesNet
(
const
BayesNet
<
GUM_SCALAR
>&
new_bn
) {
890
// remove the caches
891
clear
();
892
893
// assign the new BN
894
for
(
auto
&
xparser
:
_parsers_
) {
895
xparser
.
data
.
setBayesNet
(
new_bn
);
896
}
897
}
898
899
}
/* namespace learning */
900
901
}
/* namespace gum */
902
903
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31