aGrUM  0.21.0
a C++ library for (probabilistic) graphical models
genericBNLearner.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief A class for generic framework of learning algorithms that can easily
25  * be used.
26  *
27  * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and
28  * LocalSearchWithTabuList
29  *
30  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
31  */
32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H
33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H
34 
35 #include <sstream>
36 #include <memory>
37 
38 #include <agrum/BN/BayesNet.h>
39 #include <agrum/agrum.h>
40 #include <agrum/tools/core/bijection.h>
41 #include <agrum/tools/core/sequence.h>
42 #include <agrum/tools/graphs/DAG.h>
43 
44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h>
45 #include <agrum/tools/database/DBRowGeneratorParser.h>
46 #include <agrum/tools/database/DBInitializerFromCSV.h>
47 #include <agrum/tools/database/databaseTable.h>
48 #include <agrum/tools/database/DBRowGeneratorParser.h>
49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h>
50 #include <agrum/tools/database/DBRowGeneratorEM.h>
51 #include <agrum/tools/database/DBRowGeneratorSet.h>
52 
53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h>
54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h>
55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h>
56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h>
57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h>
58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
59 
60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h>
61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h>
62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h>
63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h>
64 
65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h>
66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h>
67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h>
68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h>
69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h>
70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h>
71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h>
72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h>
73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h>
74 
75 #include <agrum/BN/learning/structureUtils/graphChange.h>
76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h>
77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h>
78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h>
79 
80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h>
82 
83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h>
84 #include <agrum/tools/core/approximations/approximationSchemeListener.h>
85 
86 #include <agrum/BN/learning/K2.h>
87 #include <agrum/BN/learning/Miic.h>
88 #include <agrum/BN/learning/greedyHillClimbing.h>
89 #include <agrum/BN/learning/localSearchWithTabuList.h>
90 
91 #include <agrum/tools/core/signal/signaler.h>
92 
93 namespace gum {
94 
95  namespace learning {
96 
97  class BNLearnerListener;
98 
99  /** @class genericBNLearner
100  * @brief A pack of learning algorithms that can easily be used
101  *
102  * The pack currently contains K2, GreedyHillClimbing and
103  * LocalSearchWithTabuList also 3off2/miic
104  * @ingroup learning_group
105  */
106  class genericBNLearner: public gum::IApproximationSchemeConfiguration {
107  // private:
108  public:
109  /// an enumeration enabling to select easily the score we wish to use
110  enum class ScoreType
111  {
112  AIC,
113  BD,
114  BDeu,
115  BIC,
116  K2,
117  LOG2LIKELIHOOD
118  };
119 
120  /// an enumeration to select the type of parameter estimation we shall
121  /// apply
122  enum class ParamEstimatorType
123  {
124  ML
125  };
126 
127  /// an enumeration to select the apriori
128  enum class AprioriType
129  {
130  NO_APRIORI,
131  SMOOTHING,
132  DIRICHLET_FROM_DATABASE,
133  BDEU
134  };
135 
136  /// an enumeration to select easily the learning algorithm to use
137  enum class AlgoType
138  {
139  K2,
140  GREEDY_HILL_CLIMBING,
141  LOCAL_SEARCH_WITH_TABU_LIST,
142  MIIC,
143  THREE_OFF_TWO
144  };
145 
146 
147  /// a helper to easily read databases
148  class Database {
149  public:
150  // ########################################################################
151  /// @name Constructors / Destructors
152  // ########################################################################
153  /// @{
154 
155  /// default constructor
156  /** @param file the name of the CSV file containing the data
157  * @param missing_symbols the set of symbols in the CSV file that
158  * correspond to missing data */
159  explicit Database(const std::string& file,
160  const std::vector< std::string >& missing_symbols);
161 
162  /// default constructor
163  /** @param db an already initialized database table that is used to
164  * fill the Database */
165  explicit Database(const DatabaseTable<>& db);
166 
167  /// constructor for the aprioris
168  /** We must ensure that the variables of the Database are identical to
169  * those of the score database (else the countings used by the
170  * scores might be erroneous). However, we allow the variables to be
171  * ordered differently in the two databases: variables with the same
172  * name in both databases are supposed to be the same.
173  * @param file the name of the CSV file containing the data
174  * @param score_database the main database used for the learning
175  * @param missing_symbols the set of symbols in the CSV file that
176  * correspond to missing data
177  */
178  Database(const std::string& filename,
179  Database& score_database,
180  const std::vector< std::string >& missing_symbols);
181 
182  /// constructor with a BN providing the variables of interest
183  /** @param file the name of the CSV file containing the data
184  * @param bn a Bayesian network indicating which variables of the CSV
185  * file are used for learning
186  * @param missing_symbols the set of symbols in the CSV file that
187  * correspond to missing data
188  */
189  template < typename GUM_SCALAR >
190  Database(const std::string& filename,
191  const gum::BayesNet< GUM_SCALAR >& bn,
192  const std::vector< std::string >& missing_symbols);
193 
194  /// copy constructor
195  Database(const Database& from);
196 
197  /// move constructor
198  Database(Database&& from);
199 
200  /// destructor
201  ~Database();
202 
203  /// @}
204 
205  // ########################################################################
206  /// @name Operators
207  // ########################################################################
208  /// @{
209 
210  /// copy operator
211  Database& operator=(const Database& from);
212 
213  /// move operator
214  Database& operator=(Database&& from);
215 
216  /// @}
217 
218  // ########################################################################
219  /// @name Accessors / Modifiers
220  // ########################################################################
221  /// @{
222 
223  /// returns the parser for the database
224  DBRowGeneratorParser<>& parser();
225 
226  /// returns the domain sizes of the variables
227  const std::vector< std::size_t >& domainSizes() const;
228 
229  /// returns the names of the variables in the database
230  const std::vector< std::string >& names() const;
231 
232  /// returns the node id corresponding to a variable name
233  NodeId idFromName(const std::string& var_name) const;
234 
235  /// returns the variable name corresponding to a given node id
236  const std::string& nameFromId(NodeId id) const;
237 
238  /// returns the internal database table
239  const DatabaseTable<>& databaseTable() const;
240 
241  /** @brief assign a weight to all the rows of the database so
242  * that the sum of their weights is equal to new_weight */
243  void setDatabaseWeight(const double new_weight);
244 
245  /// returns the mapping between node ids and their columns in the database
246  const Bijection< NodeId, std::size_t >& nodeId2Columns() const;
247 
248  /// returns the set of missing symbols taken into account
249  const std::vector< std::string >& missingSymbols() const;
250 
251  /// returns the number of records in the database
252  std::size_t nbRows() const;
253 
254  /// returns the number of records in the database
255  std::size_t size() const;
256 
257  /// sets the weight of the ith record
258  /** @throws OutOfBounds if i is outside the set of indices of the
259  * records or if the weight is negative
260  */
261  void setWeight(const std::size_t i, const double weight);
262 
263  /// returns the weight of the ith record
264  /** @throws OutOfBounds if i is outside the set of indices of the
265  * records */
266  double weight(const std::size_t i) const;
267 
268  /// returns the weight of the whole database
269  double weight() const;
270 
271 
272  /// @}
273 
274  protected:
275  /// the database itself
276  DatabaseTable<> _database_;
277 
278  /// the parser used for reading the database
279  DBRowGeneratorParser<>* _parser_{nullptr};
280 
281  /// the domain sizes of the variables (useful to speed-up computations)
282  std::vector< std::size_t > _domain_sizes_;
283 
284  /// a bijection assigning to each variable name its NodeId
285  Bijection< NodeId, std::size_t > _nodeId2cols_;
286 
287 /// the max number of threads authorized
288 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE)
289  Size _max_threads_number_{getMaxNumberOfThreads()};
290 #else
291  Size _max_threads_number_{1};
292 #endif /* GUM_DEBUG_MODE */
293 
294  /// the minimal number of rows to parse (on average) by thread
295  Size _min_nb_rows_per_thread_{100};
296 
297  private:
298  // returns the set of variables as a BN. This is convenient for
299  // the constructors of apriori Databases
300  template < typename GUM_SCALAR >
301  BayesNet< GUM_SCALAR > _BNVars_() const;
302  };
303 
304  /// sets the apriori weight
305  void _setAprioriWeight_(double weight);
306 
307  public:
308  // ##########################################################################
309  /// @name Constructors / Destructors
310  // ##########################################################################
311  /// @{
312 
313  /// default constructor
314  /**
315  * read the database file for the score / parameter estimation and var
316  * names
317  */
318  genericBNLearner(const std::string& filename,
319  const std::vector< std::string >& missing_symbols);
320  genericBNLearner(const DatabaseTable<>& db);
321 
322  /**
323  * read the database file for the score / parameter estimation and var
324  * names
325  * @param filename The file to learn from.
326  * @param modalities indicate for some nodes (not necessarily all the
327  * nodes of the BN) which modalities they should have and in which order
328  * these modalities should be stored into the nodes. For instance, if
329  * modalities = { 1 -> {True, False, Big} }, then the node of id 1 in the
330  * BN will have 3 modalities, the first one being True, the second one
331  * being False, and the third bein Big.
332  * @param parse_database if true, the modalities specified by the user
333  * will be considered as a superset of the modalities of the variables. A
334  * parsing of the database will allow to determine which ones are really
335  * necessary and will keep them in the order specified by the user
336  * (NodeProperty modalities). If parse_database is set to false (the
337  * default), then the modalities specified by the user will be considered
338  * as being exactly those of the variables of the BN (as a consequence,
339  * if we find other values in the database, an exception will be raised
340  * during learning). */
341  template < typename GUM_SCALAR >
342  genericBNLearner(const std::string& filename,
343  const gum::BayesNet< GUM_SCALAR >& src,
344  const std::vector< std::string >& missing_symbols);
345 
346  /// copy constructor
347  genericBNLearner(const genericBNLearner&);
348 
349  /// move constructor
350  genericBNLearner(genericBNLearner&&);
351 
352  /// destructor
353  virtual ~genericBNLearner();
354 
355  /// @}
356 
357  // ##########################################################################
358  /// @name Operators
359  // ##########################################################################
360  /// @{
361 
362  /// copy operator
363  genericBNLearner& operator=(const genericBNLearner&);
364 
365  /// move operator
366  genericBNLearner& operator=(genericBNLearner&&);
367 
368  /// @}
369 
370  // ##########################################################################
371  /// @name Accessors / Modifiers
372  // ##########################################################################
373  /// @{
374 
375  /// learn a structure from a file (must have read the db before)
376  DAG learnDAG();
377 
378  /// learn a partial structure from a file (must have read the db before and
379  /// must have selected miic or 3off2)
380  MixedGraph learnMixedStructure();
381 
382  /// sets an initial DAG structure
383  void setInitialDAG(const DAG&);
384 
385  /// returns the initial DAG structure
386  DAG initialDAG();
387 
388  /// returns the names of the variables in the database
389  const std::vector< std::string >& names() const;
390 
391  /// returns the domain sizes of the variables in the database
392  const std::vector< std::size_t >& domainSizes() const;
393  Size domainSize(NodeId var) const;
394  Size domainSize(const std::string& var) const;
395 
396  /// returns the node id corresponding to a variable name
397  /**
398  * @throw MissingVariableInDatabase if a variable of the BN is not found
399  * in the database.
400  */
401  NodeId idFromName(const std::string& var_name) const;
402 
403  /// returns the database used by the BNLearner
404  const DatabaseTable<>& database() const;
405 
406  /** @brief assign a weight to all the rows of the learning database so
407  * that the sum of their weights is equal to new_weight */
408  void setDatabaseWeight(const double new_weight);
409 
410  /// sets the weight of the ith record of the database
411  /** @throws OutOfBounds if i is outside the set of indices of the
412  * records or if the weight is negative
413  */
414  void setRecordWeight(const std::size_t i, const double weight);
415 
416  /// returns the weight of the ith record
417  /** @throws OutOfBounds if i is outside the set of indices of the
418  * records */
419  double recordWeight(const std::size_t i) const;
420 
421  /// returns the weight of the whole database
422  double databaseWeight() const;
423 
424  /// returns the variable name corresponding to a given node id
425  const std::string& nameFromId(NodeId id) const;
426 
427  /// use a new set of database rows' ranges to perform learning
428  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
429  * indices. The subsequent learnings are then performed only on the union
430  * of the rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when
431  * performing cross validation tasks, in which part of the database should
432  * be ignored. An empty set of ranges is equivalent to an interval [X,Y)
433  * ranging over the whole database. */
434  template < template < typename > class XALLOC >
435  void useDatabaseRanges(
436  const std::vector< std::pair< std::size_t, std::size_t >,
437  XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
438 
439  /// reset the ranges to the one range corresponding to the whole database
440  void clearDatabaseRanges();
441 
442  /// returns the current database rows' ranges used for learning
443  /** @return The method returns a vector of pairs [Xi,Yi) of indices of
444  * rows in the database. The learning is performed on these set of rows.
445  * @warning an empty set of ranges means the whole database. */
446  const std::vector< std::pair< std::size_t, std::size_t > >& databaseRanges() const;
447 
448  /// sets the ranges of rows to be used for cross-validation learning
449  /** When applied on (x,k), the method indicates to the subsequent learnings
450  * that they should be performed on the xth fold in a k-fold
451  * cross-validation context. For instance, if a database has 1000 rows,
452  * and if we perform a 10-fold cross-validation, then, the first learning
453  * fold (learning_fold=0) corresponds to rows interval [100,1000) and the
454  * test dataset corresponds to [0,100). The second learning fold
455  * (learning_fold=1) is [0,100) U [200,1000) and the corresponding test
456  * dataset is [100,200).
457  * @param learning_fold a number indicating the set of rows used for
458  * learning. If N denotes the size of the database, and k_fold represents
459  * the number of folds in the cross validation, then the set of rows
460  * used for testing is [learning_fold * N / k_fold,
461  * (learning_fold+1) * N / k_fold) and the learning database is the
462  * complement in the database
463  * @param k_fold the value of "k" in k-fold cross validation
464  * @return a pair [x,y) of rows' indices that corresponds to the indices
465  * of rows in the original database that constitute the test dataset
466  * @throws OutOfBounds is raised if k_fold is equal to 0 or learning_fold
467  * is greater than or eqal to k_fold, or if k_fold is greater than
468  * or equal to the size of the database. */
469  std::pair< std::size_t, std::size_t > useCrossValidationFold(const std::size_t learning_fold,
470  const std::size_t k_fold);
471 
472 
473  /**
474  * Return the <statistic,pvalue> pair for chi2 test in the database
475  * @param id1 first variable
476  * @param id2 second variable
477  * @param knowing list of observed variables
478  * @return a std::pair<double,double>
479  */
480  std::pair< double, double >
481  chi2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {});
482  /**
483  * Return the <statistic,pvalue> pair for the BNLearner
484  * @param id1 first variable
485  * @param id2 second variable
486  * @param knowing list of observed variables
487  * @return a std::pair<double,double>
488  */
489  std::pair< double, double > chi2(const std::string& name1,
490  const std::string& name2,
491  const std::vector< std::string >& knowing = {});
492 
493  /**
494  * Return the <statistic,pvalue> pair for for G2 test in the database
495  * @param id1 first variable
496  * @param id2 second variable
497  * @param knowing list of observed variables
498  * @return a std::pair<double,double>
499  */
500  std::pair< double, double >
501  G2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {});
502  /**
503  * Return the <statistic,pvalue> pair for for G2 test in the database
504  * @param id1 first variable
505  * @param id2 second variable
506  * @param knowing list of observed variables
507  * @return a std::pair<double,double>
508  */
509  std::pair< double, double > G2(const std::string& name1,
510  const std::string& name2,
511  const std::vector< std::string >& knowing = {});
512 
513  /**
514  * Return the loglikelihood of vars in the base, conditioned by knowing for
515  * the BNLearner
516  * @param vars a vector of NodeIds
517  * @param knowing an optional vector of conditioning NodeIds
518  * @return a std::pair<double,double>
519  */
520  double logLikelihood(const std::vector< NodeId >& vars,
521  const std::vector< NodeId >& knowing = {});
522 
523  /**
524  * Return the loglikelihood of vars in the base, conditioned by knowing for
525  * the BNLearner
526  * @param vars a vector of name of rows
527  * @param knowing an optional vector of conditioning rows
528  * @return a std::pair<double,double>
529  */
530  double logLikelihood(const std::vector< std::string >& vars,
531  const std::vector< std::string >& knowing = {});
532 
533  /**
534  * Return the pseudoconts ofNodeIds vars in the base in a raw array
535  * @param vars a vector of
536  * @return a a std::vector<double> containing the contingency table
537  */
538  std::vector< double > rawPseudoCount(const std::vector< NodeId >& vars);
539 
540  /**
541  * Return the pseudoconts of vars in the base in a raw array
542  * @param vars a vector of name
543  * @return a std::vector<double> containing the contingency table
544  */
545  std::vector< double > rawPseudoCount(const std::vector< std::string >& vars);
546  /**
547  *
548  * @return the number of cols in the database
549  */
550  Size nbCols() const;
551 
552  /**
553  *
554  * @return the number of rows in the database
555  */
556  Size nbRows() const;
557 
558  /** use The EM algorithm to learn paramters
559  *
560  * if epsilon=0, EM is not used
561  */
562  void useEM(const double epsilon);
563 
564  /// returns true if the learner's database has missing values
565  bool hasMissingValues() const;
566 
567  /// @}
568 
569  // ##########################################################################
570  /// @name Score selection
571  // ##########################################################################
572  /// @{
573 
574  /// indicate that we wish to use an AIC score
575  void useScoreAIC();
576 
577  /// indicate that we wish to use a BD score
578  void useScoreBD();
579 
580  /// indicate that we wish to use a BDeu score
581  void useScoreBDeu();
582 
583  /// indicate that we wish to use a BIC score
584  void useScoreBIC();
585 
586  /// indicate that we wish to use a K2 score
587  void useScoreK2();
588 
589  /// indicate that we wish to use a Log2Likelihood score
590  void useScoreLog2Likelihood();
591 
592  /// @}
593 
594  // ##########################################################################
595  /// @name A priori selection / parameterization
596  // ##########################################################################
597  /// @{
598 
599  /// use no apriori
600  void useNoApriori();
601 
602  /// use the BDeu apriori
603  /** The BDeu apriori adds weight to all the cells of the countings
604  * tables. In other words, it adds weight rows in the database with
605  * equally probable values. */
606  void useAprioriBDeu(double weight = 1);
607 
608  /// use the apriori smoothing
609  /** @param weight pass in argument a weight if you wish to assign a weight
610  * to the smoothing, else the current weight of the genericBNLearner will
611  * be used. */
612  void useAprioriSmoothing(double weight = 1);
613 
614  /// use the Dirichlet apriori
615  void useAprioriDirichlet(const std::string& filename, double weight = 1);
616 
617 
618  /// checks whether the current score and apriori are compatible
619  /** @returns a non empty string if the apriori is somehow compatible with the
620  * score.*/
621  std::string checkScoreAprioriCompatibility() const;
622  /// @}
623 
624  // ##########################################################################
625  /// @name Learning algorithm selection
626  // ##########################################################################
627  /// @{
628 
629  /// indicate that we wish to use a greedy hill climbing algorithm
630  void useGreedyHillClimbing();
631 
632  /// indicate that we wish to use a local search with tabu list
633  /** @param tabu_size indicate the size of the tabu list
634  * @param nb_decrease indicate the max number of changes decreasing the
635  * score consecutively that we allow to apply */
636  void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2);
637 
638  /// indicate that we wish to use K2
639  void useK2(const Sequence< NodeId >& order);
640 
641  /// indicate that we wish to use K2
642  void useK2(const std::vector< NodeId >& order);
643 
644  /// indicate that we wish to use 3off2
645  void use3off2();
646 
647  /// indicate that we wish to use MIIC
648  void useMIIC();
649 
650  /// @}
651 
652  // ##########################################################################
653  /// @name 3off2/MIIC parameterization and specific results
654  // ##########################################################################
655  /// @{
656  /// indicate that we wish to use the NML correction for 3off2 and MIIC
657  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
658  void useNMLCorrection();
659  /// indicate that we wish to use the MDL correction for 3off2 and MIIC
660  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
661  void useMDLCorrection();
662  /// indicate that we wish to use the NoCorr correction for 3off2 and MIIC
663  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
664  void useNoCorrection();
665 
666  /// get the list of arcs hiding latent variables
667  /// @throws OperationNotAllowed when 3off2 or MIIC is not the selected algorithm
668  const std::vector< Arc > latentVariables() const;
669 
670  /// @}
671  // ##########################################################################
672  /// @name Accessors / Modifiers for adding constraints on learning
673  // ##########################################################################
674  /// @{
675 
676  /// sets the max indegree
677  void setMaxIndegree(Size max_indegree);
678 
679  /**
680  * sets a partial order on the nodes
681  * @param slice_order a NodeProperty given the rank (priority) of nodes in
682  * the partial order
683  */
684  void setSliceOrder(const NodeProperty< NodeId >& slice_order);
685 
686  /**
687  * sets a partial order on the nodes
688  * @param slices the list of list of variable names
689  */
690  void setSliceOrder(const std::vector< std::vector< std::string > >& slices);
691 
692  /// assign a set of forbidden arcs
693  void setForbiddenArcs(const ArcSet& set);
694 
695  /// @name assign a new forbidden arc
696  /// @{
697  void addForbiddenArc(const Arc& arc);
698  void addForbiddenArc(const NodeId tail, const NodeId head);
699  void addForbiddenArc(const std::string& tail, const std::string& head);
700  /// @}
701 
702  /// @name remove a forbidden arc
703  /// @{
704  void eraseForbiddenArc(const Arc& arc);
705  void eraseForbiddenArc(const NodeId tail, const NodeId head);
706  void eraseForbiddenArc(const std::string& tail, const std::string& head);
707  ///@}
708 
709  /// assign a set of forbidden arcs
710  void setMandatoryArcs(const ArcSet& set);
711 
712  /// @name assign a new forbidden arc
713  ///@{
714  void addMandatoryArc(const Arc& arc);
715  void addMandatoryArc(const NodeId tail, const NodeId head);
716  void addMandatoryArc(const std::string& tail, const std::string& head);
717  ///@}
718 
719  /// @name remove a forbidden arc
720  ///@{
721  void eraseMandatoryArc(const Arc& arc);
722  void eraseMandatoryArc(const NodeId tail, const NodeId head);
723  void eraseMandatoryArc(const std::string& tail, const std::string& head);
724  /// @}
725 
726  /// assign a set of forbidden edges
727  /// @warning Once at least one possible edge is defined, all other edges are
728  /// not possible anymore
729  /// @{
730  void setPossibleEdges(const EdgeSet& set);
731  void setPossibleSkeleton(const UndiGraph& skeleton);
732  /// @}
733 
734  /// @name assign a new possible edge
735  /// @warning Once at least one possible edge is defined, all other edges are
736  /// not possible anymore
737  /// @{
738  void addPossibleEdge(const Edge& edge);
739  void addPossibleEdge(const NodeId tail, const NodeId head);
740  void addPossibleEdge(const std::string& tail, const std::string& head);
741  /// @}
742 
743  /// @name remove a possible edge
744  /// @{
745  void erasePossibleEdge(const Edge& edge);
746  void erasePossibleEdge(const NodeId tail, const NodeId head);
747  void erasePossibleEdge(const std::string& tail, const std::string& head);
748  ///@}
749 
750  ///@}
751 
752  protected:
753  /// the score selected for learning
754  ScoreType scoreType_{ScoreType::BDeu};
755 
756  /// the score used
757  Score<>* score_{nullptr};
758 
759  /// the type of the parameter estimator
760  ParamEstimatorType paramEstimatorType_{ParamEstimatorType::ML};
761 
762  /// epsilon for EM. if espilon=0.0 : no EM
763  double epsilonEM_{0.0};
764 
765  /// the selected correction for 3off2 and miic
766  CorrectedMutualInformation<>* mutualInfo_{nullptr};
767 
768  /// the a priori selected for the score and parameters
769  AprioriType aprioriType_{AprioriType::NO_APRIORI};
770 
771  /// the apriori used
772  Apriori<>* apriori_{nullptr};
773 
774  AprioriNoApriori<>* noApriori_{nullptr};
775 
776  /// the weight of the apriori
777  double aprioriWeight_{1.0f};
778 
779  /// the constraint for 2TBNs
780  StructuralConstraintSliceOrder constraintSliceOrder_;
781 
782  /// the constraint for indegrees
783  StructuralConstraintIndegree constraintIndegree_;
784 
785  /// the constraint for tabu lists
786  StructuralConstraintTabuList constraintTabuList_;
787 
788  /// the constraint on forbidden arcs
789  StructuralConstraintForbiddenArcs constraintForbiddenArcs_;
790 
791  /// the constraint on possible Edges
792  StructuralConstraintPossibleEdges constraintPossibleEdges_;
793 
794  /// the constraint on mandatory arcs
795  StructuralConstraintMandatoryArcs constraintMandatoryArcs_;
796 
797  /// the selected learning algorithm
798  AlgoType selectedAlgo_{AlgoType::GREEDY_HILL_CLIMBING};
799 
800  /// the K2 algorithm
801  K2 algoK2_;
802 
803  /// the MIIC or 3off2 algorithm
804  Miic algoMiic3off2_;
805 
806  /// the penalty used in 3off2
807  typename CorrectedMutualInformation<>::KModeTypes kmode3Off2_{
808  CorrectedMutualInformation<>::KModeTypes::MDL};
809 
810  /// the parametric EM
811  DAG2BNLearner<> Dag2BN_;
812 
813  /// the greedy hill climbing algorithm
814  GreedyHillClimbing greedyHillClimbing_;
815 
816  /// the local search with tabu list algorithm
817  LocalSearchWithTabuList localSearchWithTabuList_;
818 
819  /// the database to be used by the scores and parameter estimators
820  Database scoreDatabase_;
821 
822  /// the set of rows' ranges within the database in which learning is done
823  std::vector< std::pair< std::size_t, std::size_t > > ranges_;
824 
825  /// the database used by the Dirichlet a priori
826  Database* aprioriDatabase_{nullptr};
827 
828  /// the filename for the Dirichlet a priori, if any
829  std::string aprioriDbname_;
830 
831 
832  /// an initial DAG given to learners
833  DAG initialDag_;
834 
835  /// the filename database
836  std::string filename_;
837 
838  // size of the tabu list
839  Size nbDecreasingChanges_{2};
840 
841  // the current algorithm as an approximationScheme
842  const ApproximationScheme* currentAlgorithm_{nullptr};
843 
844  /// reads a file and returns a databaseVectInRam
845  static DatabaseTable<> readFile_(const std::string& filename,
846  const std::vector< std::string >& missing_symbols);
847 
848  /// checks whether the extension of a CSV filename is correct
849  static void isCSVFileName_(const std::string& filename);
850 
851  /// create the apriori used for learning
852  void createApriori_();
853 
854  /// create the score used for learning
855  void createScore_();
856 
857  /// create the parameter estimator used for learning
858  ParamEstimator<>* createParamEstimator_(DBRowGeneratorParser<>& parser,
859  bool take_into_account_score = true);
860 
861  /// returns the DAG learnt
862  DAG learnDag_();
863 
864  /// prepares the initial graph for 3off2 or miic
865  MixedGraph prepareMiic3Off2_();
866 
867  /// returns the type (as a string) of a given apriori
868  const std::string& getAprioriType_() const;
869 
870  /// create the Corrected Mutual Information instance for Miic/3off2
871  void createCorrectedMutualInformation_();
872 
873 
874  public:
875  // ##########################################################################
876  /// @name redistribute signals AND implementation of interface
877  /// IApproximationSchemeConfiguration
878  // ##########################################################################
879  // in order to not pollute the proper code of genericBNLearner, we
880  // directly
881  // implement those
882  // very simples methods here.
883  /// {@ /// distribute signals
884  INLINE void setCurrentApproximationScheme(const ApproximationScheme* approximationScheme) {
885  currentAlgorithm_ = approximationScheme;
886  }
887 
888  INLINE void distributeProgress(const ApproximationScheme* approximationScheme,
889  Size pourcent,
890  double error,
891  double time) {
892  setCurrentApproximationScheme(approximationScheme);
893 
894  if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time);
895  };
896 
897  /// distribute signals
898  INLINE void distributeStop(const ApproximationScheme* approximationScheme,
899  std::string message) {
900  setCurrentApproximationScheme(approximationScheme);
901 
902  if (onStop.hasListener()) GUM_EMIT1(onStop, message);
903  };
904  /// @}
905 
906  /// Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|
907  /// If the criterion was disabled it will be enabled
908  /// @{
909  /// @throw OutOfBounds if eps<0
910  void setEpsilon(double eps) {
911  algoK2_.approximationScheme().setEpsilon(eps);
912  greedyHillClimbing_.setEpsilon(eps);
913  localSearchWithTabuList_.setEpsilon(eps);
914  Dag2BN_.setEpsilon(eps);
915  };
916 
917  /// Get the value of epsilon
918  double epsilon() const {
919  if (currentAlgorithm_ != nullptr)
920  return currentAlgorithm_->epsilon();
921  else
922  GUM_ERROR(FatalError, "No chosen algorithm for learning")
923  }
924 
925  /// Disable stopping criterion on epsilon
926  void disableEpsilon() {
927  algoK2_.approximationScheme().disableEpsilon();
928  greedyHillClimbing_.disableEpsilon();
929  localSearchWithTabuList_.disableEpsilon();
930  Dag2BN_.disableEpsilon();
931  };
932 
933  /// Enable stopping criterion on epsilon
934  void enableEpsilon() {
935  algoK2_.approximationScheme().enableEpsilon();
936  greedyHillClimbing_.enableEpsilon();
937  localSearchWithTabuList_.enableEpsilon();
938  Dag2BN_.enableEpsilon();
939  };
940 
941  /// @return true if stopping criterion on epsilon is enabled, false
942  /// otherwise
943  bool isEnabledEpsilon() const {
944  if (currentAlgorithm_ != nullptr)
945  return currentAlgorithm_->isEnabledEpsilon();
946  else
947  GUM_ERROR(FatalError, "No chosen algorithm for learning")
948  }
949  /// @}
950 
951  /// Given that we approximate f(t), stopping criterion on
952  /// d/dt(|f(t+1)-f(t)|)
953  /// If the criterion was disabled it will be enabled
954  /// @{
955  /// @throw OutOfBounds if rate<0
956  void setMinEpsilonRate(double rate) {
957  algoK2_.approximationScheme().setMinEpsilonRate(rate);
958  greedyHillClimbing_.setMinEpsilonRate(rate);
959  localSearchWithTabuList_.setMinEpsilonRate(rate);
960  Dag2BN_.setMinEpsilonRate(rate);
961  };
962 
963  /// Get the value of the minimal epsilon rate
964  double minEpsilonRate() const {
965  if (currentAlgorithm_ != nullptr)
966  return currentAlgorithm_->minEpsilonRate();
967  else
968  GUM_ERROR(FatalError, "No chosen algorithm for learning")
969  }
970 
971  /// Disable stopping criterion on epsilon rate
972  void disableMinEpsilonRate() {
973  algoK2_.approximationScheme().disableMinEpsilonRate();
974  greedyHillClimbing_.disableMinEpsilonRate();
975  localSearchWithTabuList_.disableMinEpsilonRate();
976  Dag2BN_.disableMinEpsilonRate();
977  };
978  /// Enable stopping criterion on epsilon rate
979  void enableMinEpsilonRate() {
980  algoK2_.approximationScheme().enableMinEpsilonRate();
981  greedyHillClimbing_.enableMinEpsilonRate();
982  localSearchWithTabuList_.enableMinEpsilonRate();
983  Dag2BN_.enableMinEpsilonRate();
984  };
985  /// @return true if stopping criterion on epsilon rate is enabled, false
986  /// otherwise
987  bool isEnabledMinEpsilonRate() const {
988  if (currentAlgorithm_ != nullptr)
989  return currentAlgorithm_->isEnabledMinEpsilonRate();
990  else
991  GUM_ERROR(FatalError, "No chosen algorithm for learning")
992  }
993  /// @}
994 
995  /// stopping criterion on number of iterations
996  /// @{
997  /// If the criterion was disabled it will be enabled
998  /// @param max The maximum number of iterations
999  /// @throw OutOfBounds if max<=1
1000  void setMaxIter(Size max) {
1001  algoK2_.approximationScheme().setMaxIter(max);
1002  greedyHillClimbing_.setMaxIter(max);
1003  localSearchWithTabuList_.setMaxIter(max);
1004  Dag2BN_.setMaxIter(max);
1005  };
1006 
1007  /// @return the criterion on number of iterations
1008  Size maxIter() const {
1009  if (currentAlgorithm_ != nullptr)
1010  return currentAlgorithm_->maxIter();
1011  else
1012  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1013  }
1014 
1015  /// Disable stopping criterion on max iterations
1016  void disableMaxIter() {
1017  algoK2_.approximationScheme().disableMaxIter();
1018  greedyHillClimbing_.disableMaxIter();
1019  localSearchWithTabuList_.disableMaxIter();
1020  Dag2BN_.disableMaxIter();
1021  };
1022  /// Enable stopping criterion on max iterations
1023  void enableMaxIter() {
1024  algoK2_.approximationScheme().enableMaxIter();
1025  greedyHillClimbing_.enableMaxIter();
1026  localSearchWithTabuList_.enableMaxIter();
1027  Dag2BN_.enableMaxIter();
1028  };
1029  /// @return true if stopping criterion on max iterations is enabled, false
1030  /// otherwise
1031  bool isEnabledMaxIter() const {
1032  if (currentAlgorithm_ != nullptr)
1033  return currentAlgorithm_->isEnabledMaxIter();
1034  else
1035  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1036  }
1037  /// @}
1038 
1039  /// stopping criterion on timeout
1040  /// If the criterion was disabled it will be enabled
1041  /// @{
1042  /// @throw OutOfBounds if timeout<=0.0
1043  /** timeout is time in second (double).
1044  */
1045  void setMaxTime(double timeout) {
1046  algoK2_.approximationScheme().setMaxTime(timeout);
1047  greedyHillClimbing_.setMaxTime(timeout);
1048  localSearchWithTabuList_.setMaxTime(timeout);
1049  Dag2BN_.setMaxTime(timeout);
1050  }
1051 
1052  /// returns the timeout (in seconds)
1053  double maxTime() const {
1054  if (currentAlgorithm_ != nullptr)
1055  return currentAlgorithm_->maxTime();
1056  else
1057  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1058  }
1059 
1060  /// get the current running time in second (double)
1061  double currentTime() const {
1062  if (currentAlgorithm_ != nullptr)
1063  return currentAlgorithm_->currentTime();
1064  else
1065  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1066  }
1067 
1068  /// Disable stopping criterion on timeout
1069  void disableMaxTime() {
1070  algoK2_.approximationScheme().disableMaxTime();
1071  greedyHillClimbing_.disableMaxTime();
1072  localSearchWithTabuList_.disableMaxTime();
1073  Dag2BN_.disableMaxTime();
1074  };
1075  void enableMaxTime() {
1076  algoK2_.approximationScheme().enableMaxTime();
1077  greedyHillClimbing_.enableMaxTime();
1078  localSearchWithTabuList_.enableMaxTime();
1079  Dag2BN_.enableMaxTime();
1080  };
1081  /// @return true if stopping criterion on timeout is enabled, false
1082  /// otherwise
1083  bool isEnabledMaxTime() const {
1084  if (currentAlgorithm_ != nullptr)
1085  return currentAlgorithm_->isEnabledMaxTime();
1086  else
1087  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1088  }
1089  /// @}
1090 
1091  /// how many samples between 2 stopping isEnableds
1092  /// @{
1093  /// @throw OutOfBounds if p<1
1094  void setPeriodSize(Size p) {
1095  algoK2_.approximationScheme().setPeriodSize(p);
1096  greedyHillClimbing_.setPeriodSize(p);
1097  localSearchWithTabuList_.setPeriodSize(p);
1098  Dag2BN_.setPeriodSize(p);
1099  };
1100 
1101  Size periodSize() const {
1102  if (currentAlgorithm_ != nullptr)
1103  return currentAlgorithm_->periodSize();
1104  else
1105  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1106  }
1107  /// @}
1108 
1109  /// verbosity
1110  /// @{
1111  void setVerbosity(bool v) {
1112  algoK2_.approximationScheme().setVerbosity(v);
1113  greedyHillClimbing_.setVerbosity(v);
1114  localSearchWithTabuList_.setVerbosity(v);
1115  Dag2BN_.setVerbosity(v);
1116  };
1117 
1118  bool verbosity() const {
1119  if (currentAlgorithm_ != nullptr)
1120  return currentAlgorithm_->verbosity();
1121  else
1122  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1123  }
1124  /// @}
1125 
1126  /// history
1127  /// @{
1128 
1129  ApproximationSchemeSTATE stateApproximationScheme() const {
1130  if (currentAlgorithm_ != nullptr)
1131  return currentAlgorithm_->stateApproximationScheme();
1132  else
1133  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1134  }
1135 
1136  /// @throw OperationNotAllowed if scheme not performed
1137  Size nbrIterations() const {
1138  if (currentAlgorithm_ != nullptr)
1139  return currentAlgorithm_->nbrIterations();
1140  else
1141  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1142  }
1143 
1144  /// @throw OperationNotAllowed if scheme not performed or verbosity=false
1145  const std::vector< double >& history() const {
1146  if (currentAlgorithm_ != nullptr)
1147  return currentAlgorithm_->history();
1148  else
1149  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1150  }
1151  /// @}
1152  };
1153 
1154  } /* namespace learning */
1155 
1156 } /* namespace gum */
1157 
1158 /// include the inlined functions if necessary
1159 #ifndef GUM_NO_INLINE
1160 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h>
1161 #endif /* GUM_NO_INLINE */
1162 
1163 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h>
1164 
1165 #endif /* GUM_LEARNING_GENERIC_BN_LEARNER_H */