aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
BayesNetFactory_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation of the BayesNetFactory class.
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27 
28  */
29 
30 #include <agrum/BN/BayesNetFactory.h>
31 
32 namespace gum {
33 
34  // Default constructor.
35  // @param bn A pointer over the BayesNet filled by this factory.
36  // @throw DuplicateElement Raised if two variables in bn share the same
37  // name.
38  template < typename GUM_SCALAR >
39  INLINE BayesNetFactory< GUM_SCALAR >::BayesNetFactory(BayesNet< GUM_SCALAR >* bn) :
40  _parents_(0), _impl_(0), _bn_(bn) {
43 
44  for (auto node: bn->nodes()) {
46  GUM_ERROR(DuplicateElement, "Name already used: " << bn->variable(node).name())
47 
49  }
50 
51  resetVerbose();
52  }
53 
54  // Copy constructor.
55  // The copy will have an exact copy of the constructed BayesNet in source.
56  template < typename GUM_SCALAR >
57  INLINE
59  _parents_(nullptr),
60  _impl_(nullptr), _bn_(nullptr) {
62 
63  if (source.state() != factory_state::NONE) {
64  GUM_ERROR(OperationNotAllowed, "Illegal state to proceed make a copy.")
65  } else {
67  _bn_ = new BayesNet< GUM_SCALAR >(*(source._bn_));
68  }
69  }
70 
71  // Destructor
72  template < typename GUM_SCALAR >
75 
76  if (_parents_ != nullptr) delete _parents_;
77 
78  if (_impl_ != nullptr) {
79  //@todo better than throwing an exception from inside a destructor but
80  // still ...
81  std::cerr << "[BN factory] Implementation defined for a variable but not used. "
82  "You should call endVariableDeclaration() before "
83  "deleting me."
84  << std::endl;
85  exit(1);
86  }
87  }
88 
89  // Returns the BayesNet created by this factory.
90  template < typename GUM_SCALAR >
92  return _bn_;
93  }
94 
95  template < typename GUM_SCALAR >
97  return _bn_->variable(id);
98  }
99 
100  // Returns the current state of the factory.
101  template < typename GUM_SCALAR >
103  // This is ok because there is always at least the state NONE in the stack.
104  return _states_.back();
105  }
106 
107  // Returns the NodeId of a variable given it's name.
108  // @throw NotFound Raised if no variable matches the name.
109  template < typename GUM_SCALAR >
111  try {
112  return _varNameMap_[name];
113  } catch (NotFound&) { GUM_ERROR(NotFound, name) }
114  }
115 
116  // Returns a constant reference on a variable given it's name.
117  // @throw NotFound Raised if no variable matches the name.
118  template < typename GUM_SCALAR >
119  INLINE const DiscreteVariable&
121  try {
122  return _bn_->variable(variableId(name));
123  } catch (NotFound&) { GUM_ERROR(NotFound, name) }
124  }
125 
126  // Returns the domainSize of the cpt for the node n.
127  // @throw NotFound raised if no such NodeId exists.
128  // @throw OperationNotAllowed if there is no Bayesian networks.
129  template < typename GUM_SCALAR >
131  return _bn_->cpt(n).domainSize();
132  }
133 
134  // Tells the factory that we're in a network declaration.
135  template < typename GUM_SCALAR >
137  if (state() != factory_state::NONE) {
138  _illegalStateError_("startNetworkDeclaration");
139  } else {
141  }
142  }
143 
144  // Tells the factory to add a property to the current network.
145  template < typename GUM_SCALAR >
147  const std::string& propValue) {
149  }
150 
151  // Tells the factory that we're out of a network declaration.
152  template < typename GUM_SCALAR >
154  if (state() != factory_state::NETWORK) {
155  _illegalStateError_("endNetworkDeclaration");
156  } else {
157  _states_.pop_back();
158  }
159  }
160 
161  // Tells the factory that we're in a variable declaration.
162  // A variable is considered as a LabelizedVariable while its type is not defined.
163  template < typename GUM_SCALAR >
165  if (state() != factory_state::NONE) {
166  _illegalStateError_("startVariableDeclaration");
167  } else {
169  _stringBag_.push_back("name");
170  _stringBag_.push_back("desc");
171  _stringBag_.push_back("L");
172  }
173  }
174 
175  // Tells the factory the current variable's name.
176  template < typename GUM_SCALAR >
178  if (state() != factory_state::VARIABLE) {
179  _illegalStateError_("variableName");
180  } else {
181  if (_varNameMap_.exists(name)) { GUM_ERROR(DuplicateElement, "Name already used: " << name) }
182 
183  _foo_flag_ = true;
184  _stringBag_[0] = name;
185  }
186  }
187 
188  // Tells the factory the current variable's description.
189  template < typename GUM_SCALAR >
191  if (state() != factory_state::VARIABLE) {
192  _illegalStateError_("variableDescription");
193  } else {
194  _bar_flag_ = true;
195  _stringBag_[1] = desc;
196  }
197  }
198 
199  // Tells the factory the current variable's type.
200  // L : Labelized
201  // R : Range
202  // C : Continuous
203  // D : Discretized
204  template < typename GUM_SCALAR >
206  if (state() != factory_state::VARIABLE) {
207  _illegalStateError_("variableType");
208  } else {
209  switch (type) {
210  case VarType::Discretized:
211  _stringBag_[2] = "D";
212  break;
213  case VarType::Range:
214  _stringBag_[2] = "R";
215  break;
216  case VarType::Continuous:
218  "Continuous variable (" + _stringBag_[0]
219  + ") are not supported in Bayesian networks.")
220  case VarType::Labelized:
221  _stringBag_[2] = "L";
222  break;
223  }
224  }
225  }
226 
227  // Adds a modality to the current variable.
228  // @throw DuplicateElement If the current variable already has a modality
229  // with the same name.
230  template < typename GUM_SCALAR >
232  if (state() != factory_state::VARIABLE) {
233  _illegalStateError_("addModality");
234  } else {
237  }
238  }
239 
240  // Adds a modality to the current variable.
241  // @throw DuplicateElement If the current variable already has a modality
242  // with the same name.
243  template < typename GUM_SCALAR >
244  INLINE void BayesNetFactory< GUM_SCALAR >::addMin(const long& min) {
245  if (state() != factory_state::VARIABLE) {
246  _illegalStateError_("addMin");
247  } else {
249  }
250  }
251 
252  // Adds a modality to the current variable.
253  // @throw DuplicateElement If the current variable already has a modality
254  // with the same name.
255  template < typename GUM_SCALAR >
256  INLINE void BayesNetFactory< GUM_SCALAR >::addMax(const long& max) {
257  if (state() != factory_state::VARIABLE) {
258  _illegalStateError_("addMin");
259  } else {
261  }
262  }
263 
264  // Adds a modality to the current variable.
265  // @throw DuplicateElement If the current variable already has a modality
266  // with the same name.
267  template < typename GUM_SCALAR >
269  if (state() != factory_state::VARIABLE) {
270  _illegalStateError_("addTick");
271  } else {
273  }
274  }
275 
276  // @brief Defines the implementation to use for Potential.
277  // @warning The implementation must be empty.
278  // @warning The pointer is always delegated to Potential! No copy of it
279  // is made.
280  // @todo When copy of a MultiDimImplementation is available use a copy
281  // behaviour for this method.
282  // @throw NotFound Raised if no variable matches var.
283  // @throw OperationNotAllowed Raised if impl is not empty.
284  // @throw OperationNotAllowed If an implementation is already defined for the
285  // current variable.
286  template < typename GUM_SCALAR >
287  INLINE void
290  = dynamic_cast< MultiDimImplementation< GUM_SCALAR >* >(adressable);
291 
292  if (state() != factory_state::VARIABLE) {
293  _illegalStateError_("setVariableCPTImplementation");
294  } else {
295  if (impl == 0) {
297  "An implementation for this variable is already "
298  "defined.")
299  } else if (impl->nbrDim() > 0) {
300  GUM_ERROR(OperationNotAllowed, "This implementation is not empty.")
301  }
302 
303  _impl_ = impl;
304  }
305  }
306 
307  // Tells the factory that we're out of a variable declaration.
308  template < typename GUM_SCALAR >
310  if (state() != factory_state::VARIABLE) {
311  _illegalStateError_("endVariableDeclaration");
312  } else if (_foo_flag_ && (_stringBag_.size() > 4)) {
313  DiscreteVariable* var = nullptr;
314 
315  // if the current variable is a LabelizedVariable
316  if (_stringBag_[2] == "L") {
318  = new LabelizedVariable(_stringBag_[0], (_bar_flag_) ? _stringBag_[1] : "", 0);
319 
320  for (size_t i = 3; i < _stringBag_.size(); ++i) {
321  l->addLabel(_stringBag_[i]);
322  }
323 
324  var = l;
325  // if the current variable is a RangeVariable
326  } else if (_stringBag_[2] == "R") {
328  (_bar_flag_) ? _stringBag_[1] : "",
329  std::stol(_stringBag_[3]),
330  std::stol(_stringBag_[4]));
331 
332  var = r;
333  // if the current variable is a DiscretizedVariable
334  } else if (_stringBag_[2] == "D") {
337  (_bar_flag_) ? _stringBag_[1] : "");
338 
339  for (size_t i = 3; i < _stringBag_.size(); ++i) {
340  d->addTick(std::stof(_stringBag_[i]));
341  }
342 
343  var = d;
344  }
345 
346  if (_impl_ != 0) {
348  _impl_ = 0;
349  } else {
351  }
352 
354 
355  delete var;
356 
357  _resetParts_();
358  _states_.pop_back();
359 
360  return retVal;
361  } else {
363  msg << "Not enough modalities (";
364 
365  if (_stringBag_.size() > 3) {
366  msg << _stringBag_.size() - 3;
367  } else {
368  msg << 0;
369  }
370 
371  msg << ") declared for variable ";
372 
373  if (_foo_flag_) {
374  msg << _stringBag_[0];
375  } else {
376  msg << "unknown";
377  }
378 
379  _resetParts_();
380 
381  _states_.pop_back();
383  }
384 
385  // For noisy compilers
386  return 0;
387  }
388 
389  // Tells the factory that we're declaring parents for some variable.
390  // @var The concerned variable's name.
391  template < typename GUM_SCALAR >
393  if (state() != factory_state::NONE) {
394  _illegalStateError_("startParentsDeclaration");
395  } else {
400  }
401  }
402 
403  // Tells the factory for which variable we're declaring parents.
404  // @var The parent's name.
405  // @throw NotFound Raised if var does not exists.
406  template < typename GUM_SCALAR >
408  if (state() != factory_state::PARENTS) {
409  _illegalStateError_("addParent");
410  } else {
413  }
414  }
415 
416  // Tells the factory that we've finished declaring parents for some
417  // variable. When parents exist, endParentsDeclaration creates some arcs.
418  // These arcs are created in the inverse order of the order of the parent
419  // specification.
420  template < typename GUM_SCALAR >
422  if (state() != factory_state::PARENTS) {
423  _illegalStateError_("endParentsDeclaration");
424  } else {
426 
427  // PLEASE NOTE THAT THE ORDER IS INVERSE
428 
429  for (size_t i = _stringBag_.size() - 1; i > 0; --i) {
431  }
432 
433  _resetParts_();
434 
435  _states_.pop_back();
436  }
437  }
438 
439  // Tells the factory that we're declaring a conditional probability table
440  // for some variable.
441  // @param var The concerned variable's name.
442  template < typename GUM_SCALAR >
443  INLINE void
445  if (state() != factory_state::NONE) {
446  _illegalStateError_("startRawProbabilityDeclaration");
447  } else {
451  }
452  }
453 
454  // @brief Fills the variable's table with the values in rawTable.
455  // Parse the parents in the same order in which they were added to the
456  // variable.
457  // Given a sequence [var, p_1, p_2, ...,p_n-1, p_n] of parents, modalities are
458  // parsed
459  // in the given order (if all p_i are binary):
460  // [0, 0, ..., 0, 0], [0, 0, ..., 0, 1],
461  // [0, 0, ..., 1, 0], [0, 0, ..., 1, 1],
462  // ...,
463  // [1, 1, ..., 1, 0], [1, 1, ..., 1, 1].
464  // @param rawTable The raw table.
465  template < typename GUM_SCALAR >
466  INLINE void
468  const std::vector< float >& rawTable) {
469  if (state() != factory_state::RAW_CPT) {
470  _illegalStateError_("rawConditionalTable");
471  } else {
473  }
474  }
475 
476  template < typename GUM_SCALAR >
478  const std::vector< std::string >& variables,
479  const std::vector< float >& rawTable) {
482 
483  List< const DiscreteVariable* > varList;
484 
485  for (size_t i = 0; i < variables.size(); ++i) {
487  }
488 
489  // varList.pushFront(&( _bn_->variable( _varNameMap_[ _stringBag_[0]])));
490 
491  Idx nbrVar = varList.size();
492 
493  std::vector< Idx > modCounter;
494 
495  // initializing the array
496  for (NodeId i = 0; i < nbrVar; i++) {
498  }
499 
500  Idx j = 0;
501 
502  do {
503  for (NodeId i = 0; i < nbrVar; i++) {
505  }
506 
507  if (j < rawTable.size()) {
509  } else {
511  }
512 
513  j++;
514  } while (_increment_(modCounter, varList));
515  }
516 
517  template < typename GUM_SCALAR >
518  INLINE void
520  if (state() != factory_state::RAW_CPT) {
521  _illegalStateError_("rawConditionalTable");
522  } else {
524  }
525  }
526 
527  template < typename GUM_SCALAR >
529  const std::vector< float >& rawTable) {
531 
533 
534  // the main loop is on the first variables. The others are in the right
535  // order.
536  const DiscreteVariable& first = table.variable(0);
537  Idx j = 0;
538 
542 
543  cptInst.unsetEnd();
544  }
545  }
546 
547  template < typename GUM_SCALAR >
549  List< const DiscreteVariable* >& varList) {
550  bool last = true;
551 
552  for (NodeId j = 0; j < modCounter.size(); j++) {
553  last = (modCounter[j] == (varList[j]->domainSize() - 1)) && last;
554 
555  if (!last) break;
556  }
557 
558  if (last) { return false; }
559 
560  bool add = false;
561 
562  NodeId i = NodeId(varList.size() - 1);
563 
564  do {
565  if (modCounter[i] == (varList[i]->domainSize() - 1)) {
566  modCounter[i] = 0;
567  add = true;
568  } else {
569  modCounter[i] += 1;
570  add = false;
571  }
572 
573  i--;
574  } while (add);
575 
576  return true;
577  }
578 
579  // Tells the factory that we finished declaring a conditional probability
580  // table.
581  template < typename GUM_SCALAR >
583  if (state() != factory_state::RAW_CPT) {
584  _illegalStateError_("endRawProbabilityDeclaration");
585  } else {
586  _resetParts_();
587  _states_.pop_back();
588  }
589  }
590 
591  // Tells the factory that we're starting a factorized declaration.
592  template < typename GUM_SCALAR >
593  INLINE void
595  if (state() != factory_state::NONE) {
596  _illegalStateError_("startFactorizedProbabilityDeclaration");
597  } else {
602  }
603  }
604 
605  // Tells the factory that we start an entry of a factorized conditional
606  // probability table.
607  template < typename GUM_SCALAR >
609  if (state() != factory_state::FACT_CPT) {
610  _illegalStateError_("startFactorizedEntry");
611  } else {
612  _parents_ = new Instantiation();
614  }
615  }
616 
617  // Tells the factory that we finished declaring a conditional probability
618  // table.
619  template < typename GUM_SCALAR >
621  if (state() != factory_state::FACT_ENTRY) {
622  _illegalStateError_("endFactorizedEntry");
623  } else {
624  delete _parents_;
625  _parents_ = 0;
626  _states_.pop_back();
627  }
628  }
629 
630  // Tells the factory on which modality we want to instantiate one of
631  // variable's parent.
632  template < typename GUM_SCALAR >
634  const std::string& modality) {
635  if (state() != factory_state::FACT_ENTRY) {
636  _illegalStateError_("string");
637  } else {
642  }
643  }
644 
645  // @brief Gives the values of the variable with respect to precedent
646  // parents modality.
647  // If some parents have no modality set, then we apply values for all
648  // instantiations of that parent.
649  //
650  // This means you can declare a default value for the table by doing
651  // @code
652  // BayesNetFactory factory;
653  // // Do stuff
654  // factory.startVariableDeclaration();
655  // factory.variableName("foo");
656  // factory.endVariableDeclaration();
657  // factory.startParentsDeclaration("foo");
658  // // add parents
659  // factory.endParentsDeclaration();
660  // factory.startFactorizedProbabilityDeclaration("foo");
661  // std::vector<float> seq;
662  // seq.insert(0.4); // if foo true
663  // seq.insert(O.6); // if foo false
664  // factory.setVariableValues(seq); // fills the table with a default value
665  // // finish your stuff
666  // factory.endFactorizedProbabilityDeclaration();
667  // @code
668  // as for raw Probability, if value's size is different than the number of
669  // modalities of the current variable, we don't use the supplementary values and
670  // we fill by 0 the missing values.
671  template < typename GUM_SCALAR >
672  INLINE void
674  if (state() != factory_state::FACT_ENTRY) {
675  _illegalStateError_("setVariableValues");
676  } else {
679 
680  if (_parents_->domainSize() > 0) {
683  // Creating an instantiation containing all the variables not ins
684  // _parents_.
686  inst_default << var;
687 
688  for (auto node: _bn_->parents(varId)) {
690  }
691 
692  // Filling the variable's table.
694  (_bn_->cpt(varId))
695  .set(inst,
697  : (GUM_SCALAR)0);
698  }
699  } else {
702  var_inst << var;
703 
704  for (var_inst.setFirst(); !var_inst.end(); ++var_inst) {
706 
708  (_bn_->cpt(varId))
709  .set(inst,
711  : (GUM_SCALAR)0);
712  }
713  }
714  }
715  }
716  }
717 
718  template < typename GUM_SCALAR >
720  if (state() != factory_state::FACT_ENTRY) {
721  _illegalStateError_("setVariableValues");
722  } else {
724  // Checking consistency between values and var.
725 
726  if (values.size() != var.domainSize()) {
728  var.name() << " : invalid number of modalities: found " << values.size()
729  << " while needed " << var.domainSize())
730  }
731 
733  }
734  }
735 
736  // Tells the factory that we finished declaring a conditional probability
737  // table.
738  template < typename GUM_SCALAR >
740  if (state() != factory_state::FACT_CPT) {
741  _illegalStateError_("endFactorizedProbabilityDeclaration");
742  } else {
743  _resetParts_();
744  _states_.pop_back();
745  }
746  }
747 
748  // @brief Define a variable.
749  // You can only call this method is the factory is in the NONE or NETWORK
750  // state.
751  // The variable is added by copy.
752  // @param var The pointer over a DiscreteVariable used to define a new
753  // variable in the built BayesNet.
754  // @throw DuplicateElement Raised if a variable with the same name already
755  // exists.
756  // @throw OperationNotAllowed Raised if redefineParents == false and if table
757  // is not a valid CPT for var in the current state
758  // of the BayesNet.
759  template < typename GUM_SCALAR >
761  if ((state() != factory_state::NONE)) {
762  _illegalStateError_("setVariable");
763  } else {
764  try {
766  GUM_ERROR(DuplicateElement, "Name already used: " << var.name())
767  } catch (NotFound&) {
768  // The var name is unused
770  }
771  }
772  }
773 
774  // @brief Define a variable's CPT.
775  // You can only call this method if the factory is in the NONE or NETWORK
776  // state.
777  // Be careful that table is given to the built BayesNet, so it will be
778  // deleted with it, and you should not directly access it after you call
779  // this method.
780  // When the redefineParents flag is set to true the constructed BayesNet's
781  // DAG is changed to fit with table's definition.
782  // @param var The name of the concerned variable.
783  // @param table A pointer over the CPT used for var.
784  // @param redefineParents If true redefine parents of the variable to match
785  // table's
786  // variables set.
787  //
788  // @throw NotFound Raised if no variable matches var.
789  // @throw OperationNotAllowed Raised if redefineParents == false and if table
790  // is not a valid CPT for var in the current state
791  // of the BayesNet.
792  template < typename GUM_SCALAR >
795  bool redefineParents) {
796  auto pot = dynamic_cast< Potential< GUM_SCALAR >* >(table);
797 
798  if (state() != factory_state::NONE) {
799  _illegalStateError_("setVariableCPT");
800  } else {
804  // If we have to change the structure of the BayesNet, then we call a sub
805  // method.
806 
807  if (redefineParents) {
809  } else if (pot->contains(var)) {
810  for (auto node: _bn_->parents(varId)) {
811  if (!pot->contains(_bn_->variable(node))) {
812  GUM_ERROR(OperationNotAllowed, "The CPT is not valid in the current BayesNet.")
813  }
814  }
815 
816  // CPT are created when a variable is added.
818  }
819  }
820  }
821 
822  // Raise an OperationNotAllowed with the message "Illegal state."
823  template < typename GUM_SCALAR >
825  std::string msg = "Illegal state call (";
826  msg += s;
827  msg += ") in state ";
828 
829  switch (state()) {
830  case factory_state::NONE: {
831  msg += "NONE";
832  break;
833  }
834 
835  case factory_state::NETWORK: {
836  msg += "NETWORK";
837  break;
838  }
839 
840  case factory_state::VARIABLE: {
841  msg += "VARIABLE";
842  break;
843  }
844 
845  case factory_state::PARENTS: {
846  msg += "PARENTS";
847  break;
848  }
849 
850  case factory_state::RAW_CPT: {
851  msg += "RAW_CPT";
852  break;
853  }
854 
855  case factory_state::FACT_CPT: {
856  msg += "FACT_CPT";
857  break;
858  }
859 
860  case factory_state::FACT_ENTRY: {
861  msg += "FACT_ENTRY";
862  break;
863  }
864 
865  default: {
866  msg += "Unknown state";
867  }
868  }
869 
871  }
872 
873  // Check if a variable with the given name exists, if not raise an NotFound
874  // exception.
875  template < typename GUM_SCALAR >
878  }
879 
880  // Check if var exists and if mod is one of it's modality, if not raise an
881  // NotFound exception.
882  template < typename GUM_SCALAR >
884  const std::string& mod) {
887 
888  for (Idx i = 0; i < var.domainSize(); ++i) {
889  if (mod == var.label(i)) { return i; }
890  }
891 
893  }
894 
895  // Check if in _stringBag_ there is no other modality with the same name.
896  template < typename GUM_SCALAR >
898  for (size_t i = 3; i < _stringBag_.size(); ++i) {
899  if (mod == _stringBag_[i]) { GUM_ERROR(DuplicateElement, "Label already used: " << mod) }
900  }
901  }
902 
903  // Sub method of setVariableCPT() which redefine the BayesNet's DAG with
904  // respect to table.
905  template < typename GUM_SCALAR >
907  Potential< GUM_SCALAR >* table) {
910 
911  for (auto v: table->variablesSequence()) {
912  if (v != (&var)) {
915  }
916  }
917 
918  // CPT are created when a variable is added.
920  }
921 
922  // Reset the different parts used to constructed the BayesNet.
923  template < typename GUM_SCALAR >
925  _foo_flag_ = false;
926  _bar_flag_ = false;
927  _stringBag_.clear();
928  }
929 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643