32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H 33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H 38 #include <agrum/BN/BayesNet.h> 39 #include <agrum/agrum.h> 40 #include <agrum/tools/core/bijection.h> 41 #include <agrum/tools/core/sequence.h> 42 #include <agrum/tools/graphs/DAG.h> 44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h> 45 #include <agrum/tools/database/DBRowGeneratorParser.h> 46 #include <agrum/tools/database/DBInitializerFromCSV.h> 47 #include <agrum/tools/database/databaseTable.h> 48 #include <agrum/tools/database/DBRowGeneratorParser.h> 49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h> 50 #include <agrum/tools/database/DBRowGeneratorEM.h> 51 #include <agrum/tools/database/DBRowGeneratorSet.h> 53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h> 54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h> 55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h> 56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h> 57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h> 58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h> 60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h> 61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h> 62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h> 63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h> 65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h> 66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h> 67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h> 68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h> 69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h> 70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h> 71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h> 72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h> 73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h> 75 #include <agrum/BN/learning/structureUtils/graphChange.h> 76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h> 77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h> 78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h> 80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h> 81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h> 83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h> 84 #include <agrum/tools/core/approximations/approximationSchemeListener.h> 86 #include <agrum/BN/learning/K2.h> 87 #include <agrum/BN/learning/Miic.h> 88 #include <agrum/BN/learning/greedyHillClimbing.h> 89 #include <agrum/BN/learning/localSearchWithTabuList.h> 91 #include <agrum/tools/core/signal/signaler.h> 97 class BNLearnerListener;
106 class genericBNLearner:
public gum::IApproximationSchemeConfiguration {
122 enum class ParamEstimatorType
126 enum class AprioriType
130 DIRICHLET_FROM_DATABASE,
138 GREEDY_HILL_CLIMBING,
139 LOCAL_SEARCH_WITH_TABU_LIST,
156 explicit Database(
const std::string& file,
157 const std::vector< std::string >& missing_symbols);
162 explicit Database(
const DatabaseTable<>& db);
175 Database(
const std::string& filename,
176 Database& score_database,
177 const std::vector< std::string >& missing_symbols);
186 template <
typename GUM_SCALAR >
187 Database(
const std::string& filename,
188 const gum::BayesNet< GUM_SCALAR >& bn,
189 const std::vector< std::string >& missing_symbols);
192 Database(
const Database& from);
195 Database(Database&& from);
208 Database& operator=(
const Database& from);
211 Database& operator=(Database&& from);
221 DBRowGeneratorParser<>& parser();
224 const std::vector< std::size_t >& domainSizes()
const;
227 const std::vector< std::string >& names()
const;
230 NodeId idFromName(
const std::string& var_name)
const;
233 const std::string& nameFromId(NodeId id)
const;
236 const DatabaseTable<>& databaseTable()
const;
240 void setDatabaseWeight(
const double new_weight);
243 const Bijection< NodeId, std::size_t >& nodeId2Columns()
const;
246 const std::vector< std::string >& missingSymbols()
const;
249 std::size_t nbRows()
const;
252 std::size_t size()
const;
258 void setWeight(
const std::size_t i,
const double weight);
263 double weight(
const std::size_t i)
const;
266 double weight()
const;
273 DatabaseTable<> _database_;
276 DBRowGeneratorParser<>* _parser_{
nullptr};
279 std::vector< std::size_t > _domain_sizes_;
282 Bijection< NodeId, std::size_t > _nodeId2cols_;
285 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE) 286 Size _max_threads_number_{getMaxNumberOfThreads()};
288 Size _max_threads_number_{1};
292 Size _min_nb_rows_per_thread_{100};
297 template <
typename GUM_SCALAR >
298 BayesNet< GUM_SCALAR > _BNVars_()
const;
302 void _setAprioriWeight_(
double weight);
315 genericBNLearner(
const std::string& filename,
316 const std::vector< std::string >& missing_symbols);
317 genericBNLearner(
const DatabaseTable<>& db);
338 template <
typename GUM_SCALAR >
339 genericBNLearner(
const std::string& filename,
340 const gum::BayesNet< GUM_SCALAR >& src,
341 const std::vector< std::string >& missing_symbols);
344 genericBNLearner(
const genericBNLearner&);
347 genericBNLearner(genericBNLearner&&);
350 virtual ~genericBNLearner();
360 genericBNLearner& operator=(
const genericBNLearner&);
363 genericBNLearner& operator=(genericBNLearner&&);
377 MixedGraph learnMixedStructure();
380 void setInitialDAG(
const DAG&);
383 const std::vector< std::string >& names()
const;
386 const std::vector< std::size_t >& domainSizes()
const;
387 Size domainSize(NodeId var)
const;
388 Size domainSize(
const std::string& var)
const;
395 NodeId idFromName(
const std::string& var_name)
const;
398 const DatabaseTable<>& database()
const;
402 void setDatabaseWeight(
const double new_weight);
408 void setRecordWeight(
const std::size_t i,
const double weight);
413 double recordWeight(
const std::size_t i)
const;
416 double databaseWeight()
const;
419 const std::string& nameFromId(NodeId id)
const;
428 template <
template <
typename >
class XALLOC >
429 void useDatabaseRanges(
430 const std::vector< std::pair< std::size_t, std::size_t >,
431 XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
434 void clearDatabaseRanges();
440 const std::vector< std::pair< std::size_t, std::size_t > >& databaseRanges()
const;
463 std::pair< std::size_t, std::size_t > useCrossValidationFold(
const std::size_t learning_fold,
464 const std::size_t k_fold);
474 std::pair<
double,
double >
475 chi2(
const NodeId id1,
const NodeId id2,
const std::vector< NodeId >& knowing = {});
483 std::pair<
double,
double > chi2(
const std::string& name1,
484 const std::string& name2,
485 const std::vector< std::string >& knowing = {});
494 std::pair<
double,
double >
495 G2(
const NodeId id1,
const NodeId id2,
const std::vector< NodeId >& knowing = {});
503 std::pair<
double,
double > G2(
const std::string& name1,
504 const std::string& name2,
505 const std::vector< std::string >& knowing = {});
514 double logLikelihood(
const std::vector< NodeId >& vars,
515 const std::vector< NodeId >& knowing = {});
524 double logLikelihood(
const std::vector< std::string >& vars,
525 const std::vector< std::string >& knowing = {});
532 std::vector<
double > rawPseudoCount(
const std::vector< NodeId >& vars);
539 std::vector<
double > rawPseudoCount(
const std::vector< std::string >& vars);
556 void useEM(
const double epsilon);
559 bool hasMissingValues()
const;
584 void useScoreLog2Likelihood();
600 void useAprioriBDeu(
double weight = 1);
606 void useAprioriSmoothing(
double weight = 1);
609 void useAprioriDirichlet(
const std::string& filename,
double weight = 1);
615 std::string checkScoreAprioriCompatibility();
624 void useGreedyHillClimbing();
630 void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2);
633 void useK2(
const Sequence< NodeId >& order);
636 void useK2(
const std::vector< NodeId >& order);
652 void useNMLCorrection();
655 void useMDLCorrection();
658 void useNoCorrection();
662 const std::vector< Arc > latentVariables()
const;
671 void setMaxIndegree(Size max_indegree);
678 void setSliceOrder(
const NodeProperty< NodeId >& slice_order);
684 void setSliceOrder(
const std::vector< std::vector< std::string > >& slices);
687 void setForbiddenArcs(
const ArcSet& set);
691 void addForbiddenArc(
const Arc& arc);
692 void addForbiddenArc(
const NodeId tail,
const NodeId head);
693 void addForbiddenArc(
const std::string& tail,
const std::string& head);
698 void eraseForbiddenArc(
const Arc& arc);
699 void eraseForbiddenArc(
const NodeId tail,
const NodeId head);
700 void eraseForbiddenArc(
const std::string& tail,
const std::string& head);
704 void setMandatoryArcs(
const ArcSet& set);
708 void addMandatoryArc(
const Arc& arc);
709 void addMandatoryArc(
const NodeId tail,
const NodeId head);
710 void addMandatoryArc(
const std::string& tail,
const std::string& head);
715 void eraseMandatoryArc(
const Arc& arc);
716 void eraseMandatoryArc(
const NodeId tail,
const NodeId head);
717 void eraseMandatoryArc(
const std::string& tail,
const std::string& head);
724 void setPossibleEdges(
const EdgeSet& set);
725 void setPossibleSkeleton(
const UndiGraph& skeleton);
732 void addPossibleEdge(
const Edge& edge);
733 void addPossibleEdge(
const NodeId tail,
const NodeId head);
734 void addPossibleEdge(
const std::string& tail,
const std::string& head);
739 void erasePossibleEdge(
const Edge& edge);
740 void erasePossibleEdge(
const NodeId tail,
const NodeId head);
741 void erasePossibleEdge(
const std::string& tail,
const std::string& head);
748 ScoreType scoreType_{ScoreType::BDeu};
751 Score<>* score_{
nullptr};
754 ParamEstimatorType paramEstimatorType_{ParamEstimatorType::ML};
757 double epsilonEM_{0.0};
760 CorrectedMutualInformation<>* mutualInfo_{
nullptr};
763 AprioriType aprioriType_{AprioriType::NO_APRIORI};
766 Apriori<>* apriori_{
nullptr};
768 AprioriNoApriori<>* noApriori_{
nullptr};
771 double aprioriWeight_{1.0f};
774 StructuralConstraintSliceOrder constraintSliceOrder_;
777 StructuralConstraintIndegree constraintIndegree_;
780 StructuralConstraintTabuList constraintTabuList_;
783 StructuralConstraintForbiddenArcs constraintForbiddenArcs_;
786 StructuralConstraintPossibleEdges constraintPossibleEdges_;
789 StructuralConstraintMandatoryArcs constraintMandatoryArcs_;
792 AlgoType selectedAlgo_{AlgoType::GREEDY_HILL_CLIMBING};
801 typename CorrectedMutualInformation<>::KModeTypes kmode3Off2_{
802 CorrectedMutualInformation<>::KModeTypes::MDL};
805 DAG2BNLearner<> Dag2BN_;
808 GreedyHillClimbing greedyHillClimbing_;
811 LocalSearchWithTabuList localSearchWithTabuList_;
814 Database scoreDatabase_;
817 std::vector< std::pair< std::size_t, std::size_t > > ranges_;
820 Database* aprioriDatabase_{
nullptr};
823 std::string aprioriDbname_;
829 const ApproximationScheme* currentAlgorithm_{
nullptr};
832 static DatabaseTable<> readFile_(
const std::string& filename,
833 const std::vector< std::string >& missing_symbols);
836 static void checkFileName_(
const std::string& filename);
839 void createApriori_();
845 ParamEstimator<>* createParamEstimator_(DBRowGeneratorParser<>& parser,
846 bool take_into_account_score =
true);
852 MixedGraph prepareMiic3Off2_();
855 const std::string& getAprioriType_()
const;
858 void createCorrectedMutualInformation_();
871 INLINE
void setCurrentApproximationScheme(
const ApproximationScheme* approximationScheme) {
872 currentAlgorithm_ = approximationScheme;
875 INLINE
void distributeProgress(
const ApproximationScheme* approximationScheme,
879 setCurrentApproximationScheme(approximationScheme);
881 if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time);
885 INLINE
void distributeStop(
const ApproximationScheme* approximationScheme,
886 std::string message) {
887 setCurrentApproximationScheme(approximationScheme);
889 if (onStop.hasListener()) GUM_EMIT1(onStop, message);
897 void setEpsilon(
double eps) {
898 algoK2_.approximationScheme().setEpsilon(eps);
899 greedyHillClimbing_.setEpsilon(eps);
900 localSearchWithTabuList_.setEpsilon(eps);
901 Dag2BN_.setEpsilon(eps);
905 double epsilon()
const {
906 if (currentAlgorithm_ !=
nullptr)
907 return currentAlgorithm_->epsilon();
909 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
913 void disableEpsilon() {
914 algoK2_.approximationScheme().disableEpsilon();
915 greedyHillClimbing_.disableEpsilon();
916 localSearchWithTabuList_.disableEpsilon();
917 Dag2BN_.disableEpsilon();
921 void enableEpsilon() {
922 algoK2_.approximationScheme().enableEpsilon();
923 greedyHillClimbing_.enableEpsilon();
924 localSearchWithTabuList_.enableEpsilon();
925 Dag2BN_.enableEpsilon();
930 bool isEnabledEpsilon()
const {
931 if (currentAlgorithm_ !=
nullptr)
932 return currentAlgorithm_->isEnabledEpsilon();
934 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
943 void setMinEpsilonRate(
double rate) {
944 algoK2_.approximationScheme().setMinEpsilonRate(rate);
945 greedyHillClimbing_.setMinEpsilonRate(rate);
946 localSearchWithTabuList_.setMinEpsilonRate(rate);
947 Dag2BN_.setMinEpsilonRate(rate);
951 double minEpsilonRate()
const {
952 if (currentAlgorithm_ !=
nullptr)
953 return currentAlgorithm_->minEpsilonRate();
955 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
959 void disableMinEpsilonRate() {
960 algoK2_.approximationScheme().disableMinEpsilonRate();
961 greedyHillClimbing_.disableMinEpsilonRate();
962 localSearchWithTabuList_.disableMinEpsilonRate();
963 Dag2BN_.disableMinEpsilonRate();
966 void enableMinEpsilonRate() {
967 algoK2_.approximationScheme().enableMinEpsilonRate();
968 greedyHillClimbing_.enableMinEpsilonRate();
969 localSearchWithTabuList_.enableMinEpsilonRate();
970 Dag2BN_.enableMinEpsilonRate();
974 bool isEnabledMinEpsilonRate()
const {
975 if (currentAlgorithm_ !=
nullptr)
976 return currentAlgorithm_->isEnabledMinEpsilonRate();
978 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
987 void setMaxIter(Size max) {
988 algoK2_.approximationScheme().setMaxIter(max);
989 greedyHillClimbing_.setMaxIter(max);
990 localSearchWithTabuList_.setMaxIter(max);
991 Dag2BN_.setMaxIter(max);
995 Size maxIter()
const {
996 if (currentAlgorithm_ !=
nullptr)
997 return currentAlgorithm_->maxIter();
999 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1003 void disableMaxIter() {
1004 algoK2_.approximationScheme().disableMaxIter();
1005 greedyHillClimbing_.disableMaxIter();
1006 localSearchWithTabuList_.disableMaxIter();
1007 Dag2BN_.disableMaxIter();
1010 void enableMaxIter() {
1011 algoK2_.approximationScheme().enableMaxIter();
1012 greedyHillClimbing_.enableMaxIter();
1013 localSearchWithTabuList_.enableMaxIter();
1014 Dag2BN_.enableMaxIter();
1018 bool isEnabledMaxIter()
const {
1019 if (currentAlgorithm_ !=
nullptr)
1020 return currentAlgorithm_->isEnabledMaxIter();
1022 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1032 void setMaxTime(
double timeout) {
1033 algoK2_.approximationScheme().setMaxTime(timeout);
1034 greedyHillClimbing_.setMaxTime(timeout);
1035 localSearchWithTabuList_.setMaxTime(timeout);
1036 Dag2BN_.setMaxTime(timeout);
1040 double maxTime()
const {
1041 if (currentAlgorithm_ !=
nullptr)
1042 return currentAlgorithm_->maxTime();
1044 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1048 double currentTime()
const {
1049 if (currentAlgorithm_ !=
nullptr)
1050 return currentAlgorithm_->currentTime();
1052 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1056 void disableMaxTime() {
1057 algoK2_.approximationScheme().disableMaxTime();
1058 greedyHillClimbing_.disableMaxTime();
1059 localSearchWithTabuList_.disableMaxTime();
1060 Dag2BN_.disableMaxTime();
1062 void enableMaxTime() {
1063 algoK2_.approximationScheme().enableMaxTime();
1064 greedyHillClimbing_.enableMaxTime();
1065 localSearchWithTabuList_.enableMaxTime();
1066 Dag2BN_.enableMaxTime();
1070 bool isEnabledMaxTime()
const {
1071 if (currentAlgorithm_ !=
nullptr)
1072 return currentAlgorithm_->isEnabledMaxTime();
1074 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1081 void setPeriodSize(Size p) {
1082 algoK2_.approximationScheme().setPeriodSize(p);
1083 greedyHillClimbing_.setPeriodSize(p);
1084 localSearchWithTabuList_.setPeriodSize(p);
1085 Dag2BN_.setPeriodSize(p);
1088 Size periodSize()
const {
1089 if (currentAlgorithm_ !=
nullptr)
1090 return currentAlgorithm_->periodSize();
1092 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1098 void setVerbosity(
bool v) {
1099 algoK2_.approximationScheme().setVerbosity(v);
1100 greedyHillClimbing_.setVerbosity(v);
1101 localSearchWithTabuList_.setVerbosity(v);
1102 Dag2BN_.setVerbosity(v);
1105 bool verbosity()
const {
1106 if (currentAlgorithm_ !=
nullptr)
1107 return currentAlgorithm_->verbosity();
1109 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1116 ApproximationSchemeSTATE stateApproximationScheme()
const {
1117 if (currentAlgorithm_ !=
nullptr)
1118 return currentAlgorithm_->stateApproximationScheme();
1120 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1124 Size nbrIterations()
const {
1125 if (currentAlgorithm_ !=
nullptr)
1126 return currentAlgorithm_->nbrIterations();
1128 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1132 const std::vector<
double >& history()
const {
1133 if (currentAlgorithm_ !=
nullptr)
1134 return currentAlgorithm_->history();
1136 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1146 #ifndef GUM_NO_INLINE 1147 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h> 1150 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h>