aGrUM  0.13.3
genericBNLearner.cpp
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}@lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it wil be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
29 #include <algorithm>
30 
31 #include <agrum/agrum.h>
34 
35 // include the inlined functions if necessary
36 #ifdef GUM_NO_INLINE
38 #endif /* GUM_NO_INLINE */
39 
40 namespace gum {
41 
42  namespace learning {
43 
44 
46  __database(db) {
47  // get the variables names
48  const auto& var_names = __database.variableNames();
49  const std::size_t nb_vars = var_names.size();
50  __modalities.resize(nb_vars);
51  const auto domainSizes = __database.domainSizes();
52  for (std::size_t i = 0; i < nb_vars; ++i) {
53  __name2nodeId.insert(var_names[i], NodeId(i));
54  __modalities[i] = Size(domainSizes[i]);
55  }
56 
57  // create the parser
58  __parser =
60  }
61 
62 
64  const std::string& filename,
65  const std::vector< std::string >& missing_symbols) :
66  Database(genericBNLearner::__readFile(filename, missing_symbols)) {}
67 
68 
69  /*
70  genericBNLearner::Database::Database(
71  std::string filename,
72  const NodeProperty< Sequence< std::string > >& modalities,
73  bool check_database)
74  : __database(genericBNLearner::__readFile(filename)) {
75 
76  // #### TODO: change the domain sizes of the variables
77 
78  // get the variables names
79  const auto& var_names = __database.variableNames ();
80  const std::size_t nb_vars = var_names.size ();
81  for ( std::size_t i = 0; i < nb_vars; ++i )
82  __name2nodeId.insert ( var_names[i], i );
83 
84  // get the domain sizes of the variables
85  __modalities.resize ( nb_vars );
86  for ( std::size_t i = 0; i < nb_vars; ++i ) {
87  const DiscreteVariable& var =
88  static_cast<const DiscreteVariable&> __database.variable ( i );
89  __modalities[i] = var.domainSize ();
90  }
91 
92  // create the parser
93  __parser = new DBRowGeneratorParser<> ( __database.handler (),
94  DBRowGeneratorSet<> () );
95  }
96  */
97 
98 
100  const std::string& filename,
101  Database& apriori_database,
102  const std::vector< std::string >& missing_symbols) :
103  __database(genericBNLearner::__readFile(filename, missing_symbols)) {
104  // check that there are at least as many variables in the a priori
105  // database as those in the score_database
106  if (__database.nbVariables() < apriori_database.__database.nbVariables()) {
108  "the a priori seems to have fewer variables "
109  "than the observed database");
110  }
111 
112  const std::vector< std::string >& apriori_vars =
113  apriori_database.__database.variableNames();
114  const std::vector< std::string >& score_vars = __database.variableNames();
115 
116  Size size = Size(apriori_vars.size());
117  for (Idx i = 0; i < size; ++i) {
118  if (apriori_vars[i] != score_vars[i]) {
120  "some a priori variables do not match "
121  "their counterpart in the score database");
122  }
123  }
124 
125  /*
126  ##### TODO: see what is the point of passing in argument score_database
127 
128  __raw_translators = score_database.__raw_translators;
129  auto raw_filter =
130  make_DB_row_filter(__database, __raw_translators, __generators);
131  __raw_translators = raw_filter.translatorSet();
132  score_database.__raw_translators = raw_filter.translatorSet();
133  */
134  }
135 
136 
140  // create the parser
141  __parser =
143  }
144 
145 
147  __database(std::move(from.__database)),
148  __modalities(std::move(from.__modalities)),
149  __name2nodeId(std::move(from.__name2nodeId)) {
150  // create the parser
151  __parser =
153  }
154 
155 
157 
159  operator=(const Database& from) {
160  if (this != &from) {
161  delete __parser;
162  __database = from.__database;
163  __modalities = from.__modalities;
165 
166  // create the parser
167  __parser =
169  }
170 
171  return *this;
172  }
173 
176  if (this != &from) {
177  delete __parser;
178  __database = std::move(from.__database);
179  __modalities = std::move(from.__modalities);
180  __name2nodeId = std::move(from.__name2nodeId);
181 
182  // create the parser
183  __parser =
185  }
186 
187  return *this;
188  }
189 
190 
191  // ===========================================================================
192 
194  const std::string& filename,
195  const std::vector< std::string >& missing_symbols) :
196  __score_database(filename, missing_symbols) {
197  // for debugging purposes
198  GUM_CONSTRUCTOR(genericBNLearner);
199  }
200 
201 
203  __score_database(db) {
204  // for debugging purposes
205  GUM_CONSTRUCTOR(genericBNLearner);
206  }
207 
208 
209  /*
210 
211  genericBNLearner::genericBNLearner(
212  const std::string& filename,
213  const NodeProperty< Sequence< std::string > >& modalities,
214  bool parse_database)
215  : __score_database(filename, modalities, parse_database)
216  , __user_modalities(modalities)
217  , __modalities_parse_db(parse_database) {
218  // for debugging purposes
219  GUM_CONSTRUCTOR(genericBNLearner);
220  }
221 
222  */
223 
224 
244  // for debugging purposes
245  GUM_CONS_CPY(genericBNLearner);
246  }
247 
258  __selected_algo(from.__selected_algo), __K2(std::move(from.__K2)),
259  __miic_3off2(std::move(from.__miic_3off2)),
262  std::move(from.__local_search_with_tabu_list)),
267  __initial_dag(std::move(from.__initial_dag)) {
268  // for debugging purposes
269  GUM_CONS_MOV(genericBNLearner);
270  }
271 
273  if (__score) delete __score;
274 
276 
277  if (__apriori) delete __apriori;
278 
280 
281  if (__mutual_info) delete __mutual_info;
282 
283  GUM_DESTRUCTOR(genericBNLearner);
284  }
285 
287  if (this != &from) {
288  if (__score) {
289  delete __score;
290  __score = nullptr;
291  }
292 
293  if (__param_estimator) {
294  delete __param_estimator;
295  __param_estimator = nullptr;
296  }
297 
298  if (__apriori) {
299  delete __apriori;
300  __apriori = nullptr;
301  }
302 
303  if (__apriori_database) {
304  delete __apriori_database;
305  __apriori_database = nullptr;
306  }
307 
308  if (__mutual_info) {
309  delete __mutual_info;
310  __mutual_info = nullptr;
311  }
312 
313  __score_type = from.__score_type;
323  __K2 = from.__K2;
324  __miic_3off2 = from.__miic_3off2;
332  __current_algorithm = nullptr;
333  }
334 
335  return *this;
336  }
337 
339  if (this != &from) {
340  if (__score) {
341  delete __score;
342  __score = nullptr;
343  }
344 
345  if (__param_estimator) {
346  delete __param_estimator;
347  __param_estimator = nullptr;
348  }
349 
350  if (__apriori) {
351  delete __apriori;
352  __apriori = nullptr;
353  }
354 
355  if (__apriori_database) {
356  delete __apriori_database;
357  __apriori_database = nullptr;
358  }
359 
360  if (__mutual_info) {
361  delete __mutual_info;
362  __mutual_info = nullptr;
363  }
364 
365  __score_type = from.__score_type;
366  __param_estimator_type = from.__param_estimator_type;
367  __apriori_type = from.__apriori_type;
368  __apriori_weight = from.__apriori_weight;
369  __constraint_SliceOrder = std::move(from.__constraint_SliceOrder);
370  __constraint_Indegree = std::move(from.__constraint_Indegree);
371  __constraint_TabuList = std::move(from.__constraint_TabuList);
372  __constraint_ForbiddenArcs = std::move(from.__constraint_ForbiddenArcs);
373  __constraint_MandatoryArcs = std::move(from.__constraint_MandatoryArcs);
374  __selected_algo = from.__selected_algo;
375  __K2 = from.__K2;
376  __miic_3off2 = std::move(from.__miic_3off2);
377  __greedy_hill_climbing = std::move(from.__greedy_hill_climbing);
379  std::move(from.__local_search_with_tabu_list);
380  __score_database = std::move(from.__score_database);
381  __user_modalities = std::move(from.__user_modalities);
382  __modalities_parse_db = from.__modalities_parse_db;
383  __apriori_dbname = std::move(from.__apriori_dbname);
384  __initial_dag = std::move(from.__initial_dag);
385  __current_algorithm = nullptr;
386  }
387 
388  return *this;
389  }
390 
391 
392  DatabaseTable<> readFile(const std::string& filename) {
393  // get the extension of the file
394  Size filename_size = Size(filename.size());
395 
396  if (filename_size < 4) {
398  "genericBNLearner could not determine the "
399  "file type of the database");
400  }
401 
402  std::string extension = filename.substr(filename.size() - 4);
403  std::transform(
404  extension.begin(), extension.end(), extension.begin(), ::tolower);
405 
406  if (extension != ".csv") {
408  "genericBNLearner does not support yet this type "
409  "of database file");
410  }
411 
412  DBInitializerFromCSV<> initializer(filename);
413 
414  const auto& var_names = initializer.variableNames();
415  const std::size_t nb_vars = var_names.size();
416 
417  DBTranslatorSet<> translator_set;
419  for (std::size_t i = 0; i < nb_vars; ++i) {
420  translator_set.insertTranslator(translator, i);
421  }
422 
423  DatabaseTable<> database(translator_set);
424  database.setVariableNames(initializer.variableNames());
425  initializer.fillDatabase(database);
426 
427  return database;
428  }
429 
430 
431  void genericBNLearner::__checkFileName(const std::string& filename) {
432  // get the extension of the file
433  Size filename_size = Size(filename.size());
434 
435  if (filename_size < 4) {
437  "genericBNLearner could not determine the "
438  "file type of the database");
439  }
440 
441  std::string extension = filename.substr(filename.size() - 4);
442  std::transform(
443  extension.begin(), extension.end(), extension.begin(), ::tolower);
444 
445  if (extension != ".csv") {
446  GUM_ERROR(
448  "genericBNLearner does not support yet this type of database file");
449  }
450  }
451 
452 
454  const std::string& filename,
455  const std::vector< std::string >& missing_symbols) {
456  // get the extension of the file
457  Size filename_size = Size(filename.size());
458 
459  if (filename_size < 4) {
461  "genericBNLearner could not determine the "
462  "file type of the database");
463  }
464 
465  std::string extension = filename.substr(filename.size() - 4);
466  std::transform(
467  extension.begin(), extension.end(), extension.begin(), ::tolower);
468 
469  if (extension != ".csv") {
470  GUM_ERROR(
472  "genericBNLearner does not support yet this type of database file");
473  }
474 
475 
476  DBInitializerFromCSV<> initializer(filename);
477 
478  const auto& var_names = initializer.variableNames();
479  const std::size_t nb_vars = var_names.size();
480 
481  DBTranslatorSet<> translator_set;
482  DBTranslator4LabelizedVariable<> translator(missing_symbols);
483  for (std::size_t i = 0; i < nb_vars; ++i) {
484  translator_set.insertTranslator(translator, i);
485  }
486 
487  DatabaseTable<> database(missing_symbols, translator_set);
488  database.setVariableNames(initializer.variableNames());
489  initializer.fillDatabase(database);
490 
491  // check that the database does not contain any missing value
492  if (database.hasMissingValues())
494  "For the moment, the BNLearaner is unable to cope "
495  "with missing values in databases");
496 
497  database.reorder();
498 
499  return database;
500  }
501 
502 
504  // first, save the old apriori, to be delete if everything is ok
505  Apriori<>* old_apriori = __apriori;
506 
507  // create the new apriori
508  switch (__apriori_type) {
510 
512 
514  if (__apriori_database != nullptr) {
515  delete __apriori_database;
516  __apriori_database = nullptr;
517  }
518 
519  if (__user_modalities.empty()) {
523  } else {
524  GUM_ERROR(OperationNotAllowed, "not implemented");
525  //__apriori_database =
526  // new Database(__apriori_dbname, __score_database,
527  // __user_modalities);
528  }
529 
532  break;
533 
534  default:
536  "genericBNLearner does not support yet this apriori");
537  }
538 
539  // do not forget to assign a weight to the apriori
541 
542  // remove the old apriori, if any
543  if (old_apriori != nullptr) delete old_apriori;
544  }
545 
547  // first, save the old score, to be delete if everything is ok
548  Score<>* old_score = __score;
549 
550  // create the new scoring function
551  switch (__score_type) {
552  case ScoreType::AIC:
553  __score = new ScoreAIC<>(
555  break;
556 
557  case ScoreType::BD:
558  __score = new ScoreBD<>(
560  break;
561 
562  case ScoreType::BDeu:
563  __score = new ScoreBDeu<>(
565  break;
566 
567  case ScoreType::BIC:
568  __score = new ScoreBIC<>(
570  break;
571 
572  case ScoreType::K2:
573  __score = new ScoreK2<>(
575  break;
576 
580  break;
581 
582  default:
584  "genericBNLearner does not support yet this score");
585  }
586 
587  // remove the old score, if any
588  if (old_score != nullptr) delete old_score;
589  }
590 
591  void genericBNLearner::__createParamEstimator(bool take_into_account_score) {
592  // first, save the old estimator, to be delete if everything is ok
593  ParamEstimator<>* old_estimator = __param_estimator;
594 
595  // create the new estimator
596  switch (__param_estimator_type) {
598  if (take_into_account_score && (__score != nullptr)) {
602  *__apriori,
604  } else {
608  *__apriori);
609  }
610 
611  break;
612 
613  default:
615  "genericBNLearner does not support "
616  "yet this parameter estimator");
617  }
618 
619  // remove the old estimator, if any
620  if (old_estimator != nullptr) delete old_estimator;
621  }
622 
625  // Initialize the mixed graph to the fully connected graph
626  MixedGraph mgraph;
627  for (Size i = 0; i < __score_database.modalities().size(); ++i) {
628  mgraph.addNodeWithId(i);
629  for (Size j = 0; j < i; ++j) {
630  mgraph.addEdge(j, i);
631  }
632  }
633 
634  // translating the constraints for 3off2 or miic
635  HashTable< std::pair< Idx, Idx >, char > initial_marks;
636  const ArcSet& mandatory_arcs = __constraint_MandatoryArcs.arcs();
637  for (const auto& arc : mandatory_arcs) {
638  initial_marks.insert({arc.tail(), arc.head()}, '>');
639  }
640 
641  const ArcSet& forbidden_arcs = __constraint_ForbiddenArcs.arcs();
642  for (const auto& arc : forbidden_arcs) {
643  initial_marks.insert({arc.tail(), arc.head()}, '-');
644  }
645  __miic_3off2.addConstraints(initial_marks);
646  // create the mutual entropy object
647  if (__mutual_info == nullptr) { this->useNML(); }
648 
649  return mgraph;
650  }
651 
654  GUM_ERROR(OperationNotAllowed, "Must be using the miic/3off2 algorithm");
655  }
656  BNLearnerListener listener(this, __miic_3off2);
657  // create the mixedGraph_constraint_MandatoryArcs.arcs();
658  MixedGraph mgraph = this->__prepare_miic_3off2();
660  }
661 
663  // create the score and the apriori
664  __createApriori();
665  __createScore();
666 
667  return __learnDAG();
668  }
669 
671  // add the mandatory arcs to the initial dag and remove the forbidden ones
672  // from the initial graph
673  DAG init_graph = __initial_dag;
674 
675  const ArcSet& mandatory_arcs = __constraint_MandatoryArcs.arcs();
676 
677  for (const auto& arc : mandatory_arcs) {
678  if (!init_graph.exists(arc.tail())) init_graph.addNodeWithId(arc.tail());
679 
680  if (!init_graph.exists(arc.head())) init_graph.addNodeWithId(arc.head());
681 
682  init_graph.addArc(arc.tail(), arc.head());
683  }
684 
685  const ArcSet& forbidden_arcs = __constraint_ForbiddenArcs.arcs();
686 
687  for (const auto& arc : forbidden_arcs) {
688  init_graph.eraseArc(arc);
689  }
690 
691  switch (__selected_algo) {
692  // ========================================================================
694  BNLearnerListener listener(this, __miic_3off2);
695  // create the mixedGraph
696  MixedGraph mgraph = this->__prepare_miic_3off2();
697 
698  return __miic_3off2.learnStructure(*__mutual_info, mgraph);
699  }
700  // ========================================================================
706  gen_constraint;
707  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint) =
709  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint) =
711  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) =
713 
715  gen_constraint);
716 
719  sel_constraint;
720  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
722 
724  decltype(sel_constraint),
725  decltype(op_set) >
726  selector(*__score, sel_constraint, op_set);
727 
729  selector, __score_database.modalities(), init_graph);
730  }
731 
732  // ========================================================================
738  gen_constraint;
739  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint) =
741  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint) =
743  static_cast< StructuralConstraintSliceOrder& >(gen_constraint) =
745 
747  gen_constraint);
748 
752  sel_constraint;
753  static_cast< StructuralConstraintTabuList& >(sel_constraint) =
755  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
757 
759  decltype(sel_constraint),
760  decltype(op_set) >
761  selector(*__score, sel_constraint, op_set);
762 
764  selector, __score_database.modalities(), init_graph);
765  }
766 
767  // ========================================================================
768  case AlgoType::K2: {
769  BNLearnerListener listener(this, __K2.approximationScheme());
772  gen_constraint;
773  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint) =
775  static_cast< StructuralConstraintForbiddenArcs& >(gen_constraint) =
777 
779  gen_constraint);
780 
781  // if some mandatory arcs are incompatible with the order, use a DAG
782  // constraint instead of a DiGraph constraint to avoid cycles
783  const ArcSet& mandatory_arcs =
784  static_cast< StructuralConstraintMandatoryArcs& >(gen_constraint)
785  .arcs();
786  const Sequence< NodeId >& order = __K2.order();
787  bool order_compatible = true;
788 
789  for (const auto& arc : mandatory_arcs) {
790  if (order.pos(arc.tail()) >= order.pos(arc.head())) {
791  order_compatible = false;
792  break;
793  }
794  }
795 
796  if (order_compatible) {
799  sel_constraint;
800  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
802 
804  decltype(sel_constraint),
805  decltype(op_set) >
806  selector(*__score, sel_constraint, op_set);
807 
808  return __K2.learnStructure(
809  selector, __score_database.modalities(), init_graph);
810  } else {
813  sel_constraint;
814  static_cast< StructuralConstraintIndegree& >(sel_constraint) =
816 
818  decltype(sel_constraint),
819  decltype(op_set) >
820  selector(*__score, sel_constraint, op_set);
821 
822  return __K2.learnStructure(
823  selector, __score_database.modalities(), init_graph);
824  }
825  }
826 
827  // ========================================================================
828  default:
830  "the learnDAG method has not been implemented for this "
831  "learning algorithm");
832  }
833  }
834 
836  const std::string& apriori = __getAprioriType();
837 
838  switch (__score_type) {
839  case ScoreType::AIC:
841 
842  case ScoreType::BD:
844 
845  case ScoreType::BDeu:
847 
848  case ScoreType::BIC:
850 
851  case ScoreType::K2:
853 
857 
858  default: return "genericBNLearner does not support yet this score";
859  }
860  }
861 
862  } /* namespace learning */
863 
864 } /* namespace gum */
AlgoType __selected_algo
the selected learning algorithm
the class for structural constraints limiting the number of parents of nodes in a directed graph ...
void insert(const T1 &first, const T2 &second)
Inserts a new association in the gum::Bijection.
const std::vector< std::string, ALLOC< std::string > > & variableNames()
returns the names of the variables in the input dataset
The class for computing BDeu scores (actually their log2 value)
Definition: scoreBDeu.h:79
ApproximationScheme & approximationScheme()
returns the approximation policy of the learning algorithm
unsigned long Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:50
Score * __score
the score used
void addConstraints(HashTable< std::pair< Idx, Idx >, char > constraints)
Set a ensemble of constraints for the orientation phase.
Definition: Miic.cpp:1034
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
Database __score_database
the database to be used by the scores and parameter estimators
Idx pos(const Key &key) const
Returns the position of the object passed in argument (if it exists).
Definition: sequence_tpl.h:515
the structural constraint for forbidding the creation of some arcs during structure learning ...
unsigned int NodeId
Type for node ids.
Definition: graphElements.h:97
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
const std::string & __getAprioriType() const
returns the type (as a string) of a given apriori
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, const std::vector< Size > &modal, DAG initial_dag=DAG())
learns the structure of a Bayes net
const ArcSet & arcs() const
returns the set of mandatory arcs
static void __checkFileName(const std::string &filename)
checks whether the extension of a CSV filename is correct
The base class for all the scores used for learning (BIC, BDeu, etc)The class should be used as follo...
Definition: score.h:73
A class for generic framework of learning algorithms that can easily be used.
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
const std::vector< std::string > & missingSymbols() const
returns the set of missing symbols taken into account
DBVector< std::size_t > domainSizes() const
returns the domain sizes of all the variables in the database table
MixedGraph learnMixedStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Essential Graph
Definition: Miic.cpp:103
void __createScore()
create the score used for learning
The class used to pack sets of generators.
The class for computing Bayesian Dirichlet (BD) log2 scores.
Definition: scoreBD.h:70
StructuralConstraintSliceOrder __constraint_SliceOrder
the constraint for 2TBNs
Database & operator=(const Database &from)
copy operator
the structural constraint indicating that some arcs shall never be removed or reversed ...
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Miic __miic_3off2
the 3off2 algorithm
std::vector< Size > & modalities() noexcept
returns the modalities of the variables
virtual void addEdge(const NodeId first, const NodeId second)
insert a new edge into the undirected graph
Definition: undiGraph_inl.h:32
ParamEstimatorType __param_estimator_type
the type of the parameter estimator
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
ParamEstimator * __param_estimator
the parameter estimator to use
STL namespace.
A class that redirects gum_signal from algorithms to the listeners of BNLearn.
the base class for all apriori
Definition: apriori.h:45
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
DatabaseTable __database
the database itself
The class for computing K2 scores (actually their log2 value)
Definition: scoreK2.h:79
MixedGraph __prepare_miic_3off2()
prepares the initial graph for 3off2 or miic
bool exists(const NodeId id) const
alias for existsNode
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
virtual void setWeight(double weight)
sets the weight of the a priori (kind of effective sample size)
AprioriType __apriori_type
the a priori selected for the score and parameters
The class for generic Hash Tables.
Definition: hashTable.h:676
the class for computing log2-likelihood scores
CorrectedMutualInformation * __mutual_info
the selected correction for 3off2 and miic
A dirichlet priori: computes its N&#39;_ijk from a database.
DAG __initial_dag
an initial DAG given to learners
void useNML()
indicate that we wish to use the NML correction for 3off2
const Sequence< NodeId > & order() const noexcept
returns the current order
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, const std::vector< Size > &modal, DAG initial_dag=DAG())
learns the structure of a Bayes net
The mecanism to compute the next available graph changes for directed structure learning search algor...
StructuralConstraintMandatoryArcs __constraint_MandatoryArcs
the constraint on forbidden arcs
std::size_t nbVariables() const noexcept
returns the number of variables (columns) of the database
genericBNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols)
default constructor
Database * __apriori_database
the database used by the Dirichlet a priori
LocalSearchWithTabuList __local_search_with_tabu_list
the local search with tabu list algorithm
DatabaseTable readFile(const std::string &filename)
const ApproximationScheme * __current_algorithm
bool hasMissingValues() const
indicates whether the database contains some missing values
the "meta-programming" class for storing structural constraintsIn aGrUM, there are two ways to store ...
const ArcSet & arcs() const
returns the set of mandatory arcs
Apriori * __apriori
the apriori used
StructuralConstraintTabuList __constraint_TabuList
the constraint for tabu lists
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, const std::vector< Size > &modal, DAG initial_dag=DAG())
learns the structure of a Bayes net
Definition: K2_tpl.h:38
std::string __apriori_dbname
the filename for the Dirichlet a priori, if any
void fillDatabase(DATABASE< ALLOC > &database, const bool retry_insertion=false)
fills the rows of the database table
bool __modalities_parse_db
indicates whether we shall parse the database to update __user_modalities
std::size_t insertTranslator(const Translator< ALLOC > &translator, const std::size_t column, const bool unique_column=true)
inserts a new translator at the end of the translator set
NodeProperty< Sequence< std::string > > __user_modalities
indicates the values the user specified for the translators
DAG __learnDAG()
returns the DAG learnt
Bijection< std::string, NodeId > __name2nodeId
a hashtable assigning to each variable name its NodeId
GreedyHillClimbing __greedy_hill_climbing
the greedy hill climbing algorithm
genericBNLearner & operator=(const genericBNLearner &)
copy operator
The basic class for computing the next graph changes possible in a structure learning algorithm...
std::string checkScoreAprioriCompatibility()
checks whether the current score and apriori are compatible
the class for computing AIC scores
Definition: scoreAIC.h:68
The class for computing BIC scores.
Definition: scoreBIC.h:66
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:40
DAG learnDAG()
learn a structure from a file (must have read the db before)
A listener that allows BNLearner to be used as a proxy for its inner algorithms.
The class representing a tabular database as used by learning tasks.
MixedGraph learnMixedStructure()
learn a partial structure from a file (must have read the db before and must have selected miic or 3o...
StructuralConstraintIndegree __constraint_Indegree
the constraint for indegrees
ScoreType __score_type
the score selected for learning
virtual ~genericBNLearner()
destructor
DBRowGeneratorParser * __parser
the parser used for reading the database
void __createParamEstimator(bool take_into_account_score=true)
create the parameter estimator used for learning
A pack of learning algorithms that can easily be used.
static DatabaseTable __readFile(const std::string &filename, const std::vector< std::string > &missing_symbols)
reads a file and returns a databaseVectInRam
DBRowGeneratorParser & parser()
returns the parser for the database
virtual void setVariableNames(const std::vector< std::string, ALLOC< std::string > > &names, const bool from_external_object=true) final
sets the names of the variables
double __apriori_weight
the weight of the apriori
void __createApriori()
create the apriori used for learning
The class imposing a N-sized tabu list as a structural constraints for learning algorithms.
The class for initializing DatabaseTable and RawDatabaseTable instances from CSV files.
the smooth a priori: adds a weight w to all the countings
A pack of learning algorithms that can easily be used.
Database(const std::string &file, const std::vector< std::string > &missing_symbols)
default constructor
the class for packing together the translators used to preprocess the datasets
The databases&#39; cell translators for labelized variables.
std::vector< Size > __modalities
the modalities of the variables
StructuralConstraintForbiddenArcs __constraint_ForbiddenArcs
the constraint on forbidden arcs
a helper to easily read databases
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
The base class for estimating parameters of CPTsThe class should be used as follows: first...
The class for estimating parameters of CPTs using Maximum LikelihoodThe class should be used as follo...
const DBVector< std::string > & variableNames() const noexcept
returns the variable names for all the columns of the database
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
unsigned long Idx
Type for indexes.
Definition: types.h:43
The basic class for computing the next graph changes possible in a structure learning algorithm...
Base class for dag.
Definition: DAG.h:98
The base class for structural constraints used by learning algorithms that learn a directed graph str...
void reorder(const std::size_t k, const bool k_is_input_col=false)
performs a reordering of the kth translator or of the first translator parsing the kth column of the ...
the no a priori class: corresponds to 0 weight-sample
virtual std::string isAprioriCompatible() const final
indicates whether the apriori is compatible (meaningful) with the score
iterator handler() const
returns a new unsafe handler pointing to the 1st record of the database
DAG learnStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Bayesian network, ie a DAG, by first learning an Essential graph and then ...
Definition: Miic.cpp:944
#define GUM_ERROR(type, msg)
Definition: exceptions.h:66
virtual const ScoreInternalApriori< IdSetAlloc, CountAlloc > & internalApriori() const noexcept=0
returns the internal apriori of the score
The base class for structural constraints imposed by DAGs.
Base class for mixed graphs.
Definition: mixedGraph.h:124
the structural constraint imposing a partial order over nodes