32 #ifndef DOXYGEN_SHOULD_SKIP_THIS 35 # include <agrum/BN/learning/BNLearner.h> 37 # include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h> 42 template <
typename GUM_SCALAR >
43 BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
44 const std::vector< std::string >& missing_symbols) :
45 genericBNLearner(filename, missing_symbols) {
46 GUM_CONSTRUCTOR(BNLearner);
49 template <
typename GUM_SCALAR >
50 BNLearner< GUM_SCALAR >::BNLearner(
const DatabaseTable<>& db) : genericBNLearner(db) {
51 GUM_CONSTRUCTOR(BNLearner);
54 template <
typename GUM_SCALAR >
55 BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
56 const gum::BayesNet< GUM_SCALAR >& bn,
57 const std::vector< std::string >& missing_symbols) :
58 genericBNLearner(filename, bn, missing_symbols) {
59 GUM_CONSTRUCTOR(BNLearner);
63 template <
typename GUM_SCALAR >
64 BNLearner< GUM_SCALAR >::BNLearner(
const BNLearner< GUM_SCALAR >& src) : genericBNLearner(src) {
65 GUM_CONSTRUCTOR(BNLearner);
69 template <
typename GUM_SCALAR >
70 BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : genericBNLearner(src) {
71 GUM_CONSTRUCTOR(BNLearner);
75 template <
typename GUM_SCALAR >
76 BNLearner< GUM_SCALAR >::~BNLearner() {
77 GUM_DESTRUCTOR(BNLearner);
88 template <
typename GUM_SCALAR >
89 BNLearner< GUM_SCALAR >&
90 BNLearner< GUM_SCALAR >::operator=(
const BNLearner< GUM_SCALAR >& src) {
91 genericBNLearner::operator=(src);
96 template <
typename GUM_SCALAR >
97 BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) {
98 genericBNLearner::operator=(std::move(src));
103 template <
typename GUM_SCALAR >
104 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
106 auto notification = checkScoreAprioriCompatibility();
107 if (notification !=
"") { std::cout <<
"[aGrUM notification] " << notification << std::endl; }
111 std::unique_ptr< ParamEstimator<> > param_estimator(
112 createParamEstimator_(scoreDatabase_.parser(),
true));
114 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_());
118 template <
typename GUM_SCALAR >
119 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(
const DAG& dag,
120 bool takeIntoAccountScore) {
122 if (dag.size() == 0)
return BayesNet< GUM_SCALAR >();
125 std::vector< NodeId > ids;
126 ids.reserve(dag.sizeNodes());
127 for (
const auto node: dag)
129 std::sort(ids.begin(), ids.end());
131 if (ids.back() >= scoreDatabase_.names().size()) {
132 std::stringstream str;
133 str <<
"Learning parameters corresponding to the dag is impossible " 134 <<
"because the database does not contain the following nodeID";
135 std::vector< NodeId > bad_ids;
136 for (
const auto node: ids) {
137 if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node);
139 if (bad_ids.size() > 1) str <<
's';
142 for (
const auto node: bad_ids) {
149 GUM_ERROR(MissingVariableInDatabase, str.str())
155 if (epsilonEM_ == 0.0) {
157 if (scoreDatabase_.databaseTable().hasMissingValues()
158 || ((aprioriDatabase_ !=
nullptr)
159 && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
160 && aprioriDatabase_->databaseTable().hasMissingValues())) {
161 GUM_ERROR(MissingValueInDatabase,
162 "In general, the BNLearner is unable to cope with " 163 <<
"missing values in databases. To learn parameters in " 164 <<
"such situations, you should first use method " 169 DBRowGeneratorParser<> parser(scoreDatabase_.databaseTable().handler(),
170 DBRowGeneratorSet<>());
171 std::unique_ptr< ParamEstimator<> > param_estimator(
172 createParamEstimator_(parser, takeIntoAccountScore));
174 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
177 BNLearnerListener listener(
this, Dag2BN_);
180 const auto& database = scoreDatabase_.databaseTable();
181 const std::size_t nb_vars = database.nbVariables();
182 const std::vector< gum::learning::DBTranslatedValueType > col_types(
184 gum::learning::DBTranslatedValueType::DISCRETE);
187 DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
188 DBRowGeneratorSet<> genset_bootstrap;
189 genset_bootstrap.insertGenerator(generator_bootstrap);
190 DBRowGeneratorParser<> parser_bootstrap(database.handler(), genset_bootstrap);
191 std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
192 createParamEstimator_(parser_bootstrap, takeIntoAccountScore));
195 BayesNet< GUM_SCALAR > dummy_bn;
196 DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
197 DBRowGenerator<>& gen_EM = generator_EM;
198 DBRowGeneratorSet<> genset_EM;
199 genset_EM.insertGenerator(gen_EM);
200 DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
201 std::unique_ptr< ParamEstimator<> > param_estimator_EM(
202 createParamEstimator_(parser_EM, takeIntoAccountScore));
204 Dag2BN_.setEpsilon(epsilonEM_);
205 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()),
206 *(param_estimator_EM.get()),
213 template <
typename GUM_SCALAR >
214 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(
bool take_into_account_score) {
215 return learnParameters(initialDag_, take_into_account_score);
219 template <
typename GUM_SCALAR >
220 NodeProperty< Sequence< std::string > >
221 BNLearner< GUM_SCALAR >::_labelsFromBN_(
const std::string& filename,
222 const BayesNet< GUM_SCALAR >& src) {
223 std::ifstream in(filename, std::ifstream::in);
225 if ((in.rdstate() & std::ifstream::failbit) != 0) {
226 GUM_ERROR(gum::IOError,
"File " << filename <<
" not found")
229 CSVParser<> parser(in);
231 auto names = parser.current();
233 NodeProperty< Sequence< std::string > > modals;
235 for (gum::Idx col = 0; col < names.size(); col++) {
237 gum::NodeId graphId = src.idFromName(names[col]);
238 modals.insert(col, gum::Sequence< std::string >());
240 for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
241 modals[col].insert(src.variable(graphId).label(i));
242 }
catch (
const gum::NotFound&) {
251 template <
typename GUM_SCALAR >
252 std::string BNLearner< GUM_SCALAR >::toString()
const {
253 const auto st = state();
256 for (
const auto& tuple: st)
257 if (std::get< 0 >(tuple).length() > maxkey) maxkey = std::get< 0 >(tuple).length();
260 for (
const auto& tuple: st) {
261 s << std::setiosflags(std::ios::left) << std::setw(maxkey) << std::get< 0 >(tuple) <<
" : " 262 << std::get< 1 >(tuple);
263 if (std::get< 2 >(tuple) !=
"") s <<
" (" << std::get< 2 >(tuple) <<
")";
269 template <
typename GUM_SCALAR >
270 std::vector< std::tuple< std::string, std::string, std::string > >
271 BNLearner< GUM_SCALAR >::state()
const {
272 std::vector< std::tuple< std::string, std::string, std::string > > vals;
276 const auto& db = database();
278 vals.emplace_back(
"Filename", filename_,
"");
279 vals.emplace_back(
"Size",
280 "(" + std::to_string(nbRows()) +
"," + std::to_string(nbCols()) +
")",
283 std::string vars =
"";
284 for (NodeId i = 0; i < db.nbVariables(); i++) {
285 if (i > 0) vars +=
", ";
286 vars += nameFromId(i) +
"[" + std::to_string(db.domainSize(i)) +
"]";
288 vals.emplace_back(
"Variables", vars,
"");
289 vals.emplace_back(
"Missing values", hasMissingValues() ?
"True" :
"False",
"");
292 switch (selectedAlgo_) {
293 case AlgoType::GREEDY_HILL_CLIMBING:
294 vals.emplace_back(key,
"Greedy Hill Climbing",
"");
297 vals.emplace_back(key,
"K2",
"");
298 const auto& k2order = algoK2_.order();
300 for (NodeId i = 0; i < k2order.size(); i++) {
301 if (i > 0) vars +=
", ";
302 vars += nameFromId(k2order.atPos(i));
304 vals.emplace_back(
"K2 order", vars,
"");
306 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST:
307 vals.emplace_back(key,
"Local Search with Tabu List",
"");
308 vals.emplace_back(
"Tabu list size", std::to_string(nbDecreasingChanges_),
"");
310 case AlgoType::THREE_OFF_TWO:
311 vals.emplace_back(key,
"3off2",
"");
314 vals.emplace_back(key,
"MIIC",
"");
317 vals.emplace_back(key,
"(unknown)",
"?");
321 if (selectedAlgo_ != AlgoType::MIIC && selectedAlgo_ != AlgoType::THREE_OFF_TWO) {
323 switch (scoreType_) {
325 vals.emplace_back(key,
"K2",
"");
328 vals.emplace_back(key,
"AIC",
"");
331 vals.emplace_back(key,
"BIC",
"");
334 vals.emplace_back(key,
"BD",
"");
336 case ScoreType::BDeu:
337 vals.emplace_back(key,
"BDeu",
"");
339 case ScoreType::LOG2LIKELIHOOD:
340 vals.emplace_back(key,
"Log2Likelihood",
"");
343 vals.emplace_back(key,
"(unknown)",
"?");
348 switch (kmode3Off2_) {
349 case CorrectedMutualInformation<>::KModeTypes::MDL:
350 vals.emplace_back(key,
"MDL",
"");
352 case CorrectedMutualInformation<>::KModeTypes::NML:
353 vals.emplace_back(key,
"NML",
"");
355 case CorrectedMutualInformation<>::KModeTypes::NoCorr:
356 vals.emplace_back(key,
"No correction",
"");
359 vals.emplace_back(key,
"(unknown)",
"?");
366 comment = checkScoreAprioriCompatibility();
367 switch (aprioriType_) {
368 case AprioriType::NO_APRIORI:
369 vals.emplace_back(key,
"-", comment);
371 case AprioriType::DIRICHLET_FROM_DATABASE:
372 vals.emplace_back(key,
"Dirichlet", comment);
373 vals.emplace_back(
"Dirichlet database", aprioriDbname_,
"");
375 case AprioriType::BDEU:
376 vals.emplace_back(key,
"BDEU", comment);
378 case AprioriType::SMOOTHING:
379 vals.emplace_back(key,
"Smoothing", comment);
382 vals.emplace_back(key,
"(unknown)",
"?");
386 if (aprioriType_ != AprioriType::NO_APRIORI)
387 vals.emplace_back(
"Prior weight", std::to_string(aprioriWeight_),
"");
389 if (databaseWeight() !=
double(nbRows())) {
390 vals.emplace_back(
"Database weight", std::to_string(databaseWeight()),
"");
393 if (epsilonEM_ > 0.0) {
395 if (!hasMissingValues()) comment =
"But no missing values in this database";
396 vals.emplace_back(
"EM",
"True",
"");
397 vals.emplace_back(
"EM epsilon", std::to_string(epsilonEM_), comment);
402 if (constraintIndegree_.maxIndegree() < std::numeric_limits< Size >::max()) {
403 vals.emplace_back(
"Constraint Max InDegree",
404 std::to_string(constraintIndegree_.maxIndegree()),
405 "Used only for score-based algorithms.");
407 if (!constraintForbiddenArcs_.arcs().empty()) {
410 for (
const auto& arc: constraintForbiddenArcs_.arcs()) {
415 res += nameFromId(arc.tail()) +
"->" + nameFromId(arc.head());
418 vals.emplace_back(
"Constraint Forbidden Arcs", res,
"");
420 if (!constraintMandatoryArcs_.arcs().empty()) {
423 for (
const auto& arc: constraintMandatoryArcs_.arcs()) {
428 res += nameFromId(arc.tail()) +
"->" + nameFromId(arc.head());
431 vals.emplace_back(
"Constraint Mandatory Arcs", res,
"");
433 if (!constraintPossibleEdges_.edges().empty()) {
436 for (
const auto& edge: constraintPossibleEdges_.edges()) {
441 res += nameFromId(edge.first()) +
"--" + nameFromId(edge.second());
444 vals.emplace_back(
"Constraint Possible Edges",
446 "Used only for score-based algorithms.");
448 if (!constraintSliceOrder_.sliceOrder().empty()) {
451 const auto& order = constraintSliceOrder_.sliceOrder();
452 for (
const auto& p: order) {
457 res += nameFromId(p.first) +
":" + std::to_string(p.second);
460 vals.emplace_back(
"Constraint Slice Order", res,
"Used only for score-based algorithms.");
462 if (initialDag_.size() != 0) {
463 vals.emplace_back(
"Initial DAG",
"True", initialDag_.toDot());
469 template <
typename GUM_SCALAR >
470 INLINE std::ostream& operator<<(std::ostream& output,
const BNLearner< GUM_SCALAR >& learner) {
471 output << learner.toString();