aGrUM  0.14.2
BNLearner_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}@lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it wil be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
28 #include <fstream>
29 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
32 // to help IDE parser
34 
36 
37 namespace gum {
38 
39  namespace learning {
40  template < typename GUM_SCALAR >
42  const std::string& filename,
43  const std::vector< std::string >& missing_symbols) :
44  genericBNLearner(filename, missing_symbols) {
45  GUM_CONSTRUCTOR(BNLearner);
46  }
47 
48  template < typename GUM_SCALAR >
49  BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db) :
50  genericBNLearner(db) {
51  GUM_CONSTRUCTOR(BNLearner);
52  }
53 
54  template < typename GUM_SCALAR >
56  const std::string& filename,
58  const std::vector< std::string >& missing_symbols) :
59  genericBNLearner(filename, bn, missing_symbols) {
60  GUM_CONSTRUCTOR(BNLearner);
61  }
62 
64  template < typename GUM_SCALAR >
65  BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) :
66  genericBNLearner(src) {
67  GUM_CONSTRUCTOR(BNLearner);
68  }
69 
71  template < typename GUM_SCALAR >
72  BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) :
73  genericBNLearner(src) {
74  GUM_CONSTRUCTOR(BNLearner);
75  }
76 
78  template < typename GUM_SCALAR >
80  GUM_DESTRUCTOR(BNLearner);
81  }
82 
84 
85  // ##########################################################################
87  // ##########################################################################
89 
91  template < typename GUM_SCALAR >
92  BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::
93  operator=(const BNLearner< GUM_SCALAR >& src) {
95  return *this;
96  }
97 
99  template < typename GUM_SCALAR >
100  BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::
101  operator=(BNLearner< GUM_SCALAR >&& src) {
102  genericBNLearner::operator=(std::move(src));
103  return *this;
104  }
105 
107  template < typename GUM_SCALAR >
108  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
109  // create the score, the apriori and the estimator
110  auto notification = checkScoreAprioriCompatibility();
111  if (notification != "") {
112  std::cout << "[aGrUM notification] " << notification << std::endl;
113  }
114  __createApriori();
115  __createScore();
116 
117  std::unique_ptr< ParamEstimator<> > param_estimator(
119 
120  return __Dag2BN.createBN< GUM_SCALAR >(*(param_estimator.get()),
121  __learnDAG());
122  }
123 
125  template < typename GUM_SCALAR >
126  BayesNet< GUM_SCALAR >
128  bool take_into_account_score) {
129  // if the dag contains no node, return an empty BN
130  if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
131 
132  // check that the dag corresponds to the database
133  std::vector< NodeId > ids;
134  ids.reserve(dag.sizeNodes());
135  for (const auto node : dag)
136  ids.push_back(node);
137  std::sort(ids.begin(), ids.end());
138 
139  if (ids.back() >= __score_database.names().size()) {
140  std::stringstream str;
141  str << "Learning parameters corresponding to the dag is impossible "
142  << "because the database does not contain the following nodeID";
143  std::vector< NodeId > bad_ids;
144  for (const auto node : ids) {
145  if (node >= __score_database.names().size()) bad_ids.push_back(node);
146  }
147  if (bad_ids.size() > 1) str << 's';
148  str << ": ";
149  bool deja = false;
150  for (const auto node : bad_ids) {
151  if (deja)
152  str << ", ";
153  else
154  deja = true;
155  str << node;
156  }
157  GUM_ERROR(MissingVariableInDatabase, str.str());
158  }
159 
160  // create the apriori
161  __createApriori();
162 
163  if (__EMepsilon == 0.0) {
164  // check that the database does not contain any missing value
166  || ((__apriori_database != nullptr)
169  GUM_ERROR(MissingValueInDatabase,
170  "In general, the BNLearner is unable to cope with "
171  << "missing values in databases. To learn parameters in "
172  << "such situations, you should first use method "
173  << "useEM()");
174  }
175 
176  // create the usual estimator
177  DBRowGeneratorParser<> parser(__score_database.databaseTable().handler(),
178  DBRowGeneratorSet<>());
179  std::unique_ptr< ParamEstimator<> > param_estimator(
180  __createParamEstimator(parser, take_into_account_score));
181 
182  return __Dag2BN.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
183  } else {
184  // EM !
185  BNLearnerListener listener(this, __Dag2BN);
186 
187  // get the column types
188  const auto& database = __score_database.databaseTable();
189  const std::size_t nb_vars = database.nbVariables();
190  const std::vector< gum::learning::DBTranslatedValueType > col_types(
192 
193  // create the bootstrap estimator
194  DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
195  DBRowGeneratorSet<> genset_bootstrap;
196  genset_bootstrap.insertGenerator(generator_bootstrap);
197  DBRowGeneratorParser<> parser_bootstrap(database.handler(),
198  genset_bootstrap);
199  std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
200  __createParamEstimator(parser_bootstrap, take_into_account_score));
201 
202  // create the EM estimator
203  BayesNet< GUM_SCALAR > dummy_bn;
204  DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
205  DBRowGenerator<>& gen_EM = generator_EM; // fix for g++-4.8
206  DBRowGeneratorSet<> genset_EM;
207  genset_EM.insertGenerator(gen_EM);
208  DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
209  std::unique_ptr< ParamEstimator<> > param_estimator_EM(
210  __createParamEstimator(parser_EM, take_into_account_score));
211 
213  return __Dag2BN.createBN< GUM_SCALAR >(
214  *(param_estimator_bootstrap.get()), *(param_estimator_EM.get()), dag);
215  }
216  }
217 
218 
220  template < typename GUM_SCALAR >
221  BayesNet< GUM_SCALAR >
222  BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
223  return learnParameters(__initial_dag, take_into_account_score);
224  }
225 
226 
227  template < typename GUM_SCALAR >
228  NodeProperty< Sequence< std::string > >
229  BNLearner< GUM_SCALAR >::__labelsFromBN(const std::string& filename,
230  const BayesNet< GUM_SCALAR >& src) {
231  std::ifstream in(filename, std::ifstream::in);
232 
233  if ((in.rdstate() & std::ifstream::failbit) != 0) {
234  GUM_ERROR(gum::IOError, "File " << filename << " not found");
235  }
236 
237  CSVParser<> parser(in);
238  parser.next();
239  auto names = parser.current();
240 
241  NodeProperty< Sequence< std::string > > modals;
242 
243  for (gum::Idx col = 0; col < names.size(); col++) {
244  try {
245  gum::NodeId graphId = src.idFromName(names[col]);
246  modals.insert(col, gum::Sequence< std::string >());
247 
248  for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
249  modals[col].insert(src.variable(graphId).label(i));
250  } catch (const gum::NotFound&) {
251  // no problem : a column which is not in the BN...
252  }
253  }
254 
255  return modals;
256  }
257 
258  } /* namespace learning */
259 
260 } /* namespace gum */
261 
262 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
Class representing a Bayesian Network.
Definition: BayesNet.h:76
static BayesNet< GUM_SCALAR > createBN(ParamEstimator< ALLOC > &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
Database __score_database
the database to be used by the scores and parameter estimators
double __EMepsilon
epsilon for EM. if espilon=0.0 : no EM
void __createScore()
create the score used for learning
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
const std::vector< std::string > & names() const
returns the names of the variables in the database
AprioriType __apriori_type
the a priori selected for the score and parameters
BayesNet< GUM_SCALAR > learnParameters(const DAG &dag, bool take_into_account_score=true)
learns a BN (its parameters) when its structure is known
NodeProperty< Sequence< std::string > > __labelsFromBN(const std::string &filename, const BayesNet< GUM_SCALAR > &src)
read the first line of a file to find column names
DAG __initial_dag
an initial DAG given to learners
std::size_t nbVariables() const noexcept
returns the number of variables (columns) of the database
genericBNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols)
default constructor
Database * __apriori_database
the database used by the Dirichlet a priori
bool hasMissingValues() const
indicates whether the database contains some missing values
ParamEstimator * __createParamEstimator(DBRowGeneratorParser<> &parser, bool take_into_account_score=true)
create the parameter estimator used for learning
DAG __learnDAG()
returns the DAG learnt
genericBNLearner & operator=(const genericBNLearner &)
copy operator
virtual ~BNLearner()
destructor
const DatabaseTable & databaseTable() const
returns the internal database table
std::string checkScoreAprioriCompatibility()
checks whether the current score and apriori are compatible
DAG2BNLearner __Dag2BN
the parametric EM
A listener that allows BNLearner to be used as a proxy for its inner algorithms.
A basic pack of learning algorithms that can easily be used.
DBRowGeneratorParser & parser()
returns the parser for the database
BayesNet< GUM_SCALAR > learnBN()
learn a Bayes Net from a file (must have read the db before)
void __createApriori()
create the apriori used for learning
BNLearner & operator=(const BNLearner &)
copy operator
void setEpsilon(double eps)
Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|.
Size Idx
Type for indexes.
Definition: types.h:50
BNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols={"?"})
default constructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
const DatabaseTable & database() const
returns the database used by the BNLearner
const std::vector< std::string > & names() const
returns the names of the variables in the database
Size NodeId
Type for node ids.
Definition: graphElements.h:97
iterator handler() const
returns a new unsafe handler pointing to the 1st record of the database
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52