aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
genericBNLearner_inl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief A pack of learning algorithms that can easily be used
24
*
25
* The pack currently contains K2, GreedyHillClimbing, 3off2 and
26
*LocalSearchWithTabuList
27
*
28
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29
*/
30
31
// to help IDE parser
32
#
include
<
agrum
/
BN
/
learning
/
BNLearnUtils
/
genericBNLearner
.
h
>
33
#
include
<
agrum
/
tools
/
graphs
/
undiGraph
.
h
>
34
35
namespace
gum
{
36
37
namespace
learning
{
38
39
// returns the row filter
40
INLINE DBRowGeneratorParser<>&
genericBNLearner
::
Database
::
parser
() {
return
*
_parser_
; }
41
42
// returns the modalities of the variables
43
INLINE
const
std
::
vector
<
std
::
size_t
>&
genericBNLearner
::
Database
::
domainSizes
()
const
{
44
return
_domain_sizes_
;
45
}
46
47
// returns the names of the variables in the database
48
INLINE
const
std
::
vector
<
std
::
string
>&
genericBNLearner
::
Database
::
names
()
const
{
49
return
_database_
.
variableNames
();
50
}
51
52
/// assign new weight to the rows of the learning database
53
INLINE
void
genericBNLearner
::
Database
::
setDatabaseWeight
(
const
double
new_weight
) {
54
if
(
_database_
.
nbRows
() ==
std
::
size_t
(0))
return
;
55
const
double
weight
=
new_weight
/
double
(
_database_
.
nbRows
());
56
_database_
.
setAllRowsWeight
(
weight
);
57
}
58
59
// returns the node id corresponding to a variable name
60
INLINE
NodeId
genericBNLearner
::
Database
::
idFromName
(
const
std
::
string
&
var_name
)
const
{
61
try
{
62
const
auto
cols
=
_database_
.
columnsFromVariableName
(
var_name
);
63
return
_nodeId2cols_
.
first
(
cols
[0]);
64
}
catch
(...) {
65
GUM_ERROR
(
MissingVariableInDatabase
,
66
"Variable "
<<
var_name
<<
" could not be found in the database"
);
67
}
68
}
69
70
71
// returns the variable name corresponding to a given node id
72
INLINE
const
std
::
string
&
genericBNLearner
::
Database
::
nameFromId
(
NodeId
id
)
const
{
73
try
{
74
return
_database_
.
variableName
(
_nodeId2cols_
.
second
(
id
));
75
}
catch
(...) {
76
GUM_ERROR
(
MissingVariableInDatabase
,
77
"Variable of Id "
<<
id
<<
" could not be found in the database"
);
78
}
79
}
80
81
82
/// returns the internal database table
83
INLINE
const
DatabaseTable
<>&
genericBNLearner
::
Database
::
databaseTable
()
const
{
84
return
_database_
;
85
}
86
87
88
/// returns the set of missing symbols taken into account
89
INLINE
const
std
::
vector
<
std
::
string
>&
genericBNLearner
::
Database
::
missingSymbols
()
const
{
90
return
_database_
.
missingSymbols
();
91
}
92
93
94
/// returns the mapping between node ids and their columns in the database
95
INLINE
const
Bijection
<
NodeId
,
std
::
size_t
>&
96
genericBNLearner
::
Database
::
nodeId2Columns
()
const
{
97
return
_nodeId2cols_
;
98
}
99
100
101
/// returns the number of records in the database
102
INLINE
std
::
size_t
genericBNLearner
::
Database
::
nbRows
()
const
{
return
_database_
.
nbRows
(); }
103
104
105
/// returns the number of records in the database
106
INLINE
std
::
size_t
genericBNLearner
::
Database
::
size
()
const
{
return
_database_
.
size
(); }
107
108
109
/// sets the weight of the ith record
110
INLINE
void
genericBNLearner
::
Database
::
setWeight
(
const
std
::
size_t
i
,
const
double
weight
) {
111
_database_
.
setWeight
(
i
,
weight
);
112
}
113
114
115
/// returns the weight of the ith record
116
INLINE
double
genericBNLearner
::
Database
::
weight
(
const
std
::
size_t
i
)
const
{
117
return
_database_
.
weight
(
i
);
118
}
119
120
121
/// returns the weight of the whole database
122
INLINE
double
genericBNLearner
::
Database
::
weight
()
const
{
return
_database_
.
weight
(); }
123
124
125
// ===========================================================================
126
127
// returns the node id corresponding to a variable name
128
INLINE
NodeId
genericBNLearner
::
idFromName
(
const
std
::
string
&
var_name
)
const
{
129
return
scoreDatabase_
.
idFromName
(
var_name
);
130
}
131
132
// returns the variable name corresponding to a given node id
133
INLINE
const
std
::
string
&
genericBNLearner
::
nameFromId
(
NodeId
id
)
const
{
134
return
scoreDatabase_
.
nameFromId
(
id
);
135
}
136
137
/// assign new weight to the rows of the learning database
138
INLINE
void
genericBNLearner
::
setDatabaseWeight
(
const
double
new_weight
) {
139
scoreDatabase_
.
setDatabaseWeight
(
new_weight
);
140
}
141
142
/// assign new weight to the ith row of the learning database
143
INLINE
void
genericBNLearner
::
setRecordWeight
(
const
std
::
size_t
i
,
const
double
new_weight
) {
144
scoreDatabase_
.
setWeight
(
i
,
new_weight
);
145
}
146
147
/// returns the weight of the ith record
148
INLINE
double
genericBNLearner
::
recordWeight
(
const
std
::
size_t
i
)
const
{
149
return
scoreDatabase_
.
weight
(
i
);
150
}
151
152
/// returns the weight of the whole database
153
INLINE
double
genericBNLearner
::
databaseWeight
()
const
{
return
scoreDatabase_
.
weight
(); }
154
155
// sets an initial DAG structure
156
INLINE
void
genericBNLearner
::
setInitialDAG
(
const
DAG
&
dag
) {
initialDag_
=
dag
; }
157
158
// indicate that we wish to use an AIC score
159
INLINE
void
genericBNLearner
::
useScoreAIC
() {
160
scoreType_
=
ScoreType
::
AIC
;
161
checkScoreAprioriCompatibility
();
162
}
163
164
// indicate that we wish to use a BD score
165
INLINE
void
genericBNLearner
::
useScoreBD
() {
166
scoreType_
=
ScoreType
::
BD
;
167
checkScoreAprioriCompatibility
();
168
}
169
170
// indicate that we wish to use a BDeu score
171
INLINE
void
genericBNLearner
::
useScoreBDeu
() {
172
scoreType_
=
ScoreType
::
BDeu
;
173
checkScoreAprioriCompatibility
();
174
}
175
176
// indicate that we wish to use a BIC score
177
INLINE
void
genericBNLearner
::
useScoreBIC
() {
178
scoreType_
=
ScoreType
::
BIC
;
179
checkScoreAprioriCompatibility
();
180
}
181
182
// indicate that we wish to use a K2 score
183
INLINE
void
genericBNLearner
::
useScoreK2
() {
184
scoreType_
=
ScoreType
::
K2
;
185
checkScoreAprioriCompatibility
();
186
}
187
188
// indicate that we wish to use a Log2Likelihood score
189
INLINE
void
genericBNLearner
::
useScoreLog2Likelihood
() {
190
scoreType_
=
ScoreType
::
LOG2LIKELIHOOD
;
191
checkScoreAprioriCompatibility
();
192
}
193
194
// sets the max indegree
195
INLINE
void
genericBNLearner
::
setMaxIndegree
(
Size
max_indegree
) {
196
constraintIndegree_
.
setMaxIndegree
(
max_indegree
);
197
}
198
199
// indicate that we wish to use 3off2
200
INLINE
void
genericBNLearner
::
use3off2
() {
201
selectedAlgo_
=
AlgoType
::
MIIC_THREE_OFF_TWO
;
202
algoMiic3off2_
.
set3of2Behaviour
();
203
}
204
205
// indicate that we wish to use 3off2
206
INLINE
void
genericBNLearner
::
useMIIC
() {
207
selectedAlgo_
=
AlgoType
::
MIIC_THREE_OFF_TWO
;
208
algoMiic3off2_
.
setMiicBehaviour
();
209
}
210
211
/// indicate that we wish to use the NML correction for 3off2
212
INLINE
void
genericBNLearner
::
useNMLCorrection
() {
213
kmode3Off2_
=
CorrectedMutualInformation
<>::
KModeTypes
::
NML
;
214
}
215
216
/// indicate that we wish to use the MDL correction for 3off2
217
INLINE
void
genericBNLearner
::
useMDLCorrection
() {
218
kmode3Off2_
=
CorrectedMutualInformation
<>::
KModeTypes
::
MDL
;
219
}
220
221
/// indicate that we wish to use the NoCorr correction for 3off2
222
INLINE
void
genericBNLearner
::
useNoCorrection
() {
223
kmode3Off2_
=
CorrectedMutualInformation
<>::
KModeTypes
::
NoCorr
;
224
}
225
226
/// get the list of arcs hiding latent variables
227
INLINE
const
std
::
vector
<
Arc
>
genericBNLearner
::
latentVariables
()
const
{
228
return
algoMiic3off2_
.
latentVariables
();
229
}
230
231
// indicate that we wish to use a K2 algorithm
232
INLINE
void
genericBNLearner
::
useK2
(
const
Sequence
<
NodeId
>&
order
) {
233
selectedAlgo_
=
AlgoType
::
K2
;
234
algoK2_
.
setOrder
(
order
);
235
}
236
237
// indicate that we wish to use a K2 algorithm
238
INLINE
void
genericBNLearner
::
useK2
(
const
std
::
vector
<
NodeId
>&
order
) {
239
selectedAlgo_
=
AlgoType
::
K2
;
240
algoK2_
.
setOrder
(
order
);
241
}
242
243
// indicate that we wish to use a greedy hill climbing algorithm
244
INLINE
void
genericBNLearner
::
useGreedyHillClimbing
() {
245
selectedAlgo_
=
AlgoType
::
GREEDY_HILL_CLIMBING
;
246
}
247
248
// indicate that we wish to use a local search with tabu list
249
INLINE
void
genericBNLearner
::
useLocalSearchWithTabuList
(
Size
tabu_size
,
Size
nb_decrease
) {
250
selectedAlgo_
=
AlgoType
::
LOCAL_SEARCH_WITH_TABU_LIST
;
251
constraintTabuList_
.
setTabuListSize
(
tabu_size
);
252
localSearchWithTabuList_
.
setMaxNbDecreasingChanges
(
nb_decrease
);
253
}
254
255
/// use The EM algorithm to learn paramters
256
INLINE
void
genericBNLearner
::
useEM
(
const
double
epsilon
) {
epsilonEM_
=
epsilon
; }
257
258
259
INLINE
bool
genericBNLearner
::
hasMissingValues
()
const
{
260
return
scoreDatabase_
.
databaseTable
().
hasMissingValues
();
261
}
262
263
// assign a set of forbidden edges
264
INLINE
void
genericBNLearner
::
setPossibleEdges
(
const
EdgeSet
&
set
) {
265
constraintPossibleEdges_
.
setEdges
(
set
);
266
}
267
// assign a set of forbidden edges from an UndiGraph
268
INLINE
void
genericBNLearner
::
setPossibleSkeleton
(
const
gum
::
UndiGraph
&
g
) {
269
setPossibleEdges
(
g
.
edges
());
270
}
271
272
// assign a new possible edge
273
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
Edge
&
edge
) {
274
constraintPossibleEdges_
.
addEdge
(
edge
);
275
}
276
277
// remove a forbidden edge
278
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
Edge
&
edge
) {
279
constraintPossibleEdges_
.
eraseEdge
(
edge
);
280
}
281
282
// assign a new forbidden edge
283
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
NodeId
tail
,
const
NodeId
head
) {
284
addPossibleEdge
(
Edge
(
tail
,
head
));
285
}
286
287
// remove a forbidden edge
288
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
NodeId
tail
,
const
NodeId
head
) {
289
erasePossibleEdge
(
Edge
(
tail
,
head
));
290
}
291
292
// assign a new forbidden edge
293
INLINE
void
genericBNLearner
::
addPossibleEdge
(
const
std
::
string
&
tail
,
294
const
std
::
string
&
head
) {
295
addPossibleEdge
(
Edge
(
idFromName
(
tail
),
idFromName
(
head
)));
296
}
297
298
// remove a forbidden edge
299
INLINE
void
genericBNLearner
::
erasePossibleEdge
(
const
std
::
string
&
tail
,
300
const
std
::
string
&
head
) {
301
erasePossibleEdge
(
Edge
(
idFromName
(
tail
),
idFromName
(
head
)));
302
}
303
304
// assign a set of forbidden arcs
305
INLINE
void
genericBNLearner
::
setForbiddenArcs
(
const
ArcSet
&
set
) {
306
constraintForbiddenArcs_
.
setArcs
(
set
);
307
}
308
309
// assign a new forbidden arc
310
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
Arc
&
arc
) {
311
constraintForbiddenArcs_
.
addArc
(
arc
);
312
}
313
314
// remove a forbidden arc
315
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
Arc
&
arc
) {
316
constraintForbiddenArcs_
.
eraseArc
(
arc
);
317
}
318
319
// assign a new forbidden arc
320
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
NodeId
tail
,
const
NodeId
head
) {
321
addForbiddenArc
(
Arc
(
tail
,
head
));
322
}
323
324
// remove a forbidden arc
325
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
NodeId
tail
,
const
NodeId
head
) {
326
eraseForbiddenArc
(
Arc
(
tail
,
head
));
327
}
328
329
// assign a new forbidden arc
330
INLINE
void
genericBNLearner
::
addForbiddenArc
(
const
std
::
string
&
tail
,
331
const
std
::
string
&
head
) {
332
addForbiddenArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
333
}
334
335
// remove a forbidden arc
336
INLINE
void
genericBNLearner
::
eraseForbiddenArc
(
const
std
::
string
&
tail
,
337
const
std
::
string
&
head
) {
338
eraseForbiddenArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
339
}
340
341
// assign a set of forbidden arcs
342
INLINE
void
genericBNLearner
::
setMandatoryArcs
(
const
ArcSet
&
set
) {
343
constraintMandatoryArcs_
.
setArcs
(
set
);
344
}
345
346
// assign a new forbidden arc
347
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
Arc
&
arc
) {
348
constraintMandatoryArcs_
.
addArc
(
arc
);
349
}
350
351
// remove a forbidden arc
352
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
Arc
&
arc
) {
353
constraintMandatoryArcs_
.
eraseArc
(
arc
);
354
}
355
356
// assign a new forbidden arc
357
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
std
::
string
&
tail
,
358
const
std
::
string
&
head
) {
359
addMandatoryArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
360
}
361
362
// remove a forbidden arc
363
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
std
::
string
&
tail
,
364
const
std
::
string
&
head
) {
365
eraseMandatoryArc
(
Arc
(
idFromName
(
tail
),
idFromName
(
head
)));
366
}
367
368
// assign a new forbidden arc
369
INLINE
void
genericBNLearner
::
addMandatoryArc
(
const
NodeId
tail
,
const
NodeId
head
) {
370
addMandatoryArc
(
Arc
(
tail
,
head
));
371
}
372
373
// remove a forbidden arc
374
INLINE
void
genericBNLearner
::
eraseMandatoryArc
(
const
NodeId
tail
,
const
NodeId
head
) {
375
eraseMandatoryArc
(
Arc
(
tail
,
head
));
376
}
377
378
// sets a partial order on the nodes
379
INLINE
void
genericBNLearner
::
setSliceOrder
(
const
NodeProperty
<
NodeId
>&
slice_order
) {
380
constraintSliceOrder_
=
StructuralConstraintSliceOrder
(
slice_order
);
381
}
382
383
INLINE
void
384
genericBNLearner
::
setSliceOrder
(
const
std
::
vector
<
std
::
vector
<
std
::
string
> >&
slices
) {
385
NodeProperty
<
NodeId
>
slice_order
;
386
NodeId
rank
= 0;
387
for
(
const
auto
&
slice
:
slices
) {
388
for
(
const
auto
&
name
:
slice
) {
389
slice_order
.
insert
(
idFromName
(
name
),
rank
);
390
}
391
rank
++;
392
}
393
setSliceOrder
(
slice_order
);
394
}
395
396
// sets the apriori weight
397
INLINE
void
genericBNLearner
::
_setAprioriWeight_
(
double
weight
) {
398
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
399
400
aprioriWeight_
=
weight
;
401
checkScoreAprioriCompatibility
();
402
}
403
404
// use the apriori smoothing
405
INLINE
void
genericBNLearner
::
useNoApriori
() {
406
aprioriType_
=
AprioriType
::
NO_APRIORI
;
407
checkScoreAprioriCompatibility
();
408
}
409
410
// use the apriori smoothing
411
INLINE
void
genericBNLearner
::
useAprioriSmoothing
(
double
weight
) {
412
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
413
414
aprioriType_
=
AprioriType
::
SMOOTHING
;
415
_setAprioriWeight_
(
weight
);
416
417
checkScoreAprioriCompatibility
();
418
}
419
420
// use the Dirichlet apriori
421
INLINE
void
genericBNLearner
::
useAprioriDirichlet
(
const
std
::
string
&
filename
,
double
weight
) {
422
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
423
424
aprioriDbname_
=
filename
;
425
aprioriType_
=
AprioriType
::
DIRICHLET_FROM_DATABASE
;
426
_setAprioriWeight_
(
weight
);
427
428
checkScoreAprioriCompatibility
();
429
}
430
431
432
// use the apriori BDeu
433
INLINE
void
genericBNLearner
::
useAprioriBDeu
(
double
weight
) {
434
if
(
weight
< 0) {
GUM_ERROR
(
OutOfBounds
,
"the weight of the apriori must be positive"
) }
435
436
aprioriType_
=
AprioriType
::
BDEU
;
437
_setAprioriWeight_
(
weight
);
438
439
checkScoreAprioriCompatibility
();
440
}
441
442
443
// returns the type (as a string) of a given apriori
444
INLINE
const
std
::
string
&
genericBNLearner
::
getAprioriType_
()
const
{
445
switch
(
aprioriType_
) {
446
case
AprioriType
::
NO_APRIORI
:
447
return
AprioriNoApriori
<>::
type
::
type
;
448
449
case
AprioriType
::
SMOOTHING
:
450
return
AprioriSmoothing
<>::
type
::
type
;
451
452
case
AprioriType
::
DIRICHLET_FROM_DATABASE
:
453
return
AprioriDirichletFromDatabase
<>::
type
::
type
;
454
455
case
AprioriType
::
BDEU
:
456
return
AprioriBDeu
<>::
type
::
type
;
457
458
default
:
459
GUM_ERROR
(
OperationNotAllowed
,
460
"genericBNLearner getAprioriType does "
461
"not support yet this apriori"
);
462
}
463
}
464
465
// returns the names of the variables in the database
466
INLINE
const
std
::
vector
<
std
::
string
>&
genericBNLearner
::
names
()
const
{
467
return
scoreDatabase_
.
names
();
468
}
469
470
// returns the modalities of the variables in the database
471
INLINE
const
std
::
vector
<
std
::
size_t
>&
genericBNLearner
::
domainSizes
()
const
{
472
return
scoreDatabase_
.
domainSizes
();
473
}
474
475
// returns the modalities of a variable in the database
476
INLINE
Size
genericBNLearner
::
domainSize
(
NodeId
var
)
const
{
477
return
scoreDatabase_
.
domainSizes
()[
var
];
478
}
479
// returns the modalities of a variables in the database
480
INLINE
Size
genericBNLearner
::
domainSize
(
const
std
::
string
&
var
)
const
{
481
return
scoreDatabase_
.
domainSizes
()[
idFromName
(
var
)];
482
}
483
484
/// returns the current database rows' ranges used for learning
485
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> >&
486
genericBNLearner
::
databaseRanges
()
const
{
487
return
ranges_
;
488
}
489
490
/// reset the ranges to the one range corresponding to the whole database
491
INLINE
void
genericBNLearner
::
clearDatabaseRanges
() {
ranges_
.
clear
(); }
492
493
/// returns the database used by the BNLearner
494
INLINE
const
DatabaseTable
<>&
genericBNLearner
::
database
()
const
{
495
return
scoreDatabase_
.
databaseTable
();
496
}
497
498
INLINE
Size
genericBNLearner
::
nbCols
()
const
{
return
scoreDatabase_
.
domainSizes
().
size
(); }
499
500
INLINE
Size
genericBNLearner
::
nbRows
()
const
{
return
scoreDatabase_
.
databaseTable
().
size
(); }
501
}
/* namespace learning */
502
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31