aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
genericBNLearner_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
#
include
<
algorithm
>
23
24
#
include
<
agrum
/
BN
/
learning
/
BNLearnUtils
/
genericBNLearner
.
h
>
25
26
namespace
gum
{
27
28
namespace
learning
{
29
30
template
<
typename
GUM_SCALAR
>
31
genericBNLearner
::
Database
::
Database
(
const
std
::
string
&
filename
,
32
const
BayesNet
<
GUM_SCALAR
>&
bn
,
33
const
std
::
vector
<
std
::
string
>&
missing_symbols
) {
34
// assign to each column name in the database its position
35
genericBNLearner
::
checkFileName_
(
filename
);
36
DBInitializerFromCSV
<>
initializer
(
filename
);
37
const
auto
&
xvar_names
=
initializer
.
variableNames
();
38
std
::
size_t
nb_vars
=
xvar_names
.
size
();
39
HashTable
<
std
::
string
,
std
::
size_t
>
var_names
(
nb_vars
);
40
for
(
std
::
size_t
i
=
std
::
size_t
(0);
i
<
nb_vars
; ++
i
)
41
var_names
.
insert
(
xvar_names
[
i
],
i
);
42
43
// we use the bn to insert the translators into the database table
44
std
::
vector
<
NodeId
>
nodes
;
45
nodes
.
reserve
(
bn
.
dag
().
sizeNodes
());
46
for
(
const
auto
node
:
bn
.
dag
())
47
nodes
.
push_back
(
node
);
48
std
::
sort
(
nodes
.
begin
(),
nodes
.
end
());
49
std
::
size_t
i
=
std
::
size_t
(0);
50
for
(
auto
node
:
nodes
) {
51
const
Variable
&
var
=
bn
.
variable
(
node
);
52
try
{
53
_database_
.
insertTranslator
(
var
,
var_names
[
var
.
name
()],
missing_symbols
);
54
}
catch
(
NotFound
&) {
55
GUM_ERROR
(
MissingVariableInDatabase
,
"Variable '"
<<
var
.
name
() <<
"' is missing"
)
56
}
57
_nodeId2cols_
.
insert
(
NodeId
(
node
),
i
++);
58
}
59
60
// fill the database
61
initializer
.
fillDatabase
(
_database_
);
62
63
// get the domain sizes of the variables
64
for
(
auto
dom
:
_database_
.
domainSizes
())
65
_domain_sizes_
.
push_back
(
dom
);
66
67
// create the parser
68
_parser_
=
new
DBRowGeneratorParser
<>(
_database_
.
handler
(),
DBRowGeneratorSet
<>());
69
}
70
71
72
template
<
typename
GUM_SCALAR
>
73
BayesNet
<
GUM_SCALAR
>
genericBNLearner
::
Database
::
_BNVars_
()
const
{
74
BayesNet
<
GUM_SCALAR
>
bn
;
75
const
std
::
size_t
nb_vars
=
_database_
.
nbVariables
();
76
for
(
std
::
size_t
i
= 0;
i
<
nb_vars
; ++
i
) {
77
const
DiscreteVariable
&
var
78
=
dynamic_cast
<
const
DiscreteVariable
& >(
_database_
.
variable
(
i
));
79
bn
.
add
(
var
);
80
}
81
return
bn
;
82
}
83
84
85
template
<
typename
GUM_SCALAR
>
86
genericBNLearner
::
genericBNLearner
(
const
std
::
string
&
filename
,
87
const
gum
::
BayesNet
<
GUM_SCALAR
>&
bn
,
88
const
std
::
vector
<
std
::
string
>&
missing_symbols
) :
89
scoreDatabase_
(
filename
,
bn
,
missing_symbols
) {
90
noApriori_
=
new
AprioriNoApriori
<>(
scoreDatabase_
.
databaseTable
());
91
GUM_CONSTRUCTOR
(
genericBNLearner
);
92
}
93
94
95
/// use a new set of database rows' ranges to perform learning
96
template
<
template
<
typename
>
class
XALLOC
>
97
void
genericBNLearner
::
useDatabaseRanges
(
98
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
99
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
new_ranges
) {
100
// use a score to detect whether the ranges are ok
101
ScoreLog2Likelihood
<>
score
(
scoreDatabase_
.
parser
(), *
noApriori_
);
102
score
.
setRanges
(
new_ranges
);
103
ranges_
=
score
.
ranges
();
104
}
105
}
// namespace learning
106
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31