32 #ifndef DOXYGEN_SHOULD_SKIP_THIS 35 # include <agrum/BN/learning/BNLearner.h> 37 # include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h> 42 template <
typename GUM_SCALAR >
43 BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
44 const std::vector< std::string >& missing_symbols) :
45 genericBNLearner(filename, missing_symbols) {
46 GUM_CONSTRUCTOR(BNLearner);
49 template <
typename GUM_SCALAR >
50 BNLearner< GUM_SCALAR >::BNLearner(
const DatabaseTable<>& db) : genericBNLearner(db) {
51 GUM_CONSTRUCTOR(BNLearner);
54 template <
typename GUM_SCALAR >
55 BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
56 const gum::BayesNet< GUM_SCALAR >& bn,
57 const std::vector< std::string >& missing_symbols) :
58 genericBNLearner(filename, bn, missing_symbols) {
59 GUM_CONSTRUCTOR(BNLearner);
63 template <
typename GUM_SCALAR >
64 BNLearner< GUM_SCALAR >::BNLearner(
const BNLearner< GUM_SCALAR >& src) : genericBNLearner(src) {
65 GUM_CONSTRUCTOR(BNLearner);
69 template <
typename GUM_SCALAR >
70 BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : genericBNLearner(src) {
71 GUM_CONSTRUCTOR(BNLearner);
75 template <
typename GUM_SCALAR >
76 BNLearner< GUM_SCALAR >::~BNLearner() {
77 GUM_DESTRUCTOR(BNLearner);
88 template <
typename GUM_SCALAR >
89 BNLearner< GUM_SCALAR >&
90 BNLearner< GUM_SCALAR >::operator=(
const BNLearner< GUM_SCALAR >& src) {
91 genericBNLearner::operator=(src);
96 template <
typename GUM_SCALAR >
97 BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) {
98 genericBNLearner::operator=(std::move(src));
103 template <
typename GUM_SCALAR >
104 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
106 auto notification = checkScoreAprioriCompatibility();
107 if (notification !=
"") { std::cout <<
"[aGrUM notification] " << notification << std::endl; }
111 std::unique_ptr< ParamEstimator<> > param_estimator(
112 createParamEstimator_(scoreDatabase_.parser(),
true));
114 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_());
118 template <
typename GUM_SCALAR >
119 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(
const DAG& dag,
120 bool takeIntoAccountScore) {
122 if (dag.size() == 0)
return BayesNet< GUM_SCALAR >();
125 std::vector< NodeId > ids;
126 ids.reserve(dag.sizeNodes());
127 for (
const auto node: dag)
129 std::sort(ids.begin(), ids.end());
131 if (ids.back() >= scoreDatabase_.names().size()) {
132 std::stringstream str;
133 str <<
"Learning parameters corresponding to the dag is impossible " 134 <<
"because the database does not contain the following nodeID";
135 std::vector< NodeId > bad_ids;
136 for (
const auto node: ids) {
137 if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node);
139 if (bad_ids.size() > 1) str <<
's';
142 for (
const auto node: bad_ids) {
149 GUM_ERROR(MissingVariableInDatabase, str.str())
155 if (epsilonEM_ == 0.0) {
157 if (scoreDatabase_.databaseTable().hasMissingValues()
158 || ((aprioriDatabase_ !=
nullptr)
159 && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
160 && aprioriDatabase_->databaseTable().hasMissingValues())) {
161 GUM_ERROR(MissingValueInDatabase,
162 "In general, the BNLearner is unable to cope with " 163 <<
"missing values in databases. To learn parameters in " 164 <<
"such situations, you should first use method " 169 DBRowGeneratorParser<> parser(scoreDatabase_.databaseTable().handler(),
170 DBRowGeneratorSet<>());
171 std::unique_ptr< ParamEstimator<> > param_estimator(
172 createParamEstimator_(parser, takeIntoAccountScore));
174 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
177 BNLearnerListener listener(
this, Dag2BN_);
180 const auto& database = scoreDatabase_.databaseTable();
181 const std::size_t nb_vars = database.nbVariables();
182 const std::vector< gum::learning::DBTranslatedValueType > col_types(
184 gum::learning::DBTranslatedValueType::DISCRETE);
187 DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
188 DBRowGeneratorSet<> genset_bootstrap;
189 genset_bootstrap.insertGenerator(generator_bootstrap);
190 DBRowGeneratorParser<> parser_bootstrap(database.handler(), genset_bootstrap);
191 std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
192 createParamEstimator_(parser_bootstrap, takeIntoAccountScore));
195 BayesNet< GUM_SCALAR > dummy_bn;
196 DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
197 DBRowGenerator<>& gen_EM = generator_EM;
198 DBRowGeneratorSet<> genset_EM;
199 genset_EM.insertGenerator(gen_EM);
200 DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
201 std::unique_ptr< ParamEstimator<> > param_estimator_EM(
202 createParamEstimator_(parser_EM, takeIntoAccountScore));
204 Dag2BN_.setEpsilon(epsilonEM_);
205 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()),
206 *(param_estimator_EM.get()),
213 template <
typename GUM_SCALAR >
214 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(
bool take_into_account_score) {
215 return learnParameters(initialDag_, take_into_account_score);
219 template <
typename GUM_SCALAR >
220 NodeProperty< Sequence< std::string > >
221 BNLearner< GUM_SCALAR >::_labelsFromBN_(
const std::string& filename,
222 const BayesNet< GUM_SCALAR >& src) {
223 std::ifstream in(filename, std::ifstream::in);
225 if ((in.rdstate() & std::ifstream::failbit) != 0) {
226 GUM_ERROR(gum::IOError,
"File " << filename <<
" not found")
229 CSVParser<> parser(in);
231 auto names = parser.current();
233 NodeProperty< Sequence< std::string > > modals;
235 for (gum::Idx col = 0; col < names.size(); col++) {
237 gum::NodeId graphId = src.idFromName(names[col]);
238 modals.insert(col, gum::Sequence< std::string >());
240 for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
241 modals[col].insert(src.variable(graphId).label(i));
242 }
catch (
const gum::NotFound&) {