aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
BNLearner_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 #include <fstream>
31 
32 #ifndef DOXYGEN_SHOULD_SKIP_THIS
33 
34 // to help IDE parser
35 # include <agrum/BN/learning/BNLearner.h>
36 
37 # include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h>
38 
39 namespace gum {
40 
41  namespace learning {
42  template < typename GUM_SCALAR >
43  BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
44  const std::vector< std::string >& missing_symbols) :
45  genericBNLearner(filename, missing_symbols) {
46  GUM_CONSTRUCTOR(BNLearner);
47  }
48 
49  template < typename GUM_SCALAR >
50  BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db) : genericBNLearner(db) {
51  GUM_CONSTRUCTOR(BNLearner);
52  }
53 
54  template < typename GUM_SCALAR >
55  BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
56  const gum::BayesNet< GUM_SCALAR >& bn,
57  const std::vector< std::string >& missing_symbols) :
58  genericBNLearner(filename, bn, missing_symbols) {
59  GUM_CONSTRUCTOR(BNLearner);
60  }
61 
62  /// copy constructor
63  template < typename GUM_SCALAR >
64  BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) : genericBNLearner(src) {
65  GUM_CONSTRUCTOR(BNLearner);
66  }
67 
68  /// move constructor
69  template < typename GUM_SCALAR >
70  BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : genericBNLearner(src) {
71  GUM_CONSTRUCTOR(BNLearner);
72  }
73 
74  /// destructor
75  template < typename GUM_SCALAR >
76  BNLearner< GUM_SCALAR >::~BNLearner() {
77  GUM_DESTRUCTOR(BNLearner);
78  }
79 
80  /// @}
81 
82  // ##########################################################################
83  /// @name Operators
84  // ##########################################################################
85  /// @{
86 
87  /// copy operator
88  template < typename GUM_SCALAR >
89  BNLearner< GUM_SCALAR >&
90  BNLearner< GUM_SCALAR >::operator=(const BNLearner< GUM_SCALAR >& src) {
91  genericBNLearner::operator=(src);
92  return *this;
93  }
94 
95  /// move operator
96  template < typename GUM_SCALAR >
97  BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) {
98  genericBNLearner::operator=(std::move(src));
99  return *this;
100  }
101 
102  /// learn a Bayes Net from a file
103  template < typename GUM_SCALAR >
104  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
105  // create the score, the apriori and the estimator
106  auto notification = checkScoreAprioriCompatibility();
107  if (notification != "") { std::cout << "[aGrUM notification] " << notification << std::endl; }
108  createApriori_();
109  createScore_();
110 
111  std::unique_ptr< ParamEstimator<> > param_estimator(
112  createParamEstimator_(scoreDatabase_.parser(), true));
113 
114  return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_());
115  }
116 
117  /// learns a BN (its parameters) when its structure is known
118  template < typename GUM_SCALAR >
119  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(const DAG& dag,
120  bool takeIntoAccountScore) {
121  // if the dag contains no node, return an empty BN
122  if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
123 
124  // check that the dag corresponds to the database
125  std::vector< NodeId > ids;
126  ids.reserve(dag.sizeNodes());
127  for (const auto node: dag)
128  ids.push_back(node);
129  std::sort(ids.begin(), ids.end());
130 
131  if (ids.back() >= scoreDatabase_.names().size()) {
132  std::stringstream str;
133  str << "Learning parameters corresponding to the dag is impossible "
134  << "because the database does not contain the following nodeID";
135  std::vector< NodeId > bad_ids;
136  for (const auto node: ids) {
137  if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node);
138  }
139  if (bad_ids.size() > 1) str << 's';
140  str << ": ";
141  bool deja = false;
142  for (const auto node: bad_ids) {
143  if (deja)
144  str << ", ";
145  else
146  deja = true;
147  str << node;
148  }
149  GUM_ERROR(MissingVariableInDatabase, str.str())
150  }
151 
152  // create the apriori
153  createApriori_();
154 
155  if (epsilonEM_ == 0.0) {
156  // check that the database does not contain any missing value
157  if (scoreDatabase_.databaseTable().hasMissingValues()
158  || ((aprioriDatabase_ != nullptr)
159  && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
160  && aprioriDatabase_->databaseTable().hasMissingValues())) {
161  GUM_ERROR(MissingValueInDatabase,
162  "In general, the BNLearner is unable to cope with "
163  << "missing values in databases. To learn parameters in "
164  << "such situations, you should first use method "
165  << "useEM()");
166  }
167 
168  // create the usual estimator
169  DBRowGeneratorParser<> parser(scoreDatabase_.databaseTable().handler(),
170  DBRowGeneratorSet<>());
171  std::unique_ptr< ParamEstimator<> > param_estimator(
172  createParamEstimator_(parser, takeIntoAccountScore));
173 
174  return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
175  } else {
176  // EM !
177  BNLearnerListener listener(this, Dag2BN_);
178 
179  // get the column types
180  const auto& database = scoreDatabase_.databaseTable();
181  const std::size_t nb_vars = database.nbVariables();
182  const std::vector< gum::learning::DBTranslatedValueType > col_types(
183  nb_vars,
184  gum::learning::DBTranslatedValueType::DISCRETE);
185 
186  // create the bootstrap estimator
187  DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
188  DBRowGeneratorSet<> genset_bootstrap;
189  genset_bootstrap.insertGenerator(generator_bootstrap);
190  DBRowGeneratorParser<> parser_bootstrap(database.handler(), genset_bootstrap);
191  std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
192  createParamEstimator_(parser_bootstrap, takeIntoAccountScore));
193 
194  // create the EM estimator
195  BayesNet< GUM_SCALAR > dummy_bn;
196  DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
197  DBRowGenerator<>& gen_EM = generator_EM; // fix for g++-4.8
198  DBRowGeneratorSet<> genset_EM;
199  genset_EM.insertGenerator(gen_EM);
200  DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
201  std::unique_ptr< ParamEstimator<> > param_estimator_EM(
202  createParamEstimator_(parser_EM, takeIntoAccountScore));
203 
204  Dag2BN_.setEpsilon(epsilonEM_);
205  return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()),
206  *(param_estimator_EM.get()),
207  dag);
208  }
209  }
210 
211 
212  /// learns a BN (its parameters) when its structure is known
213  template < typename GUM_SCALAR >
214  BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
215  return learnParameters(initialDag_, take_into_account_score);
216  }
217 
218 
219  template < typename GUM_SCALAR >
220  NodeProperty< Sequence< std::string > >
221  BNLearner< GUM_SCALAR >::_labelsFromBN_(const std::string& filename,
222  const BayesNet< GUM_SCALAR >& src) {
223  std::ifstream in(filename, std::ifstream::in);
224 
225  if ((in.rdstate() & std::ifstream::failbit) != 0) {
226  GUM_ERROR(gum::IOError, "File " << filename << " not found")
227  }
228 
229  CSVParser<> parser(in);
230  parser.next();
231  auto names = parser.current();
232 
233  NodeProperty< Sequence< std::string > > modals;
234 
235  for (gum::Idx col = 0; col < names.size(); col++) {
236  try {
237  gum::NodeId graphId = src.idFromName(names[col]);
238  modals.insert(col, gum::Sequence< std::string >());
239 
240  for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
241  modals[col].insert(src.variable(graphId).label(i));
242  } catch (const gum::NotFound&) {
243  // no problem : a column which is not in the BN...
244  }
245  }
246 
247  return modals;
248  }
249 
250  } /* namespace learning */
251 
252 } /* namespace gum */
253 
254 #endif /* DOXYGEN_SHOULD_SKIP_THIS */