aGrUM  0.16.0
BNLearner_tpl.h
Go to the documentation of this file.
1 
31 #include <fstream>
32 
33 #ifndef DOXYGEN_SHOULD_SKIP_THIS
34 
35 // to help IDE parser
37 
39 
40 namespace gum {
41 
42  namespace learning {
43  template < typename GUM_SCALAR >
45  const std::string& filename,
46  const std::vector< std::string >& missing_symbols) :
47  genericBNLearner(filename, missing_symbols) {
48  GUM_CONSTRUCTOR(BNLearner);
49  }
50 
51  template < typename GUM_SCALAR >
52  BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db) :
53  genericBNLearner(db) {
54  GUM_CONSTRUCTOR(BNLearner);
55  }
56 
57  template < typename GUM_SCALAR >
59  const std::string& filename,
61  const std::vector< std::string >& missing_symbols) :
62  genericBNLearner(filename, bn, missing_symbols) {
63  GUM_CONSTRUCTOR(BNLearner);
64  }
65 
67  template < typename GUM_SCALAR >
68  BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) :
69  genericBNLearner(src) {
70  GUM_CONSTRUCTOR(BNLearner);
71  }
72 
74  template < typename GUM_SCALAR >
75  BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) :
76  genericBNLearner(src) {
77  GUM_CONSTRUCTOR(BNLearner);
78  }
79 
81  template < typename GUM_SCALAR >
83  GUM_DESTRUCTOR(BNLearner);
84  }
85 
87 
88  // ##########################################################################
90  // ##########################################################################
92 
94  template < typename GUM_SCALAR >
95  BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::
96  operator=(const BNLearner< GUM_SCALAR >& src) {
98  return *this;
99  }
100 
102  template < typename GUM_SCALAR >
103  BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::
104  operator=(BNLearner< GUM_SCALAR >&& src) {
105  genericBNLearner::operator=(std::move(src));
106  return *this;
107  }
108 
110  template < typename GUM_SCALAR >
111  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
112  // create the score, the apriori and the estimator
113  auto notification = checkScoreAprioriCompatibility();
114  if (notification != "") {
115  std::cout << "[aGrUM notification] " << notification << std::endl;
116  }
117  __createApriori();
118  __createScore();
119 
120  std::unique_ptr< ParamEstimator<> > param_estimator(
122 
123  return __Dag2BN.createBN< GUM_SCALAR >(*(param_estimator.get()),
124  __learnDAG());
125  }
126 
128  template < typename GUM_SCALAR >
129  BayesNet< GUM_SCALAR >
131  bool take_into_account_score) {
132  // if the dag contains no node, return an empty BN
133  if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
134 
135  // check that the dag corresponds to the database
136  std::vector< NodeId > ids;
137  ids.reserve(dag.sizeNodes());
138  for (const auto node : dag)
139  ids.push_back(node);
140  std::sort(ids.begin(), ids.end());
141 
142  if (ids.back() >= __score_database.names().size()) {
143  std::stringstream str;
144  str << "Learning parameters corresponding to the dag is impossible "
145  << "because the database does not contain the following nodeID";
146  std::vector< NodeId > bad_ids;
147  for (const auto node : ids) {
148  if (node >= __score_database.names().size()) bad_ids.push_back(node);
149  }
150  if (bad_ids.size() > 1) str << 's';
151  str << ": ";
152  bool deja = false;
153  for (const auto node : bad_ids) {
154  if (deja)
155  str << ", ";
156  else
157  deja = true;
158  str << node;
159  }
160  GUM_ERROR(MissingVariableInDatabase, str.str());
161  }
162 
163  // create the apriori
164  __createApriori();
165 
166  if (__EMepsilon == 0.0) {
167  // check that the database does not contain any missing value
169  || ((__apriori_database != nullptr)
172  GUM_ERROR(MissingValueInDatabase,
173  "In general, the BNLearner is unable to cope with "
174  << "missing values in databases. To learn parameters in "
175  << "such situations, you should first use method "
176  << "useEM()");
177  }
178 
179  // create the usual estimator
180  DBRowGeneratorParser<> parser(__score_database.databaseTable().handler(),
181  DBRowGeneratorSet<>());
182  std::unique_ptr< ParamEstimator<> > param_estimator(
183  __createParamEstimator(parser, take_into_account_score));
184 
185  return __Dag2BN.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
186  } else {
187  // EM !
188  BNLearnerListener listener(this, __Dag2BN);
189 
190  // get the column types
191  const auto& database = __score_database.databaseTable();
192  const std::size_t nb_vars = database.nbVariables();
193  const std::vector< gum::learning::DBTranslatedValueType > col_types(
195 
196  // create the bootstrap estimator
197  DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
198  DBRowGeneratorSet<> genset_bootstrap;
199  genset_bootstrap.insertGenerator(generator_bootstrap);
200  DBRowGeneratorParser<> parser_bootstrap(database.handler(),
201  genset_bootstrap);
202  std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
203  __createParamEstimator(parser_bootstrap, take_into_account_score));
204 
205  // create the EM estimator
206  BayesNet< GUM_SCALAR > dummy_bn;
207  DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
208  DBRowGenerator<>& gen_EM = generator_EM; // fix for g++-4.8
209  DBRowGeneratorSet<> genset_EM;
210  genset_EM.insertGenerator(gen_EM);
211  DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
212  std::unique_ptr< ParamEstimator<> > param_estimator_EM(
213  __createParamEstimator(parser_EM, take_into_account_score));
214 
216  return __Dag2BN.createBN< GUM_SCALAR >(
217  *(param_estimator_bootstrap.get()), *(param_estimator_EM.get()), dag);
218  }
219  }
220 
221 
223  template < typename GUM_SCALAR >
224  BayesNet< GUM_SCALAR >
225  BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
226  return learnParameters(__initial_dag, take_into_account_score);
227  }
228 
229 
230  template < typename GUM_SCALAR >
231  NodeProperty< Sequence< std::string > >
232  BNLearner< GUM_SCALAR >::__labelsFromBN(const std::string& filename,
233  const BayesNet< GUM_SCALAR >& src) {
234  std::ifstream in(filename, std::ifstream::in);
235 
236  if ((in.rdstate() & std::ifstream::failbit) != 0) {
237  GUM_ERROR(gum::IOError, "File " << filename << " not found");
238  }
239 
240  CSVParser<> parser(in);
241  parser.next();
242  auto names = parser.current();
243 
244  NodeProperty< Sequence< std::string > > modals;
245 
246  for (gum::Idx col = 0; col < names.size(); col++) {
247  try {
248  gum::NodeId graphId = src.idFromName(names[col]);
249  modals.insert(col, gum::Sequence< std::string >());
250 
251  for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
252  modals[col].insert(src.variable(graphId).label(i));
253  } catch (const gum::NotFound&) {
254  // no problem : a column which is not in the BN...
255  }
256  }
257 
258  return modals;
259  }
260 
261  } /* namespace learning */
262 
263 } /* namespace gum */
264 
265 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
Class representing a Bayesian Network.
Definition: BayesNet.h:78
static BayesNet< GUM_SCALAR > createBN(ParamEstimator< ALLOC > &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
Database __score_database
the database to be used by the scores and parameter estimators
double __EMepsilon
epsilon for EM. if espilon=0.0 : no EM
void __createScore()
create the score used for learning
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
const std::vector< std::string > & names() const
returns the names of the variables in the database
AprioriType __apriori_type
the a priori selected for the score and parameters
BayesNet< GUM_SCALAR > learnParameters(const DAG &dag, bool take_into_account_score=true)
learns a BN (its parameters) when its structure is known
NodeProperty< Sequence< std::string > > __labelsFromBN(const std::string &filename, const BayesNet< GUM_SCALAR > &src)
read the first line of a file to find column names
DAG __initial_dag
an initial DAG given to learners
std::size_t nbVariables() const noexcept
returns the number of variables (columns) of the database
genericBNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols)
default constructor
Database * __apriori_database
the database used by the Dirichlet a priori
bool hasMissingValues() const
indicates whether the database contains some missing values
ParamEstimator * __createParamEstimator(DBRowGeneratorParser<> &parser, bool take_into_account_score=true)
create the parameter estimator used for learning
DAG __learnDAG()
returns the DAG learnt
genericBNLearner & operator=(const genericBNLearner &)
copy operator
virtual ~BNLearner()
destructor
const DatabaseTable & databaseTable() const
returns the internal database table
std::string checkScoreAprioriCompatibility()
checks whether the current score and apriori are compatible
DAG2BNLearner __Dag2BN
the parametric EM
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
DBRowGeneratorParser & parser()
returns the parser for the database
BayesNet< GUM_SCALAR > learnBN()
learn a Bayes Net from a file (must have read the db before)
void __createApriori()
create the apriori used for learning
BNLearner & operator=(const BNLearner &)
copy operator
void setEpsilon(double eps)
Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|.
Size Idx
Type for indexes.
Definition: types.h:53
BNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols={"?"})
default constructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:48
const DatabaseTable & database() const
returns the database used by the BNLearner
const std::vector< std::string > & names() const
returns the names of the variables in the database
Size NodeId
Type for node ids.
Definition: graphElements.h:98
iterator handler() const
returns a new unsafe handler pointing to the 1st record of the database
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55