aGrUM
0.21.0
a C++ library for (probabilistic) graphical models
genericBNLearner_inl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief A pack of learning algorithms that can easily be used
24
*
25
* The pack currently contains K2, GreedyHillClimbing, 3off2 and
26
*LocalSearchWithTabuList
27
*
28
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29
*/
30
31
// to help IDE parser
32
#
include
<
agrum
/
BN
/
learning
/
BNLearnUtils
/
genericBNLearner
.
h
>
33
#
include
<
agrum
/
tools
/
graphs
/
undiGraph
.
h
>
34
35
namespace
gum
{
36
37
namespace
learning
{
38
39
// returns the row filter
40
INLINE DBRowGeneratorParser<>&
genericBNLearner
::
Database
::
parser
() {
return
*
_parser_
; }
41
42
// returns the modalities of the variables
43
INLINE
const
std
::
vector
<
std
::
size_t
>&
genericBNLearner
::
Database
::
domainSizes
()
const
{
44
return
_domain_sizes_
;
45
}
46
47
// returns the names of the variables in the database
48
INLINE
const
std
::
vector
<
std
::
string
>&
genericBNLearner
::
Database
::
names
()
const
{
49
return
_database_
.
variableNames
();
50
}
51
52
/// assign new weight to the rows of the learning database
53
INLINE
void
genericBNLearner
::
Database
::
setDatabaseWeight
(
const
double
new_weight
) {
54
if
(
_database_
.
nbRows
() ==
std
::
size_t
(0))
return
;
55
const
double
weight
=
new_weight
/
double
(
_database_
.
nbRows
());
56
_database_
.
setAllRowsWeight
(
weight
);
57
}
58
59
// returns the node id corresponding to a variable name
60
INLINE
NodeId
genericBNLearner
::
Database
::
idFromName
(
const
std
::
string
&
var_name
)
const
{
61
try
{
62
const
auto
cols
=
_database_
.
columnsFromVariableName
(
var_name
);
63
return
_nodeId2cols_
.
first
(
cols
[0]);
64
}
catch
(...) {
65
GUM_ERROR
(
MissingVariableInDatabase
,
66
"Variable "
<<
var_name
<<
" could not be found in the database"
);
67
}
68
}
69
70
71
// returns the variable name corresponding to a given node id
72
INLINE
const
std
::
string
&
genericBNLearner
::
Database
::
nameFromId
(
NodeId
id
)
const
{
73
try
{
74
return
_database_
.
variableName
(
_nodeId2cols_
.
second
(
id
));
75
}
catch
(...) {
76
GUM_ERROR
(
MissingVariableInDatabase
,
77
"Variable of Id "
<<
id
<<
" could not be found in the database"
);
78
}
79
}
80
81
82
/// returns the internal database table
83
INLINE
const
DatabaseTable
<>&
genericBNLearner
::
Database
::
databaseTable
()
const
{
84
return
_database_
;
85
}
86
87
88
/// returns the set of missing symbols taken into account
89
INLINE
const
std
::
vector
<
std
::
string
>&
genericBNLearner
::
Database
::
missingSymbols
()
const
{
90
return
_database_
.
missingSymbols
();
91
}
92
93
94
/// returns the mapping between node ids and their columns in the database
95
INLINE
const
Bijection
<
NodeId
,
std
::
size_t
>&
96
genericBNLearner
::
Database
::
nodeId2Columns
()
const
{
97
return
_nodeId2cols_
;
98
}
99
100
101
/// returns the number of records in the database
102
INLINE
std
::
size_t
genericBNLearner
::
Database
::
nbRows
()
const
{
return
_database_
.
nbRows
(); }
103
104
105
/// returns the number of records in the database
106
INLINE
std
::
size_t
genericBNLearner
::
Database
::
size
()
const
{
return
_database_
.
size
(); }
107
108
109
/// sets the weight of the ith record
110
INLINE
void
genericBNLearner
::
Database
::
setWeight
(
const
std
::
size_t
i
,
const
double
weight
) {
111
_database_
.
setWeight
(
i
,
weight
);
112
}
113
114
115
/// returns the weight of the ith record
116
INLINE
double
genericBNLearner
::
Database
::
weight
(
const
std
::
size_t
i
)
const
{
117
return
_database_
.
weight
(
i
);
118
}
119
120
121
/// returns the weight of the whole database
122
INLINE
double
genericBNLearner
::
Database
::
weight
()
const
{
return
_database_
.
weight
(); }
123
124
125
// ===========================================================================
126
127
// returns the node id corresponding to a variable name
128
INLINE
NodeId
genericBNLearner
::
idFromName
(
const
std
::
string
&
var_name
)
const
{
129
return
scoreDatabase_
.
idFromName
(
var_name
);
130
}
131
132
// returns the variable name corresponding to a given node id
133
INLINE
const
std
::
string
&
genericBNLearner
::
nameFromId
(
NodeId
id
)
const
{
134
return
scoreDatabase_
.
nameFromId
(
id
);
135
}
136
137
/// assign new weight to the rows of the learning database
138
INLINE
void
genericBNLearner
::
setDatabaseWeight
(
const
double
new_weight
) {
139
scoreDatabase_
.
setDatabaseWeight
(
new_weight
);
140
}
141
142
/// assign new weight to the ith row of the learning database
143
INLINE
void
genericBNLearner
::
setRecordWeight
(
const
std
::
size_t
i
,
const
double
new_weight
) {
144
scoreDatabase_
.
setWeight
(
i
,
new_weight
);
145
}
146
147
/// returns the weight of the ith record
148
INLINE
double
genericBNLearner
::
recordWeight
(
const
std
::
size_t
i
)
const
{
149
return
scoreDatabase_
.
weight
(
i
);
150
}
151
152
/// returns the weight of the whole database
153
INLINE
double
genericBNLearner
::
databaseWeight
()
const
{
return
scoreDatabase_
.
weight
(); }
154
155
// sets an initial DAG structure
156
INLINE
void
genericBNLearner
::
setInitialDAG
(
const
DAG
&
dag
) {
initialDag_
=
dag
; }
157
158
INLINE
DAG
genericBNLearner
::
initialDAG
() {
return
initialDag_
; }
159
160
// indicate that we wish to use an AIC score
161
INLINE
void
genericBNLearner
::
useScoreAIC
() {
162
scoreType_
=
ScoreType
::
AIC
;
163
checkScoreAprioriCompatibility
();
164
}
165
166
// indicate that we wish to use a BD score
167
INLINE
void
genericBNLearner
::
useScoreBD
() {
168
scoreType_
=
ScoreType
::
BD
;
169
checkScoreAprioriCompatibility
();
170
}
171
172
// indicate that we wish to use a BDeu score
173
INLINE
void
genericBNLearner
::
useScoreBDeu
() {
174
scoreType_
=
ScoreType
::
BDeu
;
175
checkScoreAprioriCompatibility
();
176
}
177
178
// indicate that we wish to use a BIC score
179
INLINE
void
genericBNLearner
::
useScoreBIC
() {
180
scoreType_
=
ScoreType
::
BIC
;
181
checkScoreAprioriCompatibility
();
182
}
183
184
// indicate that we wish to use a K2 score
185
INLINE
void
genericBNLearner
::
useScoreK2
() {
186
scoreType_
=
ScoreType
::
K2
;
187
checkScoreAprioriCompatibility
();
188
}
189
190
// indicate that we wish to use a Log2Likelihood score
191
INLINE
void
genericBNLearner
::
useScoreLog2Likelihood
() {
192
scoreType_
=
ScoreType
::
LOG2LIKELIHOOD
;
193
checkScoreAprioriCompatibility
();
194
}
195
196
// sets the max indegree
197
INLINE
void
genericBNLearner
::
setMaxIndegree
(
Size
max_indegree
) {
198
constraintIndegree_
.
setMaxIndegree
(
max_indegree
);
199
}
200
201
// indicate that we wish to use 3off2
202
INLINE
void
genericBNLearner
::
use3off2
() {
203
selectedAlgo_
=
AlgoType
::
THREE_OFF_TWO
;
204
algoMiic3off2_
.
set3of2Behaviour
();
205
}
206
207
// indicate that we wish to use 3off2
208
INLINE
void
genericBNLearner
::
useMIIC
() {
209
selectedAlgo_
=
AlgoType
::
MIIC
;
210
algoMiic3off2_
.
setMiicBehaviour
();
211
}
212
213
/// indicate that we wish to use the NML correction for 3off2
214
INLINE
void
genericBNLearner
::
useNMLCorrection
() {
215
kmode3Off2_
=
CorrectedMutualInformation
<>::
KModeTypes
::
NML
;
216
}
217
218
/// indicate that we wish to use the MDL correction for 3off2
219
INLINE
void
genericBNLearner
::
useMDLCorrection
() {
220
kmode3Off2_
=
CorrectedMutualInformation
<>::
KModeTypes
::
MDL
;
221
}
222
223
/// indicate that we wish to use the NoCorr correction for 3off2
224
INLINE
void
genericBNLearner
::
useNoCorrection
() {
225
kmode3Off2_
=
CorrectedMutualInformation
<>::
KModeTypes
::
NoCorr
;
226
}
227
228
/// get the list of arcs hiding latent variables
229
INLINE
const
std
::
vector
<
Arc
>
genericBNLearner
::
latentVariables
()
const
{
230
return
algoMiic3off2_
.
latentVariables
();
231
}
232
233
// indicate that we wish to use a K2 algorithm
234
INLINE
void
genericBNLearner
::
useK2
(
const
Sequence
<
NodeId
>&
order
) {
235
selectedAlgo_
=
AlgoType
::
K2
;
236
algoK2_
.
setOrder
(
order
);
237
}
238
239
// indicate that we wish to use a K2 algorithm
240
INLINE
void
genericBNLearner
::
useK2
(
const
std
::
vector
<
NodeId
>&
order
) {
241
selectedAlgo_
=
AlgoType
::
K2
;
242
algoK2_
.
setOrder
(
order
);
243
}
244
245
// indicate that we wish to use a greedy hill climbing algorithm
246
INLINE
void
genericBNLearner
::
useGreedyHillClimbing
() {
247
selectedAlgo_
=
AlgoType
::
GREEDY_HILL_CLIMBING
;
248
}
249
250
// indicate that we wish to use a local search with tabu list
251
INLINE
void
genericBNLearner
::
useLocalSearchWithTabuList
(
Size
tabu_size
,
Size
nb_decrease
) {
252
selectedAlgo_
=
AlgoType
::
LOCAL_SEARCH_WITH_TABU_LIST
;
253
nbDecreasingChanges_
=
nb_decrease
;
254
constraintTabuList_
.
setTabuListSize
(
tabu_size
);
255
localSearchWithTabuList_
.
setMaxNbDecreasingChanges
(
nb_decrease
);
256
}
257
258
/// use The EM algorithm to learn paramters
259
INLINE
void
genericBNLearner
::
useEM
(
const
double
epsilon
) {
epsilonEM_
=
epsilon
; }
260
261
262
INLINE
bool
genericBNLearner
::
hasMissingValues
()
const
{
263
return
scoreDatabase_
.
databaseTable
().
hasMissingValues
();
264
}
265
266
// assign a set of forbidden edges
267
INLINE
void
genericBNLearner
::
setPossibleEdges
(
const
EdgeSet
&
set
) {
268
constraintPossibleEdges_
.
setEdges
(
set
);
269
}
270
// assign a set of forbidden edges from an UndiGraph
271
INLINE
void
genericBNLearner
::
setPossibleSkeleton
(
const
gum
::
UndiGraph
&
g
) {
272
setPossibleEdges
(
g
.
edges
());
273
}
274
275
// assign a new possible edge
276
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
Edge
&
edge
) {
277
constraintPossibleEdges_
.
addEdge
(
edge
);
278
}
279
280
// remove a forbidden edge
281
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
Edge
&
edge
) {
282
constraintPossibleEdges_
.
eraseEdge
(
edge
);
283
}
284
285
// assign a new forbidden edge
286
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
NodeId
tail
,
const
NodeId
head
) {
287
addPossibleEdge
(
Edge
(
tail
,
head
));
288
}
289
290
// remove a forbidden edge
291
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
NodeId
tail
,
const
NodeId
head
) {
292
erasePossibleEdge
(
Edge
(
tail
,
head
));
293
}
294
295
// assign a new forbidden edge
296
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
std
::
string
&
tail
,
297
const
std
::
string
&
head
) {
298
addPossibleEdge
(
Edge
(
idFromName
(
tail
),
idFromName
(
head
)));
299
}
300
301
// remove a forbidden edge
302
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
std
::
string
&
tail
,
303
const
std
::
string
&
head
) {
304
erasePossibleEdge
(
Edge
(
idFromName
(
tail
),
idFromName
(
head
)));
305
}
306
307
// assign a set of forbidden arcs
308
INLINE
void
genericBNLearner
::
setForbiddenArcs
(
const
ArcSet
&
set
) {
309
constraintForbiddenArcs_
.
setArcs
(
set
);
310
}
311
312
// assign a new forbidden arc
313
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
Arc
&
arc
) {
314
constraintForbiddenArcs_
.
addArc
(
arc
);
315
}
316
317
// remove a forbidden arc
318
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
Arc
&
arc
) {
319
constraintForbiddenArcs_
.
eraseArc
(
arc
);
320
}
321
322
// assign a new forbidden arc
323
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
NodeId
tail
,
const
NodeId
head
) {
324
addForbiddenArc
(
Arc
(
tail
,
head
));
325
}
326
327
// remove a forbidden arc
328
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
NodeId
tail
,
const
NodeId
head
) {
329
eraseForbiddenArc
(
Arc
(
tail
,
head
));
330
}
331
332
// assign a new forbidden arc
333
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
std
::
string
&
tail
,
334
const
std
::
string
&
head
) {
335
addForbiddenArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
336
}
337
338
// remove a forbidden arc
339
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
std
::
string
&
tail
,
340
const
std
::
string
&
head
) {
341
eraseForbiddenArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
342
}
343
344
// assign a set of forbidden arcs
345
INLINE
void
genericBNLearner
::
setMandatoryArcs
(
const
ArcSet
&
set
) {
346
constraintMandatoryArcs_
.
setArcs
(
set
);
347
}
348
349
// assign a new forbidden arc
350
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
Arc
&
arc
) {
351
constraintMandatoryArcs_
.
addArc
(
arc
);
352
}
353
354
// remove a forbidden arc
355
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
Arc
&
arc
) {
356
constraintMandatoryArcs_
.
eraseArc
(
arc
);
357
}
358
359
// assign a new forbidden arc
360
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
std
::
string
&
tail
,
361
const
std
::
string
&
head
) {
362
addMandatoryArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
363
}
364
365
// remove a forbidden arc
366
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
std
::
string
&
tail
,
367
const
std
::
string
&
head
) {
368
eraseMandatoryArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
369
}
370
371
// assign a new forbidden arc
372
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
NodeId
tail
,
const
NodeId
head
) {
373
addMandatoryArc
(
Arc
(
tail
,
head
));
374
}
375
376
// remove a forbidden arc
377
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
NodeId
tail
,
const
NodeId
head
) {
378
eraseMandatoryArc
(
Arc
(
tail
,
head
));
379
}
380
381
// sets a partial order on the nodes
382
INLINE
void
genericBNLearner
::
setSliceOrder
(
const
NodeProperty
<
NodeId
>&
slice_order
) {
383
constraintSliceOrder_
=
StructuralConstraintSliceOrder
(
slice_order
);
384
}
385
386
INLINE
void
387
genericBNLearner
::
setSliceOrder
(
const
std
::
vector
<
std
::
vector
<
std
::
string
> >&
slices
) {
388
NodeProperty
<
NodeId
>
slice_order
;
389
NodeId
rank
= 0;
390
for
(
const
auto
&
slice
:
slices
) {
391
for
(
const
auto
&
name
:
slice
) {
392
slice_order
.
insert
(
idFromName
(
name
),
rank
);
393
}
394
rank
++;
395
}
396
setSliceOrder
(
slice_order
);
397
}
398
399
// sets the apriori weight
400
INLINE
void
genericBNLearner
::
_setAprioriWeight_
(
double
weight
) {
401
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
402
403
aprioriWeight_
=
weight
;
404
checkScoreAprioriCompatibility
();
405
}
406
407
// use the apriori smoothing
408
INLINE
void
genericBNLearner
::
useNoApriori
() {
409
aprioriType_
=
AprioriType
::
NO_APRIORI
;
410
checkScoreAprioriCompatibility
();
411
}
412
413
// use the apriori smoothing
414
INLINE
void
genericBNLearner
::
useAprioriSmoothing
(
double
weight
) {
415
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
416
417
aprioriType_
=
AprioriType
::
SMOOTHING
;
418
_setAprioriWeight_
(
weight
);
419
420
checkScoreAprioriCompatibility
();
421
}
422
423
// use the Dirichlet apriori
424
INLINE
void
genericBNLearner
::
useAprioriDirichlet
(
const
std
::
string
&
filename
,
double
weight
) {
425
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
426
427
aprioriDbname_
=
filename
;
428
aprioriType_
=
AprioriType
::
DIRICHLET_FROM_DATABASE
;
429
_setAprioriWeight_
(
weight
);
430
431
checkScoreAprioriCompatibility
();
432
}
433
434
435
// use the apriori BDeu
436
INLINE
void
genericBNLearner
::
useAprioriBDeu
(
double
weight
) {
437
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
438
439
aprioriType_
=
AprioriType
::
BDEU
;
440
_setAprioriWeight_
(
weight
);
441
442
checkScoreAprioriCompatibility
();
443
}
444
445
446
// returns the type (as a string) of a given apriori
447
INLINE
const
std
::
string
&
genericBNLearner
::
getAprioriType_
()
const
{
448
switch
(
aprioriType_
) {
449
case
AprioriType
::
NO_APRIORI
:
450
return
AprioriNoApriori
<>::
type
::
type
;
451
452
case
AprioriType
::
SMOOTHING
:
453
return
AprioriSmoothing
<>::
type
::
type
;
454
455
case
AprioriType
::
DIRICHLET_FROM_DATABASE
:
456
return
AprioriDirichletFromDatabase
<>::
type
::
type
;
457
458
case
AprioriType
::
BDEU
:
459
return
AprioriBDeu
<>::
type
::
type
;
460
461
default
:
462
GUM_ERROR
(
OperationNotAllowed
,
463
"genericBNLearner getAprioriType does "
464
"not support yet this apriori"
);
465
}
466
}
467
468
// returns the names of the variables in the database
469
INLINE
const
std
::
vector
<
std
::
string
>&
genericBNLearner
::
names
()
const
{
470
return
scoreDatabase_
.
names
();
471
}
472
473
// returns the modalities of the variables in the database
474
INLINE
const
std
::
vector
<
std
::
size_t
>&
genericBNLearner
::
domainSizes
()
const
{
475
return
scoreDatabase_
.
domainSizes
();
476
}
477
478
// returns the modalities of a variable in the database
479
INLINE
Size
genericBNLearner
::
domainSize
(
NodeId
var
)
const
{
480
return
scoreDatabase_
.
domainSizes
()[
var
];
481
}
482
// returns the modalities of a variables in the database
483
INLINE
Size
genericBNLearner
::
domainSize
(
const
std
::
string
&
var
)
const
{
484
return
scoreDatabase_
.
domainSizes
()[
idFromName
(
var
)];
485
}
486
487
/// returns the current database rows' ranges used for learning
488
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> >&
489
genericBNLearner
::
databaseRanges
()
const
{
490
return
ranges_
;
491
}
492
493
/// reset the ranges to the one range corresponding to the whole database
494
INLINE
void
genericBNLearner
::
clearDatabaseRanges
() {
ranges_
.
clear
(); }
495
496
/// returns the database used by the BNLearner
497
INLINE
const
DatabaseTable
<>&
genericBNLearner
::
database
()
const
{
498
return
scoreDatabase_
.
databaseTable
();
499
}
500
501
INLINE
Size
genericBNLearner
::
nbCols
()
const
{
return
scoreDatabase_
.
domainSizes
().
size
(); }
502
503
INLINE
Size
genericBNLearner
::
nbRows
()
const
{
return
scoreDatabase_
.
databaseTable
().
size
(); }
504
}
/* namespace learning */
505
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31