aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
BNLearner_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 #include <fstream>
31 
32 #ifndef DOXYGEN_SHOULD_SKIP_THIS
33 
34 // to help IDE parser
35 # include <agrum/BN/learning/BNLearner.h>
36 
37 # include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h>
38 
39 namespace gum {
40 
41  namespace learning {
42  template < typename GUM_SCALAR >
43  BNLearner< GUM_SCALAR >::BNLearner(
44  const std::string& filename,
45  const std::vector< std::string >& missing_symbols) :
46  genericBNLearner(filename, missing_symbols) {
47  GUM_CONSTRUCTOR(BNLearner);
48  }
49 
50  template < typename GUM_SCALAR >
51  BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db) :
52  genericBNLearner(db) {
53  GUM_CONSTRUCTOR(BNLearner);
54  }
55 
56  template < typename GUM_SCALAR >
57  BNLearner< GUM_SCALAR >::BNLearner(
58  const std::string& filename,
59  const gum::BayesNet< GUM_SCALAR >& bn,
60  const std::vector< std::string >& missing_symbols) :
61  genericBNLearner(filename, bn, missing_symbols) {
62  GUM_CONSTRUCTOR(BNLearner);
63  }
64 
65  /// copy constructor
66  template < typename GUM_SCALAR >
67  BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) :
68  genericBNLearner(src) {
69  GUM_CONSTRUCTOR(BNLearner);
70  }
71 
72  /// move constructor
73  template < typename GUM_SCALAR >
74  BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) :
75  genericBNLearner(src) {
76  GUM_CONSTRUCTOR(BNLearner);
77  }
78 
79  /// destructor
80  template < typename GUM_SCALAR >
81  BNLearner< GUM_SCALAR >::~BNLearner() {
82  GUM_DESTRUCTOR(BNLearner);
83  }
84 
85  /// @}
86 
87  // ##########################################################################
88  /// @name Operators
89  // ##########################################################################
90  /// @{
91 
92  /// copy operator
93  template < typename GUM_SCALAR >
94  BNLearner< GUM_SCALAR >&
95  BNLearner< GUM_SCALAR >::operator=(const BNLearner< GUM_SCALAR >& src) {
96  genericBNLearner::operator=(src);
97  return *this;
98  }
99 
100  /// move operator
101  template < typename GUM_SCALAR >
102  BNLearner< GUM_SCALAR >&
103  BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) {
104  genericBNLearner::operator=(std::move(src));
105  return *this;
106  }
107 
108  /// learn a Bayes Net from a file
109  template < typename GUM_SCALAR >
110  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
111  // create the score, the apriori and the estimator
112  auto notification = checkScoreAprioriCompatibility();
113  if (notification != "") {
114  std::cout << "[aGrUM notification] " << notification << std::endl;
115  }
116  createApriori__();
117  createScore__();
118 
119  std::unique_ptr< ParamEstimator<> > param_estimator(
120  createParamEstimator__(score_database__.parser(), true));
121 
122  return Dag2BN__.createBN< GUM_SCALAR >(*(param_estimator.get()),
123  learnDAG__());
124  }
125 
126  /// learns a BN (its parameters) when its structure is known
127  template < typename GUM_SCALAR >
128  BayesNet< GUM_SCALAR >
129  BNLearner< GUM_SCALAR >::learnParameters(const DAG& dag,
130  bool take_into_account_score) {
131  // if the dag contains no node, return an empty BN
132  if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
133 
134  // check that the dag corresponds to the database
135  std::vector< NodeId > ids;
136  ids.reserve(dag.sizeNodes());
137  for (const auto node: dag)
138  ids.push_back(node);
139  std::sort(ids.begin(), ids.end());
140 
141  if (ids.back() >= score_database__.names().size()) {
142  std::stringstream str;
143  str << "Learning parameters corresponding to the dag is impossible "
144  << "because the database does not contain the following nodeID";
145  std::vector< NodeId > bad_ids;
146  for (const auto node: ids) {
147  if (node >= score_database__.names().size()) bad_ids.push_back(node);
148  }
149  if (bad_ids.size() > 1) str << 's';
150  str << ": ";
151  bool deja = false;
152  for (const auto node: bad_ids) {
153  if (deja)
154  str << ", ";
155  else
156  deja = true;
157  str << node;
158  }
159  GUM_ERROR(MissingVariableInDatabase, str.str());
160  }
161 
162  // create the apriori
163  createApriori__();
164 
165  if (EMepsilon__ == 0.0) {
166  // check that the database does not contain any missing value
167  if (score_database__.databaseTable().hasMissingValues()
168  || ((apriori_database__ != nullptr)
169  && (apriori_type__ == AprioriType::DIRICHLET_FROM_DATABASE)
170  && apriori_database__->databaseTable().hasMissingValues())) {
171  GUM_ERROR(MissingValueInDatabase,
172  "In general, the BNLearner is unable to cope with "
173  << "missing values in databases. To learn parameters in "
174  << "such situations, you should first use method "
175  << "useEM()");
176  }
177 
178  // create the usual estimator
179  DBRowGeneratorParser<> parser(score_database__.databaseTable().handler(),
180  DBRowGeneratorSet<>());
181  std::unique_ptr< ParamEstimator<> > param_estimator(
182  createParamEstimator__(parser, take_into_account_score));
183 
184  return Dag2BN__.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
185  } else {
186  // EM !
187  BNLearnerListener listener(this, Dag2BN__);
188 
189  // get the column types
190  const auto& database = score_database__.databaseTable();
191  const std::size_t nb_vars = database.nbVariables();
192  const std::vector< gum::learning::DBTranslatedValueType > col_types(
193  nb_vars,
194  gum::learning::DBTranslatedValueType::DISCRETE);
195 
196  // create the bootstrap estimator
197  DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
198  DBRowGeneratorSet<> genset_bootstrap;
199  genset_bootstrap.insertGenerator(generator_bootstrap);
200  DBRowGeneratorParser<> parser_bootstrap(database.handler(),
201  genset_bootstrap);
202  std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
203  createParamEstimator__(parser_bootstrap, take_into_account_score));
204 
205  // create the EM estimator
206  BayesNet< GUM_SCALAR > dummy_bn;
207  DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
208  DBRowGenerator<>& gen_EM = generator_EM; // fix for g++-4.8
209  DBRowGeneratorSet<> genset_EM;
210  genset_EM.insertGenerator(gen_EM);
211  DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
212  std::unique_ptr< ParamEstimator<> > param_estimator_EM(
213  createParamEstimator__(parser_EM, take_into_account_score));
214 
215  Dag2BN__.setEpsilon(EMepsilon__);
216  return Dag2BN__.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()),
217  *(param_estimator_EM.get()),
218  dag);
219  }
220  }
221 
222 
223  /// learns a BN (its parameters) when its structure is known
224  template < typename GUM_SCALAR >
225  BayesNet< GUM_SCALAR >
226  BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
227  return learnParameters(initial_dag__, take_into_account_score);
228  }
229 
230 
231  template < typename GUM_SCALAR >
232  NodeProperty< Sequence< std::string > >
233  BNLearner< GUM_SCALAR >::labelsFromBN__(const std::string& filename,
234  const BayesNet< GUM_SCALAR >& src) {
235  std::ifstream in(filename, std::ifstream::in);
236 
237  if ((in.rdstate() & std::ifstream::failbit) != 0) {
238  GUM_ERROR(gum::IOError, "File " << filename << " not found");
239  }
240 
241  CSVParser<> parser(in);
242  parser.next();
243  auto names = parser.current();
244 
245  NodeProperty< Sequence< std::string > > modals;
246 
247  for (gum::Idx col = 0; col < names.size(); col++) {
248  try {
249  gum::NodeId graphId = src.idFromName(names[col]);
250  modals.insert(col, gum::Sequence< std::string >());
251 
252  for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
253  modals[col].insert(src.variable(graphId).label(i));
254  } catch (const gum::NotFound&) {
255  // no problem : a column which is not in the BN...
256  }
257  }
258 
259  return modals;
260  }
261 
262  } /* namespace learning */
263 
264 } /* namespace gum */
265 
266 #endif /* DOXYGEN_SHOULD_SKIP_THIS */