aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
genericBNLearner.cpp
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 
31 #include <algorithm>
32 #include <iterator>
33 
34 #include <agrum/agrum.h>
35 #include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h>
36 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h>
37 #include <agrum/tools/stattests/indepTestChi2.h>
38 #include <agrum/tools/stattests/indepTestG2.h>
39 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
40 #include <agrum/tools/stattests/pseudoCount.h>
41 
42 // include the inlined functions if necessary
43 #ifdef GUM_NO_INLINE
44 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h>
45 #endif /* GUM_NO_INLINE */
46 
47 namespace gum {
48 
49  namespace learning {
50 
51 
52  genericBNLearner::Database::Database(const DatabaseTable<>& db) : _database_(db) {
53  // get the variables names
54  const auto& var_names = _database_.variableNames();
55  const std::size_t nb_vars = var_names.size();
56  for (auto dom: _database_.domainSizes())
57  _domain_sizes_.push_back(dom);
58  for (std::size_t i = 0; i < nb_vars; ++i) {
59  _nodeId2cols_.insert(NodeId(i), i);
60  }
61 
62  // create the parser
63  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
64  }
65 
66 
67  genericBNLearner::Database::Database(const std::string& filename,
68  const std::vector< std::string >& missing_symbols) :
69  Database(genericBNLearner::readFile_(filename, missing_symbols)) {}
70 
71 
72  genericBNLearner::Database::Database(const std::string& CSV_filename,
73  Database& score_database,
74  const std::vector< std::string >& missing_symbols) {
75  // assign to each column name in the CSV file its column
76  genericBNLearner::checkFileName_(CSV_filename);
77  DBInitializerFromCSV<> initializer(CSV_filename);
78  const auto& apriori_names = initializer.variableNames();
79  std::size_t apriori_nb_vars = apriori_names.size();
80  HashTable< std::string, std::size_t > apriori_names2col(apriori_nb_vars);
81  for (std::size_t i = std::size_t(0); i < apriori_nb_vars; ++i)
82  apriori_names2col.insert(apriori_names[i], i);
83 
84  // check that there are at least as many variables in the a priori
85  // database as those in the score_database
86  if (apriori_nb_vars < score_database._database_.nbVariables()) {
87  GUM_ERROR(InvalidArgument,
88  "the a apriori database has fewer variables "
89  "than the observed database");
90  }
91 
92  // get the mapping from the columns of score_database to those of
93  // the CSV file
94  const std::vector< std::string >& score_names
95  = score_database.databaseTable().variableNames();
96  const std::size_t score_nb_vars = score_names.size();
97  HashTable< std::size_t, std::size_t > mapping(score_nb_vars);
98  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
99  try {
100  mapping.insert(i, apriori_names2col[score_names[i]]);
101  } catch (Exception&) {
102  GUM_ERROR(MissingVariableInDatabase,
103  "Variable " << score_names[i]
104  << " of the observed database does not belong to the "
105  << "apriori database");
106  }
107  }
108 
109  // create the translators for CSV database
110  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
111  const Variable& var = score_database.databaseTable().variable(i);
112  _database_.insertTranslator(var, mapping[i], missing_symbols);
113  }
114 
115  // fill the database
116  initializer.fillDatabase(_database_);
117 
118  // get the domain sizes of the variables
119  for (auto dom: _database_.domainSizes())
120  _domain_sizes_.push_back(dom);
121 
122  // compute the mapping from node ids to column indices
123  _nodeId2cols_ = score_database.nodeId2Columns();
124 
125  // create the parser
126  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
127  }
128 
129 
130  genericBNLearner::Database::Database(const Database& from) :
131  _database_(from._database_), _domain_sizes_(from._domain_sizes_),
132  _nodeId2cols_(from._nodeId2cols_) {
133  // create the parser
134  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
135  }
136 
137 
138  genericBNLearner::Database::Database(Database&& from) :
139  _database_(std::move(from._database_)), _domain_sizes_(std::move(from._domain_sizes_)),
140  _nodeId2cols_(std::move(from._nodeId2cols_)) {
141  // create the parser
142  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
143  }
144 
145 
146  genericBNLearner::Database::~Database() { delete _parser_; }
147 
148  genericBNLearner::Database& genericBNLearner::Database::operator=(const Database& from) {
149  if (this != &from) {
150  delete _parser_;
151  _database_ = from._database_;
152  _domain_sizes_ = from._domain_sizes_;
153  _nodeId2cols_ = from._nodeId2cols_;
154 
155  // create the parser
156  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
157  }
158 
159  return *this;
160  }
161 
162  genericBNLearner::Database& genericBNLearner::Database::operator=(Database&& from) {
163  if (this != &from) {
164  delete _parser_;
165  _database_ = std::move(from._database_);
166  _domain_sizes_ = std::move(from._domain_sizes_);
167  _nodeId2cols_ = std::move(from._nodeId2cols_);
168 
169  // create the parser
170  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
171  }
172 
173  return *this;
174  }
175 
176 
177  // ===========================================================================
178 
179  genericBNLearner::genericBNLearner(const std::string& filename,
180  const std::vector< std::string >& missing_symbols) :
181  scoreDatabase_(filename, missing_symbols) {
182  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
183 
184  GUM_CONSTRUCTOR(genericBNLearner);
185  }
186 
187 
188  genericBNLearner::genericBNLearner(const DatabaseTable<>& db) : scoreDatabase_(db) {
189  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
190 
191  GUM_CONSTRUCTOR(genericBNLearner);
192  }
193 
194 
195  genericBNLearner::genericBNLearner(const genericBNLearner& from) :
196  scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
197  epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
198  aprioriWeight_(from.aprioriWeight_), constraintSliceOrder_(from.constraintSliceOrder_),
199  constraintIndegree_(from.constraintIndegree_),
200  constraintTabuList_(from.constraintTabuList_),
201  constraintForbiddenArcs_(from.constraintForbiddenArcs_),
202  constraintMandatoryArcs_(from.constraintMandatoryArcs_), selectedAlgo_(from.selectedAlgo_),
203  algoK2_(from.algoK2_), algoMiic3off2_(from.algoMiic3off2_), kmode3Off2_(from.kmode3Off2_),
204  greedyHillClimbing_(from.greedyHillClimbing_),
205  localSearchWithTabuList_(from.localSearchWithTabuList_),
206  scoreDatabase_(from.scoreDatabase_), ranges_(from.ranges_),
207  aprioriDbname_(from.aprioriDbname_), initialDag_(from.initialDag_) {
208  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
209 
210  GUM_CONS_CPY(genericBNLearner);
211  }
212 
213  genericBNLearner::genericBNLearner(genericBNLearner&& from) :
214  scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
215  epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
216  aprioriWeight_(from.aprioriWeight_),
217  constraintSliceOrder_(std::move(from.constraintSliceOrder_)),
218  constraintIndegree_(std::move(from.constraintIndegree_)),
219  constraintTabuList_(std::move(from.constraintTabuList_)),
220  constraintForbiddenArcs_(std::move(from.constraintForbiddenArcs_)),
221  constraintMandatoryArcs_(std::move(from.constraintMandatoryArcs_)),
222  selectedAlgo_(from.selectedAlgo_), algoK2_(std::move(from.algoK2_)),
223  algoMiic3off2_(std::move(from.algoMiic3off2_)), kmode3Off2_(from.kmode3Off2_),
224  greedyHillClimbing_(std::move(from.greedyHillClimbing_)),
225  localSearchWithTabuList_(std::move(from.localSearchWithTabuList_)),
226  scoreDatabase_(std::move(from.scoreDatabase_)), ranges_(std::move(from.ranges_)),
227  aprioriDbname_(std::move(from.aprioriDbname_)), initialDag_(std::move(from.initialDag_)) {
228  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
229 
230  GUM_CONS_MOV(genericBNLearner)
231  }
232 
233  genericBNLearner::~genericBNLearner() {
234  if (score_) delete score_;
235 
236  if (apriori_) delete apriori_;
237 
238  if (noApriori_) delete noApriori_;
239 
240  if (aprioriDatabase_) delete aprioriDatabase_;
241 
242  if (mutualInfo_) delete mutualInfo_;
243 
244  GUM_DESTRUCTOR(genericBNLearner);
245  }
246 
247  genericBNLearner& genericBNLearner::operator=(const genericBNLearner& from) {
248  if (this != &from) {
249  if (score_) {
250  delete score_;
251  score_ = nullptr;
252  }
253 
254  if (apriori_) {
255  delete apriori_;
256  apriori_ = nullptr;
257  }
258 
259  if (aprioriDatabase_) {
260  delete aprioriDatabase_;
261  aprioriDatabase_ = nullptr;
262  }
263 
264  if (mutualInfo_) {
265  delete mutualInfo_;
266  mutualInfo_ = nullptr;
267  }
268 
269  scoreType_ = from.scoreType_;
270  paramEstimatorType_ = from.paramEstimatorType_;
271  epsilonEM_ = from.epsilonEM_;
272  aprioriType_ = from.aprioriType_;
273  aprioriWeight_ = from.aprioriWeight_;
274  constraintSliceOrder_ = from.constraintSliceOrder_;
275  constraintIndegree_ = from.constraintIndegree_;
276  constraintTabuList_ = from.constraintTabuList_;
277  constraintForbiddenArcs_ = from.constraintForbiddenArcs_;
278  constraintMandatoryArcs_ = from.constraintMandatoryArcs_;
279  selectedAlgo_ = from.selectedAlgo_;
280  algoK2_ = from.algoK2_;
281  algoMiic3off2_ = from.algoMiic3off2_;
282  kmode3Off2_ = from.kmode3Off2_;
283  greedyHillClimbing_ = from.greedyHillClimbing_;
284  localSearchWithTabuList_ = from.localSearchWithTabuList_;
285  scoreDatabase_ = from.scoreDatabase_;
286  ranges_ = from.ranges_;
287  aprioriDbname_ = from.aprioriDbname_;
288  initialDag_ = from.initialDag_;
289  currentAlgorithm_ = nullptr;
290  }
291 
292  return *this;
293  }
294 
295  genericBNLearner& genericBNLearner::operator=(genericBNLearner&& from) {
296  if (this != &from) {
297  if (score_) {
298  delete score_;
299  score_ = nullptr;
300  }
301 
302  if (apriori_) {
303  delete apriori_;
304  apriori_ = nullptr;
305  }
306 
307  if (aprioriDatabase_) {
308  delete aprioriDatabase_;
309  aprioriDatabase_ = nullptr;
310  }
311 
312  if (mutualInfo_) {
313  delete mutualInfo_;
314  mutualInfo_ = nullptr;
315  }
316 
317  scoreType_ = from.scoreType_;
318  paramEstimatorType_ = from.paramEstimatorType_;
319  epsilonEM_ = from.epsilonEM_;
320  aprioriType_ = from.aprioriType_;
321  aprioriWeight_ = from.aprioriWeight_;
322  constraintSliceOrder_ = std::move(from.constraintSliceOrder_);
323  constraintIndegree_ = std::move(from.constraintIndegree_);
324  constraintTabuList_ = std::move(from.constraintTabuList_);
325  constraintForbiddenArcs_ = std::move(from.constraintForbiddenArcs_);
326  constraintMandatoryArcs_ = std::move(from.constraintMandatoryArcs_);
327  selectedAlgo_ = from.selectedAlgo_;
328  algoK2_ = from.algoK2_;
329  algoMiic3off2_ = std::move(from.algoMiic3off2_);
330  kmode3Off2_ = from.kmode3Off2_;
331  greedyHillClimbing_ = std::move(from.greedyHillClimbing_);
332  localSearchWithTabuList_ = std::move(from.localSearchWithTabuList_);
333  scoreDatabase_ = std::move(from.scoreDatabase_);
334  ranges_ = std::move(from.ranges_);
335  aprioriDbname_ = std::move(from.aprioriDbname_);
336  initialDag_ = std::move(from.initialDag_);
337  currentAlgorithm_ = nullptr;
338  }
339 
340  return *this;
341  }
342 
343 
344  DatabaseTable<> readFile(const std::string& filename) {
345  // get the extension of the file
346  Size filename_size = Size(filename.size());
347 
348  if (filename_size < 4) {
349  GUM_ERROR(FormatNotFound,
350  "genericBNLearner could not determine the "
351  "file type of the database");
352  }
353 
354  std::string extension = filename.substr(filename.size() - 4);
355  std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
356 
357  if (extension != ".csv") {
358  GUM_ERROR(OperationNotAllowed,
359  "genericBNLearner does not support yet this type "
360  "of database file");
361  }
362 
363  DBInitializerFromCSV<> initializer(filename);
364 
365  const auto& var_names = initializer.variableNames();
366  const std::size_t nb_vars = var_names.size();
367 
368  DBTranslatorSet<> translator_set;
369  DBTranslator4LabelizedVariable<> translator;
370  for (std::size_t i = 0; i < nb_vars; ++i) {
371  translator_set.insertTranslator(translator, i);
372  }
373 
374  DatabaseTable<> database(translator_set);
375  database.setVariableNames(initializer.variableNames());
376  initializer.fillDatabase(database);
377 
378  return database;
379  }
380 
381 
382  void genericBNLearner::checkFileName_(const std::string& filename) {
383  // get the extension of the file
384  Size filename_size = Size(filename.size());
385 
386  if (filename_size < 4) {
387  GUM_ERROR(FormatNotFound,
388  "genericBNLearner could not determine the "
389  "file type of the database");
390  }
391 
392  std::string extension = filename.substr(filename.size() - 4);
393  std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
394 
395  if (extension != ".csv") {
396  GUM_ERROR(OperationNotAllowed,
397  "genericBNLearner does not support yet this type of database file");
398  }
399  }
400 
401 
402  DatabaseTable<> genericBNLearner::readFile_(const std::string& filename,
403  const std::vector< std::string >& missing_symbols) {
404  // get the extension of the file
405  checkFileName_(filename);
406 
407  DBInitializerFromCSV<> initializer(filename);
408 
409  const auto& var_names = initializer.variableNames();
410  const std::size_t nb_vars = var_names.size();
411 
412  DBTranslatorSet<> translator_set;
413  DBTranslator4LabelizedVariable<> translator(missing_symbols);
414  for (std::size_t i = 0; i < nb_vars; ++i) {
415  translator_set.insertTranslator(translator, i);
416  }
417 
418  DatabaseTable<> database(missing_symbols, translator_set);
419  database.setVariableNames(initializer.variableNames());
420  initializer.fillDatabase(database);
421 
422  database.reorder();
423 
424  return database;
425  }
426 
427 
428  void genericBNLearner::createApriori_() {
429  // first, save the old apriori, to be delete if everything is ok
430  Apriori<>* old_apriori = apriori_;
431 
432  // create the new apriori
433  switch (aprioriType_) {
434  case AprioriType::NO_APRIORI:
435  apriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable(),
436  scoreDatabase_.nodeId2Columns());
437  break;
438 
439  case AprioriType::SMOOTHING:
440  apriori_ = new AprioriSmoothing<>(scoreDatabase_.databaseTable(),
441  scoreDatabase_.nodeId2Columns());
442  break;
443 
444  case AprioriType::DIRICHLET_FROM_DATABASE:
445  if (aprioriDatabase_ != nullptr) {
446  delete aprioriDatabase_;
447  aprioriDatabase_ = nullptr;
448  }
449 
450  aprioriDatabase_
451  = new Database(aprioriDbname_, scoreDatabase_, scoreDatabase_.missingSymbols());
452 
453  apriori_ = new AprioriDirichletFromDatabase<>(scoreDatabase_.databaseTable(),
454  aprioriDatabase_->parser(),
455  aprioriDatabase_->nodeId2Columns());
456  break;
457 
458  case AprioriType::BDEU:
459  apriori_
460  = new AprioriBDeu<>(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
461  break;
462 
463  default:
464  GUM_ERROR(OperationNotAllowed, "The BNLearner does not support yet this apriori")
465  }
466 
467  // do not forget to assign a weight to the apriori
468  apriori_->setWeight(aprioriWeight_);
469 
470  // remove the old apriori, if any
471  if (old_apriori != nullptr) delete old_apriori;
472  }
473 
474  void genericBNLearner::createScore_() {
475  // first, save the old score, to be delete if everything is ok
476  Score<>* old_score = score_;
477 
478  // create the new scoring function
479  switch (scoreType_) {
480  case ScoreType::AIC:
481  score_ = new ScoreAIC<>(scoreDatabase_.parser(),
482  *apriori_,
483  ranges_,
484  scoreDatabase_.nodeId2Columns());
485  break;
486 
487  case ScoreType::BD:
488  score_ = new ScoreBD<>(scoreDatabase_.parser(),
489  *apriori_,
490  ranges_,
491  scoreDatabase_.nodeId2Columns());
492  break;
493 
494  case ScoreType::BDeu:
495  score_ = new ScoreBDeu<>(scoreDatabase_.parser(),
496  *apriori_,
497  ranges_,
498  scoreDatabase_.nodeId2Columns());
499  break;
500 
501  case ScoreType::BIC:
502  score_ = new ScoreBIC<>(scoreDatabase_.parser(),
503  *apriori_,
504  ranges_,
505  scoreDatabase_.nodeId2Columns());
506  break;
507 
508  case ScoreType::K2:
509  score_ = new ScoreK2<>(scoreDatabase_.parser(),
510  *apriori_,
511  ranges_,
512  scoreDatabase_.nodeId2Columns());
513  break;
514 
515  case ScoreType::LOG2LIKELIHOOD:
516  score_ = new ScoreLog2Likelihood<>(scoreDatabase_.parser(),
517  *apriori_,
518  ranges_,
519  scoreDatabase_.nodeId2Columns());
520  break;
521 
522  default:
523  GUM_ERROR(OperationNotAllowed, "genericBNLearner does not support yet this score")
524  }
525 
526  // remove the old score, if any
527  if (old_score != nullptr) delete old_score;
528  }
529 
530  ParamEstimator<>* genericBNLearner::createParamEstimator_(DBRowGeneratorParser<>& parser,
531  bool take_into_account_score) {
532  ParamEstimator<>* param_estimator = nullptr;
533 
534  // create the new estimator
535  switch (paramEstimatorType_) {
536  case ParamEstimatorType::ML:
537  if (take_into_account_score && (score_ != nullptr)) {
538  param_estimator = new ParamEstimatorML<>(parser,
539  *apriori_,
540  score_->internalApriori(),
541  ranges_,
542  scoreDatabase_.nodeId2Columns());
543  } else {
544  param_estimator = new ParamEstimatorML<>(parser,
545  *apriori_,
546  *noApriori_,
547  ranges_,
548  scoreDatabase_.nodeId2Columns());
549  }
550 
551  break;
552 
553  default:
554  GUM_ERROR(OperationNotAllowed,
555  "genericBNLearner does not support "
556  << "yet this parameter estimator");
557  }
558 
559  // assign the set of ranges
560  param_estimator->setRanges(ranges_);
561 
562  return param_estimator;
563  }
564 
565  /// prepares the initial graph for 3off2 or miic
566  MixedGraph genericBNLearner::prepareMiic3Off2_() {
567  // Initialize the mixed graph to the fully connected graph
568  MixedGraph mgraph;
569  for (Size i = 0; i < scoreDatabase_.databaseTable().nbVariables(); ++i) {
570  mgraph.addNodeWithId(i);
571  for (Size j = 0; j < i; ++j) {
572  mgraph.addEdge(j, i);
573  }
574  }
575 
576  // translating the constraints for 3off2 or miic
577  HashTable< std::pair< NodeId, NodeId >, char > initial_marks;
578  const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
579  for (const auto& arc: mandatory_arcs) {
580  initial_marks.insert({arc.tail(), arc.head()}, '>');
581  }
582 
583  const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
584  for (const auto& arc: forbidden_arcs) {
585  initial_marks.insert({arc.tail(), arc.head()}, '-');
586  }
587  algoMiic3off2_.addConstraints(initial_marks);
588 
589  // create the mutual entropy object
590  // if ( _mutual_info_ == nullptr) { this->useNMLCorrection(); }
591  createCorrectedMutualInformation_();
592 
593  return mgraph;
594  }
595 
596  MixedGraph genericBNLearner::learnMixedStructure() {
597  if (selectedAlgo_ != AlgoType::MIIC_THREE_OFF_TWO) {
598  GUM_ERROR(OperationNotAllowed, "Must be using the miic/3off2 algorithm")
599  }
600  // check that the database does not contain any missing value
601  if (scoreDatabase_.databaseTable().hasMissingValues()) {
602  GUM_ERROR(MissingValueInDatabase,
603  "For the moment, the BNLearner is unable to learn "
604  << "structures with missing values in databases");
605  }
606  BNLearnerListener listener(this, algoMiic3off2_);
607 
608  // create the mixedGraph_constraint_MandatoryArcs.arcs();
609  MixedGraph mgraph = this->prepareMiic3Off2_();
610 
611  return algoMiic3off2_.learnMixedStructure(*mutualInfo_, mgraph);
612  }
613 
614  DAG genericBNLearner::learnDAG() {
615  // create the score and the apriori
616  createApriori_();
617  createScore_();
618 
619  return learnDag_();
620  }
621 
622  void genericBNLearner::createCorrectedMutualInformation_() {
623  if (mutualInfo_ != nullptr) delete mutualInfo_;
624 
625  mutualInfo_ = new CorrectedMutualInformation<>(scoreDatabase_.parser(),
626  *noApriori_,
627  ranges_,
628  scoreDatabase_.nodeId2Columns());
629  switch (kmode3Off2_) {
630  case CorrectedMutualInformation<>::KModeTypes::MDL:
631  mutualInfo_->useMDL();
632  break;
633 
634  case CorrectedMutualInformation<>::KModeTypes::NML:
635  mutualInfo_->useNML();
636  break;
637 
638  case CorrectedMutualInformation<>::KModeTypes::NoCorr:
639  mutualInfo_->useNoCorr();
640  break;
641 
642  default:
643  GUM_ERROR(NotImplementedYet,
644  "The BNLearner's corrected mutual information class does "
645  << "not implement yet this correction : " << int(kmode3Off2_));
646  }
647  }
648 
649  DAG genericBNLearner::learnDag_() {
650  // check that the database does not contain any missing value
651  if (scoreDatabase_.databaseTable().hasMissingValues()
652  || ((aprioriDatabase_ != nullptr)
653  && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
654  && aprioriDatabase_->databaseTable().hasMissingValues())) {
655  GUM_ERROR(MissingValueInDatabase,
656  "For the moment, the BNLearner is unable to cope "
657  "with missing values in databases");
658  }
659  // add the mandatory arcs to the initial dag and remove the forbidden ones
660  // from the initial graph
661  DAG init_graph = initialDag_;
662 
663  const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
664 
665  for (const auto& arc: mandatory_arcs) {
666  if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
667 
668  if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
669 
670  init_graph.addArc(arc.tail(), arc.head());
671  }
672 
673  const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
674 
675  for (const auto& arc: forbidden_arcs) {
676  init_graph.eraseArc(arc);
677  }
678 
679  switch (selectedAlgo_) {
680  // ========================================================================
681  case AlgoType::MIIC_THREE_OFF_TWO: {
682  BNLearnerListener listener(this, algoMiic3off2_);
683  // create the mixedGraph and the corrected mutual information
684  MixedGraph mgraph = this->prepareMiic3Off2_();
685 
686  return algoMiic3off2_.learnStructure(*mutualInfo_, mgraph);
687  }
688 
689  // ========================================================================
690  case AlgoType::GREEDY_HILL_CLIMBING: {
691  BNLearnerListener listener(this, greedyHillClimbing_);
692  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
693  StructuralConstraintForbiddenArcs,
694  StructuralConstraintPossibleEdges,
695  StructuralConstraintSliceOrder >
696  gen_constraint;
697  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
698  = constraintMandatoryArcs_;
699  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
700  = constraintForbiddenArcs_;
701  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
702  = constraintPossibleEdges_;
703  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
704 
705  GraphChangesGenerator4DiGraph< decltype(gen_constraint) > op_set(gen_constraint);
706 
707  StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
708  sel_constraint;
709  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
710 
711  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
712  *score_,
713  sel_constraint,
714  op_set);
715 
716  return greedyHillClimbing_.learnStructure(selector, init_graph);
717  }
718 
719  // ========================================================================
720  case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST: {
721  BNLearnerListener listener(this, localSearchWithTabuList_);
722  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
723  StructuralConstraintForbiddenArcs,
724  StructuralConstraintPossibleEdges,
725  StructuralConstraintSliceOrder >
726  gen_constraint;
727  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
728  = constraintMandatoryArcs_;
729  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
730  = constraintForbiddenArcs_;
731  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
732  = constraintPossibleEdges_;
733  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
734 
735  GraphChangesGenerator4DiGraph< decltype(gen_constraint) > op_set(gen_constraint);
736 
737  StructuralConstraintSetStatic< StructuralConstraintTabuList,
738  StructuralConstraintIndegree,
739  StructuralConstraintDAG >
740  sel_constraint;
741  static_cast< StructuralConstraintTabuList& >(sel_constraint) = constraintTabuList_;
742  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
743 
744  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
745  *score_,
746  sel_constraint,
747  op_set);
748 
749  return localSearchWithTabuList_.learnStructure(selector, init_graph);
750  }
751 
752  // ========================================================================
753  case AlgoType::K2: {
754  BNLearnerListener listener(this, algoK2_.approximationScheme());
755  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
756  StructuralConstraintForbiddenArcs,
757  StructuralConstraintPossibleEdges >
758  gen_constraint;
759  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
760  = constraintMandatoryArcs_;
761  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
762  = constraintForbiddenArcs_;
763  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
764  = constraintPossibleEdges_;
765 
766  GraphChangesGenerator4K2< decltype(gen_constraint) > op_set(gen_constraint);
767 
768  // if some mandatory arcs are incompatible with the order, use a DAG
769  // constraint instead of a DiGraph constraint to avoid cycles
770  const ArcSet& mandatory_arcs
771  = static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint).arcs();
772  const Sequence< NodeId >& order = algoK2_.order();
773  bool order_compatible = true;
774 
775  for (const auto& arc: mandatory_arcs) {
776  if (order.pos(arc.tail()) >= order.pos(arc.head())) {
777  order_compatible = false;
778  break;
779  }
780  }
781 
782  if (order_compatible) {
783  StructuralConstraintSetStatic< StructuralConstraintIndegree,
784  StructuralConstraintDiGraph >
785  sel_constraint;
786  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
787 
788  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
789  *score_,
790  sel_constraint,
791  op_set);
792 
793  return algoK2_.learnStructure(selector, init_graph);
794  } else {
795  StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
796  sel_constraint;
797  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
798 
799  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
800  *score_,
801  sel_constraint,
802  op_set);
803 
804  return algoK2_.learnStructure(selector, init_graph);
805  }
806  }
807 
808  // ========================================================================
809  default:
810  GUM_ERROR(OperationNotAllowed,
811  "the learnDAG method has not been implemented for this "
812  "learning algorithm");
813  }
814  }
815 
816  std::string genericBNLearner::checkScoreAprioriCompatibility() {
817  const std::string& apriori = getAprioriType_();
818 
819  switch (scoreType_) {
820  case ScoreType::AIC:
821  return ScoreAIC<>::isAprioriCompatible(apriori, aprioriWeight_);
822 
823  case ScoreType::BD:
824  return ScoreBD<>::isAprioriCompatible(apriori, aprioriWeight_);
825 
826  case ScoreType::BDeu:
827  return ScoreBDeu<>::isAprioriCompatible(apriori, aprioriWeight_);
828 
829  case ScoreType::BIC:
830  return ScoreBIC<>::isAprioriCompatible(apriori, aprioriWeight_);
831 
832  case ScoreType::K2:
833  return ScoreK2<>::isAprioriCompatible(apriori, aprioriWeight_);
834 
835  case ScoreType::LOG2LIKELIHOOD:
836  return ScoreLog2Likelihood<>::isAprioriCompatible(apriori, aprioriWeight_);
837 
838  default:
839  return "genericBNLearner does not support yet this score";
840  }
841  }
842 
843 
844  /// sets the ranges of rows to be used for cross-validation learning
845  std::pair< std::size_t, std::size_t >
846  genericBNLearner::useCrossValidationFold(const std::size_t learning_fold,
847  const std::size_t k_fold) {
848  if (k_fold == 0) { GUM_ERROR(OutOfBounds, "K-fold cross validation with k=0 is forbidden") }
849 
850  if (learning_fold >= k_fold) {
851  GUM_ERROR(OutOfBounds,
852  "In " << k_fold << "-fold cross validation, the learning "
853  << "fold should be strictly lower than " << k_fold
854  << " but, here, it is equal to " << learning_fold);
855  }
856 
857  const std::size_t db_size = scoreDatabase_.databaseTable().nbRows();
858  if (k_fold >= db_size) {
859  GUM_ERROR(OutOfBounds,
860  "In " << k_fold << "-fold cross validation, the database's "
861  << "size should be strictly greater than " << k_fold
862  << " but, here, the database has only " << db_size << "rows");
863  }
864 
865  // create the ranges of rows of the test database
866  const std::size_t foldSize = db_size / k_fold;
867  const std::size_t unfold_deb = learning_fold * foldSize;
868  const std::size_t unfold_end = unfold_deb + foldSize;
869 
870  ranges_.clear();
871  if (learning_fold == std::size_t(0)) {
872  ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
873  } else {
874  ranges_.push_back(std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
875 
876  if (learning_fold != k_fold - 1) {
877  ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
878  }
879  }
880 
881  return std::pair< std::size_t, std::size_t >(unfold_deb, unfold_end);
882  }
883 
884 
885  std::pair< double, double > genericBNLearner::chi2(const NodeId id1,
886  const NodeId id2,
887  const std::vector< NodeId >& knowing) {
888  createApriori_();
889  gum::learning::IndepTestChi2<> chi2score(scoreDatabase_.parser(),
890  *apriori_,
891  databaseRanges());
892 
893  return chi2score.statistics(id1, id2, knowing);
894  }
895 
896  std::pair< double, double > genericBNLearner::chi2(const std::string& name1,
897  const std::string& name2,
898  const std::vector< std::string >& knowing) {
899  std::vector< NodeId > knowingIds;
900  std::transform(knowing.begin(),
901  knowing.end(),
902  std::back_inserter(knowingIds),
903  [this](const std::string& c) -> NodeId { return this->idFromName(c); });
904  return chi2(idFromName(name1), idFromName(name2), knowingIds);
905  }
906 
907  std::pair< double, double > genericBNLearner::G2(const NodeId id1,
908  const NodeId id2,
909  const std::vector< NodeId >& knowing) {
910  createApriori_();
911  gum::learning::IndepTestG2<> g2score(scoreDatabase_.parser(), *apriori_, databaseRanges());
912  return g2score.statistics(id1, id2, knowing);
913  }
914 
915  std::pair< double, double > genericBNLearner::G2(const std::string& name1,
916  const std::string& name2,
917  const std::vector< std::string >& knowing) {
918  std::vector< NodeId > knowingIds;
919  std::transform(knowing.begin(),
920  knowing.end(),
921  std::back_inserter(knowingIds),
922  [this](const std::string& c) -> NodeId { return this->idFromName(c); });
923  return G2(idFromName(name1), idFromName(name2), knowingIds);
924  }
925 
926  double genericBNLearner::logLikelihood(const std::vector< NodeId >& vars,
927  const std::vector< NodeId >& knowing) {
928  createApriori_();
929  gum::learning::ScoreLog2Likelihood<> ll2score(scoreDatabase_.parser(),
930  *apriori_,
931  databaseRanges());
932 
933  std::vector< NodeId > total(vars);
934  total.insert(total.end(), knowing.begin(), knowing.end());
935  double LLtotal = ll2score.score(IdCondSet<>(total, false, true));
936  if (knowing.size() == (Size)0) {
937  return LLtotal;
938  } else {
939  double LLknw = ll2score.score(IdCondSet<>(knowing, false, true));
940  return LLtotal - LLknw;
941  }
942  }
943 
944  double genericBNLearner::logLikelihood(const std::vector< std::string >& vars,
945  const std::vector< std::string >& knowing) {
946  std::vector< NodeId > ids;
947  std::vector< NodeId > knowingIds;
948 
949  auto mapper = [this](const std::string& c) -> NodeId {
950  return this->idFromName(c);
951  };
952 
953  std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
954  std::transform(knowing.begin(), knowing.end(), std::back_inserter(knowingIds), mapper);
955 
956  return logLikelihood(ids, knowingIds);
957  }
958 
959  std::vector< double > genericBNLearner::rawPseudoCount(const std::vector< NodeId >& vars) {
960  Potential< double > res;
961 
962  createApriori_();
963  gum::learning::PseudoCount<> count(scoreDatabase_.parser(), *apriori_, databaseRanges());
964  return count.get(vars);
965  }
966 
967 
968  std::vector< double > genericBNLearner::rawPseudoCount(const std::vector< std::string >& vars) {
969  std::vector< NodeId > ids;
970 
971  auto mapper = [this](const std::string& c) -> NodeId {
972  return this->idFromName(c);
973  };
974 
975  std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
976 
977  return rawPseudoCount(ids);
978  }
979 
980  } /* namespace learning */
981 
982 } /* namespace gum */