aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
Miic.cpp
Go to the documentation of this file.
1 
2 /**
3  *
4  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe
5  * GONZALES(@AMU) info_at_agrum_dot_org
6  *
7  * This library is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU Lesser General Public License as published by
9  * the Free Software Foundation, either version 3 of the License, or
10  * (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  * GNU Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public License
18  * along with this library. If not, see <http://www.gnu.org/licenses/>.
19  *
20  */
21 
22 
23 /** @file
24  * @brief Implementation of gum::learning::ThreeOffTwo and MIIC
25  *
26  * @author Quentin FALCAND, Marvin LASSERRE and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <agrum/tools/core/math/math_utils.h>
30 #include <agrum/tools/core/hashTable.h>
31 #include <agrum/tools/core/heap.h>
32 #include <agrum/tools/core/timer.h>
33 #include <agrum/tools/graphs/mixedGraph.h>
34 #include <agrum/BN/learning/Miic.h>
35 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
36 #include <agrum/tools/stattests/correctedMutualInformation.h>
37 
38 
39 namespace gum {
40 
41  namespace learning {
42 
43  /// default constructor
45 
46  /// default constructor with maxLog
48 
49  /// copy constructor
52  }
53 
54  /// move constructor
57  }
58 
59  /// destructor
61 
62  /// copy operator
63  Miic& Miic::operator=(const Miic& from) {
65  return *this;
66  }
67 
68  /// move operator
71  return *this;
72  }
73 
74 
75  bool GreaterPairOn2nd::operator()(const CondRanking& e1, const CondRanking& e2) const {
76  return e1.second > e2.second;
77  }
78 
79  bool GreaterAbsPairOn2nd::operator()(const Ranking& e1, const Ranking& e2) const {
80  return std::abs(e1.second) > std::abs(e2.second);
81  }
82 
84  const ProbabilisticRanking& e2) const {
85  double p1xz = std::get< 2 >(e1);
86  double p1yz = std::get< 3 >(e1);
87  double p2xz = std::get< 2 >(e2);
88  double p2yz = std::get< 3 >(e2);
89  double I1 = std::get< 1 >(e1);
90  double I2 = std::get< 1 >(e2);
91  // First, we look at the sign of information.
92  // Then, the probability values
93  // and finally the abs value of information.
94  if ((I1 < 0 && I2 < 0) || (I1 >= 0 && I2 >= 0)) {
95  if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) {
96  return std::abs(I1) > std::abs(I2);
97  } else {
98  return std::max(p1xz, p1yz) > std::max(p2xz, p2yz);
99  }
100  } else {
101  return I1 < I2;
102  }
103  }
104 
105  /// learns the structure of a MixedGraph
107  MixedGraph graph) {
108  timer_.reset();
109  current_step_ = 0;
110 
111  // clear the vector of latent arcs to be sure
113 
114  /// the heap of ranks, with the score, and the NodeIds of x, y and z.
116 
117  /// the variables separation sets
119 
121 
123 
124  // std::cout << "Le graphe contient: " << graph.sizeEdges() << " edges." <<
125  // std::endl; std::cout << "En voici la liste: " << graph.edges() <<
126  // std::endl;
127 
128  if (_useMiic_) {
130  } else {
132  }
133 
134  return graph;
135  }
136 
137  /*
138  * PHASE 1 : INITIATION
139  *
140  * We go over all edges and test if the variables are independent. If they
141  * are,
142  * the edge is deleted. If not, the best contributor is found.
143  */
145  MixedGraph& graph,
148  NodeId x, y;
149  EdgeSet edges = graph.edges();
151 
152  for (const Edge& edge: edges) {
153  x = edge.first();
154  y = edge.second();
155  double Ixy = mutualInformation.score(x, y);
156 
157  if (Ixy <= 0) { //< K
160  } else {
162  }
163 
164  ++current_step_;
165  if (onProgress.hasListener()) {
167  }
168  }
169  }
170 
171  /*
172  * PHASE 2 : ITERATION
173  *
174  * As long as we find important nodes for edges, we go over them to see if
175  * we can assess the independence of the variables.
176  */
178  MixedGraph& graph,
181  // if no triples to further examine pass
183 
185  Size steps_iter = rank.size();
186 
187  try {
188  while (rank.top().second > 0.5) {
189  best = rank.pop();
190 
191  const NodeId x = std::get< 0 >(*(best.first));
192  const NodeId y = std::get< 1 >(*(best.first));
193  const NodeId z = std::get< 2 >(*(best.first));
194  std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
195 
196  ui.push_back(z);
197  const double i_xy_ui = mutualInformation.score(x, y, ui);
198  if (i_xy_ui < 0) {
199  graph.eraseEdge(Edge(x, y));
201  } else {
203  }
204 
205  delete best.first;
206 
207  ++current_step_;
208  if (onProgress.hasListener()) {
210  (current_step_ * 66) / (steps_init + steps_iter),
211  0.,
212  timer_.step());
213  }
214  }
215  } catch (...) {} // here, rank is empty
217  if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 66, 0., timer_.step()); }
219  }
220 
221  /*
222  * PHASE 3 : ORIENTATION
223  *
224  * Try to assess v-structures and propagate them.
225  */
228  MixedGraph& graph,
229  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
233 
234  // marks always correspond to the head of the arc/edge. - is for a forbidden
235  // arc, > for a mandatory arc
236  // we start by adding the mandatory arcs
237  for (auto iter = _initialMarks_.begin(); iter != _initialMarks_.end(); ++iter) {
238  if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') {
241  }
242  }
243 
244  NodeId i = 0;
245  // list of elements that we shouldnt read again, ie elements that are
246  // eligible to
247  // rule 0 after the first time they are tested, and elements on which rule 1
248  // has been applied
249  while (i < triples.size()) {
250  // if i not in do_not_reread
251  Ranking triple = triples[i];
252  NodeId x, y, z;
253  x = std::get< 0 >(*triple.first);
254  y = std::get< 1 >(*triple.first);
255  z = std::get< 2 >(*triple.first);
256 
257  std::vector< NodeId > ui;
258  std::pair< NodeId, NodeId > key = {x, y};
259  std::pair< NodeId, NodeId > rev_key = {y, x};
260  if (sepSet.exists(key)) {
261  ui = sepSet[key];
262  } else if (sepSet.exists(rev_key)) {
263  ui = sepSet[rev_key];
264  }
265  double Ixyz_ui = triple.second;
266  bool reset{false};
267  // try Rule 0
268  if (Ixyz_ui < 0) {
269  // if ( z not in Sep[x,y])
270  if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
271  if (!graph.existsArc(x, z) && !graph.existsArc(z, x)) {
272  // when we try to add an arc to the graph, we always verify if
273  // we are allowed to do so, ie it is not a forbidden arc an it
274  // does not create a cycle
275  if (!_existsDirectedPath_(graph, z, x) && !isForbidenArc_(x, z)) {
276  reset = true;
277  graph.eraseEdge(Edge(x, z));
278  graph.addArc(x, z);
279  } else if (_existsDirectedPath_(graph, z, x) && !isForbidenArc_(z, x)) {
280  reset = true;
281  graph.eraseEdge(Edge(x, z));
282  // if we find a cycle, we force the competing edge
283  graph.addArc(z, x);
285  == _latentCouples_.end()) {
287  }
288  }
289  } else if (!graph.existsArc(y, z) && !graph.existsArc(z, y)) {
290  if (!_existsDirectedPath_(graph, z, y) && !isForbidenArc_(x, z)) {
291  reset = true;
292  graph.eraseEdge(Edge(y, z));
293  graph.addArc(y, z);
294  } else if (_existsDirectedPath_(graph, z, y) && !isForbidenArc_(z, y)) {
295  reset = true;
296  graph.eraseEdge(Edge(y, z));
297  // if we find a cycle, we force the competing edge
298  graph.addArc(z, y);
300  == _latentCouples_.end()) {
302  }
303  }
304  } else {
305  // checking if the anti-directed arc already exists, to register a
306  // latent variable
307  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
309  }
310  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
312  }
313  }
314  }
315  } else { // try Rule 1
316  if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) {
317  if (!_existsDirectedPath_(graph, y, z) && !isForbidenArc_(z, y)) {
318  reset = true;
319  graph.eraseEdge(Edge(z, y));
320  graph.addArc(z, y);
321  } else if (_existsDirectedPath_(graph, y, z) && !isForbidenArc_(y, z)) {
322  reset = true;
323  graph.eraseEdge(Edge(z, y));
324  // if we find a cycle, we force the competing edge
325  graph.addArc(y, z);
327  == _latentCouples_.end()) {
329  }
330  }
331  }
332  if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) {
333  if (!_existsDirectedPath_(graph, x, z) && !isForbidenArc_(z, x)) {
334  reset = true;
335  graph.eraseEdge(Edge(z, x));
336  graph.addArc(z, x);
337  } else if (_existsDirectedPath_(graph, x, z) && !isForbidenArc_(x, z)) {
338  reset = true;
339  graph.eraseEdge(Edge(z, x));
340  // if we find a cycle, we force the competing edge
341  graph.addArc(x, z);
343  == _latentCouples_.end()) {
345  }
346  }
347  }
348  } // if rule 0 or rule 1
349 
350  // if what we want to add already exists : pass to the next triplet
351  if (reset) {
352  i = 0;
353  } else {
354  ++i;
355  }
356  if (onProgress.hasListener()) {
358  ((current_step_ + i) * 100) / (past_steps + steps_orient),
359  0.,
360  timer_.step());
361  }
362  } // while
363 
364  // erasing the the double headed arcs
365  for (const Arc& arc: _latentCouples_) {
366  graph.eraseArc(Arc(arc.head(), arc.tail()));
367  }
368  }
369 
370  /// varient trying to propagate both orientations in a bidirected arc
373  MixedGraph& graph,
374  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
378 
379  NodeId i = 0;
380  // list of elements that we shouldnt read again, ie elements that are
381  // eligible to
382  // rule 0 after the first time they are tested, and elements on which rule 1
383  // has been applied
384  while (i < triples.size()) {
385  // if i not in do_not_reread
386  Ranking triple = triples[i];
387  NodeId x, y, z;
388  x = std::get< 0 >(*triple.first);
389  y = std::get< 1 >(*triple.first);
390  z = std::get< 2 >(*triple.first);
391 
392  std::vector< NodeId > ui;
393  std::pair< NodeId, NodeId > key = {x, y};
394  std::pair< NodeId, NodeId > rev_key = {y, x};
395  if (sepSet.exists(key)) {
396  ui = sepSet[key];
397  } else if (sepSet.exists(rev_key)) {
398  ui = sepSet[rev_key];
399  }
400  double Ixyz_ui = triple.second;
401  // try Rule 0
402  if (Ixyz_ui < 0) {
403  // if ( z not in Sep[x,y])
404  if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
405  // if what we want to add already exists : pass
406  if ((graph.existsArc(x, z) || graph.existsArc(z, x))
407  && (graph.existsArc(y, z) || graph.existsArc(z, y))) {
408  ++i;
409  } else {
410  i = 0;
411  graph.eraseEdge(Edge(x, z));
412  graph.eraseEdge(Edge(y, z));
413  // checking for cycles
414  if (graph.existsArc(z, x)) {
415  graph.eraseArc(Arc(z, x));
416  try {
418  // if we find a cycle, we force the competing edge
420  } catch (gum::NotFound) { graph.addArc(x, z); }
421  graph.addArc(z, x);
422  } else {
423  try {
425  // if we find a cycle, we force the competing edge
426  graph.addArc(z, x);
428  } catch (gum::NotFound) { graph.addArc(x, z); }
429  }
430  if (graph.existsArc(z, y)) {
431  graph.eraseArc(Arc(z, y));
432  try {
434  // if we find a cycle, we force the competing edge
436  } catch (gum::NotFound) { graph.addArc(y, z); }
437  graph.addArc(z, y);
438  } else {
439  try {
441  // if we find a cycle, we force the competing edge
442  graph.addArc(z, y);
444 
445  } catch (gum::NotFound) { graph.addArc(y, z); }
446  }
447  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
449  }
450  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
452  }
453  }
454  } else {
455  ++i;
456  }
457  } else { // try Rule 1
458  bool reset{false};
459  if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) {
460  reset = true;
461  graph.eraseEdge(Edge(z, y));
462  try {
464  // if we find a cycle, we force the competing edge
465  graph.addArc(y, z);
467  } catch (gum::NotFound) { graph.addArc(z, y); }
468  }
469  if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) {
470  reset = true;
471  graph.eraseEdge(Edge(z, x));
472  try {
474  // if we find a cycle, we force the competing edge
475  graph.addArc(x, z);
477  } catch (gum::NotFound) { graph.addArc(z, x); }
478  }
479 
480  if (reset) {
481  i = 0;
482  } else {
483  ++i;
484  }
485  } // if rule 0 or rule 1
486  if (onProgress.hasListener()) {
488  ((current_step_ + i) * 100) / (past_steps + steps_orient),
489  0.,
490  timer_.step());
491  }
492  } // while
493 
494  // erasing the the double headed arcs
495  for (const Arc& arc: _latentCouples_) {
496  graph.eraseArc(Arc(arc.head(), arc.tail()));
497  }
498  }
499 
500  /// varient using the orientation protocol of MIIC
503  MixedGraph& graph,
504  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
505  // structure to store the orientations marks -, o, or >,
506  // Considers the head of the arc/edge first node -* second node
508 
509  // marks always correspond to the head of the arc/edge. - is for a forbidden
510  // arc, > for a mandatory arc
511  // we start by adding the mandatory arcs
512  for (auto iter = marks.begin(); iter != marks.end(); ++iter) {
513  if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') {
516  }
517  }
518 
521 
524 
526  if (steps_orient > 0) { best = proba_triples[0]; }
527 
528  while (!proba_triples.empty() && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
529  const NodeId x = std::get< 0 >(*std::get< 0 >(best));
530  const NodeId y = std::get< 1 >(*std::get< 0 >(best));
531  const NodeId z = std::get< 2 >(*std::get< 0 >(best));
532 
533  const double i3 = std::get< 1 >(best);
534 
535  const double p1 = std::get< 2 >(best);
536  const double p2 = std::get< 3 >(best);
537  if (i3 <= 0) {
539  } else {
541  }
542 
543  delete std::get< 0 >(best);
545  // actualisation of the list of triples
547 
548  if (!proba_triples.empty()) best = proba_triples[0];
549 
550  ++current_step_;
551  if (onProgress.hasListener()) {
553  (current_step_ * 100) / (steps_orient + past_steps),
554  0.,
555  timer_.step());
556  }
557  } // while
558 
559  // erasing the double headed arcs
560  for (auto iter = _latentCouples_.rbegin(); iter != _latentCouples_.rend(); ++iter) {
561  graph.eraseArc(Arc(iter->head(), iter->tail()));
562  if (_existsDirectedPath_(graph, iter->head(), iter->tail())) {
563  // if we find a cycle, we force the competing edge
564  graph.addArc(iter->head(), iter->tail());
565  graph.eraseArc(Arc(iter->tail(), iter->head()));
566  *iter = Arc(iter->head(), iter->tail());
567  }
568  }
569 
570  if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 100, 0., timer_.step()); }
571  }
572 
573  /// finds the best contributor node for a pair given a conditioning set
575  NodeId y,
576  const std::vector< NodeId >& ui,
577  const MixedGraph& graph,
580  double maxP = -1.0;
581  NodeId maxZ = 0;
582 
583  // compute N
584  // __N = I.N();
585  const double Ixy_ui = mutualInformation.score(x, y, ui);
586 
587  for (const NodeId z: graph) {
588  // if z!=x and z!=y and z not in ui
589  if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
590  double Pnv;
591  double Pb;
592 
593  // Computing Pnv
594  const double Ixyz_ui = mutualInformation.score(x, y, z, ui);
595  double calc_expo1 = -Ixyz_ui * M_LN2;
596  // if exponentials are too high or to low, crop them at | __maxLog|
597  if (calc_expo1 > _maxLog_) {
598  Pnv = 0.0;
599  } else if (calc_expo1 < -_maxLog_) {
600  Pnv = 1.0;
601  } else {
602  Pnv = 1 / (1 + std::exp(calc_expo1));
603  }
604 
605  // Computing Pb
606  const double Ixz_ui = mutualInformation.score(x, z, ui);
607  const double Iyz_ui = mutualInformation.score(y, z, ui);
608 
609  calc_expo1 = -(Ixz_ui - Ixy_ui) * M_LN2;
610  double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2;
611 
612  // if exponentials are too high or to low, crop them at _maxLog_
613  if (calc_expo1 > _maxLog_ || calc_expo2 > _maxLog_) {
614  Pb = 0.0;
615  } else if (calc_expo1 < -_maxLog_ && calc_expo2 < -_maxLog_) {
616  Pb = 1.0;
617  } else {
618  double expo1, expo2;
619  if (calc_expo1 < -_maxLog_) {
620  expo1 = 0.0;
621  } else {
622  expo1 = std::exp(calc_expo1);
623  }
624  if (calc_expo2 < -_maxLog_) {
625  expo2 = 0.0;
626  } else {
627  expo2 = std::exp(calc_expo2);
628  }
629  Pb = 1 / (1 + expo1 + expo2);
630  }
631 
632  // Getting max(min(Pnv, pb))
633  const double min_pnv_pb = std::min(Pnv, Pb);
634  if (min_pnv_pb > maxP) {
635  maxP = min_pnv_pb;
636  maxZ = z;
637  }
638  } // if z not in (x, y)
639  } // for z in graph.nodes
640  // storing best z in rank_
642  auto tup = new CondThreePoints{x, y, maxZ, ui};
643  final.first = tup;
644  final.second = maxP;
645  rank.insert(final);
646  }
647 
648  /// gets the list of unshielded triples in the graph in decreasing value of
649  ///|I'(x, y, z|{ui})|
651  const MixedGraph& graph,
653  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
655  for (NodeId z: graph) {
656  for (NodeId x: graph.neighbours(z)) {
657  for (NodeId y: graph.neighbours(z)) {
658  if (y < x && !graph.existsEdge(x, y)) {
659  std::vector< NodeId > ui;
660  std::pair< NodeId, NodeId > key = {x, y};
661  std::pair< NodeId, NodeId > rev_key = {y, x};
662  if (sepSet.exists(key)) {
663  ui = sepSet[key];
664  } else if (sepSet.exists(rev_key)) {
665  ui = sepSet[rev_key];
666  }
667  // remove z from ui if it's present
668  const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
669  if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
670 
671  double Ixyz_ui = mutualInformation.score(x, y, z, ui);
672  Ranking triple;
673  auto tup = new ThreePoints{x, y, z};
674  triple.first = tup;
677  }
678  }
679  }
680  }
682  return triples;
683  }
684 
685  /// gets the list of unshielded triples in the graph in decreasing value of
686  ///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
688  const MixedGraph& graph,
690  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
691  HashTable< std::pair< NodeId, NodeId >, char >& marks) {
693  for (NodeId z: graph) {
694  for (NodeId x: graph.neighbours(z)) {
695  for (NodeId y: graph.neighbours(z)) {
696  if (y < x && !graph.existsEdge(x, y)) {
697  std::vector< NodeId > ui;
698  std::pair< NodeId, NodeId > key = {x, y};
699  std::pair< NodeId, NodeId > rev_key = {y, x};
700  if (sepSet.exists(key)) {
701  ui = sepSet[key];
702  } else if (sepSet.exists(rev_key)) {
703  ui = sepSet[rev_key];
704  }
705  // remove z from ui if it's present
706  const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
707  if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
708 
709  const double Ixyz_ui = mutualInformation.score(x, y, z, ui);
710  auto tup = new ThreePoints{x, y, z};
713  if (!marks.exists({x, z})) { marks.insert({x, z}, 'o'); }
714  if (!marks.exists({z, x})) { marks.insert({z, x}, 'o'); }
715  if (!marks.exists({y, z})) { marks.insert({y, z}, 'o'); }
716  if (!marks.exists({z, y})) { marks.insert({z, y}, 'o'); }
717  }
718  }
719  }
720  }
723  return triples;
724  }
725 
726  /// Gets the orientation probabilities like MIIC for the orientation phase
730  for (auto& triple: probaTriples) {
731  NodeId x, y, z;
732  x = std::get< 0 >(*std::get< 0 >(triple));
733  y = std::get< 1 >(*std::get< 0 >(triple));
734  z = std::get< 2 >(*std::get< 0 >(triple));
735  const double Ixyz = std::get< 1 >(triple);
736  double Pxz = std::get< 2 >(triple);
737  double Pyz = std::get< 3 >(triple);
738 
739  if (Ixyz <= 0) {
740  const double expo = std::exp(Ixyz);
741  const double P0 = (1 + expo) / (1 + 3 * expo);
742  // distinguish betweeen the initialization and the update process
743  if (Pxz == Pyz && Pyz == 0.5) {
744  std::get< 2 >(triple) = P0;
745  std::get< 3 >(triple) = P0;
746  } else {
747  if (graph.existsArc(x, z) && Pxz >= P0) {
748  std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
749  } else if (graph.existsArc(y, z) && Pyz >= P0) {
750  std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
751  }
752  }
753  } else {
754  const double expo = std::exp(-Ixyz);
755  if (graph.existsArc(x, z) && Pxz >= 0.5) {
756  std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
757  } else if (graph.existsArc(y, z) && Pyz >= 0.5) {
758  std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
759  }
760  }
761  }
763  return probaTriples;
764  }
765 
766  /// learns the structure of an Bayesian network, ie a DAG, from an Essential
767  /// graph.
770  // orientate remaining edges
771 
773 
774  // first, forbidden arcs force arc in the other direction
775  for (NodeId x: order) {
776  const auto nei_x = essentialGraph.neighbours(x);
777  for (NodeId y: nei_x)
778  if (isForbidenArc_(x, y)) {
780  if (isForbidenArc_(y, x)) {
781  GUM_TRACE("Neither arc allowed for edge (" << x << "," << y << ")")
782  } else {
783  GUM_TRACE("Forced orientation : " << y << "->" << x)
785  }
786  } else if (isForbidenArc_(y, x)) {
788  GUM_TRACE("Forced orientation : " << x << "->" << y)
790  }
791  }
793 
794  // first, propagate existing orientations
795  bool newOrientation = true;
796  while (newOrientation) {
797  newOrientation = false;
798  for (NodeId x: order) {
799  if (!essentialGraph.parents(x).empty()) {
801  }
802  }
803  }
807 
808  // then decide the orientation for double arcs
809  for (NodeId x: order)
812  GUM_TRACE(" + Resolving double arcs (poorly)")
814  }
815 
816  // std::cout << "Le mixed graph après une deuxième propagation mesdames et
817  // messieurs: "
818  //<< essentialGraph.toDot() << std::endl;
819  // std::cout << "Le graphe contient maintenant : " <<
820  // essentialGraph.sizeArcs() << " arcs."
821  //<< std::endl;
822  // std::cout << "Que voici: " << essentialGraph.arcs() << std::endl;
823  // turn the mixed graph into a dag
824  DAG dag;
825  for (auto node: essentialGraph) {
827  }
828  for (const Arc& arc: essentialGraph.arcs()) {
829  dag.addArc(arc.tail(), arc.head());
830  }
831 
832  return dag;
833  }
834 
836  // no cycle
837  if (_existsDirectedPath_(graph, xj, xi)) {
838  GUM_TRACE("cycle(" << xi << "-" << xj << ")")
839  return false;
840  }
841 
842  // R1
843  if (!(graph.parents(xi) - graph.adjacents(xj)).empty()) {
844  GUM_TRACE("R1(" << xi << "-" << xj << ")")
845  return true;
846  }
847 
848  // R2
849  if (_existsDirectedPath_(graph, xi, xj)) {
850  GUM_TRACE("R2(" << xi << "-" << xj << ")")
851  return true;
852  }
853 
854  // R3
855  int nbr = 0;
856  for (const auto p: graph.parents(xj)) {
857  if (!graph.mixedOrientedPath(xi, p).empty()) {
858  nbr += 1;
859  if (nbr == 2) {
860  GUM_TRACE("R3(" << xi << "-" << xj << ")")
861  return true;
862  }
863  }
864  }
865  return false;
866  }
867 
869  // then decide the orientation for remaining edges
870  while (!essentialGraph.edges().empty()) {
871  const auto& edge = *(essentialGraph.edges().begin());
872  NodeId root = edge.first();
875  NodeSet stack{root};
876  // check the best root for the set of neighbours
877  while (!stack.empty()) {
878  NodeId next = *(stack.begin());
879  stack.erase(next);
880  if (visited.contains(next)) continue;
883  root = next;
884  }
885  for (const auto n: essentialGraph.neighbours(next))
886  if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
888  }
889  // orientation now
890  visited.clear();
891  stack.clear();
892  stack.insert(root);
893  while (!stack.empty()) {
894  NodeId next = *(stack.begin());
895  stack.erase(next);
896  if (visited.contains(next)) continue;
897  const auto nei = essentialGraph.neighbours(next);
898  for (const auto n: nei) {
899  if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
900  GUM_TRACE(" + amap reasonably orientation for " << n << "->" << next);
903  }
905  }
906  }
907  }
908 
909  /// Propagates the orientation from a node to its neighbours
911  bool res = false;
912  const auto neighbours = graph.neighbours(xj);
913  for (auto& xi: neighbours) {
914  bool i_j = isOrientable_(graph, xi, xj);
915  bool j_i = isOrientable_(graph, xj, xi);
916  if (i_j || j_i) {
917  GUM_TRACE(" + Removing edge (" << xi << "," << xj << ")")
918  graph.eraseEdge(Edge(xi, xj));
919  res = true;
920  }
921  if (i_j) {
922  GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
923  graph.addArc(xi, xj);
925  }
926  if (j_i) {
927  GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
928  graph.addArc(xj, xi);
930  }
931  if (i_j && j_i) {
932  GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
934  }
935  }
936 
937  return res;
938  }
939 
940  /// get the list of arcs hiding latent variables
941  const std::vector< Arc > Miic::latentVariables() const { return _latentCouples_; }
942 
943  /// learns the structure and the parameters of a BN
944  template < typename GUM_SCALAR, typename GRAPH_CHANGES_SELECTOR, typename PARAM_ESTIMATOR >
947  DAG initial_dag) {
950  }
951 
952  void Miic::setMiicBehaviour() { this->_useMiic_ = true; }
953 
954  void Miic::set3of2Behaviour() { this->_useMiic_ = false; }
955 
957  this->_initialMarks_ = constraints;
958  }
959 
961  const NodeId n1,
962  const NodeId n2) {
963  for (const auto parent: graph.parents(n2)) {
964  if (graph.existsArc(parent,
965  n2)) // if there is a double arc, pass
966  continue;
967  if (parent == n1) // trivial directed path => not recognized
968  continue;
969  if (_existsDirectedPath_(graph, n1, parent)) return true;
970  }
971  return false;
972  }
973 
974  bool Miic::_existsDirectedPath_(const MixedGraph& graph, const NodeId n1, const NodeId n2) {
975  // not recursive version => use a FIFO for simulating the recursion
976  List< NodeId > nodeFIFO;
977  // mark[node] = successor if visited, else mark[node] does not exist
978  Set< NodeId > mark;
979 
980  mark.insert(n2);
982 
983  NodeId current;
984 
985  while (!nodeFIFO.empty()) {
986  current = nodeFIFO.front();
987  nodeFIFO.popFront();
988 
989  // check the parents
990  for (const auto new_one: graph.parents(current)) {
991  if (graph.existsArc(current,
992  new_one)) // if there is a double arc, pass
993  continue;
994 
995  if (new_one == n1) { return true; }
996 
997  if (mark.exists(new_one)) // if this node is already marked, do not
998  continue; // check it again
999 
1000  mark.insert(new_one);
1002  }
1003  }
1004 
1005  return false;
1006  }
1007 
1009  HashTable< std::pair< NodeId, NodeId >, char >& marks,
1010  NodeId x,
1011  NodeId y,
1012  NodeId z,
1013  double p1,
1014  double p2) {
1015  // v-structure discovery
1016  if (marks[{x, z}] == 'o' && marks[{y, z}] == 'o') { // If x-z-y
1018  graph.eraseEdge(Edge(x, z));
1019  graph.addArc(x, z);
1020  GUM_TRACE("1.a Removing edge (" << x << "," << z << ")")
1021  GUM_TRACE("1.a Adding arc (" << x << "," << z << ")")
1022  marks[{x, z}] = '>';
1023  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
1024  GUM_TRACE("Adding latent couple (" << z << "," << x << ")")
1026  }
1027  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1028  } else {
1029  graph.eraseEdge(Edge(x, z));
1030  GUM_TRACE("1.b Adding arc (" << x << "," << z << ")")
1032  graph.addArc(z, x);
1033  GUM_TRACE("1.b Removing edge (" << x << "," << z << ")")
1034  marks[{z, x}] = '>';
1035  }
1036  }
1037 
1039  graph.eraseEdge(Edge(y, z));
1040  graph.addArc(y, z);
1041  GUM_TRACE("1.c Removing edge (" << y << "," << z << ")")
1042  GUM_TRACE("1.c Adding arc (" << y << "," << z << ")")
1043  marks[{y, z}] = '>';
1044  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
1046  }
1047  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1048  } else {
1049  graph.eraseEdge(Edge(y, z));
1050  GUM_TRACE("1.d Removing edge (" << y << "," << z << ")")
1052  graph.addArc(z, y);
1053  GUM_TRACE("1.d Adding arc (" << z << "," << y << ")")
1054  marks[{z, y}] = '>';
1055  }
1056  }
1057  } else if (marks[{x, z}] == '>' && marks[{y, z}] == 'o') { // If x->z-y
1059  graph.eraseEdge(Edge(y, z));
1060  graph.addArc(y, z);
1061  GUM_TRACE("2.a Removing edge (" << y << "," << z << ")")
1062  GUM_TRACE("2.a Adding arc (" << y << "," << z << ")")
1063  marks[{y, z}] = '>';
1064  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
1066  }
1067  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1068  } else {
1069  graph.eraseEdge(Edge(y, z));
1070  GUM_TRACE("2.b Removing edge (" << y << "," << z << ")")
1072  graph.addArc(z, y);
1073  GUM_TRACE("2.b Adding arc (" << y << "," << z << ")")
1074  marks[{z, y}] = '>';
1075  }
1076  }
1077  } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o') {
1079  graph.eraseEdge(Edge(x, z));
1080  graph.addArc(x, z);
1081  GUM_TRACE("3.a Removing edge (" << x << "," << z << ")")
1082  GUM_TRACE("3.a Adding arc (" << x << "," << z << ")")
1083  marks[{x, z}] = '>';
1084  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
1086  }
1087  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1088  } else {
1089  graph.eraseEdge(Edge(x, z));
1090  GUM_TRACE("3.b Removing edge (" << x << "," << z << ")")
1092  graph.addArc(z, x);
1093  GUM_TRACE("3.b Adding arc (" << x << "," << z << ")")
1094  marks[{z, x}] = '>';
1095  }
1096  }
1097  }
1098  }
1099 
1100 
1102  HashTable< std::pair< NodeId, NodeId >, char >& marks,
1103  NodeId x,
1104  NodeId y,
1105  NodeId z,
1106  double p1,
1107  double p2) {
1108  // orientation propagation
1109  if (marks[{x, z}] == '>' && marks[{y, z}] == 'o' && marks[{z, y}] != '-') {
1110  graph.eraseEdge(Edge(z, y));
1111  // std::cout << "4. Removing edge (" << z << "," << y << ")" <<
1112  // std::endl;
1113  if (!_existsDirectedPath_(graph, y, z) && graph.parents(y).empty()) {
1114  graph.addArc(z, y);
1115  GUM_TRACE("4.a Adding arc (" << z << "," << y << ")")
1116  marks[{z, y}] = '>';
1117  marks[{y, z}] = '-';
1118  if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1119  } else if (!_existsDirectedPath_(graph, z, y) && graph.parents(z).empty()) {
1120  graph.addArc(y, z);
1121  GUM_TRACE("4.b Adding arc (" << y << "," << z << ")")
1122  marks[{z, y}] = '-';
1123  marks[{y, z}] = '>';
1125  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1126  } else if (!_existsDirectedPath_(graph, y, z)) {
1127  graph.addArc(z, y);
1128  GUM_TRACE("4.c Adding arc (" << z << "," << y << ")")
1129  marks[{z, y}] = '>';
1130  marks[{y, z}] = '-';
1131  if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1132  } else if (!_existsDirectedPath_(graph, z, y)) {
1133  graph.addArc(y, z);
1134  GUM_TRACE("4.d Adding arc (" << y << "," << z << ")")
1136  marks[{z, y}] = '-';
1137  marks[{y, z}] = '>';
1138  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1139  }
1140  } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o' && marks[{z, x}] != '-') {
1141  graph.eraseEdge(Edge(z, x));
1142  GUM_TRACE("5. Removing edge (" << z << "," << x << ")")
1143  if (!_existsDirectedPath_(graph, x, z) && graph.parents(x).empty()) {
1144  graph.addArc(z, x);
1145  GUM_TRACE("5.a Adding arc (" << z << "," << x << ")")
1146  marks[{z, x}] = '>';
1147  marks[{x, z}] = '-';
1148  if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1149  } else if (!_existsDirectedPath_(graph, z, x) && graph.parents(z).empty()) {
1150  graph.addArc(x, z);
1151  GUM_TRACE("5.b Adding arc (" << x << "," << z << ")")
1152  marks[{z, x}] = '-';
1153  marks[{x, z}] = '>';
1155  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1156  } else if (!_existsDirectedPath_(graph, x, z)) {
1157  graph.addArc(z, x);
1158  GUM_TRACE("5.c Adding arc (" << z << "," << x << ")")
1159  marks[{z, x}] = '>';
1160  marks[{x, z}] = '-';
1161  if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1162  } else if (!_existsDirectedPath_(graph, z, x)) {
1163  graph.addArc(x, z);
1164  GUM_TRACE("5.d Adding arc (" << x << "," << z << ")")
1165  marks[{z, x}] = '-';
1166  marks[{x, z}] = '>';
1168  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1169  }
1170  }
1171  }
1172 
1173  bool Miic::_isNotLatentCouple_(const NodeId x, const NodeId y) {
1174  const auto& lbeg = _latentCouples_.begin();
1175  const auto& lend = _latentCouples_.end();
1176 
1177  return (std::find(lbeg, lend, Arc(x, y)) == lend)
1178  && (std::find(lbeg, lend, Arc(y, x)) == lend);
1179  }
1180 
1182  return (_initialMarks_.exists({x, y}) && _initialMarks_[{x, y}] == '-');
1183  }
1184  } /* namespace learning */
1185 
1186 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)