aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
BNDatabaseGenerator_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief
24  *
25  * @author Santiago CORTIJO
26  */
27 
28 #include <agrum/BN/database/BNDatabaseGenerator.h>
29 
30 #include <agrum/tools/core/timer.h>
31 
32 
33 namespace gum {
34  namespace learning {
35 
36 
37  /// default constructor
38  template < typename GUM_SCALAR >
39  BNDatabaseGenerator< GUM_SCALAR >::BNDatabaseGenerator(
40  const BayesNet< GUM_SCALAR >& bn) :
41  bn__(bn) {
42  // for debugging purposes
43  GUM_CONSTRUCTOR(BNDatabaseGenerator);
44 
45  // get the node names => they will serve as ids
46  NodeId id = 0;
47  for (const auto& var: bn__.dag()) {
48  auto name = bn__.variable(var).name();
49  names2ids__.insert(name, var);
50  ++id;
51  }
52  nbVars__ = id;
53  varOrder__.resize(nbVars__);
54  std::iota(varOrder__.begin(), varOrder__.end(), (Idx)0);
55  }
56 
57  /// destructor
58  template < typename GUM_SCALAR >
61  }
62 
63 
64  /// draw instances from bn__
65  template < typename GUM_SCALAR >
67  Timer timer;
68  int progress = 0;
69 
70  timer.reset();
71 
72  if (onProgress.hasListener()) {
74  }
75  database__.clear();
77  for (auto& row: database__) {
79  }
80  // get the order in which the nodes will be sampled
83 
84  // create instantiations in advance
85  for (Idx node = 0; node < nbVars__; ++node)
87 
88  // create the random generator
90  std::mt19937 gen(rd());
92 
93  // perform the sampling
94  log2likelihood__ = 0;
95  const gum::DAG& dag = bn__.dag();
96  for (Idx i = 0; i < nbSamples; ++i) {
97  if (onProgress.hasListener()) {
98  int p = int((i * 100) / nbSamples);
99  if (p != progress) {
100  progress = p;
102  }
103  }
104  std::vector< Idx >& sample = database__.at(i);
105  for (Idx j = 0; j < nbVars__; ++j) {
106  const gum::NodeId node = topOrder[j];
107  const auto& var = bn__.variable(node);
108  const auto& cpt = bn__.cpt(node);
109 
111  for (auto par: dag.parents(node))
113 
114  const double nb = distro(gen);
115  double cumul = 0.0;
116  for (inst.chgVal(var, 0); !inst.end(); inst.incVar(var)) {
117  cumul += cpt[inst];
118  if (cumul >= nb) break;
119  }
120 
121  if (inst.end()) inst.chgVal(var, var.domainSize() - 1);
122  sample.at(node) = inst.val(var);
123 
125  }
126  }
127 
128  drawnSamples__ = true;
129 
130  if (onProgress.hasListener()) {
132  ss << "Database of size " << nbSamples << " generated in " << timer.step()
133  << " seconds. Log2likelihood : " << log2likelihood__;
134  GUM_EMIT1(onStop, ss.str());
135  }
136 
137  return log2likelihood__;
138  }
139 
140  /// generates database, and writes csv file
141  template < typename GUM_SCALAR >
143  bool useLabels,
144  bool append,
146  bool checkOnAppend) const {
147  if (!drawnSamples__) {
148  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
149  }
150 
151  if (csvSeparator.find("\n") != std::string::npos) {
153  "csvSeparator must not contain end-line characters");
154  }
155 
156  bool includeHeader = true;
157  if (append) {
159  if (csvFile) {
161  if (checkOnAppend && varOrder != varOrder__) {
162  GUM_ERROR(
164  "Inconsistent variable order in csvFile when appending. You "
165  "can use setVarOrderFromCSV(url) function to get the right "
166  "order. You could also set parameter checkOnAppend=false if you "
167  "know what you are doing.");
168  }
169  includeHeader = false;
170  }
171  csvFile.close();
172  }
173 
174 
176 
178  bool firstCol = true;
179  if (includeHeader) {
180  for (const auto& i: varOrder__) {
181  if (firstCol) {
182  firstCol = false;
183  } else {
184  os << csvSeparator;
185  }
186  os << bn__.variable(i).name();
187  }
188  }
189  os << std::endl;
190 
191  bool firstRow = true;
192  for (const auto& row: database__) {
193  if (firstRow) {
194  firstRow = false;
195  } else {
196  os << std::endl;
197  }
198  firstCol = true;
199  for (const auto& i: varOrder__) {
200  if (firstCol) {
201  firstCol = false;
202  } else {
203  os << csvSeparator;
204  }
205  if (useLabels) {
206  os << bn__.variable(i).label(row.at(i));
207  } else {
208  os << row[i];
209  }
210  }
211  }
212 
213  os.close();
214  }
215 
216  /// generates a DatabaseVectInRAM
217  template < typename GUM_SCALAR >
218  DatabaseTable<>
220  if (!drawnSamples__)
221  GUM_ERROR(OperationNotAllowed, "proceed() must be called first.");
222 
223  DatabaseTable<> db;
226  for (const auto& i: varOrder__) {
228  }
229 
230  // create the translators
231  for (std::size_t i = 0; i < nbVars__; ++i) {
232  const Variable& var = bn__.variable(varOrder__[i]);
234  }
235 
236 
237  // db.setVariableNames(varNames);
238  // db.setVariableNames(varOrderNames());
239 
240  if (useLabels) {
242  for (const auto& row: database__) {
243  for (Idx i = 0; i < nbVars__; ++i) {
244  Idx j = varOrder__.at(i);
245  xrow[i] = bn__.variable(j).label(row.at(j));
246  }
247  db.insertRow(xrow);
248  }
249  } else {
251  for (std::size_t i = 0; i < nbVars__; ++i) {
253  }
255  const auto xmiss = gum::learning::DatabaseTable<>::IsMissing::False;
256  for (const auto& row: database__) {
257  for (Idx i = 0; i < nbVars__; ++i) {
258  Idx j = varOrder__.at(i);
259 
261  xrow[i].discr_val = std::size_t(row.at(j));
262  else
263  xrow[i].cont_val = float(row.at(j));
264  }
265  }
267  }
268 
269  return db;
270  }
271 
272 
273  /// returns database using specified data order
274  template < typename GUM_SCALAR >
275  std::vector< std::vector< Idx > >
277  if (!drawnSamples__)
278  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
279 
280  auto db(database__);
281  for (Idx i = 0; i < database__.size(); ++i) {
282  for (Idx j = 0; j < nbVars__; ++j) {
283  db.at(i).at(j) = (Idx)database__.at(i).at(varOrder__.at(j));
284  }
285  }
286  return db;
287  }
288 
289  /// change columns order
290  template < typename GUM_SCALAR >
292  const std::vector< Idx >& varOrder) {
293  if (varOrder.size() != nbVars__) {
295  "varOrder's size must be equal to the number of variables");
296  }
297  std::vector< bool > usedVars(nbVars__, false);
298  for (const auto& i: varOrder) {
299  if (i >= nbVars__) {
300  GUM_ERROR(FatalError, "varOrder contains invalid variables");
301  }
302  if (usedVars.at(i))
303  GUM_ERROR(FatalError, "varOrder must not have repeated variables");
304  usedVars.at(i) = true;
305  }
306 
307  if (std::find(usedVars.begin(), usedVars.end(), false) != usedVars.end()) {
308  GUM_ERROR(FatalError, "varOrder must contain all variables");
309  }
310 
312  }
313 
314  /// change columns order using variable names
315  template < typename GUM_SCALAR >
317  const std::vector< std::string >& varOrder) {
320  for (const auto& vname: varOrder) {
322  }
324  }
325 
326  /// change columns order according to a csv file
327  template < typename GUM_SCALAR >
329  const std::string& csvFileURL,
330  const std::string& csvSeparator) {
332  }
333 
334  /// set columns in topoligical order
335  template < typename GUM_SCALAR >
337  std::vector< Idx > varOrder;
339  for (const auto& v: bn__.topologicalOrder()) {
341  }
343  }
344 
345  /// set columns in antiTopoligical order
346  template < typename GUM_SCALAR >
348  std::vector< Idx > varOrder;
350  for (const auto& v: bn__.topologicalOrder()) {
352  }
355  }
356 
357  /// set columns in random order
358  template < typename GUM_SCALAR >
362  for (const auto& var: bn__.dag()) {
364  }
366  std::mt19937 g(rd());
369  }
370 
371 
372  /// returns variable order indexes
373  template < typename GUM_SCALAR >
375  return varOrder__;
376  }
377 
378  /// returns variable order.
379  template < typename GUM_SCALAR >
380  std::vector< std::string >
384  for (const auto& i: varOrder__) {
386  }
387 
388  return varNames;
389  }
390 
391  /// returns log2Likelihood of generated samples
392  template < typename GUM_SCALAR >
394  if (!drawnSamples__) {
395  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
396  }
397  return log2likelihood__;
398  }
399 
400  /// returns varOrder from a csv file
401  template < typename GUM_SCALAR >
403  const std::string& csvFileURL,
404  const std::string& csvSeparator) const {
406  std::vector< Idx > varOrder;
407  if (csvFile) {
409  csvFile.close();
410  } else {
411  GUM_ERROR(NotFound, "csvFileURL does not exist");
412  }
413 
414  return varOrder;
415  }
416 
417  /// returns varOrder from a csv file
418  template < typename GUM_SCALAR >
420  std::ifstream& csvFile,
421  const std::string& csvSeparator) const {
422  std::string line;
425  while (std::getline(csvFile, line)) {
426  std::size_t i = 0;
427  auto pos = line.find(csvSeparator);
428  while (pos != std::string::npos) {
430  pos += csvSeparator.length();
431  i = pos;
433 
434  if (pos == std::string::npos)
436  }
437  break;
438  }
439 
440  std::vector< Size > varOrder;
442 
443  for (const auto& hf: header_found) {
445  }
446 
447  return varOrder;
448  }
449  } /* namespace learning */
450 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)