aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
genericBNLearner.cpp
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 
31 #include <algorithm>
32 #include <iterator>
33 
34 #include <agrum/agrum.h>
35 #include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h>
36 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h>
37 #include <agrum/tools/stattests/indepTestChi2.h>
38 #include <agrum/tools/stattests/indepTestG2.h>
39 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
40 #include <agrum/tools/stattests/pseudoCount.h>
41 
42 // include the inlined functions if necessary
43 #ifdef GUM_NO_INLINE
44 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h>
45 #endif /* GUM_NO_INLINE */
46 
47 namespace gum {
48 
49  namespace learning {
50 
51 
52  genericBNLearner::Database::Database(const DatabaseTable<>& db) :
53  database__(db) {
54  // get the variables names
55  const auto& var_names = database__.variableNames();
56  const std::size_t nb_vars = var_names.size();
57  for (auto dom: database__.domainSizes())
58  domain_sizes__.push_back(dom);
59  for (std::size_t i = 0; i < nb_vars; ++i) {
60  nodeId2cols__.insert(NodeId(i), i);
61  }
62 
63  // create the parser
64  parser__
65  = new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
66  }
67 
68 
69  genericBNLearner::Database::Database(
70  const std::string& filename,
71  const std::vector< std::string >& missing_symbols) :
72  Database(genericBNLearner::readFile__(filename, missing_symbols)) {}
73 
74 
75  genericBNLearner::Database::Database(
76  const std::string& CSV_filename,
77  Database& score_database,
78  const std::vector< std::string >& missing_symbols) {
79  // assign to each column name in the CSV file its column
80  genericBNLearner::checkFileName__(CSV_filename);
81  DBInitializerFromCSV<> initializer(CSV_filename);
82  const auto& apriori_names = initializer.variableNames();
83  std::size_t apriori_nb_vars = apriori_names.size();
84  HashTable< std::string, std::size_t > apriori_names2col(apriori_nb_vars);
85  for (std::size_t i = std::size_t(0); i < apriori_nb_vars; ++i)
86  apriori_names2col.insert(apriori_names[i], i);
87 
88  // check that there are at least as many variables in the a priori
89  // database as those in the score_database
90  if (apriori_nb_vars < score_database.database__.nbVariables()) {
91  GUM_ERROR(InvalidArgument,
92  "the a apriori database has fewer variables "
93  "than the observed database");
94  }
95 
96  // get the mapping from the columns of score_database to those of
97  // the CSV file
98  const std::vector< std::string >& score_names
99  = score_database.databaseTable().variableNames();
100  const std::size_t score_nb_vars = score_names.size();
101  HashTable< std::size_t, std::size_t > mapping(score_nb_vars);
102  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
103  try {
104  mapping.insert(i, apriori_names2col[score_names[i]]);
105  } catch (Exception&) {
106  GUM_ERROR(MissingVariableInDatabase,
107  "Variable "
108  << score_names[i]
109  << " of the observed database does not belong to the "
110  << "apriori database");
111  }
112  }
113 
114  // create the translators for CSV database
115  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
116  const Variable& var = score_database.databaseTable().variable(i);
117  database__.insertTranslator(var, mapping[i], missing_symbols);
118  }
119 
120  // fill the database
121  initializer.fillDatabase(database__);
122 
123  // get the domain sizes of the variables
124  for (auto dom: database__.domainSizes())
125  domain_sizes__.push_back(dom);
126 
127  // compute the mapping from node ids to column indices
128  nodeId2cols__ = score_database.nodeId2Columns();
129 
130  // create the parser
131  parser__
132  = new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
133  }
134 
135 
136  genericBNLearner::Database::Database(const Database& from) :
137  database__(from.database__), domain_sizes__(from.domain_sizes__),
138  nodeId2cols__(from.nodeId2cols__) {
139  // create the parser
140  parser__
141  = new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
142  }
143 
144 
145  genericBNLearner::Database::Database(Database&& from) :
146  database__(std::move(from.database__)),
147  domain_sizes__(std::move(from.domain_sizes__)),
148  nodeId2cols__(std::move(from.nodeId2cols__)) {
149  // create the parser
150  parser__
151  = new DBRowGeneratorParser<>(database__.handler(), DBRowGeneratorSet<>());
152  }
153 
154 
155  genericBNLearner::Database::~Database() { delete parser__; }
156 
157  genericBNLearner::Database&
158  genericBNLearner::Database::operator=(const Database& from) {
159  if (this != &from) {
160  delete parser__;
161  database__ = from.database__;
162  domain_sizes__ = from.domain_sizes__;
163  nodeId2cols__ = from.nodeId2cols__;
164 
165  // create the parser
166  parser__ = new DBRowGeneratorParser<>(database__.handler(),
167  DBRowGeneratorSet<>());
168  }
169 
170  return *this;
171  }
172 
173  genericBNLearner::Database&
174  genericBNLearner::Database::operator=(Database&& from) {
175  if (this != &from) {
176  delete parser__;
177  database__ = std::move(from.database__);
178  domain_sizes__ = std::move(from.domain_sizes__);
179  nodeId2cols__ = std::move(from.nodeId2cols__);
180 
181  // create the parser
182  parser__ = new DBRowGeneratorParser<>(database__.handler(),
183  DBRowGeneratorSet<>());
184  }
185 
186  return *this;
187  }
188 
189 
190  // ===========================================================================
191 
192  genericBNLearner::genericBNLearner(
193  const std::string& filename,
194  const std::vector< std::string >& missing_symbols) :
195  score_database__(filename, missing_symbols) {
196  no_apriori__ = new AprioriNoApriori<>(score_database__.databaseTable());
197 
198  // for debugging purposes
199  GUM_CONSTRUCTOR(genericBNLearner);
200  }
201 
202 
203  genericBNLearner::genericBNLearner(const DatabaseTable<>& db) :
204  score_database__(db) {
205  no_apriori__ = new AprioriNoApriori<>(score_database__.databaseTable());
206 
207  // for debugging purposes
208  GUM_CONSTRUCTOR(genericBNLearner);
209  }
210 
211 
212  genericBNLearner::genericBNLearner(const genericBNLearner& from) :
213  score_type__(from.score_type__),
214  param_estimator_type__(from.param_estimator_type__),
215  EMepsilon__(from.EMepsilon__), apriori_type__(from.apriori_type__),
216  apriori_weight__(from.apriori_weight__),
217  constraint_SliceOrder__(from.constraint_SliceOrder__),
218  constraint_Indegree__(from.constraint_Indegree__),
219  constraint_TabuList__(from.constraint_TabuList__),
220  constraint_ForbiddenArcs__(from.constraint_ForbiddenArcs__),
221  constraint_MandatoryArcs__(from.constraint_MandatoryArcs__),
222  selected_algo__(from.selected_algo__), K2__(from.K2__),
223  miic_3off2__(from.miic_3off2__), kmode_3off2__(from.kmode_3off2__),
224  greedy_hill_climbing__(from.greedy_hill_climbing__),
225  local_search_with_tabu_list__(from.local_search_with_tabu_list__),
226  score_database__(from.score_database__), ranges__(from.ranges__),
227  apriori_dbname__(from.apriori_dbname__),
228  initial_dag__(from.initial_dag__) {
229  no_apriori__ = new AprioriNoApriori<>(score_database__.databaseTable());
230 
231  // for debugging purposes
232  GUM_CONS_CPY(genericBNLearner);
233  }
234 
235  genericBNLearner::genericBNLearner(genericBNLearner&& from) :
236  score_type__(from.score_type__),
237  param_estimator_type__(from.param_estimator_type__),
238  EMepsilon__(from.EMepsilon__), apriori_type__(from.apriori_type__),
239  apriori_weight__(from.apriori_weight__),
240  constraint_SliceOrder__(std::move(from.constraint_SliceOrder__)),
241  constraint_Indegree__(std::move(from.constraint_Indegree__)),
242  constraint_TabuList__(std::move(from.constraint_TabuList__)),
243  constraint_ForbiddenArcs__(std::move(from.constraint_ForbiddenArcs__)),
244  constraint_MandatoryArcs__(std::move(from.constraint_MandatoryArcs__)),
245  selected_algo__(from.selected_algo__), K2__(std::move(from.K2__)),
246  miic_3off2__(std::move(from.miic_3off2__)),
247  kmode_3off2__(from.kmode_3off2__),
248  greedy_hill_climbing__(std::move(from.greedy_hill_climbing__)),
249  local_search_with_tabu_list__(
250  std::move(from.local_search_with_tabu_list__)),
251  score_database__(std::move(from.score_database__)),
252  ranges__(std::move(from.ranges__)),
253  apriori_dbname__(std::move(from.apriori_dbname__)),
254  initial_dag__(std::move(from.initial_dag__)) {
255  no_apriori__ = new AprioriNoApriori<>(score_database__.databaseTable());
256 
257  // for debugging purposes
258  GUM_CONS_MOV(genericBNLearner);
259  }
260 
261  genericBNLearner::~genericBNLearner() {
262  if (score__) delete score__;
263 
264  if (apriori__) delete apriori__;
265 
266  if (no_apriori__) delete no_apriori__;
267 
268  if (apriori_database__) delete apriori_database__;
269 
270  if (mutual_info__) delete mutual_info__;
271 
272  GUM_DESTRUCTOR(genericBNLearner);
273  }
274 
275  genericBNLearner& genericBNLearner::operator=(const genericBNLearner& from) {
276  if (this != &from) {
277  if (score__) {
278  delete score__;
279  score__ = nullptr;
280  }
281 
282  if (apriori__) {
283  delete apriori__;
284  apriori__ = nullptr;
285  }
286 
287  if (apriori_database__) {
288  delete apriori_database__;
289  apriori_database__ = nullptr;
290  }
291 
292  if (mutual_info__) {
293  delete mutual_info__;
294  mutual_info__ = nullptr;
295  }
296 
297  score_type__ = from.score_type__;
298  param_estimator_type__ = from.param_estimator_type__;
299  EMepsilon__ = from.EMepsilon__;
300  apriori_type__ = from.apriori_type__;
301  apriori_weight__ = from.apriori_weight__;
302  constraint_SliceOrder__ = from.constraint_SliceOrder__;
303  constraint_Indegree__ = from.constraint_Indegree__;
304  constraint_TabuList__ = from.constraint_TabuList__;
305  constraint_ForbiddenArcs__ = from.constraint_ForbiddenArcs__;
306  constraint_MandatoryArcs__ = from.constraint_MandatoryArcs__;
307  selected_algo__ = from.selected_algo__;
308  K2__ = from.K2__;
309  miic_3off2__ = from.miic_3off2__;
310  kmode_3off2__ = from.kmode_3off2__;
311  greedy_hill_climbing__ = from.greedy_hill_climbing__;
312  local_search_with_tabu_list__ = from.local_search_with_tabu_list__;
313  score_database__ = from.score_database__;
314  ranges__ = from.ranges__;
315  apriori_dbname__ = from.apriori_dbname__;
316  initial_dag__ = from.initial_dag__;
317  current_algorithm__ = nullptr;
318  }
319 
320  return *this;
321  }
322 
323  genericBNLearner& genericBNLearner::operator=(genericBNLearner&& from) {
324  if (this != &from) {
325  if (score__) {
326  delete score__;
327  score__ = nullptr;
328  }
329 
330  if (apriori__) {
331  delete apriori__;
332  apriori__ = nullptr;
333  }
334 
335  if (apriori_database__) {
336  delete apriori_database__;
337  apriori_database__ = nullptr;
338  }
339 
340  if (mutual_info__) {
341  delete mutual_info__;
342  mutual_info__ = nullptr;
343  }
344 
345  score_type__ = from.score_type__;
346  param_estimator_type__ = from.param_estimator_type__;
347  EMepsilon__ = from.EMepsilon__;
348  apriori_type__ = from.apriori_type__;
349  apriori_weight__ = from.apriori_weight__;
350  constraint_SliceOrder__ = std::move(from.constraint_SliceOrder__);
351  constraint_Indegree__ = std::move(from.constraint_Indegree__);
352  constraint_TabuList__ = std::move(from.constraint_TabuList__);
353  constraint_ForbiddenArcs__ = std::move(from.constraint_ForbiddenArcs__);
354  constraint_MandatoryArcs__ = std::move(from.constraint_MandatoryArcs__);
355  selected_algo__ = from.selected_algo__;
356  K2__ = from.K2__;
357  miic_3off2__ = std::move(from.miic_3off2__);
358  kmode_3off2__ = from.kmode_3off2__;
359  greedy_hill_climbing__ = std::move(from.greedy_hill_climbing__);
360  local_search_with_tabu_list__
361  = std::move(from.local_search_with_tabu_list__);
362  score_database__ = std::move(from.score_database__);
363  ranges__ = std::move(from.ranges__);
364  apriori_dbname__ = std::move(from.apriori_dbname__);
365  initial_dag__ = std::move(from.initial_dag__);
366  current_algorithm__ = nullptr;
367  }
368 
369  return *this;
370  }
371 
372 
373  DatabaseTable<> readFile(const std::string& filename) {
374  // get the extension of the file
375  Size filename_size = Size(filename.size());
376 
377  if (filename_size < 4) {
378  GUM_ERROR(FormatNotFound,
379  "genericBNLearner could not determine the "
380  "file type of the database");
381  }
382 
383  std::string extension = filename.substr(filename.size() - 4);
384  std::transform(extension.begin(),
385  extension.end(),
386  extension.begin(),
387  ::tolower);
388 
389  if (extension != ".csv") {
390  GUM_ERROR(OperationNotAllowed,
391  "genericBNLearner does not support yet this type "
392  "of database file");
393  }
394 
395  DBInitializerFromCSV<> initializer(filename);
396 
397  const auto& var_names = initializer.variableNames();
398  const std::size_t nb_vars = var_names.size();
399 
400  DBTranslatorSet<> translator_set;
401  DBTranslator4LabelizedVariable<> translator;
402  for (std::size_t i = 0; i < nb_vars; ++i) {
403  translator_set.insertTranslator(translator, i);
404  }
405 
406  DatabaseTable<> database(translator_set);
407  database.setVariableNames(initializer.variableNames());
408  initializer.fillDatabase(database);
409 
410  return database;
411  }
412 
413 
414  void genericBNLearner::checkFileName__(const std::string& filename) {
415  // get the extension of the file
416  Size filename_size = Size(filename.size());
417 
418  if (filename_size < 4) {
419  GUM_ERROR(FormatNotFound,
420  "genericBNLearner could not determine the "
421  "file type of the database");
422  }
423 
424  std::string extension = filename.substr(filename.size() - 4);
425  std::transform(extension.begin(),
426  extension.end(),
427  extension.begin(),
428  ::tolower);
429 
430  if (extension != ".csv") {
431  GUM_ERROR(
432  OperationNotAllowed,
433  "genericBNLearner does not support yet this type of database file");
434  }
435  }
436 
437 
438  DatabaseTable<> genericBNLearner::readFile__(
439  const std::string& filename,
440  const std::vector< std::string >& missing_symbols) {
441  // get the extension of the file
442  checkFileName__(filename);
443 
444  DBInitializerFromCSV<> initializer(filename);
445 
446  const auto& var_names = initializer.variableNames();
447  const std::size_t nb_vars = var_names.size();
448 
449  DBTranslatorSet<> translator_set;
450  DBTranslator4LabelizedVariable<> translator(missing_symbols);
451  for (std::size_t i = 0; i < nb_vars; ++i) {
452  translator_set.insertTranslator(translator, i);
453  }
454 
455  DatabaseTable<> database(missing_symbols, translator_set);
456  database.setVariableNames(initializer.variableNames());
457  initializer.fillDatabase(database);
458 
459  database.reorder();
460 
461  return database;
462  }
463 
464 
465  void genericBNLearner::createApriori__() {
466  // first, save the old apriori, to be delete if everything is ok
467  Apriori<>* old_apriori = apriori__;
468 
469  // create the new apriori
470  switch (apriori_type__) {
471  case AprioriType::NO_APRIORI:
472  apriori__ = new AprioriNoApriori<>(score_database__.databaseTable(),
473  score_database__.nodeId2Columns());
474  break;
475 
476  case AprioriType::SMOOTHING:
477  apriori__ = new AprioriSmoothing<>(score_database__.databaseTable(),
478  score_database__.nodeId2Columns());
479  break;
480 
481  case AprioriType::DIRICHLET_FROM_DATABASE:
482  if (apriori_database__ != nullptr) {
483  delete apriori_database__;
484  apriori_database__ = nullptr;
485  }
486 
487  apriori_database__ = new Database(apriori_dbname__,
488  score_database__,
489  score_database__.missingSymbols());
490 
491  apriori__ = new AprioriDirichletFromDatabase<>(
492  score_database__.databaseTable(),
493  apriori_database__->parser(),
494  apriori_database__->nodeId2Columns());
495  break;
496 
497  case AprioriType::BDEU:
498  apriori__ = new AprioriBDeu<>(score_database__.databaseTable(),
499  score_database__.nodeId2Columns());
500  break;
501 
502  default:
503  GUM_ERROR(OperationNotAllowed,
504  "The BNLearner does not support yet this apriori");
505  }
506 
507  // do not forget to assign a weight to the apriori
508  apriori__->setWeight(apriori_weight__);
509 
510  // remove the old apriori, if any
511  if (old_apriori != nullptr) delete old_apriori;
512  }
513 
514  void genericBNLearner::createScore__() {
515  // first, save the old score, to be delete if everything is ok
516  Score<>* old_score = score__;
517 
518  // create the new scoring function
519  switch (score_type__) {
520  case ScoreType::AIC:
521  score__ = new ScoreAIC<>(score_database__.parser(),
522  *apriori__,
523  ranges__,
524  score_database__.nodeId2Columns());
525  break;
526 
527  case ScoreType::BD:
528  score__ = new ScoreBD<>(score_database__.parser(),
529  *apriori__,
530  ranges__,
531  score_database__.nodeId2Columns());
532  break;
533 
534  case ScoreType::BDeu:
535  score__ = new ScoreBDeu<>(score_database__.parser(),
536  *apriori__,
537  ranges__,
538  score_database__.nodeId2Columns());
539  break;
540 
541  case ScoreType::BIC:
542  score__ = new ScoreBIC<>(score_database__.parser(),
543  *apriori__,
544  ranges__,
545  score_database__.nodeId2Columns());
546  break;
547 
548  case ScoreType::K2:
549  score__ = new ScoreK2<>(score_database__.parser(),
550  *apriori__,
551  ranges__,
552  score_database__.nodeId2Columns());
553  break;
554 
555  case ScoreType::LOG2LIKELIHOOD:
556  score__ = new ScoreLog2Likelihood<>(score_database__.parser(),
557  *apriori__,
558  ranges__,
559  score_database__.nodeId2Columns());
560  break;
561 
562  default:
563  GUM_ERROR(OperationNotAllowed,
564  "genericBNLearner does not support yet this score");
565  }
566 
567  // remove the old score, if any
568  if (old_score != nullptr) delete old_score;
569  }
570 
571  ParamEstimator<>*
572  genericBNLearner::createParamEstimator__(DBRowGeneratorParser<>& parser,
573  bool take_into_account_score) {
574  ParamEstimator<>* param_estimator = nullptr;
575 
576  // create the new estimator
577  switch (param_estimator_type__) {
578  case ParamEstimatorType::ML:
579  if (take_into_account_score && (score__ != nullptr)) {
580  param_estimator
581  = new ParamEstimatorML<>(parser,
582  *apriori__,
583  score__->internalApriori(),
584  ranges__,
585  score_database__.nodeId2Columns());
586  } else {
587  param_estimator
588  = new ParamEstimatorML<>(parser,
589  *apriori__,
590  *no_apriori__,
591  ranges__,
592  score_database__.nodeId2Columns());
593  }
594 
595  break;
596 
597  default:
598  GUM_ERROR(OperationNotAllowed,
599  "genericBNLearner does not support "
600  << "yet this parameter estimator");
601  }
602 
603  // assign the set of ranges
604  param_estimator->setRanges(ranges__);
605 
606  return param_estimator;
607  }
608 
609  /// prepares the initial graph for 3off2 or miic
610  MixedGraph genericBNLearner::prepare_miic_3off2__() {
611  // Initialize the mixed graph to the fully connected graph
612  MixedGraph mgraph;
613  for (Size i = 0; i < score_database__.databaseTable().nbVariables(); ++i) {
614  mgraph.addNodeWithId(i);
615  for (Size j = 0; j < i; ++j) {
616  mgraph.addEdge(j, i);
617  }
618  }
619 
620  // translating the constraints for 3off2 or miic
621  HashTable< std::pair< NodeId, NodeId >, char > initial_marks;
622  const ArcSet& mandatory_arcs = constraint_MandatoryArcs__.arcs();
623  for (const auto& arc: mandatory_arcs) {
624  initial_marks.insert({arc.tail(), arc.head()}, '>');
625  }
626 
627  const ArcSet& forbidden_arcs = constraint_ForbiddenArcs__.arcs();
628  for (const auto& arc: forbidden_arcs) {
629  initial_marks.insert({arc.tail(), arc.head()}, '-');
630  }
631  miic_3off2__.addConstraints(initial_marks);
632 
633  // create the mutual entropy object
634  // if (mutual_info__ == nullptr) { this->useNML(); }
635  createCorrectedMutualInformation__();
636 
637  return mgraph;
638  }
639 
640  MixedGraph genericBNLearner::learnMixedStructure() {
641  if (selected_algo__ != AlgoType::MIIC_THREE_OFF_TWO) {
642  GUM_ERROR(OperationNotAllowed, "Must be using the miic/3off2 algorithm");
643  }
644  // check that the database does not contain any missing value
645  if (score_database__.databaseTable().hasMissingValues()) {
646  GUM_ERROR(MissingValueInDatabase,
647  "For the moment, the BNLearner is unable to learn "
648  << "structures with missing values in databases");
649  }
650  BNLearnerListener listener(this, miic_3off2__);
651 
652  // create the mixedGraph_constraint_MandatoryArcs.arcs();
653  MixedGraph mgraph = this->prepare_miic_3off2__();
654 
655  return miic_3off2__.learnMixedStructure(*mutual_info__, mgraph);
656  }
657 
658  DAG genericBNLearner::learnDAG() {
659  // create the score and the apriori
660  createApriori__();
661  createScore__();
662 
663  return learnDAG__();
664  }
665 
666  void genericBNLearner::createCorrectedMutualInformation__() {
667  if (mutual_info__ != nullptr) delete mutual_info__;
668 
669  mutual_info__
670  = new CorrectedMutualInformation<>(score_database__.parser(),
671  *no_apriori__,
672  ranges__,
673  score_database__.nodeId2Columns());
674  switch (kmode_3off2__) {
675  case CorrectedMutualInformation<>::KModeTypes::MDL:
676  mutual_info__->useMDL();
677  break;
678 
679  case CorrectedMutualInformation<>::KModeTypes::NML:
680  mutual_info__->useNML();
681  break;
682 
683  case CorrectedMutualInformation<>::KModeTypes::NoCorr:
684  mutual_info__->useNoCorr();
685  break;
686 
687  default:
688  GUM_ERROR(NotImplementedYet,
689  "The BNLearner's corrected mutual information class does "
690  << "not support yet penalty mode " << int(kmode_3off2__));
691  }
692  }
693 
694  DAG genericBNLearner::learnDAG__() {
695  // check that the database does not contain any missing value
696  if (score_database__.databaseTable().hasMissingValues()
697  || ((apriori_database__ != nullptr)
698  && (apriori_type__ == AprioriType::DIRICHLET_FROM_DATABASE)
699  && apriori_database__->databaseTable().hasMissingValues())) {
700  GUM_ERROR(MissingValueInDatabase,
701  "For the moment, the BNLearner is unable to cope "
702  "with missing values in databases");
703  }
704  // add the mandatory arcs to the initial dag and remove the forbidden ones
705  // from the initial graph
706  DAG init_graph = initial_dag__;
707 
708  const ArcSet& mandatory_arcs = constraint_MandatoryArcs__.arcs();
709 
710  for (const auto& arc: mandatory_arcs) {
711  if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
712 
713  if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
714 
715  init_graph.addArc(arc.tail(), arc.head());
716  }
717 
718  const ArcSet& forbidden_arcs = constraint_ForbiddenArcs__.arcs();
719 
720  for (const auto& arc: forbidden_arcs) {
721  init_graph.eraseArc(arc);
722  }
723 
724  switch (selected_algo__) {
725  // ========================================================================
726  case AlgoType::MIIC_THREE_OFF_TWO: {
727  BNLearnerListener listener(this, miic_3off2__);
728  // create the mixedGraph and the corrected mutual information
729  MixedGraph mgraph = this->prepare_miic_3off2__();
730 
731  return miic_3off2__.learnStructure(*mutual_info__, mgraph);
732  }
733 
734  // ========================================================================
735  case AlgoType::GREEDY_HILL_CLIMBING: {
736  BNLearnerListener listener(this, greedy_hill_climbing__);
737  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
738  StructuralConstraintForbiddenArcs,
739  StructuralConstraintPossibleEdges,
740  StructuralConstraintSliceOrder >
741  gen_constraint;
742  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
743  = constraint_MandatoryArcs__;
744  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
745  = constraint_ForbiddenArcs__;
746  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
747  = constraint_PossibleEdges__;
748  static_cast< StructuralConstraintSliceOrder& >(gen_constraint)
749  = constraint_SliceOrder__;
750 
751  GraphChangesGenerator4DiGraph< decltype(gen_constraint) > op_set(
752  gen_constraint);
753 
754  StructuralConstraintSetStatic< StructuralConstraintIndegree,
755  StructuralConstraintDAG >
756  sel_constraint;
757  static_cast< StructuralConstraintIndegree& >(sel_constraint)
758  = constraint_Indegree__;
759 
760  GraphChangesSelector4DiGraph< decltype(sel_constraint),
761  decltype(op_set) >
762  selector(*score__, sel_constraint, op_set);
763 
764  return greedy_hill_climbing__.learnStructure(selector, init_graph);
765  }
766 
767  // ========================================================================
768  case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST: {
769  BNLearnerListener listener(this, local_search_with_tabu_list__);
770  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
771  StructuralConstraintForbiddenArcs,
772  StructuralConstraintPossibleEdges,
773  StructuralConstraintSliceOrder >
774  gen_constraint;
775  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
776  = constraint_MandatoryArcs__;
777  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
778  = constraint_ForbiddenArcs__;
779  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
780  = constraint_PossibleEdges__;
781  static_cast< StructuralConstraintSliceOrder& >(gen_constraint)
782  = constraint_SliceOrder__;
783 
784  GraphChangesGenerator4DiGraph< decltype(gen_constraint) > op_set(
785  gen_constraint);
786 
787  StructuralConstraintSetStatic< StructuralConstraintTabuList,
788  StructuralConstraintIndegree,
789  StructuralConstraintDAG >
790  sel_constraint;
791  static_cast< StructuralConstraintTabuList& >(sel_constraint)
792  = constraint_TabuList__;
793  static_cast< StructuralConstraintIndegree& >(sel_constraint)
794  = constraint_Indegree__;
795 
796  GraphChangesSelector4DiGraph< decltype(sel_constraint),
797  decltype(op_set) >
798  selector(*score__, sel_constraint, op_set);
799 
800  return local_search_with_tabu_list__.learnStructure(selector,
801  init_graph);
802  }
803 
804  // ========================================================================
805  case AlgoType::K2: {
806  BNLearnerListener listener(this, K2__.approximationScheme());
807  StructuralConstraintSetStatic< StructuralConstraintMandatoryArcs,
808  StructuralConstraintForbiddenArcs,
809  StructuralConstraintPossibleEdges >
810  gen_constraint;
811  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
812  = constraint_MandatoryArcs__;
813  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint)
814  = constraint_ForbiddenArcs__;
815  static_cast< StructuralConstraintPossibleEdges& >(gen_constraint)
816  = constraint_PossibleEdges__;
817 
818  GraphChangesGenerator4K2< decltype(gen_constraint) > op_set(
819  gen_constraint);
820 
821  // if some mandatory arcs are incompatible with the order, use a DAG
822  // constraint instead of a DiGraph constraint to avoid cycles
823  const ArcSet& mandatory_arcs
824  = static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
825  .arcs();
826  const Sequence< NodeId >& order = K2__.order();
827  bool order_compatible = true;
828 
829  for (const auto& arc: mandatory_arcs) {
830  if (order.pos(arc.tail()) >= order.pos(arc.head())) {
831  order_compatible = false;
832  break;
833  }
834  }
835 
836  if (order_compatible) {
837  StructuralConstraintSetStatic< StructuralConstraintIndegree,
838  StructuralConstraintDiGraph >
839  sel_constraint;
840  static_cast< StructuralConstraintIndegree& >(sel_constraint)
841  = constraint_Indegree__;
842 
843  GraphChangesSelector4DiGraph< decltype(sel_constraint),
844  decltype(op_set) >
845  selector(*score__, sel_constraint, op_set);
846 
847  return K2__.learnStructure(selector, init_graph);
848  } else {
849  StructuralConstraintSetStatic< StructuralConstraintIndegree,
850  StructuralConstraintDAG >
851  sel_constraint;
852  static_cast< StructuralConstraintIndegree& >(sel_constraint)
853  = constraint_Indegree__;
854 
855  GraphChangesSelector4DiGraph< decltype(sel_constraint),
856  decltype(op_set) >
857  selector(*score__, sel_constraint, op_set);
858 
859  return K2__.learnStructure(selector, init_graph);
860  }
861  }
862 
863  // ========================================================================
864  default:
865  GUM_ERROR(OperationNotAllowed,
866  "the learnDAG method has not been implemented for this "
867  "learning algorithm");
868  }
869  }
870 
871  std::string genericBNLearner::checkScoreAprioriCompatibility() {
872  const std::string& apriori = getAprioriType__();
873 
874  switch (score_type__) {
875  case ScoreType::AIC:
876  return ScoreAIC<>::isAprioriCompatible(apriori, apriori_weight__);
877 
878  case ScoreType::BD:
879  return ScoreBD<>::isAprioriCompatible(apriori, apriori_weight__);
880 
881  case ScoreType::BDeu:
882  return ScoreBDeu<>::isAprioriCompatible(apriori, apriori_weight__);
883 
884  case ScoreType::BIC:
885  return ScoreBIC<>::isAprioriCompatible(apriori, apriori_weight__);
886 
887  case ScoreType::K2:
888  return ScoreK2<>::isAprioriCompatible(apriori, apriori_weight__);
889 
890  case ScoreType::LOG2LIKELIHOOD:
891  return ScoreLog2Likelihood<>::isAprioriCompatible(apriori,
892  apriori_weight__);
893 
894  default:
895  return "genericBNLearner does not support yet this score";
896  }
897  }
898 
899 
900  /// sets the ranges of rows to be used for cross-validation learning
901  std::pair< std::size_t, std::size_t >
902  genericBNLearner::useCrossValidationFold(const std::size_t learning_fold,
903  const std::size_t k_fold) {
904  if (k_fold == 0) {
905  GUM_ERROR(OutOfBounds, "K-fold cross validation with k=0 is forbidden");
906  }
907 
908  if (learning_fold >= k_fold) {
909  GUM_ERROR(OutOfBounds,
910  "In " << k_fold << "-fold cross validation, the learning "
911  << "fold should be strictly lower than " << k_fold
912  << " but, here, it is equal to " << learning_fold);
913  }
914 
915  const std::size_t db_size = score_database__.databaseTable().nbRows();
916  if (k_fold >= db_size) {
917  GUM_ERROR(OutOfBounds,
918  "In " << k_fold << "-fold cross validation, the database's "
919  << "size should be strictly greater than " << k_fold
920  << " but, here, the database has only " << db_size
921  << "rows");
922  }
923 
924  // create the ranges of rows of the test database
925  const std::size_t foldSize = db_size / k_fold;
926  const std::size_t unfold_deb = learning_fold * foldSize;
927  const std::size_t unfold_end = unfold_deb + foldSize;
928 
929  ranges__.clear();
930  if (learning_fold == std::size_t(0)) {
931  ranges__.push_back(
932  std::pair< std::size_t, std::size_t >(unfold_end, db_size));
933  } else {
934  ranges__.push_back(
935  std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
936 
937  if (learning_fold != k_fold - 1) {
938  ranges__.push_back(
939  std::pair< std::size_t, std::size_t >(unfold_end, db_size));
940  }
941  }
942 
943  return std::pair< std::size_t, std::size_t >(unfold_deb, unfold_end);
944  }
945 
946 
947  std::pair< double, double >
948  genericBNLearner::chi2(const NodeId id1,
949  const NodeId id2,
950  const std::vector< NodeId >& knowing) {
951  createApriori__();
952  gum::learning::IndepTestChi2<> chi2score(score_database__.parser(),
953  *apriori__,
954  databaseRanges());
955 
956  return chi2score.statistics(id1, id2, knowing);
957  }
958 
959  std::pair< double, double >
960  genericBNLearner::chi2(const std::string& name1,
961  const std::string& name2,
962  const std::vector< std::string >& knowing) {
963  std::vector< NodeId > knowingIds;
964  std::transform(
965  knowing.begin(),
966  knowing.end(),
967  std::back_inserter(knowingIds),
968  [this](const std::string& c) -> NodeId { return this->idFromName(c); });
969  return chi2(idFromName(name1), idFromName(name2), knowingIds);
970  }
971 
972  std::pair< double, double >
973  genericBNLearner::G2(const NodeId id1,
974  const NodeId id2,
975  const std::vector< NodeId >& knowing) {
976  createApriori__();
977  gum::learning::IndepTestG2<> g2score(score_database__.parser(),
978  *apriori__,
979  databaseRanges());
980  return g2score.statistics(id1, id2, knowing);
981  }
982 
983  std::pair< double, double >
984  genericBNLearner::G2(const std::string& name1,
985  const std::string& name2,
986  const std::vector< std::string >& knowing) {
987  std::vector< NodeId > knowingIds;
988  std::transform(
989  knowing.begin(),
990  knowing.end(),
991  std::back_inserter(knowingIds),
992  [this](const std::string& c) -> NodeId { return this->idFromName(c); });
993  return G2(idFromName(name1), idFromName(name2), knowingIds);
994  }
995 
996  double genericBNLearner::logLikelihood(const std::vector< NodeId >& vars,
997  const std::vector< NodeId >& knowing) {
998  createApriori__();
999  gum::learning::ScoreLog2Likelihood<> ll2score(score_database__.parser(),
1000  *apriori__,
1001  databaseRanges());
1002 
1003  std::vector< NodeId > total(vars);
1004  total.insert(total.end(), knowing.begin(), knowing.end());
1005  double LLtotal = ll2score.score(IdCondSet<>(total, false, true));
1006  if (knowing.size() == (Size)0) {
1007  return LLtotal;
1008  } else {
1009  double LLknw = ll2score.score(IdCondSet<>(knowing, false, true));
1010  return LLtotal - LLknw;
1011  }
1012  }
1013 
1014  double
1015  genericBNLearner::logLikelihood(const std::vector< std::string >& vars,
1016  const std::vector< std::string >& knowing) {
1017  std::vector< NodeId > ids;
1018  std::vector< NodeId > knowingIds;
1019 
1020  auto mapper = [this](const std::string& c) -> NodeId {
1021  return this->idFromName(c);
1022  };
1023 
1024  std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
1025  std::transform(knowing.begin(),
1026  knowing.end(),
1027  std::back_inserter(knowingIds),
1028  mapper);
1029 
1030  return logLikelihood(ids, knowingIds);
1031  }
1032 
1033  std::vector< double >
1034  genericBNLearner::rawPseudoCount(const std::vector< NodeId >& vars) {
1035  Potential< double > res;
1036 
1037  createApriori__();
1038  gum::learning::PseudoCount<> count(score_database__.parser(),
1039  *apriori__,
1040  databaseRanges());
1041  return count.get(vars);
1042  }
1043 
1044 
1045  std::vector< double >
1046  genericBNLearner::rawPseudoCount(const std::vector< std::string >& vars) {
1047  std::vector< NodeId > ids;
1048 
1049  auto mapper = [this](const std::string& c) -> NodeId {
1050  return this->idFromName(c);
1051  };
1052 
1053  std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
1054 
1055  return rawPseudoCount(ids);
1056  }
1057 
1058  } /* namespace learning */
1059 
1060 } /* namespace gum */