30 #ifndef DOXYGEN_SHOULD_SKIP_THIS 40 template <
typename GUM_SCALAR >
42 const std::string& filename,
43 const std::vector< std::string >& missing_symbols) :
44 genericBNLearner(filename, missing_symbols) {
48 template <
typename GUM_SCALAR >
54 template <
typename GUM_SCALAR >
56 const std::string& filename,
58 const std::vector< std::string >& missing_symbols) :
64 template <
typename GUM_SCALAR >
71 template <
typename GUM_SCALAR >
78 template <
typename GUM_SCALAR >
91 template <
typename GUM_SCALAR >
93 operator=(
const BNLearner< GUM_SCALAR >& src) {
99 template <
typename GUM_SCALAR >
101 operator=(BNLearner< GUM_SCALAR >&& src) {
107 template <
typename GUM_SCALAR >
111 if (notification !=
"") {
112 std::cout <<
"[aGrUM notification] " << notification << std::endl;
117 std::unique_ptr< ParamEstimator<> > param_estimator(
125 template <
typename GUM_SCALAR >
126 BayesNet< GUM_SCALAR >
128 bool take_into_account_score) {
130 if (dag.size() == 0)
return BayesNet< GUM_SCALAR >();
133 std::vector< NodeId > ids;
134 ids.reserve(dag.sizeNodes());
135 for (
const auto node : dag)
137 std::sort(ids.begin(), ids.end());
140 std::stringstream str;
141 str <<
"Learning parameters corresponding to the dag is impossible " 142 <<
"because the database does not contain the following nodeID";
143 std::vector< NodeId > bad_ids;
144 for (
const auto node : ids) {
147 if (bad_ids.size() > 1) str <<
's';
150 for (
const auto node : bad_ids) {
157 GUM_ERROR(MissingVariableInDatabase, str.str());
170 "In general, the BNLearner is unable to cope with " 171 <<
"missing values in databases. To learn parameters in " 172 <<
"such situations, you should first use method " 178 DBRowGeneratorSet<>());
179 std::unique_ptr< ParamEstimator<> > param_estimator(
185 BNLearnerListener listener(
this,
__Dag2BN);
190 const std::vector< gum::learning::DBTranslatedValueType > col_types(
194 DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
195 DBRowGeneratorSet<> genset_bootstrap;
196 genset_bootstrap.insertGenerator(generator_bootstrap);
199 std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
203 BayesNet< GUM_SCALAR > dummy_bn;
204 DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
205 DBRowGenerator<>& gen_EM = generator_EM;
206 DBRowGeneratorSet<> genset_EM;
207 genset_EM.insertGenerator(gen_EM);
209 std::unique_ptr< ParamEstimator<> > param_estimator_EM(
214 *(param_estimator_bootstrap.get()), *(param_estimator_EM.get()), dag);
220 template <
typename GUM_SCALAR >
221 BayesNet< GUM_SCALAR >
227 template <
typename GUM_SCALAR >
228 NodeProperty< Sequence< std::string > >
230 const BayesNet< GUM_SCALAR >& src) {
231 std::ifstream in(filename, std::ifstream::in);
233 if ((in.rdstate() & std::ifstream::failbit) != 0) {
237 CSVParser<> parser(in);
239 auto names = parser.current();
241 NodeProperty< Sequence< std::string > > modals;
248 for (
gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
249 modals[col].insert(src.variable(graphId).label(i));
Class representing a Bayesian Network.
static BayesNet< GUM_SCALAR > createBN(ParamEstimator< ALLOC > &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
Database __score_database
the database to be used by the scores and parameter estimators
double __EMepsilon
epsilon for EM. if espilon=0.0 : no EM
void __createScore()
create the score used for learning
gum is the global namespace for all aGrUM entities
const std::vector< std::string > & names() const
returns the names of the variables in the database
AprioriType __apriori_type
the a priori selected for the score and parameters
BayesNet< GUM_SCALAR > learnParameters(const DAG &dag, bool take_into_account_score=true)
learns a BN (its parameters) when its structure is known
NodeProperty< Sequence< std::string > > __labelsFromBN(const std::string &filename, const BayesNet< GUM_SCALAR > &src)
read the first line of a file to find column names
DAG __initial_dag
an initial DAG given to learners
std::size_t nbVariables() const noexcept
returns the number of variables (columns) of the database
genericBNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols)
default constructor
Database * __apriori_database
the database used by the Dirichlet a priori
bool hasMissingValues() const
indicates whether the database contains some missing values
ParamEstimator * __createParamEstimator(DBRowGeneratorParser<> &parser, bool take_into_account_score=true)
create the parameter estimator used for learning
DAG __learnDAG()
returns the DAG learnt
genericBNLearner & operator=(const genericBNLearner &)
copy operator
virtual ~BNLearner()
destructor
const DatabaseTable & databaseTable() const
returns the internal database table
std::string checkScoreAprioriCompatibility()
checks whether the current score and apriori are compatible
DAG2BNLearner __Dag2BN
the parametric EM
A listener that allows BNLearner to be used as a proxy for its inner algorithms.
A basic pack of learning algorithms that can easily be used.
DBRowGeneratorParser & parser()
returns the parser for the database
BayesNet< GUM_SCALAR > learnBN()
learn a Bayes Net from a file (must have read the db before)
void __createApriori()
create the apriori used for learning
BNLearner & operator=(const BNLearner &)
copy operator
void setEpsilon(double eps)
Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|.
Size Idx
Type for indexes.
BNLearner(const std::string &filename, const std::vector< std::string > &missing_symbols={"?"})
default constructor
std::size_t Size
In aGrUM, hashed values are unsigned long int.
const DatabaseTable & database() const
returns the database used by the BNLearner
const std::vector< std::string > & names() const
returns the names of the variables in the database
Size NodeId
Type for node ids.
iterator handler() const
returns a new unsafe handler pointing to the 1st record of the database
#define GUM_ERROR(type, msg)