aGrUM  0.14.2
BNDatabaseGenerator_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}@lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it wil be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
27 
28 #include <agrum/core/timer.h>
29 
30 
31 namespace gum {
32  namespace learning {
33 
34 
36  template < typename GUM_SCALAR >
38  const BayesNet< GUM_SCALAR >& bn) :
39  __bn(bn) {
40  // for debugging purposes
41  GUM_CONSTRUCTOR(BNDatabaseGenerator);
42 
43  // get the node names => they will serve as ids
44  NodeId id = 0;
45  for (const auto& var : __bn.dag()) {
46  auto name = __bn.variable(var).name();
47  __names2ids.insert(name, var);
48  ++id;
49  }
50  __nbVars = id;
51  __varOrder.resize(__nbVars);
52  std::iota(__varOrder.begin(), __varOrder.end(), (Idx)0);
53  }
54 
56  template < typename GUM_SCALAR >
58  GUM_DESTRUCTOR(BNDatabaseGenerator);
59  }
60 
61 
63  template < typename GUM_SCALAR >
65  Timer timer;
66  int progress = 0;
67 
68  timer.reset();
69 
70  if (onProgress.hasListener()) {
71  GUM_EMIT2(onProgress, progress, timer.step());
72  }
73  __database.clear();
74  __database.resize(nbSamples);
75  for (auto& row : __database) {
76  row.resize(__nbVars);
77  }
78  // get the order in which the nodes will be sampled
79  const gum::Sequence< gum::NodeId >& topOrder = __bn.topologicalOrder();
80  std::vector< gum::Instantiation > instantiations;
81 
82  // create instantiations in advance
83  for (Idx node = 0; node < __nbVars; ++node)
84  instantiations.push_back(gum::Instantiation(__bn.cpt(node)));
85 
86  // create the random generator
87  std::random_device rd;
88  std::mt19937 gen(rd());
89  std::uniform_real_distribution<> distro(0.0, 1.0);
90 
91  // perform the sampling
92  __log2likelihood = 0;
93  const gum::DAG& dag = __bn.dag();
94  for (Idx i = 0; i < nbSamples; ++i) {
95  if (onProgress.hasListener()) {
96  int p = int((i * 100) / nbSamples);
97  if (p != progress) {
98  progress = p;
99  GUM_EMIT2(onProgress, progress, timer.step());
100  }
101  }
102  std::vector< Idx >& sample = __database.at(i);
103  for (Idx j = 0; j < __nbVars; ++j) {
104  const gum::NodeId node = topOrder[j];
105  const auto& var = __bn.variable(node);
106  const auto& cpt = __bn.cpt(node);
107 
108  gum::Instantiation& inst = instantiations[node];
109  for (auto par : dag.parents(node))
110  inst.chgVal(__bn.variable(par), sample.at(par));
111 
112  const double nb = distro(gen);
113  double cumul = 0.0;
114  for (inst.chgVal(var, 0); !inst.end(); inst.incVar(var)) {
115  cumul += cpt[inst];
116  if (cumul >= nb) break;
117  }
118 
119  if (inst.end()) inst.chgVal(var, var.domainSize() - 1);
120  sample.at(node) = inst.val(var);
121 
122  __log2likelihood += std::log2(__bn.cpt(node)[inst]);
123  }
124  }
125 
126  __drawnSamples = true;
127 
128  if (onProgress.hasListener()) {
129  std::stringstream ss;
130  ss << "Database of size " << nbSamples << " generated in " << timer.step()
131  << " seconds. Log2likelihood : " << __log2likelihood;
132  GUM_EMIT1(onStop, ss.str());
133  }
134 
135  return __log2likelihood;
136  }
137 
139  template < typename GUM_SCALAR >
140  void BNDatabaseGenerator< GUM_SCALAR >::toCSV(const std::string& csvFileURL,
141  bool useLabels,
142  bool append,
143  std::string csvSeparator,
144  bool checkOnAppend) const {
145  if (!__drawnSamples) {
146  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
147  }
148 
149  if (csvSeparator.find("\n") != std::string::npos) {
151  "csvSeparator must not contain end-line characters");
152  }
153 
154  bool includeHeader = true;
155  if (append) {
156  std::ifstream csvFile(csvFileURL);
157  if (csvFile) {
158  auto varOrder = __varOrderFromCSV(csvFile, csvSeparator);
159  if (checkOnAppend && varOrder != __varOrder) {
160  GUM_ERROR(
162  "Inconsistent variable order in csvFile when appending. You "
163  "can use setVarOrderFromCSV(url) function to get the right "
164  "order. You could also set parameter checkOnAppend=false if you "
165  "know what you are doing.");
166  }
167  includeHeader = false;
168  }
169  csvFile.close();
170  }
171 
172 
173  auto ofstreamFlag = append ? std::ofstream::app : std::ofstream::out;
174 
175  std::ofstream os(csvFileURL, ofstreamFlag);
176  bool firstCol = true;
177  if (includeHeader) {
178  for (const auto& i : __varOrder) {
179  if (firstCol) {
180  firstCol = false;
181  } else {
182  os << csvSeparator;
183  }
184  os << __bn.variable(i).name();
185  }
186  }
187  os << std::endl;
188 
189  bool firstRow = true;
190  for (const auto& row : __database) {
191  if (firstRow) {
192  firstRow = false;
193  } else {
194  os << std::endl;
195  }
196  firstCol = true;
197  for (const auto& i : __varOrder) {
198  if (firstCol) {
199  firstCol = false;
200  } else {
201  os << csvSeparator;
202  }
203  if (useLabels) {
204  os << __bn.variable(i).label(row.at(i));
205  } else {
206  os << row[i];
207  }
208  }
209  }
210 
211  os.close();
212  }
213 
215  template < typename GUM_SCALAR >
218  if (!__drawnSamples)
219  GUM_ERROR(OperationNotAllowed, "proceed() must be called first.");
220 
221  DatabaseTable<> db;
222  std::vector< std::string > varNames;
223  varNames.reserve(__nbVars);
224  for (const auto& i : __varOrder) {
225  varNames.push_back(__names2ids.first(i));
226  }
227 
228  // create the translators
229  for (std::size_t i = 0; i < __nbVars; ++i) {
230  const Variable& var = __bn.variable(__varOrder[i]);
231  db.insertTranslator(var, i);
232  }
233 
234 
235  // db.setVariableNames(varNames);
236  // db.setVariableNames(varOrderNames());
237 
238  if (useLabels) {
239  std::vector< std::string > xrow(__nbVars);
240  for (const auto& row : __database) {
241  for (Idx i = 0; i < __nbVars; ++i) {
242  Idx j = __varOrder.at(i);
243  xrow[i] = __bn.variable(j).label(row.at(j));
244  }
245  db.insertRow(xrow);
246  }
247  } else {
248  std::vector< DBTranslatedValueType > translatorType(__nbVars);
249  for (std::size_t i = 0; i < __nbVars; ++i) {
250  translatorType[i] = db.translator(i).getValType();
251  }
252  DBRow< DBTranslatedValue > xrow(__nbVars);
254  for (const auto& row : __database) {
255  for (Idx i = 0; i < __nbVars; ++i) {
256  Idx j = __varOrder.at(i);
257 
258  if (translatorType[i] == DBTranslatedValueType::DISCRETE)
259  xrow[i].discr_val = std::size_t(row.at(j));
260  else
261  xrow[i].cont_val = float(row.at(j));
262  }
263  }
264  db.insertRow(xrow, xmiss);
265  }
266 
267  return db;
268  }
269 
270 
272  template < typename GUM_SCALAR >
273  std::vector< std::vector< Idx > >
275  if (!__drawnSamples)
276  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
277 
278  auto db(__database);
279  for (Idx i = 0; i < __database.size(); ++i) {
280  for (Idx j = 0; j < __nbVars; ++j) {
281  db.at(i).at(j) = (Idx)__database.at(i).at(__varOrder.at(j));
282  }
283  }
284  return db;
285  }
286 
288  template < typename GUM_SCALAR >
290  const std::vector< Idx >& varOrder) {
291  if (varOrder.size() != __nbVars) {
293  "varOrder's size must be equal to the number of variables");
294  }
295  std::vector< bool > usedVars(__nbVars, false);
296  for (const auto& i : varOrder) {
297  if (i >= __nbVars) {
298  GUM_ERROR(FatalError, "varOrder contains invalid variables");
299  }
300  if (usedVars.at(i))
301  GUM_ERROR(FatalError, "varOrder must not have repeated variables");
302  usedVars.at(i) = true;
303  }
304 
305  if (std::find(usedVars.begin(), usedVars.end(), false) != usedVars.end()) {
306  GUM_ERROR(FatalError, "varOrder must contain all variables");
307  }
308 
310  }
311 
313  template < typename GUM_SCALAR >
315  const std::vector< std::string >& varOrder) {
316  std::vector< Idx > varOrderIdx;
317  varOrderIdx.reserve(varOrder.size());
318  for (const auto& vname : varOrder) {
319  varOrderIdx.push_back(__names2ids.second(vname));
320  }
321  setVarOrder(varOrderIdx);
322  }
323 
325  template < typename GUM_SCALAR >
327  const std::string& csvFileURL, const std::string& csvSeparator) {
328  setVarOrder(__varOrderFromCSV(csvFileURL, csvSeparator));
329  }
330 
332  template < typename GUM_SCALAR >
334  std::vector< Idx > varOrder;
335  varOrder.reserve(__nbVars);
336  for (const auto& v : __bn.topologicalOrder()) {
337  varOrder.push_back(v);
338  }
339  setVarOrder(varOrder);
340  }
341 
343  template < typename GUM_SCALAR >
345  std::vector< Idx > varOrder;
346  varOrder.reserve(__nbVars);
347  for (const auto& v : __bn.topologicalOrder()) {
348  varOrder.push_back(v);
349  }
350  std::reverse(varOrder.begin(), varOrder.end());
351  setVarOrder(varOrder);
352  }
353 
355  template < typename GUM_SCALAR >
357  std::vector< std::string > varOrder;
358  varOrder.reserve(__bn.size());
359  for (const auto& var : __bn.dag()) {
360  varOrder.push_back(__bn.variable(var).name());
361  }
362  std::random_device rd;
363  std::mt19937 g(rd());
364  std::shuffle(varOrder.begin(), varOrder.end(), g);
365  setVarOrder(varOrder);
366  }
367 
368 
370  template < typename GUM_SCALAR >
371  std::vector< Idx > BNDatabaseGenerator< GUM_SCALAR >::varOrder() const {
372  return __varOrder;
373  }
374 
376  template < typename GUM_SCALAR >
377  std::vector< std::string >
379  std::vector< std::string > varNames;
380  varNames.reserve(__nbVars);
381  for (const auto& i : __varOrder) {
382  varNames.push_back(__names2ids.first(i));
383  }
384 
385  return varNames;
386  }
387 
389  template < typename GUM_SCALAR >
391  if (!__drawnSamples) {
392  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
393  }
394  return __log2likelihood;
395  }
396 
398  template < typename GUM_SCALAR >
400  const std::string& csvFileURL, const std::string& csvSeparator) const {
401  std::ifstream csvFile(csvFileURL);
402  std::vector< Idx > varOrder;
403  if (csvFile) {
404  varOrder = __varOrderFromCSV(csvFile, csvSeparator);
405  csvFile.close();
406  } else {
407  GUM_ERROR(NotFound, "csvFileURL does not exist");
408  }
409 
410  return varOrder;
411  }
412 
414  template < typename GUM_SCALAR >
416  std::ifstream& csvFile, const std::string& csvSeparator) const {
417  std::string line;
418  std::vector< std::string > header_found;
419  header_found.reserve(__nbVars);
420  while (std::getline(csvFile, line)) {
421  std::size_t i = 0;
422  auto pos = line.find(csvSeparator);
423  while (pos != std::string::npos) {
424  header_found.push_back(line.substr(i, pos - i));
425  pos += csvSeparator.length();
426  i = pos;
427  pos = line.find(csvSeparator, pos);
428 
429  if (pos == std::string::npos)
430  header_found.push_back(line.substr(i, line.length()));
431  }
432  break;
433  }
434 
435  std::vector< Size > varOrder;
436  varOrder.reserve(__nbVars);
437 
438  for (const auto& hf : header_found) {
439  varOrder.push_back(__names2ids.second(hf));
440  }
441 
442  return varOrder;
443  }
444  } /* namespace learning */
445 } /* namespace gum */
void insert(const T1 &first, const T2 &second)
Inserts a new association in the gum::Bijection.
Class representing a Bayesian Network.
Definition: BayesNet.h:76
const T2 & second(const T1 &first) const
Returns the second value of a pair given its first value.
Base class for every random variable.
Definition: variable.h:63
Signaler2< Size, double > onProgress
Progression (percent) and time.
double __log2likelihood
log2Likelihood of generated samples
double step() const
Returns the delta time between now and the last reset() call (or the constructor).
Definition: timer_inl.h:39
const T1 & first(const T2 &second) const
Returns the first value of a pair given its second value.
#define GUM_EMIT1(signal, arg1)
Definition: signaler1.h:40
std::vector< std::string > varOrderNames() const
returns variable order.
double drawSamples(Size nbSamples)
generate and stock database, returns log2likelihood using ProgressNotifier as notification ...
The generic class for storing (ordered) sequences of objects.
Definition: sequence.h:1019
virtual void insertRow(const std::vector< std::string, ALLOC< std::string > > &new_row) final
insert a new row at the end of the database
Instantiation & chgVal(const DiscreteVariable &v, Idx newval)
Assign newval to variable v in the Instantiation.
Bijection< std::string, NodeId > __names2ids
bijection nodes names
Class used to compute response times for benchmark purposes.
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
void incVar(const DiscreteVariable &v)
Operator increment for variable v only.
DatabaseTable toDatabaseTable(bool useLabels=true) const
generates a DatabaseVectInRAM
void setTopologicalVarOrder()
set columns in topoligical order
Idx val(Idx i) const
Returns the current value of the variable at position i.
void reset()
Reset the timer.
Definition: timer_inl.h:29
const BayesNet< GUM_SCALAR > & __bn
Bayesian network.
#define GUM_EMIT2(signal, arg1, arg2)
Definition: signaler2.h:40
const DBTranslator< ALLOC > & translator(const std::size_t k, const bool k_is_input_col=false) const
returns either the kth translator of the database table or the first one reading the kth column of th...
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
std::vector< Idx > __varOrderFromCSV(const std::string &csvFileURL, const std::string &csvSeparator=",") const
returns varOrder from a csv file
void toCSV(const std::string &csvFileURL, bool useLabels=true, bool append=false, std::string csvSeparator=",", bool checkOnAppend=false) const
generates csv database according to bn
double log2likelihood() const
returns log2Likelihood of generated samples
std::vector< std::vector< Idx > > database() const
generates database according to bn into a std::vector
The class for storing a record in a database.
Definition: DBRow.h:53
std::vector< std::vector< Idx > > __database
generated database
The class representing a tabular database as used by learning tasks.
void setVarOrder(const std::vector< Idx > &varOrder)
change columns order
std::size_t insertTranslator(const DBTranslator< ALLOC > &translator, const std::size_t input_column, const bool unique_column=true)
insert a new translator into the database table
Class for assigning/browsing values to tuples of discrete variables.
Definition: instantiation.h:80
void setRandomVarOrder()
set columns in random order
void setAntiTopologicalVarOrder()
set columns in antiTopoligical order
std::vector< Idx > varOrder() const
returns variable order indexes
Class used to compute response times for benchmark purposesThis class represents a classic timer...
Definition: timer.h:48
Size Idx
Type for indexes.
Definition: types.h:50
Signaler1< const std::string &> onStop
with a possible explanation for stopping
bool __drawnSamples
whether drawSamples has been already called.
std::vector< Idx > __varOrder
variable order in generated database
void setVarOrderFromCSV(const std::string &csvFileURL, const std::string &csvSeparator=",")
change columns order according to a csv file
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
Base class for dag.
Definition: DAG.h:99
Size NodeId
Type for node ids.
Definition: graphElements.h:97
BNDatabaseGenerator(const BayesNet< GUM_SCALAR > &bn)
default constructor
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
bool end() const
Returns true if the Instantiation reached the end.