aGrUM  0.21.0
a C++ library for (probabilistic) graphical models
BNDatabaseGenerator_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief
24  *
25  * @author Santiago CORTIJO
26  */
27 
28 #include <agrum/BN/database/BNDatabaseGenerator.h>
29 
30 #include <agrum/tools/core/timer.h>
31 
32 
33 namespace gum {
34  namespace learning {
35 
36 
37  /// default constructor
38  template < typename GUM_SCALAR >
39  BNDatabaseGenerator< GUM_SCALAR >::BNDatabaseGenerator(const BayesNet< GUM_SCALAR >& bn) :
40  _bn_(bn) {
41  GUM_CONSTRUCTOR(BNDatabaseGenerator);
42 
43  // get the node names => they will serve as ids
44  NodeId id = 0;
45  for (const auto& var: _bn_.dag()) {
46  auto name = _bn_.variable(var).name();
47  _names2ids_.insert(name, var);
48  ++id;
49  }
50  _nbVars_ = id;
51  _varOrder_.resize(_nbVars_);
52  std::iota(_varOrder_.begin(), _varOrder_.end(), (Idx)0);
53  }
54 
55  /// destructor
56  template < typename GUM_SCALAR >
59  }
60 
61 
62  /// draw instances from _bn_
63  template < typename GUM_SCALAR >
65  Timer timer;
66  int progress = 0;
67 
68  timer.reset();
69 
71  _database_.clear();
73  for (auto& row: _database_) {
75  }
76  // get the order in which the nodes will be sampled
79 
80  // create instantiations in advance
81  for (Idx node = 0; node < _nbVars_; ++node)
83 
84  // create the random generator
86  std::mt19937 gen(rd());
88 
89  // perform the sampling
90  _log2likelihood_ = 0;
91  const gum::DAG& dag = _bn_.dag();
92  for (Idx i = 0; i < nbSamples; ++i) {
93  if (onProgress.hasListener()) {
94  int p = int((i * 100) / nbSamples);
95  if (p != progress) {
96  progress = p;
98  }
99  }
100  std::vector< Idx >& sample = _database_.at(i);
101  for (Idx j = 0; j < _nbVars_; ++j) {
102  const gum::NodeId node = topOrder[j];
103  const auto& var = _bn_.variable(node);
104  const auto& cpt = _bn_.cpt(node);
105 
107  for (auto par: dag.parents(node))
109 
110  const double nb = distro(gen);
111  double cumul = 0.0;
112  for (inst.chgVal(var, 0); !inst.end(); inst.incVar(var)) {
113  cumul += cpt[inst];
114  if (cumul >= nb) break;
115  }
116 
117  if (inst.end()) inst.chgVal(var, var.domainSize() - 1);
118  sample.at(node) = inst.val(var);
119 
121  }
122  }
123 
124  _drawnSamples_ = true;
125 
126  if (onProgress.hasListener()) {
128  ss << "Database of size " << nbSamples << " generated in " << timer.step()
129  << " seconds. Log2likelihood : " << _log2likelihood_;
130  GUM_EMIT1(onStop, ss.str());
131  }
132 
133  return _log2likelihood_;
134  }
135 
136  /// generates database, and writes csv file
137  template < typename GUM_SCALAR >
139  bool useLabels,
140  bool append,
142  bool checkOnAppend) const {
143  if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
144 
145  if (csvSeparator.find("\n") != std::string::npos) {
146  GUM_ERROR(InvalidArgument, "csvSeparator must not contain end-line characters")
147  }
148 
149  bool includeHeader = true;
150  if (append) {
152  if (csvFile) {
154  if (checkOnAppend && varOrder != _varOrder_) {
156  "Inconsistent variable order in csvFile when appending. You "
157  "can use setVarOrderFromCSV(url) function to get the right "
158  "order. You could also set parameter checkOnAppend=false if you "
159  "know what you are doing.");
160  }
161  includeHeader = false;
162  }
163  csvFile.close();
164  }
165 
166 
168 
170  bool firstCol = true;
171  if (includeHeader) {
172  for (const auto& i: _varOrder_) {
173  if (firstCol) {
174  firstCol = false;
175  } else {
176  os << csvSeparator;
177  }
178  os << _bn_.variable(i).name();
179  }
180  }
181  os << std::endl;
182 
183  bool firstRow = true;
184  for (const auto& row: _database_) {
185  if (firstRow) {
186  firstRow = false;
187  } else {
188  os << std::endl;
189  }
190  firstCol = true;
191  for (const auto& i: _varOrder_) {
192  if (firstCol) {
193  firstCol = false;
194  } else {
195  os << csvSeparator;
196  }
197  if (useLabels) {
198  os << _bn_.variable(i).label(row.at(i));
199  } else {
200  os << row[i];
201  }
202  }
203  }
204 
205  os.close();
206  }
207 
208  /// generates a DatabaseVectInRAM
209  template < typename GUM_SCALAR >
211  if (!_drawnSamples_) GUM_ERROR(OperationNotAllowed, "proceed() must be called first.")
212 
213  DatabaseTable<> db;
216  for (const auto& i: _varOrder_) {
218  }
219 
220  // create the translators
221  for (std::size_t i = 0; i < _nbVars_; ++i) {
222  const Variable& var = _bn_.variable(_varOrder_[i]);
224  }
225 
226 
227  // db.setVariableNames(varNames);
228  // db.setVariableNames(varOrderNames());
229 
230  if (useLabels) {
232  for (const auto& row: _database_) {
233  for (Idx i = 0; i < _nbVars_; ++i) {
234  Idx j = _varOrder_.at(i);
235  xrow[i] = _bn_.variable(j).label(row.at(j));
236  }
237  db.insertRow(xrow);
238  }
239  } else {
241  for (std::size_t i = 0; i < _nbVars_; ++i) {
243  }
245  const auto xmiss = gum::learning::DatabaseTable<>::IsMissing::False;
246  for (const auto& row: _database_) {
247  for (Idx i = 0; i < _nbVars_; ++i) {
248  Idx j = _varOrder_.at(i);
249 
251  xrow[i].discr_val = std::size_t(row.at(j));
252  else
253  xrow[i].cont_val = float(row.at(j));
254  }
255  }
257  }
258 
259  return db;
260  }
261 
262 
263  /// returns database using specified data order
264  template < typename GUM_SCALAR >
266  if (!_drawnSamples_) GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.")
267 
268  auto db(_database_);
269  for (Idx i = 0; i < _database_.size(); ++i) {
270  for (Idx j = 0; j < _nbVars_; ++j) {
271  db.at(i).at(j) = (Idx)_database_.at(i).at(_varOrder_.at(j));
272  }
273  }
274  return db;
275  }
276 
277  /// change columns order
278  template < typename GUM_SCALAR >
280  if (varOrder.size() != _nbVars_)
281  GUM_ERROR(FatalError, "varOrder's size must be equal to the number of variables")
282 
283  std::vector< bool > usedVars(_nbVars_, false);
284  for (const auto& i: varOrder) {
285  if (i >= _nbVars_) GUM_ERROR(FatalError, "varOrder contains invalid variables")
286  if (usedVars.at(i)) GUM_ERROR(FatalError, "varOrder must not have repeated variables")
287  usedVars.at(i) = true;
288  }
289 
290  if (std::find(usedVars.begin(), usedVars.end(), false) != usedVars.end()) {
291  GUM_ERROR(FatalError, "varOrder must contain all variables")
292  }
293 
295  }
296 
297  /// change columns order using variable names
298  template < typename GUM_SCALAR >
299  void
303  for (const auto& vname: varOrder) {
305  }
307  }
308 
309  /// change columns order according to a csv file
310  template < typename GUM_SCALAR >
312  const std::string& csvSeparator) {
314  }
315 
316  /// set columns in topoligical order
317  template < typename GUM_SCALAR >
319  std::vector< Idx > varOrder;
321  for (const auto& v: _bn_.topologicalOrder()) {
323  }
325  }
326 
327  /// set columns in antiTopoligical order
328  template < typename GUM_SCALAR >
330  std::vector< Idx > varOrder;
332  for (const auto& v: _bn_.topologicalOrder()) {
334  }
337  }
338 
339  /// set columns in random order
340  template < typename GUM_SCALAR >
344  for (const auto& var: _bn_.dag()) {
346  }
348  std::mt19937 g(rd());
351  }
352 
353 
354  /// returns variable order indexes
355  template < typename GUM_SCALAR >
357  return _varOrder_;
358  }
359 
360  /// returns variable order.
361  template < typename GUM_SCALAR >
365  for (const auto& i: _varOrder_) {
367  }
368 
369  return varNames;
370  }
371 
372  /// returns log2Likelihood of generated samples
373  template < typename GUM_SCALAR >
375  if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
376  return _log2likelihood_;
377  }
378 
379  /// returns varOrder from a csv file
380  template < typename GUM_SCALAR >
381  std::vector< Idx >
383  const std::string& csvSeparator) const {
385  std::vector< Idx > varOrder;
386  if (csvFile) {
388  csvFile.close();
389  } else {
390  GUM_ERROR(NotFound, "csvFileURL does not exist")
391  }
392 
393  return varOrder;
394  }
395 
396  /// returns varOrder from a csv file
397  template < typename GUM_SCALAR >
398  std::vector< Idx >
400  const std::string& csvSeparator) const {
401  std::string line;
404  while (std::getline(csvFile, line)) {
405  std::size_t i = 0;
406  auto pos = line.find(csvSeparator);
407  while (pos != std::string::npos) {
409  pos += csvSeparator.length();
410  i = pos;
412 
414  }
415  break;
416  }
417 
418  std::vector< Size > varOrder;
420 
421  for (const auto& hf: header_found) {
423  }
424 
425  return varOrder;
426  }
427  } /* namespace learning */
428 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)