aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
genericBNLearner.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief A class for generic framework of learning algorithms that can easily
25  * be used.
26  *
27  * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and
28  * LocalSearchWithTabuList
29  *
30  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
31  */
32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H
33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H
34 
35 #include <sstream>
36 #include <memory>
37 
38 #include <agrum/BN/BayesNet.h>
39 #include <agrum/agrum.h>
40 #include <agrum/tools/core/bijection.h>
41 #include <agrum/tools/core/sequence.h>
42 #include <agrum/tools/graphs/DAG.h>
43 
44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h>
45 #include <agrum/tools/database/DBRowGeneratorParser.h>
46 #include <agrum/tools/database/DBInitializerFromCSV.h>
47 #include <agrum/tools/database/databaseTable.h>
48 #include <agrum/tools/database/DBRowGeneratorParser.h>
49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h>
50 #include <agrum/tools/database/DBRowGeneratorEM.h>
51 #include <agrum/tools/database/DBRowGeneratorSet.h>
52 
53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h>
54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h>
55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h>
56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h>
57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h>
58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
59 
60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h>
61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h>
62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h>
63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h>
64 
65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h>
66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h>
67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h>
68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h>
69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h>
70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h>
71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h>
72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h>
73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h>
74 
75 #include <agrum/BN/learning/structureUtils/graphChange.h>
76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h>
77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h>
78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h>
79 
80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h>
82 
83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h>
84 #include <agrum/tools/core/approximations/approximationSchemeListener.h>
85 
86 #include <agrum/BN/learning/K2.h>
87 #include <agrum/BN/learning/Miic.h>
88 #include <agrum/BN/learning/greedyHillClimbing.h>
89 #include <agrum/BN/learning/localSearchWithTabuList.h>
90 
91 #include <agrum/tools/core/signal/signaler.h>
92 
93 namespace gum {
94 
95  namespace learning {
96 
97  class BNLearnerListener;
98 
99  /** @class genericBNLearner
100  * @brief A pack of learning algorithms that can easily be used
101  *
102  * The pack currently contains K2, GreedyHillClimbing and
103  * LocalSearchWithTabuList also 3off2/miic
104  * @ingroup learning_group
105  */
106  class genericBNLearner: public gum::IApproximationSchemeConfiguration {
107  // private:
108  public:
109  /// an enumeration enabling to select easily the score we wish to use
110  enum class ScoreType
111  {
112  AIC,
113  BD,
114  BDeu,
115  BIC,
116  K2,
117  LOG2LIKELIHOOD
118  };
119 
120  /// an enumeration to select the type of parameter estimation we shall
121  /// apply
122  enum class ParamEstimatorType
123  { ML };
124 
125  /// an enumeration to select the apriori
126  enum class AprioriType
127  {
128  NO_APRIORI,
129  SMOOTHING,
130  DIRICHLET_FROM_DATABASE,
131  BDEU
132  };
133 
134  /// an enumeration to select easily the learning algorithm to use
135  enum class AlgoType
136  {
137  K2,
138  GREEDY_HILL_CLIMBING,
139  LOCAL_SEARCH_WITH_TABU_LIST,
140  MIIC_THREE_OFF_TWO
141  };
142 
143 
144  /// a helper to easily read databases
145  class Database {
146  public:
147  // ########################################################################
148  /// @name Constructors / Destructors
149  // ########################################################################
150  /// @{
151 
152  /// default constructor
153  /** @param file the name of the CSV file containing the data
154  * @param missing_symbols the set of symbols in the CSV file that
155  * correspond to missing data */
156  explicit Database(const std::string& file,
157  const std::vector< std::string >& missing_symbols);
158 
159  /// default constructor
160  /** @param db an already initialized database table that is used to
161  * fill the Database */
162  explicit Database(const DatabaseTable<>& db);
163 
164  /// constructor for the aprioris
165  /** We must ensure that the variables of the Database are identical to
166  * those of the score database (else the countings used by the
167  * scores might be erroneous). However, we allow the variables to be
168  * ordered differently in the two databases: variables with the same
169  * name in both databases are supposed to be the same.
170  * @param file the name of the CSV file containing the data
171  * @param score_database the main database used for the learning
172  * @param missing_symbols the set of symbols in the CSV file that
173  * correspond to missing data
174  */
175  Database(const std::string& filename,
176  Database& score_database,
177  const std::vector< std::string >& missing_symbols);
178 
179  /// constructor with a BN providing the variables of interest
180  /** @param file the name of the CSV file containing the data
181  * @param bn a Bayesian network indicating which variables of the CSV
182  * file are used for learning
183  * @param missing_symbols the set of symbols in the CSV file that
184  * correspond to missing data
185  */
186  template < typename GUM_SCALAR >
187  Database(const std::string& filename,
188  const gum::BayesNet< GUM_SCALAR >& bn,
189  const std::vector< std::string >& missing_symbols);
190 
191  /// copy constructor
192  Database(const Database& from);
193 
194  /// move constructor
195  Database(Database&& from);
196 
197  /// destructor
198  ~Database();
199 
200  /// @}
201 
202  // ########################################################################
203  /// @name Operators
204  // ########################################################################
205  /// @{
206 
207  /// copy operator
208  Database& operator=(const Database& from);
209 
210  /// move operator
211  Database& operator=(Database&& from);
212 
213  /// @}
214 
215  // ########################################################################
216  /// @name Accessors / Modifiers
217  // ########################################################################
218  /// @{
219 
220  /// returns the parser for the database
221  DBRowGeneratorParser<>& parser();
222 
223  /// returns the domain sizes of the variables
224  const std::vector< std::size_t >& domainSizes() const;
225 
226  /// returns the names of the variables in the database
227  const std::vector< std::string >& names() const;
228 
229  /// returns the node id corresponding to a variable name
230  NodeId idFromName(const std::string& var_name) const;
231 
232  /// returns the variable name corresponding to a given node id
233  const std::string& nameFromId(NodeId id) const;
234 
235  /// returns the internal database table
236  const DatabaseTable<>& databaseTable() const;
237 
238  /** @brief assign a weight to all the rows of the database so
239  * that the sum of their weights is equal to new_weight */
240  void setDatabaseWeight(const double new_weight);
241 
242  /// returns the mapping between node ids and their columns in the database
243  const Bijection< NodeId, std::size_t >& nodeId2Columns() const;
244 
245  /// returns the set of missing symbols taken into account
246  const std::vector< std::string >& missingSymbols() const;
247 
248  /// returns the number of records in the database
249  std::size_t nbRows() const;
250 
251  /// returns the number of records in the database
252  std::size_t size() const;
253 
254  /// sets the weight of the ith record
255  /** @throws OutOfBounds if i is outside the set of indices of the
256  * records or if the weight is negative
257  */
258  void setWeight(const std::size_t i, const double weight);
259 
260  /// returns the weight of the ith record
261  /** @throws OutOfBounds if i is outside the set of indices of the
262  * records */
263  double weight(const std::size_t i) const;
264 
265  /// returns the weight of the whole database
266  double weight() const;
267 
268 
269  /// @}
270 
271  protected:
272  /// the database itself
273  DatabaseTable<> database__;
274 
275  /// the parser used for reading the database
276  DBRowGeneratorParser<>* parser__{nullptr};
277 
278  /// the domain sizes of the variables (useful to speed-up computations)
279  std::vector< std::size_t > domain_sizes__;
280 
281  /// a bijection assigning to each variable name its NodeId
282  Bijection< NodeId, std::size_t > nodeId2cols__;
283 
284 /// the max number of threads authorized
285 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE)
286  Size max_threads_number__{getMaxNumberOfThreads()};
287 #else
288  Size max_threads_number__{1};
289 #endif /* GUM_DEBUG_MODE */
290 
291  /// the minimal number of rows to parse (on average) by thread
292  Size min_nb_rows_per_thread__{100};
293 
294  private:
295  // returns the set of variables as a BN. This is convenient for
296  // the constructors of apriori Databases
297  template < typename GUM_SCALAR >
298  BayesNet< GUM_SCALAR > BNVars__() const;
299  };
300 
301  /// sets the apriori weight
302  void setAprioriWeight__(double weight);
303 
304  public:
305  // ##########################################################################
306  /// @name Constructors / Destructors
307  // ##########################################################################
308  /// @{
309 
310  /// default constructor
311  /**
312  * read the database file for the score / parameter estimation and var
313  * names
314  */
315  genericBNLearner(const std::string& filename,
316  const std::vector< std::string >& missing_symbols);
317  genericBNLearner(const DatabaseTable<>& db);
318 
319  /**
320  * read the database file for the score / parameter estimation and var
321  * names
322  * @param filename The file to learn from.
323  * @param modalities indicate for some nodes (not necessarily all the
324  * nodes of the BN) which modalities they should have and in which order
325  * these modalities should be stored into the nodes. For instance, if
326  * modalities = { 1 -> {True, False, Big} }, then the node of id 1 in the
327  * BN will have 3 modalities, the first one being True, the second one
328  * being False, and the third bein Big.
329  * @param parse_database if true, the modalities specified by the user
330  * will be considered as a superset of the modalities of the variables. A
331  * parsing of the database will allow to determine which ones are really
332  * necessary and will keep them in the order specified by the user
333  * (NodeProperty modalities). If parse_database is set to false (the
334  * default), then the modalities specified by the user will be considered
335  * as being exactly those of the variables of the BN (as a consequence,
336  * if we find other values in the database, an exception will be raised
337  * during learning). */
338  template < typename GUM_SCALAR >
339  genericBNLearner(const std::string& filename,
340  const gum::BayesNet< GUM_SCALAR >& src,
341  const std::vector< std::string >& missing_symbols);
342 
343  /// copy constructor
344  genericBNLearner(const genericBNLearner&);
345 
346  /// move constructor
347  genericBNLearner(genericBNLearner&&);
348 
349  /// destructor
350  virtual ~genericBNLearner();
351 
352  /// @}
353 
354  // ##########################################################################
355  /// @name Operators
356  // ##########################################################################
357  /// @{
358 
359  /// copy operator
360  genericBNLearner& operator=(const genericBNLearner&);
361 
362  /// move operator
363  genericBNLearner& operator=(genericBNLearner&&);
364 
365  /// @}
366 
367  // ##########################################################################
368  /// @name Accessors / Modifiers
369  // ##########################################################################
370  /// @{
371 
372  /// learn a structure from a file (must have read the db before)
373  DAG learnDAG();
374 
375  /// learn a partial structure from a file (must have read the db before and
376  /// must have selected miic or 3off2)
377  MixedGraph learnMixedStructure();
378 
379  /// sets an initial DAG structure
380  void setInitialDAG(const DAG&);
381 
382  /// returns the names of the variables in the database
383  const std::vector< std::string >& names() const;
384 
385  /// returns the domain sizes of the variables in the database
386  const std::vector< std::size_t >& domainSizes() const;
387  Size domainSize(NodeId var) const;
388  Size domainSize(const std::string& var) const;
389 
390  /// returns the node id corresponding to a variable name
391  /**
392  * @throw MissingVariableInDatabase if a variable of the BN is not found
393  * in the database.
394  */
395  NodeId idFromName(const std::string& var_name) const;
396 
397  /// returns the database used by the BNLearner
398  const DatabaseTable<>& database() const;
399 
400  /** @brief assign a weight to all the rows of the learning database so
401  * that the sum of their weights is equal to new_weight */
402  void setDatabaseWeight(const double new_weight);
403 
404  /// sets the weight of the ith record of the database
405  /** @throws OutOfBounds if i is outside the set of indices of the
406  * records or if the weight is negative
407  */
408  void setRecordWeight(const std::size_t i, const double weight);
409 
410  /// returns the weight of the ith record
411  /** @throws OutOfBounds if i is outside the set of indices of the
412  * records */
413  double recordWeight(const std::size_t i) const;
414 
415  /// returns the weight of the whole database
416  double databaseWeight() const;
417 
418  /// returns the variable name corresponding to a given node id
419  const std::string& nameFromId(NodeId id) const;
420 
421  /// use a new set of database rows' ranges to perform learning
422  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
423  * indices. The subsequent learnings are then performed only on the union
424  * of the rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when
425  * performing cross validation tasks, in which part of the database should
426  * be ignored. An empty set of ranges is equivalent to an interval [X,Y)
427  * ranging over the whole database. */
428  template < template < typename > class XALLOC >
429  void useDatabaseRanges(
430  const std::vector< std::pair< std::size_t, std::size_t >,
431  XALLOC< std::pair< std::size_t, std::size_t > > >&
432  new_ranges);
433 
434  /// reset the ranges to the one range corresponding to the whole database
435  void clearDatabaseRanges();
436 
437  /// returns the current database rows' ranges used for learning
438  /** @return The method returns a vector of pairs [Xi,Yi) of indices of
439  * rows in the database. The learning is performed on these set of rows.
440  * @warning an empty set of ranges means the whole database. */
441  const std::vector< std::pair< std::size_t, std::size_t > >&
442  databaseRanges() const;
443 
444  /// sets the ranges of rows to be used for cross-validation learning
445  /** When applied on (x,k), the method indicates to the subsequent learnings
446  * that they should be performed on the xth fold in a k-fold
447  * cross-validation context. For instance, if a database has 1000 rows,
448  * and if we perform a 10-fold cross-validation, then, the first learning
449  * fold (learning_fold=0) corresponds to rows interval [100,1000) and the
450  * test dataset corresponds to [0,100). The second learning fold
451  * (learning_fold=1) is [0,100) U [200,1000) and the corresponding test
452  * dataset is [100,200).
453  * @param learning_fold a number indicating the set of rows used for
454  * learning. If N denotes the size of the database, and k_fold represents
455  * the number of folds in the cross validation, then the set of rows
456  * used for testing is [learning_fold * N / k_fold,
457  * (learning_fold+1) * N / k_fold) and the learning database is the
458  * complement in the database
459  * @param k_fold the value of "k" in k-fold cross validation
460  * @return a pair [x,y) of rows' indices that corresponds to the indices
461  * of rows in the original database that constitute the test dataset
462  * @throws OutOfBounds is raised if k_fold is equal to 0 or learning_fold
463  * is greater than or eqal to k_fold, or if k_fold is greater than
464  * or equal to the size of the database. */
465  std::pair< std::size_t, std::size_t >
466  useCrossValidationFold(const std::size_t learning_fold,
467  const std::size_t k_fold);
468 
469 
470  /**
471  * Return the <statistic,pvalue> pair for chi2 test in the database
472  * @param id1 first variable
473  * @param id2 second variable
474  * @param knowing list of observed variables
475  * @return a std::pair<double,double>
476  */
477  std::pair< double, double > chi2(const NodeId id1,
478  const NodeId id2,
479  const std::vector< NodeId >& knowing = {});
480  /**
481  * Return the <statistic,pvalue> pair for the BNLearner
482  * @param id1 first variable
483  * @param id2 second variable
484  * @param knowing list of observed variables
485  * @return a std::pair<double,double>
486  */
487  std::pair< double, double > chi2(const std::string& name1,
488  const std::string& name2,
489  const std::vector< std::string >& knowing
490  = {});
491 
492  /**
493  * Return the <statistic,pvalue> pair for for G2 test in the database
494  * @param id1 first variable
495  * @param id2 second variable
496  * @param knowing list of observed variables
497  * @return a std::pair<double,double>
498  */
499  std::pair< double, double > G2(const NodeId id1,
500  const NodeId id2,
501  const std::vector< NodeId >& knowing = {});
502  /**
503  * Return the <statistic,pvalue> pair for for G2 test in the database
504  * @param id1 first variable
505  * @param id2 second variable
506  * @param knowing list of observed variables
507  * @return a std::pair<double,double>
508  */
509  std::pair< double, double > G2(const std::string& name1,
510  const std::string& name2,
511  const std::vector< std::string >& knowing
512  = {});
513 
514  /**
515  * Return the loglikelihood of vars in the base, conditioned by knowing for
516  * the BNLearner
517  * @param vars a vector of NodeIds
518  * @param knowing an optional vector of conditioning NodeIds
519  * @return a std::pair<double,double>
520  */
521  double logLikelihood(const std::vector< NodeId >& vars,
522  const std::vector< NodeId >& knowing = {});
523 
524  /**
525  * Return the loglikelihood of vars in the base, conditioned by knowing for
526  * the BNLearner
527  * @param vars a vector of name of rows
528  * @param knowing an optional vector of conditioning rows
529  * @return a std::pair<double,double>
530  */
531  double logLikelihood(const std::vector< std::string >& vars,
532  const std::vector< std::string >& knowing = {});
533 
534  /**
535  * Return the pseudoconts ofNodeIds vars in the base in a raw array
536  * @param vars a vector of
537  * @return a a std::vector<double> containing the contingency table
538  */
539  std::vector< double > rawPseudoCount(const std::vector< NodeId >& vars);
540 
541  /**
542  * Return the pseudoconts of vars in the base in a raw array
543  * @param vars a vector of name
544  * @return a std::vector<double> containing the contingency table
545  */
546  std::vector< double > rawPseudoCount(const std::vector< std::string >& vars);
547  /**
548  *
549  * @return the number of cols in the database
550  */
551  Size nbCols() const;
552 
553  /**
554  *
555  * @return the number of rows in the database
556  */
557  Size nbRows() const;
558 
559  /** use The EM algorithm to learn paramters
560  *
561  * if epsilon=0, EM is not used
562  */
563  void useEM(const double epsilon);
564 
565  /// returns true if the learner's database has missing values
566  bool hasMissingValues() const;
567 
568  /// @}
569 
570  // ##########################################################################
571  /// @name Score selection
572  // ##########################################################################
573  /// @{
574 
575  /// indicate that we wish to use an AIC score
576  void useScoreAIC();
577 
578  /// indicate that we wish to use a BD score
579  void useScoreBD();
580 
581  /// indicate that we wish to use a BDeu score
582  void useScoreBDeu();
583 
584  /// indicate that we wish to use a BIC score
585  void useScoreBIC();
586 
587  /// indicate that we wish to use a K2 score
588  void useScoreK2();
589 
590  /// indicate that we wish to use a Log2Likelihood score
591  void useScoreLog2Likelihood();
592 
593  /// @}
594 
595  // ##########################################################################
596  /// @name A priori selection / parameterization
597  // ##########################################################################
598  /// @{
599 
600  /// use no apriori
601  void useNoApriori();
602 
603  /// use the BDeu apriori
604  /** The BDeu apriori adds weight to all the cells of the countings
605  * tables. In other words, it adds weight rows in the database with
606  * equally probable values. */
607  void useAprioriBDeu(double weight = 1);
608 
609  /// use the apriori smoothing
610  /** @param weight pass in argument a weight if you wish to assign a weight
611  * to the smoothing, else the current weight of the genericBNLearner will
612  * be used. */
613  void useAprioriSmoothing(double weight = 1);
614 
615  /// use the Dirichlet apriori
616  void useAprioriDirichlet(const std::string& filename, double weight = 1);
617 
618 
619  /// checks whether the current score and apriori are compatible
620  /** @returns a non empty string if the apriori is somehow compatible with the
621  * score.*/
622  std::string checkScoreAprioriCompatibility();
623  /// @}
624 
625  // ##########################################################################
626  /// @name Learning algorithm selection
627  // ##########################################################################
628  /// @{
629 
630  /// indicate that we wish to use a greedy hill climbing algorithm
631  void useGreedyHillClimbing();
632 
633  /// indicate that we wish to use a local search with tabu list
634  /** @param tabu_size indicate the size of the tabu list
635  * @param nb_decrease indicate the max number of changes decreasing the
636  * score consecutively that we allow to apply */
637  void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2);
638 
639  /// indicate that we wish to use K2
640  void useK2(const Sequence< NodeId >& order);
641 
642  /// indicate that we wish to use K2
643  void useK2(const std::vector< NodeId >& order);
644 
645  /// indicate that we wish to use 3off2
646  void use3off2();
647 
648  /// indicate that we wish to use MIIC
649  void useMIIC();
650 
651  /// @}
652 
653  // ##########################################################################
654  /// @name 3off2/MIIC parameterization and specific results
655  // ##########################################################################
656  /// @{
657  /// indicate that we wish to use the NML correction for 3off2
658  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
659  void useNML();
660  /// indicate that we wish to use the MDL correction for 3off2
661  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
662  void useMDL();
663  /// indicate that we wish to use the NoCorr correction for 3off2
664  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
665  void useNoCorr();
666 
667  /// get the list of arcs hiding latent variables
668  /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
669  const std::vector< Arc > latentVariables() const;
670 
671  /// @}
672  // ##########################################################################
673  /// @name Accessors / Modifiers for adding constraints on learning
674  // ##########################################################################
675  /// @{
676 
677  /// sets the max indegree
678  void setMaxIndegree(Size max_indegree);
679 
680  /**
681  * sets a partial order on the nodes
682  * @param slice_order a NodeProperty given the rank (priority) of nodes in
683  * the partial order
684  */
685  void setSliceOrder(const NodeProperty< NodeId >& slice_order);
686 
687  /**
688  * sets a partial order on the nodes
689  * @param slices the list of list of variable names
690  */
691  void setSliceOrder(const std::vector< std::vector< std::string > >& slices);
692 
693  /// assign a set of forbidden arcs
694  void setForbiddenArcs(const ArcSet& set);
695 
696  /// @name assign a new forbidden arc
697  /// @{
698  void addForbiddenArc(const Arc& arc);
699  void addForbiddenArc(const NodeId tail, const NodeId head);
700  void addForbiddenArc(const std::string& tail, const std::string& head);
701  /// @}
702 
703  /// @name remove a forbidden arc
704  /// @{
705  void eraseForbiddenArc(const Arc& arc);
706  void eraseForbiddenArc(const NodeId tail, const NodeId head);
707  void eraseForbiddenArc(const std::string& tail, const std::string& head);
708  ///@}
709 
710  /// assign a set of forbidden arcs
711  void setMandatoryArcs(const ArcSet& set);
712 
713  /// @name assign a new forbidden arc
714  ///@{
715  void addMandatoryArc(const Arc& arc);
716  void addMandatoryArc(const NodeId tail, const NodeId head);
717  void addMandatoryArc(const std::string& tail, const std::string& head);
718  ///@}
719 
720  /// @name remove a forbidden arc
721  ///@{
722  void eraseMandatoryArc(const Arc& arc);
723  void eraseMandatoryArc(const NodeId tail, const NodeId head);
724  void eraseMandatoryArc(const std::string& tail, const std::string& head);
725  /// @}
726 
727  /// assign a set of forbidden edges
728  /// @warning Once at least one possible edge is defined, all other edges are
729  /// not possible anymore
730  /// @{
731  void setPossibleEdges(const EdgeSet& set);
732  void setPossibleSkeleton(const UndiGraph& skeleton);
733  /// @}
734 
735  /// @name assign a new possible edge
736  /// @warning Once at least one possible edge is defined, all other edges are
737  /// not possible anymore
738  /// @{
739  void addPossibleEdge(const Edge& edge);
740  void addPossibleEdge(const NodeId tail, const NodeId head);
741  void addPossibleEdge(const std::string& tail, const std::string& head);
742  /// @}
743 
744  /// @name remove a possible edge
745  /// @{
746  void erasePossibleEdge(const Edge& edge);
747  void erasePossibleEdge(const NodeId tail, const NodeId head);
748  void erasePossibleEdge(const std::string& tail, const std::string& head);
749  ///@}
750 
751  ///@}
752 
753  protected:
754  /// the score selected for learning
755  ScoreType score_type__{ScoreType::BDeu};
756 
757  /// the score used
758  Score<>* score__{nullptr};
759 
760  /// the type of the parameter estimator
761  ParamEstimatorType param_estimator_type__{ParamEstimatorType::ML};
762 
763  /// epsilon for EM. if espilon=0.0 : no EM
764  double EMepsilon__{0.0};
765 
766  /// the selected correction for 3off2 and miic
767  CorrectedMutualInformation<>* mutual_info__{nullptr};
768 
769  /// the a priori selected for the score and parameters
770  AprioriType apriori_type__{AprioriType::NO_APRIORI};
771 
772  /// the apriori used
773  Apriori<>* apriori__{nullptr};
774 
775  AprioriNoApriori<>* no_apriori__{nullptr};
776 
777  /// the weight of the apriori
778  double apriori_weight__{1.0f};
779 
780  /// the constraint for 2TBNs
781  StructuralConstraintSliceOrder constraint_SliceOrder__;
782 
783  /// the constraint for indegrees
784  StructuralConstraintIndegree constraint_Indegree__;
785 
786  /// the constraint for tabu lists
787  StructuralConstraintTabuList constraint_TabuList__;
788 
789  /// the constraint on forbidden arcs
790  StructuralConstraintForbiddenArcs constraint_ForbiddenArcs__;
791 
792  /// the constraint on possible Edges
793  StructuralConstraintPossibleEdges constraint_PossibleEdges__;
794 
795  /// the constraint on forbidden arcs
796  StructuralConstraintMandatoryArcs constraint_MandatoryArcs__;
797 
798  /// the selected learning algorithm
799  AlgoType selected_algo__{AlgoType::GREEDY_HILL_CLIMBING};
800 
801  /// the K2 algorithm
802  K2 K2__;
803 
804  /// the 3off2 algorithm
805  Miic miic_3off2__;
806 
807  /// the penalty used in 3off2
808  typename CorrectedMutualInformation<>::KModeTypes kmode_3off2__{
809  CorrectedMutualInformation<>::KModeTypes::MDL};
810 
811  /// the parametric EM
812  DAG2BNLearner<> Dag2BN__;
813 
814  /// the greedy hill climbing algorithm
815  GreedyHillClimbing greedy_hill_climbing__;
816 
817  /// the local search with tabu list algorithm
818  LocalSearchWithTabuList local_search_with_tabu_list__;
819 
820  /// the database to be used by the scores and parameter estimators
821  Database score_database__;
822 
823  /// the set of rows' ranges within the database in which learning is done
824  std::vector< std::pair< std::size_t, std::size_t > > ranges__;
825 
826  /// the database used by the Dirichlet a priori
827  Database* apriori_database__{nullptr};
828 
829  /// the filename for the Dirichlet a priori, if any
830  std::string apriori_dbname__;
831 
832  /// an initial DAG given to learners
833  DAG initial_dag__;
834 
835  // the current algorithm as an approximationScheme
836  const ApproximationScheme* current_algorithm__{nullptr};
837 
838  /// reads a file and returns a databaseVectInRam
839  static DatabaseTable<>
840  readFile__(const std::string& filename,
841  const std::vector< std::string >& missing_symbols);
842 
843  /// checks whether the extension of a CSV filename is correct
844  static void checkFileName__(const std::string& filename);
845 
846  /// create the apriori used for learning
847  void createApriori__();
848 
849  /// create the score used for learning
850  void createScore__();
851 
852  /// create the parameter estimator used for learning
853  ParamEstimator<>* createParamEstimator__(DBRowGeneratorParser<>& parser,
854  bool take_into_account_score
855  = true);
856 
857  /// returns the DAG learnt
858  DAG learnDAG__();
859 
860  /// prepares the initial graph for 3off2 or miic
861  MixedGraph prepare_miic_3off2__();
862 
863  /// returns the type (as a string) of a given apriori
864  const std::string& getAprioriType__() const;
865 
866  /// create the Corrected Mutual Information instance for Miic/3off2
867  void createCorrectedMutualInformation__();
868 
869 
870  public:
871  // ##########################################################################
872  /// @name redistribute signals AND implemenation of interface
873  /// IApproximationSchemeConfiguration
874  // ##########################################################################
875  // in order to not pollute the proper code of genericBNLearner, we
876  // directly
877  // implement those
878  // very simples methods here.
879  /// {@ /// distribute signals
880  INLINE void setCurrentApproximationScheme(
881  const ApproximationScheme* approximationScheme) {
882  current_algorithm__ = approximationScheme;
883  }
884 
885  INLINE void
886  distributeProgress(const ApproximationScheme* approximationScheme,
887  Size pourcent,
888  double error,
889  double time) {
890  setCurrentApproximationScheme(approximationScheme);
891 
892  if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time);
893  };
894 
895  /// distribute signals
896  INLINE void distributeStop(const ApproximationScheme* approximationScheme,
897  std::string message) {
898  setCurrentApproximationScheme(approximationScheme);
899 
900  if (onStop.hasListener()) GUM_EMIT1(onStop, message);
901  };
902  /// @}
903 
904  /// Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|
905  /// If the criterion was disabled it will be enabled
906  /// @{
907  /// @throw OutOfLowerBound if eps<0
908  void setEpsilon(double eps) {
909  K2__.approximationScheme().setEpsilon(eps);
910  greedy_hill_climbing__.setEpsilon(eps);
911  local_search_with_tabu_list__.setEpsilon(eps);
912  Dag2BN__.setEpsilon(eps);
913  };
914 
915  /// Get the value of epsilon
916  double epsilon() const {
917  if (current_algorithm__ != nullptr)
918  return current_algorithm__->epsilon();
919  else
920  GUM_ERROR(FatalError, "No chosen algorithm for learning");
921  };
922 
923  /// Disable stopping criterion on epsilon
924  void disableEpsilon() {
925  K2__.approximationScheme().disableEpsilon();
926  greedy_hill_climbing__.disableEpsilon();
927  local_search_with_tabu_list__.disableEpsilon();
928  Dag2BN__.disableEpsilon();
929  };
930 
931  /// Enable stopping criterion on epsilon
932  void enableEpsilon() {
933  K2__.approximationScheme().enableEpsilon();
934  greedy_hill_climbing__.enableEpsilon();
935  local_search_with_tabu_list__.enableEpsilon();
936  Dag2BN__.enableEpsilon();
937  };
938 
939  /// @return true if stopping criterion on epsilon is enabled, false
940  /// otherwise
941  bool isEnabledEpsilon() const {
942  if (current_algorithm__ != nullptr)
943  return current_algorithm__->isEnabledEpsilon();
944  else
945  GUM_ERROR(FatalError, "No chosen algorithm for learning");
946  };
947  /// @}
948 
949  /// Given that we approximate f(t), stopping criterion on
950  /// d/dt(|f(t+1)-f(t)|)
951  /// If the criterion was disabled it will be enabled
952  /// @{
953  /// @throw OutOfLowerBound if rate<0
954  void setMinEpsilonRate(double rate) {
955  K2__.approximationScheme().setMinEpsilonRate(rate);
956  greedy_hill_climbing__.setMinEpsilonRate(rate);
957  local_search_with_tabu_list__.setMinEpsilonRate(rate);
958  Dag2BN__.setMinEpsilonRate(rate);
959  };
960 
961  /// Get the value of the minimal epsilon rate
962  double minEpsilonRate() const {
963  if (current_algorithm__ != nullptr)
964  return current_algorithm__->minEpsilonRate();
965  else
966  GUM_ERROR(FatalError, "No chosen algorithm for learning");
967  };
968 
969  /// Disable stopping criterion on epsilon rate
970  void disableMinEpsilonRate() {
971  K2__.approximationScheme().disableMinEpsilonRate();
972  greedy_hill_climbing__.disableMinEpsilonRate();
973  local_search_with_tabu_list__.disableMinEpsilonRate();
974  Dag2BN__.disableMinEpsilonRate();
975  };
976  /// Enable stopping criterion on epsilon rate
977  void enableMinEpsilonRate() {
978  K2__.approximationScheme().enableMinEpsilonRate();
979  greedy_hill_climbing__.enableMinEpsilonRate();
980  local_search_with_tabu_list__.enableMinEpsilonRate();
981  Dag2BN__.enableMinEpsilonRate();
982  };
983  /// @return true if stopping criterion on epsilon rate is enabled, false
984  /// otherwise
985  bool isEnabledMinEpsilonRate() const {
986  if (current_algorithm__ != nullptr)
987  return current_algorithm__->isEnabledMinEpsilonRate();
988  else
989  GUM_ERROR(FatalError, "No chosen algorithm for learning");
990  };
991  /// @}
992 
993  /// stopping criterion on number of iterations
994  /// @{
995  /// If the criterion was disabled it will be enabled
996  /// @param max The maximum number of iterations
997  /// @throw OutOfLowerBound if max<=1
998  void setMaxIter(Size max) {
999  K2__.approximationScheme().setMaxIter(max);
1000  greedy_hill_climbing__.setMaxIter(max);
1001  local_search_with_tabu_list__.setMaxIter(max);
1002  Dag2BN__.setMaxIter(max);
1003  };
1004 
1005  /// @return the criterion on number of iterations
1006  Size maxIter() const {
1007  if (current_algorithm__ != nullptr)
1008  return current_algorithm__->maxIter();
1009  else
1010  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1011  };
1012 
1013  /// Disable stopping criterion on max iterations
1014  void disableMaxIter() {
1015  K2__.approximationScheme().disableMaxIter();
1016  greedy_hill_climbing__.disableMaxIter();
1017  local_search_with_tabu_list__.disableMaxIter();
1018  Dag2BN__.disableMaxIter();
1019  };
1020  /// Enable stopping criterion on max iterations
1021  void enableMaxIter() {
1022  K2__.approximationScheme().enableMaxIter();
1023  greedy_hill_climbing__.enableMaxIter();
1024  local_search_with_tabu_list__.enableMaxIter();
1025  Dag2BN__.enableMaxIter();
1026  };
1027  /// @return true if stopping criterion on max iterations is enabled, false
1028  /// otherwise
1029  bool isEnabledMaxIter() const {
1030  if (current_algorithm__ != nullptr)
1031  return current_algorithm__->isEnabledMaxIter();
1032  else
1033  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1034  };
1035  /// @}
1036 
1037  /// stopping criterion on timeout
1038  /// If the criterion was disabled it will be enabled
1039  /// @{
1040  /// @throw OutOfLowerBound if timeout<=0.0
1041  /** timeout is time in second (double).
1042  */
1043  void setMaxTime(double timeout) {
1044  K2__.approximationScheme().setMaxTime(timeout);
1045  greedy_hill_climbing__.setMaxTime(timeout);
1046  local_search_with_tabu_list__.setMaxTime(timeout);
1047  Dag2BN__.setMaxTime(timeout);
1048  }
1049 
1050  /// returns the timeout (in seconds)
1051  double maxTime() const {
1052  if (current_algorithm__ != nullptr)
1053  return current_algorithm__->maxTime();
1054  else
1055  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1056  };
1057 
1058  /// get the current running time in second (double)
1059  double currentTime() const {
1060  if (current_algorithm__ != nullptr)
1061  return current_algorithm__->currentTime();
1062  else
1063  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1064  };
1065 
1066  /// Disable stopping criterion on timeout
1067  void disableMaxTime() {
1068  K2__.approximationScheme().disableMaxTime();
1069  greedy_hill_climbing__.disableMaxTime();
1070  local_search_with_tabu_list__.disableMaxTime();
1071  Dag2BN__.disableMaxTime();
1072  };
1073  void enableMaxTime() {
1074  K2__.approximationScheme().enableMaxTime();
1075  greedy_hill_climbing__.enableMaxTime();
1076  local_search_with_tabu_list__.enableMaxTime();
1077  Dag2BN__.enableMaxTime();
1078  };
1079  /// @return true if stopping criterion on timeout is enabled, false
1080  /// otherwise
1081  bool isEnabledMaxTime() const {
1082  if (current_algorithm__ != nullptr)
1083  return current_algorithm__->isEnabledMaxTime();
1084  else
1085  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1086  };
1087  /// @}
1088 
1089  /// how many samples between 2 stopping isEnableds
1090  /// @{
1091  /// @throw OutOfLowerBound if p<1
1092  void setPeriodSize(Size p) {
1093  K2__.approximationScheme().setPeriodSize(p);
1094  greedy_hill_climbing__.setPeriodSize(p);
1095  local_search_with_tabu_list__.setPeriodSize(p);
1096  Dag2BN__.setPeriodSize(p);
1097  };
1098 
1099  Size periodSize() const {
1100  if (current_algorithm__ != nullptr)
1101  return current_algorithm__->periodSize();
1102  else
1103  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1104  };
1105  /// @}
1106 
1107  /// verbosity
1108  /// @{
1109  void setVerbosity(bool v) {
1110  K2__.approximationScheme().setVerbosity(v);
1111  greedy_hill_climbing__.setVerbosity(v);
1112  local_search_with_tabu_list__.setVerbosity(v);
1113  Dag2BN__.setVerbosity(v);
1114  };
1115 
1116  bool verbosity() const {
1117  if (current_algorithm__ != nullptr)
1118  return current_algorithm__->verbosity();
1119  else
1120  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1121  };
1122  /// @}
1123 
1124  /// history
1125  /// @{
1126 
1127  ApproximationSchemeSTATE stateApproximationScheme() const {
1128  if (current_algorithm__ != nullptr)
1129  return current_algorithm__->stateApproximationScheme();
1130  else
1131  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1132  };
1133 
1134  /// @throw OperationNotAllowed if scheme not performed
1135  Size nbrIterations() const {
1136  if (current_algorithm__ != nullptr)
1137  return current_algorithm__->nbrIterations();
1138  else
1139  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1140  };
1141 
1142  /// @throw OperationNotAllowed if scheme not performed or verbosity=false
1143  const std::vector< double >& history() const {
1144  if (current_algorithm__ != nullptr)
1145  return current_algorithm__->history();
1146  else
1147  GUM_ERROR(FatalError, "No chosen algorithm for learning");
1148  };
1149  /// @}
1150  };
1151 
1152  } /* namespace learning */
1153 
1154 } /* namespace gum */
1155 
1156 /// include the inlined functions if necessary
1157 #ifndef GUM_NO_INLINE
1158 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h>
1159 #endif /* GUM_NO_INLINE */
1160 
1161 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h>
1162 
1163 #endif /* GUM_LEARNING_GENERIC_BN_LEARNER_H */