aGrUM  0.21.0
a C++ library for (probabilistic) graphical models
genericBNLearner_inl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing, 3off2 and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 
31 // to help IDE parser
32 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner.h>
33 #include <agrum/tools/graphs/undiGraph.h>
34 
35 namespace gum {
36 
37  namespace learning {
38 
39  // returns the row filter
40  INLINE DBRowGeneratorParser<>& genericBNLearner::Database::parser() { return *_parser_; }
41 
42  // returns the modalities of the variables
44  return _domain_sizes_;
45  }
46 
47  // returns the names of the variables in the database
49  return _database_.variableNames();
50  }
51 
52  /// assign new weight to the rows of the learning database
54  if (_database_.nbRows() == std::size_t(0)) return;
55  const double weight = new_weight / double(_database_.nbRows());
57  }
58 
59  // returns the node id corresponding to a variable name
61  try {
63  return _nodeId2cols_.first(cols[0]);
64  } catch (...) {
66  "Variable " << var_name << " could not be found in the database");
67  }
68  }
69 
70 
71  // returns the variable name corresponding to a given node id
73  try {
75  } catch (...) {
77  "Variable of Id " << id << " could not be found in the database");
78  }
79  }
80 
81 
82  /// returns the internal database table
84  return _database_;
85  }
86 
87 
88  /// returns the set of missing symbols taken into account
90  return _database_.missingSymbols();
91  }
92 
93 
94  /// returns the mapping between node ids and their columns in the database
95  INLINE const Bijection< NodeId, std::size_t >&
97  return _nodeId2cols_;
98  }
99 
100 
101  /// returns the number of records in the database
103 
104 
105  /// returns the number of records in the database
107 
108 
109  /// sets the weight of the ith record
110  INLINE void genericBNLearner::Database::setWeight(const std::size_t i, const double weight) {
112  }
113 
114 
115  /// returns the weight of the ith record
116  INLINE double genericBNLearner::Database::weight(const std::size_t i) const {
117  return _database_.weight(i);
118  }
119 
120 
121  /// returns the weight of the whole database
122  INLINE double genericBNLearner::Database::weight() const { return _database_.weight(); }
123 
124 
125  // ===========================================================================
126 
127  // returns the node id corresponding to a variable name
130  }
131 
132  // returns the variable name corresponding to a given node id
134  return scoreDatabase_.nameFromId(id);
135  }
136 
137  /// assign new weight to the rows of the learning database
140  }
141 
142  /// assign new weight to the ith row of the learning database
145  }
146 
147  /// returns the weight of the ith record
148  INLINE double genericBNLearner::recordWeight(const std::size_t i) const {
149  return scoreDatabase_.weight(i);
150  }
151 
152  /// returns the weight of the whole database
154 
155  // sets an initial DAG structure
157 
159 
160  // indicate that we wish to use an AIC score
164  }
165 
166  // indicate that we wish to use a BD score
170  }
171 
172  // indicate that we wish to use a BDeu score
176  }
177 
178  // indicate that we wish to use a BIC score
182  }
183 
184  // indicate that we wish to use a K2 score
188  }
189 
190  // indicate that we wish to use a Log2Likelihood score
194  }
195 
196  // sets the max indegree
199  }
200 
201  // indicate that we wish to use 3off2
205  }
206 
207  // indicate that we wish to use 3off2
211  }
212 
213  /// indicate that we wish to use the NML correction for 3off2
216  }
217 
218  /// indicate that we wish to use the MDL correction for 3off2
221  }
222 
223  /// indicate that we wish to use the NoCorr correction for 3off2
226  }
227 
228  /// get the list of arcs hiding latent variables
231  }
232 
233  // indicate that we wish to use a K2 algorithm
237  }
238 
239  // indicate that we wish to use a K2 algorithm
243  }
244 
245  // indicate that we wish to use a greedy hill climbing algorithm
248  }
249 
250  // indicate that we wish to use a local search with tabu list
256  }
257 
258  /// use The EM algorithm to learn paramters
260 
261 
264  }
265 
266  // assign a set of forbidden edges
269  }
270  // assign a set of forbidden edges from an UndiGraph
273  }
274 
275  // assign a new possible edge
278  }
279 
280  // remove a forbidden edge
283  }
284 
285  // assign a new forbidden edge
288  }
289 
290  // remove a forbidden edge
293  }
294 
295  // assign a new forbidden edge
297  const std::string& head) {
299  }
300 
301  // remove a forbidden edge
303  const std::string& head) {
305  }
306 
307  // assign a set of forbidden arcs
310  }
311 
312  // assign a new forbidden arc
315  }
316 
317  // remove a forbidden arc
320  }
321 
322  // assign a new forbidden arc
325  }
326 
327  // remove a forbidden arc
330  }
331 
332  // assign a new forbidden arc
334  const std::string& head) {
336  }
337 
338  // remove a forbidden arc
340  const std::string& head) {
342  }
343 
344  // assign a set of forbidden arcs
347  }
348 
349  // assign a new forbidden arc
352  }
353 
354  // remove a forbidden arc
357  }
358 
359  // assign a new forbidden arc
361  const std::string& head) {
363  }
364 
365  // remove a forbidden arc
367  const std::string& head) {
369  }
370 
371  // assign a new forbidden arc
374  }
375 
376  // remove a forbidden arc
379  }
380 
381  // sets a partial order on the nodes
384  }
385 
386  INLINE void
389  NodeId rank = 0;
390  for (const auto& slice: slices) {
391  for (const auto& name: slice) {
393  }
394  rank++;
395  }
397  }
398 
399  // sets the apriori weight
401  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
402 
405  }
406 
407  // use the apriori smoothing
411  }
412 
413  // use the apriori smoothing
415  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
416 
419 
421  }
422 
423  // use the Dirichlet apriori
425  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
426 
430 
432  }
433 
434 
435  // use the apriori BDeu
437  if (weight < 0) { GUM_ERROR(OutOfBounds, "the weight of the apriori must be positive") }
438 
441 
443  }
444 
445 
446  // returns the type (as a string) of a given apriori
448  switch (aprioriType_) {
449  case AprioriType::NO_APRIORI:
450  return AprioriNoApriori<>::type::type;
451 
452  case AprioriType::SMOOTHING:
453  return AprioriSmoothing<>::type::type;
454 
457 
458  case AprioriType::BDEU:
459  return AprioriBDeu<>::type::type;
460 
461  default:
463  "genericBNLearner getAprioriType does "
464  "not support yet this apriori");
465  }
466  }
467 
468  // returns the names of the variables in the database
470  return scoreDatabase_.names();
471  }
472 
473  // returns the modalities of the variables in the database
475  return scoreDatabase_.domainSizes();
476  }
477 
478  // returns the modalities of a variable in the database
480  return scoreDatabase_.domainSizes()[var];
481  }
482  // returns the modalities of a variables in the database
485  }
486 
487  /// returns the current database rows' ranges used for learning
488  INLINE const std::vector< std::pair< std::size_t, std::size_t > >&
490  return ranges_;
491  }
492 
493  /// reset the ranges to the one range corresponding to the whole database
495 
496  /// returns the database used by the BNLearner
498  return scoreDatabase_.databaseTable();
499  }
500 
502 
504  } /* namespace learning */
505 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)