32 #ifndef DOXYGEN_SHOULD_SKIP_THIS 35 # include <agrum/BN/learning/BNLearner.h> 37 # include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h> 42 template <
typename GUM_SCALAR >
43 BNLearner< GUM_SCALAR >::BNLearner(
44 const std::string& filename,
45 const std::vector< std::string >& missing_symbols) :
46 genericBNLearner(filename, missing_symbols) {
47 GUM_CONSTRUCTOR(BNLearner);
50 template <
typename GUM_SCALAR >
51 BNLearner< GUM_SCALAR >::BNLearner(
const DatabaseTable<>& db) :
52 genericBNLearner(db) {
53 GUM_CONSTRUCTOR(BNLearner);
56 template <
typename GUM_SCALAR >
57 BNLearner< GUM_SCALAR >::BNLearner(
58 const std::string& filename,
59 const gum::BayesNet< GUM_SCALAR >& bn,
60 const std::vector< std::string >& missing_symbols) :
61 genericBNLearner(filename, bn, missing_symbols) {
62 GUM_CONSTRUCTOR(BNLearner);
66 template <
typename GUM_SCALAR >
67 BNLearner< GUM_SCALAR >::BNLearner(
const BNLearner< GUM_SCALAR >& src) :
68 genericBNLearner(src) {
69 GUM_CONSTRUCTOR(BNLearner);
73 template <
typename GUM_SCALAR >
74 BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) :
75 genericBNLearner(src) {
76 GUM_CONSTRUCTOR(BNLearner);
80 template <
typename GUM_SCALAR >
81 BNLearner< GUM_SCALAR >::~BNLearner() {
82 GUM_DESTRUCTOR(BNLearner);
93 template <
typename GUM_SCALAR >
94 BNLearner< GUM_SCALAR >&
95 BNLearner< GUM_SCALAR >::operator=(
const BNLearner< GUM_SCALAR >& src) {
96 genericBNLearner::operator=(src);
101 template <
typename GUM_SCALAR >
102 BNLearner< GUM_SCALAR >&
103 BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) {
104 genericBNLearner::operator=(std::move(src));
109 template <
typename GUM_SCALAR >
110 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
112 auto notification = checkScoreAprioriCompatibility();
113 if (notification !=
"") {
114 std::cout <<
"[aGrUM notification] " << notification << std::endl;
119 std::unique_ptr< ParamEstimator<> > param_estimator(
120 createParamEstimator__(score_database__.parser(),
true));
122 return Dag2BN__.createBN< GUM_SCALAR >(*(param_estimator.get()),
127 template <
typename GUM_SCALAR >
128 BayesNet< GUM_SCALAR >
129 BNLearner< GUM_SCALAR >::learnParameters(
const DAG& dag,
130 bool take_into_account_score) {
132 if (dag.size() == 0)
return BayesNet< GUM_SCALAR >();
135 std::vector< NodeId > ids;
136 ids.reserve(dag.sizeNodes());
137 for (
const auto node: dag)
139 std::sort(ids.begin(), ids.end());
141 if (ids.back() >= score_database__.names().size()) {
142 std::stringstream str;
143 str <<
"Learning parameters corresponding to the dag is impossible " 144 <<
"because the database does not contain the following nodeID";
145 std::vector< NodeId > bad_ids;
146 for (
const auto node: ids) {
147 if (node >= score_database__.names().size()) bad_ids.push_back(node);
149 if (bad_ids.size() > 1) str <<
's';
152 for (
const auto node: bad_ids) {
159 GUM_ERROR(MissingVariableInDatabase, str.str());
165 if (EMepsilon__ == 0.0) {
167 if (score_database__.databaseTable().hasMissingValues()
168 || ((apriori_database__ !=
nullptr)
169 && (apriori_type__ == AprioriType::DIRICHLET_FROM_DATABASE)
170 && apriori_database__->databaseTable().hasMissingValues())) {
171 GUM_ERROR(MissingValueInDatabase,
172 "In general, the BNLearner is unable to cope with " 173 <<
"missing values in databases. To learn parameters in " 174 <<
"such situations, you should first use method " 179 DBRowGeneratorParser<> parser(score_database__.databaseTable().handler(),
180 DBRowGeneratorSet<>());
181 std::unique_ptr< ParamEstimator<> > param_estimator(
182 createParamEstimator__(parser, take_into_account_score));
184 return Dag2BN__.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
187 BNLearnerListener listener(
this, Dag2BN__);
190 const auto& database = score_database__.databaseTable();
191 const std::size_t nb_vars = database.nbVariables();
192 const std::vector< gum::learning::DBTranslatedValueType > col_types(
194 gum::learning::DBTranslatedValueType::DISCRETE);
197 DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
198 DBRowGeneratorSet<> genset_bootstrap;
199 genset_bootstrap.insertGenerator(generator_bootstrap);
200 DBRowGeneratorParser<> parser_bootstrap(database.handler(),
202 std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
203 createParamEstimator__(parser_bootstrap, take_into_account_score));
206 BayesNet< GUM_SCALAR > dummy_bn;
207 DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
208 DBRowGenerator<>& gen_EM = generator_EM;
209 DBRowGeneratorSet<> genset_EM;
210 genset_EM.insertGenerator(gen_EM);
211 DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
212 std::unique_ptr< ParamEstimator<> > param_estimator_EM(
213 createParamEstimator__(parser_EM, take_into_account_score));
215 Dag2BN__.setEpsilon(EMepsilon__);
216 return Dag2BN__.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()),
217 *(param_estimator_EM.get()),
224 template <
typename GUM_SCALAR >
225 BayesNet< GUM_SCALAR >
226 BNLearner< GUM_SCALAR >::learnParameters(
bool take_into_account_score) {
227 return learnParameters(initial_dag__, take_into_account_score);
231 template <
typename GUM_SCALAR >
232 NodeProperty< Sequence< std::string > >
233 BNLearner< GUM_SCALAR >::labelsFromBN__(
const std::string& filename,
234 const BayesNet< GUM_SCALAR >& src) {
235 std::ifstream in(filename, std::ifstream::in);
237 if ((in.rdstate() & std::ifstream::failbit) != 0) {
238 GUM_ERROR(gum::IOError,
"File " << filename <<
" not found");
241 CSVParser<> parser(in);
243 auto names = parser.current();
245 NodeProperty< Sequence< std::string > > modals;
247 for (gum::Idx col = 0; col < names.size(); col++) {
249 gum::NodeId graphId = src.idFromName(names[col]);
250 modals.insert(col, gum::Sequence< std::string >());
252 for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
253 modals[col].insert(src.variable(graphId).label(i));
254 }
catch (
const gum::NotFound&) {