34 #include <agrum/agrum.h> 35 #include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h> 36 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h> 37 #include <agrum/tools/stattests/indepTestChi2.h> 38 #include <agrum/tools/stattests/indepTestG2.h> 39 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h> 40 #include <agrum/tools/stattests/pseudoCount.h> 44 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h> 52 genericBNLearner::Database::Database(
const DatabaseTable<>& db) :
55 const auto& var_names = database__.variableNames();
56 const std::size_t nb_vars = var_names.size();
57 for (
auto dom: database__.domainSizes())
58 domain_sizes__.push_back(dom);
59 for (std::size_t i = 0; i < nb_vars; ++i) {
60 nodeId2cols__.insert(NodeId(i), i);
65 =
new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
69 genericBNLearner::Database::Database(
70 const std::string& filename,
71 const std::vector< std::string >& missing_symbols) :
72 Database(genericBNLearner::readFile__(filename, missing_symbols)) {}
75 genericBNLearner::Database::Database(
76 const std::string& CSV_filename,
77 Database& score_database,
78 const std::vector< std::string >& missing_symbols) {
80 genericBNLearner::checkFileName__(CSV_filename);
81 DBInitializerFromCSV<> initializer(CSV_filename);
82 const auto& apriori_names = initializer.variableNames();
83 std::size_t apriori_nb_vars = apriori_names.size();
84 HashTable< std::string, std::size_t > apriori_names2col(apriori_nb_vars);
85 for (std::size_t i = std::size_t(0); i < apriori_nb_vars; ++i)
86 apriori_names2col.insert(apriori_names[i], i);
90 if (apriori_nb_vars < score_database.database__.nbVariables()) {
91 GUM_ERROR(InvalidArgument,
92 "the a apriori database has fewer variables " 93 "than the observed database");
98 const std::vector< std::string >& score_names
99 = score_database.databaseTable().variableNames();
100 const std::size_t score_nb_vars = score_names.size();
101 HashTable< std::size_t, std::size_t > mapping(score_nb_vars);
102 for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
104 mapping.insert(i, apriori_names2col[score_names[i]]);
105 }
catch (Exception&) {
106 GUM_ERROR(MissingVariableInDatabase,
109 <<
" of the observed database does not belong to the " 110 <<
"apriori database");
115 for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
116 const Variable& var = score_database.databaseTable().variable(i);
117 database__.insertTranslator(var, mapping[i], missing_symbols);
121 initializer.fillDatabase(database__);
124 for (
auto dom: database__.domainSizes())
125 domain_sizes__.push_back(dom);
128 nodeId2cols__ = score_database.nodeId2Columns();
132 =
new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
136 genericBNLearner::Database::Database(
const Database& from) :
137 database__(from.database__), domain_sizes__(from.domain_sizes__),
138 nodeId2cols__(from.nodeId2cols__) {
141 =
new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
145 genericBNLearner::Database::Database(Database&& from) :
146 database__(std::move(from.database__)),
147 domain_sizes__(std::move(from.domain_sizes__)),
148 nodeId2cols__(std::move(from.nodeId2cols__)) {
151 =
new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
155 genericBNLearner::Database::~Database() {
delete parser__; }
157 genericBNLearner::Database&
158 genericBNLearner::Database::operator=(
const Database& from) {
161 database__ = from.database__;
162 domain_sizes__ = from.domain_sizes__;
163 nodeId2cols__ = from.nodeId2cols__;
166 parser__ =
new DBRowGeneratorParser<>(database__.handler(),
167 DBRowGeneratorSet<>());
173 genericBNLearner::Database&
174 genericBNLearner::Database::operator=(Database&& from) {
177 database__ = std::move(from.database__);
178 domain_sizes__ = std::move(from.domain_sizes__);
179 nodeId2cols__ = std::move(from.nodeId2cols__);
182 parser__ =
new DBRowGeneratorParser<>(database__.handler(),
183 DBRowGeneratorSet<>());
192 genericBNLearner::genericBNLearner(
193 const std::string& filename,
194 const std::vector< std::string >& missing_symbols) :
195 score_database__(filename, missing_symbols) {
196 no_apriori__ =
new AprioriNoApriori<>(score_database__.databaseTable());
199 GUM_CONSTRUCTOR(genericBNLearner);
203 genericBNLearner::genericBNLearner(
const DatabaseTable<>& db) :
204 score_database__(db) {
205 no_apriori__ =
new AprioriNoApriori<>(score_database__.databaseTable());
208 GUM_CONSTRUCTOR(genericBNLearner);
212 genericBNLearner::genericBNLearner(
const genericBNLearner& from) :
213 score_type__(from.score_type__),
214 param_estimator_type__(from.param_estimator_type__),
215 EMepsilon__(from.EMepsilon__), apriori_type__(from.apriori_type__),
216 apriori_weight__(from.apriori_weight__),
217 constraint_SliceOrder__(from.constraint_SliceOrder__),
218 constraint_Indegree__(from.constraint_Indegree__),
219 constraint_TabuList__(from.constraint_TabuList__),
220 constraint_ForbiddenArcs__(from.constraint_ForbiddenArcs__),
221 constraint_MandatoryArcs__(from.constraint_MandatoryArcs__),
222 selected_algo__(from.selected_algo__), K2__(from.K2__),
223 miic_3off2__(from.miic_3off2__), kmode_3off2__(from.kmode_3off2__),
224 greedy_hill_climbing__(from.greedy_hill_climbing__),
225 local_search_with_tabu_list__(from.local_search_with_tabu_list__),
226 score_database__(from.score_database__), ranges__(from.ranges__),
227 apriori_dbname__(from.apriori_dbname__),
228 initial_dag__(from.initial_dag__) {
229 no_apriori__ =
new AprioriNoApriori<>(score_database__.databaseTable());
232 GUM_CONS_CPY(genericBNLearner);
235 genericBNLearner::genericBNLearner(genericBNLearner&& from) :
236 score_type__(from.score_type__),
237 param_estimator_type__(from.param_estimator_type__),
238 EMepsilon__(from.EMepsilon__), apriori_type__(from.apriori_type__),
239 apriori_weight__(from.apriori_weight__),
240 constraint_SliceOrder__(std::move(from.constraint_SliceOrder__)),
241 constraint_Indegree__(std::move(from.constraint_Indegree__)),
242 constraint_TabuList__(std::move(from.constraint_TabuList__)),
243 constraint_ForbiddenArcs__(std::move(from.constraint_ForbiddenArcs__)),
244 constraint_MandatoryArcs__(std::move(from.constraint_MandatoryArcs__)),
245 selected_algo__(from.selected_algo__), K2__(std::move(from.K2__)),
246 miic_3off2__(std::move(from.miic_3off2__)),
247 kmode_3off2__(from.kmode_3off2__),
248 greedy_hill_climbing__(std::move(from.greedy_hill_climbing__)),
249 local_search_with_tabu_list__(
250 std::move(from.local_search_with_tabu_list__)),
251 score_database__(std::move(from.score_database__)),
252 ranges__(std::move(from.ranges__)),
253 apriori_dbname__(std::move(from.apriori_dbname__)),
254 initial_dag__(std::move(from.initial_dag__)) {
255 no_apriori__ =
new AprioriNoApriori<>(score_database__.databaseTable());
258 GUM_CONS_MOV(genericBNLearner);
261 genericBNLearner::~genericBNLearner() {
262 if (score__)
delete score__;
264 if (apriori__)
delete apriori__;
266 if (no_apriori__)
delete no_apriori__;
268 if (apriori_database__)
delete apriori_database__;
270 if (mutual_info__)
delete mutual_info__;
272 GUM_DESTRUCTOR(genericBNLearner);
275 genericBNLearner& genericBNLearner::operator=(
const genericBNLearner& from) {
287 if (apriori_database__) {
288 delete apriori_database__;
289 apriori_database__ =
nullptr;
293 delete mutual_info__;
294 mutual_info__ =
nullptr;
297 score_type__ = from.score_type__;
298 param_estimator_type__ = from.param_estimator_type__;
299 EMepsilon__ = from.EMepsilon__;
300 apriori_type__ = from.apriori_type__;
301 apriori_weight__ = from.apriori_weight__;
302 constraint_SliceOrder__ = from.constraint_SliceOrder__;
303 constraint_Indegree__ = from.constraint_Indegree__;
304 constraint_TabuList__ = from.constraint_TabuList__;
305 constraint_ForbiddenArcs__ = from.constraint_ForbiddenArcs__;
306 constraint_MandatoryArcs__ = from.constraint_MandatoryArcs__;
307 selected_algo__ = from.selected_algo__;
309 miic_3off2__ = from.miic_3off2__;
310 kmode_3off2__ = from.kmode_3off2__;
311 greedy_hill_climbing__ = from.greedy_hill_climbing__;
312 local_search_with_tabu_list__ = from.local_search_with_tabu_list__;
313 score_database__ = from.score_database__;
314 ranges__ = from.ranges__;
315 apriori_dbname__ = from.apriori_dbname__;
316 initial_dag__ = from.initial_dag__;
317 current_algorithm__ =
nullptr;
323 genericBNLearner& genericBNLearner::operator=(genericBNLearner&& from) {
335 if (apriori_database__) {
336 delete apriori_database__;
337 apriori_database__ =
nullptr;
341 delete mutual_info__;
342 mutual_info__ =
nullptr;
345 score_type__ = from.score_type__;
346 param_estimator_type__ = from.param_estimator_type__;
347 EMepsilon__ = from.EMepsilon__;
348 apriori_type__ = from.apriori_type__;
349 apriori_weight__ = from.apriori_weight__;
350 constraint_SliceOrder__ = std::move(from.constraint_SliceOrder__);
351 constraint_Indegree__ = std::move(from.constraint_Indegree__);
352 constraint_TabuList__ = std::move(from.constraint_TabuList__);
353 constraint_ForbiddenArcs__ = std::move(from.constraint_ForbiddenArcs__);
354 constraint_MandatoryArcs__ = std::move(from.constraint_MandatoryArcs__);
355 selected_algo__ = from.selected_algo__;
357 miic_3off2__ = std::move(from.miic_3off2__);
358 kmode_3off2__ = from.kmode_3off2__;
359 greedy_hill_climbing__ = std::move(from.greedy_hill_climbing__);
360 local_search_with_tabu_list__
361 = std::move(from.local_search_with_tabu_list__);
362 score_database__ = std::move(from.score_database__);
363 ranges__ = std::move(from.ranges__);
364 apriori_dbname__ = std::move(from.apriori_dbname__);
365 initial_dag__ = std::move(from.initial_dag__);
366 current_algorithm__ =
nullptr;
373 DatabaseTable<> readFile(
const std::string& filename) {
375 Size filename_size = Size(filename.size());
377 if (filename_size < 4) {
378 GUM_ERROR(FormatNotFound,
379 "genericBNLearner could not determine the " 380 "file type of the database");
383 std::string extension = filename.substr(filename.size() - 4);
384 std::transform(extension.begin(),
389 if (extension !=
".csv") {
390 GUM_ERROR(OperationNotAllowed,
391 "genericBNLearner does not support yet this type " 395 DBInitializerFromCSV<> initializer(filename);
397 const auto& var_names = initializer.variableNames();
398 const std::size_t nb_vars = var_names.size();
400 DBTranslatorSet<> translator_set;
401 DBTranslator4LabelizedVariable<> translator;
402 for (std::size_t i = 0; i < nb_vars; ++i) {
403 translator_set.insertTranslator(translator, i);
406 DatabaseTable<> database(translator_set);
407 database.setVariableNames(initializer.variableNames());
408 initializer.fillDatabase(database);
414 void genericBNLearner::checkFileName__(
const std::string& filename) {
416 Size filename_size = Size(filename.size());
418 if (filename_size < 4) {
419 GUM_ERROR(FormatNotFound,
420 "genericBNLearner could not determine the " 421 "file type of the database");
424 std::string extension = filename.substr(filename.size() - 4);
425 std::transform(extension.begin(),
430 if (extension !=
".csv") {
433 "genericBNLearner does not support yet this type of database file");
438 DatabaseTable<> genericBNLearner::readFile__(
439 const std::string& filename,
440 const std::vector< std::string >& missing_symbols) {
442 checkFileName__(filename);
444 DBInitializerFromCSV<> initializer(filename);
446 const auto& var_names = initializer.variableNames();
447 const std::size_t nb_vars = var_names.size();
449 DBTranslatorSet<> translator_set;
450 DBTranslator4LabelizedVariable<> translator(missing_symbols);
451 for (std::size_t i = 0; i < nb_vars; ++i) {
452 translator_set.insertTranslator(translator, i);
455 DatabaseTable<> database(missing_symbols, translator_set);
456 database.setVariableNames(initializer.variableNames());
457 initializer.fillDatabase(database);
465 void genericBNLearner::createApriori__() {
467 Apriori<>* old_apriori = apriori__;
470 switch (apriori_type__) {
471 case AprioriType::NO_APRIORI:
472 apriori__ =
new AprioriNoApriori<>(score_database__.databaseTable(),
473 score_database__.nodeId2Columns());
476 case AprioriType::SMOOTHING:
477 apriori__ =
new AprioriSmoothing<>(score_database__.databaseTable(),
478 score_database__.nodeId2Columns());
481 case AprioriType::DIRICHLET_FROM_DATABASE:
482 if (apriori_database__ !=
nullptr) {
483 delete apriori_database__;
484 apriori_database__ =
nullptr;
487 apriori_database__ =
new Database(apriori_dbname__,
489 score_database__.missingSymbols());
491 apriori__ =
new AprioriDirichletFromDatabase<>(
492 score_database__.databaseTable(),
493 apriori_database__->parser(),
494 apriori_database__->nodeId2Columns());
497 case AprioriType::BDEU:
498 apriori__ =
new AprioriBDeu<>(score_database__.databaseTable(),
499 score_database__.nodeId2Columns());
503 GUM_ERROR(OperationNotAllowed,
504 "The BNLearner does not support yet this apriori");
508 apriori__->setWeight(apriori_weight__);
511 if (old_apriori !=
nullptr)
delete old_apriori;
514 void genericBNLearner::createScore__() {
516 Score<>* old_score = score__;
519 switch (score_type__) {
521 score__ =
new ScoreAIC<>(score_database__.parser(),
524 score_database__.nodeId2Columns());
528 score__ =
new ScoreBD<>(score_database__.parser(),
531 score_database__.nodeId2Columns());
534 case ScoreType::BDeu:
535 score__ =
new ScoreBDeu<>(score_database__.parser(),
538 score_database__.nodeId2Columns());
542 score__ =
new ScoreBIC<>(score_database__.parser(),
545 score_database__.nodeId2Columns());
549 score__ =
new ScoreK2<>(score_database__.parser(),
552 score_database__.nodeId2Columns());
555 case ScoreType::LOG2LIKELIHOOD:
556 score__ =
new ScoreLog2Likelihood<>(score_database__.parser(),
559 score_database__.nodeId2Columns());
563 GUM_ERROR(OperationNotAllowed,
564 "genericBNLearner does not support yet this score");
568 if (old_score !=
nullptr)
delete old_score;
572 genericBNLearner::createParamEstimator__(DBRowGeneratorParser<>& parser,
573 bool take_into_account_score) {
574 ParamEstimator<>* param_estimator =
nullptr;
577 switch (param_estimator_type__) {
578 case ParamEstimatorType::ML:
579 if (take_into_account_score && (score__ !=
nullptr)) {
581 =
new ParamEstimatorML<>(parser,
583 score__->internalApriori(),
585 score_database__.nodeId2Columns());
588 =
new ParamEstimatorML<>(parser,
592 score_database__.nodeId2Columns());
598 GUM_ERROR(OperationNotAllowed,
599 "genericBNLearner does not support " 600 <<
"yet this parameter estimator");
604 param_estimator->setRanges(ranges__);
606 return param_estimator;
610 MixedGraph genericBNLearner::prepare_miic_3off2__() {
613 for (Size i = 0; i < score_database__.databaseTable().nbVariables(); ++i) {
614 mgraph.addNodeWithId(i);
615 for (Size j = 0; j < i; ++j) {
616 mgraph.addEdge(j, i);
621 HashTable< std::pair< NodeId, NodeId >,
char > initial_marks;
622 const ArcSet& mandatory_arcs = constraint_MandatoryArcs__.arcs();
623 for (
const auto& arc: mandatory_arcs) {
624 initial_marks.insert({arc.tail(), arc.head()},
'>');
627 const ArcSet& forbidden_arcs = constraint_ForbiddenArcs__.arcs();
628 for (
const auto& arc: forbidden_arcs) {
629 initial_marks.insert({arc.tail(), arc.head()},
'-');
631 miic_3off2__.addConstraints(initial_marks);
635 createCorrectedMutualInformation__();
640 MixedGraph genericBNLearner::learnMixedStructure() {
641 if (selected_algo__ != AlgoType::MIIC_THREE_OFF_TWO) {
642 GUM_ERROR(OperationNotAllowed,
"Must be using the miic/3off2 algorithm");
645 if (score_database__.databaseTable().hasMissingValues()) {
646 GUM_ERROR(MissingValueInDatabase,
647 "For the moment, the BNLearner is unable to learn " 648 <<
"structures with missing values in databases");
650 BNLearnerListener listener(
this, miic_3off2__);
653 MixedGraph mgraph =
this->prepare_miic_3off2__();
655 return miic_3off2__.learnMixedStructure(*mutual_info__, mgraph);
658 DAG genericBNLearner::learnDAG() {
666 void genericBNLearner::createCorrectedMutualInformation__() {
667 if (mutual_info__ !=
nullptr)
delete mutual_info__;
670 =
new CorrectedMutualInformation<>(score_database__.parser(),
673 score_database__.nodeId2Columns());
674 switch (kmode_3off2__) {
675 case CorrectedMutualInformation<>::KModeTypes::MDL:
676 mutual_info__->useMDL();
679 case CorrectedMutualInformation<>::KModeTypes::NML:
680 mutual_info__->useNML();
683 case CorrectedMutualInformation<>::KModeTypes::NoCorr:
684 mutual_info__->useNoCorr();
688 GUM_ERROR(NotImplementedYet,
689 "The BNLearner's corrected mutual information class does " 690 <<
"not support yet penalty mode " <<
int(kmode_3off2__));
694 DAG genericBNLearner::learnDAG__() {
696 if (score_database__.databaseTable().hasMissingValues()
697 || ((apriori_database__ !=
nullptr)
698 && (apriori_type__ == AprioriType::DIRICHLET_FROM_DATABASE)
699 && apriori_database__->databaseTable().hasMissingValues())) {
700 GUM_ERROR(MissingValueInDatabase,
701 "For the moment, the BNLearner is unable to cope " 702 "with missing values in databases");
706 DAG init_graph = initial_dag__;
708 const ArcSet& mandatory_arcs = constraint_MandatoryArcs__.arcs();
710 for (
const auto& arc: mandatory_arcs) {
711 if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
713 if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
715 init_graph.addArc(arc.tail(), arc.head());
718 const ArcSet& forbidden_arcs = constraint_ForbiddenArcs__.arcs();
720 for (
const auto& arc: forbidden_arcs) {
721 init_graph.eraseArc(arc);
724 switch (selected_algo__) {
726 case AlgoType::MIIC_THREE_OFF_TWO: {
727 BNLearnerListener listener(
this, miic_3off2__);
729 MixedGraph mgraph =
this->prepare_miic_3off2__();
731 return miic_3off2__.learnStructure(*mutual_info__, mgraph);
735 case AlgoType::GREEDY_HILL_CLIMBING: {
736 BNLearnerListener listener(
this, greedy_hill_climbing__);
737 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
738 StructuralConstraintForbiddenArcs,
739 StructuralConstraintPossibleEdges,
740 StructuralConstraintSliceOrder >
742 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
743 = constraint_MandatoryArcs__;
744 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
745 = constraint_ForbiddenArcs__;
746 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
747 = constraint_PossibleEdges__;
748 static_cast< StructuralConstraintSliceOrder& >(gen_constraint)
749 = constraint_SliceOrder__;
751 GraphChangesGenerator4DiGraph<
decltype(gen_constraint) > op_set(
754 StructuralConstraintSetStatic< StructuralConstraintIndegree,
755 StructuralConstraintDAG >
757 static_cast< StructuralConstraintIndegree& >(sel_constraint)
758 = constraint_Indegree__;
760 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
762 selector(*score__, sel_constraint, op_set);
764 return greedy_hill_climbing__.learnStructure(selector, init_graph);
768 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST: {
769 BNLearnerListener listener(
this, local_search_with_tabu_list__);
770 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
771 StructuralConstraintForbiddenArcs,
772 StructuralConstraintPossibleEdges,
773 StructuralConstraintSliceOrder >
775 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
776 = constraint_MandatoryArcs__;
777 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
778 = constraint_ForbiddenArcs__;
779 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
780 = constraint_PossibleEdges__;
781 static_cast< StructuralConstraintSliceOrder& >(gen_constraint)
782 = constraint_SliceOrder__;
784 GraphChangesGenerator4DiGraph<
decltype(gen_constraint) > op_set(
787 StructuralConstraintSetStatic< StructuralConstraintTabuList,
788 StructuralConstraintIndegree,
789 StructuralConstraintDAG >
791 static_cast< StructuralConstraintTabuList& >(sel_constraint)
792 = constraint_TabuList__;
793 static_cast< StructuralConstraintIndegree& >(sel_constraint)
794 = constraint_Indegree__;
796 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
798 selector(*score__, sel_constraint, op_set);
800 return local_search_with_tabu_list__.learnStructure(selector,
806 BNLearnerListener listener(
this, K2__.approximationScheme());
807 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
808 StructuralConstraintForbiddenArcs,
809 StructuralConstraintPossibleEdges >
811 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
812 = constraint_MandatoryArcs__;
813 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
814 = constraint_ForbiddenArcs__;
815 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
816 = constraint_PossibleEdges__;
818 GraphChangesGenerator4K2<
decltype(gen_constraint) > op_set(
823 const ArcSet& mandatory_arcs
824 =
static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
826 const Sequence< NodeId >& order = K2__.order();
827 bool order_compatible =
true;
829 for (
const auto& arc: mandatory_arcs) {
830 if (order.pos(arc.tail()) >= order.pos(arc.head())) {
831 order_compatible =
false;
836 if (order_compatible) {
837 StructuralConstraintSetStatic< StructuralConstraintIndegree,
838 StructuralConstraintDiGraph >
840 static_cast< StructuralConstraintIndegree& >(sel_constraint)
841 = constraint_Indegree__;
843 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
845 selector(*score__, sel_constraint, op_set);
847 return K2__.learnStructure(selector, init_graph);
849 StructuralConstraintSetStatic< StructuralConstraintIndegree,
850 StructuralConstraintDAG >
852 static_cast< StructuralConstraintIndegree& >(sel_constraint)
853 = constraint_Indegree__;
855 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
857 selector(*score__, sel_constraint, op_set);
859 return K2__.learnStructure(selector, init_graph);
865 GUM_ERROR(OperationNotAllowed,
866 "the learnDAG method has not been implemented for this " 867 "learning algorithm");
871 std::string genericBNLearner::checkScoreAprioriCompatibility() {
872 const std::string& apriori = getAprioriType__();
874 switch (score_type__) {
876 return ScoreAIC<>::isAprioriCompatible(apriori, apriori_weight__);
879 return ScoreBD<>::isAprioriCompatible(apriori, apriori_weight__);
881 case ScoreType::BDeu:
882 return ScoreBDeu<>::isAprioriCompatible(apriori, apriori_weight__);
885 return ScoreBIC<>::isAprioriCompatible(apriori, apriori_weight__);
888 return ScoreK2<>::isAprioriCompatible(apriori, apriori_weight__);
890 case ScoreType::LOG2LIKELIHOOD:
891 return ScoreLog2Likelihood<>::isAprioriCompatible(apriori,
895 return "genericBNLearner does not support yet this score";
901 std::pair< std::size_t, std::size_t >
902 genericBNLearner::useCrossValidationFold(
const std::size_t learning_fold,
903 const std::size_t k_fold) {
905 GUM_ERROR(OutOfBounds,
"K-fold cross validation with k=0 is forbidden");
908 if (learning_fold >= k_fold) {
909 GUM_ERROR(OutOfBounds,
910 "In " << k_fold <<
"-fold cross validation, the learning " 911 <<
"fold should be strictly lower than " << k_fold
912 <<
" but, here, it is equal to " << learning_fold);
915 const std::size_t db_size = score_database__.databaseTable().nbRows();
916 if (k_fold >= db_size) {
917 GUM_ERROR(OutOfBounds,
918 "In " << k_fold <<
"-fold cross validation, the database's " 919 <<
"size should be strictly greater than " << k_fold
920 <<
" but, here, the database has only " << db_size
925 const std::size_t foldSize = db_size / k_fold;
926 const std::size_t unfold_deb = learning_fold * foldSize;
927 const std::size_t unfold_end = unfold_deb + foldSize;
930 if (learning_fold == std::size_t(0)) {
932 std::pair< std::size_t, std::size_t >(unfold_end, db_size));
935 std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
937 if (learning_fold != k_fold - 1) {
939 std::pair< std::size_t, std::size_t >(unfold_end, db_size));
943 return std::pair< std::size_t, std::size_t >(unfold_deb, unfold_end);
947 std::pair<
double,
double >
948 genericBNLearner::chi2(
const NodeId id1,
950 const std::vector< NodeId >& knowing) {
952 gum::learning::IndepTestChi2<> chi2score(score_database__.parser(),
956 return chi2score.statistics(id1, id2, knowing);
959 std::pair<
double,
double >
960 genericBNLearner::chi2(
const std::string& name1,
961 const std::string& name2,
962 const std::vector< std::string >& knowing) {
963 std::vector< NodeId > knowingIds;
967 std::back_inserter(knowingIds),
968 [
this](
const std::string& c) -> NodeId {
return this->idFromName(c); });
969 return chi2(idFromName(name1), idFromName(name2), knowingIds);
972 std::pair<
double,
double >
973 genericBNLearner::G2(
const NodeId id1,
975 const std::vector< NodeId >& knowing) {
977 gum::learning::IndepTestG2<> g2score(score_database__.parser(),
980 return g2score.statistics(id1, id2, knowing);
983 std::pair<
double,
double >
984 genericBNLearner::G2(
const std::string& name1,
985 const std::string& name2,
986 const std::vector< std::string >& knowing) {
987 std::vector< NodeId > knowingIds;
991 std::back_inserter(knowingIds),
992 [
this](
const std::string& c) -> NodeId {
return this->idFromName(c); });
993 return G2(idFromName(name1), idFromName(name2), knowingIds);
996 double genericBNLearner::logLikelihood(
const std::vector< NodeId >& vars,
997 const std::vector< NodeId >& knowing) {
999 gum::learning::ScoreLog2Likelihood<> ll2score(score_database__.parser(),
1003 std::vector< NodeId > total(vars);
1004 total.insert(total.end(), knowing.begin(), knowing.end());
1005 double LLtotal = ll2score.score(IdCondSet<>(total,
false,
true));
1006 if (knowing.size() == (Size)0) {
1009 double LLknw = ll2score.score(IdCondSet<>(knowing,
false,
true));
1010 return LLtotal - LLknw;
1015 genericBNLearner::logLikelihood(
const std::vector< std::string >& vars,
1016 const std::vector< std::string >& knowing) {
1017 std::vector< NodeId > ids;
1018 std::vector< NodeId > knowingIds;
1020 auto mapper = [
this](
const std::string& c) -> NodeId {
1021 return this->idFromName(c);
1024 std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
1025 std::transform(knowing.begin(),
1027 std::back_inserter(knowingIds),
1030 return logLikelihood(ids, knowingIds);
1033 std::vector<
double >
1034 genericBNLearner::rawPseudoCount(
const std::vector< NodeId >& vars) {
1035 Potential<
double > res;
1038 gum::learning::PseudoCount<> count(score_database__.parser(),
1041 return count.get(vars);
1045 std::vector<
double >
1046 genericBNLearner::rawPseudoCount(
const std::vector< std::string >& vars) {
1047 std::vector< NodeId > ids;
1049 auto mapper = [
this](
const std::string& c) -> NodeId {
1050 return this->idFromName(c);
1053 std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
1055 return rawPseudoCount(ids);