aGrUM  0.21.0
a C++ library for (probabilistic) graphical models
genericBNLearner.cpp
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 
31 #include <algorithm>
32 #include <iterator>
33 
34 #include <agrum/agrum.h>
35 #include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h>
36 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h>
37 #include <agrum/tools/stattests/indepTestChi2.h>
38 #include <agrum/tools/stattests/indepTestG2.h>
39 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
40 #include <agrum/tools/stattests/pseudoCount.h>
41 
42 // include the inlined functions if necessary
43 #ifdef GUM_NO_INLINE
44 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h>
45 #endif /* GUM_NO_INLINE */
46 
47 namespace gum {
48 
49  namespace learning {
50 
51 
52  genericBNLearner::Database::Database(const DatabaseTable<>& db) : _database_(db) {
53  // get the variables names
54  const auto& var_names = _database_.variableNames();
55  const std::size_t nb_vars = var_names.size();
56  for (auto dom: _database_.domainSizes())
57  _domain_sizes_.push_back(dom);
58  for (std::size_t i = 0; i < nb_vars; ++i) {
59  _nodeId2cols_.insert(NodeId(i), i);
60  }
61 
62  // create the parser
63  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
64  }
65 
66 
67  genericBNLearner::Database::Database(const std::string& filename,
68  const std::vector< std::string >& missing_symbols) :
69  Database(genericBNLearner::readFile_(filename, missing_symbols)) {}
70 
71 
72  genericBNLearner::Database::Database(const std::string& CSV_filename,
73  Database& score_database,
74  const std::vector< std::string >& missing_symbols) {
75  // assign to each column name in the CSV file its column
76  genericBNLearner::isCSVFileName_(CSV_filename);
77  DBInitializerFromCSV<> initializer(CSV_filename);
78  const auto& apriori_names = initializer.variableNames();
79  std::size_t apriori_nb_vars = apriori_names.size();
80  HashTable< std::string, std::size_t > apriori_names2col(apriori_nb_vars);
81  for (std::size_t i = std::size_t(0); i < apriori_nb_vars; ++i)
82  apriori_names2col.insert(apriori_names[i], i);
83 
84  // check that there are at least as many variables in the a priori
85  // database as those in the score_database
86  if (apriori_nb_vars < score_database._database_.nbVariables()) {
87  GUM_ERROR(InvalidArgument,
88  "the a apriori database has fewer variables "
89  "than the observed database");
90  }
91 
92  // get the mapping from the columns of score_database to those of
93  // the CSV file
94  const std::vector< std::string >& score_names
95  = score_database.databaseTable().variableNames();
96  const std::size_t score_nb_vars = score_names.size();
97  HashTable< std::size_t, std::size_t > mapping(score_nb_vars);
98  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
99  try {
100  mapping.insert(i, apriori_names2col[score_names[i]]);
101  } catch (Exception&) {
102  GUM_ERROR(MissingVariableInDatabase,
103  "Variable " << score_names[i]
104  << " of the observed database does not belong to the "
105  << "apriori database");
106  }
107  }
108 
109  // create the translators for CSV database
110  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
111  const Variable& var = score_database.databaseTable().variable(i);
112  _database_.insertTranslator(var, mapping[i], missing_symbols);
113  }
114 
115  // fill the database
116  initializer.fillDatabase(_database_);
117 
118  // get the domain sizes of the variables
119  for (auto dom: _database_.domainSizes())
120  _domain_sizes_.push_back(dom);
121 
122  // compute the mapping from node ids to column indices
123  _nodeId2cols_ = score_database.nodeId2Columns();
124 
125  // create the parser
126  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
127  }
128 
129 
130  genericBNLearner::Database::Database(const Database& from) :
131  _database_(from._database_), _domain_sizes_(from._domain_sizes_),
132  _nodeId2cols_(from._nodeId2cols_) {
133  // create the parser
134  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
135  }
136 
137 
138  genericBNLearner::Database::Database(Database&& from) :
139  _database_(std::move(from._database_)), _domain_sizes_(std::move(from._domain_sizes_)),
140  _nodeId2cols_(std::move(from._nodeId2cols_)) {
141  // create the parser
142  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
143  }
144 
145 
146  genericBNLearner::Database::~Database() { delete _parser_; }
147 
148  genericBNLearner::Database& genericBNLearner::Database::operator=(const Database& from) {
149  if (this != &from) {
150  delete _parser_;
151  _database_ = from._database_;
152  _domain_sizes_ = from._domain_sizes_;
153  _nodeId2cols_ = from._nodeId2cols_;
154 
155  // create the parser
156  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
157  }
158 
159  return *this;
160  }
161 
162  genericBNLearner::Database& genericBNLearner::Database::operator=(Database&& from) {
163  if (this != &from) {
164  delete _parser_;
165  _database_ = std::move(from._database_);
166  _domain_sizes_ = std::move(from._domain_sizes_);
167  _nodeId2cols_ = std::move(from._nodeId2cols_);
168 
169  // create the parser
170  _parser_ = new DBRowGeneratorParser<>(_database_.handler(), DBRowGeneratorSet<>());
171  }
172 
173  return *this;
174  }
175 
176 
177  // ===========================================================================
178 
179  genericBNLearner::genericBNLearner(const std::string& filename,
180  const std::vector< std::string >& missing_symbols) :
181  scoreDatabase_(filename, missing_symbols) {
182  filename_ = filename;
183  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
184 
185  GUM_CONSTRUCTOR(genericBNLearner);
186  }
187 
188 
189  genericBNLearner::genericBNLearner(const DatabaseTable<>& db) : scoreDatabase_(db) {
190  filename_ = "-";
191  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
192 
193  GUM_CONSTRUCTOR(genericBNLearner);
194  }
195 
196 
197  genericBNLearner::genericBNLearner(const genericBNLearner& from) :
198  scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
199  epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
200  aprioriWeight_(from.aprioriWeight_), constraintSliceOrder_(from.constraintSliceOrder_),
201  constraintIndegree_(from.constraintIndegree_),
202  constraintTabuList_(from.constraintTabuList_),
203  constraintForbiddenArcs_(from.constraintForbiddenArcs_),
204  constraintMandatoryArcs_(from.constraintMandatoryArcs_), selectedAlgo_(from.selectedAlgo_),
205  algoK2_(from.algoK2_), algoMiic3off2_(from.algoMiic3off2_), kmode3Off2_(from.kmode3Off2_),
206  greedyHillClimbing_(from.greedyHillClimbing_),
207  localSearchWithTabuList_(from.localSearchWithTabuList_),
208  scoreDatabase_(from.scoreDatabase_), ranges_(from.ranges_),
209  aprioriDbname_(from.aprioriDbname_), initialDag_(from.initialDag_),
210  filename_(from.filename_), nbDecreasingChanges_(from.nbDecreasingChanges_) {
211  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
212 
213  GUM_CONS_CPY(genericBNLearner);
214  }
215 
216  genericBNLearner::genericBNLearner(genericBNLearner&& from) :
217  scoreType_(from.scoreType_), paramEstimatorType_(from.paramEstimatorType_),
218  epsilonEM_(from.epsilonEM_), aprioriType_(from.aprioriType_),
219  aprioriWeight_(from.aprioriWeight_),
220  constraintSliceOrder_(std::move(from.constraintSliceOrder_)),
221  constraintIndegree_(std::move(from.constraintIndegree_)),
222  constraintTabuList_(std::move(from.constraintTabuList_)),
223  constraintForbiddenArcs_(std::move(from.constraintForbiddenArcs_)),
224  constraintMandatoryArcs_(std::move(from.constraintMandatoryArcs_)),
225  selectedAlgo_(from.selectedAlgo_), algoK2_(std::move(from.algoK2_)),
226  algoMiic3off2_(std::move(from.algoMiic3off2_)), kmode3Off2_(from.kmode3Off2_),
227  greedyHillClimbing_(std::move(from.greedyHillClimbing_)),
228  localSearchWithTabuList_(std::move(from.localSearchWithTabuList_)),
229  scoreDatabase_(std::move(from.scoreDatabase_)), ranges_(std::move(from.ranges_)),
230  aprioriDbname_(std::move(from.aprioriDbname_)), initialDag_(std::move(from.initialDag_)),
231  filename_(std::move(from.filename_)),
232  nbDecreasingChanges_(std::move(from.nbDecreasingChanges_)) {
233  noApriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable());
234 
235  GUM_CONS_MOV(genericBNLearner)
236  }
237 
238  genericBNLearner::~genericBNLearner() {
239  if (score_) delete score_;
240 
241  if (apriori_) delete apriori_;
242 
243  if (noApriori_) delete noApriori_;
244 
245  if (aprioriDatabase_) delete aprioriDatabase_;
246 
247  if (mutualInfo_) delete mutualInfo_;
248 
249  GUM_DESTRUCTOR(genericBNLearner);
250  }
251 
252  genericBNLearner& genericBNLearner::operator=(const genericBNLearner& from) {
253  if (this != &from) {
254  if (score_) {
255  delete score_;
256  score_ = nullptr;
257  }
258 
259  if (apriori_) {
260  delete apriori_;
261  apriori_ = nullptr;
262  }
263 
264  if (aprioriDatabase_) {
265  delete aprioriDatabase_;
266  aprioriDatabase_ = nullptr;
267  }
268 
269  if (mutualInfo_) {
270  delete mutualInfo_;
271  mutualInfo_ = nullptr;
272  }
273 
274  scoreType_ = from.scoreType_;
275  paramEstimatorType_ = from.paramEstimatorType_;
276  epsilonEM_ = from.epsilonEM_;
277  aprioriType_ = from.aprioriType_;
278  aprioriWeight_ = from.aprioriWeight_;
279  constraintSliceOrder_ = from.constraintSliceOrder_;
280  constraintIndegree_ = from.constraintIndegree_;
281  constraintTabuList_ = from.constraintTabuList_;
282  constraintForbiddenArcs_ = from.constraintForbiddenArcs_;
283  constraintMandatoryArcs_ = from.constraintMandatoryArcs_;
284  selectedAlgo_ = from.selectedAlgo_;
285  algoK2_ = from.algoK2_;
286  algoMiic3off2_ = from.algoMiic3off2_;
287  kmode3Off2_ = from.kmode3Off2_;
288  greedyHillClimbing_ = from.greedyHillClimbing_;
289  localSearchWithTabuList_ = from.localSearchWithTabuList_;
290  scoreDatabase_ = from.scoreDatabase_;
291  ranges_ = from.ranges_;
292  aprioriDbname_ = from.aprioriDbname_;
293  initialDag_ = from.initialDag_;
294  filename_ = from.filename_;
295  nbDecreasingChanges_ = from.nbDecreasingChanges_;
296  currentAlgorithm_ = nullptr;
297  }
298 
299  return *this;
300  }
301 
302  genericBNLearner& genericBNLearner::operator=(genericBNLearner&& from) {
303  if (this != &from) {
304  if (score_) {
305  delete score_;
306  score_ = nullptr;
307  }
308 
309  if (apriori_) {
310  delete apriori_;
311  apriori_ = nullptr;
312  }
313 
314  if (aprioriDatabase_) {
315  delete aprioriDatabase_;
316  aprioriDatabase_ = nullptr;
317  }
318 
319  if (mutualInfo_) {
320  delete mutualInfo_;
321  mutualInfo_ = nullptr;
322  }
323 
324  scoreType_ = from.scoreType_;
325  paramEstimatorType_ = from.paramEstimatorType_;
326  epsilonEM_ = from.epsilonEM_;
327  aprioriType_ = from.aprioriType_;
328  aprioriWeight_ = from.aprioriWeight_;
329  constraintSliceOrder_ = std::move(from.constraintSliceOrder_);
330  constraintIndegree_ = std::move(from.constraintIndegree_);
331  constraintTabuList_ = std::move(from.constraintTabuList_);
332  constraintForbiddenArcs_ = std::move(from.constraintForbiddenArcs_);
333  constraintMandatoryArcs_ = std::move(from.constraintMandatoryArcs_);
334  selectedAlgo_ = from.selectedAlgo_;
335  algoK2_ = from.algoK2_;
336  algoMiic3off2_ = std::move(from.algoMiic3off2_);
337  kmode3Off2_ = from.kmode3Off2_;
338  greedyHillClimbing_ = std::move(from.greedyHillClimbing_);
339  localSearchWithTabuList_ = std::move(from.localSearchWithTabuList_);
340  scoreDatabase_ = std::move(from.scoreDatabase_);
341  ranges_ = std::move(from.ranges_);
342  aprioriDbname_ = std::move(from.aprioriDbname_);
343  filename_ = std::move(from.filename_);
344  initialDag_ = std::move(from.initialDag_);
345  nbDecreasingChanges_ = std::move(from.nbDecreasingChanges_);
346  currentAlgorithm_ = nullptr;
347  }
348 
349  return *this;
350  }
351 
352 
353  DatabaseTable<> readFile(const std::string& filename) {
354  // get the extension of the file
355  Size filename_size = Size(filename.size());
356 
357  if (filename_size < 4) {
358  GUM_ERROR(FormatNotFound,
359  "genericBNLearner could not determine the "
360  "file type of the database '"
361  << filename << "'");
362  }
363 
364  std::string extension = filename.substr(filename.size() - 4);
365  std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
366 
367  if (extension != ".csv") {
368  GUM_ERROR(OperationNotAllowed,
369  "genericBNLearner does not support yet this type ('" << extension
370  << "')"
371  "of database file");
372  }
373 
374  DBInitializerFromCSV<> initializer(filename);
375 
376  const auto& var_names = initializer.variableNames();
377  const std::size_t nb_vars = var_names.size();
378 
379  DBTranslatorSet<> translator_set;
380  DBTranslator4LabelizedVariable<> translator;
381  for (std::size_t i = 0; i < nb_vars; ++i) {
382  translator_set.insertTranslator(translator, i);
383  }
384 
385  DatabaseTable<> database(translator_set);
386  database.setVariableNames(initializer.variableNames());
387  initializer.fillDatabase(database);
388 
389  return database;
390  }
391 
392 
393  void genericBNLearner::isCSVFileName_(const std::string& filename) {
394  // get the extension of the file
395  Size filename_size = Size(filename.size());
396 
397  if (filename_size < 4) {
398  GUM_ERROR(FormatNotFound,
399  "genericBNLearner could not determine the "
400  "file type of the database");
401  }
402 
403  std::string extension = filename.substr(filename.size() - 4);
404  std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
405 
406  if (extension != ".csv") {
407  GUM_ERROR(OperationNotAllowed,
408  "genericBNLearner does not support yet this type of database file");
409  }
410  }
411 
412 
413  DatabaseTable<> genericBNLearner::readFile_(const std::string& filename,
414  const std::vector< std::string >& missing_symbols) {
415  // get the extension of the file
416  isCSVFileName_(filename);
417 
418  DBInitializerFromCSV<> initializer(filename);
419 
420  const auto& var_names = initializer.variableNames();
421  const std::size_t nb_vars = var_names.size();
422 
423  DBTranslatorSet<> translator_set;
424  DBTranslator4LabelizedVariable<> translator(missing_symbols);
425  for (std::size_t i = 0; i < nb_vars; ++i) {
426  translator_set.insertTranslator(translator, i);
427  }
428 
429  DatabaseTable<> database(missing_symbols, translator_set);
430  database.setVariableNames(initializer.variableNames());
431  initializer.fillDatabase(database);
432 
433  database.reorder();
434 
435  return database;
436  }
437 
438 
439  void genericBNLearner::createApriori_() {
440  // first, save the old apriori, to be delete if everything is ok
441  Apriori<>* old_apriori = apriori_;
442 
443  // create the new apriori
444  switch (aprioriType_) {
445  case AprioriType::NO_APRIORI:
446  apriori_ = new AprioriNoApriori<>(scoreDatabase_.databaseTable(),
447  scoreDatabase_.nodeId2Columns());
448  break;
449 
450  case AprioriType::SMOOTHING:
451  apriori_ = new AprioriSmoothing<>(scoreDatabase_.databaseTable(),
452  scoreDatabase_.nodeId2Columns());
453  break;
454 
455  case AprioriType::DIRICHLET_FROM_DATABASE:
456  if (aprioriDatabase_ != nullptr) {
457  delete aprioriDatabase_;
458  aprioriDatabase_ = nullptr;
459  }
460 
461  aprioriDatabase_
462  = new Database(aprioriDbname_, scoreDatabase_, scoreDatabase_.missingSymbols());
463 
464  apriori_ = new AprioriDirichletFromDatabase<>(scoreDatabase_.databaseTable(),
465  aprioriDatabase_->parser(),
466  aprioriDatabase_->nodeId2Columns());
467  break;
468 
469  case AprioriType::BDEU:
470  apriori_
471  = new AprioriBDeu<>(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
472  break;
473 
474  default:
475  GUM_ERROR(OperationNotAllowed, "The BNLearner does not support yet this apriori")
476  }
477 
478  // do not forget to assign a weight to the apriori
479  apriori_->setWeight(aprioriWeight_);
480 
481  // remove the old apriori, if any
482  if (old_apriori != nullptr) delete old_apriori;
483  }
484 
485  void genericBNLearner::createScore_() {
486  // first, save the old score, to be delete if everything is ok
487  Score<>* old_score = score_;
488 
489  // create the new scoring function
490  switch (scoreType_) {
491  case ScoreType::AIC:
492  score_ = new ScoreAIC<>(scoreDatabase_.parser(),
493  *apriori_,
494  ranges_,
495  scoreDatabase_.nodeId2Columns());
496  break;
497 
498  case ScoreType::BD:
499  score_ = new ScoreBD<>(scoreDatabase_.parser(),
500  *apriori_,
501  ranges_,
502  scoreDatabase_.nodeId2Columns());
503  break;
504 
505  case ScoreType::BDeu:
506  score_ = new ScoreBDeu<>(scoreDatabase_.parser(),
507  *apriori_,
508  ranges_,
509  scoreDatabase_.nodeId2Columns());
510  break;
511 
512  case ScoreType::BIC:
513  score_ = new ScoreBIC<>(scoreDatabase_.parser(),
514  *apriori_,
515  ranges_,
516  scoreDatabase_.nodeId2Columns());
517  break;
518 
519  case ScoreType::K2:
520  score_ = new ScoreK2<>(scoreDatabase_.parser(),
521  *apriori_,
522  ranges_,
523  scoreDatabase_.nodeId2Columns());
524  break;
525 
526  case ScoreType::LOG2LIKELIHOOD:
527  score_ = new ScoreLog2Likelihood<>(scoreDatabase_.parser(),
528  *apriori_,
529  ranges_,
530  scoreDatabase_.nodeId2Columns());
531  break;
532 
533  default:
534  GUM_ERROR(OperationNotAllowed, "genericBNLearner does not support yet this score")
535  }
536 
537  // remove the old score, if any
538  if (old_score != nullptr) delete old_score;
539  }
540 
541  ParamEstimator<>* genericBNLearner::createParamEstimator_(DBRowGeneratorParser<>& parser,
542  bool take_into_account_score) {
543  ParamEstimator<>* param_estimator = nullptr;
544 
545  // create the new estimator
546  switch (paramEstimatorType_) {
547  case ParamEstimatorType::ML:
548  if (take_into_account_score && (score_ != nullptr)) {
549  param_estimator = new ParamEstimatorML<>(parser,
550  *apriori_,
551  score_->internalApriori(),
552  ranges_,
553  scoreDatabase_.nodeId2Columns());
554  } else {
555  param_estimator = new ParamEstimatorML<>(parser,
556  *apriori_,
557  *noApriori_,
558  ranges_,
559  scoreDatabase_.nodeId2Columns());
560  }
561 
562  break;
563 
564  default:
565  GUM_ERROR(OperationNotAllowed,
566  "genericBNLearner does not support "
567  << "yet this parameter estimator");
568  }
569 
570  // assign the set of ranges
571  param_estimator->setRanges(ranges_);
572 
573  return param_estimator;
574  }
575 
576  /// prepares the initial graph for 3off2 or miic
577  MixedGraph genericBNLearner::prepareMiic3Off2_() {
578  // Initialize the mixed graph to the fully connected graph
579  MixedGraph mgraph;
580  for (Size i = 0; i < scoreDatabase_.databaseTable().nbVariables(); ++i) {
581  mgraph.addNodeWithId(i);
582  for (Size j = 0; j < i; ++j) {
583  mgraph.addEdge(j, i);
584  }
585  }
586 
587  // translating the constraints for 3off2 or miic
588  HashTable< std::pair< NodeId, NodeId >, char > initial_marks;
589  const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
590  for (const auto& arc: mandatory_arcs) {
591  initial_marks.insert({arc.tail(), arc.head()}, '>');
592  }
593 
594  const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
595  for (const auto& arc: forbidden_arcs) {
596  initial_marks.insert({arc.tail(), arc.head()}, '-');
597  }
598  algoMiic3off2_.addConstraints(initial_marks);
599 
600  // create the mutual entropy object
601  // if ( _mutual_info_ == nullptr) { this->useNMLCorrection(); }
602  createCorrectedMutualInformation_();
603 
604  return mgraph;
605  }
606 
607  MixedGraph genericBNLearner::learnMixedStructure() {
608  if (selectedAlgo_ != AlgoType::MIIC && selectedAlgo_ != AlgoType::THREE_OFF_TWO) {
609  GUM_ERROR(OperationNotAllowed, "Must be using the miic/3off2 algorithm")
610  }
611  // check that the database does not contain any missing value
612  if (scoreDatabase_.databaseTable().hasMissingValues()) {
613  GUM_ERROR(MissingValueInDatabase,
614  "For the moment, the BNLearner is unable to learn "
615  << "structures with missing values in databases");
616  }
617  BNLearnerListener listener(this, algoMiic3off2_);
618 
619  // create the mixedGraph_constraint_MandatoryArcs.arcs();
620  MixedGraph mgraph = this->prepareMiic3Off2_();
621 
622  return algoMiic3off2_.learnMixedStructure(*mutualInfo_, mgraph);
623  }
624 
625  DAG genericBNLearner::learnDAG() {
626  // create the score and the apriori
627  createApriori_();
628  createScore_();
629 
630  return learnDag_();
631  }
632 
633  void genericBNLearner::createCorrectedMutualInformation_() {
634  if (mutualInfo_ != nullptr) delete mutualInfo_;
635 
636  mutualInfo_ = new CorrectedMutualInformation<>(scoreDatabase_.parser(),
637  *noApriori_,
638  ranges_,
639  scoreDatabase_.nodeId2Columns());
640  switch (kmode3Off2_) {
641  case CorrectedMutualInformation<>::KModeTypes::MDL:
642  mutualInfo_->useMDL();
643  break;
644 
645  case CorrectedMutualInformation<>::KModeTypes::NML:
646  mutualInfo_->useNML();
647  break;
648 
649  case CorrectedMutualInformation<>::KModeTypes::NoCorr:
650  mutualInfo_->useNoCorr();
651  break;
652 
653  default:
654  GUM_ERROR(NotImplementedYet,
655  "The BNLearner's corrected mutual information class does "
656  << "not implement yet this correction : " << int(kmode3Off2_));
657  }
658  }
659 
660  DAG genericBNLearner::learnDag_() {
661  // check that the database does not contain any missing value
662  if (scoreDatabase_.databaseTable().hasMissingValues()
663  || ((aprioriDatabase_ != nullptr)
664  && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
665  && aprioriDatabase_->databaseTable().hasMissingValues())) {
666  GUM_ERROR(MissingValueInDatabase,
667  "For the moment, the BNLearner is unable to cope "
668  "with missing values in databases");
669  }
670  // add the mandatory arcs to the initial dag and remove the forbidden ones
671  // from the initial graph
672  DAG init_graph = initialDag_;
673 
674  const ArcSet& mandatory_arcs = constraintMandatoryArcs_.arcs();
675 
676  for (const auto& arc: mandatory_arcs) {
677  if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
678 
679  if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
680 
681  init_graph.addArc(arc.tail(), arc.head());
682  }
683 
684  const ArcSet& forbidden_arcs = constraintForbiddenArcs_.arcs();
685 
686  for (const auto& arc: forbidden_arcs) {
687  init_graph.eraseArc(arc);
688  }
689 
690  switch (selectedAlgo_) {
691  // ========================================================================
692  case AlgoType::MIIC:
693  case AlgoType::THREE_OFF_TWO: {
694  BNLearnerListener listener(this, algoMiic3off2_);
695  // create the mixedGraph and the corrected mutual information
696  MixedGraph mgraph = this->prepareMiic3Off2_();
697 
698  return algoMiic3off2_.learnStructure(*mutualInfo_, mgraph);
699  }
700 
701  // ========================================================================
702  case AlgoType::GREEDY_HILL_CLIMBING: {
703  BNLearnerListener listener(this, greedyHillClimbing_);
704  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
705  StructuralConstraintForbiddenArcs,
706  StructuralConstraintPossibleEdges,
707  StructuralConstraintSliceOrder >
708  gen_constraint;
709  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
710  = constraintMandatoryArcs_;
711  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
712  = constraintForbiddenArcs_;
713  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
714  = constraintPossibleEdges_;
715  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
716 
717  GraphChangesGenerator4DiGraph< decltype(gen_constraint) > op_set(gen_constraint);
718 
719  StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
720  sel_constraint;
721  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
722 
723  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
724  *score_,
725  sel_constraint,
726  op_set);
727 
728  return greedyHillClimbing_.learnStructure(selector, init_graph);
729  }
730 
731  // ========================================================================
732  case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST: {
733  BNLearnerListener listener(this, localSearchWithTabuList_);
734  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
735  StructuralConstraintForbiddenArcs,
736  StructuralConstraintPossibleEdges,
737  StructuralConstraintSliceOrder >
738  gen_constraint;
739  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
740  = constraintMandatoryArcs_;
741  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
742  = constraintForbiddenArcs_;
743  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
744  = constraintPossibleEdges_;
745  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) = constraintSliceOrder_;
746 
747  GraphChangesGenerator4DiGraph< decltype(gen_constraint) > op_set(gen_constraint);
748 
749  StructuralConstraintSetStatic< StructuralConstraintTabuList,
750  StructuralConstraintIndegree,
751  StructuralConstraintDAG >
752  sel_constraint;
753  static_cast< StructuralConstraintTabuList& >(sel_constraint) = constraintTabuList_;
754  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
755 
756  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
757  *score_,
758  sel_constraint,
759  op_set);
760 
761  return localSearchWithTabuList_.learnStructure(selector, init_graph);
762  }
763 
764  // ========================================================================
765  case AlgoType::K2: {
766  BNLearnerListener listener(this, algoK2_.approximationScheme());
767  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
768  StructuralConstraintForbiddenArcs,
769  StructuralConstraintPossibleEdges >
770  gen_constraint;
771  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
772  = constraintMandatoryArcs_;
773  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
774  = constraintForbiddenArcs_;
775  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
776  = constraintPossibleEdges_;
777 
778  GraphChangesGenerator4K2< decltype(gen_constraint) > op_set(gen_constraint);
779 
780  // if some mandatory arcs are incompatible with the order, use a DAG
781  // constraint instead of a DiGraph constraint to avoid cycles
782  const ArcSet& mandatory_arcs
783  = static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint).arcs();
784  const Sequence< NodeId >& order = algoK2_.order();
785  bool order_compatible = true;
786 
787  for (const auto& arc: mandatory_arcs) {
788  if (order.pos(arc.tail()) >= order.pos(arc.head())) {
789  order_compatible = false;
790  break;
791  }
792  }
793 
794  if (order_compatible) {
795  StructuralConstraintSetStatic< StructuralConstraintIndegree,
796  StructuralConstraintDiGraph >
797  sel_constraint;
798  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
799 
800  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
801  *score_,
802  sel_constraint,
803  op_set);
804 
805  return algoK2_.learnStructure(selector, init_graph);
806  } else {
807  StructuralConstraintSetStatic< StructuralConstraintIndegree, StructuralConstraintDAG >
808  sel_constraint;
809  static_cast< StructuralConstraintIndegree& >(sel_constraint) = constraintIndegree_;
810 
811  GraphChangesSelector4DiGraph< decltype(sel_constraint), decltype(op_set) > selector(
812  *score_,
813  sel_constraint,
814  op_set);
815 
816  return algoK2_.learnStructure(selector, init_graph);
817  }
818  }
819 
820  // ========================================================================
821  default:
822  GUM_ERROR(OperationNotAllowed,
823  "the learnDAG method has not been implemented for this "
824  "learning algorithm");
825  }
826  }
827 
828  std::string genericBNLearner::checkScoreAprioriCompatibility() const {
829  const std::string& apriori = getAprioriType_();
830 
831  switch (scoreType_) {
832  case ScoreType::AIC:
833  return ScoreAIC<>::isAprioriCompatible(apriori, aprioriWeight_);
834 
835  case ScoreType::BD:
836  return ScoreBD<>::isAprioriCompatible(apriori, aprioriWeight_);
837 
838  case ScoreType::BDeu:
839  return ScoreBDeu<>::isAprioriCompatible(apriori, aprioriWeight_);
840 
841  case ScoreType::BIC:
842  return ScoreBIC<>::isAprioriCompatible(apriori, aprioriWeight_);
843 
844  case ScoreType::K2:
845  return ScoreK2<>::isAprioriCompatible(apriori, aprioriWeight_);
846 
847  case ScoreType::LOG2LIKELIHOOD:
848  return ScoreLog2Likelihood<>::isAprioriCompatible(apriori, aprioriWeight_);
849 
850  default:
851  return "genericBNLearner does not support yet this score";
852  }
853  }
854 
855 
856  /// sets the ranges of rows to be used for cross-validation learning
857  std::pair< std::size_t, std::size_t >
858  genericBNLearner::useCrossValidationFold(const std::size_t learning_fold,
859  const std::size_t k_fold) {
860  if (k_fold == 0) { GUM_ERROR(OutOfBounds, "K-fold cross validation with k=0 is forbidden") }
861 
862  if (learning_fold >= k_fold) {
863  GUM_ERROR(OutOfBounds,
864  "In " << k_fold << "-fold cross validation, the learning "
865  << "fold should be strictly lower than " << k_fold
866  << " but, here, it is equal to " << learning_fold);
867  }
868 
869  const std::size_t db_size = scoreDatabase_.databaseTable().nbRows();
870  if (k_fold >= db_size) {
871  GUM_ERROR(OutOfBounds,
872  "In " << k_fold << "-fold cross validation, the database's "
873  << "size should be strictly greater than " << k_fold
874  << " but, here, the database has only " << db_size << "rows");
875  }
876 
877  // create the ranges of rows of the test database
878  const std::size_t foldSize = db_size / k_fold;
879  const std::size_t unfold_deb = learning_fold * foldSize;
880  const std::size_t unfold_end = unfold_deb + foldSize;
881 
882  ranges_.clear();
883  if (learning_fold == std::size_t(0)) {
884  ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
885  } else {
886  ranges_.push_back(std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
887 
888  if (learning_fold != k_fold - 1) {
889  ranges_.push_back(std::pair< std::size_t, std::size_t >(unfold_end, db_size));
890  }
891  }
892 
893  return std::pair< std::size_t, std::size_t >(unfold_deb, unfold_end);
894  }
895 
896 
897  std::pair< double, double > genericBNLearner::chi2(const NodeId id1,
898  const NodeId id2,
899  const std::vector< NodeId >& knowing) {
900  createApriori_();
901  gum::learning::IndepTestChi2<> chi2score(scoreDatabase_.parser(),
902  *apriori_,
903  databaseRanges());
904 
905  return chi2score.statistics(id1, id2, knowing);
906  }
907 
908  std::pair< double, double > genericBNLearner::chi2(const std::string& name1,
909  const std::string& name2,
910  const std::vector< std::string >& knowing) {
911  std::vector< NodeId > knowingIds;
912  std::transform(knowing.begin(),
913  knowing.end(),
914  std::back_inserter(knowingIds),
915  [this](const std::string& c) -> NodeId { return this->idFromName(c); });
916  return chi2(idFromName(name1), idFromName(name2), knowingIds);
917  }
918 
919  std::pair< double, double > genericBNLearner::G2(const NodeId id1,
920  const NodeId id2,
921  const std::vector< NodeId >& knowing) {
922  createApriori_();
923  gum::learning::IndepTestG2<> g2score(scoreDatabase_.parser(), *apriori_, databaseRanges());
924  return g2score.statistics(id1, id2, knowing);
925  }
926 
927  std::pair< double, double > genericBNLearner::G2(const std::string& name1,
928  const std::string& name2,
929  const std::vector< std::string >& knowing) {
930  std::vector< NodeId > knowingIds;
931  std::transform(knowing.begin(),
932  knowing.end(),
933  std::back_inserter(knowingIds),
934  [this](const std::string& c) -> NodeId { return this->idFromName(c); });
935  return G2(idFromName(name1), idFromName(name2), knowingIds);
936  }
937 
938  double genericBNLearner::logLikelihood(const std::vector< NodeId >& vars,
939  const std::vector< NodeId >& knowing) {
940  createApriori_();
941  gum::learning::ScoreLog2Likelihood<> ll2score(scoreDatabase_.parser(),
942  *apriori_,
943  databaseRanges());
944 
945  std::vector< NodeId > total(vars);
946  total.insert(total.end(), knowing.begin(), knowing.end());
947  double LLtotal = ll2score.score(IdCondSet<>(total, false, true));
948  if (knowing.size() == (Size)0) {
949  return LLtotal;
950  } else {
951  double LLknw = ll2score.score(IdCondSet<>(knowing, false, true));
952  return LLtotal - LLknw;
953  }
954  }
955 
956  double genericBNLearner::logLikelihood(const std::vector< std::string >& vars,
957  const std::vector< std::string >& knowing) {
958  std::vector< NodeId > ids;
959  std::vector< NodeId > knowingIds;
960 
961  auto mapper = [this](const std::string& c) -> NodeId {
962  return this->idFromName(c);
963  };
964 
965  std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
966  std::transform(knowing.begin(), knowing.end(), std::back_inserter(knowingIds), mapper);
967 
968  return logLikelihood(ids, knowingIds);
969  }
970 
971  std::vector< double > genericBNLearner::rawPseudoCount(const std::vector< NodeId >& vars) {
972  Potential< double > res;
973 
974  createApriori_();
975  gum::learning::PseudoCount<> count(scoreDatabase_.parser(), *apriori_, databaseRanges());
976  return count.get(vars);
977  }
978 
979 
980  std::vector< double > genericBNLearner::rawPseudoCount(const std::vector< std::string >& vars) {
981  std::vector< NodeId > ids;
982 
983  auto mapper = [this](const std::string& c) -> NodeId {
984  return this->idFromName(c);
985  };
986 
987  std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
988 
989  return rawPseudoCount(ids);
990  }
991 
992  } /* namespace learning */
993 
994 } /* namespace gum */