aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
BayesNet_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Template implementation of BN/BayesNet.h class.
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6) and Lionel TORTI
27  */
28 
29 #include <limits>
30 #include <set>
31 
32 #include <agrum/BN/BayesNet.h>
33 
34 #include <agrum/tools/variables/rangeVariable.h>
35 #include <agrum/tools/variables/labelizedVariable.h>
36 #include <agrum/tools/variables/discretizedVariable.h>
37 
38 #include <agrum/tools/multidim/aggregators/amplitude.h>
39 #include <agrum/tools/multidim/aggregators/and.h>
40 #include <agrum/tools/multidim/aggregators/count.h>
41 #include <agrum/tools/multidim/aggregators/exists.h>
42 #include <agrum/tools/multidim/aggregators/forall.h>
43 #include <agrum/tools/multidim/aggregators/max.h>
44 #include <agrum/tools/multidim/aggregators/median.h>
45 #include <agrum/tools/multidim/aggregators/min.h>
46 #include <agrum/tools/multidim/aggregators/or.h>
47 #include <agrum/tools/multidim/aggregators/sum.h>
48 
49 #include <agrum/tools/multidim/ICIModels/multiDimNoisyAND.h>
50 #include <agrum/tools/multidim/ICIModels/multiDimNoisyORCompound.h>
51 #include <agrum/tools/multidim/ICIModels/multiDimNoisyORNet.h>
52 
53 #include <agrum/tools/multidim/ICIModels/multiDimLogit.h>
54 
55 #include <agrum/BN/generator/simpleCPTGenerator.h>
56 #include <agrum/tools/core/utils_string.h>
57 
58 namespace gum {
59  template < typename GUM_SCALAR >
61  std::string node,
63  std::string name = node;
64  auto ds = default_domain_size;
65  long range_min = 0;
66  long range_max = long(ds) - 1;
67  std::vector< std::string > labels;
69 
70  if (*(node.rbegin()) == ']') {
71  auto posBrack = node.find('[');
72  if (posBrack != std::string::npos) {
73  name = node.substr(0, posBrack);
74  const auto& s_args = node.substr(posBrack + 1, node.size() - posBrack - 2);
75  const auto& args = split(s_args, ",");
76  if (args.size() == 0) { // n[]
77  GUM_ERROR(InvalidArgument, "Empty range for variable " << node)
78  } else if (args.size() == 1) { // n[4]
79  ds = static_cast< Size >(std::stoi(args[0]));
80  range_min = 0;
81  range_max = long(ds) - 1;
82  } else if (args.size() == 2) { // n[5,10]
83  range_min = std::stol(args[0]);
84  range_max = std::stol(args[1]);
85  if (1 + range_max - range_min < 2) {
86  GUM_ERROR(InvalidArgument, "Invalid range for variable " << node);
87  }
88  ds = static_cast< Size >(1 + range_max - range_min);
89  } else { // n[3.14,5,10,12]
90  for (const auto& tick: args) {
91  ticks.push_back(static_cast< GUM_SCALAR >(std::atof(tick.c_str())));
92  }
93  ds = static_cast< Size >(args.size() - 1);
94  }
95  }
96  } else if (*(node.rbegin()) == '}') { // node like "n{one|two|three}"
97  auto posBrack = node.find('{');
98  if (posBrack != std::string::npos) {
99  name = node.substr(0, posBrack);
100  labels = split(node.substr(posBrack + 1, node.size() - posBrack - 2), "|");
101  if (labels.size() < 2) {
102  GUM_ERROR(InvalidArgument, "Not enough labels in node " << node);
103  }
104  if (!hasUniqueElts(labels)) {
105  GUM_ERROR(InvalidArgument, "Duplicate labels in node " << node);
106  }
107  ds = static_cast< Size >(labels.size());
108  }
109  }
110 
111  if (ds == 0) {
112  GUM_ERROR(InvalidArgument, "No value for variable " << name << ".");
113  } else if (ds == 1) {
115  "Only one value for variable " << name
116  << " (2 at least are needed).");
117  }
118 
119  // now we add the node in the BN
120  NodeId idVar;
121  try {
122  idVar = bn.idFromName(name);
123  } catch (gum::NotFound&) {
124  if (!labels.empty()) {
126  } else if (!ticks.empty()) {
128  } else {
130  }
131  }
132 
133  return idVar;
134  }
135 
136  template < typename GUM_SCALAR >
137  BayesNet< GUM_SCALAR >
138  BayesNet< GUM_SCALAR >::fastPrototype(const std::string& dotlike,
139  Size domainSize) {
141 
142 
143  for (const auto& chaine: split(dotlike, ";")) {
144  NodeId lastId = 0;
145  bool notfirst = false;
146  for (const auto& souschaine: split(chaine, "->")) {
147  bool forward = true;
148  for (const auto& node: split(souschaine, "<-")) {
149  auto idVar = build_node(bn, node, domainSize);
150  if (notfirst) {
151  if (forward) {
152  bn.addArc(lastId, idVar);
153  forward = false;
154  } else {
155  bn.addArc(idVar, lastId);
156  }
157  } else {
158  notfirst = true;
159  forward = false;
160  }
161  lastId = idVar;
162  }
163  }
164  }
165  bn.generateCPTs();
166  bn.setProperty("name", "fastPrototype");
167  return bn;
168  }
169 
170  template < typename GUM_SCALAR >
173  }
174 
175  template < typename GUM_SCALAR >
179  }
180 
181  template < typename GUM_SCALAR >
185 
187  }
188 
189  template < typename GUM_SCALAR >
190  BayesNet< GUM_SCALAR >&
192  if (this != &source) {
195 
198  }
199 
200  return *this;
201  }
202 
203  template < typename GUM_SCALAR >
206  for (const auto p: probaMap__) {
207  delete p.second;
208  }
209  }
210 
211  template < typename GUM_SCALAR >
212  INLINE const DiscreteVariable&
214  return varMap__.get(id);
215  }
216 
217  template < typename GUM_SCALAR >
218  INLINE void
220  const std::string& new_name) {
222  }
223 
224  template < typename GUM_SCALAR >
225  INLINE void
227  const std::string& old_label,
228  const std::string& new_label) {
229  if (variable(id).varType() != VarType::Labelized) {
230  GUM_ERROR(NotFound, "Variable " << id << " is not a LabelizedVariable.");
231  }
232  LabelizedVariable* var = dynamic_cast< LabelizedVariable* >(
233  const_cast< DiscreteVariable* >(&variable(id)));
234 
236  }
237 
238 
239  template < typename GUM_SCALAR >
241  return varMap__.get(var);
242  }
243 
244  template < typename GUM_SCALAR >
246  auto ptr = new MultiDimArray< GUM_SCALAR >();
247  NodeId res = 0;
248 
249  try {
250  res = add(var, ptr);
251 
252  } catch (Exception&) {
253  delete ptr;
254  throw;
255  }
256 
257  return res;
258  }
259 
260  template < typename GUM_SCALAR >
262  unsigned int nbrmod) {
263  if (nbrmod < 2) {
265  "Variable " << name << "needs more than " << nbrmod
266  << " modalities");
267  }
268 
269  RangeVariable v(name, name, 0, nbrmod - 1);
270  return add(v);
271  }
272 
273  template < typename GUM_SCALAR >
274  INLINE NodeId
278  NodeId res = 0;
279 
281 
282  return res;
283  }
284 
285  template < typename GUM_SCALAR >
287  NodeId id) {
288  auto ptr = new MultiDimArray< GUM_SCALAR >();
289  NodeId res = 0;
290 
291  try {
292  res = add(var, ptr, id);
293 
294  } catch (Exception&) {
295  delete ptr;
296  throw;
297  }
298 
299  return res;
300  }
301 
302  template < typename GUM_SCALAR >
303  NodeId
306  NodeId id) {
307  varMap__.insert(id, var);
308  this->dag_.addNodeWithId(id);
309 
310  auto cpt = new Potential< GUM_SCALAR >(aContent);
311  (*cpt) << variable(id);
313  return id;
314  }
315 
316  template < typename GUM_SCALAR >
318  return varMap__.idFromName(name);
319  }
320 
321  template < typename GUM_SCALAR >
322  INLINE const DiscreteVariable&
325  }
326 
327  template < typename GUM_SCALAR >
328  INLINE const Potential< GUM_SCALAR >&
330  return *(probaMap__[varId]);
331  }
332 
333  template < typename GUM_SCALAR >
335  return varMap__;
336  }
337 
338  template < typename GUM_SCALAR >
340  erase(varMap__.get(var));
341  }
342 
343  template < typename GUM_SCALAR >
345  if (varMap__.exists(varId)) {
346  // Reduce the variable child's CPT
347  const NodeSet& children = this->children(varId);
348 
349  for (const auto c: children) {
351  }
352 
353  delete probaMap__[varId];
354 
357  this->dag_.eraseNode(varId);
358  }
359  }
360 
361  template < typename GUM_SCALAR >
362  void BayesNet< GUM_SCALAR >::clear() {
363  if (!this->empty()) {
364  auto l = this->nodes();
365  for (const auto no: l) {
366  this->erase(no);
367  }
368  }
369  }
370 
371  template < typename GUM_SCALAR >
373  if (this->dag_.existsArc(tail, head)) {
375  "The arc (" << tail << "," << head << ") already exists.")
376  }
377 
378  this->dag_.addArc(tail, head);
379  // Add parent in the child's CPT
380  (*(probaMap__[head])) << variable(tail);
381  }
382 
383  template < typename GUM_SCALAR >
385  const std::string& head) {
386  try {
387  addArc(this->idFromName(tail), this->idFromName(head));
388  } catch (DuplicateElement) {
390  "The arc " << tail << "->" << head << " already exists.")
391  }
392  }
393 
394  template < typename GUM_SCALAR >
396  if (varMap__.exists(arc.tail()) && varMap__.exists(arc.head())) {
397  NodeId head = arc.head(), tail = arc.tail();
398  this->dag_.eraseArc(arc);
399  // Remove parent froms child's CPT
400  (*(probaMap__[head])) >> variable(tail);
401  }
402  }
403 
404  template < typename GUM_SCALAR >
406  eraseArc(Arc(tail, head));
407  }
408 
409  template < typename GUM_SCALAR >
410  void BayesNet< GUM_SCALAR >::reverseArc(const Arc& arc) {
411  // check that the arc exsists
412  if (!varMap__.exists(arc.tail()) || !varMap__.exists(arc.head())
413  || !dag().existsArc(arc)) {
414  GUM_ERROR(InvalidArc, "a nonexisting arc cannot be reversed");
415  }
416 
417  NodeId tail = arc.tail(), head = arc.head();
418 
419  // check that the reversal does not induce a cycle
420  try {
421  DAG d = dag();
422  d.eraseArc(arc);
423  d.addArc(head, tail);
424  } catch (Exception&) {
425  GUM_ERROR(InvalidArc, "this arc reversal would induce a directed cycle");
426  }
427 
428  // with the same notations as Shachter (1986), "evaluating influence
429  // diagrams",
430  // p.878, we shall first compute the product of probabilities:
431  // pi_j^old (x_j | x_c^old(j) ) * pi_i^old (x_i | x_c^old(i) )
433 
434  // modify the topology of the graph: add to tail all the parents of head
435  // and add to head all the parents of tail
438  for (const auto node: this->parents(tail))
440  for (const auto node: this->parents(head))
442  // remove arc (head, tail)
443  eraseArc(arc);
444 
445  // add the necessary arcs to the tail
446  for (const auto p: new_parents) {
447  if ((p != tail) && !dag().existsArc(p, tail)) { addArc(p, tail); }
448  }
449 
450  addArc(head, tail);
451  // add the necessary arcs to the head
453 
454  for (const auto p: new_parents) {
455  if ((p != head) && !dag().existsArc(p, head)) { addArc(p, head); }
456  }
457 
459 
460  // update the conditional distributions of head and tail
461  Set< const DiscreteVariable* > del_vars;
462  del_vars << &(variable(tail));
465 
466  auto& cpt_head = const_cast< Potential< GUM_SCALAR >& >(cpt(head));
468 
471  auto& cpt_tail = const_cast< Potential< GUM_SCALAR >& >(cpt(tail));
473  }
474 
475  template < typename GUM_SCALAR >
477  reverseArc(Arc(tail, head));
478  }
479 
480 
481  //==============================================
482  // Aggregators
483  //=============================================
484  template < typename GUM_SCALAR >
486  return add(var, new aggregator::Amplitude< GUM_SCALAR >());
487  }
488 
489  template < typename GUM_SCALAR >
491  if (var.domainSize() > 2) GUM_ERROR(SizeError, "an AND has to be boolean");
492 
493  return add(var, new aggregator::And< GUM_SCALAR >());
494  }
495 
496  template < typename GUM_SCALAR >
498  Idx value) {
499  return add(var, new aggregator::Count< GUM_SCALAR >(value));
500  }
501 
502  template < typename GUM_SCALAR >
504  Idx value) {
505  if (var.domainSize() > 2) GUM_ERROR(SizeError, "an EXISTS has to be boolean");
506 
507  return add(var, new aggregator::Exists< GUM_SCALAR >(value));
508  }
509 
510  template < typename GUM_SCALAR >
512  Idx value) {
513  if (var.domainSize() > 2) GUM_ERROR(SizeError, "an EXISTS has to be boolean");
514 
515  return add(var, new aggregator::Forall< GUM_SCALAR >(value));
516  }
517 
518  template < typename GUM_SCALAR >
520  return add(var, new aggregator::Max< GUM_SCALAR >());
521  }
522 
523  template < typename GUM_SCALAR >
525  return add(var, new aggregator::Median< GUM_SCALAR >());
526  }
527 
528  template < typename GUM_SCALAR >
530  return add(var, new aggregator::Min< GUM_SCALAR >());
531  }
532 
533  template < typename GUM_SCALAR >
535  if (var.domainSize() > 2) GUM_ERROR(SizeError, "an OR has to be boolean");
536 
537  return add(var, new aggregator::Or< GUM_SCALAR >());
538  }
539 
540  template < typename GUM_SCALAR >
542  return add(var, new aggregator::Sum< GUM_SCALAR >());
543  }
544 
545  //================================
546  // ICIModels
547  //================================
548  template < typename GUM_SCALAR >
552  }
553 
554  template < typename GUM_SCALAR >
555  INLINE NodeId
559  }
560 
561  template < typename GUM_SCALAR >
565  }
566 
567  template < typename GUM_SCALAR >
571  }
572 
573  template < typename GUM_SCALAR >
577  }
578 
579  template < typename GUM_SCALAR >
582  NodeId id) {
584  }
585 
586  template < typename GUM_SCALAR >
589  NodeId id) {
591  }
592 
593  template < typename GUM_SCALAR >
596  NodeId id) {
597  return add(var, new MultiDimLogit< GUM_SCALAR >(external_weight), id);
598  }
599 
600  template < typename GUM_SCALAR >
601  INLINE NodeId
604  NodeId id) {
605  return add(var,
607  id);
608  }
609 
610  template < typename GUM_SCALAR >
613  NodeId id) {
615  }
616 
617  template < typename GUM_SCALAR >
619  NodeId head,
621  auto* CImodel = dynamic_cast< const MultiDimICIModel< GUM_SCALAR >* >(
622  cpt(head).content());
623 
624  if (CImodel != 0) {
625  // or is OK
626  addArc(tail, head);
627 
629  } else {
631  "Head variable (" << variable(head).name()
632  << ") is not a CIModel variable !");
633  }
634  }
635 
636  template < typename GUM_SCALAR >
638  const BayesNet< GUM_SCALAR >& bn) {
639  output << bn.toString();
640  return output;
641  }
642 
643  /// begin Multiple Change for all CPTs
644  template < typename GUM_SCALAR >
646  for (const auto node: nodes())
648  }
649 
650  /// end Multiple Change for all CPTs
651  template < typename GUM_SCALAR >
653  for (const auto node: nodes())
655  }
656 
657  /// clear all potentials
658  template < typename GUM_SCALAR >
660  // Removing previous potentials
661  for (const auto& elt: probaMap__) {
662  delete elt.second;
663  }
664 
665  probaMap__.clear();
666  }
667 
668  /// copy of potentials from a BN to another, using names of vars as ref.
669  template < typename GUM_SCALAR >
671  const BayesNet< GUM_SCALAR >& source) {
672  // Copying potentials
673 
674  for (const auto src: source.probaMap__) {
675  // First we build the node's CPT
678  for (gum::Idx i = 0; i < src.second->nbrDim(); i++) {
680  }
683 
684  // We add the CPT to the CPT's hashmap
686  }
687  }
688 
689  template < typename GUM_SCALAR >
691  for (const auto node: nodes())
692  generateCPT(node);
693  }
694 
695  template < typename GUM_SCALAR >
698 
700  }
701 
702  template < typename GUM_SCALAR >
705  if (cpt(id).nbrDim() != newPot->nbrDim()) {
707  "cannot exchange potentials with different "
708  "dimensions for variable with id "
709  << id);
710  }
711 
712  for (Idx i = 0; i < cpt(id).nbrDim(); i++) {
713  if (&cpt(id).variable(i) != &(newPot->variable(i))) {
715  "cannot exchange potentials because, for variable with id "
716  << id << ", dimension " << i << " differs. ");
717  }
718  }
719 
721  }
722 
723  template < typename GUM_SCALAR >
725  NodeId id,
727  delete probaMap__[id];
728  probaMap__[id] = newPot;
729  }
730 
731  template < typename GUM_SCALAR >
735  }
736 
737 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
NodeId build_node(gum::BayesNet< GUM_SCALAR > &bn, std::string node, gum::Size default_domain_size)
Definition: BayesNet_tpl.h:60