aGrUM  0.14.2
genericBNLearner.cpp
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}@lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it wil be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
29 #include <algorithm>
30 
31 #include <agrum/agrum.h>
36 
37 // include the inlined functions if necessary
38 #ifdef GUM_NO_INLINE
40 #endif /* GUM_NO_INLINE */
41 
42 namespace gum {
43 
44  namespace learning {
45 
46 
48  __database(db) {
49  // get the variables names
50  const auto& var_names = __database.variableNames();
51  const std::size_t nb_vars = var_names.size();
52  for (auto dom : __database.domainSizes())
53  __domain_sizes.push_back(dom);
54  for (std::size_t i = 0; i < nb_vars; ++i) {
56  }
57 
58  // create the parser
59  __parser =
61  }
62 
63 
65  const std::string& filename,
66  const std::vector< std::string >& missing_symbols) :
67  Database(genericBNLearner::__readFile(filename, missing_symbols)) {}
68 
69 
71  const std::string& CSV_filename,
72  Database& score_database,
73  const std::vector< std::string >& missing_symbols) {
74  // assign to each column name in the CSV file its column
76  DBInitializerFromCSV<> initializer(CSV_filename);
77  const auto& apriori_names = initializer.variableNames();
78  std::size_t apriori_nb_vars = apriori_names.size();
79  HashTable< std::string, std::size_t > apriori_names2col(apriori_nb_vars);
80  for (std::size_t i = std::size_t(0); i < apriori_nb_vars; ++i)
81  apriori_names2col.insert(apriori_names[i], i);
82 
83  // check that there are at least as many variables in the a priori
84  // database as those in the score_database
85  if (apriori_nb_vars < score_database.__database.nbVariables()) {
87  "the a apriori database has fewer variables "
88  "than the observed database");
89  }
90 
91  // get the mapping from the columns of score_database to those of
92  // the CSV file
93  const std::vector< std::string >& score_names =
94  score_database.databaseTable().variableNames();
95  const std::size_t score_nb_vars = score_names.size();
96  HashTable< std::size_t, std::size_t > mapping(score_nb_vars);
97  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
98  try {
99  mapping.insert(i, apriori_names2col[score_names[i]]);
100  } catch (Exception&) {
102  "Variable "
103  << score_names[i]
104  << " of the observed database does not belong to the "
105  << "apriori database");
106  }
107  }
108 
109  // create the translators for CSV database
110  for (std::size_t i = std::size_t(0); i < score_nb_vars; ++i) {
111  const Variable& var = score_database.databaseTable().variable(i);
112  __database.insertTranslator(var, mapping[i], missing_symbols);
113  }
114 
115  // fill the database
116  initializer.fillDatabase(__database);
117 
118  // get the domain sizes of the variables
119  for (auto dom : __database.domainSizes())
120  __domain_sizes.push_back(dom);
121 
122  // compute the mapping from node ids to column indices
123  __nodeId2cols = score_database.nodeId2Columns();
124 
125  // create the parser
126  __parser =
128  }
129 
130 
134  // create the parser
135  __parser =
137  }
138 
139 
141  __database(std::move(from.__database)),
142  __domain_sizes(std::move(from.__domain_sizes)),
143  __nodeId2cols(std::move(from.__nodeId2cols)) {
144  // create the parser
145  __parser =
147  }
148 
149 
151 
153  operator=(const Database& from) {
154  if (this != &from) {
155  delete __parser;
156  __database = from.__database;
159 
160  // create the parser
161  __parser =
163  }
164 
165  return *this;
166  }
167 
170  if (this != &from) {
171  delete __parser;
172  __database = std::move(from.__database);
173  __domain_sizes = std::move(from.__domain_sizes);
174  __nodeId2cols = std::move(from.__nodeId2cols);
175 
176  // create the parser
177  __parser =
179  }
180 
181  return *this;
182  }
183 
184 
185  // ===========================================================================
186 
188  const std::string& filename,
189  const std::vector< std::string >& missing_symbols) :
190  __score_database(filename, missing_symbols) {
192 
193  // for debugging purposes
194  GUM_CONSTRUCTOR(genericBNLearner);
195  }
196 
197 
199  __score_database(db) {
201 
202  // for debugging purposes
203  GUM_CONSTRUCTOR(genericBNLearner);
204  }
205 
206 
225 
226  // for debugging purposes
227  GUM_CONS_CPY(genericBNLearner);
228  }
229 
240  __selected_algo(from.__selected_algo), __K2(std::move(from.__K2)),
241  __miic_3off2(std::move(from.__miic_3off2)),
245  std::move(from.__local_search_with_tabu_list)),
247  __ranges(std::move(from.__ranges)),
249  __initial_dag(std::move(from.__initial_dag)) {
251 
252  // for debugging purposes
253  GUM_CONS_MOV(genericBNLearner);
254  }
255 
257  if (__score) delete __score;
258 
259  if (__apriori) delete __apriori;
260 
261  if (__no_apriori) delete __no_apriori;
262 
264 
265  if (__mutual_info) delete __mutual_info;
266 
267  GUM_DESTRUCTOR(genericBNLearner);
268  }
269 
271  if (this != &from) {
272  if (__score) {
273  delete __score;
274  __score = nullptr;
275  }
276 
277  if (__apriori) {
278  delete __apriori;
279  __apriori = nullptr;
280  }
281 
282  if (__apriori_database) {
283  delete __apriori_database;
284  __apriori_database = nullptr;
285  }
286 
287  if (__mutual_info) {
288  delete __mutual_info;
289  __mutual_info = nullptr;
290  }
291 
292  __score_type = from.__score_type;
294  __EMepsilon = from.__EMepsilon;
303  __K2 = from.__K2;
304  __miic_3off2 = from.__miic_3off2;
309  __ranges = from.__ranges;
312  __current_algorithm = nullptr;
313  }
314 
315  return *this;
316  }
317 
319  if (this != &from) {
320  if (__score) {
321  delete __score;
322  __score = nullptr;
323  }
324 
325  if (__apriori) {
326  delete __apriori;
327  __apriori = nullptr;
328  }
329 
330  if (__apriori_database) {
331  delete __apriori_database;
332  __apriori_database = nullptr;
333  }
334 
335  if (__mutual_info) {
336  delete __mutual_info;
337  __mutual_info = nullptr;
338  }
339 
340  __score_type = from.__score_type;
341  __param_estimator_type = from.__param_estimator_type;
342  __EMepsilon = from.__EMepsilon;
343  __apriori_type = from.__apriori_type;
344  __apriori_weight = from.__apriori_weight;
345  __constraint_SliceOrder = std::move(from.__constraint_SliceOrder);
346  __constraint_Indegree = std::move(from.__constraint_Indegree);
347  __constraint_TabuList = std::move(from.__constraint_TabuList);
348  __constraint_ForbiddenArcs = std::move(from.__constraint_ForbiddenArcs);
349  __constraint_MandatoryArcs = std::move(from.__constraint_MandatoryArcs);
350  __selected_algo = from.__selected_algo;
351  __K2 = from.__K2;
352  __miic_3off2 = std::move(from.__miic_3off2);
353  __3off2_kmode = from.__3off2_kmode;
354  __greedy_hill_climbing = std::move(from.__greedy_hill_climbing);
356  std::move(from.__local_search_with_tabu_list);
357  __score_database = std::move(from.__score_database);
358  __ranges = std::move(from.__ranges);
359  __apriori_dbname = std::move(from.__apriori_dbname);
360  __initial_dag = std::move(from.__initial_dag);
361  __current_algorithm = nullptr;
362  }
363 
364  return *this;
365  }
366 
367 
368  DatabaseTable<> readFile(const std::string& filename) {
369  // get the extension of the file
370  Size filename_size = Size(filename.size());
371 
372  if (filename_size < 4) {
374  "genericBNLearner could not determine the "
375  "file type of the database");
376  }
377 
378  std::string extension = filename.substr(filename.size() - 4);
379  std::transform(
380  extension.begin(), extension.end(), extension.begin(), ::tolower);
381 
382  if (extension != ".csv") {
384  "genericBNLearner does not support yet this type "
385  "of database file");
386  }
387 
388  DBInitializerFromCSV<> initializer(filename);
389 
390  const auto& var_names = initializer.variableNames();
391  const std::size_t nb_vars = var_names.size();
392 
393  DBTranslatorSet<> translator_set;
395  for (std::size_t i = 0; i < nb_vars; ++i) {
396  translator_set.insertTranslator(translator, i);
397  }
398 
399  DatabaseTable<> database(translator_set);
400  database.setVariableNames(initializer.variableNames());
401  initializer.fillDatabase(database);
402 
403  return database;
404  }
405 
406 
407  void genericBNLearner::__checkFileName(const std::string& filename) {
408  // get the extension of the file
409  Size filename_size = Size(filename.size());
410 
411  if (filename_size < 4) {
413  "genericBNLearner could not determine the "
414  "file type of the database");
415  }
416 
417  std::string extension = filename.substr(filename.size() - 4);
418  std::transform(
419  extension.begin(), extension.end(), extension.begin(), ::tolower);
420 
421  if (extension != ".csv") {
422  GUM_ERROR(
424  "genericBNLearner does not support yet this type of database file");
425  }
426  }
427 
428 
430  const std::string& filename,
431  const std::vector< std::string >& missing_symbols) {
432  // get the extension of the file
433  __checkFileName(filename);
434 
435  DBInitializerFromCSV<> initializer(filename);
436 
437  const auto& var_names = initializer.variableNames();
438  const std::size_t nb_vars = var_names.size();
439 
440  DBTranslatorSet<> translator_set;
441  DBTranslator4LabelizedVariable<> translator(missing_symbols);
442  for (std::size_t i = 0; i < nb_vars; ++i) {
443  translator_set.insertTranslator(translator, i);
444  }
445 
446  DatabaseTable<> database(missing_symbols, translator_set);
447  database.setVariableNames(initializer.variableNames());
448  initializer.fillDatabase(database);
449 
450  database.reorder();
451 
452  return database;
453  }
454 
455 
457  // first, save the old apriori, to be delete if everything is ok
458  Apriori<>* old_apriori = __apriori;
459 
460  // create the new apriori
461  switch (__apriori_type) {
465  break;
466 
470  break;
471 
473  if (__apriori_database != nullptr) {
474  delete __apriori_database;
475  __apriori_database = nullptr;
476  }
477 
481 
486  break;
487 
488  case AprioriType::BDEU:
489  __apriori = new AprioriBDeu<>(__score_database.databaseTable(),
491  break;
492 
493  default:
495  "The BNLearner does not support yet this apriori");
496  }
497 
498  // do not forget to assign a weight to the apriori
500 
501  // remove the old apriori, if any
502  if (old_apriori != nullptr) delete old_apriori;
503  }
504 
506  // first, save the old score, to be delete if everything is ok
507  Score<>* old_score = __score;
508 
509  // create the new scoring function
510  switch (__score_type) {
511  case ScoreType::AIC:
513  *__apriori,
514  __ranges,
516  break;
517 
518  case ScoreType::BD:
520  *__apriori,
521  __ranges,
523  break;
524 
525  case ScoreType::BDeu:
527  *__apriori,
528  __ranges,
530  break;
531 
532  case ScoreType::BIC:
534  *__apriori,
535  __ranges,
537  break;
538 
539  case ScoreType::K2:
541  *__apriori,
542  __ranges,
544  break;
545 
548  *__apriori,
549  __ranges,
551  break;
552 
553  default:
555  "genericBNLearner does not support yet this score");
556  }
557 
558  // remove the old score, if any
559  if (old_score != nullptr) delete old_score;
560  }
561 
564  bool take_into_account_score) {
565  ParamEstimator<>* param_estimator = nullptr;
566 
567  // create the new estimator
568  switch (__param_estimator_type) {
570  if (take_into_account_score && (__score != nullptr)) {
571  param_estimator =
572  new ParamEstimatorML<>(parser,
573  *__apriori,
575  __ranges,
577  } else {
578  param_estimator =
579  new ParamEstimatorML<>(parser,
580  *__apriori,
581  *__no_apriori,
582  __ranges,
584  }
585 
586  break;
587 
588  default:
590  "genericBNLearner does not support "
591  << "yet this parameter estimator");
592  }
593 
594  // assign the set of ranges
595  param_estimator->setRanges(__ranges);
596 
597  return param_estimator;
598  }
599 
602  // Initialize the mixed graph to the fully connected graph
603  MixedGraph mgraph;
604  for (Size i = 0; i < __score_database.databaseTable().nbVariables(); ++i) {
605  mgraph.addNodeWithId(i);
606  for (Size j = 0; j < i; ++j) {
607  mgraph.addEdge(j, i);
608  }
609  }
610 
611  // translating the constraints for 3off2 or miic
612  HashTable< std::pair< NodeId, NodeId >, char > initial_marks;
613  const ArcSet& mandatory_arcs = __constraint_MandatoryArcs.arcs();
614  for (const auto& arc : mandatory_arcs) {
615  initial_marks.insert({arc.tail(), arc.head()}, '>');
616  }
617 
618  const ArcSet& forbidden_arcs = __constraint_ForbiddenArcs.arcs();
619  for (const auto& arc : forbidden_arcs) {
620  initial_marks.insert({arc.tail(), arc.head()}, '-');
621  }
622  __miic_3off2.addConstraints(initial_marks);
623 
624  // create the mutual entropy object
625  // if (__mutual_info == nullptr) { this->useNML(); }
627 
628  return mgraph;
629  }
630 
633  GUM_ERROR(OperationNotAllowed, "Must be using the miic/3off2 algorithm");
634  }
635  // check that the database does not contain any missing value
638  "For the moment, the BNLearner is unable to learn "
639  << "structures with missing values in databases");
640  }
641  BNLearnerListener listener(this, __miic_3off2);
642 
643  // create the mixedGraph_constraint_MandatoryArcs.arcs();
644  MixedGraph mgraph = this->__prepare_miic_3off2();
645 
647  }
648 
650  // create the score and the apriori
651  __createApriori();
652  __createScore();
653 
654  return __learnDAG();
655  }
656 
658  if (__mutual_info != nullptr) delete __mutual_info;
659 
660  __mutual_info =
662  *__no_apriori,
663  __ranges,
665  switch (__3off2_kmode) {
668  break;
669 
672  break;
673 
676  break;
677 
678  default:
680  "The BNLearner's corrected mutual information class does "
681  << "not support yet penalty mode " << int(__3off2_kmode));
682  }
683  }
684 
686  // check that the database does not contain any missing value
688  || ((__apriori_database != nullptr)
692  "For the moment, the BNLearner is unable to cope "
693  "with missing values in databases");
694  }
695  // add the mandatory arcs to the initial dag and remove the forbidden ones
696  // from the initial graph
697  DAG init_graph = __initial_dag;
698 
699  const ArcSet& mandatory_arcs = __constraint_MandatoryArcs.arcs();
700 
701  for (const auto& arc : mandatory_arcs) {
702  if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
703 
704  if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
705 
706  init_graph.addArc(arc.tail(), arc.head());
707  }
708 
709  const ArcSet& forbidden_arcs = __constraint_ForbiddenArcs.arcs();
710 
711  for (const auto& arc : forbidden_arcs) {
712  init_graph.eraseArc(arc);
713  }
714 
715  switch (__selected_algo) {
716  // ========================================================================
718  BNLearnerListener listener(this, __miic_3off2);
719  // create the mixedGraph and the corrected mutual information
720  MixedGraph mgraph = this->__prepare_miic_3off2();
721 
722  return __miic_3off2.learnStructure(*__mutual_info, mgraph);
723  }
724 
725  // ========================================================================
731  gen_constraint;
732  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint) =
734  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint) =
736  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) =
738 
740  gen_constraint);
741 
744  sel_constraint;
745  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
747 
748  GraphChangesSelector4DiGraph< decltype(sel_constraint),
749  decltype(op_set) >
750  selector(*__score, sel_constraint, op_set);
751 
752  return __greedy_hill_climbing.learnStructure(selector, init_graph);
753  }
754 
755  // ========================================================================
761  gen_constraint;
762  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint) =
764  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint) =
766  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) =
768 
770  gen_constraint);
771 
775  sel_constraint;
776  static_cast< StructuralConstraintTabuList& >(sel_constraint) =
778  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
780 
781  GraphChangesSelector4DiGraph< decltype(sel_constraint),
782  decltype(op_set) >
783  selector(*__score, sel_constraint, op_set);
784 
786  init_graph);
787  }
788 
789  // ========================================================================
790  case AlgoType::K2: {
791  BNLearnerListener listener(this, __K2.approximationScheme());
794  gen_constraint;
795  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint) =
797  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint) =
799 
801  gen_constraint);
802 
803  // if some mandatory arcs are incompatible with the order, use a DAG
804  // constraint instead of a DiGraph constraint to avoid cycles
805  const ArcSet& mandatory_arcs =
806  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
807  .arcs();
808  const Sequence< NodeId >& order = __K2.order();
809  bool order_compatible = true;
810 
811  for (const auto& arc : mandatory_arcs) {
812  if (order.pos(arc.tail()) >= order.pos(arc.head())) {
813  order_compatible = false;
814  break;
815  }
816  }
817 
818  if (order_compatible) {
821  sel_constraint;
822  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
824 
825  GraphChangesSelector4DiGraph< decltype(sel_constraint),
826  decltype(op_set) >
827  selector(*__score, sel_constraint, op_set);
828 
829  return __K2.learnStructure(selector, init_graph);
830  } else {
833  sel_constraint;
834  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
836 
837  GraphChangesSelector4DiGraph< decltype(sel_constraint),
838  decltype(op_set) >
839  selector(*__score, sel_constraint, op_set);
840 
841  return __K2.learnStructure(selector, init_graph);
842  }
843  }
844 
845  // ========================================================================
846  default:
848  "the learnDAG method has not been implemented for this "
849  "learning algorithm");
850  }
851  }
852 
854  const std::string& apriori = __getAprioriType();
855 
856  switch (__score_type) {
857  case ScoreType::AIC:
859 
860  case ScoreType::BD:
862 
863  case ScoreType::BDeu:
865 
866  case ScoreType::BIC:
868 
869  case ScoreType::K2:
871 
875 
876  default: return "genericBNLearner does not support yet this score";
877  }
878  }
879 
880 
882  std::pair< std::size_t, std::size_t >
883  genericBNLearner::useCrossValidationFold(const std::size_t learning_fold,
884  const std::size_t k_fold) {
885  if (k_fold == 0) {
886  GUM_ERROR(OutOfBounds, "K-fold cross validation with k=0 is forbidden");
887  }
888 
889  if (learning_fold >= k_fold) {
891  "In " << k_fold << "-fold cross validation, the learning "
892  << "fold should be strictly lower than " << k_fold
893  << " but, here, it is equal to " << learning_fold);
894  }
895 
896  const std::size_t db_size = __score_database.databaseTable().nbRows();
897  if (k_fold >= db_size) {
899  "In " << k_fold << "-fold cross validation, the database's "
900  << "size should be strictly greater than " << k_fold
901  << " but, here, the database has only " << db_size
902  << "rows");
903  }
904 
905  // create the ranges of rows of the test database
906  const std::size_t foldSize = db_size / k_fold;
907  const std::size_t unfold_deb = learning_fold * foldSize;
908  const std::size_t unfold_end = unfold_deb + foldSize;
909 
910  __ranges.clear();
911  if (learning_fold == std::size_t(0)) {
912  __ranges.push_back(
913  std::pair< std::size_t, std::size_t >(unfold_end, db_size));
914  } else {
915  __ranges.push_back(
916  std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
917 
918  if (learning_fold != k_fold - 1) {
919  __ranges.push_back(
920  std::pair< std::size_t, std::size_t >(unfold_end, db_size));
921  }
922  }
923 
924  return std::pair< std::size_t, std::size_t >(unfold_deb, unfold_end);
925  }
926 
927 
928  std::pair< double, double > genericBNLearner::chi2(
929  const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing) {
930  __createApriori();
934  parser, *__apriori, databaseRanges());
935 
936  return chi2score.statistics(id1, id2, knowing);
937  }
938 
939  std::pair< double, double >
940  genericBNLearner::chi2(const std::string& name1,
941  const std::string& name2,
942  const std::vector< std::string >& knowing) {
943  std::vector< NodeId > knowingIds;
944  std::transform(
945  knowing.begin(),
946  knowing.end(),
947  std::back_inserter(knowingIds),
948  [this](const std::string& c) -> NodeId { return this->idFromName(c); });
949  return chi2(idFromName(name1), idFromName(name2), knowingIds);
950  }
951 
952  double genericBNLearner::logLikelihood(const std::vector< NodeId >& vars,
953  const std::vector< NodeId >& knowing) {
954  __createApriori();
958  parser, *__apriori, databaseRanges());
959 
960  std::vector< NodeId > total(vars);
961  total.insert(total.end(), knowing.begin(), knowing.end());
962  double LLtotal = ll2score.score(IdSet<>(total, false, true));
963  if (knowing.size() == (Size)0) {
964  return LLtotal;
965  } else {
966  double LLknw = ll2score.score(IdSet<>(knowing, false, true));
967  return LLtotal - LLknw;
968  }
969  }
970 
971  double
972  genericBNLearner::logLikelihood(const std::vector< std::string >& vars,
973  const std::vector< std::string >& knowing) {
974  std::vector< NodeId > ids;
975  std::vector< NodeId > knowingIds;
976 
977  auto mapper = [this](const std::string& c) -> NodeId {
978  return this->idFromName(c);
979  };
980 
981  std::transform(vars.begin(), vars.end(), std::back_inserter(ids), mapper);
982  std::transform(
983  knowing.begin(), knowing.end(), std::back_inserter(knowingIds), mapper);
984 
985  return logLikelihood(ids, knowingIds);
986  }
987 
988 
989  } /* namespace learning */
990 
991 } /* namespace gum */
void useNML()
use the kNML penalty function
AlgoType __selected_algo
the selected learning algorithm
the class for structural constraints limiting the number of parents of nodes in a directed graph ...
void insert(const T1 &first, const T2 &second)
Inserts a new association in the gum::Bijection.
const std::vector< std::string, ALLOC< std::string > > & variableNames()
returns the names of the variables in the input dataset
the class for computing BDeu scores
Definition: scoreBDeu.h:56
ApproximationScheme & approximationScheme()
returns the approximation policy of the learning algorithm
Score * __score
the score used
double score(const IdSet< ALLOC > &idset)
returns the score for a given IdSet
Base class for every random variable.
Definition: variable.h:63
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
Database __score_database
the database to be used by the scores and parameter estimators
Idx pos(const Key &key) const
Returns the position of the object passed in argument (if it exists).
Definition: sequence_tpl.h:515
virtual void setWeight(const double weight)
sets the weight of the a priori (kind of effective sample size)
the structural constraint for forbidding the creation of some arcs during structure learning ...
CorrectedMutualInformation ::KModeTypes __3off2_kmode
the penalty used in 3off2
const std::string & __getAprioriType() const
returns the type (as a string) of a given apriori
double __EMepsilon
epsilon for EM. if espilon=0.0 : no EM
std::pair< std::size_t, std::size_t > useCrossValidationFold(const std::size_t learning_fold, const std::size_t k_fold)
sets the ranges of rows to be used for cross-validation learning
The class computing n times the corrected mutual information, as used in the 3off2 algorithm...
const ArcSet & arcs() const
returns the set of mandatory arcs
static void __checkFileName(const std::string &filename)
checks whether the extension of a CSV filename is correct
The base class for all the scores used for learning (BIC, BDeu, etc)
Definition: score.h:49
A class for generic framework of learning algorithms that can easily be used.
const std::vector< std::string > & missingSymbols() const
returns the set of missing symbols taken into account
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the parameter estimator
DBVector< std::size_t > domainSizes() const
returns the domain sizes of all the variables in the database table
the class for computing Log2-likelihood scores
MixedGraph learnMixedStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Essential Graph
Definition: Miic.cpp:110
void __createScore()
create the score used for learning
The class used to pack sets of generators.
the class for computing Bayesian Dirichlet (BD) log2 scores
Definition: scoreBD.h:62
StructuralConstraintSliceOrder __constraint_SliceOrder
the constraint for 2TBNs
Database & operator=(const Database &from)
copy operator
the structural constraint indicating that some arcs shall never be removed or reversed ...
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Miic __miic_3off2
the 3off2 algorithm
the class for computing Chi2 independence test scores
Definition: indepTestChi2.h:45
virtual void addEdge(const NodeId first, const NodeId second)
insert a new edge into the undirected graph
Definition: undiGraph_inl.h:32
ParamEstimatorType __param_estimator_type
the type of the parameter estimator
void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints)
Set a ensemble of constraints for the orientation phase.
Definition: Miic.cpp:1064
STL namespace.
the class for computing Chi2 scores
A class for storing a pair of sets of NodeIds, the second one corresponding to a conditional set...
Definition: idSet.h:45
A class that redirects gum_signal from algorithms to the listeners of BNLearn.
the base class for all a priori
Definition: apriori.h:47
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
DatabaseTable __database
the database itself
the class for computing K2 scores (actually their log2 value)
Definition: scoreK2.h:58
MixedGraph __prepare_miic_3off2()
prepares the initial graph for 3off2 or miic
std::pair< double, double > chi2(const NodeId id1, const NodeId id2, const std::vector< NodeId > &knowing={})
Return the <statistic,pvalue> pair for the BNLearner.
bool exists(const NodeId id) const
alias for existsNode
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
double logLikelihood(const std::vector< NodeId > &vars, const std::vector< NodeId > &knowing={})
Return the loglikelihood of vars in the base, conditioned by knowing for the BNLearner.
AprioriType __apriori_type
the a priori selected for the score and parameters
NodeId idFromName(const std::string &var_name) const
returns the node id corresponding to a variable name
the internal apriori for the BDeu score (N&#39; / (r_i * q_i)BDeu is a BD score with a N&#39;/(r_i * q_i) apr...
Definition: aprioriBDeu.h:51
The class for generic Hash Tables.
Definition: hashTable.h:676
the class for computing Log2-likelihood scores
std::pair< double, double > statistics(NodeId var1, NodeId var2, const std::vector< NodeId, ALLOC< NodeId > > &rhs_ids={})
get the pair <chi2 statistic,pvalue> for a test var1 indep var2 given rhs_ids
CorrectedMutualInformation * __mutual_info
the selected correction for 3off2 and miic
A dirichlet priori: computes its N&#39;_ijk from a database.
DAG __initial_dag
an initial DAG given to learners
const Sequence< NodeId > & order() const noexcept
returns the current order
The mecanism to compute the next available graph changes for directed structure learning search algor...
StructuralConstraintMandatoryArcs __constraint_MandatoryArcs
the constraint on forbidden arcs
std::size_t nbVariables() const noexcept
returns the number of variables (columns) of the database
genericBNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols)
default constructor
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
Database * __apriori_database
the database used by the Dirichlet a priori
LocalSearchWithTabuList __local_search_with_tabu_list
the local search with tabu list algorithm
DatabaseTable readFile(const std::string &filename)
const ApproximationScheme * __current_algorithm
bool hasMissingValues() const
indicates whether the database contains some missing values
the "meta-programming" class for storing structural constraintsIn aGrUM, there are two ways to store ...
ParamEstimator * __createParamEstimator(DBRowGeneratorParser<> &parser, bool take_into_account_score=true)
create the parameter estimator used for learning
void __createCorrectedMutualInformation()
create the Corrected Mutual Information instance for Miic/3off2
const ArcSet & arcs() const
returns the set of mandatory arcs
Apriori * __apriori
the apriori used
StructuralConstraintTabuList __constraint_TabuList
the constraint for tabu lists
std::string __apriori_dbname
the filename for the Dirichlet a priori, if any
void fillDatabase(DATABASE< ALLOC > &database, const bool retry_insertion=false)
fills the rows of the database table
std::size_t insertTranslator(const Translator< ALLOC > &translator, const std::size_t column, const bool unique_column=true)
inserts a new translator at the end of the translator set
DAG __learnDAG()
returns the DAG learnt
GreedyHillClimbing __greedy_hill_climbing
the greedy hill climbing algorithm
std::size_t nbRows() const noexcept
returns the number of records (rows) in the database
Base class for all aGrUM&#39;s exceptions.
Definition: exceptions.h:103
genericBNLearner & operator=(const genericBNLearner &)
copy operator
const DatabaseTable & databaseTable() const
returns the internal database table
The basic class for computing the next graph changes possible in a structure learning algorithm...
std::string checkScoreAprioriCompatibility()
checks whether the current score and apriori are compatible
the class for computing AIC scores
Definition: scoreAIC.h:49
the class for computing BIC scores
Definition: scoreBIC.h:49
virtual const Apriori< ALLOC > & internalApriori() const =0
returns the internal apriori of the score
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:40
DAG learnDAG()
learn a structure from a file (must have read the db before)
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
std::vector< std::pair< std::size_t, std::size_t > > __ranges
the set of rows&#39; ranges within the database in which learning is done
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
const Variable & variable(const std::size_t k, const bool k_is_input_col=false) const
returns either the kth variable of the database table or the first one corresponding to the kth colum...
A listener that allows BNLearner to be used as a proxy for its inner algorithms.
const Bijection< NodeId, std::size_t > & nodeId2Columns() const
returns the mapping between node ids and their columns in the database
The class representing a tabular database as used by learning tasks.
MixedGraph learnMixedStructure()
learn a partial structure from a file (must have read the db before and must have selected miic or 3o...
StructuralConstraintIndegree __constraint_Indegree
the constraint for indegrees
ScoreType __score_type
the score selected for learning
const std::vector< std::pair< std::size_t, std::size_t > > & databaseRanges() const
returns the current database rows&#39; ranges used for learning
std::size_t insertTranslator(const DBTranslator< ALLOC > &translator, const std::size_t input_column, const bool unique_column=true)
insert a new translator into the database table
virtual ~genericBNLearner()
destructor
DBRowGeneratorParser * __parser
the parser used for reading the database
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
A pack of learning algorithms that can easily be used.
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
Definition: K2_tpl.h:38
static DatabaseTable __readFile(const std::string &filename, const std::vector< std::string > &missing_symbols)
reads a file and returns a databaseVectInRam
DBRowGeneratorParser & parser()
returns the parser for the database
virtual void setVariableNames(const std::vector< std::string, ALLOC< std::string > > &names, const bool from_external_object=true) final
sets the names of the variables
double __apriori_weight
the weight of the apriori
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
void __createApriori()
create the apriori used for learning
The class imposing a N-sized tabu list as a structural constraints for learning algorithms.
The class for initializing DatabaseTable and RawDatabaseTable instances from CSV files.
the smooth a priori: adds a weight w to all the countings
A pack of learning algorithms that can easily be used.
std::vector< std::size_t > __domain_sizes
the domain sizes of the variables (useful to speed-up computations)
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
Bijection< NodeId, std::size_t > __nodeId2cols
a bijection assigning to each variable name its NodeId
Database(const std::string &file, const std::vector< std::string > &missing_symbols)
default constructor
the class for packing together the translators used to preprocess the datasets
The databases&#39; cell translators for labelized variables.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
const DatabaseTable & database() const
returns the database used by the BNLearner
StructuralConstraintForbiddenArcs __constraint_ForbiddenArcs
the constraint on forbidden arcs
void useMDL()
use the MDL penalty function
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
a helper to easily read databases
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
The base class for estimating parameters of CPTs.
The class for estimating parameters of CPTs using Maximum Likelihood.
const DBVector< std::string > & variableNames() const noexcept
returns the variable names for all the columns of the database
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
The basic class for computing the next graph changes possible in a structure learning algorithm...
Base class for dag.
Definition: DAG.h:99
The base class for structural constraints used by learning algorithms that learn a directed graph str...
Size NodeId
Type for node ids.
Definition: graphElements.h:97
void reorder(const std::size_t k, const bool k_is_input_col=false)
performs a reordering of the kth translator or of the first translator parsing the kth column of the ...
the no a priori class: corresponds to 0 weight-sample
iterator handler() const
returns a new unsafe handler pointing to the 1st record of the database
DAG learnStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Bayesian network, ie a DAG, by first learning an Essential graph and then ...
Definition: Miic.cpp:984
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
void useNoCorr()
use no correction/penalty function
The base class for structural constraints imposed by DAGs.
Base class for mixed graphs.
Definition: mixedGraph.h:124
the structural constraint imposing a partial order over nodes