aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
paramEstimator_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief the base class for estimating parameters of CPTs
24
*
25
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
26
*/
27
28
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
29
30
namespace
gum
{
31
32
namespace
learning
{
33
34
/// returns the allocator used by the score
35
template
<
template
<
typename
>
class
ALLOC
>
36
INLINE
typename
ParamEstimator
<
ALLOC
>::
allocator_type
37
ParamEstimator
<
ALLOC
>::
getAllocator
()
const
{
38
return
counter_
.
getAllocator
();
39
}
40
41
42
/// default constructor
43
template
<
template
<
typename
>
class
ALLOC
>
44
ParamEstimator
<
ALLOC
>::
ParamEstimator
(
45
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
46
const
Apriori
<
ALLOC
>&
external_apriori
,
47
const
Apriori
<
ALLOC
>&
score_internal_apriori
,
48
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
49
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
ranges
,
50
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
51
nodeId2columns
,
52
const
typename
ParamEstimator
<
ALLOC
>::
allocator_type
&
alloc
) :
53
counter_
(
parser
,
ranges
,
nodeId2columns
,
alloc
) {
54
// copy the a prioris
55
external_apriori_
=
external_apriori
.
clone
(
alloc
);
56
try
{
57
score_internal_apriori_
=
score_internal_apriori
.
clone
();
58
}
catch
(...) {
59
ALLOC
<
Apriori
<
ALLOC
> >
allocator
(
alloc
);
60
allocator
.
destroy
(
external_apriori_
);
61
allocator
.
deallocate
(
external_apriori_
, 1);
62
throw
;
63
}
64
65
GUM_CONSTRUCTOR
(
ParamEstimator
);
66
}
67
68
69
/// default constructor
70
template
<
template
<
typename
>
class
ALLOC
>
71
ParamEstimator
<
ALLOC
>::
ParamEstimator
(
72
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
73
const
Apriori
<
ALLOC
>&
external_apriori
,
74
const
Apriori
<
ALLOC
>&
score_internal_apriori
,
75
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
76
nodeId2columns
,
77
const
typename
ParamEstimator
<
ALLOC
>::
allocator_type
&
alloc
) :
78
counter_
(
parser
,
nodeId2columns
,
alloc
) {
79
// copy the a prioris
80
external_apriori_
=
external_apriori
.
clone
(
alloc
);
81
try
{
82
score_internal_apriori_
=
score_internal_apriori
.
clone
();
83
}
catch
(...) {
84
ALLOC
<
Apriori
<
ALLOC
> >
allocator
(
alloc
);
85
allocator
.
destroy
(
external_apriori_
);
86
allocator
.
deallocate
(
external_apriori_
, 1);
87
throw
;
88
}
89
90
GUM_CONSTRUCTOR
(
ParamEstimator
);
91
}
92
93
94
/// copy constructor with a given allocator
95
template
<
template
<
typename
>
class
ALLOC
>
96
INLINE
ParamEstimator
<
ALLOC
>::
ParamEstimator
(
97
const
ParamEstimator
<
ALLOC
>&
from
,
98
const
typename
ParamEstimator
<
ALLOC
>::
allocator_type
&
alloc
) :
99
external_apriori_
(
from
.
external_apriori_
->
clone
(
alloc
)),
100
score_internal_apriori_
(
from
.
score_internal_apriori_
->
clone
(
alloc
)),
101
counter_
(
from
.
counter_
,
alloc
) {
102
GUM_CONS_CPY
(
ParamEstimator
);
103
}
104
105
106
/// copy constructor
107
template
<
template
<
typename
>
class
ALLOC
>
108
INLINE
ParamEstimator
<
ALLOC
>::
ParamEstimator
(
109
const
ParamEstimator
<
ALLOC
>&
from
) :
110
ParamEstimator
<
ALLOC
>(
from
,
from
.
getAllocator
()) {}
111
112
113
/// move constructor with a given allocator
114
template
<
template
<
typename
>
class
ALLOC
>
115
INLINE
ParamEstimator
<
ALLOC
>::
ParamEstimator
(
116
ParamEstimator
<
ALLOC
>&&
from
,
117
const
typename
ParamEstimator
<
ALLOC
>::
allocator_type
&
alloc
) :
118
external_apriori_
(
from
.
external_apriori_
),
119
score_internal_apriori_
(
from
.
score_internal_apriori_
),
120
counter_
(
std
::
move
(
from
.
counter_
),
alloc
) {
121
from
.
external_apriori_
=
nullptr
;
122
from
.
score_internal_apriori_
=
nullptr
;
123
GUM_CONS_MOV
(
ParamEstimator
);
124
}
125
126
127
/// move constructor
128
template
<
template
<
typename
>
class
ALLOC
>
129
INLINE
130
ParamEstimator
<
ALLOC
>::
ParamEstimator
(
ParamEstimator
<
ALLOC
>&&
from
) :
131
ParamEstimator
<
ALLOC
>(
std
::
move
(
from
),
from
.
getAllocator
()) {}
132
133
134
/// destructor
135
template
<
template
<
typename
>
class
ALLOC
>
136
ParamEstimator
<
ALLOC
>::~
ParamEstimator
() {
137
ALLOC
<
Apriori
<
ALLOC
> >
allocator
(
this
->
getAllocator
());
138
if
(
external_apriori_
!=
nullptr
) {
139
allocator
.
destroy
(
external_apriori_
);
140
allocator
.
deallocate
(
external_apriori_
, 1);
141
}
142
143
if
(
score_internal_apriori_
!=
nullptr
) {
144
allocator
.
destroy
(
score_internal_apriori_
);
145
allocator
.
deallocate
(
score_internal_apriori_
, 1);
146
}
147
148
GUM_DESTRUCTOR
(
ParamEstimator
);
149
}
150
151
152
/// copy operator
153
template
<
template
<
typename
>
class
ALLOC
>
154
ParamEstimator
<
ALLOC
>&
155
ParamEstimator
<
ALLOC
>::
operator
=(
const
ParamEstimator
<
ALLOC
>&
from
) {
156
if
(
this
!= &
from
) {
157
ALLOC
<
Apriori
<
ALLOC
> >
allocator
(
this
->
getAllocator
());
158
if
(
external_apriori_
!=
nullptr
) {
159
allocator
.
destroy
(
external_apriori_
);
160
allocator
.
deallocate
(
external_apriori_
, 1);
161
external_apriori_
=
nullptr
;
162
}
163
external_apriori_
=
from
.
external_apriori_
->
clone
(
this
->
getAllocator
());
164
165
if
(
score_internal_apriori_
!=
nullptr
) {
166
allocator
.
destroy
(
score_internal_apriori_
);
167
allocator
.
deallocate
(
score_internal_apriori_
, 1);
168
external_apriori_
=
nullptr
;
169
}
170
score_internal_apriori_
171
=
from
.
score_internal_apriori_
->
clone
(
this
->
getAllocator
());
172
173
counter_
=
from
.
counter_
;
174
}
175
return
*
this
;
176
}
177
178
179
/// move operator
180
template
<
template
<
typename
>
class
ALLOC
>
181
ParamEstimator
<
ALLOC
>&
182
ParamEstimator
<
ALLOC
>::
operator
=(
ParamEstimator
<
ALLOC
>&&
from
) {
183
if
(
this
!= &
from
) {
184
external_apriori_
=
from
.
external_apriori_
;
185
score_internal_apriori_
=
from
.
score_internal_apriori_
;
186
counter_
=
std
::
move
(
from
.
counter_
);
187
from
.
external_apriori_
=
nullptr
;
188
from
.
score_internal_apriori_
=
nullptr
;
189
}
190
return
*
this
;
191
}
192
193
194
/// clears all the data structures from memory
195
template
<
template
<
typename
>
class
ALLOC
>
196
INLINE
void
ParamEstimator
<
ALLOC
>::
clear
() {
197
counter_
.
clear
();
198
}
199
200
201
/// changes the max number of threads used to parse the database
202
template
<
template
<
typename
>
class
ALLOC
>
203
void
ParamEstimator
<
ALLOC
>::
setMaxNbThreads
(
std
::
size_t
nb
)
const
{
204
counter_
.
setMaxNbThreads
(
nb
);
205
}
206
207
208
/// returns the number of threads used to parse the database
209
template
<
template
<
typename
>
class
ALLOC
>
210
std
::
size_t
ParamEstimator
<
ALLOC
>::
nbThreads
()
const
{
211
return
counter_
.
nbThreads
();
212
}
213
214
215
/** @brief changes the number min of rows a thread should process in a
216
* multithreading context */
217
template
<
template
<
typename
>
class
ALLOC
>
218
INLINE
void
219
ParamEstimator
<
ALLOC
>::
setMinNbRowsPerThread
(
const
std
::
size_t
nb
)
const
{
220
counter_
.
setMinNbRowsPerThread
(
nb
);
221
}
222
223
224
/// returns the minimum of rows that each thread should process
225
template
<
template
<
typename
>
class
ALLOC
>
226
INLINE
std
::
size_t
ParamEstimator
<
ALLOC
>::
minNbRowsPerThread
()
const
{
227
return
counter_
.
minNbRowsPerThread
();
228
}
229
230
231
/// sets new ranges to perform the countings used by the score
232
/** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
233
* indices. The countings are then performed only on the union of the
234
* rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
235
* cross validation tasks, in which part of the database should be ignored.
236
* An empty set of ranges is equivalent to an interval [X,Y) ranging over
237
* the whole database. */
238
template
<
template
<
typename
>
class
ALLOC
>
239
template
<
template
<
typename
>
class
XALLOC
>
240
void
ParamEstimator
<
ALLOC
>::
setRanges
(
241
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
242
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
243
new_ranges
) {
244
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
245
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
246
old_ranges
=
ranges
();
247
counter_
.
setRanges
(
new_ranges
);
248
if
(
old_ranges
!=
ranges
())
clear
();
249
}
250
251
252
/// reset the ranges to the one range corresponding to the whole database
253
template
<
template
<
typename
>
class
ALLOC
>
254
void
ParamEstimator
<
ALLOC
>::
clearRanges
() {
255
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
256
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
257
old_ranges
=
ranges
();
258
counter_
.
clearRanges
();
259
if
(
old_ranges
!=
ranges
())
clear
();
260
}
261
262
263
/// returns the current ranges
264
template
<
template
<
typename
>
class
ALLOC
>
265
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
266
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
267
ParamEstimator
<
ALLOC
>::
ranges
()
const
{
268
return
counter_
.
ranges
();
269
}
270
271
272
/// returns the CPT's parameters corresponding to a given target node
273
template
<
template
<
typename
>
class
ALLOC
>
274
INLINE
std
::
vector
<
double
,
ALLOC
<
double
> >
275
ParamEstimator
<
ALLOC
>::
parameters
(
const
NodeId
target_node
) {
276
return
parameters
(
target_node
,
empty_nodevect_
);
277
}
278
279
280
// check the coherency between the parameters passed to setParameters functions
281
template
<
template
<
typename
>
class
ALLOC
>
282
template
<
typename
GUM_SCALAR
>
283
void
ParamEstimator
<
ALLOC
>::
checkParameters__
(
284
const
NodeId
target_node
,
285
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_nodes
,
286
Potential
<
GUM_SCALAR
>&
pot
) {
287
// check that the nodes passed in arguments correspond to those of pot
288
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
.
variablesSequence
();
289
if
(
vars
.
size
() == 0) {
290
GUM_ERROR
(
SizeError
,
"the potential contains no variable"
);
291
}
292
293
const
auto
&
database
=
counter_
.
database
();
294
const
auto
&
node2cols
=
counter_
.
nodeId2Columns
();
295
if
(
node2cols
.
empty
()) {
296
if
(
database
.
domainSize
(
target_node
) !=
vars
[0]->
domainSize
()) {
297
GUM_ERROR
(
SizeError
,
298
"Variable "
299
<<
vars
[0]->
name
() <<
"of the potential to be filled "
300
<<
"has a domain size of "
<<
vars
[0]->
domainSize
()
301
<<
", which is different from that of node "
<<
target_node
302
<<
" which is equal to "
303
<<
database
.
domainSize
(
target_node
));
304
}
305
for
(
std
::
size_t
i
= 1;
i
<
vars
.
size
(); ++
i
) {
306
if
(
database
.
domainSize
(
conditioning_nodes
[
i
- 1])
307
!=
vars
[
i
]->
domainSize
()) {
308
GUM_ERROR
(
SizeError
,
309
"Variable "
310
<<
vars
[
i
]->
name
() <<
"of the potential to be filled "
311
<<
"has a domain size of "
<<
vars
[
i
]->
domainSize
()
312
<<
", which is different from that of node "
313
<<
conditioning_nodes
[
i
- 1] <<
" which is equal to "
314
<<
database
.
domainSize
(
conditioning_nodes
[
i
- 1]));
315
}
316
}
317
}
else
{
318
std
::
size_t
col
=
node2cols
.
second
(
target_node
);
319
if
(
database
.
domainSize
(
col
) !=
vars
[0]->
domainSize
()) {
320
GUM_ERROR
(
SizeError
,
321
"Variable "
322
<<
vars
[0]->
name
() <<
"of the potential to be filled "
323
<<
"has a domain size of "
<<
vars
[0]->
domainSize
()
324
<<
", which is different from that of node "
<<
target_node
325
<<
" which is equal to "
<<
database
.
domainSize
(
col
));
326
}
327
for
(
std
::
size_t
i
= 1;
i
<
vars
.
size
(); ++
i
) {
328
col
=
node2cols
.
second
(
conditioning_nodes
[
i
- 1]);
329
if
(
database
.
domainSize
(
col
) !=
vars
[
i
]->
domainSize
()) {
330
GUM_ERROR
(
SizeError
,
331
"Variable "
332
<<
vars
[
i
]->
name
() <<
"of the potential to be filled "
333
<<
"has a domain size of "
<<
vars
[
i
]->
domainSize
()
334
<<
", which is different from that of node "
335
<<
conditioning_nodes
[
i
- 1] <<
" which is equal to "
336
<<
database
.
domainSize
(
col
));
337
}
338
}
339
}
340
}
341
342
343
/// sets the CPT's parameters corresponding to a given nodeset
344
template
<
template
<
typename
>
class
ALLOC
>
345
template
<
typename
GUM_SCALAR
>
346
INLINE
typename
std
::
enable_if
< !
std
::
is_same
<
GUM_SCALAR
,
double
>::
value
,
347
void
>::
type
348
ParamEstimator
<
ALLOC
>::
setParameters__
(
349
const
NodeId
target_node
,
350
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_nodes
,
351
Potential
<
GUM_SCALAR
>&
pot
) {
352
checkParameters__
(
target_node
,
conditioning_nodes
,
pot
);
353
354
const
std
::
vector
<
double
,
ALLOC
<
double
> >
params
(
355
parameters
(
target_node
,
conditioning_nodes
));
356
357
// transform the vector of double into a vector of GUM_SCALAR
358
const
std
::
size_t
size
=
params
.
size
();
359
std
::
vector
<
GUM_SCALAR
,
ALLOC
<
GUM_SCALAR
> >
xparams
(
size
);
360
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
size
; ++
i
)
361
xparams
[
i
] =
GUM_SCALAR
(
params
[
i
]);
362
363
pot
.
fillWith
(
xparams
);
364
}
365
366
367
/// sets the CPT's parameters corresponding to a given nodeset
368
template
<
template
<
typename
>
class
ALLOC
>
369
template
<
typename
GUM_SCALAR
>
370
INLINE
typename
std
::
enable_if
<
std
::
is_same
<
GUM_SCALAR
,
double
>::
value
,
371
void
>::
type
372
ParamEstimator
<
ALLOC
>::
setParameters__
(
373
const
NodeId
target_node
,
374
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_nodes
,
375
Potential
<
GUM_SCALAR
>&
pot
) {
376
checkParameters__
(
target_node
,
conditioning_nodes
,
pot
);
377
378
const
std
::
vector
<
double
,
ALLOC
<
double
> >
params
(
379
parameters
(
target_node
,
conditioning_nodes
));
380
pot
.
fillWith
(
params
);
381
}
382
383
384
/// sets the CPT's parameters corresponding to a given nodeset
385
template
<
template
<
typename
>
class
ALLOC
>
386
template
<
typename
GUM_SCALAR
>
387
INLINE
void
ParamEstimator
<
ALLOC
>::
setParameters
(
388
const
NodeId
target_node
,
389
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_nodes
,
390
Potential
<
GUM_SCALAR
>&
pot
) {
391
setParameters__
(
target_node
,
conditioning_nodes
,
pot
);
392
}
393
394
395
/// returns the mapping from ids to column positions in the database
396
template
<
template
<
typename
>
class
ALLOC
>
397
INLINE
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
398
ParamEstimator
<
ALLOC
>::
nodeId2Columns
()
const
{
399
return
counter_
.
nodeId2Columns
();
400
}
401
402
403
/// returns the database on which we perform the counts
404
template
<
template
<
typename
>
class
ALLOC
>
405
INLINE
const
DatabaseTable
<
ALLOC
>&
406
ParamEstimator
<
ALLOC
>::
database
()
const
{
407
return
counter_
.
database
();
408
}
409
410
411
/// assign a new Bayes net to all the counter's generators depending on a BN
412
template
<
template
<
typename
>
class
ALLOC
>
413
template
<
typename
GUM_SCALAR
>
414
INLINE
void
415
ParamEstimator
<
ALLOC
>::
setBayesNet
(
const
BayesNet
<
GUM_SCALAR
>&
new_bn
) {
416
counter_
.
setBayesNet
(
new_bn
);
417
}
418
419
420
}
/* namespace learning */
421
422
}
/* namespace gum */
423
424
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31