aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
genericBNLearner.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief A class for generic framework of learning algorithms that can easily
25  * be used.
26  *
27  * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and
28  * LocalSearchWithTabuList
29  *
30  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
31  */
32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H
33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H
34 
35 #include <sstream>
36 #include <memory>
37 
38 #include <agrum/BN/BayesNet.h>
39 #include <agrum/agrum.h>
40 #include <agrum/tools/core/bijection.h>
41 #include <agrum/tools/core/sequence.h>
42 #include <agrum/tools/graphs/DAG.h>
43 
44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h>
45 #include <agrum/tools/database/DBRowGeneratorParser.h>
46 #include <agrum/tools/database/DBInitializerFromCSV.h>
47 #include <agrum/tools/database/databaseTable.h>
48 #include <agrum/tools/database/DBRowGeneratorParser.h>
49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h>
50 #include <agrum/tools/database/DBRowGeneratorEM.h>
51 #include <agrum/tools/database/DBRowGeneratorSet.h>
52 
53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h>
54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h>
55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h>
56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h>
57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h>
58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
59 
60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h>
61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h>
62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h>
63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h>
64 
65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h>
66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h>
67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h>
68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h>
69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h>
70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h>
71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h>
72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h>
73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h>
74 
75 #include <agrum/BN/learning/structureUtils/graphChange.h>
76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h>
77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h>
78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h>
79 
80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h>
82 
83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h>
84 #include <agrum/tools/core/approximations/approximationSchemeListener.h>
85 
86 #include <agrum/BN/learning/K2.h>
87 #include <agrum/BN/learning/Miic.h>
88 #include <agrum/BN/learning/greedyHillClimbing.h>
89 #include <agrum/BN/learning/localSearchWithTabuList.h>
90 
91 #include <agrum/tools/core/signal/signaler.h>
92 
93 namespace gum {
94 
95  namespace learning {
96 
97  class BNLearnerListener;
98 
99  /** @class genericBNLearner
100  * @brief A pack of learning algorithms that can easily be used
101  *
102  * The pack currently contains K2, GreedyHillClimbing and
103  * LocalSearchWithTabuList also 3off2/miic
104  * @ingroup learning_group
105  */
106  class genericBNLearner: public gum::IApproximationSchemeConfiguration {
107  // private:
108  public:
109  /// an enumeration enabling to select easily the score we wish to use
110  enum class ScoreType
111  {
112  AIC,
113  BD,
114  BDeu,
115  BIC,
116  K2,
117  LOG2LIKELIHOOD
118  };
119 
120  /// an enumeration to select the type of parameter estimation we shall
121  /// apply
122  enum class ParamEstimatorType
123  { ML };
124 
125  /// an enumeration to select the apriori
126  enum class AprioriType
127  {
128  NO_APRIORI,
129  SMOOTHING,
130  DIRICHLET_FROM_DATABASE,
131  BDEU
132  };
133 
134  /// an enumeration to select easily the learning algorithm to use
135  enum class AlgoType
136  {
137  K2,
138  GREEDY_HILL_CLIMBING,
139  LOCAL_SEARCH_WITH_TABU_LIST,
140  MIIC_THREE_OFF_TWO
141  };
142 
143 
144  /// a helper to easily read databases
145  class Database {
146  public:
147  // ########################################################################
148  /// @name Constructors / Destructors
149  // ########################################################################
150  /// @{
151 
152  /// default constructor
153  /** @param file the name of the CSV file containing the data
154  * @param missing_symbols the set of symbols in the CSV file that
155  * correspond to missing data */
156  explicit Database(const std::string& file,
157  const std::vector< std::string >& missing_symbols);
158 
159  /// default constructor
160  /** @param db an already initialized database table that is used to
161  * fill the Database */
162  explicit Database(const DatabaseTable<>& db);
163 
164  /// constructor for the aprioris
165  /** We must ensure that the variables of the Database are identical to
166  * those of the score database (else the countings used by the
167  * scores might be erroneous). However, we allow the variables to be
168  * ordered differently in the two databases: variables with the same
169  * name in both databases are supposed to be the same.
170  * @param file the name of the CSV file containing the data
171  * @param score_database the main database used for the learning
172  * @param missing_symbols the set of symbols in the CSV file that
173  * correspond to missing data
174  */
175  Database(const std::string& filename,
176  Database& score_database,
177  const std::vector< std::string >& missing_symbols);
178 
179  /// constructor with a BN providing the variables of interest
180  /** @param file the name of the CSV file containing the data
181  * @param bn a Bayesian network indicating which variables of the CSV
182  * file are used for learning
183  * @param missing_symbols the set of symbols in the CSV file that
184  * correspond to missing data
185  */
186  template < typename GUM_SCALAR >
187  Database(const std::string& filename,
188  const gum::BayesNet< GUM_SCALAR >& bn,
189  const std::vector< std::string >& missing_symbols);
190 
191  /// copy constructor
192  Database(const Database& from);
193 
194  /// move constructor
195  Database(Database&& from);
196 
197  /// destructor
198  ~Database();
199 
200  /// @}
201 
202  // ########################################################################
203  /// @name Operators
204  // ########################################################################
205  /// @{
206 
207  /// copy operator
208  Database& operator=(const Database& from);
209 
210  /// move operator
211  Database& operator=(Database&& from);
212 
213  /// @}
214 
215  // ########################################################################
216  /// @name Accessors / Modifiers
217  // ########################################################################
218  /// @{
219 
220  /// returns the parser for the database
221  DBRowGeneratorParser<>& parser();
222 
223  /// returns the domain sizes of the variables
224  const std::vector< std::size_t >& domainSizes() const;
225 
226  /// returns the names of the variables in the database
227  const std::vector< std::string >& names() const;
228 
229  /// returns the node id corresponding to a variable name
230  NodeId idFromName(const std::string& var_name) const;
231 
232  /// returns the variable name corresponding to a given node id
233  const std::string& nameFromId(NodeId id) const;
234 
235  /// returns the internal database table
236  const DatabaseTable<>& databaseTable() const;
237 
238  /** @brief assign a weight to all the rows of the database so
239  * that the sum of their weights is equal to new_weight */
240  void setDatabaseWeight(const double new_weight);
241 
242  /// returns the mapping between node ids and their columns in the database
243  const Bijection< NodeId, std::size_t >& nodeId2Columns() const;
244 
245  /// returns the set of missing symbols taken into account
246  const std::vector< std::string >& missingSymbols() const;
247 
248  /// returns the number of records in the database
249  std::size_t nbRows() const;
250 
251  /// returns the number of records in the database
252  std::size_t size() const;
253 
254  /// sets the weight of the ith record
255  /** @throws OutOfBounds if i is outside the set of indices of the
256  * records or if the weight is negative
257  */
258  void setWeight(const std::size_t i, const double weight);
259 
260  /// returns the weight of the ith record
261  /** @throws OutOfBounds if i is outside the set of indices of the
262  * records */
263  double weight(const std::size_t i) const;
264 
265  /// returns the weight of the whole database
266  double weight() const;
267 
268 
269  /// @}
270 
271  protected:
272  /// the database itself
273  DatabaseTable<> _database_;
274 
275  /// the parser used for reading the database
276  DBRowGeneratorParser<>* _parser_{nullptr};
277 
278  /// the domain sizes of the variables (useful to speed-up computations)
279  std::vector< std::size_t > _domain_sizes_;
280 
281  /// a bijection assigning to each variable name its NodeId
282  Bijection< NodeId, std::size_t > _nodeId2cols_;
283 
284 /// the max number of threads authorized
285 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE)
286  Size _max_threads_number_{getMaxNumberOfThreads()};
287 #else
288  Size _max_threads_number_{1};
289 #endif /* GUM_DEBUG_MODE */
290 
291  /// the minimal number of rows to parse (on average) by thread
292  Size _min_nb_rows_per_thread_{100};
293 
294  private:
295  // returns the set of variables as a BN. This is convenient for
296  // the constructors of apriori Databases
297  template < typename GUM_SCALAR >
298  BayesNet< GUM_SCALAR > _BNVars_() const;
299  };
300 
301  /// sets the apriori weight
302  void _setAprioriWeight_(double weight);
303 
304  public:
305  // ##########################################################################
306  /// @name Constructors / Destructors
307  // ##########################################################################
308  /// @{
309 
310  /// default constructor
311  /**
312  * read the database file for the score / parameter estimation and var
313  * names
314  */
315  genericBNLearner(const std::string& filename,
316  const std::vector< std::string >& missing_symbols);
317  genericBNLearner(const DatabaseTable<>& db);
318 
319  /**
320  * read the database file for the score / parameter estimation and var
321  * names
322  * @param filename The file to learn from.
323  * @param modalities indicate for some nodes (not necessarily all the
324  * nodes of the BN) which modalities they should have and in which order
325  * these modalities should be stored into the nodes. For instance, if
326  * modalities = { 1 -> {True, False, Big} }, then the node of id 1 in the
327  * BN will have 3 modalities, the first one being True, the second one
328  * being False, and the third bein Big.
329  * @param parse_database if true, the modalities specified by the user
330  * will be considered as a superset of the modalities of the variables. A
331  * parsing of the database will allow to determine which ones are really
332  * necessary and will keep them in the order specified by the user
333  * (NodeProperty modalities). If parse_database is set to false (the
334  * default), then the modalities specified by the user will be considered
335  * as being exactly those of the variables of the BN (as a consequence,
336  * if we find other values in the database, an exception will be raised
337  * during learning). */
338  template < typename GUM_SCALAR >
339  genericBNLearner(const std::string& filename,
340  const gum::BayesNet< GUM_SCALAR >& src,
341  const std::vector< std::string >& missing_symbols);
342 
343  /// copy constructor
344  genericBNLearner(const genericBNLearner&);
345 
346  /// move constructor
347  genericBNLearner(genericBNLearner&&);
348 
349  /// destructor
350  virtual ~genericBNLearner();
351 
352  /// @}
353 
354  // ##########################################################################
355  /// @name Operators
356  // ##########################################################################
357  /// @{
358 
359  /// copy operator
360  genericBNLearner& operator=(const genericBNLearner&);
361 
362  /// move operator
363  genericBNLearner& operator=(genericBNLearner&&);
364 
365  /// @}
366 
367  // ##########################################################################
368  /// @name Accessors / Modifiers
369  // ##########################################################################
370  /// @{
371 
372  /// learn a structure from a file (must have read the db before)
373  DAG learnDAG();
374 
375  /// learn a partial structure from a file (must have read the db before and
376  /// must have selected miic or 3off2)
377  MixedGraph learnMixedStructure();
378 
379  /// sets an initial DAG structure
380  void setInitialDAG(const DAG&);
381 
382  /// returns the names of the variables in the database
383  const std::vector< std::string >& names() const;
384 
385  /// returns the domain sizes of the variables in the database
386  const std::vector< std::size_t >& domainSizes() const;
387  Size domainSize(NodeId var) const;
388  Size domainSize(const std::string& var) const;
389 
390  /// returns the node id corresponding to a variable name
391  /**
392  * @throw MissingVariableInDatabase if a variable of the BN is not found
393  * in the database.
394  */
395  NodeId idFromName(const std::string& var_name) const;
396 
397  /// returns the database used by the BNLearner
398  const DatabaseTable<>& database() const;
399 
400  /** @brief assign a weight to all the rows of the learning database so
401  * that the sum of their weights is equal to new_weight */
402  void setDatabaseWeight(const double new_weight);
403 
404  /// sets the weight of the ith record of the database
405  /** @throws OutOfBounds if i is outside the set of indices of the
406  * records or if the weight is negative
407  */
408  void setRecordWeight(const std::size_t i, const double weight);
409 
410  /// returns the weight of the ith record
411  /** @throws OutOfBounds if i is outside the set of indices of the
412  * records */
413  double recordWeight(const std::size_t i) const;
414 
415  /// returns the weight of the whole database
416  double databaseWeight() const;
417 
418  /// returns the variable name corresponding to a given node id
419  const std::string& nameFromId(NodeId id) const;
420 
421  /// use a new set of database rows' ranges to perform learning
422  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
423  * indices. The subsequent learnings are then performed only on the union
424  * of the rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when
425  * performing cross validation tasks, in which part of the database should
426  * be ignored. An empty set of ranges is equivalent to an interval [X,Y)
427  * ranging over the whole database. */
428  template < template < typename > class XALLOC >
429  void useDatabaseRanges(
430  const std::vector< std::pair< std::size_t, std::size_t >,
431  XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
432 
433  /// reset the ranges to the one range corresponding to the whole database
434  void clearDatabaseRanges();
435 
436  /// returns the current database rows' ranges used for learning
437  /** @return The method returns a vector of pairs [Xi,Yi) of indices of
438  * rows in the database. The learning is performed on these set of rows.
439  * @warning an empty set of ranges means the whole database. */
440  const std::vector< std::pair< std::size_t, std::size_t > >& databaseRanges() const;
441 
442  /// sets the ranges of rows to be used for cross-validation learning
443  /** When applied on (x,k), the method indicates to the subsequent learnings
444  * that they should be performed on the xth fold in a k-fold
445  * cross-validation context. For instance, if a database has 1000 rows,
446  * and if we perform a 10-fold cross-validation, then, the first learning
447  * fold (learning_fold=0) corresponds to rows interval [100,1000) and the
448  * test dataset corresponds to [0,100). The second learning fold
449  * (learning_fold=1) is [0,100) U [200,1000) and the corresponding test
450  * dataset is [100,200).
451  * @param learning_fold a number indicating the set of rows used for
452  * learning. If N denotes the size of the database, and k_fold represents
453  * the number of folds in the cross validation, then the set of rows
454  * used for testing is [learning_fold * N / k_fold,
455  * (learning_fold+1) * N / k_fold) and the learning database is the
456  * complement in the database
457  * @param k_fold the value of "k" in k-fold cross validation
458  * @return a pair [x,y) of rows' indices that corresponds to the indices
459  * of rows in the original database that constitute the test dataset
460  * @throws OutOfBounds is raised if k_fold is equal to 0 or learning_fold
461  * is greater than or eqal to k_fold, or if k_fold is greater than
462  * or equal to the size of the database. */
463  std::pair< std::size_t, std::size_t > useCrossValidationFold(const std::size_t learning_fold,
464  const std::size_t k_fold);
465 
466 
467  /**
468  * Return the <statistic,pvalue> pair for chi2 test in the database
469  * @param id1 first variable
470  * @param id2 second variable
471  * @param knowing list of observed variables
472  * @return a std::pair<double,double>
473  */
474  std::pair< double, double >
475  chi2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {});
476  /**
477  * Return the <statistic,pvalue> pair for the BNLearner
478  * @param id1 first variable
479  * @param id2 second variable
480  * @param knowing list of observed variables
481  * @return a std::pair<double,double>
482  */
483  std::pair< double, double > chi2(const std::string& name1,
484  const std::string& name2,
485  const std::vector< std::string >& knowing = {});
486 
487  /**
488  * Return the <statistic,pvalue> pair for for G2 test in the database
489  * @param id1 first variable
490  * @param id2 second variable
491  * @param knowing list of observed variables
492  * @return a std::pair<double,double>
493  */
494  std::pair< double, double >
495  G2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {});
496  /**
497  * Return the <statistic,pvalue> pair for for G2 test in the database
498  * @param id1 first variable
499  * @param id2 second variable
500  * @param knowing list of observed variables
501  * @return a std::pair<double,double>
502  */
503  std::pair< double, double > G2(const std::string& name1,
504  const std::string& name2,
505  const std::vector< std::string >& knowing = {});
506 
507  /**
508  * Return the loglikelihood of vars in the base, conditioned by knowing for
509  * the BNLearner
510  * @param vars a vector of NodeIds
511  * @param knowing an optional vector of conditioning NodeIds
512  * @return a std::pair<double,double>
513  */
514  double logLikelihood(const std::vector< NodeId >& vars,
515  const std::vector< NodeId >& knowing = {});
516 
517  /**
518  * Return the loglikelihood of vars in the base, conditioned by knowing for
519  * the BNLearner
520  * @param vars a vector of name of rows
521  * @param knowing an optional vector of conditioning rows
522  * @return a std::pair<double,double>
523  */
524  double logLikelihood(const std::vector< std::string >& vars,
525  const std::vector< std::string >& knowing = {});
526 
527  /**
528  * Return the pseudoconts ofNodeIds vars in the base in a raw array
529  * @param vars a vector of
530  * @return a a std::vector<double> containing the contingency table
531  */
532  std::vector< double > rawPseudoCount(const std::vector< NodeId >& vars);
533 
534  /**
535  * Return the pseudoconts of vars in the base in a raw array
536  * @param vars a vector of name
537  * @return a std::vector<double> containing the contingency table
538  */
539  std::vector< double > rawPseudoCount(const std::vector< std::string >& vars);
540  /**
541  *
542  * @return the number of cols in the database
543  */
544  Size nbCols() const;
545 
546  /**
547  *
548  * @return the number of rows in the database
549  */
550  Size nbRows() const;
551 
552  /** use The EM algorithm to learn paramters
553  *
554  * if epsilon=0, EM is not used
555  */
556  void useEM(const double epsilon);
557 
558  /// returns true if the learner's database has missing values
559  bool hasMissingValues() const;
560 
561  /// @}
562 
563  // ##########################################################################
564  /// @name Score selection
565  // ##########################################################################
566  /// @{
567 
568  /// indicate that we wish to use an AIC score
569  void useScoreAIC();
570 
571  /// indicate that we wish to use a BD score
572  void useScoreBD();
573 
574  /// indicate that we wish to use a BDeu score
575  void useScoreBDeu();
576 
577  /// indicate that we wish to use a BIC score
578  void useScoreBIC();
579 
580  /// indicate that we wish to use a K2 score
581  void useScoreK2();
582 
583  /// indicate that we wish to use a Log2Likelihood score
584  void useScoreLog2Likelihood();
585 
586  /// @}
587 
588  // ##########################################################################
589  /// @name A priori selection / parameterization
590  // ##########################################################################
591  /// @{
592 
593  /// use no apriori
594  void useNoApriori();
595 
596  /// use the BDeu apriori
597  /** The BDeu apriori adds weight to all the cells of the countings
598  * tables. In other words, it adds weight rows in the database with
599  * equally probable values. */
600  void useAprioriBDeu(double weight = 1);
601 
602  /// use the apriori smoothing
603  /** @param weight pass in argument a weight if you wish to assign a weight
604  * to the smoothing, else the current weight of the genericBNLearner will
605  * be used. */
606  void useAprioriSmoothing(double weight = 1);
607 
608  /// use the Dirichlet apriori
609  void useAprioriDirichlet(const std::string& filename, double weight = 1);
610 
611 
612  /// checks whether the current score and apriori are compatible
613  /** @returns a non empty string if the apriori is somehow compatible with the
614  * score.*/
615  std::string checkScoreAprioriCompatibility();
616  /// @}
617 
618  // ##########################################################################
619  /// @name Learning algorithm selection
620  // ##########################################################################
621  /// @{
622 
623  /// indicate that we wish to use a greedy hill climbing algorithm
624  void useGreedyHillClimbing();
625 
626  /// indicate that we wish to use a local search with tabu list
627  /** @param tabu_size indicate the size of the tabu list
628  * @param nb_decrease indicate the max number of changes decreasing the
629  * score consecutively that we allow to apply */
630  void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2);
631 
632  /// indicate that we wish to use K2
633  void useK2(const Sequence< NodeId >& order);
634 
635  /// indicate that we wish to use K2
636  void useK2(const std::vector< NodeId >& order);
637 
638  /// indicate that we wish to use 3off2
639  void use3off2();
640 
641  /// indicate that we wish to use MIIC
642  void useMIIC();
643 
644  /// @}
645 
646  // ##########################################################################
647  /// @name 3off2/MIIC parameterization and specific results
648  // ##########################################################################
649  /// @{
650  /// indicate that we wish to use the NML correction for 3off2 and MIIC
651  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
652  void useNMLCorrection();
653  /// indicate that we wish to use the MDL correction for 3off2 and MIIC
654  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
655  void useMDLCorrection();
656  /// indicate that we wish to use the NoCorr correction for 3off2 and MIIC
657  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
658  void useNoCorrection();
659 
660  /// get the list of arcs hiding latent variables
661  /// @throws OperationNotAllowed when 3off2 or MIIC is not the selected algorithm
662  const std::vector< Arc > latentVariables() const;
663 
664  /// @}
665  // ##########################################################################
666  /// @name Accessors / Modifiers for adding constraints on learning
667  // ##########################################################################
668  /// @{
669 
670  /// sets the max indegree
671  void setMaxIndegree(Size max_indegree);
672 
673  /**
674  * sets a partial order on the nodes
675  * @param slice_order a NodeProperty given the rank (priority) of nodes in
676  * the partial order
677  */
678  void setSliceOrder(const NodeProperty< NodeId >& slice_order);
679 
680  /**
681  * sets a partial order on the nodes
682  * @param slices the list of list of variable names
683  */
684  void setSliceOrder(const std::vector< std::vector< std::string > >& slices);
685 
686  /// assign a set of forbidden arcs
687  void setForbiddenArcs(const ArcSet& set);
688 
689  /// @name assign a new forbidden arc
690  /// @{
691  void addForbiddenArc(const Arc& arc);
692  void addForbiddenArc(const NodeId tail, const NodeId head);
693  void addForbiddenArc(const std::string& tail, const std::string& head);
694  /// @}
695 
696  /// @name remove a forbidden arc
697  /// @{
698  void eraseForbiddenArc(const Arc& arc);
699  void eraseForbiddenArc(const NodeId tail, const NodeId head);
700  void eraseForbiddenArc(const std::string& tail, const std::string& head);
701  ///@}
702 
703  /// assign a set of forbidden arcs
704  void setMandatoryArcs(const ArcSet& set);
705 
706  /// @name assign a new forbidden arc
707  ///@{
708  void addMandatoryArc(const Arc& arc);
709  void addMandatoryArc(const NodeId tail, const NodeId head);
710  void addMandatoryArc(const std::string& tail, const std::string& head);
711  ///@}
712 
713  /// @name remove a forbidden arc
714  ///@{
715  void eraseMandatoryArc(const Arc& arc);
716  void eraseMandatoryArc(const NodeId tail, const NodeId head);
717  void eraseMandatoryArc(const std::string& tail, const std::string& head);
718  /// @}
719 
720  /// assign a set of forbidden edges
721  /// @warning Once at least one possible edge is defined, all other edges are
722  /// not possible anymore
723  /// @{
724  void setPossibleEdges(const EdgeSet& set);
725  void setPossibleSkeleton(const UndiGraph& skeleton);
726  /// @}
727 
728  /// @name assign a new possible edge
729  /// @warning Once at least one possible edge is defined, all other edges are
730  /// not possible anymore
731  /// @{
732  void addPossibleEdge(const Edge& edge);
733  void addPossibleEdge(const NodeId tail, const NodeId head);
734  void addPossibleEdge(const std::string& tail, const std::string& head);
735  /// @}
736 
737  /// @name remove a possible edge
738  /// @{
739  void erasePossibleEdge(const Edge& edge);
740  void erasePossibleEdge(const NodeId tail, const NodeId head);
741  void erasePossibleEdge(const std::string& tail, const std::string& head);
742  ///@}
743 
744  ///@}
745 
746  protected:
747  /// the score selected for learning
748  ScoreType scoreType_{ScoreType::BDeu};
749 
750  /// the score used
751  Score<>* score_{nullptr};
752 
753  /// the type of the parameter estimator
754  ParamEstimatorType paramEstimatorType_{ParamEstimatorType::ML};
755 
756  /// epsilon for EM. if espilon=0.0 : no EM
757  double epsilonEM_{0.0};
758 
759  /// the selected correction for 3off2 and miic
760  CorrectedMutualInformation<>* mutualInfo_{nullptr};
761 
762  /// the a priori selected for the score and parameters
763  AprioriType aprioriType_{AprioriType::NO_APRIORI};
764 
765  /// the apriori used
766  Apriori<>* apriori_{nullptr};
767 
768  AprioriNoApriori<>* noApriori_{nullptr};
769 
770  /// the weight of the apriori
771  double aprioriWeight_{1.0f};
772 
773  /// the constraint for 2TBNs
774  StructuralConstraintSliceOrder constraintSliceOrder_;
775 
776  /// the constraint for indegrees
777  StructuralConstraintIndegree constraintIndegree_;
778 
779  /// the constraint for tabu lists
780  StructuralConstraintTabuList constraintTabuList_;
781 
782  /// the constraint on forbidden arcs
783  StructuralConstraintForbiddenArcs constraintForbiddenArcs_;
784 
785  /// the constraint on possible Edges
786  StructuralConstraintPossibleEdges constraintPossibleEdges_;
787 
788  /// the constraint on forbidden arcs
789  StructuralConstraintMandatoryArcs constraintMandatoryArcs_;
790 
791  /// the selected learning algorithm
792  AlgoType selectedAlgo_{AlgoType::GREEDY_HILL_CLIMBING};
793 
794  /// the K2 algorithm
795  K2 algoK2_;
796 
797  /// the MIIC or 3off2 algorithm
798  Miic algoMiic3off2_;
799 
800  /// the penalty used in 3off2
801  typename CorrectedMutualInformation<>::KModeTypes kmode3Off2_{
802  CorrectedMutualInformation<>::KModeTypes::MDL};
803 
804  /// the parametric EM
805  DAG2BNLearner<> Dag2BN_;
806 
807  /// the greedy hill climbing algorithm
808  GreedyHillClimbing greedyHillClimbing_;
809 
810  /// the local search with tabu list algorithm
811  LocalSearchWithTabuList localSearchWithTabuList_;
812 
813  /// the database to be used by the scores and parameter estimators
814  Database scoreDatabase_;
815 
816  /// the set of rows' ranges within the database in which learning is done
817  std::vector< std::pair< std::size_t, std::size_t > > ranges_;
818 
819  /// the database used by the Dirichlet a priori
820  Database* aprioriDatabase_{nullptr};
821 
822  /// the filename for the Dirichlet a priori, if any
823  std::string aprioriDbname_;
824 
825  /// an initial DAG given to learners
826  DAG initialDag_;
827 
828  // the current algorithm as an approximationScheme
829  const ApproximationScheme* currentAlgorithm_{nullptr};
830 
831  /// reads a file and returns a databaseVectInRam
832  static DatabaseTable<> readFile_(const std::string& filename,
833  const std::vector< std::string >& missing_symbols);
834 
835  /// checks whether the extension of a CSV filename is correct
836  static void checkFileName_(const std::string& filename);
837 
838  /// create the apriori used for learning
839  void createApriori_();
840 
841  /// create the score used for learning
842  void createScore_();
843 
844  /// create the parameter estimator used for learning
845  ParamEstimator<>* createParamEstimator_(DBRowGeneratorParser<>& parser,
846  bool take_into_account_score = true);
847 
848  /// returns the DAG learnt
849  DAG learnDag_();
850 
851  /// prepares the initial graph for 3off2 or miic
852  MixedGraph prepareMiic3Off2_();
853 
854  /// returns the type (as a string) of a given apriori
855  const std::string& getAprioriType_() const;
856 
857  /// create the Corrected Mutual Information instance for Miic/3off2
858  void createCorrectedMutualInformation_();
859 
860 
861  public:
862  // ##########################################################################
863  /// @name redistribute signals AND implemenation of interface
864  /// IApproximationSchemeConfiguration
865  // ##########################################################################
866  // in order to not pollute the proper code of genericBNLearner, we
867  // directly
868  // implement those
869  // very simples methods here.
870  /// {@ /// distribute signals
871  INLINE void setCurrentApproximationScheme(const ApproximationScheme* approximationScheme) {
872  currentAlgorithm_ = approximationScheme;
873  }
874 
875  INLINE void distributeProgress(const ApproximationScheme* approximationScheme,
876  Size pourcent,
877  double error,
878  double time) {
879  setCurrentApproximationScheme(approximationScheme);
880 
881  if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time);
882  };
883 
884  /// distribute signals
885  INLINE void distributeStop(const ApproximationScheme* approximationScheme,
886  std::string message) {
887  setCurrentApproximationScheme(approximationScheme);
888 
889  if (onStop.hasListener()) GUM_EMIT1(onStop, message);
890  };
891  /// @}
892 
893  /// Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|
894  /// If the criterion was disabled it will be enabled
895  /// @{
896  /// @throw OutOfLowerBound if eps<0
897  void setEpsilon(double eps) {
898  algoK2_.approximationScheme().setEpsilon(eps);
899  greedyHillClimbing_.setEpsilon(eps);
900  localSearchWithTabuList_.setEpsilon(eps);
901  Dag2BN_.setEpsilon(eps);
902  };
903 
904  /// Get the value of epsilon
905  double epsilon() const {
906  if (currentAlgorithm_ != nullptr)
907  return currentAlgorithm_->epsilon();
908  else
909  GUM_ERROR(FatalError, "No chosen algorithm for learning")
910  }
911 
912  /// Disable stopping criterion on epsilon
913  void disableEpsilon() {
914  algoK2_.approximationScheme().disableEpsilon();
915  greedyHillClimbing_.disableEpsilon();
916  localSearchWithTabuList_.disableEpsilon();
917  Dag2BN_.disableEpsilon();
918  };
919 
920  /// Enable stopping criterion on epsilon
921  void enableEpsilon() {
922  algoK2_.approximationScheme().enableEpsilon();
923  greedyHillClimbing_.enableEpsilon();
924  localSearchWithTabuList_.enableEpsilon();
925  Dag2BN_.enableEpsilon();
926  };
927 
928  /// @return true if stopping criterion on epsilon is enabled, false
929  /// otherwise
930  bool isEnabledEpsilon() const {
931  if (currentAlgorithm_ != nullptr)
932  return currentAlgorithm_->isEnabledEpsilon();
933  else
934  GUM_ERROR(FatalError, "No chosen algorithm for learning")
935  }
936  /// @}
937 
938  /// Given that we approximate f(t), stopping criterion on
939  /// d/dt(|f(t+1)-f(t)|)
940  /// If the criterion was disabled it will be enabled
941  /// @{
942  /// @throw OutOfLowerBound if rate<0
943  void setMinEpsilonRate(double rate) {
944  algoK2_.approximationScheme().setMinEpsilonRate(rate);
945  greedyHillClimbing_.setMinEpsilonRate(rate);
946  localSearchWithTabuList_.setMinEpsilonRate(rate);
947  Dag2BN_.setMinEpsilonRate(rate);
948  };
949 
950  /// Get the value of the minimal epsilon rate
951  double minEpsilonRate() const {
952  if (currentAlgorithm_ != nullptr)
953  return currentAlgorithm_->minEpsilonRate();
954  else
955  GUM_ERROR(FatalError, "No chosen algorithm for learning")
956  }
957 
958  /// Disable stopping criterion on epsilon rate
959  void disableMinEpsilonRate() {
960  algoK2_.approximationScheme().disableMinEpsilonRate();
961  greedyHillClimbing_.disableMinEpsilonRate();
962  localSearchWithTabuList_.disableMinEpsilonRate();
963  Dag2BN_.disableMinEpsilonRate();
964  };
965  /// Enable stopping criterion on epsilon rate
966  void enableMinEpsilonRate() {
967  algoK2_.approximationScheme().enableMinEpsilonRate();
968  greedyHillClimbing_.enableMinEpsilonRate();
969  localSearchWithTabuList_.enableMinEpsilonRate();
970  Dag2BN_.enableMinEpsilonRate();
971  };
972  /// @return true if stopping criterion on epsilon rate is enabled, false
973  /// otherwise
974  bool isEnabledMinEpsilonRate() const {
975  if (currentAlgorithm_ != nullptr)
976  return currentAlgorithm_->isEnabledMinEpsilonRate();
977  else
978  GUM_ERROR(FatalError, "No chosen algorithm for learning")
979  }
980  /// @}
981 
982  /// stopping criterion on number of iterations
983  /// @{
984  /// If the criterion was disabled it will be enabled
985  /// @param max The maximum number of iterations
986  /// @throw OutOfLowerBound if max<=1
987  void setMaxIter(Size max) {
988  algoK2_.approximationScheme().setMaxIter(max);
989  greedyHillClimbing_.setMaxIter(max);
990  localSearchWithTabuList_.setMaxIter(max);
991  Dag2BN_.setMaxIter(max);
992  };
993 
994  /// @return the criterion on number of iterations
995  Size maxIter() const {
996  if (currentAlgorithm_ != nullptr)
997  return currentAlgorithm_->maxIter();
998  else
999  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1000  }
1001 
1002  /// Disable stopping criterion on max iterations
1003  void disableMaxIter() {
1004  algoK2_.approximationScheme().disableMaxIter();
1005  greedyHillClimbing_.disableMaxIter();
1006  localSearchWithTabuList_.disableMaxIter();
1007  Dag2BN_.disableMaxIter();
1008  };
1009  /// Enable stopping criterion on max iterations
1010  void enableMaxIter() {
1011  algoK2_.approximationScheme().enableMaxIter();
1012  greedyHillClimbing_.enableMaxIter();
1013  localSearchWithTabuList_.enableMaxIter();
1014  Dag2BN_.enableMaxIter();
1015  };
1016  /// @return true if stopping criterion on max iterations is enabled, false
1017  /// otherwise
1018  bool isEnabledMaxIter() const {
1019  if (currentAlgorithm_ != nullptr)
1020  return currentAlgorithm_->isEnabledMaxIter();
1021  else
1022  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1023  }
1024  /// @}
1025 
1026  /// stopping criterion on timeout
1027  /// If the criterion was disabled it will be enabled
1028  /// @{
1029  /// @throw OutOfLowerBound if timeout<=0.0
1030  /** timeout is time in second (double).
1031  */
1032  void setMaxTime(double timeout) {
1033  algoK2_.approximationScheme().setMaxTime(timeout);
1034  greedyHillClimbing_.setMaxTime(timeout);
1035  localSearchWithTabuList_.setMaxTime(timeout);
1036  Dag2BN_.setMaxTime(timeout);
1037  }
1038 
1039  /// returns the timeout (in seconds)
1040  double maxTime() const {
1041  if (currentAlgorithm_ != nullptr)
1042  return currentAlgorithm_->maxTime();
1043  else
1044  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1045  }
1046 
1047  /// get the current running time in second (double)
1048  double currentTime() const {
1049  if (currentAlgorithm_ != nullptr)
1050  return currentAlgorithm_->currentTime();
1051  else
1052  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1053  }
1054 
1055  /// Disable stopping criterion on timeout
1056  void disableMaxTime() {
1057  algoK2_.approximationScheme().disableMaxTime();
1058  greedyHillClimbing_.disableMaxTime();
1059  localSearchWithTabuList_.disableMaxTime();
1060  Dag2BN_.disableMaxTime();
1061  };
1062  void enableMaxTime() {
1063  algoK2_.approximationScheme().enableMaxTime();
1064  greedyHillClimbing_.enableMaxTime();
1065  localSearchWithTabuList_.enableMaxTime();
1066  Dag2BN_.enableMaxTime();
1067  };
1068  /// @return true if stopping criterion on timeout is enabled, false
1069  /// otherwise
1070  bool isEnabledMaxTime() const {
1071  if (currentAlgorithm_ != nullptr)
1072  return currentAlgorithm_->isEnabledMaxTime();
1073  else
1074  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1075  }
1076  /// @}
1077 
1078  /// how many samples between 2 stopping isEnableds
1079  /// @{
1080  /// @throw OutOfLowerBound if p<1
1081  void setPeriodSize(Size p) {
1082  algoK2_.approximationScheme().setPeriodSize(p);
1083  greedyHillClimbing_.setPeriodSize(p);
1084  localSearchWithTabuList_.setPeriodSize(p);
1085  Dag2BN_.setPeriodSize(p);
1086  };
1087 
1088  Size periodSize() const {
1089  if (currentAlgorithm_ != nullptr)
1090  return currentAlgorithm_->periodSize();
1091  else
1092  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1093  }
1094  /// @}
1095 
1096  /// verbosity
1097  /// @{
1098  void setVerbosity(bool v) {
1099  algoK2_.approximationScheme().setVerbosity(v);
1100  greedyHillClimbing_.setVerbosity(v);
1101  localSearchWithTabuList_.setVerbosity(v);
1102  Dag2BN_.setVerbosity(v);
1103  };
1104 
1105  bool verbosity() const {
1106  if (currentAlgorithm_ != nullptr)
1107  return currentAlgorithm_->verbosity();
1108  else
1109  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1110  }
1111  /// @}
1112 
1113  /// history
1114  /// @{
1115 
1116  ApproximationSchemeSTATE stateApproximationScheme() const {
1117  if (currentAlgorithm_ != nullptr)
1118  return currentAlgorithm_->stateApproximationScheme();
1119  else
1120  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1121  }
1122 
1123  /// @throw OperationNotAllowed if scheme not performed
1124  Size nbrIterations() const {
1125  if (currentAlgorithm_ != nullptr)
1126  return currentAlgorithm_->nbrIterations();
1127  else
1128  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1129  }
1130 
1131  /// @throw OperationNotAllowed if scheme not performed or verbosity=false
1132  const std::vector< double >& history() const {
1133  if (currentAlgorithm_ != nullptr)
1134  return currentAlgorithm_->history();
1135  else
1136  GUM_ERROR(FatalError, "No chosen algorithm for learning")
1137  }
1138  /// @}
1139  };
1140 
1141  } /* namespace learning */
1142 
1143 } /* namespace gum */
1144 
1145 /// include the inlined functions if necessary
1146 #ifndef GUM_NO_INLINE
1147 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h>
1148 #endif /* GUM_NO_INLINE */
1149 
1150 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h>
1151 
1152 #endif /* GUM_LEARNING_GENERIC_BN_LEARNER_H */