aGrUM  0.16.0
BNDatabaseGenerator_tpl.h
Go to the documentation of this file.
1 
30 
31 #include <agrum/core/timer.h>
32 
33 
34 namespace gum {
35  namespace learning {
36 
37 
39  template < typename GUM_SCALAR >
41  const BayesNet< GUM_SCALAR >& bn) :
42  __bn(bn) {
43  // for debugging purposes
44  GUM_CONSTRUCTOR(BNDatabaseGenerator);
45 
46  // get the node names => they will serve as ids
47  NodeId id = 0;
48  for (const auto& var : __bn.dag()) {
49  auto name = __bn.variable(var).name();
50  __names2ids.insert(name, var);
51  ++id;
52  }
53  __nbVars = id;
54  __varOrder.resize(__nbVars);
55  std::iota(__varOrder.begin(), __varOrder.end(), (Idx)0);
56  }
57 
59  template < typename GUM_SCALAR >
61  GUM_DESTRUCTOR(BNDatabaseGenerator);
62  }
63 
64 
66  template < typename GUM_SCALAR >
68  Timer timer;
69  int progress = 0;
70 
71  timer.reset();
72 
73  if (onProgress.hasListener()) {
74  GUM_EMIT2(onProgress, progress, timer.step());
75  }
76  __database.clear();
77  __database.resize(nbSamples);
78  for (auto& row : __database) {
79  row.resize(__nbVars);
80  }
81  // get the order in which the nodes will be sampled
82  const gum::Sequence< gum::NodeId >& topOrder = __bn.topologicalOrder();
83  std::vector< gum::Instantiation > instantiations;
84 
85  // create instantiations in advance
86  for (Idx node = 0; node < __nbVars; ++node)
87  instantiations.push_back(gum::Instantiation(__bn.cpt(node)));
88 
89  // create the random generator
90  std::random_device rd;
91  std::mt19937 gen(rd());
92  std::uniform_real_distribution<> distro(0.0, 1.0);
93 
94  // perform the sampling
95  __log2likelihood = 0;
96  const gum::DAG& dag = __bn.dag();
97  for (Idx i = 0; i < nbSamples; ++i) {
98  if (onProgress.hasListener()) {
99  int p = int((i * 100) / nbSamples);
100  if (p != progress) {
101  progress = p;
102  GUM_EMIT2(onProgress, progress, timer.step());
103  }
104  }
105  std::vector< Idx >& sample = __database.at(i);
106  for (Idx j = 0; j < __nbVars; ++j) {
107  const gum::NodeId node = topOrder[j];
108  const auto& var = __bn.variable(node);
109  const auto& cpt = __bn.cpt(node);
110 
111  gum::Instantiation& inst = instantiations[node];
112  for (auto par : dag.parents(node))
113  inst.chgVal(__bn.variable(par), sample.at(par));
114 
115  const double nb = distro(gen);
116  double cumul = 0.0;
117  for (inst.chgVal(var, 0); !inst.end(); inst.incVar(var)) {
118  cumul += cpt[inst];
119  if (cumul >= nb) break;
120  }
121 
122  if (inst.end()) inst.chgVal(var, var.domainSize() - 1);
123  sample.at(node) = inst.val(var);
124 
125  __log2likelihood += std::log2(__bn.cpt(node)[inst]);
126  }
127  }
128 
129  __drawnSamples = true;
130 
131  if (onProgress.hasListener()) {
132  std::stringstream ss;
133  ss << "Database of size " << nbSamples << " generated in " << timer.step()
134  << " seconds. Log2likelihood : " << __log2likelihood;
135  GUM_EMIT1(onStop, ss.str());
136  }
137 
138  return __log2likelihood;
139  }
140 
142  template < typename GUM_SCALAR >
143  void BNDatabaseGenerator< GUM_SCALAR >::toCSV(const std::string& csvFileURL,
144  bool useLabels,
145  bool append,
146  std::string csvSeparator,
147  bool checkOnAppend) const {
148  if (!__drawnSamples) {
149  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
150  }
151 
152  if (csvSeparator.find("\n") != std::string::npos) {
154  "csvSeparator must not contain end-line characters");
155  }
156 
157  bool includeHeader = true;
158  if (append) {
159  std::ifstream csvFile(csvFileURL);
160  if (csvFile) {
161  auto varOrder = __varOrderFromCSV(csvFile, csvSeparator);
162  if (checkOnAppend && varOrder != __varOrder) {
163  GUM_ERROR(
165  "Inconsistent variable order in csvFile when appending. You "
166  "can use setVarOrderFromCSV(url) function to get the right "
167  "order. You could also set parameter checkOnAppend=false if you "
168  "know what you are doing.");
169  }
170  includeHeader = false;
171  }
172  csvFile.close();
173  }
174 
175 
176  auto ofstreamFlag = append ? std::ofstream::app : std::ofstream::out;
177 
178  std::ofstream os(csvFileURL, ofstreamFlag);
179  bool firstCol = true;
180  if (includeHeader) {
181  for (const auto& i : __varOrder) {
182  if (firstCol) {
183  firstCol = false;
184  } else {
185  os << csvSeparator;
186  }
187  os << __bn.variable(i).name();
188  }
189  }
190  os << std::endl;
191 
192  bool firstRow = true;
193  for (const auto& row : __database) {
194  if (firstRow) {
195  firstRow = false;
196  } else {
197  os << std::endl;
198  }
199  firstCol = true;
200  for (const auto& i : __varOrder) {
201  if (firstCol) {
202  firstCol = false;
203  } else {
204  os << csvSeparator;
205  }
206  if (useLabels) {
207  os << __bn.variable(i).label(row.at(i));
208  } else {
209  os << row[i];
210  }
211  }
212  }
213 
214  os.close();
215  }
216 
218  template < typename GUM_SCALAR >
221  if (!__drawnSamples)
222  GUM_ERROR(OperationNotAllowed, "proceed() must be called first.");
223 
224  DatabaseTable<> db;
225  std::vector< std::string > varNames;
226  varNames.reserve(__nbVars);
227  for (const auto& i : __varOrder) {
228  varNames.push_back(__names2ids.first(i));
229  }
230 
231  // create the translators
232  for (std::size_t i = 0; i < __nbVars; ++i) {
233  const Variable& var = __bn.variable(__varOrder[i]);
234  db.insertTranslator(var, i);
235  }
236 
237 
238  // db.setVariableNames(varNames);
239  // db.setVariableNames(varOrderNames());
240 
241  if (useLabels) {
242  std::vector< std::string > xrow(__nbVars);
243  for (const auto& row : __database) {
244  for (Idx i = 0; i < __nbVars; ++i) {
245  Idx j = __varOrder.at(i);
246  xrow[i] = __bn.variable(j).label(row.at(j));
247  }
248  db.insertRow(xrow);
249  }
250  } else {
251  std::vector< DBTranslatedValueType > translatorType(__nbVars);
252  for (std::size_t i = 0; i < __nbVars; ++i) {
253  translatorType[i] = db.translator(i).getValType();
254  }
255  DBRow< DBTranslatedValue > xrow(__nbVars);
257  for (const auto& row : __database) {
258  for (Idx i = 0; i < __nbVars; ++i) {
259  Idx j = __varOrder.at(i);
260 
261  if (translatorType[i] == DBTranslatedValueType::DISCRETE)
262  xrow[i].discr_val = std::size_t(row.at(j));
263  else
264  xrow[i].cont_val = float(row.at(j));
265  }
266  }
267  db.insertRow(xrow, xmiss);
268  }
269 
270  return db;
271  }
272 
273 
275  template < typename GUM_SCALAR >
276  std::vector< std::vector< Idx > >
278  if (!__drawnSamples)
279  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
280 
281  auto db(__database);
282  for (Idx i = 0; i < __database.size(); ++i) {
283  for (Idx j = 0; j < __nbVars; ++j) {
284  db.at(i).at(j) = (Idx)__database.at(i).at(__varOrder.at(j));
285  }
286  }
287  return db;
288  }
289 
291  template < typename GUM_SCALAR >
293  const std::vector< Idx >& varOrder) {
294  if (varOrder.size() != __nbVars) {
296  "varOrder's size must be equal to the number of variables");
297  }
298  std::vector< bool > usedVars(__nbVars, false);
299  for (const auto& i : varOrder) {
300  if (i >= __nbVars) {
301  GUM_ERROR(FatalError, "varOrder contains invalid variables");
302  }
303  if (usedVars.at(i))
304  GUM_ERROR(FatalError, "varOrder must not have repeated variables");
305  usedVars.at(i) = true;
306  }
307 
308  if (std::find(usedVars.begin(), usedVars.end(), false) != usedVars.end()) {
309  GUM_ERROR(FatalError, "varOrder must contain all variables");
310  }
311 
313  }
314 
316  template < typename GUM_SCALAR >
318  const std::vector< std::string >& varOrder) {
319  std::vector< Idx > varOrderIdx;
320  varOrderIdx.reserve(varOrder.size());
321  for (const auto& vname : varOrder) {
322  varOrderIdx.push_back(__names2ids.second(vname));
323  }
324  setVarOrder(varOrderIdx);
325  }
326 
328  template < typename GUM_SCALAR >
330  const std::string& csvFileURL, const std::string& csvSeparator) {
331  setVarOrder(__varOrderFromCSV(csvFileURL, csvSeparator));
332  }
333 
335  template < typename GUM_SCALAR >
337  std::vector< Idx > varOrder;
338  varOrder.reserve(__nbVars);
339  for (const auto& v : __bn.topologicalOrder()) {
340  varOrder.push_back(v);
341  }
342  setVarOrder(varOrder);
343  }
344 
346  template < typename GUM_SCALAR >
348  std::vector< Idx > varOrder;
349  varOrder.reserve(__nbVars);
350  for (const auto& v : __bn.topologicalOrder()) {
351  varOrder.push_back(v);
352  }
353  std::reverse(varOrder.begin(), varOrder.end());
354  setVarOrder(varOrder);
355  }
356 
358  template < typename GUM_SCALAR >
360  std::vector< std::string > varOrder;
361  varOrder.reserve(__bn.size());
362  for (const auto& var : __bn.dag()) {
363  varOrder.push_back(__bn.variable(var).name());
364  }
365  std::random_device rd;
366  std::mt19937 g(rd());
367  std::shuffle(varOrder.begin(), varOrder.end(), g);
368  setVarOrder(varOrder);
369  }
370 
371 
373  template < typename GUM_SCALAR >
374  std::vector< Idx > BNDatabaseGenerator< GUM_SCALAR >::varOrder() const {
375  return __varOrder;
376  }
377 
379  template < typename GUM_SCALAR >
380  std::vector< std::string >
382  std::vector< std::string > varNames;
383  varNames.reserve(__nbVars);
384  for (const auto& i : __varOrder) {
385  varNames.push_back(__names2ids.first(i));
386  }
387 
388  return varNames;
389  }
390 
392  template < typename GUM_SCALAR >
394  if (!__drawnSamples) {
395  GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.");
396  }
397  return __log2likelihood;
398  }
399 
401  template < typename GUM_SCALAR >
403  const std::string& csvFileURL, const std::string& csvSeparator) const {
404  std::ifstream csvFile(csvFileURL);
405  std::vector< Idx > varOrder;
406  if (csvFile) {
407  varOrder = __varOrderFromCSV(csvFile, csvSeparator);
408  csvFile.close();
409  } else {
410  GUM_ERROR(NotFound, "csvFileURL does not exist");
411  }
412 
413  return varOrder;
414  }
415 
417  template < typename GUM_SCALAR >
419  std::ifstream& csvFile, const std::string& csvSeparator) const {
420  std::string line;
421  std::vector< std::string > header_found;
422  header_found.reserve(__nbVars);
423  while (std::getline(csvFile, line)) {
424  std::size_t i = 0;
425  auto pos = line.find(csvSeparator);
426  while (pos != std::string::npos) {
427  header_found.push_back(line.substr(i, pos - i));
428  pos += csvSeparator.length();
429  i = pos;
430  pos = line.find(csvSeparator, pos);
431 
432  if (pos == std::string::npos)
433  header_found.push_back(line.substr(i, line.length()));
434  }
435  break;
436  }
437 
438  std::vector< Size > varOrder;
439  varOrder.reserve(__nbVars);
440 
441  for (const auto& hf : header_found) {
442  varOrder.push_back(__names2ids.second(hf));
443  }
444 
445  return varOrder;
446  }
447  } /* namespace learning */
448 } /* namespace gum */
void insert(const T1 &first, const T2 &second)
Inserts a new association in the gum::Bijection.
Class representing a Bayesian Network.
Definition: BayesNet.h:78
const T2 & second(const T1 &first) const
Returns the second value of a pair given its first value.
Base class for every random variable.
Definition: variable.h:66
Signaler2< Size, double > onProgress
Progression (percent) and time.
double __log2likelihood
log2Likelihood of generated samples
double step() const
Returns the delta time between now and the last reset() call (or the constructor).
Definition: timer_inl.h:42
const T1 & first(const T2 &second) const
Returns the first value of a pair given its second value.
#define GUM_EMIT1(signal, arg1)
Definition: signaler1.h:42
std::vector< std::string > varOrderNames() const
returns variable order.
double drawSamples(Size nbSamples)
generate and stock database, returns log2likelihood using ProgressNotifier as notification ...
The generic class for storing (ordered) sequences of objects.
Definition: sequence.h:1022
virtual void insertRow(const std::vector< std::string, ALLOC< std::string > > &new_row) final
insert a new row at the end of the database
Instantiation & chgVal(const DiscreteVariable &v, Idx newval)
Assign newval to variable v in the Instantiation.
Bijection< std::string, NodeId > __names2ids
bijection nodes names
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
void incVar(const DiscreteVariable &v)
Operator increment for variable v only.
DatabaseTable toDatabaseTable(bool useLabels=true) const
generates a DatabaseVectInRAM
void setTopologicalVarOrder()
set columns in topoligical order
Idx val(Idx i) const
Returns the current value of the variable at position i.
void reset()
Reset the timer.
Definition: timer_inl.h:32
const BayesNet< GUM_SCALAR > & __bn
Bayesian network.
#define GUM_EMIT2(signal, arg1, arg2)
Definition: signaler2.h:42
const DBTranslator< ALLOC > & translator(const std::size_t k, const bool k_is_input_col=false) const
returns either the kth translator of the database table or the first one reading the kth column of th...
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
std::vector< Idx > __varOrderFromCSV(const std::string &csvFileURL, const std::string &csvSeparator=",") const
returns varOrder from a csv file
void toCSV(const std::string &csvFileURL, bool useLabels=true, bool append=false, std::string csvSeparator=",", bool checkOnAppend=false) const
generates csv database according to bn
double log2likelihood() const
returns log2Likelihood of generated samples
std::vector< std::vector< Idx > > database() const
generates database according to bn into a std::vector
The class for storing a record in a database.
Definition: DBRow.h:56
std::vector< std::vector< Idx > > __database
generated database
The class representing a tabular database as used by learning tasks.
void setVarOrder(const std::vector< Idx > &varOrder)
change columns order
std::size_t insertTranslator(const DBTranslator< ALLOC > &translator, const std::size_t input_column, const bool unique_column=true)
insert a new translator into the database table
Class for assigning/browsing values to tuples of discrete variables.
Definition: instantiation.h:83
void setRandomVarOrder()
set columns in random order
void setAntiTopologicalVarOrder()
set columns in antiTopoligical order
std::vector< Idx > varOrder() const
returns variable order indexes
Class used to compute response times for benchmark purposesThis class represents a classic timer...
Definition: timer.h:51
Size Idx
Type for indexes.
Definition: types.h:53
Signaler1< const std::string &> onStop
with a possible explanation for stopping
bool __drawnSamples
whether drawSamples has been already called.
std::vector< Idx > __varOrder
variable order in generated database
void setVarOrderFromCSV(const std::string &csvFileURL, const std::string &csvSeparator=",")
change columns order according to a csv file
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:48
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Base class for dag.
Definition: DAG.h:102
Size NodeId
Type for node ids.
Definition: graphElements.h:98
BNDatabaseGenerator(const BayesNet< GUM_SCALAR > &bn)
default constructor
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
bool end() const
Returns true if the Instantiation reached the end.