34 #include <agrum/agrum.h> 35 #include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h> 36 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h> 37 #include <agrum/tools/stattests/indepTestChi2.h> 38 #include <agrum/tools/stattests/indepTestG2.h> 39 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h> 40 #include <agrum/tools/stattests/pseudoCount.h> 44 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h> 52 genericBNLearner::Database::Database(
const DatabaseTable<>& db) : _database_(db) {
54 const auto& var_names = _database_.variableNames();
55 const std::size_t nb_vars = var_names.size();
56 for (
auto dom: _database_.domainSizes())
57 _domain_sizes_.push_back(dom);
58 for (std::size_t i = 0; i < nb_vars; ++i) {
59 _nodeId2cols_.insert(NodeId(i), i);
63 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
67 genericBNLearner::Database::Database(
const std::string& filename,
68 const std::vector< std::string >& missing_symbols) :
69 Database(genericBNLearner::readFile_(filename, missing_symbols)) {}
72 genericBNLearner::Database::Database(
const std::string& CSV_filename,
73 Database& score_database,
74 const std::vector< std::string >& missing_symbols) {
76 genericBNLearner::checkFileName_(CSV_filename);
77 DBInitializerFromCSV<> initializer(CSV_filename);
78 const auto& apriori_names = initializer.variableNames();
79 std::size_t apriori_nb_vars = apriori_names.size();
80 HashTable< std::string, std::size_t > apriori_names2col(apriori_nb_vars);
81 for (std::size_t i = std::size_t(0); i < apriori_nb_vars; ++i)
82 apriori_names2col.insert(apriori_names[i], i);
86 if (apriori_nb_vars < score_database._database_.nbVariables()) {
87 GUM_ERROR(InvalidArgument,
88 "the a apriori database has fewer variables " 89 "than the observed database");
94 const std::vector< std::string >& score_names
95 = score_database.databaseTable().variableNames();
96 const std::size_t score_nb_vars = score_names.size();
97 HashTable< std::size_t, std::size_t > mapping(score_nb_vars);
98 for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
100 mapping.insert(i, apriori_names2col[score_names[i]]);
101 }
catch (Exception&) {
102 GUM_ERROR(MissingVariableInDatabase,
103 "Variable " << score_names[i]
104 <<
" of the observed database does not belong to the " 105 <<
"apriori database");
110 for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
111 const Variable& var = score_database.databaseTable().variable(i);
112 _database_.insertTranslator(var, mapping[i], missing_symbols);
116 initializer.fillDatabase(_database_);
119 for (
auto dom: _database_.domainSizes())
120 _domain_sizes_.push_back(dom);
123 _nodeId2cols_ = score_database.nodeId2Columns();
126 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
130 genericBNLearner::Database::Database(
const Database& from) :
131 _database_(from._database_), _domain_sizes_(from._domain_sizes_),
132 _nodeId2cols_(from._nodeId2cols_) {
134 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
138 genericBNLearner::Database::Database(Database&& from) :
139 _database_(std::move(from._database_)), _domain_sizes_(std::move(from._domain_sizes_)),
140 _nodeId2cols_(std::move(from._nodeId2cols_)) {
142 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
146 genericBNLearner::Database::~Database() {
delete _parser_; }
148 genericBNLearner::Database& genericBNLearner::Database::operator=(
const Database& from) {
151 _database_ = from._database_;
152 _domain_sizes_ = from._domain_sizes_;
153 _nodeId2cols_ = from._nodeId2cols_;
156 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
162 genericBNLearner::Database& genericBNLearner::Database::operator=(Database&& from) {
165 _database_ = std::move(from._database_);
166 _domain_sizes_ = std::move(from._domain_sizes_);
167 _nodeId2cols_ = std::move(from._nodeId2cols_);
170 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
179 genericBNLearner::genericBNLearner(
const std::string& filename,
180 const std::vector< std::string >& missing_symbols) :
181 scoreDatabase_(filename, missing_symbols) {
182 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
184 GUM_CONSTRUCTOR(genericBNLearner);
188 genericBNLearner::genericBNLearner(
const DatabaseTable<>& db) : scoreDatabase_(db) {
189 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
191 GUM_CONSTRUCTOR(genericBNLearner);
195 genericBNLearner::genericBNLearner(
const genericBNLearner& from) :
196 scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
197 epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
198 aprioriWeight_(from.aprioriWeight_), constraintSliceOrder_(from.constraintSliceOrder_),
199 constraintIndegree_(from.constraintIndegree_),
200 constraintTabuList_(from.constraintTabuList_),
201 constraintForbiddenArcs_(from.constraintForbiddenArcs_),
202 constraintMandatoryArcs_(from.constraintMandatoryArcs_), selectedAlgo_(from.selectedAlgo_),
203 algoK2_(from.algoK2_), algoMiic3off2_(from.algoMiic3off2_), kmode3Off2_(from.kmode3Off2_),
204 greedyHillClimbing_(from.greedyHillClimbing_),
205 localSearchWithTabuList_(from.localSearchWithTabuList_),
206 scoreDatabase_(from.scoreDatabase_), ranges_(from.ranges_),
207 aprioriDbname_(from.aprioriDbname_), initialDag_(from.initialDag_) {
208 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
210 GUM_CONS_CPY(genericBNLearner);
213 genericBNLearner::genericBNLearner(genericBNLearner&& from) :
214 scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
215 epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
216 aprioriWeight_(from.aprioriWeight_),
217 constraintSliceOrder_(std::move(from.constraintSliceOrder_)),
218 constraintIndegree_(std::move(from.constraintIndegree_)),
219 constraintTabuList_(std::move(from.constraintTabuList_)),
220 constraintForbiddenArcs_(std::move(from.constraintForbiddenArcs_)),
221 constraintMandatoryArcs_(std::move(from.constraintMandatoryArcs_)),
222 selectedAlgo_(from.selectedAlgo_), algoK2_(std::move(from.algoK2_)),
223 algoMiic3off2_(std::move(from.algoMiic3off2_)), kmode3Off2_(from.kmode3Off2_),
224 greedyHillClimbing_(std::move(from.greedyHillClimbing_)),
225 localSearchWithTabuList_(std::move(from.localSearchWithTabuList_)),
226 scoreDatabase_(std::move(from.scoreDatabase_)), ranges_(std::move(from.ranges_)),
227 aprioriDbname_(std::move(from.aprioriDbname_)), initialDag_(std::move(from.initialDag_)) {
228 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
230 GUM_CONS_MOV(genericBNLearner)
233 genericBNLearner::~genericBNLearner() {
234 if (score_)
delete score_;
236 if (apriori_)
delete apriori_;
238 if (noApriori_)
delete noApriori_;
240 if (aprioriDatabase_)
delete aprioriDatabase_;
242 if (mutualInfo_)
delete mutualInfo_;
244 GUM_DESTRUCTOR(genericBNLearner);
247 genericBNLearner& genericBNLearner::operator=(
const genericBNLearner& from) {
259 if (aprioriDatabase_) {
260 delete aprioriDatabase_;
261 aprioriDatabase_ =
nullptr;
266 mutualInfo_ =
nullptr;
269 scoreType_ = from.scoreType_;
270 paramEstimatorType_ = from.paramEstimatorType_;
271 epsilonEM_ = from.epsilonEM_;
272 aprioriType_ = from.aprioriType_;
273 aprioriWeight_ = from.aprioriWeight_;
274 constraintSliceOrder_ = from.constraintSliceOrder_;
275 constraintIndegree_ = from.constraintIndegree_;
276 constraintTabuList_ = from.constraintTabuList_;
277 constraintForbiddenArcs_ = from.constraintForbiddenArcs_;
278 constraintMandatoryArcs_ = from.constraintMandatoryArcs_;
279 selectedAlgo_ = from.selectedAlgo_;
280 algoK2_ = from.algoK2_;
281 algoMiic3off2_ = from.algoMiic3off2_;
282 kmode3Off2_ = from.kmode3Off2_;
283 greedyHillClimbing_ = from.greedyHillClimbing_;
284 localSearchWithTabuList_ = from.localSearchWithTabuList_;
285 scoreDatabase_ = from.scoreDatabase_;
286 ranges_ = from.ranges_;
287 aprioriDbname_ = from.aprioriDbname_;
288 initialDag_ = from.initialDag_;
289 currentAlgorithm_ =
nullptr;
295 genericBNLearner& genericBNLearner::operator=(genericBNLearner&& from) {
307 if (aprioriDatabase_) {
308 delete aprioriDatabase_;
309 aprioriDatabase_ =
nullptr;
314 mutualInfo_ =
nullptr;
317 scoreType_ = from.scoreType_;
318 paramEstimatorType_ = from.paramEstimatorType_;
319 epsilonEM_ = from.epsilonEM_;
320 aprioriType_ = from.aprioriType_;
321 aprioriWeight_ = from.aprioriWeight_;
322 constraintSliceOrder_ = std::move(from.constraintSliceOrder_);
323 constraintIndegree_ = std::move(from.constraintIndegree_);
324 constraintTabuList_ = std::move(from.constraintTabuList_);
325 constraintForbiddenArcs_ = std::move(from.constraintForbiddenArcs_);
326 constraintMandatoryArcs_ = std::move(from.constraintMandatoryArcs_);
327 selectedAlgo_ = from.selectedAlgo_;
328 algoK2_ = from.algoK2_;
329 algoMiic3off2_ = std::move(from.algoMiic3off2_);
330 kmode3Off2_ = from.kmode3Off2_;
331 greedyHillClimbing_ = std::move(from.greedyHillClimbing_);
332 localSearchWithTabuList_ = std::move(from.localSearchWithTabuList_);
333 scoreDatabase_ = std::move(from.scoreDatabase_);
334 ranges_ = std::move(from.ranges_);
335 aprioriDbname_ = std::move(from.aprioriDbname_);
336 initialDag_ = std::move(from.initialDag_);
337 currentAlgorithm_ =
nullptr;
344 DatabaseTable<> readFile(
const std::string& filename) {
346 Size filename_size = Size(filename.size());
348 if (filename_size < 4) {
349 GUM_ERROR(FormatNotFound,
350 "genericBNLearner could not determine the " 351 "file type of the database");
354 std::string extension = filename.substr(filename.size() - 4);
355 std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
357 if (extension !=
".csv") {
358 GUM_ERROR(OperationNotAllowed,
359 "genericBNLearner does not support yet this type " 363 DBInitializerFromCSV<> initializer(filename);
365 const auto& var_names = initializer.variableNames();
366 const std::size_t nb_vars = var_names.size();
368 DBTranslatorSet<> translator_set;
369 DBTranslator4LabelizedVariable<> translator;
370 for (std::size_t i = 0; i < nb_vars; ++i) {
371 translator_set.insertTranslator(translator, i);
374 DatabaseTable<> database(translator_set);
375 database.setVariableNames(initializer.variableNames());
376 initializer.fillDatabase(database);
382 void genericBNLearner::checkFileName_(
const std::string& filename) {
384 Size filename_size = Size(filename.size());
386 if (filename_size < 4) {
387 GUM_ERROR(FormatNotFound,
388 "genericBNLearner could not determine the " 389 "file type of the database");
392 std::string extension = filename.substr(filename.size() - 4);
393 std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
395 if (extension !=
".csv") {
396 GUM_ERROR(OperationNotAllowed,
397 "genericBNLearner does not support yet this type of database file");
402 DatabaseTable<> genericBNLearner::readFile_(
const std::string& filename,
403 const std::vector< std::string >& missing_symbols) {
405 checkFileName_(filename);
407 DBInitializerFromCSV<> initializer(filename);
409 const auto& var_names = initializer.variableNames();
410 const std::size_t nb_vars = var_names.size();
412 DBTranslatorSet<> translator_set;
413 DBTranslator4LabelizedVariable<> translator(missing_symbols);
414 for (std::size_t i = 0; i < nb_vars; ++i) {
415 translator_set.insertTranslator(translator, i);
418 DatabaseTable<> database(missing_symbols, translator_set);
419 database.setVariableNames(initializer.variableNames());
420 initializer.fillDatabase(database);
428 void genericBNLearner::createApriori_() {
430 Apriori<>* old_apriori = apriori_;
433 switch (aprioriType_) {
434 case AprioriType::NO_APRIORI:
435 apriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable(),
436 scoreDatabase_.nodeId2Columns());
439 case AprioriType::SMOOTHING:
440 apriori_ =
new AprioriSmoothing<>(scoreDatabase_.databaseTable(),
441 scoreDatabase_.nodeId2Columns());
444 case AprioriType::DIRICHLET_FROM_DATABASE:
445 if (aprioriDatabase_ !=
nullptr) {
446 delete aprioriDatabase_;
447 aprioriDatabase_ =
nullptr;
451 =
new Database(aprioriDbname_, scoreDatabase_, scoreDatabase_.missingSymbols());
453 apriori_ =
new AprioriDirichletFromDatabase<>(scoreDatabase_.databaseTable(),
454 aprioriDatabase_->parser(),
455 aprioriDatabase_->nodeId2Columns());
458 case AprioriType::BDEU:
460 =
new AprioriBDeu<>(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
464 GUM_ERROR(OperationNotAllowed,
"The BNLearner does not support yet this apriori")
468 apriori_->setWeight(aprioriWeight_);
471 if (old_apriori !=
nullptr)
delete old_apriori;
474 void genericBNLearner::createScore_() {
476 Score<>* old_score = score_;
479 switch (scoreType_) {
481 score_ =
new ScoreAIC<>(scoreDatabase_.parser(),
484 scoreDatabase_.nodeId2Columns());
488 score_ =
new ScoreBD<>(scoreDatabase_.parser(),
491 scoreDatabase_.nodeId2Columns());
494 case ScoreType::BDeu:
495 score_ =
new ScoreBDeu<>(scoreDatabase_.parser(),
498 scoreDatabase_.nodeId2Columns());
502 score_ =
new ScoreBIC<>(scoreDatabase_.parser(),
505 scoreDatabase_.nodeId2Columns());
509 score_ =
new ScoreK2<>(scoreDatabase_.parser(),
512 scoreDatabase_.nodeId2Columns());
515 case ScoreType::LOG2LIKELIHOOD:
516 score_ =
new ScoreLog2Likelihood<>(scoreDatabase_.parser(),
519 scoreDatabase_.nodeId2Columns());
523 GUM_ERROR(OperationNotAllowed,
"genericBNLearner does not support yet this score")
527 if (old_score !=
nullptr)
delete old_score;
530 ParamEstimator<>* genericBNLearner::createParamEstimator_(DBRowGeneratorParser<>& parser,
531 bool take_into_account_score) {
532 ParamEstimator<>* param_estimator =
nullptr;
535 switch (paramEstimatorType_) {
536 case ParamEstimatorType::ML:
537 if (take_into_account_score && (score_ !=
nullptr)) {
538 param_estimator =
new ParamEstimatorML<>(parser,
540 score_->internalApriori(),
542 scoreDatabase_.nodeId2Columns());
544 param_estimator =
new ParamEstimatorML<>(parser,
548 scoreDatabase_.nodeId2Columns());
554 GUM_ERROR(OperationNotAllowed,
555 "genericBNLearner does not support " 556 <<
"yet this parameter estimator");
560 param_estimator->setRanges(ranges_);
562 return param_estimator;
566 MixedGraph genericBNLearner::prepareMiic3Off2_() {
569 for (Size i = 0; i < scoreDatabase_.databaseTable().nbVariables(); ++i) {
570 mgraph.addNodeWithId(i);
571 for (Size j = 0; j < i; ++j) {
572 mgraph.addEdge(j, i);
577 HashTable< std::pair< NodeId, NodeId >,
char > initial_marks;
578 const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
579 for (
const auto& arc: mandatory_arcs) {
580 initial_marks.insert({arc.tail(), arc.head()},
'>');
583 const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
584 for (
const auto& arc: forbidden_arcs) {
585 initial_marks.insert({arc.tail(), arc.head()},
'-');
587 algoMiic3off2_.addConstraints(initial_marks);
591 createCorrectedMutualInformation_();
596 MixedGraph genericBNLearner::learnMixedStructure() {
597 if (selectedAlgo_ != AlgoType::MIIC_THREE_OFF_TWO) {
598 GUM_ERROR(OperationNotAllowed,
"Must be using the miic/3off2 algorithm")
601 if (scoreDatabase_.databaseTable().hasMissingValues()) {
602 GUM_ERROR(MissingValueInDatabase,
603 "For the moment, the BNLearner is unable to learn " 604 <<
"structures with missing values in databases");
606 BNLearnerListener listener(
this, algoMiic3off2_);
609 MixedGraph mgraph =
this->prepareMiic3Off2_();
611 return algoMiic3off2_.learnMixedStructure(*mutualInfo_, mgraph);
614 DAG genericBNLearner::learnDAG() {
622 void genericBNLearner::createCorrectedMutualInformation_() {
623 if (mutualInfo_ !=
nullptr)
delete mutualInfo_;
625 mutualInfo_ =
new CorrectedMutualInformation<>(scoreDatabase_.parser(),
628 scoreDatabase_.nodeId2Columns());
629 switch (kmode3Off2_) {
630 case CorrectedMutualInformation<>::KModeTypes::MDL:
631 mutualInfo_->useMDL();
634 case CorrectedMutualInformation<>::KModeTypes::NML:
635 mutualInfo_->useNML();
638 case CorrectedMutualInformation<>::KModeTypes::NoCorr:
639 mutualInfo_->useNoCorr();
643 GUM_ERROR(NotImplementedYet,
644 "The BNLearner's corrected mutual information class does " 645 <<
"not implement yet this correction : " <<
int(kmode3Off2_));
649 DAG genericBNLearner::learnDag_() {
651 if (scoreDatabase_.databaseTable().hasMissingValues()
652 || ((aprioriDatabase_ !=
nullptr)
653 && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
654 && aprioriDatabase_->databaseTable().hasMissingValues())) {
655 GUM_ERROR(MissingValueInDatabase,
656 "For the moment, the BNLearner is unable to cope " 657 "with missing values in databases");
661 DAG init_graph = initialDag_;
663 const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
665 for (
const auto& arc: mandatory_arcs) {
666 if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
668 if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
670 init_graph.addArc(arc.tail(), arc.head());
673 const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
675 for (
const auto& arc: forbidden_arcs) {
676 init_graph.eraseArc(arc);
679 switch (selectedAlgo_) {
681 case AlgoType::MIIC_THREE_OFF_TWO: {
682 BNLearnerListener listener(
this, algoMiic3off2_);
684 MixedGraph mgraph =
this->prepareMiic3Off2_();
686 return algoMiic3off2_.learnStructure(*mutualInfo_, mgraph);
690 case AlgoType::GREEDY_HILL_CLIMBING: {
691 BNLearnerListener listener(
this, greedyHillClimbing_);
692 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
693 StructuralConstraintForbiddenArcs,
694 StructuralConstraintPossibleEdges,
695 StructuralConstraintSliceOrder >
697 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
698 = constraintMandatoryArcs_;
699 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
700 = constraintForbiddenArcs_;
701 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
702 = constraintPossibleEdges_;
703 static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
705 GraphChangesGenerator4DiGraph<
decltype(gen_constraint) > op_set(gen_constraint);
707 StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
709 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
711 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
716 return greedyHillClimbing_.learnStructure(selector, init_graph);
720 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST: {
721 BNLearnerListener listener(
this, localSearchWithTabuList_);
722 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
723 StructuralConstraintForbiddenArcs,
724 StructuralConstraintPossibleEdges,
725 StructuralConstraintSliceOrder >
727 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
728 = constraintMandatoryArcs_;
729 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
730 = constraintForbiddenArcs_;
731 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
732 = constraintPossibleEdges_;
733 static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
735 GraphChangesGenerator4DiGraph<
decltype(gen_constraint) > op_set(gen_constraint);
737 StructuralConstraintSetStatic< StructuralConstraintTabuList,
738 StructuralConstraintIndegree,
739 StructuralConstraintDAG >
741 static_cast< StructuralConstraintTabuList& >(sel_constraint) = constraintTabuList_;
742 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
744 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
749 return localSearchWithTabuList_.learnStructure(selector, init_graph);
754 BNLearnerListener listener(
this, algoK2_.approximationScheme());
755 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
756 StructuralConstraintForbiddenArcs,
757 StructuralConstraintPossibleEdges >
759 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
760 = constraintMandatoryArcs_;
761 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
762 = constraintForbiddenArcs_;
763 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
764 = constraintPossibleEdges_;
766 GraphChangesGenerator4K2<
decltype(gen_constraint) > op_set(gen_constraint);
770 const ArcSet& mandatory_arcs
771 =
static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint).arcs();
772 const Sequence< NodeId >& order = algoK2_.order();
773 bool order_compatible =
true;
775 for (
const auto& arc: mandatory_arcs) {
776 if (order.pos(arc.tail()) >= order.pos(arc.head())) {
777 order_compatible =
false;
782 if (order_compatible) {
783 StructuralConstraintSetStatic< StructuralConstraintIndegree,
784 StructuralConstraintDiGraph >
786 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
788 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
793 return algoK2_.learnStructure(selector, init_graph);
795 StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
797 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
799 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
804 return algoK2_.learnStructure(selector, init_graph);
810 GUM_ERROR(OperationNotAllowed,
811 "the learnDAG method has not been implemented for this " 812 "learning algorithm");
816 std::string genericBNLearner::checkScoreAprioriCompatibility() {
817 const std::string& apriori = getAprioriType_();
819 switch (scoreType_) {
821 return ScoreAIC<>::isAprioriCompatible(apriori, aprioriWeight_);
824 return ScoreBD<>::isAprioriCompatible(apriori, aprioriWeight_);
826 case ScoreType::BDeu:
827 return ScoreBDeu<>::isAprioriCompatible(apriori, aprioriWeight_);
830 return ScoreBIC<>::isAprioriCompatible(apriori, aprioriWeight_);
833 return ScoreK2<>::isAprioriCompatible(apriori, aprioriWeight_);
835 case ScoreType::LOG2LIKELIHOOD:
836 return ScoreLog2Likelihood<>::isAprioriCompatible(apriori, aprioriWeight_);
839 return "genericBNLearner does not support yet this score";
845 std::pair< std::size_t, std::size_t >
846 genericBNLearner::useCrossValidationFold(
const std::size_t learning_fold,
847 const std::size_t k_fold) {
848 if (k_fold == 0) { GUM_ERROR(OutOfBounds,
"K-fold cross validation with k=0 is forbidden") }
850 if (learning_fold >= k_fold) {
851 GUM_ERROR(OutOfBounds,
852 "In " << k_fold <<
"-fold cross validation, the learning " 853 <<
"fold should be strictly lower than " << k_fold
854 <<
" but, here, it is equal to " << learning_fold);
857 const std::size_t db_size = scoreDatabase_.databaseTable().nbRows();
858 if (k_fold >= db_size) {
859 GUM_ERROR(OutOfBounds,
860 "In " << k_fold <<
"-fold cross validation, the database's " 861 <<
"size should be strictly greater than " << k_fold
862 <<
" but, here, the database has only " << db_size <<
"rows");
866 const std::size_t foldSize = db_size / k_fold;
867 const std::size_t unfold_deb = learning_fold * foldSize;
868 const std::size_t unfold_end = unfold_deb + foldSize;
871 if (learning_fold == std::size_t(0)) {
872 ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
874 ranges_.push_back(std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
876 if (learning_fold != k_fold - 1) {
877 ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
881 return std::pair< std::size_t, std::size_t >(unfold_deb, unfold_end);
885 std::pair<
double,
double > genericBNLearner::chi2(
const NodeId id1,
887 const std::vector< NodeId >& knowing) {
889 gum::learning::IndepTestChi2<> chi2score(scoreDatabase_.parser(),
893 return chi2score.statistics(id1, id2, knowing);
896 std::pair<
double,
double > genericBNLearner::chi2(
const std::string& name1,
897 const std::string& name2,
898 const std::vector< std::string >& knowing) {
899 std::vector< NodeId > knowingIds;
900 std::transform(knowing.begin(),
902 std::back_inserter(knowingIds),
903 [
this](
const std::string& c) -> NodeId {
return this->idFromName(c); });
904 return chi2(idFromName(name1), idFromName(name2), knowingIds);
907 std::pair<
double,
double > genericBNLearner::G2(
const NodeId id1,
909 const std::vector< NodeId >& knowing) {
911 gum::learning::IndepTestG2<> g2score(scoreDatabase_.parser(), *apriori_, databaseRanges());
912 return g2score.statistics(id1, id2, knowing);
915 std::pair<
double,
double > genericBNLearner::G2(
const std::string& name1,
916 const std::string& name2,
917 const std::vector< std::string >& knowing) {
918 std::vector< NodeId > knowingIds;
919 std::transform(knowing.begin(),
921 std::back_inserter(knowingIds),
922 [
this](
const std::string& c) -> NodeId {
return this->idFromName(c); });
923 return G2(idFromName(name1), idFromName(name2), knowingIds);
926 double genericBNLearner::logLikelihood(
const std::vector< NodeId >& vars,
927 const std::vector< NodeId >& knowing) {
929 gum::learning::ScoreLog2Likelihood<> ll2score(scoreDatabase_.parser(),
933 std::vector< NodeId > total(vars);
934 total.insert(total.end(), knowing.begin(), knowing.end());
935 double LLtotal = ll2score.score(IdCondSet<>(total,
false,
true));
936 if (knowing.size() == (Size)0) {
939 double LLknw = ll2score.score(IdCondSet<>(knowing,
false,
true));
940 return LLtotal - LLknw;
944 double genericBNLearner::logLikelihood(
const std::vector< std::string >& vars,
945 const std::vector< std::string >& knowing) {
946 std::vector< NodeId > ids;
947 std::vector< NodeId > knowingIds;
949 auto mapper = [
this](
const std::string& c) -> NodeId {
950 return this->idFromName(c);
953 std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
954 std::transform(knowing.begin(), knowing.end(), std::back_inserter(knowingIds), mapper);
956 return logLikelihood(ids, knowingIds);
959 std::vector<
double > genericBNLearner::rawPseudoCount(
const std::vector< NodeId >& vars) {
960 Potential<
double > res;
963 gum::learning::PseudoCount<> count(scoreDatabase_.parser(), *apriori_, databaseRanges());
964 return count.get(vars);
968 std::vector<
double > genericBNLearner::rawPseudoCount(
const std::vector< std::string >& vars) {
969 std::vector< NodeId > ids;
971 auto mapper = [
this](
const std::string& c) -> NodeId {
972 return this->idFromName(c);
975 std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
977 return rawPseudoCount(ids);