32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H 33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H 38 #include <agrum/BN/BayesNet.h> 39 #include <agrum/agrum.h> 40 #include <agrum/tools/core/bijection.h> 41 #include <agrum/tools/core/sequence.h> 42 #include <agrum/tools/graphs/DAG.h> 44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h> 45 #include <agrum/tools/database/DBRowGeneratorParser.h> 46 #include <agrum/tools/database/DBInitializerFromCSV.h> 47 #include <agrum/tools/database/databaseTable.h> 48 #include <agrum/tools/database/DBRowGeneratorParser.h> 49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h> 50 #include <agrum/tools/database/DBRowGeneratorEM.h> 51 #include <agrum/tools/database/DBRowGeneratorSet.h> 53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h> 54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h> 55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h> 56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h> 57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h> 58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h> 60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h> 61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h> 62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h> 63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h> 65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h> 66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h> 67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h> 68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h> 69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h> 70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h> 71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h> 72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h> 73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h> 75 #include <agrum/BN/learning/structureUtils/graphChange.h> 76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h> 77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h> 78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h> 80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h> 81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h> 83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h> 84 #include <agrum/tools/core/approximations/approximationSchemeListener.h> 86 #include <agrum/BN/learning/K2.h> 87 #include <agrum/BN/learning/Miic.h> 88 #include <agrum/BN/learning/greedyHillClimbing.h> 89 #include <agrum/BN/learning/localSearchWithTabuList.h> 91 #include <agrum/tools/core/signal/signaler.h> 97 class BNLearnerListener;
106 class genericBNLearner:
public gum::IApproximationSchemeConfiguration {
122 enum class ParamEstimatorType
128 enum class AprioriType
132 DIRICHLET_FROM_DATABASE,
140 GREEDY_HILL_CLIMBING,
141 LOCAL_SEARCH_WITH_TABU_LIST,
159 explicit Database(
const std::string& file,
160 const std::vector< std::string >& missing_symbols);
165 explicit Database(
const DatabaseTable<>& db);
178 Database(
const std::string& filename,
179 Database& score_database,
180 const std::vector< std::string >& missing_symbols);
189 template <
typename GUM_SCALAR >
190 Database(
const std::string& filename,
191 const gum::BayesNet< GUM_SCALAR >& bn,
192 const std::vector< std::string >& missing_symbols);
195 Database(
const Database& from);
198 Database(Database&& from);
211 Database& operator=(
const Database& from);
214 Database& operator=(Database&& from);
224 DBRowGeneratorParser<>& parser();
227 const std::vector< std::size_t >& domainSizes()
const;
230 const std::vector< std::string >& names()
const;
233 NodeId idFromName(
const std::string& var_name)
const;
236 const std::string& nameFromId(NodeId id)
const;
239 const DatabaseTable<>& databaseTable()
const;
243 void setDatabaseWeight(
const double new_weight);
246 const Bijection< NodeId, std::size_t >& nodeId2Columns()
const;
249 const std::vector< std::string >& missingSymbols()
const;
252 std::size_t nbRows()
const;
255 std::size_t size()
const;
261 void setWeight(
const std::size_t i,
const double weight);
266 double weight(
const std::size_t i)
const;
269 double weight()
const;
276 DatabaseTable<> _database_;
279 DBRowGeneratorParser<>* _parser_{
nullptr};
282 std::vector< std::size_t > _domain_sizes_;
285 Bijection< NodeId, std::size_t > _nodeId2cols_;
288 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE) 289 Size _max_threads_number_{getMaxNumberOfThreads()};
291 Size _max_threads_number_{1};
295 Size _min_nb_rows_per_thread_{100};
300 template <
typename GUM_SCALAR >
301 BayesNet< GUM_SCALAR > _BNVars_()
const;
305 void _setAprioriWeight_(
double weight);
318 genericBNLearner(
const std::string& filename,
319 const std::vector< std::string >& missing_symbols);
320 genericBNLearner(
const DatabaseTable<>& db);
341 template <
typename GUM_SCALAR >
342 genericBNLearner(
const std::string& filename,
343 const gum::BayesNet< GUM_SCALAR >& src,
344 const std::vector< std::string >& missing_symbols);
347 genericBNLearner(
const genericBNLearner&);
350 genericBNLearner(genericBNLearner&&);
353 virtual ~genericBNLearner();
363 genericBNLearner& operator=(
const genericBNLearner&);
366 genericBNLearner& operator=(genericBNLearner&&);
380 MixedGraph learnMixedStructure();
383 void setInitialDAG(
const DAG&);
389 const std::vector< std::string >& names()
const;
392 const std::vector< std::size_t >& domainSizes()
const;
393 Size domainSize(NodeId var)
const;
394 Size domainSize(
const std::string& var)
const;
401 NodeId idFromName(
const std::string& var_name)
const;
404 const DatabaseTable<>& database()
const;
408 void setDatabaseWeight(
const double new_weight);
414 void setRecordWeight(
const std::size_t i,
const double weight);
419 double recordWeight(
const std::size_t i)
const;
422 double databaseWeight()
const;
425 const std::string& nameFromId(NodeId id)
const;
434 template <
template <
typename >
class XALLOC >
435 void useDatabaseRanges(
436 const std::vector< std::pair< std::size_t, std::size_t >,
437 XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
440 void clearDatabaseRanges();
446 const std::vector< std::pair< std::size_t, std::size_t > >& databaseRanges()
const;
469 std::pair< std::size_t, std::size_t > useCrossValidationFold(
const std::size_t learning_fold,
470 const std::size_t k_fold);
480 std::pair<
double,
double >
481 chi2(
const NodeId id1,
const NodeId id2,
const std::vector< NodeId >& knowing = {});
489 std::pair<
double,
double > chi2(
const std::string& name1,
490 const std::string& name2,
491 const std::vector< std::string >& knowing = {});
500 std::pair<
double,
double >
501 G2(
const NodeId id1,
const NodeId id2,
const std::vector< NodeId >& knowing = {});
509 std::pair<
double,
double > G2(
const std::string& name1,
510 const std::string& name2,
511 const std::vector< std::string >& knowing = {});
520 double logLikelihood(
const std::vector< NodeId >& vars,
521 const std::vector< NodeId >& knowing = {});
530 double logLikelihood(
const std::vector< std::string >& vars,
531 const std::vector< std::string >& knowing = {});
538 std::vector<
double > rawPseudoCount(
const std::vector< NodeId >& vars);
545 std::vector<
double > rawPseudoCount(
const std::vector< std::string >& vars);
562 void useEM(
const double epsilon);
565 bool hasMissingValues()
const;
590 void useScoreLog2Likelihood();
606 void useAprioriBDeu(
double weight = 1);
612 void useAprioriSmoothing(
double weight = 1);
615 void useAprioriDirichlet(
const std::string& filename,
double weight = 1);
621 std::string checkScoreAprioriCompatibility()
const;
630 void useGreedyHillClimbing();
636 void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2);
639 void useK2(
const Sequence< NodeId >& order);
642 void useK2(
const std::vector< NodeId >& order);
658 void useNMLCorrection();
661 void useMDLCorrection();
664 void useNoCorrection();
668 const std::vector< Arc > latentVariables()
const;
677 void setMaxIndegree(Size max_indegree);
684 void setSliceOrder(
const NodeProperty< NodeId >& slice_order);
690 void setSliceOrder(
const std::vector< std::vector< std::string > >& slices);
693 void setForbiddenArcs(
const ArcSet& set);
697 void addForbiddenArc(
const Arc& arc);
698 void addForbiddenArc(
const NodeId tail,
const NodeId head);
699 void addForbiddenArc(
const std::string& tail,
const std::string& head);
704 void eraseForbiddenArc(
const Arc& arc);
705 void eraseForbiddenArc(
const NodeId tail,
const NodeId head);
706 void eraseForbiddenArc(
const std::string& tail,
const std::string& head);
710 void setMandatoryArcs(
const ArcSet& set);
714 void addMandatoryArc(
const Arc& arc);
715 void addMandatoryArc(
const NodeId tail,
const NodeId head);
716 void addMandatoryArc(
const std::string& tail,
const std::string& head);
721 void eraseMandatoryArc(
const Arc& arc);
722 void eraseMandatoryArc(
const NodeId tail,
const NodeId head);
723 void eraseMandatoryArc(
const std::string& tail,
const std::string& head);
730 void setPossibleEdges(
const EdgeSet& set);
731 void setPossibleSkeleton(
const UndiGraph& skeleton);
738 void addPossibleEdge(
const Edge& edge);
739 void addPossibleEdge(
const NodeId tail,
const NodeId head);
740 void addPossibleEdge(
const std::string& tail,
const std::string& head);
745 void erasePossibleEdge(
const Edge& edge);
746 void erasePossibleEdge(
const NodeId tail,
const NodeId head);
747 void erasePossibleEdge(
const std::string& tail,
const std::string& head);
754 ScoreType scoreType_{ScoreType::BDeu};
757 Score<>* score_{
nullptr};
760 ParamEstimatorType paramEstimatorType_{ParamEstimatorType::ML};
763 double epsilonEM_{0.0};
766 CorrectedMutualInformation<>* mutualInfo_{
nullptr};
769 AprioriType aprioriType_{AprioriType::NO_APRIORI};
772 Apriori<>* apriori_{
nullptr};
774 AprioriNoApriori<>* noApriori_{
nullptr};
777 double aprioriWeight_{1.0f};
780 StructuralConstraintSliceOrder constraintSliceOrder_;
783 StructuralConstraintIndegree constraintIndegree_;
786 StructuralConstraintTabuList constraintTabuList_;
789 StructuralConstraintForbiddenArcs constraintForbiddenArcs_;
792 StructuralConstraintPossibleEdges constraintPossibleEdges_;
795 StructuralConstraintMandatoryArcs constraintMandatoryArcs_;
798 AlgoType selectedAlgo_{AlgoType::GREEDY_HILL_CLIMBING};
807 typename CorrectedMutualInformation<>::KModeTypes kmode3Off2_{
808 CorrectedMutualInformation<>::KModeTypes::MDL};
811 DAG2BNLearner<> Dag2BN_;
814 GreedyHillClimbing greedyHillClimbing_;
817 LocalSearchWithTabuList localSearchWithTabuList_;
820 Database scoreDatabase_;
823 std::vector< std::pair< std::size_t, std::size_t > > ranges_;
826 Database* aprioriDatabase_{
nullptr};
829 std::string aprioriDbname_;
836 std::string filename_;
839 Size nbDecreasingChanges_{2};
842 const ApproximationScheme* currentAlgorithm_{
nullptr};
845 static DatabaseTable<> readFile_(
const std::string& filename,
846 const std::vector< std::string >& missing_symbols);
849 static void isCSVFileName_(
const std::string& filename);
852 void createApriori_();
858 ParamEstimator<>* createParamEstimator_(DBRowGeneratorParser<>& parser,
859 bool take_into_account_score =
true);
865 MixedGraph prepareMiic3Off2_();
868 const std::string& getAprioriType_()
const;
871 void createCorrectedMutualInformation_();
884 INLINE
void setCurrentApproximationScheme(
const ApproximationScheme* approximationScheme) {
885 currentAlgorithm_ = approximationScheme;
888 INLINE
void distributeProgress(
const ApproximationScheme* approximationScheme,
892 setCurrentApproximationScheme(approximationScheme);
894 if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time);
898 INLINE
void distributeStop(
const ApproximationScheme* approximationScheme,
899 std::string message) {
900 setCurrentApproximationScheme(approximationScheme);
902 if (onStop.hasListener()) GUM_EMIT1(onStop, message);
910 void setEpsilon(
double eps) {
911 algoK2_.approximationScheme().setEpsilon(eps);
912 greedyHillClimbing_.setEpsilon(eps);
913 localSearchWithTabuList_.setEpsilon(eps);
914 Dag2BN_.setEpsilon(eps);
918 double epsilon()
const {
919 if (currentAlgorithm_ !=
nullptr)
920 return currentAlgorithm_->epsilon();
922 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
926 void disableEpsilon() {
927 algoK2_.approximationScheme().disableEpsilon();
928 greedyHillClimbing_.disableEpsilon();
929 localSearchWithTabuList_.disableEpsilon();
930 Dag2BN_.disableEpsilon();
934 void enableEpsilon() {
935 algoK2_.approximationScheme().enableEpsilon();
936 greedyHillClimbing_.enableEpsilon();
937 localSearchWithTabuList_.enableEpsilon();
938 Dag2BN_.enableEpsilon();
943 bool isEnabledEpsilon()
const {
944 if (currentAlgorithm_ !=
nullptr)
945 return currentAlgorithm_->isEnabledEpsilon();
947 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
956 void setMinEpsilonRate(
double rate) {
957 algoK2_.approximationScheme().setMinEpsilonRate(rate);
958 greedyHillClimbing_.setMinEpsilonRate(rate);
959 localSearchWithTabuList_.setMinEpsilonRate(rate);
960 Dag2BN_.setMinEpsilonRate(rate);
964 double minEpsilonRate()
const {
965 if (currentAlgorithm_ !=
nullptr)
966 return currentAlgorithm_->minEpsilonRate();
968 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
972 void disableMinEpsilonRate() {
973 algoK2_.approximationScheme().disableMinEpsilonRate();
974 greedyHillClimbing_.disableMinEpsilonRate();
975 localSearchWithTabuList_.disableMinEpsilonRate();
976 Dag2BN_.disableMinEpsilonRate();
979 void enableMinEpsilonRate() {
980 algoK2_.approximationScheme().enableMinEpsilonRate();
981 greedyHillClimbing_.enableMinEpsilonRate();
982 localSearchWithTabuList_.enableMinEpsilonRate();
983 Dag2BN_.enableMinEpsilonRate();
987 bool isEnabledMinEpsilonRate()
const {
988 if (currentAlgorithm_ !=
nullptr)
989 return currentAlgorithm_->isEnabledMinEpsilonRate();
991 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1000 void setMaxIter(Size max) {
1001 algoK2_.approximationScheme().setMaxIter(max);
1002 greedyHillClimbing_.setMaxIter(max);
1003 localSearchWithTabuList_.setMaxIter(max);
1004 Dag2BN_.setMaxIter(max);
1008 Size maxIter()
const {
1009 if (currentAlgorithm_ !=
nullptr)
1010 return currentAlgorithm_->maxIter();
1012 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1016 void disableMaxIter() {
1017 algoK2_.approximationScheme().disableMaxIter();
1018 greedyHillClimbing_.disableMaxIter();
1019 localSearchWithTabuList_.disableMaxIter();
1020 Dag2BN_.disableMaxIter();
1023 void enableMaxIter() {
1024 algoK2_.approximationScheme().enableMaxIter();
1025 greedyHillClimbing_.enableMaxIter();
1026 localSearchWithTabuList_.enableMaxIter();
1027 Dag2BN_.enableMaxIter();
1031 bool isEnabledMaxIter()
const {
1032 if (currentAlgorithm_ !=
nullptr)
1033 return currentAlgorithm_->isEnabledMaxIter();
1035 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1045 void setMaxTime(
double timeout) {
1046 algoK2_.approximationScheme().setMaxTime(timeout);
1047 greedyHillClimbing_.setMaxTime(timeout);
1048 localSearchWithTabuList_.setMaxTime(timeout);
1049 Dag2BN_.setMaxTime(timeout);
1053 double maxTime()
const {
1054 if (currentAlgorithm_ !=
nullptr)
1055 return currentAlgorithm_->maxTime();
1057 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1061 double currentTime()
const {
1062 if (currentAlgorithm_ !=
nullptr)
1063 return currentAlgorithm_->currentTime();
1065 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1069 void disableMaxTime() {
1070 algoK2_.approximationScheme().disableMaxTime();
1071 greedyHillClimbing_.disableMaxTime();
1072 localSearchWithTabuList_.disableMaxTime();
1073 Dag2BN_.disableMaxTime();
1075 void enableMaxTime() {
1076 algoK2_.approximationScheme().enableMaxTime();
1077 greedyHillClimbing_.enableMaxTime();
1078 localSearchWithTabuList_.enableMaxTime();
1079 Dag2BN_.enableMaxTime();
1083 bool isEnabledMaxTime()
const {
1084 if (currentAlgorithm_ !=
nullptr)
1085 return currentAlgorithm_->isEnabledMaxTime();
1087 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1094 void setPeriodSize(Size p) {
1095 algoK2_.approximationScheme().setPeriodSize(p);
1096 greedyHillClimbing_.setPeriodSize(p);
1097 localSearchWithTabuList_.setPeriodSize(p);
1098 Dag2BN_.setPeriodSize(p);
1101 Size periodSize()
const {
1102 if (currentAlgorithm_ !=
nullptr)
1103 return currentAlgorithm_->periodSize();
1105 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1111 void setVerbosity(
bool v) {
1112 algoK2_.approximationScheme().setVerbosity(v);
1113 greedyHillClimbing_.setVerbosity(v);
1114 localSearchWithTabuList_.setVerbosity(v);
1115 Dag2BN_.setVerbosity(v);
1118 bool verbosity()
const {
1119 if (currentAlgorithm_ !=
nullptr)
1120 return currentAlgorithm_->verbosity();
1122 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1129 ApproximationSchemeSTATE stateApproximationScheme()
const {
1130 if (currentAlgorithm_ !=
nullptr)
1131 return currentAlgorithm_->stateApproximationScheme();
1133 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1137 Size nbrIterations()
const {
1138 if (currentAlgorithm_ !=
nullptr)
1139 return currentAlgorithm_->nbrIterations();
1141 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1145 const std::vector<
double >& history()
const {
1146 if (currentAlgorithm_ !=
nullptr)
1147 return currentAlgorithm_->history();
1149 GUM_ERROR(FatalError,
"No chosen algorithm for learning")
1159 #ifndef GUM_NO_INLINE 1160 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h> 1163 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h>