aGrUM  0.21.0
a C++ library for (probabilistic) graphical models
BNLearner_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 #include <fstream>
31 
32 #ifndef DOXYGEN_SHOULD_SKIP_THIS
33 
34 // to help IDE parser
35 # include <agrum/BN/learning/BNLearner.h>
36 
37 # include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h>
38 
39 namespace gum {
40 
41  namespace learning {
42  template < typename GUM_SCALAR >
43  BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
44  const std::vector< std::string >& missing_symbols) :
45  genericBNLearner(filename, missing_symbols) {
46  GUM_CONSTRUCTOR(BNLearner);
47  }
48 
49  template < typename GUM_SCALAR >
50  BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db) : genericBNLearner(db) {
51  GUM_CONSTRUCTOR(BNLearner);
52  }
53 
54  template < typename GUM_SCALAR >
55  BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
56  const gum::BayesNet< GUM_SCALAR >& bn,
57  const std::vector< std::string >& missing_symbols) :
58  genericBNLearner(filename, bn, missing_symbols) {
59  GUM_CONSTRUCTOR(BNLearner);
60  }
61 
62  /// copy constructor
63  template < typename GUM_SCALAR >
64  BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) : genericBNLearner(src) {
65  GUM_CONSTRUCTOR(BNLearner);
66  }
67 
68  /// move constructor
69  template < typename GUM_SCALAR >
70  BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : genericBNLearner(src) {
71  GUM_CONSTRUCTOR(BNLearner);
72  }
73 
74  /// destructor
75  template < typename GUM_SCALAR >
76  BNLearner< GUM_SCALAR >::~BNLearner() {
77  GUM_DESTRUCTOR(BNLearner);
78  }
79 
80  /// @}
81 
82  // ##########################################################################
83  /// @name Operators
84  // ##########################################################################
85  /// @{
86 
87  /// copy operator
88  template < typename GUM_SCALAR >
89  BNLearner< GUM_SCALAR >&
90  BNLearner< GUM_SCALAR >::operator=(const BNLearner< GUM_SCALAR >& src) {
91  genericBNLearner::operator=(src);
92  return *this;
93  }
94 
95  /// move operator
96  template < typename GUM_SCALAR >
97  BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) {
98  genericBNLearner::operator=(std::move(src));
99  return *this;
100  }
101 
102  /// learn a Bayes Net from a file
103  template < typename GUM_SCALAR >
104  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
105  // create the score, the apriori and the estimator
106  auto notification = checkScoreAprioriCompatibility();
107  if (notification != "") { std::cout << "[aGrUM notification] " << notification << std::endl; }
108  createApriori_();
109  createScore_();
110 
111  std::unique_ptr< ParamEstimator<> > param_estimator(
112  createParamEstimator_(scoreDatabase_.parser(), true));
113 
114  return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_());
115  }
116 
117  /// learns a BN (its parameters) when its structure is known
118  template < typename GUM_SCALAR >
119  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(const DAG& dag,
120  bool takeIntoAccountScore) {
121  // if the dag contains no node, return an empty BN
122  if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
123 
124  // check that the dag corresponds to the database
125  std::vector< NodeId > ids;
126  ids.reserve(dag.sizeNodes());
127  for (const auto node: dag)
128  ids.push_back(node);
129  std::sort(ids.begin(), ids.end());
130 
131  if (ids.back() >= scoreDatabase_.names().size()) {
132  std::stringstream str;
133  str << "Learning parameters corresponding to the dag is impossible "
134  << "because the database does not contain the following nodeID";
135  std::vector< NodeId > bad_ids;
136  for (const auto node: ids) {
137  if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node);
138  }
139  if (bad_ids.size() > 1) str << 's';
140  str << ": ";
141  bool deja = false;
142  for (const auto node: bad_ids) {
143  if (deja)
144  str << ", ";
145  else
146  deja = true;
147  str << node;
148  }
149  GUM_ERROR(MissingVariableInDatabase, str.str())
150  }
151 
152  // create the apriori
153  createApriori_();
154 
155  if (epsilonEM_ == 0.0) {
156  // check that the database does not contain any missing value
157  if (scoreDatabase_.databaseTable().hasMissingValues()
158  || ((aprioriDatabase_ != nullptr)
159  && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
160  && aprioriDatabase_->databaseTable().hasMissingValues())) {
161  GUM_ERROR(MissingValueInDatabase,
162  "In general, the BNLearner is unable to cope with "
163  << "missing values in databases. To learn parameters in "
164  << "such situations, you should first use method "
165  << "useEM()");
166  }
167 
168  // create the usual estimator
169  DBRowGeneratorParser<> parser(scoreDatabase_.databaseTable().handler(),
170  DBRowGeneratorSet<>());
171  std::unique_ptr< ParamEstimator<> > param_estimator(
172  createParamEstimator_(parser, takeIntoAccountScore));
173 
174  return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
175  } else {
176  // EM !
177  BNLearnerListener listener(this, Dag2BN_);
178 
179  // get the column types
180  const auto& database = scoreDatabase_.databaseTable();
181  const std::size_t nb_vars = database.nbVariables();
182  const std::vector< gum::learning::DBTranslatedValueType > col_types(
183  nb_vars,
184  gum::learning::DBTranslatedValueType::DISCRETE);
185 
186  // create the bootstrap estimator
187  DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
188  DBRowGeneratorSet<> genset_bootstrap;
189  genset_bootstrap.insertGenerator(generator_bootstrap);
190  DBRowGeneratorParser<> parser_bootstrap(database.handler(), genset_bootstrap);
191  std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
192  createParamEstimator_(parser_bootstrap, takeIntoAccountScore));
193 
194  // create the EM estimator
195  BayesNet< GUM_SCALAR > dummy_bn;
196  DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
197  DBRowGenerator<>& gen_EM = generator_EM; // fix for g++-4.8
198  DBRowGeneratorSet<> genset_EM;
199  genset_EM.insertGenerator(gen_EM);
200  DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
201  std::unique_ptr< ParamEstimator<> > param_estimator_EM(
202  createParamEstimator_(parser_EM, takeIntoAccountScore));
203 
204  Dag2BN_.setEpsilon(epsilonEM_);
205  return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()),
206  *(param_estimator_EM.get()),
207  dag);
208  }
209  }
210 
211 
212  /// learns a BN (its parameters) when its structure is known
213  template < typename GUM_SCALAR >
214  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
215  return learnParameters(initialDag_, take_into_account_score);
216  }
217 
218 
219  template < typename GUM_SCALAR >
220  NodeProperty< Sequence< std::string > >
221  BNLearner< GUM_SCALAR >::_labelsFromBN_(const std::string& filename,
222  const BayesNet< GUM_SCALAR >& src) {
223  std::ifstream in(filename, std::ifstream::in);
224 
225  if ((in.rdstate() & std::ifstream::failbit) != 0) {
226  GUM_ERROR(gum::IOError, "File " << filename << " not found")
227  }
228 
229  CSVParser<> parser(in);
230  parser.next();
231  auto names = parser.current();
232 
233  NodeProperty< Sequence< std::string > > modals;
234 
235  for (gum::Idx col = 0; col < names.size(); col++) {
236  try {
237  gum::NodeId graphId = src.idFromName(names[col]);
238  modals.insert(col, gum::Sequence< std::string >());
239 
240  for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
241  modals[col].insert(src.variable(graphId).label(i));
242  } catch (const gum::NotFound&) {
243  // no problem : a column which is not in the BN...
244  }
245  }
246 
247  return modals;
248  }
249 
250 
251  template < typename GUM_SCALAR >
252  std::string BNLearner< GUM_SCALAR >::toString() const {
253  const auto st = state();
254 
255  Size maxkey = 0;
256  for (const auto& tuple: st)
257  if (std::get< 0 >(tuple).length() > maxkey) maxkey = std::get< 0 >(tuple).length();
258 
259  std::stringstream s;
260  for (const auto& tuple: st) {
261  s << std::setiosflags(std::ios::left) << std::setw(maxkey) << std::get< 0 >(tuple) << " : "
262  << std::get< 1 >(tuple);
263  if (std::get< 2 >(tuple) != "") s << " (" << std::get< 2 >(tuple) << ")";
264  s << std::endl;
265  }
266  return s.str();
267  }
268 
269  template < typename GUM_SCALAR >
270  std::vector< std::tuple< std::string, std::string, std::string > >
271  BNLearner< GUM_SCALAR >::state() const {
272  std::vector< std::tuple< std::string, std::string, std::string > > vals;
273 
274  std::string key;
275  std::string comment;
276  const auto& db = database();
277 
278  vals.emplace_back("Filename", filename_, "");
279  vals.emplace_back("Size",
280  "(" + std::to_string(nbRows()) + "," + std::to_string(nbCols()) + ")",
281  "");
282 
283  std::string vars = "";
284  for (NodeId i = 0; i < db.nbVariables(); i++) {
285  if (i > 0) vars += ", ";
286  vars += nameFromId(i) + "[" + std::to_string(db.domainSize(i)) + "]";
287  }
288  vals.emplace_back("Variables", vars, "");
289  vals.emplace_back("Missing values", hasMissingValues() ? "True" : "False", "");
290 
291  key = "Algorithm";
292  switch (selectedAlgo_) {
293  case AlgoType::GREEDY_HILL_CLIMBING:
294  vals.emplace_back(key, "Greedy Hill Climbing", "");
295  break;
296  case AlgoType::K2: {
297  vals.emplace_back(key, "K2", "");
298  const auto& k2order = algoK2_.order();
299  vars = "";
300  for (NodeId i = 0; i < k2order.size(); i++) {
301  if (i > 0) vars += ", ";
302  vars += nameFromId(k2order.atPos(i));
303  }
304  vals.emplace_back("K2 order", vars, "");
305  } break;
306  case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST:
307  vals.emplace_back(key, "Local Search with Tabu List", "");
308  vals.emplace_back("Tabu list size", std::to_string(nbDecreasingChanges_), "");
309  break;
310  case AlgoType::THREE_OFF_TWO:
311  vals.emplace_back(key, "3off2", "");
312  break;
313  case AlgoType::MIIC:
314  vals.emplace_back(key, "MIIC", "");
315  break;
316  default:
317  vals.emplace_back(key, "(unknown)", "?");
318  break;
319  }
320 
321  if (selectedAlgo_ != AlgoType::MIIC && selectedAlgo_ != AlgoType::THREE_OFF_TWO) {
322  key = "Score";
323  switch (scoreType_) {
324  case ScoreType::K2:
325  vals.emplace_back(key, "K2", "");
326  break;
327  case ScoreType::AIC:
328  vals.emplace_back(key, "AIC", "");
329  break;
330  case ScoreType::BIC:
331  vals.emplace_back(key, "BIC", "");
332  break;
333  case ScoreType::BD:
334  vals.emplace_back(key, "BD", "");
335  break;
336  case ScoreType::BDeu:
337  vals.emplace_back(key, "BDeu", "");
338  break;
339  case ScoreType::LOG2LIKELIHOOD:
340  vals.emplace_back(key, "Log2Likelihood", "");
341  break;
342  default:
343  vals.emplace_back(key, "(unknown)", "?");
344  break;
345  }
346  } else {
347  key = "Correction";
348  switch (kmode3Off2_) {
349  case CorrectedMutualInformation<>::KModeTypes::MDL:
350  vals.emplace_back(key, "MDL", "");
351  break;
352  case CorrectedMutualInformation<>::KModeTypes::NML:
353  vals.emplace_back(key, "NML", "");
354  break;
355  case CorrectedMutualInformation<>::KModeTypes::NoCorr:
356  vals.emplace_back(key, "No correction", "");
357  break;
358  default:
359  vals.emplace_back(key, "(unknown)", "?");
360  break;
361  }
362  }
363 
364 
365  key = "Prior";
366  comment = checkScoreAprioriCompatibility();
367  switch (aprioriType_) {
368  case AprioriType::NO_APRIORI:
369  vals.emplace_back(key, "-", comment);
370  break;
371  case AprioriType::DIRICHLET_FROM_DATABASE:
372  vals.emplace_back(key, "Dirichlet", comment);
373  vals.emplace_back("Dirichlet database", aprioriDbname_, "");
374  break;
375  case AprioriType::BDEU:
376  vals.emplace_back(key, "BDEU", comment);
377  break;
378  case AprioriType::SMOOTHING:
379  vals.emplace_back(key, "Smoothing", comment);
380  break;
381  default:
382  vals.emplace_back(key, "(unknown)", "?");
383  break;
384  }
385 
386  if (aprioriType_ != AprioriType::NO_APRIORI)
387  vals.emplace_back("Prior weight", std::to_string(aprioriWeight_), "");
388 
389  if (databaseWeight() != double(nbRows())) {
390  vals.emplace_back("Database weight", std::to_string(databaseWeight()), "");
391  }
392 
393  if (epsilonEM_ > 0.0) {
394  comment = "";
395  if (!hasMissingValues()) comment = "But no missing values in this database";
396  vals.emplace_back("EM", "True", "");
397  vals.emplace_back("EM epsilon", std::to_string(epsilonEM_), comment);
398  }
399 
400  std::string res;
401  bool nofirst;
402  if (constraintIndegree_.maxIndegree() < std::numeric_limits< Size >::max()) {
403  vals.emplace_back("Constraint Max InDegree",
404  std::to_string(constraintIndegree_.maxIndegree()),
405  "Used only for score-based algorithms.");
406  }
407  if (!constraintForbiddenArcs_.arcs().empty()) {
408  res = "{";
409  nofirst = false;
410  for (const auto& arc: constraintForbiddenArcs_.arcs()) {
411  if (nofirst)
412  res += ", ";
413  else
414  nofirst = true;
415  res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head());
416  }
417  res += "}";
418  vals.emplace_back("Constraint Forbidden Arcs", res, "");
419  }
420  if (!constraintMandatoryArcs_.arcs().empty()) {
421  res = "{";
422  nofirst = false;
423  for (const auto& arc: constraintMandatoryArcs_.arcs()) {
424  if (nofirst)
425  res += ", ";
426  else
427  nofirst = true;
428  res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head());
429  }
430  res += "}";
431  vals.emplace_back("Constraint Mandatory Arcs", res, "");
432  }
433  if (!constraintPossibleEdges_.edges().empty()) {
434  res = "{";
435  nofirst = false;
436  for (const auto& edge: constraintPossibleEdges_.edges()) {
437  if (nofirst)
438  res += ", ";
439  else
440  nofirst = true;
441  res += nameFromId(edge.first()) + "--" + nameFromId(edge.second());
442  }
443  res += "}";
444  vals.emplace_back("Constraint Possible Edges",
445  res,
446  "Used only for score-based algorithms.");
447  }
448  if (!constraintSliceOrder_.sliceOrder().empty()) {
449  res = "{";
450  nofirst = false;
451  const auto& order = constraintSliceOrder_.sliceOrder();
452  for (const auto& p: order) {
453  if (nofirst)
454  res += ", ";
455  else
456  nofirst = true;
457  res += nameFromId(p.first) + ":" + std::to_string(p.second);
458  }
459  res += "}";
460  vals.emplace_back("Constraint Slice Order", res, "Used only for score-based algorithms.");
461  }
462  if (initialDag_.size() != 0) {
463  vals.emplace_back("Initial DAG", "True", initialDag_.toDot());
464  }
465 
466  return vals;
467  }
468 
469  template < typename GUM_SCALAR >
470  INLINE std::ostream& operator<<(std::ostream& output, const BNLearner< GUM_SCALAR >& learner) {
471  output << learner.toString();
472  return output;
473  }
474 
475  } /* namespace learning */
476 
477 } /* namespace gum */
478 
479 #endif /* DOXYGEN_SHOULD_SKIP_THIS */