aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
DBRowGeneratorEM_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A DBRowGenerator class that returns incomplete rows as EM would do
24  *
25  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
26  */
27 #include <agrum/tools/database/DBRowGeneratorIdentity.h>
28 
29 #ifndef DOXYGEN_SHOULD_SKIP_THIS
30 
31 namespace gum {
32 
33  namespace learning {
34 
35  /// returns the allocator used
36  template < typename GUM_SCALAR, template < typename > class ALLOC >
39  return DBRowGenerator< ALLOC >::getAllocator();
40  }
41 
42 
43  /// default constructor
44  template < typename GUM_SCALAR, template < typename > class ALLOC >
47  const BayesNet< GUM_SCALAR >& bn,
49  const typename DBRowGeneratorEM< GUM_SCALAR, ALLOC >::allocator_type& alloc) :
51  bn,
54  alloc),
55  _filled_row1_(bn.size(), 1.0, alloc), _filled_row2_(bn.size(), 1.0, alloc) {
56  setBayesNet(bn);
57 
59  }
60 
61 
62  /// copy constructor with a given allocator
63  template < typename GUM_SCALAR, template < typename > class ALLOC >
66  const typename DBRowGeneratorEM< GUM_SCALAR, ALLOC >::allocator_type& alloc) :
72  if (from._joint_inst_ != nullptr) {
74  const auto& var_seq = _joint_inst_->variablesSequence();
75  const std::size_t size = var_seq.size();
76  for (std::size_t i = std::size_t(0); i < size; ++i) {
78  }
79  }
80 
82  }
83 
84 
85  /// copy constructor
86  template < typename GUM_SCALAR, template < typename > class ALLOC >
90 
91 
92  /// move constructor with a given allocator
93  template < typename GUM_SCALAR, template < typename > class ALLOC >
96  const typename DBRowGeneratorEM< GUM_SCALAR, ALLOC >::allocator_type& alloc) :
102  if (from._joint_inst_ != nullptr) {
104  const auto& var_seq = _joint_inst_->variablesSequence();
105  const std::size_t size = var_seq.size();
106  for (std::size_t i = std::size_t(0); i < size; ++i) {
108  }
109  }
110 
112  }
113 
114 
115  /// move constructor
116  template < typename GUM_SCALAR, template < typename > class ALLOC >
120 
121 
122  /// virtual copy constructor with a given allocator
123  template < typename GUM_SCALAR, template < typename > class ALLOC >
125  const typename DBRowGeneratorEM< GUM_SCALAR, ALLOC >::allocator_type& alloc) const {
128  try {
130  } catch (...) {
132  throw;
133  }
134  return generator;
135  }
136 
137 
138  /// virtual copy constructor
139  template < typename GUM_SCALAR, template < typename > class ALLOC >
141  return clone(this->getAllocator());
142  }
143 
144 
145  /// destructor
146  template < typename GUM_SCALAR, template < typename > class ALLOC >
148  if (_joint_inst_ != nullptr) delete _joint_inst_;
150  }
151 
152 
153  /// copy operator
154  template < typename GUM_SCALAR, template < typename > class ALLOC >
156  const DBRowGeneratorEM< GUM_SCALAR, ALLOC >& from) {
157  if (this != &from) {
167 
168  if (_joint_inst_ != nullptr) {
169  delete _joint_inst_;
170  _joint_inst_ = nullptr;
171  }
172 
173  if (from._joint_inst_ != nullptr) {
175  const auto& var_seq = _joint_inst_->variablesSequence();
176  const std::size_t size = var_seq.size();
177  for (std::size_t i = std::size_t(0); i < size; ++i) {
179  }
180  }
181  }
182 
183  return *this;
184  }
185 
186 
187  /// move operator
188  template < typename GUM_SCALAR, template < typename > class ALLOC >
191  if (this != &from) {
201 
202  if (_joint_inst_ != nullptr) {
203  delete _joint_inst_;
204  _joint_inst_ = nullptr;
205  }
206 
207  if (from._joint_inst_ != nullptr) {
209  const auto& var_seq = _joint_inst_->variablesSequence();
210  const std::size_t size = var_seq.size();
211  for (std::size_t i = std::size_t(0); i < size; ++i) {
213  }
214  }
215  }
216 
217  return *this;
218  }
219 
220 
221  /// generates new lines from those the generator gets in input
222  template < typename GUM_SCALAR, template < typename > class ALLOC >
225  this->decreaseRemainingRows();
226 
227  // if everything is observed, return the input row
228  if (_input_row_ != nullptr) return *_input_row_;
229 
230  if (_use_filled_row1_) {
231  // get the weight of the row from the joint probability
233 
234  // fill the values of the row
235  for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i)
237 
238  _joint_inst_->inc();
239  _use_filled_row1_ = false;
240 
241  return _filled_row1_;
242  } else {
243  // get the weight of the row from the joint probability
245 
246  // fill the values of the row
247  for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i)
249 
250  _joint_inst_->inc();
251  _use_filled_row1_ = true;
252 
253  return _filled_row2_;
254  }
255  }
256 
257 
258  /// computes the rows it will provide in output
259  template < typename GUM_SCALAR, template < typename > class ALLOC >
261  const DBRow< DBTranslatedValue, ALLOC >& row) {
262  // check if there are unobserved values among the columns of interest.
263  // If this is the case, set them as targets
264  bool found_unobserved = false;
265  const auto& xrow = row.row();
266  for (const auto col: this->columns_of_interest_) {
267  switch (this->column_types_[col]) {
269  if (xrow[col].discr_val == std::numeric_limits< std::size_t >::max()) {
270  if (!found_unobserved) {
272  found_unobserved = true;
273  }
275  }
276  break;
277 
280  "The BDRowGeneratorEM does not handle yet continuous "
281  << "variables. But the variable in column" << col << " is continuous.");
282  break;
283 
284  default:
286  "DBTranslatedValueType " << int(this->column_types_[col])
287  << " is not supported yet");
288  }
289  }
290 
291  // if there is no unobserved value, make the _input_row_ point to the row
292  if (!found_unobserved) {
293  _input_row_ = &row;
294  return std::size_t(1);
295  }
296 
297  _input_row_ = nullptr;
300 
301  // here, there are missing symbols, so we should compute the distribution
302  // of the missing values. For this purpose, we use Variable Elimination
304 
305  // add the targets and fill the output row with the observed values
307  if (this->nodeId2columns_.empty()) {
308  std::size_t i = std::size_t(0);
309  bool end_miss = false;
310  for (const auto col: this->columns_of_interest_) {
311  if (!end_miss && (col == _missing_cols_[i])) {
313  ++i;
314  if (i == _nb_miss_) end_miss = true;
315  } else {
318  }
319  }
320  } else {
321  std::size_t i = std::size_t(0);
322  bool end_miss = false;
323  for (const auto col: this->columns_of_interest_) {
324  if (!end_miss && (col == _missing_cols_[i])) {
326  ++i;
327  if (i == _nb_miss_) end_miss = true;
328  } else {
331  }
332  }
333  }
334 
336 
337  // add the evidence and the target
338  const std::size_t row_size = xrow.size();
339  if (this->nodeId2columns_.empty()) {
340  for (std::size_t col = std::size_t(0); col < row_size; ++col) {
341  switch (this->column_types_[col]) {
343  // only observed values are evidence
344  if (xrow[col].discr_val != std::numeric_limits< std::size_t >::max()) {
346  }
347  break;
348 
351  "The BDRowGeneratorEM does not handle yet continuous "
352  << "variables. But the variable in column" << col << " is continuous.");
353  break;
354 
355  default:
357  "DBTranslatedValueType " << int(this->column_types_[col])
358  << " is not supported yet");
359  }
360  }
361  } else {
362  for (std::size_t col = std::size_t(0); col < row_size; ++col) {
363  switch (this->column_types_[col]) {
365  // only observed values are evidence
366  if (xrow[col].discr_val != std::numeric_limits< std::size_t >::max()) {
368  }
369  break;
370 
373  "The BDRowGeneratorEM does not handle yet continuous "
374  << "variables. But the variable in column" << col << " is continuous.");
375  break;
376 
377  default:
379  "DBTranslatedValueType " << int(this->column_types_[col])
380  << " is not supported yet");
381  }
382  }
383  }
384 
385  // get the potential of the target set
387  = const_cast< Potential< GUM_SCALAR >& >(ve.jointPosterior(target_set));
389  if (_joint_inst_ != nullptr) delete _joint_inst_;
391 
392  // get the mapping between variables of the joint proba and the
393  // columns in the database
395  if (this->nodeId2columns_.empty()) {
396  for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i) {
397  _missing_cols_[i] = std::size_t(this->bn_->nodeId(*(var_sequence[i])));
398  }
399  } else {
400  for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i) {
402  }
403  }
404 
405  return std::size_t(_joint_proba_.domainSize());
406  }
407 
408 
409  /// assign a new Bayes net to the generator
410  template < typename GUM_SCALAR, template < typename > class ALLOC >
412  // check that if nodeId2columns is not empty, then all the columns
413  // correspond to nodes of the BN
414  if (!this->nodeId2columns_.empty()) {
415  const DAG& dag = new_bn.dag();
416  for (auto iter = this->nodeId2columns_.begin(); iter != this->nodeId2columns_.end();
417  ++iter) {
418  if (!dag.existsNode(iter.first())) {
420  "Column " << iter.second() << " of the database is associated to Node ID "
421  << iter.first()
422  << ", which does not belong to the Bayesian network");
423  }
424  }
425  }
426 
428 
429  // we determine the size of the filled rows
430  std::size_t size = std::size_t(0);
431  if (this->nodeId2columns_.empty()) {
432  for (auto node: new_bn.dag())
433  if (std::size_t(node) > size) size = std::size_t(node);
434  } else {
435  for (auto iter = this->nodeId2columns_.begin(); iter != this->nodeId2columns_.end();
436  ++iter) {
437  if (iter.second() > size) size = iter.second();
438  }
439  }
442  }
443 
444  } /* namespace learning */
445 
446 } /* namespace gum */
447 
448 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)