aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
genericBNLearner_inl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief A pack of learning algorithms that can easily be used
24
*
25
* The pack currently contains K2, GreedyHillClimbing, 3off2 and
26
*LocalSearchWithTabuList
27
*
28
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29
*/
30
31
// to help IDE parser
32
#
include
<
agrum
/
BN
/
learning
/
BNLearnUtils
/
genericBNLearner
.
h
>
33
#
include
<
agrum
/
tools
/
graphs
/
undiGraph
.
h
>
34
35
namespace
gum
{
36
37
namespace
learning
{
38
39
// returns the row filter
40
INLINE DBRowGeneratorParser<>&
genericBNLearner
::
Database
::
parser
() {
41
return
*
parser__
;
42
}
43
44
// returns the modalities of the variables
45
INLINE
const
std
::
vector
<
std
::
size_t
>&
46
genericBNLearner
::
Database
::
domainSizes
()
const
{
47
return
domain_sizes__
;
48
}
49
50
// returns the names of the variables in the database
51
INLINE
const
std
::
vector
<
std
::
string
>&
52
genericBNLearner
::
Database
::
names
()
const
{
53
return
database__
.
variableNames
();
54
}
55
56
/// assign new weight to the rows of the learning database
57
INLINE
void
58
genericBNLearner
::
Database
::
setDatabaseWeight
(
const
double
new_weight
) {
59
if
(
database__
.
nbRows
() ==
std
::
size_t
(0))
return
;
60
const
double
weight
=
new_weight
/
double
(
database__
.
nbRows
());
61
database__
.
setAllRowsWeight
(
weight
);
62
}
63
64
// returns the node id corresponding to a variable name
65
INLINE
NodeId
66
genericBNLearner
::
Database
::
idFromName
(
const
std
::
string
&
var_name
)
const
{
67
try
{
68
const
auto
cols
=
database__
.
columnsFromVariableName
(
var_name
);
69
return
nodeId2cols__
.
first
(
cols
[0]);
70
}
catch
(...) {
71
GUM_ERROR
(
MissingVariableInDatabase
,
72
"Variable "
<<
var_name
73
<<
" could not be found in the database"
);
74
}
75
}
76
77
78
// returns the variable name corresponding to a given node id
79
INLINE
const
std
::
string
&
80
genericBNLearner
::
Database
::
nameFromId
(
NodeId
id
)
const
{
81
try
{
82
return
database__
.
variableName
(
nodeId2cols__
.
second
(
id
));
83
}
catch
(...) {
84
GUM_ERROR
(
MissingVariableInDatabase
,
85
"Variable of Id "
<<
id
86
<<
" could not be found in the database"
);
87
}
88
}
89
90
91
/// returns the internal database table
92
INLINE
const
DatabaseTable
<>&
93
genericBNLearner
::
Database
::
databaseTable
()
const
{
94
return
database__
;
95
}
96
97
98
/// returns the set of missing symbols taken into account
99
INLINE
const
std
::
vector
<
std
::
string
>&
100
genericBNLearner
::
Database
::
missingSymbols
()
const
{
101
return
database__
.
missingSymbols
();
102
}
103
104
105
/// returns the mapping between node ids and their columns in the database
106
INLINE
const
Bijection
<
NodeId
,
std
::
size_t
>&
107
genericBNLearner
::
Database
::
nodeId2Columns
()
const
{
108
return
nodeId2cols__
;
109
}
110
111
112
/// returns the number of records in the database
113
INLINE
std
::
size_t
genericBNLearner
::
Database
::
nbRows
()
const
{
114
return
database__
.
nbRows
();
115
}
116
117
118
/// returns the number of records in the database
119
INLINE
std
::
size_t
genericBNLearner
::
Database
::
size
()
const
{
120
return
database__
.
size
();
121
}
122
123
124
/// sets the weight of the ith record
125
INLINE
void
genericBNLearner
::
Database
::
setWeight
(
const
std
::
size_t
i
,
126
const
double
weight
) {
127
database__
.
setWeight
(
i
,
weight
);
128
}
129
130
131
/// returns the weight of the ith record
132
INLINE
double
genericBNLearner
::
Database
::
weight
(
const
std
::
size_t
i
)
const
{
133
return
database__
.
weight
(
i
);
134
}
135
136
137
/// returns the weight of the whole database
138
INLINE
double
genericBNLearner
::
Database
::
weight
()
const
{
139
return
database__
.
weight
();
140
}
141
142
143
// ===========================================================================
144
145
// returns the node id corresponding to a variable name
146
INLINE
NodeId
genericBNLearner
::
idFromName
(
const
std
::
string
&
var_name
)
const
{
147
return
score_database__
.
idFromName
(
var_name
);
148
}
149
150
// returns the variable name corresponding to a given node id
151
INLINE
const
std
::
string
&
genericBNLearner
::
nameFromId
(
NodeId
id
)
const
{
152
return
score_database__
.
nameFromId
(
id
);
153
}
154
155
/// assign new weight to the rows of the learning database
156
INLINE
void
genericBNLearner
::
setDatabaseWeight
(
const
double
new_weight
) {
157
score_database__
.
setDatabaseWeight
(
new_weight
);
158
}
159
160
/// assign new weight to the ith row of the learning database
161
INLINE
void
genericBNLearner
::
setRecordWeight
(
const
std
::
size_t
i
,
162
const
double
new_weight
) {
163
score_database__
.
setWeight
(
i
,
new_weight
);
164
}
165
166
/// returns the weight of the ith record
167
INLINE
double
genericBNLearner
::
recordWeight
(
const
std
::
size_t
i
)
const
{
168
return
score_database__
.
weight
(
i
);
169
}
170
171
/// returns the weight of the whole database
172
INLINE
double
genericBNLearner
::
databaseWeight
()
const
{
173
return
score_database__
.
weight
();
174
}
175
176
// sets an initial DAG structure
177
INLINE
void
genericBNLearner
::
setInitialDAG
(
const
DAG
&
dag
) {
178
initial_dag__
=
dag
;
179
}
180
181
// indicate that we wish to use an AIC score
182
INLINE
void
genericBNLearner
::
useScoreAIC
() {
183
score_type__
=
ScoreType
::
AIC
;
184
checkScoreAprioriCompatibility
();
185
}
186
187
// indicate that we wish to use a BD score
188
INLINE
void
genericBNLearner
::
useScoreBD
() {
189
score_type__
=
ScoreType
::
BD
;
190
checkScoreAprioriCompatibility
();
191
}
192
193
// indicate that we wish to use a BDeu score
194
INLINE
void
genericBNLearner
::
useScoreBDeu
() {
195
score_type__
=
ScoreType
::
BDeu
;
196
checkScoreAprioriCompatibility
();
197
}
198
199
// indicate that we wish to use a BIC score
200
INLINE
void
genericBNLearner
::
useScoreBIC
() {
201
score_type__
=
ScoreType
::
BIC
;
202
checkScoreAprioriCompatibility
();
203
}
204
205
// indicate that we wish to use a K2 score
206
INLINE
void
genericBNLearner
::
useScoreK2
() {
207
score_type__
=
ScoreType
::
K2
;
208
checkScoreAprioriCompatibility
();
209
}
210
211
// indicate that we wish to use a Log2Likelihood score
212
INLINE
void
genericBNLearner
::
useScoreLog2Likelihood
() {
213
score_type__
=
ScoreType
::
LOG2LIKELIHOOD
;
214
checkScoreAprioriCompatibility
();
215
}
216
217
// sets the max indegree
218
INLINE
void
genericBNLearner
::
setMaxIndegree
(
Size
max_indegree
) {
219
constraint_Indegree__
.
setMaxIndegree
(
max_indegree
);
220
}
221
222
// indicate that we wish to use 3off2
223
INLINE
void
genericBNLearner
::
use3off2
() {
224
selected_algo__
=
AlgoType
::
MIIC_THREE_OFF_TWO
;
225
miic_3off2__
.
set3off2Behaviour
();
226
}
227
228
// indicate that we wish to use 3off2
229
INLINE
void
genericBNLearner
::
useMIIC
() {
230
selected_algo__
=
AlgoType
::
MIIC_THREE_OFF_TWO
;
231
miic_3off2__
.
setMiicBehaviour
();
232
}
233
234
/// indicate that we wish to use the NML correction for 3off2
235
INLINE
void
genericBNLearner
::
useNML
() {
236
if
(
selected_algo__
!=
AlgoType
::
MIIC_THREE_OFF_TWO
) {
237
GUM_ERROR
(
OperationNotAllowed
,
238
"You must use the 3off2 algorithm before selecting "
239
<<
"the NML score"
);
240
}
241
kmode_3off2__
=
CorrectedMutualInformation
<>::
KModeTypes
::
NML
;
242
}
243
244
/// indicate that we wish to use the MDL correction for 3off2
245
INLINE
void
genericBNLearner
::
useMDL
() {
246
if
(
selected_algo__
!=
AlgoType
::
MIIC_THREE_OFF_TWO
) {
247
GUM_ERROR
(
OperationNotAllowed
,
248
"You must use the 3off2 algorithm before selecting "
249
<<
"the MDL score"
);
250
}
251
kmode_3off2__
=
CorrectedMutualInformation
<>::
KModeTypes
::
MDL
;
252
}
253
254
/// indicate that we wish to use the NoCorr correction for 3off2
255
INLINE
void
genericBNLearner
::
useNoCorr
() {
256
if
(
selected_algo__
!=
AlgoType
::
MIIC_THREE_OFF_TWO
) {
257
GUM_ERROR
(
OperationNotAllowed
,
258
"You must use the 3off2 algorithm before selecting "
259
<<
"the NoCorr score"
);
260
}
261
kmode_3off2__
=
CorrectedMutualInformation
<>::
KModeTypes
::
NoCorr
;
262
}
263
264
/// get the list of arcs hiding latent variables
265
INLINE
const
std
::
vector
<
Arc
>
genericBNLearner
::
latentVariables
()
const
{
266
if
(
selected_algo__
!=
AlgoType
::
MIIC_THREE_OFF_TWO
) {
267
GUM_ERROR
(
OperationNotAllowed
,
268
"You must use the 3off2 algorithm before selecting "
269
<<
"the latentVariables method"
);
270
}
271
return
miic_3off2__
.
latentVariables
();
272
}
273
274
// indicate that we wish to use a K2 algorithm
275
INLINE
void
genericBNLearner
::
useK2
(
const
Sequence
<
NodeId
>&
order
) {
276
selected_algo__
=
AlgoType
::
K2
;
277
K2__
.
setOrder
(
order
);
278
}
279
280
// indicate that we wish to use a K2 algorithm
281
INLINE
void
genericBNLearner
::
useK2
(
const
std
::
vector
<
NodeId
>&
order
) {
282
selected_algo__
=
AlgoType
::
K2
;
283
K2__
.
setOrder
(
order
);
284
}
285
286
// indicate that we wish to use a greedy hill climbing algorithm
287
INLINE
void
genericBNLearner
::
useGreedyHillClimbing
() {
288
selected_algo__
=
AlgoType
::
GREEDY_HILL_CLIMBING
;
289
}
290
291
// indicate that we wish to use a local search with tabu list
292
INLINE
void
genericBNLearner
::
useLocalSearchWithTabuList
(
Size
tabu_size
,
293
Size
nb_decrease
) {
294
selected_algo__
=
AlgoType
::
LOCAL_SEARCH_WITH_TABU_LIST
;
295
constraint_TabuList__
.
setTabuListSize
(
tabu_size
);
296
local_search_with_tabu_list__
.
setMaxNbDecreasingChanges
(
nb_decrease
);
297
}
298
299
/// use The EM algorithm to learn paramters
300
INLINE
void
genericBNLearner
::
useEM
(
const
double
epsilon
) {
301
EMepsilon__
=
epsilon
;
302
}
303
304
305
INLINE
bool
genericBNLearner
::
hasMissingValues
()
const
{
306
return
score_database__
.
databaseTable
().
hasMissingValues
();
307
}
308
309
// assign a set of forbidden edges
310
INLINE
void
genericBNLearner
::
setPossibleEdges
(
const
EdgeSet
&
set
) {
311
constraint_PossibleEdges__
.
setEdges
(
set
);
312
}
313
// assign a set of forbidden edges from an UndiGraph
314
INLINE
void
genericBNLearner
::
setPossibleSkeleton
(
const
gum
::
UndiGraph
&
g
) {
315
setPossibleEdges
(
g
.
edges
());
316
}
317
318
// assign a new possible edge
319
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
Edge
&
edge
) {
320
constraint_PossibleEdges__
.
addEdge
(
edge
);
321
}
322
323
// remove a forbidden edge
324
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
Edge
&
edge
) {
325
constraint_PossibleEdges__
.
eraseEdge
(
edge
);
326
}
327
328
// assign a new forbidden edge
329
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
NodeId
tail
,
330
const
NodeId
head
) {
331
addPossibleEdge
(
Edge
(
tail
,
head
));
332
}
333
334
// remove a forbidden edge
335
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
NodeId
tail
,
336
const
NodeId
head
) {
337
erasePossibleEdge
(
Edge
(
tail
,
head
));
338
}
339
340
// assign a new forbidden edge
341
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
std
::
string
&
tail
,
342
const
std
::
string
&
head
) {
343
addPossibleEdge
(
Edge
(
idFromName
(
tail
),
idFromName
(
head
)));
344
}
345
346
// remove a forbidden edge
347
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
std
::
string
&
tail
,
348
const
std
::
string
&
head
) {
349
erasePossibleEdge
(
Edge
(
idFromName
(
tail
),
idFromName
(
head
)));
350
}
351
352
// assign a set of forbidden arcs
353
INLINE
void
genericBNLearner
::
setForbiddenArcs
(
const
ArcSet
&
set
) {
354
constraint_ForbiddenArcs__
.
setArcs
(
set
);
355
}
356
357
// assign a new forbidden arc
358
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
Arc
&
arc
) {
359
constraint_ForbiddenArcs__
.
addArc
(
arc
);
360
}
361
362
// remove a forbidden arc
363
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
Arc
&
arc
) {
364
constraint_ForbiddenArcs__
.
eraseArc
(
arc
);
365
}
366
367
// assign a new forbidden arc
368
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
NodeId
tail
,
369
const
NodeId
head
) {
370
addForbiddenArc
(
Arc
(
tail
,
head
));
371
}
372
373
// remove a forbidden arc
374
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
NodeId
tail
,
375
const
NodeId
head
) {
376
eraseForbiddenArc
(
Arc
(
tail
,
head
));
377
}
378
379
// assign a new forbidden arc
380
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
std
::
string
&
tail
,
381
const
std
::
string
&
head
) {
382
addForbiddenArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
383
}
384
385
// remove a forbidden arc
386
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
std
::
string
&
tail
,
387
const
std
::
string
&
head
) {
388
eraseForbiddenArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
389
}
390
391
// assign a set of forbidden arcs
392
INLINE
void
genericBNLearner
::
setMandatoryArcs
(
const
ArcSet
&
set
) {
393
constraint_MandatoryArcs__
.
setArcs
(
set
);
394
}
395
396
// assign a new forbidden arc
397
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
Arc
&
arc
) {
398
constraint_MandatoryArcs__
.
addArc
(
arc
);
399
}
400
401
// remove a forbidden arc
402
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
Arc
&
arc
) {
403
constraint_MandatoryArcs__
.
eraseArc
(
arc
);
404
}
405
406
// assign a new forbidden arc
407
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
std
::
string
&
tail
,
408
const
std
::
string
&
head
) {
409
addMandatoryArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
410
}
411
412
// remove a forbidden arc
413
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
std
::
string
&
tail
,
414
const
std
::
string
&
head
) {
415
eraseMandatoryArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
416
}
417
418
// assign a new forbidden arc
419
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
NodeId
tail
,
420
const
NodeId
head
) {
421
addMandatoryArc
(
Arc
(
tail
,
head
));
422
}
423
424
// remove a forbidden arc
425
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
NodeId
tail
,
426
const
NodeId
head
) {
427
eraseMandatoryArc
(
Arc
(
tail
,
head
));
428
}
429
430
// sets a partial order on the nodes
431
INLINE
void
432
genericBNLearner
::
setSliceOrder
(
const
NodeProperty
<
NodeId
>&
slice_order
) {
433
constraint_SliceOrder__
=
StructuralConstraintSliceOrder
(
slice_order
);
434
}
435
436
INLINE
void
genericBNLearner
::
setSliceOrder
(
437
const
std
::
vector
<
std
::
vector
<
std
::
string
> >&
slices
) {
438
NodeProperty
<
NodeId
>
slice_order
;
439
NodeId
rank
= 0;
440
for
(
const
auto
&
slice
:
slices
) {
441
for
(
const
auto
&
name
:
slice
) {
442
slice_order
.
insert
(
idFromName
(
name
),
rank
);
443
}
444
rank
++;
445
}
446
setSliceOrder
(
slice_order
);
447
}
448
449
// sets the apriori weight
450
INLINE
void
genericBNLearner
::
setAprioriWeight__
(
double
weight
) {
451
if
(
weight
< 0) {
452
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
);
453
}
454
455
apriori_weight__
=
weight
;
456
checkScoreAprioriCompatibility
();
457
}
458
459
// use the apriori smoothing
460
INLINE
void
genericBNLearner
::
useNoApriori
() {
461
apriori_type__
=
AprioriType
::
NO_APRIORI
;
462
checkScoreAprioriCompatibility
();
463
}
464
465
// use the apriori smoothing
466
INLINE
void
genericBNLearner
::
useAprioriSmoothing
(
double
weight
) {
467
if
(
weight
< 0) {
468
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
);
469
}
470
471
apriori_type__
=
AprioriType
::
SMOOTHING
;
472
setAprioriWeight__
(
weight
);
473
474
checkScoreAprioriCompatibility
();
475
}
476
477
// use the Dirichlet apriori
478
INLINE
void
genericBNLearner
::
useAprioriDirichlet
(
const
std
::
string
&
filename
,
479
double
weight
) {
480
if
(
weight
< 0) {
481
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
);
482
}
483
484
apriori_dbname__
=
filename
;
485
apriori_type__
=
AprioriType
::
DIRICHLET_FROM_DATABASE
;
486
setAprioriWeight__
(
weight
);
487
488
checkScoreAprioriCompatibility
();
489
}
490
491
492
// use the apriori BDeu
493
INLINE
void
genericBNLearner
::
useAprioriBDeu
(
double
weight
) {
494
if
(
weight
< 0) {
495
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
);
496
}
497
498
apriori_type__
=
AprioriType
::
BDEU
;
499
setAprioriWeight__
(
weight
);
500
501
checkScoreAprioriCompatibility
();
502
}
503
504
505
// returns the type (as a string) of a given apriori
506
INLINE
const
std
::
string
&
genericBNLearner
::
getAprioriType__
()
const
{
507
switch
(
apriori_type__
) {
508
case
AprioriType
::
NO_APRIORI
:
509
return
AprioriNoApriori
<>::
type
::
type
;
510
511
case
AprioriType
::
SMOOTHING
:
512
return
AprioriSmoothing
<>::
type
::
type
;
513
514
case
AprioriType
::
DIRICHLET_FROM_DATABASE
:
515
return
AprioriDirichletFromDatabase
<>::
type
::
type
;
516
517
case
AprioriType
::
BDEU
:
518
return
AprioriBDeu
<>::
type
::
type
;
519
520
default
:
521
GUM_ERROR
(
OperationNotAllowed
,
522
"genericBNLearner getAprioriType does "
523
"not support yet this apriori"
);
524
}
525
}
526
527
// returns the names of the variables in the database
528
INLINE
const
std
::
vector
<
std
::
string
>&
genericBNLearner
::
names
()
const
{
529
return
score_database__
.
names
();
530
}
531
532
// returns the modalities of the variables in the database
533
INLINE
const
std
::
vector
<
std
::
size_t
>&
534
genericBNLearner
::
domainSizes
()
const
{
535
return
score_database__
.
domainSizes
();
536
}
537
538
// returns the modalities of a variable in the database
539
INLINE
Size
genericBNLearner
::
domainSize
(
NodeId
var
)
const
{
540
return
score_database__
.
domainSizes
()[
var
];
541
}
542
// returns the modalities of a variables in the database
543
INLINE
Size
genericBNLearner
::
domainSize
(
const
std
::
string
&
var
)
const
{
544
return
score_database__
.
domainSizes
()[
idFromName
(
var
)];
545
}
546
547
/// returns the current database rows' ranges used for learning
548
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> >&
549
genericBNLearner
::
databaseRanges
()
const
{
550
return
ranges__
;
551
}
552
553
/// reset the ranges to the one range corresponding to the whole database
554
INLINE
void
genericBNLearner
::
clearDatabaseRanges
() {
ranges__
.
clear
(); }
555
556
/// returns the database used by the BNLearner
557
INLINE
const
DatabaseTable
<>&
genericBNLearner
::
database
()
const
{
558
return
score_database__
.
databaseTable
();
559
}
560
561
INLINE
Size
genericBNLearner
::
nbCols
()
const
{
562
return
score_database__
.
domainSizes
().
size
();
563
}
564
565
INLINE
Size
genericBNLearner
::
nbRows
()
const
{
566
return
score_database__
.
databaseTable
().
size
();
567
}
568
}
/* namespace learning */
569
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31