32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H 33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H 38 #include <agrum/BN/BayesNet.h> 39 #include <agrum/agrum.h> 40 #include <agrum/tools/core/bijection.h> 41 #include <agrum/tools/core/sequence.h> 42 #include <agrum/tools/graphs/DAG.h> 44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h> 45 #include <agrum/tools/database/DBRowGeneratorParser.h> 46 #include <agrum/tools/database/DBInitializerFromCSV.h> 47 #include <agrum/tools/database/databaseTable.h> 48 #include <agrum/tools/database/DBRowGeneratorParser.h> 49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h> 50 #include <agrum/tools/database/DBRowGeneratorEM.h> 51 #include <agrum/tools/database/DBRowGeneratorSet.h> 53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h> 54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h> 55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h> 56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h> 57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h> 58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h> 60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h> 61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h> 62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h> 63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h> 65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h> 66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h> 67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h> 68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h> 69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h> 70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h> 71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h> 72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h> 73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h> 75 #include <agrum/BN/learning/structureUtils/graphChange.h> 76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h> 77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h> 78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h> 80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h> 81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h> 83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h> 84 #include <agrum/tools/core/approximations/approximationSchemeListener.h> 86 #include <agrum/BN/learning/K2.h> 87 #include <agrum/BN/learning/Miic.h> 88 #include <agrum/BN/learning/greedyHillClimbing.h> 89 #include <agrum/BN/learning/localSearchWithTabuList.h> 91 #include <agrum/tools/core/signal/signaler.h> 97 class BNLearnerListener;
106 class genericBNLearner:
public gum::IApproximationSchemeConfiguration {
122 enum class ParamEstimatorType
126 enum class AprioriType
130 DIRICHLET_FROM_DATABASE,
138 GREEDY_HILL_CLIMBING,
139 LOCAL_SEARCH_WITH_TABU_LIST,
156 explicit Database(
const std::string& file,
157 const std::vector< std::string >& missing_symbols);
162 explicit Database(
const DatabaseTable<>& db);
175 Database(
const std::string& filename,
176 Database& score_database,
177 const std::vector< std::string >& missing_symbols);
186 template <
typename GUM_SCALAR >
187 Database(
const std::string& filename,
188 const gum::BayesNet< GUM_SCALAR >& bn,
189 const std::vector< std::string >& missing_symbols);
192 Database(
const Database& from);
195 Database(Database&& from);
208 Database& operator=(
const Database& from);
211 Database& operator=(Database&& from);
221 DBRowGeneratorParser<>& parser();
224 const std::vector< std::size_t >& domainSizes()
const;
227 const std::vector< std::string >& names()
const;
230 NodeId idFromName(
const std::string& var_name)
const;
233 const std::string& nameFromId(NodeId id)
const;
236 const DatabaseTable<>& databaseTable()
const;
240 void setDatabaseWeight(
const double new_weight);
243 const Bijection< NodeId, std::size_t >& nodeId2Columns()
const;
246 const std::vector< std::string >& missingSymbols()
const;
249 std::size_t nbRows()
const;
252 std::size_t size()
const;
258 void setWeight(
const std::size_t i,
const double weight);
263 double weight(
const std::size_t i)
const;
266 double weight()
const;
273 DatabaseTable<> database__;
276 DBRowGeneratorParser<>* parser__{
nullptr};
279 std::vector< std::size_t > domain_sizes__;
282 Bijection< NodeId, std::size_t > nodeId2cols__;
285 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE) 286 Size max_threads_number__{getMaxNumberOfThreads()};
288 Size max_threads_number__{1};
292 Size min_nb_rows_per_thread__{100};
297 template <
typename GUM_SCALAR >
298 BayesNet< GUM_SCALAR > BNVars__()
const;
302 void setAprioriWeight__(
double weight);
315 genericBNLearner(
const std::string& filename,
316 const std::vector< std::string >& missing_symbols);
317 genericBNLearner(
const DatabaseTable<>& db);
338 template <
typename GUM_SCALAR >
339 genericBNLearner(
const std::string& filename,
340 const gum::BayesNet< GUM_SCALAR >& src,
341 const std::vector< std::string >& missing_symbols);
344 genericBNLearner(
const genericBNLearner&);
347 genericBNLearner(genericBNLearner&&);
350 virtual ~genericBNLearner();
360 genericBNLearner& operator=(
const genericBNLearner&);
363 genericBNLearner& operator=(genericBNLearner&&);
377 MixedGraph learnMixedStructure();
380 void setInitialDAG(
const DAG&);
383 const std::vector< std::string >& names()
const;
386 const std::vector< std::size_t >& domainSizes()
const;
387 Size domainSize(NodeId var)
const;
388 Size domainSize(
const std::string& var)
const;
395 NodeId idFromName(
const std::string& var_name)
const;
398 const DatabaseTable<>& database()
const;
402 void setDatabaseWeight(
const double new_weight);
408 void setRecordWeight(
const std::size_t i,
const double weight);
413 double recordWeight(
const std::size_t i)
const;
416 double databaseWeight()
const;
419 const std::string& nameFromId(NodeId id)
const;
428 template <
template <
typename >
class XALLOC >
429 void useDatabaseRanges(
430 const std::vector< std::pair< std::size_t, std::size_t >,
431 XALLOC< std::pair< std::size_t, std::size_t > > >&
435 void clearDatabaseRanges();
441 const std::vector< std::pair< std::size_t, std::size_t > >&
442 databaseRanges()
const;
465 std::pair< std::size_t, std::size_t >
466 useCrossValidationFold(
const std::size_t learning_fold,
467 const std::size_t k_fold);
477 std::pair<
double,
double > chi2(
const NodeId id1,
479 const std::vector< NodeId >& knowing = {});
487 std::pair<
double,
double > chi2(
const std::string& name1,
488 const std::string& name2,
489 const std::vector< std::string >& knowing
499 std::pair<
double,
double > G2(
const NodeId id1,
501 const std::vector< NodeId >& knowing = {});
509 std::pair<
double,
double > G2(
const std::string& name1,
510 const std::string& name2,
511 const std::vector< std::string >& knowing
521 double logLikelihood(
const std::vector< NodeId >& vars,
522 const std::vector< NodeId >& knowing = {});
531 double logLikelihood(
const std::vector< std::string >& vars,
532 const std::vector< std::string >& knowing = {});
539 std::vector<
double > rawPseudoCount(
const std::vector< NodeId >& vars);
546 std::vector<
double > rawPseudoCount(
const std::vector< std::string >& vars);
563 void useEM(
const double epsilon);
566 bool hasMissingValues()
const;
591 void useScoreLog2Likelihood();
607 void useAprioriBDeu(
double weight = 1);
613 void useAprioriSmoothing(
double weight = 1);
616 void useAprioriDirichlet(
const std::string& filename,
double weight = 1);
622 std::string checkScoreAprioriCompatibility();
631 void useGreedyHillClimbing();
637 void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2);
640 void useK2(
const Sequence< NodeId >& order);
643 void useK2(
const std::vector< NodeId >& order);
669 const std::vector< Arc > latentVariables()
const;
678 void setMaxIndegree(Size max_indegree);
685 void setSliceOrder(
const NodeProperty< NodeId >& slice_order);
691 void setSliceOrder(
const std::vector< std::vector< std::string > >& slices);
694 void setForbiddenArcs(
const ArcSet& set);
698 void addForbiddenArc(
const Arc& arc);
699 void addForbiddenArc(
const NodeId tail,
const NodeId head);
700 void addForbiddenArc(
const std::string& tail,
const std::string& head);
705 void eraseForbiddenArc(
const Arc& arc);
706 void eraseForbiddenArc(
const NodeId tail,
const NodeId head);
707 void eraseForbiddenArc(
const std::string& tail,
const std::string& head);
711 void setMandatoryArcs(
const ArcSet& set);
715 void addMandatoryArc(
const Arc& arc);
716 void addMandatoryArc(
const NodeId tail,
const NodeId head);
717 void addMandatoryArc(
const std::string& tail,
const std::string& head);
722 void eraseMandatoryArc(
const Arc& arc);
723 void eraseMandatoryArc(
const NodeId tail,
const NodeId head);
724 void eraseMandatoryArc(
const std::string& tail,
const std::string& head);
731 void setPossibleEdges(
const EdgeSet& set);
732 void setPossibleSkeleton(
const UndiGraph& skeleton);
739 void addPossibleEdge(
const Edge& edge);
740 void addPossibleEdge(
const NodeId tail,
const NodeId head);
741 void addPossibleEdge(
const std::string& tail,
const std::string& head);
746 void erasePossibleEdge(
const Edge& edge);
747 void erasePossibleEdge(
const NodeId tail,
const NodeId head);
748 void erasePossibleEdge(
const std::string& tail,
const std::string& head);
755 ScoreType score_type__{ScoreType::BDeu};
758 Score<>* score__{
nullptr};
761 ParamEstimatorType param_estimator_type__{ParamEstimatorType::ML};
764 double EMepsilon__{0.0};
767 CorrectedMutualInformation<>* mutual_info__{
nullptr};
770 AprioriType apriori_type__{AprioriType::NO_APRIORI};
773 Apriori<>* apriori__{
nullptr};
775 AprioriNoApriori<>* no_apriori__{
nullptr};
778 double apriori_weight__{1.0f};
781 StructuralConstraintSliceOrder constraint_SliceOrder__;
784 StructuralConstraintIndegree constraint_Indegree__;
787 StructuralConstraintTabuList constraint_TabuList__;
790 StructuralConstraintForbiddenArcs constraint_ForbiddenArcs__;
793 StructuralConstraintPossibleEdges constraint_PossibleEdges__;
796 StructuralConstraintMandatoryArcs constraint_MandatoryArcs__;
799 AlgoType selected_algo__{AlgoType::GREEDY_HILL_CLIMBING};
808 typename CorrectedMutualInformation<>::KModeTypes kmode_3off2__{
809 CorrectedMutualInformation<>::KModeTypes::MDL};
812 DAG2BNLearner<> Dag2BN__;
815 GreedyHillClimbing greedy_hill_climbing__;
818 LocalSearchWithTabuList local_search_with_tabu_list__;
821 Database score_database__;
824 std::vector< std::pair< std::size_t, std::size_t > > ranges__;
827 Database* apriori_database__{
nullptr};
830 std::string apriori_dbname__;
836 const ApproximationScheme* current_algorithm__{
nullptr};
839 static DatabaseTable<>
840 readFile__(
const std::string& filename,
841 const std::vector< std::string >& missing_symbols);
844 static void checkFileName__(
const std::string& filename);
847 void createApriori__();
850 void createScore__();
853 ParamEstimator<>* createParamEstimator__(DBRowGeneratorParser<>& parser,
854 bool take_into_account_score
861 MixedGraph prepare_miic_3off2__();
864 const std::string& getAprioriType__()
const;
867 void createCorrectedMutualInformation__();
880 INLINE
void setCurrentApproximationScheme(
881 const ApproximationScheme* approximationScheme) {
882 current_algorithm__ = approximationScheme;
886 distributeProgress(
const ApproximationScheme* approximationScheme,
890 setCurrentApproximationScheme(approximationScheme);
892 if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time);
896 INLINE
void distributeStop(
const ApproximationScheme* approximationScheme,
897 std::string message) {
898 setCurrentApproximationScheme(approximationScheme);
900 if (onStop.hasListener()) GUM_EMIT1(onStop, message);
908 void setEpsilon(
double eps) {
909 K2__.approximationScheme().setEpsilon(eps);
910 greedy_hill_climbing__.setEpsilon(eps);
911 local_search_with_tabu_list__.setEpsilon(eps);
912 Dag2BN__.setEpsilon(eps);
916 double epsilon()
const {
917 if (current_algorithm__ !=
nullptr)
918 return current_algorithm__->epsilon();
920 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
924 void disableEpsilon() {
925 K2__.approximationScheme().disableEpsilon();
926 greedy_hill_climbing__.disableEpsilon();
927 local_search_with_tabu_list__.disableEpsilon();
928 Dag2BN__.disableEpsilon();
932 void enableEpsilon() {
933 K2__.approximationScheme().enableEpsilon();
934 greedy_hill_climbing__.enableEpsilon();
935 local_search_with_tabu_list__.enableEpsilon();
936 Dag2BN__.enableEpsilon();
941 bool isEnabledEpsilon()
const {
942 if (current_algorithm__ !=
nullptr)
943 return current_algorithm__->isEnabledEpsilon();
945 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
954 void setMinEpsilonRate(
double rate) {
955 K2__.approximationScheme().setMinEpsilonRate(rate);
956 greedy_hill_climbing__.setMinEpsilonRate(rate);
957 local_search_with_tabu_list__.setMinEpsilonRate(rate);
958 Dag2BN__.setMinEpsilonRate(rate);
962 double minEpsilonRate()
const {
963 if (current_algorithm__ !=
nullptr)
964 return current_algorithm__->minEpsilonRate();
966 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
970 void disableMinEpsilonRate() {
971 K2__.approximationScheme().disableMinEpsilonRate();
972 greedy_hill_climbing__.disableMinEpsilonRate();
973 local_search_with_tabu_list__.disableMinEpsilonRate();
974 Dag2BN__.disableMinEpsilonRate();
977 void enableMinEpsilonRate() {
978 K2__.approximationScheme().enableMinEpsilonRate();
979 greedy_hill_climbing__.enableMinEpsilonRate();
980 local_search_with_tabu_list__.enableMinEpsilonRate();
981 Dag2BN__.enableMinEpsilonRate();
985 bool isEnabledMinEpsilonRate()
const {
986 if (current_algorithm__ !=
nullptr)
987 return current_algorithm__->isEnabledMinEpsilonRate();
989 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
998 void setMaxIter(Size max) {
999 K2__.approximationScheme().setMaxIter(max);
1000 greedy_hill_climbing__.setMaxIter(max);
1001 local_search_with_tabu_list__.setMaxIter(max);
1002 Dag2BN__.setMaxIter(max);
1006 Size maxIter()
const {
1007 if (current_algorithm__ !=
nullptr)
1008 return current_algorithm__->maxIter();
1010 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1014 void disableMaxIter() {
1015 K2__.approximationScheme().disableMaxIter();
1016 greedy_hill_climbing__.disableMaxIter();
1017 local_search_with_tabu_list__.disableMaxIter();
1018 Dag2BN__.disableMaxIter();
1021 void enableMaxIter() {
1022 K2__.approximationScheme().enableMaxIter();
1023 greedy_hill_climbing__.enableMaxIter();
1024 local_search_with_tabu_list__.enableMaxIter();
1025 Dag2BN__.enableMaxIter();
1029 bool isEnabledMaxIter()
const {
1030 if (current_algorithm__ !=
nullptr)
1031 return current_algorithm__->isEnabledMaxIter();
1033 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1043 void setMaxTime(
double timeout) {
1044 K2__.approximationScheme().setMaxTime(timeout);
1045 greedy_hill_climbing__.setMaxTime(timeout);
1046 local_search_with_tabu_list__.setMaxTime(timeout);
1047 Dag2BN__.setMaxTime(timeout);
1051 double maxTime()
const {
1052 if (current_algorithm__ !=
nullptr)
1053 return current_algorithm__->maxTime();
1055 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1059 double currentTime()
const {
1060 if (current_algorithm__ !=
nullptr)
1061 return current_algorithm__->currentTime();
1063 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1067 void disableMaxTime() {
1068 K2__.approximationScheme().disableMaxTime();
1069 greedy_hill_climbing__.disableMaxTime();
1070 local_search_with_tabu_list__.disableMaxTime();
1071 Dag2BN__.disableMaxTime();
1073 void enableMaxTime() {
1074 K2__.approximationScheme().enableMaxTime();
1075 greedy_hill_climbing__.enableMaxTime();
1076 local_search_with_tabu_list__.enableMaxTime();
1077 Dag2BN__.enableMaxTime();
1081 bool isEnabledMaxTime()
const {
1082 if (current_algorithm__ !=
nullptr)
1083 return current_algorithm__->isEnabledMaxTime();
1085 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1092 void setPeriodSize(Size p) {
1093 K2__.approximationScheme().setPeriodSize(p);
1094 greedy_hill_climbing__.setPeriodSize(p);
1095 local_search_with_tabu_list__.setPeriodSize(p);
1096 Dag2BN__.setPeriodSize(p);
1099 Size periodSize()
const {
1100 if (current_algorithm__ !=
nullptr)
1101 return current_algorithm__->periodSize();
1103 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1109 void setVerbosity(
bool v) {
1110 K2__.approximationScheme().setVerbosity(v);
1111 greedy_hill_climbing__.setVerbosity(v);
1112 local_search_with_tabu_list__.setVerbosity(v);
1113 Dag2BN__.setVerbosity(v);
1116 bool verbosity()
const {
1117 if (current_algorithm__ !=
nullptr)
1118 return current_algorithm__->verbosity();
1120 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1127 ApproximationSchemeSTATE stateApproximationScheme()
const {
1128 if (current_algorithm__ !=
nullptr)
1129 return current_algorithm__->stateApproximationScheme();
1131 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1135 Size nbrIterations()
const {
1136 if (current_algorithm__ !=
nullptr)
1137 return current_algorithm__->nbrIterations();
1139 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1143 const std::vector<
double >& history()
const {
1144 if (current_algorithm__ !=
nullptr)
1145 return current_algorithm__->history();
1147 GUM_ERROR(FatalError,
"No chosen algorithm for learning");
1157 #ifndef GUM_NO_INLINE 1158 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h> 1161 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h>