34 #include <agrum/agrum.h> 35 #include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h> 36 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h> 37 #include <agrum/tools/stattests/indepTestChi2.h> 38 #include <agrum/tools/stattests/indepTestG2.h> 39 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h> 40 #include <agrum/tools/stattests/pseudoCount.h> 44 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h> 52 genericBNLearner::Database::Database(
const DatabaseTable<>& db) : _database_(db) {
54 const auto& var_names = _database_.variableNames();
55 const std::size_t nb_vars = var_names.size();
56 for (
auto dom: _database_.domainSizes())
57 _domain_sizes_.push_back(dom);
58 for (std::size_t i = 0; i < nb_vars; ++i) {
59 _nodeId2cols_.insert(NodeId(i), i);
63 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
67 genericBNLearner::Database::Database(
const std::string& filename,
68 const std::vector< std::string >& missing_symbols) :
69 Database(genericBNLearner::readFile_(filename, missing_symbols)) {}
72 genericBNLearner::Database::Database(
const std::string& CSV_filename,
73 Database& score_database,
74 const std::vector< std::string >& missing_symbols) {
76 genericBNLearner::isCSVFileName_(CSV_filename);
77 DBInitializerFromCSV<> initializer(CSV_filename);
78 const auto& apriori_names = initializer.variableNames();
79 std::size_t apriori_nb_vars = apriori_names.size();
80 HashTable< std::string, std::size_t > apriori_names2col(apriori_nb_vars);
81 for (std::size_t i = std::size_t(0); i < apriori_nb_vars; ++i)
82 apriori_names2col.insert(apriori_names[i], i);
86 if (apriori_nb_vars < score_database._database_.nbVariables()) {
87 GUM_ERROR(InvalidArgument,
88 "the a apriori database has fewer variables " 89 "than the observed database");
94 const std::vector< std::string >& score_names
95 = score_database.databaseTable().variableNames();
96 const std::size_t score_nb_vars = score_names.size();
97 HashTable< std::size_t, std::size_t > mapping(score_nb_vars);
98 for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
100 mapping.insert(i, apriori_names2col[score_names[i]]);
101 }
catch (Exception&) {
102 GUM_ERROR(MissingVariableInDatabase,
103 "Variable " << score_names[i]
104 <<
" of the observed database does not belong to the " 105 <<
"apriori database");
110 for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
111 const Variable& var = score_database.databaseTable().variable(i);
112 _database_.insertTranslator(var, mapping[i], missing_symbols);
116 initializer.fillDatabase(_database_);
119 for (
auto dom: _database_.domainSizes())
120 _domain_sizes_.push_back(dom);
123 _nodeId2cols_ = score_database.nodeId2Columns();
126 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
130 genericBNLearner::Database::Database(
const Database& from) :
131 _database_(from._database_), _domain_sizes_(from._domain_sizes_),
132 _nodeId2cols_(from._nodeId2cols_) {
134 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
138 genericBNLearner::Database::Database(Database&& from) :
139 _database_(std::move(from._database_)), _domain_sizes_(std::move(from._domain_sizes_)),
140 _nodeId2cols_(std::move(from._nodeId2cols_)) {
142 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
146 genericBNLearner::Database::~Database() {
delete _parser_; }
148 genericBNLearner::Database& genericBNLearner::Database::operator=(
const Database& from) {
151 _database_ = from._database_;
152 _domain_sizes_ = from._domain_sizes_;
153 _nodeId2cols_ = from._nodeId2cols_;
156 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
162 genericBNLearner::Database& genericBNLearner::Database::operator=(Database&& from) {
165 _database_ = std::move(from._database_);
166 _domain_sizes_ = std::move(from._domain_sizes_);
167 _nodeId2cols_ = std::move(from._nodeId2cols_);
170 _parser_ =
new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
179 genericBNLearner::genericBNLearner(
const std::string& filename,
180 const std::vector< std::string >& missing_symbols) :
181 scoreDatabase_(filename, missing_symbols) {
182 filename_ = filename;
183 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
185 GUM_CONSTRUCTOR(genericBNLearner);
189 genericBNLearner::genericBNLearner(
const DatabaseTable<>& db) : scoreDatabase_(db) {
191 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
193 GUM_CONSTRUCTOR(genericBNLearner);
197 genericBNLearner::genericBNLearner(
const genericBNLearner& from) :
198 scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
199 epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
200 aprioriWeight_(from.aprioriWeight_), constraintSliceOrder_(from.constraintSliceOrder_),
201 constraintIndegree_(from.constraintIndegree_),
202 constraintTabuList_(from.constraintTabuList_),
203 constraintForbiddenArcs_(from.constraintForbiddenArcs_),
204 constraintMandatoryArcs_(from.constraintMandatoryArcs_), selectedAlgo_(from.selectedAlgo_),
205 algoK2_(from.algoK2_), algoMiic3off2_(from.algoMiic3off2_), kmode3Off2_(from.kmode3Off2_),
206 greedyHillClimbing_(from.greedyHillClimbing_),
207 localSearchWithTabuList_(from.localSearchWithTabuList_),
208 scoreDatabase_(from.scoreDatabase_), ranges_(from.ranges_),
209 aprioriDbname_(from.aprioriDbname_), initialDag_(from.initialDag_),
210 filename_(from.filename_), nbDecreasingChanges_(from.nbDecreasingChanges_) {
211 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
213 GUM_CONS_CPY(genericBNLearner);
216 genericBNLearner::genericBNLearner(genericBNLearner&& from) :
217 scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
218 epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
219 aprioriWeight_(from.aprioriWeight_),
220 constraintSliceOrder_(std::move(from.constraintSliceOrder_)),
221 constraintIndegree_(std::move(from.constraintIndegree_)),
222 constraintTabuList_(std::move(from.constraintTabuList_)),
223 constraintForbiddenArcs_(std::move(from.constraintForbiddenArcs_)),
224 constraintMandatoryArcs_(std::move(from.constraintMandatoryArcs_)),
225 selectedAlgo_(from.selectedAlgo_), algoK2_(std::move(from.algoK2_)),
226 algoMiic3off2_(std::move(from.algoMiic3off2_)), kmode3Off2_(from.kmode3Off2_),
227 greedyHillClimbing_(std::move(from.greedyHillClimbing_)),
228 localSearchWithTabuList_(std::move(from.localSearchWithTabuList_)),
229 scoreDatabase_(std::move(from.scoreDatabase_)), ranges_(std::move(from.ranges_)),
230 aprioriDbname_(std::move(from.aprioriDbname_)), initialDag_(std::move(from.initialDag_)),
231 filename_(std::move(from.filename_)),
232 nbDecreasingChanges_(std::move(from.nbDecreasingChanges_)) {
233 noApriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable());
235 GUM_CONS_MOV(genericBNLearner)
238 genericBNLearner::~genericBNLearner() {
239 if (score_)
delete score_;
241 if (apriori_)
delete apriori_;
243 if (noApriori_)
delete noApriori_;
245 if (aprioriDatabase_)
delete aprioriDatabase_;
247 if (mutualInfo_)
delete mutualInfo_;
249 GUM_DESTRUCTOR(genericBNLearner);
252 genericBNLearner& genericBNLearner::operator=(
const genericBNLearner& from) {
264 if (aprioriDatabase_) {
265 delete aprioriDatabase_;
266 aprioriDatabase_ =
nullptr;
271 mutualInfo_ =
nullptr;
274 scoreType_ = from.scoreType_;
275 paramEstimatorType_ = from.paramEstimatorType_;
276 epsilonEM_ = from.epsilonEM_;
277 aprioriType_ = from.aprioriType_;
278 aprioriWeight_ = from.aprioriWeight_;
279 constraintSliceOrder_ = from.constraintSliceOrder_;
280 constraintIndegree_ = from.constraintIndegree_;
281 constraintTabuList_ = from.constraintTabuList_;
282 constraintForbiddenArcs_ = from.constraintForbiddenArcs_;
283 constraintMandatoryArcs_ = from.constraintMandatoryArcs_;
284 selectedAlgo_ = from.selectedAlgo_;
285 algoK2_ = from.algoK2_;
286 algoMiic3off2_ = from.algoMiic3off2_;
287 kmode3Off2_ = from.kmode3Off2_;
288 greedyHillClimbing_ = from.greedyHillClimbing_;
289 localSearchWithTabuList_ = from.localSearchWithTabuList_;
290 scoreDatabase_ = from.scoreDatabase_;
291 ranges_ = from.ranges_;
292 aprioriDbname_ = from.aprioriDbname_;
293 initialDag_ = from.initialDag_;
294 filename_ = from.filename_;
295 nbDecreasingChanges_ = from.nbDecreasingChanges_;
296 currentAlgorithm_ =
nullptr;
302 genericBNLearner& genericBNLearner::operator=(genericBNLearner&& from) {
314 if (aprioriDatabase_) {
315 delete aprioriDatabase_;
316 aprioriDatabase_ =
nullptr;
321 mutualInfo_ =
nullptr;
324 scoreType_ = from.scoreType_;
325 paramEstimatorType_ = from.paramEstimatorType_;
326 epsilonEM_ = from.epsilonEM_;
327 aprioriType_ = from.aprioriType_;
328 aprioriWeight_ = from.aprioriWeight_;
329 constraintSliceOrder_ = std::move(from.constraintSliceOrder_);
330 constraintIndegree_ = std::move(from.constraintIndegree_);
331 constraintTabuList_ = std::move(from.constraintTabuList_);
332 constraintForbiddenArcs_ = std::move(from.constraintForbiddenArcs_);
333 constraintMandatoryArcs_ = std::move(from.constraintMandatoryArcs_);
334 selectedAlgo_ = from.selectedAlgo_;
335 algoK2_ = from.algoK2_;
336 algoMiic3off2_ = std::move(from.algoMiic3off2_);
337 kmode3Off2_ = from.kmode3Off2_;
338 greedyHillClimbing_ = std::move(from.greedyHillClimbing_);
339 localSearchWithTabuList_ = std::move(from.localSearchWithTabuList_);
340 scoreDatabase_ = std::move(from.scoreDatabase_);
341 ranges_ = std::move(from.ranges_);
342 aprioriDbname_ = std::move(from.aprioriDbname_);
343 filename_ = std::move(from.filename_);
344 initialDag_ = std::move(from.initialDag_);
345 nbDecreasingChanges_ = std::move(from.nbDecreasingChanges_);
346 currentAlgorithm_ =
nullptr;
353 DatabaseTable<> readFile(
const std::string& filename) {
355 Size filename_size = Size(filename.size());
357 if (filename_size < 4) {
358 GUM_ERROR(FormatNotFound,
359 "genericBNLearner could not determine the " 360 "file type of the database '" 364 std::string extension = filename.substr(filename.size() - 4);
365 std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
367 if (extension !=
".csv") {
368 GUM_ERROR(OperationNotAllowed,
369 "genericBNLearner does not support yet this type ('" << extension
374 DBInitializerFromCSV<> initializer(filename);
376 const auto& var_names = initializer.variableNames();
377 const std::size_t nb_vars = var_names.size();
379 DBTranslatorSet<> translator_set;
380 DBTranslator4LabelizedVariable<> translator;
381 for (std::size_t i = 0; i < nb_vars; ++i) {
382 translator_set.insertTranslator(translator, i);
385 DatabaseTable<> database(translator_set);
386 database.setVariableNames(initializer.variableNames());
387 initializer.fillDatabase(database);
393 void genericBNLearner::isCSVFileName_(
const std::string& filename) {
395 Size filename_size = Size(filename.size());
397 if (filename_size < 4) {
398 GUM_ERROR(FormatNotFound,
399 "genericBNLearner could not determine the " 400 "file type of the database");
403 std::string extension = filename.substr(filename.size() - 4);
404 std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
406 if (extension !=
".csv") {
407 GUM_ERROR(OperationNotAllowed,
408 "genericBNLearner does not support yet this type of database file");
413 DatabaseTable<> genericBNLearner::readFile_(
const std::string& filename,
414 const std::vector< std::string >& missing_symbols) {
416 isCSVFileName_(filename);
418 DBInitializerFromCSV<> initializer(filename);
420 const auto& var_names = initializer.variableNames();
421 const std::size_t nb_vars = var_names.size();
423 DBTranslatorSet<> translator_set;
424 DBTranslator4LabelizedVariable<> translator(missing_symbols);
425 for (std::size_t i = 0; i < nb_vars; ++i) {
426 translator_set.insertTranslator(translator, i);
429 DatabaseTable<> database(missing_symbols, translator_set);
430 database.setVariableNames(initializer.variableNames());
431 initializer.fillDatabase(database);
439 void genericBNLearner::createApriori_() {
441 Apriori<>* old_apriori = apriori_;
444 switch (aprioriType_) {
445 case AprioriType::NO_APRIORI:
446 apriori_ =
new AprioriNoApriori<>(scoreDatabase_.databaseTable(),
447 scoreDatabase_.nodeId2Columns());
450 case AprioriType::SMOOTHING:
451 apriori_ =
new AprioriSmoothing<>(scoreDatabase_.databaseTable(),
452 scoreDatabase_.nodeId2Columns());
455 case AprioriType::DIRICHLET_FROM_DATABASE:
456 if (aprioriDatabase_ !=
nullptr) {
457 delete aprioriDatabase_;
458 aprioriDatabase_ =
nullptr;
462 =
new Database(aprioriDbname_, scoreDatabase_, scoreDatabase_.missingSymbols());
464 apriori_ =
new AprioriDirichletFromDatabase<>(scoreDatabase_.databaseTable(),
465 aprioriDatabase_->parser(),
466 aprioriDatabase_->nodeId2Columns());
469 case AprioriType::BDEU:
471 =
new AprioriBDeu<>(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
475 GUM_ERROR(OperationNotAllowed,
"The BNLearner does not support yet this apriori")
479 apriori_->setWeight(aprioriWeight_);
482 if (old_apriori !=
nullptr)
delete old_apriori;
485 void genericBNLearner::createScore_() {
487 Score<>* old_score = score_;
490 switch (scoreType_) {
492 score_ =
new ScoreAIC<>(scoreDatabase_.parser(),
495 scoreDatabase_.nodeId2Columns());
499 score_ =
new ScoreBD<>(scoreDatabase_.parser(),
502 scoreDatabase_.nodeId2Columns());
505 case ScoreType::BDeu:
506 score_ =
new ScoreBDeu<>(scoreDatabase_.parser(),
509 scoreDatabase_.nodeId2Columns());
513 score_ =
new ScoreBIC<>(scoreDatabase_.parser(),
516 scoreDatabase_.nodeId2Columns());
520 score_ =
new ScoreK2<>(scoreDatabase_.parser(),
523 scoreDatabase_.nodeId2Columns());
526 case ScoreType::LOG2LIKELIHOOD:
527 score_ =
new ScoreLog2Likelihood<>(scoreDatabase_.parser(),
530 scoreDatabase_.nodeId2Columns());
534 GUM_ERROR(OperationNotAllowed,
"genericBNLearner does not support yet this score")
538 if (old_score !=
nullptr)
delete old_score;
541 ParamEstimator<>* genericBNLearner::createParamEstimator_(DBRowGeneratorParser<>& parser,
542 bool take_into_account_score) {
543 ParamEstimator<>* param_estimator =
nullptr;
546 switch (paramEstimatorType_) {
547 case ParamEstimatorType::ML:
548 if (take_into_account_score && (score_ !=
nullptr)) {
549 param_estimator =
new ParamEstimatorML<>(parser,
551 score_->internalApriori(),
553 scoreDatabase_.nodeId2Columns());
555 param_estimator =
new ParamEstimatorML<>(parser,
559 scoreDatabase_.nodeId2Columns());
565 GUM_ERROR(OperationNotAllowed,
566 "genericBNLearner does not support " 567 <<
"yet this parameter estimator");
571 param_estimator->setRanges(ranges_);
573 return param_estimator;
577 MixedGraph genericBNLearner::prepareMiic3Off2_() {
580 for (Size i = 0; i < scoreDatabase_.databaseTable().nbVariables(); ++i) {
581 mgraph.addNodeWithId(i);
582 for (Size j = 0; j < i; ++j) {
583 mgraph.addEdge(j, i);
588 HashTable< std::pair< NodeId, NodeId >,
char > initial_marks;
589 const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
590 for (
const auto& arc: mandatory_arcs) {
591 initial_marks.insert({arc.tail(), arc.head()},
'>');
594 const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
595 for (
const auto& arc: forbidden_arcs) {
596 initial_marks.insert({arc.tail(), arc.head()},
'-');
598 algoMiic3off2_.addConstraints(initial_marks);
602 createCorrectedMutualInformation_();
607 MixedGraph genericBNLearner::learnMixedStructure() {
608 if (selectedAlgo_ != AlgoType::MIIC && selectedAlgo_ != AlgoType::THREE_OFF_TWO) {
609 GUM_ERROR(OperationNotAllowed,
"Must be using the miic/3off2 algorithm")
612 if (scoreDatabase_.databaseTable().hasMissingValues()) {
613 GUM_ERROR(MissingValueInDatabase,
614 "For the moment, the BNLearner is unable to learn " 615 <<
"structures with missing values in databases");
617 BNLearnerListener listener(
this, algoMiic3off2_);
620 MixedGraph mgraph =
this->prepareMiic3Off2_();
622 return algoMiic3off2_.learnMixedStructure(*mutualInfo_, mgraph);
625 DAG genericBNLearner::learnDAG() {
633 void genericBNLearner::createCorrectedMutualInformation_() {
634 if (mutualInfo_ !=
nullptr)
delete mutualInfo_;
636 mutualInfo_ =
new CorrectedMutualInformation<>(scoreDatabase_.parser(),
639 scoreDatabase_.nodeId2Columns());
640 switch (kmode3Off2_) {
641 case CorrectedMutualInformation<>::KModeTypes::MDL:
642 mutualInfo_->useMDL();
645 case CorrectedMutualInformation<>::KModeTypes::NML:
646 mutualInfo_->useNML();
649 case CorrectedMutualInformation<>::KModeTypes::NoCorr:
650 mutualInfo_->useNoCorr();
654 GUM_ERROR(NotImplementedYet,
655 "The BNLearner's corrected mutual information class does " 656 <<
"not implement yet this correction : " <<
int(kmode3Off2_));
660 DAG genericBNLearner::learnDag_() {
662 if (scoreDatabase_.databaseTable().hasMissingValues()
663 || ((aprioriDatabase_ !=
nullptr)
664 && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
665 && aprioriDatabase_->databaseTable().hasMissingValues())) {
666 GUM_ERROR(MissingValueInDatabase,
667 "For the moment, the BNLearner is unable to cope " 668 "with missing values in databases");
672 DAG init_graph = initialDag_;
674 const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
676 for (
const auto& arc: mandatory_arcs) {
677 if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
679 if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
681 init_graph.addArc(arc.tail(), arc.head());
684 const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
686 for (
const auto& arc: forbidden_arcs) {
687 init_graph.eraseArc(arc);
690 switch (selectedAlgo_) {
693 case AlgoType::THREE_OFF_TWO: {
694 BNLearnerListener listener(
this, algoMiic3off2_);
696 MixedGraph mgraph =
this->prepareMiic3Off2_();
698 return algoMiic3off2_.learnStructure(*mutualInfo_, mgraph);
702 case AlgoType::GREEDY_HILL_CLIMBING: {
703 BNLearnerListener listener(
this, greedyHillClimbing_);
704 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
705 StructuralConstraintForbiddenArcs,
706 StructuralConstraintPossibleEdges,
707 StructuralConstraintSliceOrder >
709 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
710 = constraintMandatoryArcs_;
711 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
712 = constraintForbiddenArcs_;
713 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
714 = constraintPossibleEdges_;
715 static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
717 GraphChangesGenerator4DiGraph<
decltype(gen_constraint) > op_set(gen_constraint);
719 StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
721 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
723 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
728 return greedyHillClimbing_.learnStructure(selector, init_graph);
732 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST: {
733 BNLearnerListener listener(
this, localSearchWithTabuList_);
734 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
735 StructuralConstraintForbiddenArcs,
736 StructuralConstraintPossibleEdges,
737 StructuralConstraintSliceOrder >
739 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
740 = constraintMandatoryArcs_;
741 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
742 = constraintForbiddenArcs_;
743 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
744 = constraintPossibleEdges_;
745 static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
747 GraphChangesGenerator4DiGraph<
decltype(gen_constraint) > op_set(gen_constraint);
749 StructuralConstraintSetStatic< StructuralConstraintTabuList,
750 StructuralConstraintIndegree,
751 StructuralConstraintDAG >
753 static_cast< StructuralConstraintTabuList& >(sel_constraint) = constraintTabuList_;
754 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
756 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
761 return localSearchWithTabuList_.learnStructure(selector, init_graph);
766 BNLearnerListener listener(
this, algoK2_.approximationScheme());
767 StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
768 StructuralConstraintForbiddenArcs,
769 StructuralConstraintPossibleEdges >
771 static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
772 = constraintMandatoryArcs_;
773 static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
774 = constraintForbiddenArcs_;
775 static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
776 = constraintPossibleEdges_;
778 GraphChangesGenerator4K2<
decltype(gen_constraint) > op_set(gen_constraint);
782 const ArcSet& mandatory_arcs
783 =
static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint).arcs();
784 const Sequence< NodeId >& order = algoK2_.order();
785 bool order_compatible =
true;
787 for (
const auto& arc: mandatory_arcs) {
788 if (order.pos(arc.tail()) >= order.pos(arc.head())) {
789 order_compatible =
false;
794 if (order_compatible) {
795 StructuralConstraintSetStatic< StructuralConstraintIndegree,
796 StructuralConstraintDiGraph >
798 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
800 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
805 return algoK2_.learnStructure(selector, init_graph);
807 StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
809 static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
811 GraphChangesSelector4DiGraph<
decltype(sel_constraint),
decltype(op_set) > selector(
816 return algoK2_.learnStructure(selector, init_graph);
822 GUM_ERROR(OperationNotAllowed,
823 "the learnDAG method has not been implemented for this " 824 "learning algorithm");
828 std::string genericBNLearner::checkScoreAprioriCompatibility()
const {
829 const std::string& apriori = getAprioriType_();
831 switch (scoreType_) {
833 return ScoreAIC<>::isAprioriCompatible(apriori, aprioriWeight_);
836 return ScoreBD<>::isAprioriCompatible(apriori, aprioriWeight_);
838 case ScoreType::BDeu:
839 return ScoreBDeu<>::isAprioriCompatible(apriori, aprioriWeight_);
842 return ScoreBIC<>::isAprioriCompatible(apriori, aprioriWeight_);
845 return ScoreK2<>::isAprioriCompatible(apriori, aprioriWeight_);
847 case ScoreType::LOG2LIKELIHOOD:
848 return ScoreLog2Likelihood<>::isAprioriCompatible(apriori, aprioriWeight_);
851 return "genericBNLearner does not support yet this score";
857 std::pair< std::size_t, std::size_t >
858 genericBNLearner::useCrossValidationFold(
const std::size_t learning_fold,
859 const std::size_t k_fold) {
860 if (k_fold == 0) { GUM_ERROR(OutOfBounds,
"K-fold cross validation with k=0 is forbidden") }
862 if (learning_fold >= k_fold) {
863 GUM_ERROR(OutOfBounds,
864 "In " << k_fold <<
"-fold cross validation, the learning " 865 <<
"fold should be strictly lower than " << k_fold
866 <<
" but, here, it is equal to " << learning_fold);
869 const std::size_t db_size = scoreDatabase_.databaseTable().nbRows();
870 if (k_fold >= db_size) {
871 GUM_ERROR(OutOfBounds,
872 "In " << k_fold <<
"-fold cross validation, the database's " 873 <<
"size should be strictly greater than " << k_fold
874 <<
" but, here, the database has only " << db_size <<
"rows");
878 const std::size_t foldSize = db_size / k_fold;
879 const std::size_t unfold_deb = learning_fold * foldSize;
880 const std::size_t unfold_end = unfold_deb + foldSize;
883 if (learning_fold == std::size_t(0)) {
884 ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
886 ranges_.push_back(std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
888 if (learning_fold != k_fold - 1) {
889 ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
893 return std::pair< std::size_t, std::size_t >(unfold_deb, unfold_end);
897 std::pair<
double,
double > genericBNLearner::chi2(
const NodeId id1,
899 const std::vector< NodeId >& knowing) {
901 gum::learning::IndepTestChi2<> chi2score(scoreDatabase_.parser(),
905 return chi2score.statistics(id1, id2, knowing);
908 std::pair<
double,
double > genericBNLearner::chi2(
const std::string& name1,
909 const std::string& name2,
910 const std::vector< std::string >& knowing) {
911 std::vector< NodeId > knowingIds;
912 std::transform(knowing.begin(),
914 std::back_inserter(knowingIds),
915 [
this](
const std::string& c) -> NodeId {
return this->idFromName(c); });
916 return chi2(idFromName(name1), idFromName(name2), knowingIds);
919 std::pair<
double,
double > genericBNLearner::G2(
const NodeId id1,
921 const std::vector< NodeId >& knowing) {
923 gum::learning::IndepTestG2<> g2score(scoreDatabase_.parser(), *apriori_, databaseRanges());
924 return g2score.statistics(id1, id2, knowing);
927 std::pair<
double,
double > genericBNLearner::G2(
const std::string& name1,
928 const std::string& name2,
929 const std::vector< std::string >& knowing) {
930 std::vector< NodeId > knowingIds;
931 std::transform(knowing.begin(),
933 std::back_inserter(knowingIds),
934 [
this](
const std::string& c) -> NodeId {
return this->idFromName(c); });
935 return G2(idFromName(name1), idFromName(name2), knowingIds);
938 double genericBNLearner::logLikelihood(
const std::vector< NodeId >& vars,
939 const std::vector< NodeId >& knowing) {
941 gum::learning::ScoreLog2Likelihood<> ll2score(scoreDatabase_.parser(),
945 std::vector< NodeId > total(vars);
946 total.insert(total.end(), knowing.begin(), knowing.end());
947 double LLtotal = ll2score.score(IdCondSet<>(total,
false,
true));
948 if (knowing.size() == (Size)0) {
951 double LLknw = ll2score.score(IdCondSet<>(knowing,
false,
true));
952 return LLtotal - LLknw;
956 double genericBNLearner::logLikelihood(
const std::vector< std::string >& vars,
957 const std::vector< std::string >& knowing) {
958 std::vector< NodeId > ids;
959 std::vector< NodeId > knowingIds;
961 auto mapper = [
this](
const std::string& c) -> NodeId {
962 return this->idFromName(c);
965 std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
966 std::transform(knowing.begin(), knowing.end(), std::back_inserter(knowingIds), mapper);
968 return logLikelihood(ids, knowingIds);
971 std::vector<
double > genericBNLearner::rawPseudoCount(
const std::vector< NodeId >& vars) {
972 Potential<
double > res;
975 gum::learning::PseudoCount<> count(scoreDatabase_.parser(), *apriori_, databaseRanges());
976 return count.get(vars);
980 std::vector<
double > genericBNLearner::rawPseudoCount(
const std::vector< std::string >& vars) {
981 std::vector< NodeId > ids;
983 auto mapper = [
this](
const std::string& c) -> NodeId {
984 return this->idFromName(c);
987 std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
989 return rawPseudoCount(ids);