aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
genericBNLearner_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
#
include
<
algorithm
>
23
24
#
include
<
agrum
/
BN
/
learning
/
BNLearnUtils
/
genericBNLearner
.
h
>
25
26
namespace
gum
{
27
28
namespace
learning
{
29
30
template
<
typename
GUM_SCALAR
>
31
genericBNLearner
::
Database
::
Database
(
32
const
std
::
string
&
filename
,
33
const
BayesNet
<
GUM_SCALAR
>&
bn
,
34
const
std
::
vector
<
std
::
string
>&
missing_symbols
) {
35
// assign to each column name in the database its position
36
genericBNLearner
::
checkFileName__
(
filename
);
37
DBInitializerFromCSV
<>
initializer
(
filename
);
38
const
auto
&
xvar_names
=
initializer
.
variableNames
();
39
std
::
size_t
nb_vars
=
xvar_names
.
size
();
40
HashTable
<
std
::
string
,
std
::
size_t
>
var_names
(
nb_vars
);
41
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_vars
; ++
i
)
42
var_names
.
insert
(
xvar_names
[
i
],
i
);
43
44
// we use the bn to insert the translators into the database table
45
std
::
vector
<
NodeId
>
nodes
;
46
nodes
.
reserve
(
bn
.
dag
().
sizeNodes
());
47
for
(
const
auto
node
:
bn
.
dag
())
48
nodes
.
push_back
(
node
);
49
std
::
sort
(
nodes
.
begin
(),
nodes
.
end
());
50
std
::
size_t
i
=
std
::
size_t
(0);
51
for
(
auto
node
:
nodes
) {
52
const
Variable
&
var
=
bn
.
variable
(
node
);
53
try
{
54
database__
.
insertTranslator
(
var
,
var_names
[
var
.
name
()],
missing_symbols
);
55
}
catch
(
NotFound
&) {
56
GUM_ERROR
(
MissingVariableInDatabase
,
57
"Variable '"
<<
var
.
name
() <<
"' is missing"
);
58
}
59
nodeId2cols__
.
insert
(
NodeId
(
node
),
i
++);
60
}
61
62
// fill the database
63
initializer
.
fillDatabase
(
database__
);
64
65
// get the domain sizes of the variables
66
for
(
auto
dom
:
database__
.
domainSizes
())
67
domain_sizes__
.
push_back
(
dom
);
68
69
// create the parser
70
parser__
71
=
new
DBRowGeneratorParser
<>(
database__
.
handler
(),
DBRowGeneratorSet
<>());
72
}
73
74
75
template
<
typename
GUM_SCALAR
>
76
BayesNet
<
GUM_SCALAR
>
genericBNLearner
::
Database
::
BNVars__
()
const
{
77
BayesNet
<
GUM_SCALAR
>
bn
;
78
const
std
::
size_t
nb_vars
=
database__
.
nbVariables
();
79
for
(
std
::
size_t
i
= 0;
i
<
nb_vars
; ++
i
) {
80
const
DiscreteVariable
&
var
81
=
dynamic_cast
<
const
DiscreteVariable
& >(
database__
.
variable
(
i
));
82
bn
.
add
(
var
);
83
}
84
return
bn
;
85
}
86
87
88
template
<
typename
GUM_SCALAR
>
89
genericBNLearner
::
genericBNLearner
(
90
const
std
::
string
&
filename
,
91
const
gum
::
BayesNet
<
GUM_SCALAR
>&
bn
,
92
const
std
::
vector
<
std
::
string
>&
missing_symbols
) :
93
score_database__
(
filename
,
bn
,
missing_symbols
) {
94
no_apriori__
=
new
AprioriNoApriori
<>(
score_database__
.
databaseTable
());
95
GUM_CONSTRUCTOR
(
genericBNLearner
);
96
}
97
98
99
/// use a new set of database rows' ranges to perform learning
100
template
<
template
<
typename
>
class
XALLOC
>
101
void
genericBNLearner
::
useDatabaseRanges
(
102
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
103
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
104
new_ranges
) {
105
// use a score to detect whether the ranges are ok
106
ScoreLog2Likelihood
<>
score
(
score_database__
.
parser
(), *
no_apriori__
);
107
score
.
setRanges
(
new_ranges
);
108
ranges__
=
score
.
ranges
();
109
}
110
}
// namespace learning
111
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31