aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
genericBNLearner_inl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing, 3off2 and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 
31 // to help IDE parser
32 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h>
33 #include <agrum/tools/graphs/undiGraph.h>
34 
35 namespace gum {
36 
37  namespace learning {
38 
39  // returns the row filter
40  INLINE DBRowGeneratorParser<>& genericBNLearner::Database::parser() { return *_parser_; }
41 
42  // returns the modalities of the variables
44  return _domain_sizes_;
45  }
46 
47  // returns the names of the variables in the database
49  return _database_.variableNames();
50  }
51 
52  /// assign new weight to the rows of the learning database
54  if (_database_.nbRows() == std::size_t(0)) return;
55  const double weight = new_weight / double(_database_.nbRows());
57  }
58 
59  // returns the node id corresponding to a variable name
61  try {
63  return _nodeId2cols_.first(cols[0]);
64  } catch (...) {
66  "Variable " << var_name << " could not be found in the database");
67  }
68  }
69 
70 
71  // returns the variable name corresponding to a given node id
73  try {
75  } catch (...) {
77  "Variable of Id " << id << " could not be found in the database");
78  }
79  }
80 
81 
82  /// returns the internal database table
84  return _database_;
85  }
86 
87 
88  /// returns the set of missing symbols taken into account
90  return _database_.missingSymbols();
91  }
92 
93 
94  /// returns the mapping between node ids and their columns in the database
95  INLINE const Bijection< NodeId, std::size_t >&
97  return _nodeId2cols_;
98  }
99 
100 
101  /// returns the number of records in the database
103 
104 
105  /// returns the number of records in the database
107 
108 
109  /// sets the weight of the ith record
110  INLINE void genericBNLearner::Database::setWeight(const std::size_t i, const double weight) {
112  }
113 
114 
115  /// returns the weight of the ith record
116  INLINE double genericBNLearner::Database::weight(const std::size_t i) const {
117  return _database_.weight(i);
118  }
119 
120 
121  /// returns the weight of the whole database
122  INLINE double genericBNLearner::Database::weight() const { return _database_.weight(); }
123 
124 
125  // ===========================================================================
126 
127  // returns the node id corresponding to a variable name
130  }
131 
132  // returns the variable name corresponding to a given node id
134  return scoreDatabase_.nameFromId(id);
135  }
136 
137  /// assign new weight to the rows of the learning database
140  }
141 
142  /// assign new weight to the ith row of the learning database
145  }
146 
147  /// returns the weight of the ith record
148  INLINE double genericBNLearner::recordWeight(const std::size_t i) const {
149  return scoreDatabase_.weight(i);
150  }
151 
152  /// returns the weight of the whole database
154 
155  // sets an initial DAG structure
157 
158  // indicate that we wish to use an AIC score
162  }
163 
164  // indicate that we wish to use a BD score
168  }
169 
170  // indicate that we wish to use a BDeu score
174  }
175 
176  // indicate that we wish to use a BIC score
180  }
181 
182  // indicate that we wish to use a K2 score
186  }
187 
188  // indicate that we wish to use a Log2Likelihood score
192  }
193 
194  // sets the max indegree
197  }
198 
199  // indicate that we wish to use 3off2
203  }
204 
205  // indicate that we wish to use 3off2
209  }
210 
211  /// indicate that we wish to use the NML correction for 3off2
214  }
215 
216  /// indicate that we wish to use the MDL correction for 3off2
219  }
220 
221  /// indicate that we wish to use the NoCorr correction for 3off2
224  }
225 
226  /// get the list of arcs hiding latent variables
229  }
230 
231  // indicate that we wish to use a K2 algorithm
235  }
236 
237  // indicate that we wish to use a K2 algorithm
241  }
242 
243  // indicate that we wish to use a greedy hill climbing algorithm
246  }
247 
248  // indicate that we wish to use a local search with tabu list
253  }
254 
255  /// use The EM algorithm to learn paramters
257 
258 
261  }
262 
263  // assign a set of forbidden edges
266  }
267  // assign a set of forbidden edges from an UndiGraph
270  }
271 
272  // assign a new possible edge
275  }
276 
277  // remove a forbidden edge
280  }
281 
282  // assign a new forbidden edge
285  }
286 
287  // remove a forbidden edge
290  }
291 
292  // assign a new forbidden edge
294  const std::string& head) {
296  }
297 
298  // remove a forbidden edge
300  const std::string& head) {
302  }
303 
304  // assign a set of forbidden arcs
307  }
308 
309  // assign a new forbidden arc
312  }
313 
314  // remove a forbidden arc
317  }
318 
319  // assign a new forbidden arc
322  }
323 
324  // remove a forbidden arc
327  }
328 
329  // assign a new forbidden arc
331  const std::string& head) {
333  }
334 
335  // remove a forbidden arc
337  const std::string& head) {
339  }
340 
341  // assign a set of forbidden arcs
344  }
345 
346  // assign a new forbidden arc
349  }
350 
351  // remove a forbidden arc
354  }
355 
356  // assign a new forbidden arc
358  const std::string& head) {
360  }
361 
362  // remove a forbidden arc
364  const std::string& head) {
366  }
367 
368  // assign a new forbidden arc
371  }
372 
373  // remove a forbidden arc
376  }
377 
378  // sets a partial order on the nodes
381  }
382 
383  INLINE void
386  NodeId rank = 0;
387  for (const auto& slice: slices) {
388  for (const auto& name: slice) {
390  }
391  rank++;
392  }
394  }
395 
396  // sets the apriori weight
398  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
399 
402  }
403 
404  // use the apriori smoothing
408  }
409 
410  // use the apriori smoothing
412  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
413 
416 
418  }
419 
420  // use the Dirichlet apriori
422  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
423 
427 
429  }
430 
431 
432  // use the apriori BDeu
434  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
435 
438 
440  }
441 
442 
443  // returns the type (as a string) of a given apriori
445  switch (aprioriType_) {
446  case AprioriType::NO_APRIORI:
447  return AprioriNoApriori<>::type::type;
448 
449  case AprioriType::SMOOTHING:
450  return AprioriSmoothing<>::type::type;
451 
454 
455  case AprioriType::BDEU:
456  return AprioriBDeu<>::type::type;
457 
458  default:
460  "genericBNLearner getAprioriType does "
461  "not support yet this apriori");
462  }
463  }
464 
465  // returns the names of the variables in the database
467  return scoreDatabase_.names();
468  }
469 
470  // returns the modalities of the variables in the database
472  return scoreDatabase_.domainSizes();
473  }
474 
475  // returns the modalities of a variable in the database
477  return scoreDatabase_.domainSizes()[var];
478  }
479  // returns the modalities of a variables in the database
482  }
483 
484  /// returns the current database rows' ranges used for learning
485  INLINE const std::vector< std::pair< std::size_t, std::size_t > >&
487  return ranges_;
488  }
489 
490  /// reset the ranges to the one range corresponding to the whole database
492 
493  /// returns the database used by the BNLearner
495  return scoreDatabase_.databaseTable();
496  }
497 
499 
501  } /* namespace learning */
502 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)