aGrUM  0.16.0
genericBNLearner_tpl.h
Go to the documentation of this file.
1 
22 #include <algorithm>
23 
25 
26 namespace gum {
27 
28  namespace learning {
29 
30  template < typename GUM_SCALAR >
32  const std::string& filename,
33  const BayesNet< GUM_SCALAR >& bn,
34  const std::vector< std::string >& missing_symbols) {
35  // assign to each column name in the database its position
37  DBInitializerFromCSV<> initializer(filename);
38  const auto& xvar_names = initializer.variableNames();
39  std::size_t nb_vars = xvar_names.size();
40  HashTable< std::string, std::size_t > var_names(nb_vars);
41  for (std::size_t i = std::size_t(0); i < nb_vars; ++i)
42  var_names.insert(xvar_names[i], i);
43 
44  // we use the bn to insert the translators into the database table
45  std::vector< NodeId > nodes;
46  nodes.reserve(bn.dag().sizeNodes());
47  for (const auto node : bn.dag())
48  nodes.push_back(node);
49  std::sort(nodes.begin(), nodes.end());
50  std::size_t i = std::size_t(0);
51  for (auto node : nodes) {
52  const Variable& var = bn.variable(node);
53  try {
54  __database.insertTranslator(var, var_names[var.name()], missing_symbols);
55  } catch (NotFound&) {
57  "Variable '" << var.name() << "' is missing");
58  }
59  __nodeId2cols.insert(NodeId(node), i++);
60  }
61 
62  // fill the database
63  initializer.fillDatabase(__database);
64 
65  // get the domain sizes of the variables
66  for (auto dom : __database.domainSizes())
67  __domain_sizes.push_back(dom);
68 
69  // create the parser
70  __parser =
72  }
73 
74 
75  template < typename GUM_SCALAR >
78  const std::size_t nb_vars = __database.nbVariables();
79  for (std::size_t i = 0; i < nb_vars; ++i) {
80  const DiscreteVariable& var =
81  dynamic_cast< const DiscreteVariable& >(__database.variable(i));
82  bn.add(var);
83  }
84  return bn;
85  }
86 
87 
88  template < typename GUM_SCALAR >
90  const std::string& filename,
92  const std::vector< std::string >& missing_symbols) :
93  __score_database(filename, bn, missing_symbols) {
95  GUM_CONSTRUCTOR(genericBNLearner);
96  }
97 
98 
100  template < template < typename > class XALLOC >
102  const std::vector< std::pair< std::size_t, std::size_t >,
103  XALLOC< std::pair< std::size_t, std::size_t > > >&
104  new_ranges) {
105  // use a score to detect whether the ranges are ok
107  score.setRanges(new_ranges);
108  __ranges = score.ranges();
109  }
110  } // namespace learning
111 } // namespace gum
void insert(const T1 &first, const T2 &second)
Inserts a new association in the gum::Bijection.
Class representing a Bayesian Network.
Definition: BayesNet.h:78
const std::vector< std::string, ALLOC< std::string > > & variableNames()
returns the names of the variables in the input dataset
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the score
Base class for every random variable.
Definition: variable.h:66
Database __score_database
the database to be used by the scores and parameter estimators
const DiscreteVariable & variable(NodeId id) const final
Returns a gum::DiscreteVariable given its gum::NodeId in the gum::BayesNet.
Definition: BayesNet_tpl.h:202
static void __checkFileName(const std::string &filename)
checks whether the extension of a CSV filename is correct
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
DBVector< std::size_t > domainSizes() const
returns the domain sizes of all the variables in the database table
The class used to pack sets of generators.
NodeId add(const DiscreteVariable &var)
Add a variable to the gum::BayesNet.
Definition: BayesNet_tpl.h:232
DatabaseTable __database
the database itself
Base class for discrete random variable.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
The class for generic Hash Tables.
Definition: hashTable.h:679
the class for computing Log2-likelihood scores
std::size_t nbVariables() const noexcept
returns the number of variables (columns) of the database
genericBNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols)
default constructor
void fillDatabase(DATABASE< ALLOC > &database, const bool retry_insertion=false)
fills the rows of the database table
const DatabaseTable & databaseTable() const
returns the internal database table
std::vector< std::pair< std::size_t, std::size_t > > __ranges
the set of rows&#39; ranges within the database in which learning is done
const Variable & variable(const std::size_t k, const bool k_is_input_col=false) const
returns either the kth variable of the database table or the first one corresponding to the kth colum...
std::size_t insertTranslator(const DBTranslator< ALLOC > &translator, const std::size_t input_column, const bool unique_column=true)
insert a new translator into the database table
DBRowGeneratorParser * __parser
the parser used for reading the database
A pack of learning algorithms that can easily be used.
DBRowGeneratorParser & parser()
returns the parser for the database
void useDatabaseRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
use a new set of database rows&#39; ranges to perform learning
The class for initializing DatabaseTable and RawDatabaseTable instances from CSV files.
BayesNet< GUM_SCALAR > __BNVars() const
std::vector< std::size_t > __domain_sizes
the domain sizes of the variables (useful to speed-up computations)
Bijection< NodeId, std::size_t > __nodeId2cols
a bijection assigning to each variable name its NodeId
Database(const std::string &file, const std::vector< std::string > &missing_symbols)
default constructor
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
const std::string & name() const
returns the name of the variable
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
const DAG & dag() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:63
Size NodeId
Type for node ids.
Definition: graphElements.h:98
the no a priori class: corresponds to 0 weight-sample
iterator handler() const
returns a new unsafe handler pointing to the 1st record of the database
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55