aGrUM  0.21.0
a C++ library for (probabilistic) graphical models
Miic.cpp
Go to the documentation of this file.
1 
2 /**
3  *
4  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe
5  * GONZALES(@AMU) info_at_agrum_dot_org
6  *
7  * This library is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU Lesser General Public License as published by
9  * the Free Software Foundation, either version 3 of the License, or
10  * (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  * GNU Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public License
18  * along with this library. If not, see <http://www.gnu.org/licenses/>.
19  *
20  */
21 
22 
23 /** @file
24  * @brief Implementation of gum::learning::ThreeOffTwo and MIIC
25  *
26  * @author Quentin FALCAND, Marvin LASSERRE and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <agrum/tools/core/math/math_utils.h>
30 #include <agrum/tools/core/hashTable.h>
31 #include <agrum/tools/core/heap.h>
32 #include <agrum/tools/core/timer.h>
33 #include <agrum/tools/graphs/mixedGraph.h>
34 #include <agrum/BN/learning/Miic.h>
35 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
36 #include <agrum/tools/stattests/correctedMutualInformation.h>
37 
38 
39 namespace gum {
40 
41  namespace learning {
42 
43  /// default constructor
45 
46  /// default constructor with maxLog
48 
49  /// copy constructor
52  }
53 
54  /// move constructor
57  }
58 
59  /// destructor
61 
62  /// copy operator
63  Miic& Miic::operator=(const Miic& from) {
65  return *this;
66  }
67 
68  /// move operator
71  return *this;
72  }
73 
74 
75  bool GreaterPairOn2nd::operator()(const CondRanking& e1, const CondRanking& e2) const {
76  return e1.second > e2.second;
77  }
78 
79  bool GreaterAbsPairOn2nd::operator()(const Ranking& e1, const Ranking& e2) const {
80  return std::abs(e1.second) > std::abs(e2.second);
81  }
82 
84  const ProbabilisticRanking& e2) const {
85  double p1xz = std::get< 2 >(e1);
86  double p1yz = std::get< 3 >(e1);
87  double p2xz = std::get< 2 >(e2);
88  double p2yz = std::get< 3 >(e2);
89  double I1 = std::get< 1 >(e1);
90  double I2 = std::get< 1 >(e2);
91  // First, we look at the sign of information.
92  // Then, the probability values
93  // and finally the abs value of information.
94  if ((I1 < 0 && I2 < 0) || (I1 >= 0 && I2 >= 0)) {
95  if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) {
96  return std::abs(I1) > std::abs(I2);
97  } else {
98  return std::max(p1xz, p1yz) > std::max(p2xz, p2yz);
99  }
100  } else {
101  return I1 < I2;
102  }
103  }
104 
105  /// learns the structure of a MixedGraph
107  MixedGraph graph) {
108  timer_.reset();
109  current_step_ = 0;
110 
111  // clear the vector of latent arcs to be sure
113 
114  /// the heap of ranks, with the score, and the NodeIds of x, y and z.
116 
117  /// the variables separation sets
119 
121 
123 
124  if (_useMiic_) {
126  } else {
128  }
129 
130  return graph;
131  }
132 
133  /*
134  * PHASE 1 : INITIATION
135  *
136  * We go over all edges and test if the variables are independent. If they
137  * are,
138  * the edge is deleted. If not, the best contributor is found.
139  */
141  MixedGraph& graph,
144  NodeId x, y;
145  EdgeSet edges = graph.edges();
147 
148  for (const Edge& edge: edges) {
149  x = edge.first();
150  y = edge.second();
151  double Ixy = mutualInformation.score(x, y);
152 
153  if (Ixy <= 0) { //< K
156  } else {
158  }
159 
160  ++current_step_;
161  if (onProgress.hasListener()) {
163  }
164  }
165  }
166 
167  /*
168  * PHASE 2 : ITERATION
169  *
170  * As long as we find important nodes for edges, we go over them to see if
171  * we can assess the independence of the variables.
172  */
174  MixedGraph& graph,
177  // if no triples to further examine pass
179 
181  Size steps_iter = rank.size();
182 
183  try {
184  while (rank.top().second > 0.5) {
185  best = rank.pop();
186 
187  const NodeId x = std::get< 0 >(*(best.first));
188  const NodeId y = std::get< 1 >(*(best.first));
189  const NodeId z = std::get< 2 >(*(best.first));
190  std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
191 
192  ui.push_back(z);
193  const double i_xy_ui = mutualInformation.score(x, y, ui);
194  if (i_xy_ui < 0) {
195  graph.eraseEdge(Edge(x, y));
197  } else {
199  }
200 
201  delete best.first;
202 
203  ++current_step_;
204  if (onProgress.hasListener()) {
206  (current_step_ * 66) / (steps_init + steps_iter),
207  0.,
208  timer_.step());
209  }
210  }
211  } catch (...) {} // here, rank is empty
213  if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 66, 0., timer_.step()); }
215  }
216 
217  /*
218  * PHASE 3 : ORIENTATION
219  *
220  * Try to assess v-structures and propagate them.
221  */
224  MixedGraph& graph,
225  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
229 
230  // marks always correspond to the head of the arc/edge. - is for a forbidden
231  // arc, > for a mandatory arc
232  // we start by adding the mandatory arcs
233  for (auto iter = _initialMarks_.begin(); iter != _initialMarks_.end(); ++iter) {
234  if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') {
237  }
238  }
239 
240  NodeId i = 0;
241  // list of elements that we shouldn't read again, ie elements that are
242  // eligible to
243  // rule 0 after the first time they are tested, and elements on which rule 1
244  // has been applied
245  while (i < triples.size()) {
246  // if i not in do_not_reread
247  Ranking triple = triples[i];
248  NodeId x, y, z;
249  x = std::get< 0 >(*triple.first);
250  y = std::get< 1 >(*triple.first);
251  z = std::get< 2 >(*triple.first);
252 
253  std::vector< NodeId > ui;
254  std::pair< NodeId, NodeId > key = {x, y};
255  std::pair< NodeId, NodeId > rev_key = {y, x};
256  if (sepSet.exists(key)) {
257  ui = sepSet[key];
258  } else if (sepSet.exists(rev_key)) {
259  ui = sepSet[rev_key];
260  }
261  double Ixyz_ui = triple.second;
262  bool reset{false};
263  // try Rule 0
264  if (Ixyz_ui < 0) {
265  // if ( z not in Sep[x,y])
266  if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
267  if (!graph.existsArc(x, z) && !graph.existsArc(z, x)) {
268  // when we try to add an arc to the graph, we always verify if
269  // we are allowed to do so, ie it is not a forbidden arc an it
270  // does not create a cycle
271  if (!_existsDirectedPath_(graph, z, x) && !isForbidenArc_(x, z)) {
272  reset = true;
273  graph.eraseEdge(Edge(x, z));
274  graph.addArc(x, z);
275  } else if (_existsDirectedPath_(graph, z, x) && !isForbidenArc_(z, x)) {
276  reset = true;
277  graph.eraseEdge(Edge(x, z));
278  // if we find a cycle, we force the competing edge
279  graph.addArc(z, x);
281  == _latentCouples_.end()) {
283  }
284  }
285  } else if (!graph.existsArc(y, z) && !graph.existsArc(z, y)) {
286  if (!_existsDirectedPath_(graph, z, y) && !isForbidenArc_(x, z)) {
287  reset = true;
288  graph.eraseEdge(Edge(y, z));
289  graph.addArc(y, z);
290  } else if (_existsDirectedPath_(graph, z, y) && !isForbidenArc_(z, y)) {
291  reset = true;
292  graph.eraseEdge(Edge(y, z));
293  // if we find a cycle, we force the competing edge
294  graph.addArc(z, y);
296  == _latentCouples_.end()) {
298  }
299  }
300  } else {
301  // checking if the anti-directed arc already exists, to register a
302  // latent variable
303  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
305  }
306  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
308  }
309  }
310  }
311  } else { // try Rule 1
312  if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) {
313  if (!_existsDirectedPath_(graph, y, z) && !isForbidenArc_(z, y)) {
314  reset = true;
315  graph.eraseEdge(Edge(z, y));
316  graph.addArc(z, y);
317  } else if (_existsDirectedPath_(graph, y, z) && !isForbidenArc_(y, z)) {
318  reset = true;
319  graph.eraseEdge(Edge(z, y));
320  // if we find a cycle, we force the competing edge
321  graph.addArc(y, z);
323  == _latentCouples_.end()) {
325  }
326  }
327  }
328  if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) {
329  if (!_existsDirectedPath_(graph, x, z) && !isForbidenArc_(z, x)) {
330  reset = true;
331  graph.eraseEdge(Edge(z, x));
332  graph.addArc(z, x);
333  } else if (_existsDirectedPath_(graph, x, z) && !isForbidenArc_(x, z)) {
334  reset = true;
335  graph.eraseEdge(Edge(z, x));
336  // if we find a cycle, we force the competing edge
337  graph.addArc(x, z);
339  == _latentCouples_.end()) {
341  }
342  }
343  }
344  } // if rule 0 or rule 1
345 
346  // if what we want to add already exists : pass to the next triplet
347  if (reset) {
348  i = 0;
349  } else {
350  ++i;
351  }
352  if (onProgress.hasListener()) {
354  ((current_step_ + i) * 100) / (past_steps + steps_orient),
355  0.,
356  timer_.step());
357  }
358  } // while
359 
360  // erasing the the double headed arcs
361  for (const Arc& arc: _latentCouples_) {
362  graph.eraseArc(Arc(arc.head(), arc.tail()));
363  }
364  }
365 
366  /// variant trying to propagate both orientations in a bidirected arc
369  MixedGraph& graph,
370  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
374 
375  NodeId i = 0;
376  // list of elements that we shouldnt read again, ie elements that are
377  // eligible to
378  // rule 0 after the first time they are tested, and elements on which rule 1
379  // has been applied
380  while (i < triples.size()) {
381  // if i not in do_not_reread
382  Ranking triple = triples[i];
383  NodeId x, y, z;
384  x = std::get< 0 >(*triple.first);
385  y = std::get< 1 >(*triple.first);
386  z = std::get< 2 >(*triple.first);
387 
388  std::vector< NodeId > ui;
389  std::pair< NodeId, NodeId > key = {x, y};
390  std::pair< NodeId, NodeId > rev_key = {y, x};
391  if (sepSet.exists(key)) {
392  ui = sepSet[key];
393  } else if (sepSet.exists(rev_key)) {
394  ui = sepSet[rev_key];
395  }
396  double Ixyz_ui = triple.second;
397  // try Rule 0
398  if (Ixyz_ui < 0) {
399  // if ( z not in Sep[x,y])
400  if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
401  // if what we want to add already exists : pass
402  if ((graph.existsArc(x, z) || graph.existsArc(z, x))
403  && (graph.existsArc(y, z) || graph.existsArc(z, y))) {
404  ++i;
405  } else {
406  i = 0;
407  graph.eraseEdge(Edge(x, z));
408  graph.eraseEdge(Edge(y, z));
409  // checking for cycles
410  if (graph.existsArc(z, x)) {
411  graph.eraseArc(Arc(z, x));
412  try {
414  // if we find a cycle, we force the competing edge
416  } catch (gum::NotFound) { graph.addArc(x, z); }
417  graph.addArc(z, x);
418  } else {
419  try {
421  // if we find a cycle, we force the competing edge
422  graph.addArc(z, x);
424  } catch (gum::NotFound) { graph.addArc(x, z); }
425  }
426  if (graph.existsArc(z, y)) {
427  graph.eraseArc(Arc(z, y));
428  try {
430  // if we find a cycle, we force the competing edge
432  } catch (gum::NotFound) { graph.addArc(y, z); }
433  graph.addArc(z, y);
434  } else {
435  try {
437  // if we find a cycle, we force the competing edge
438  graph.addArc(z, y);
440 
441  } catch (gum::NotFound) { graph.addArc(y, z); }
442  }
443  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
445  }
446  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
448  }
449  }
450  } else {
451  ++i;
452  }
453  } else { // try Rule 1
454  bool reset{false};
455  if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) {
456  reset = true;
457  graph.eraseEdge(Edge(z, y));
458  try {
460  // if we find a cycle, we force the competing edge
461  graph.addArc(y, z);
463  } catch (gum::NotFound) { graph.addArc(z, y); }
464  }
465  if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) {
466  reset = true;
467  graph.eraseEdge(Edge(z, x));
468  try {
470  // if we find a cycle, we force the competing edge
471  graph.addArc(x, z);
473  } catch (gum::NotFound) { graph.addArc(z, x); }
474  }
475 
476  if (reset) {
477  i = 0;
478  } else {
479  ++i;
480  }
481  } // if rule 0 or rule 1
482  if (onProgress.hasListener()) {
484  ((current_step_ + i) * 100) / (past_steps + steps_orient),
485  0.,
486  timer_.step());
487  }
488  } // while
489 
490  // erasing the the double headed arcs
491  for (const Arc& arc: _latentCouples_) {
492  graph.eraseArc(Arc(arc.head(), arc.tail()));
493  }
494  }
495 
496  /// varient using the orientation protocol of MIIC
499  MixedGraph& graph,
500  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
501  // structure to store the orientations marks -, o, or >,
502  // Considers the head of the arc/edge first node -* second node
504 
505  // marks always correspond to the head of the arc/edge. - is for a forbidden
506  // arc, > for a mandatory arc
507  // we start by adding the mandatory arcs
508  for (auto iter = marks.begin(); iter != marks.end(); ++iter) {
509  if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') {
512  }
513  }
514 
517 
520 
522  if (steps_orient > 0) { best = proba_triples[0]; }
523 
524  while (!proba_triples.empty() && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
525  const NodeId x = std::get< 0 >(*std::get< 0 >(best));
526  const NodeId y = std::get< 1 >(*std::get< 0 >(best));
527  const NodeId z = std::get< 2 >(*std::get< 0 >(best));
528 
529  const double i3 = std::get< 1 >(best);
530 
531  const double p1 = std::get< 2 >(best);
532  const double p2 = std::get< 3 >(best);
533  if (i3 <= 0) {
535  } else {
537  }
538 
539  delete std::get< 0 >(best);
541  // actualisation of the list of triples
543 
544  if (!proba_triples.empty()) best = proba_triples[0];
545 
546  ++current_step_;
547  if (onProgress.hasListener()) {
549  (current_step_ * 100) / (steps_orient + past_steps),
550  0.,
551  timer_.step());
552  }
553  } // while
554 
555  // erasing the double headed arcs
556  for (auto iter = _latentCouples_.rbegin(); iter != _latentCouples_.rend(); ++iter) {
557  graph.eraseArc(Arc(iter->head(), iter->tail()));
558  if (_existsDirectedPath_(graph, iter->head(), iter->tail())) {
559  // if we find a cycle, we force the competing edge
560  graph.addArc(iter->head(), iter->tail());
561  graph.eraseArc(Arc(iter->tail(), iter->head()));
562  *iter = Arc(iter->head(), iter->tail());
563  }
564  }
565 
566  if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 100, 0., timer_.step()); }
567  }
568 
569  /// finds the best contributor node for a pair given a conditioning set
571  NodeId y,
572  const std::vector< NodeId >& ui,
573  const MixedGraph& graph,
576  double maxP = -1.0;
577  NodeId maxZ = 0;
578 
579  // compute N
580  // __N = I.N();
581  const double Ixy_ui = mutualInformation.score(x, y, ui);
582 
583  for (const NodeId z: graph) {
584  // if z!=x and z!=y and z not in ui
585  if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
586  double Pnv;
587  double Pb;
588 
589  // Computing Pnv
590  const double Ixyz_ui = mutualInformation.score(x, y, z, ui);
591  double calc_expo1 = -Ixyz_ui * M_LN2;
592  // if exponential are too high or to low, crop them at _maxLog_
593  if (calc_expo1 > _maxLog_) {
594  Pnv = 0.0;
595  } else if (calc_expo1 < -_maxLog_) {
596  Pnv = 1.0;
597  } else {
598  Pnv = 1 / (1 + std::exp(calc_expo1));
599  }
600 
601  // Computing Pb
602  const double Ixz_ui = mutualInformation.score(x, z, ui);
603  const double Iyz_ui = mutualInformation.score(y, z, ui);
604 
605  calc_expo1 = -(Ixz_ui - Ixy_ui) * M_LN2;
606  double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2;
607 
608  // if exponential are too high or to low, crop them at _maxLog_
609  if (calc_expo1 > _maxLog_ || calc_expo2 > _maxLog_) {
610  Pb = 0.0;
611  } else if (calc_expo1 < -_maxLog_ && calc_expo2 < -_maxLog_) {
612  Pb = 1.0;
613  } else {
614  double expo1, expo2;
615  if (calc_expo1 < -_maxLog_) {
616  expo1 = 0.0;
617  } else {
618  expo1 = std::exp(calc_expo1);
619  }
620  if (calc_expo2 < -_maxLog_) {
621  expo2 = 0.0;
622  } else {
623  expo2 = std::exp(calc_expo2);
624  }
625  Pb = 1 / (1 + expo1 + expo2);
626  }
627 
628  // Getting max(min(Pnv, pb))
629  const double min_pnv_pb = std::min(Pnv, Pb);
630  if (min_pnv_pb > maxP) {
631  maxP = min_pnv_pb;
632  maxZ = z;
633  }
634  } // if z not in (x, y)
635  } // for z in graph.nodes
636  // storing best z in rank_
638  auto tup = new CondThreePoints{x, y, maxZ, ui};
639  final.first = tup;
640  final.second = maxP;
641  rank.insert(final);
642  }
643 
644  /// gets the list of unshielded triples in the graph in decreasing value of
645  ///|I'(x, y, z|{ui})|
647  const MixedGraph& graph,
649  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
651  for (NodeId z: graph) {
652  for (NodeId x: graph.neighbours(z)) {
653  for (NodeId y: graph.neighbours(z)) {
654  if (y < x && !graph.existsEdge(x, y)) {
655  std::vector< NodeId > ui;
656  std::pair< NodeId, NodeId > key = {x, y};
657  std::pair< NodeId, NodeId > rev_key = {y, x};
658  if (sepSet.exists(key)) {
659  ui = sepSet[key];
660  } else if (sepSet.exists(rev_key)) {
661  ui = sepSet[rev_key];
662  }
663  // remove z from ui if it's present
664  const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
665  if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
666 
667  double Ixyz_ui = mutualInformation.score(x, y, z, ui);
668  Ranking triple;
669  auto tup = new ThreePoints{x, y, z};
670  triple.first = tup;
673  }
674  }
675  }
676  }
678  return triples;
679  }
680 
681  /// gets the list of unshielded triples in the graph in decreasing value of
682  ///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
684  const MixedGraph& graph,
686  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
687  HashTable< std::pair< NodeId, NodeId >, char >& marks) {
689  for (NodeId z: graph) {
690  for (NodeId x: graph.neighbours(z)) {
691  for (NodeId y: graph.neighbours(z)) {
692  if (y < x && !graph.existsEdge(x, y)) {
693  std::vector< NodeId > ui;
694  std::pair< NodeId, NodeId > key = {x, y};
695  std::pair< NodeId, NodeId > rev_key = {y, x};
696  if (sepSet.exists(key)) {
697  ui = sepSet[key];
698  } else if (sepSet.exists(rev_key)) {
699  ui = sepSet[rev_key];
700  }
701  // remove z from ui if it's present
702  const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
703  if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
704 
705  const double Ixyz_ui = mutualInformation.score(x, y, z, ui);
706  auto tup = new ThreePoints{x, y, z};
709  if (!marks.exists({x, z})) { marks.insert({x, z}, 'o'); }
710  if (!marks.exists({z, x})) { marks.insert({z, x}, 'o'); }
711  if (!marks.exists({y, z})) { marks.insert({y, z}, 'o'); }
712  if (!marks.exists({z, y})) { marks.insert({z, y}, 'o'); }
713  }
714  }
715  }
716  }
719  return triples;
720  }
721 
722  /// Gets the orientation probabilities like MIIC for the orientation phase
726  for (auto& triple: probaTriples) {
727  NodeId x, y, z;
728  x = std::get< 0 >(*std::get< 0 >(triple));
729  y = std::get< 1 >(*std::get< 0 >(triple));
730  z = std::get< 2 >(*std::get< 0 >(triple));
731  const double Ixyz = std::get< 1 >(triple);
732  double Pxz = std::get< 2 >(triple);
733  double Pyz = std::get< 3 >(triple);
734 
735  if (Ixyz <= 0) {
736  const double expo = std::exp(Ixyz);
737  const double P0 = (1 + expo) / (1 + 3 * expo);
738  // distinguish between the initialization and the update process
739  if (Pxz == Pyz && Pyz == 0.5) {
740  std::get< 2 >(triple) = P0;
741  std::get< 3 >(triple) = P0;
742  } else {
743  if (graph.existsArc(x, z) && Pxz >= P0) {
744  std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
745  } else if (graph.existsArc(y, z) && Pyz >= P0) {
746  std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
747  }
748  }
749  } else {
750  const double expo = std::exp(-Ixyz);
751  if (graph.existsArc(x, z) && Pxz >= 0.5) {
752  std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
753  } else if (graph.existsArc(y, z) && Pyz >= 0.5) {
754  std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
755  }
756  }
757  }
759  return probaTriples;
760  }
761 
762  /// learns the structure of an Bayesian network, ie a DAG, from an Essential
763  /// graph.
766  // orientate remaining edges
767 
769 
770  // first, forbidden arcs force arc in the other direction
771  for (NodeId x: order) {
772  const auto nei_x = essentialGraph.neighbours(x);
773  for (NodeId y: nei_x)
774  if (isForbidenArc_(x, y)) {
776  if (isForbidenArc_(y, x)) {
777  GUM_TRACE("Neither arc allowed for edge (" << x << "," << y << ")")
778  } else {
779  GUM_TRACE("Forced orientation : " << y << "->" << x)
781  }
782  } else if (isForbidenArc_(y, x)) {
784  GUM_TRACE("Forced orientation : " << x << "->" << y)
786  }
787  }
789 
790  // first, propagate existing orientations
791  bool newOrientation = true;
792  while (newOrientation) {
793  newOrientation = false;
794  for (NodeId x: order) {
795  if (!essentialGraph.parents(x).empty()) {
797  }
798  }
799  }
803 
804  // then decide the orientation for double arcs
805  for (NodeId x: order)
808  GUM_TRACE(" + Resolving double arcs (poorly)")
810  }
811 
812  DAG dag;
813  for (auto node: essentialGraph) {
815  }
816  for (const Arc& arc: essentialGraph.arcs()) {
817  dag.addArc(arc.tail(), arc.head());
818  }
819 
820  return dag;
821  }
822 
824  // no cycle
825  if (_existsDirectedPath_(graph, xj, xi)) {
826  GUM_TRACE("cycle(" << xi << "-" << xj << ")")
827  return false;
828  }
829 
830  // R1
831  if (!(graph.parents(xi) - graph.adjacents(xj)).empty()) {
832  GUM_TRACE("R1(" << xi << "-" << xj << ")")
833  return true;
834  }
835 
836  // R2
837  if (_existsDirectedPath_(graph, xi, xj)) {
838  GUM_TRACE("R2(" << xi << "-" << xj << ")")
839  return true;
840  }
841 
842  // R3
843  int nbr = 0;
844  for (const auto p: graph.parents(xj)) {
845  if (!graph.mixedOrientedPath(xi, p).empty()) {
846  nbr += 1;
847  if (nbr == 2) {
848  GUM_TRACE("R3(" << xi << "-" << xj << ")")
849  return true;
850  }
851  }
852  }
853  return false;
854  }
855 
857  // then decide the orientation for remaining edges
858  while (!essentialGraph.edges().empty()) {
859  const auto& edge = *(essentialGraph.edges().begin());
860  NodeId root = edge.first();
863  NodeSet stack{root};
864  // check the best root for the set of neighbours
865  while (!stack.empty()) {
866  NodeId next = *(stack.begin());
867  stack.erase(next);
868  if (visited.contains(next)) continue;
871  root = next;
872  }
873  for (const auto n: essentialGraph.neighbours(next))
874  if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
876  }
877  // orientation now
878  visited.clear();
879  stack.clear();
880  stack.insert(root);
881  while (!stack.empty()) {
882  NodeId next = *(stack.begin());
883  stack.erase(next);
884  if (visited.contains(next)) continue;
885  const auto nei = essentialGraph.neighbours(next);
886  for (const auto n: nei) {
887  if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
888  GUM_TRACE(" + amap reasonably orientation for " << n << "->" << next);
891  }
893  }
894  }
895  }
896 
897  /// Propagates the orientation from a node to its neighbours
899  bool res = false;
900  const auto neighbours = graph.neighbours(xj);
901  for (auto& xi: neighbours) {
902  bool i_j = isOrientable_(graph, xi, xj);
903  bool j_i = isOrientable_(graph, xj, xi);
904  if (i_j || j_i) {
905  GUM_TRACE(" + Removing edge (" << xi << "," << xj << ")")
906  graph.eraseEdge(Edge(xi, xj));
907  res = true;
908  }
909  if (i_j) {
910  GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
911  graph.addArc(xi, xj);
913  }
914  if (j_i) {
915  GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
916  graph.addArc(xj, xi);
918  }
919  if (i_j && j_i) {
920  GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
922  }
923  }
924 
925  return res;
926  }
927 
928  /// get the list of arcs hiding latent variables
929  const std::vector< Arc > Miic::latentVariables() const { return _latentCouples_; }
930 
931  /// learns the structure and the parameters of a BN
932  template < typename GUM_SCALAR, typename GRAPH_CHANGES_SELECTOR, typename PARAM_ESTIMATOR >
935  DAG initial_dag) {
938  }
939 
940  void Miic::setMiicBehaviour() { this->_useMiic_ = true; }
941 
942  void Miic::set3of2Behaviour() { this->_useMiic_ = false; }
943 
945  this->_initialMarks_ = constraints;
946  }
947 
949  const NodeId n1,
950  const NodeId n2) {
951  for (const auto parent: graph.parents(n2)) {
952  if (graph.existsArc(parent,
953  n2)) // if there is a double arc, pass
954  continue;
955  if (parent == n1) // trivial directed path => not recognized
956  continue;
957  if (_existsDirectedPath_(graph, n1, parent)) return true;
958  }
959  return false;
960  }
961 
962  bool Miic::_existsDirectedPath_(const MixedGraph& graph, const NodeId n1, const NodeId n2) {
963  // not recursive version => use a FIFO for simulating the recursion
964  List< NodeId > nodeFIFO;
965  // mark[node] = successor if visited, else mark[node] does not exist
966  Set< NodeId > mark;
967 
968  mark.insert(n2);
970 
971  NodeId current;
972 
973  while (!nodeFIFO.empty()) {
974  current = nodeFIFO.front();
975  nodeFIFO.popFront();
976 
977  // check the parents
978  for (const auto new_one: graph.parents(current)) {
979  if (graph.existsArc(current,
980  new_one)) // if there is a double arc, pass
981  continue;
982 
983  if (new_one == n1) { return true; }
984 
985  if (mark.exists(new_one)) // if this node is already marked, do not
986  continue; // check it again
987 
990  }
991  }
992 
993  return false;
994  }
995 
997  HashTable< std::pair< NodeId, NodeId >, char >& marks,
998  NodeId x,
999  NodeId y,
1000  NodeId z,
1001  double p1,
1002  double p2) {
1003  // v-structure discovery
1004  if (marks[{x, z}] == 'o' && marks[{y, z}] == 'o') { // If x-z-y
1006  graph.eraseEdge(Edge(x, z));
1007  graph.addArc(x, z);
1008  GUM_TRACE("1.a Removing edge (" << x << "," << z << ")")
1009  GUM_TRACE("1.a Adding arc (" << x << "," << z << ")")
1010  marks[{x, z}] = '>';
1011  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
1012  GUM_TRACE("Adding latent couple (" << z << "," << x << ")")
1014  }
1015  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1016  } else {
1017  graph.eraseEdge(Edge(x, z));
1018  GUM_TRACE("1.b Adding arc (" << x << "," << z << ")")
1020  graph.addArc(z, x);
1021  GUM_TRACE("1.b Removing edge (" << x << "," << z << ")")
1022  marks[{z, x}] = '>';
1023  }
1024  }
1025 
1027  graph.eraseEdge(Edge(y, z));
1028  graph.addArc(y, z);
1029  GUM_TRACE("1.c Removing edge (" << y << "," << z << ")")
1030  GUM_TRACE("1.c Adding arc (" << y << "," << z << ")")
1031  marks[{y, z}] = '>';
1032  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
1034  }
1035  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1036  } else {
1037  graph.eraseEdge(Edge(y, z));
1038  GUM_TRACE("1.d Removing edge (" << y << "," << z << ")")
1040  graph.addArc(z, y);
1041  GUM_TRACE("1.d Adding arc (" << z << "," << y << ")")
1042  marks[{z, y}] = '>';
1043  }
1044  }
1045  } else if (marks[{x, z}] == '>' && marks[{y, z}] == 'o') { // If x->z-y
1047  graph.eraseEdge(Edge(y, z));
1048  graph.addArc(y, z);
1049  GUM_TRACE("2.a Removing edge (" << y << "," << z << ")")
1050  GUM_TRACE("2.a Adding arc (" << y << "," << z << ")")
1051  marks[{y, z}] = '>';
1052  if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
1054  }
1055  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1056  } else {
1057  graph.eraseEdge(Edge(y, z));
1058  GUM_TRACE("2.b Removing edge (" << y << "," << z << ")")
1060  graph.addArc(z, y);
1061  GUM_TRACE("2.b Adding arc (" << y << "," << z << ")")
1062  marks[{z, y}] = '>';
1063  }
1064  }
1065  } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o') {
1067  graph.eraseEdge(Edge(x, z));
1068  graph.addArc(x, z);
1069  GUM_TRACE("3.a Removing edge (" << x << "," << z << ")")
1070  GUM_TRACE("3.a Adding arc (" << x << "," << z << ")")
1071  marks[{x, z}] = '>';
1072  if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
1074  }
1075  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1076  } else {
1077  graph.eraseEdge(Edge(x, z));
1078  GUM_TRACE("3.b Removing edge (" << x << "," << z << ")")
1080  graph.addArc(z, x);
1081  GUM_TRACE("3.b Adding arc (" << x << "," << z << ")")
1082  marks[{z, x}] = '>';
1083  }
1084  }
1085  }
1086  }
1087 
1088 
1090  HashTable< std::pair< NodeId, NodeId >, char >& marks,
1091  NodeId x,
1092  NodeId y,
1093  NodeId z,
1094  double p1,
1095  double p2) {
1096  // orientation propagation
1097  if (marks[{x, z}] == '>' && marks[{y, z}] == 'o' && marks[{z, y}] != '-') {
1098  graph.eraseEdge(Edge(z, y));
1099  // std::cout << "4. Removing edge (" << z << "," << y << ")" <<
1100  // std::endl;
1101  if (!_existsDirectedPath_(graph, y, z) && graph.parents(y).empty()) {
1102  graph.addArc(z, y);
1103  GUM_TRACE("4.a Adding arc (" << z << "," << y << ")")
1104  marks[{z, y}] = '>';
1105  marks[{y, z}] = '-';
1106  if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1107  } else if (!_existsDirectedPath_(graph, z, y) && graph.parents(z).empty()) {
1108  graph.addArc(y, z);
1109  GUM_TRACE("4.b Adding arc (" << y << "," << z << ")")
1110  marks[{z, y}] = '-';
1111  marks[{y, z}] = '>';
1113  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1114  } else if (!_existsDirectedPath_(graph, y, z)) {
1115  graph.addArc(z, y);
1116  GUM_TRACE("4.c Adding arc (" << z << "," << y << ")")
1117  marks[{z, y}] = '>';
1118  marks[{y, z}] = '-';
1119  if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1120  } else if (!_existsDirectedPath_(graph, z, y)) {
1121  graph.addArc(y, z);
1122  GUM_TRACE("4.d Adding arc (" << y << "," << z << ")")
1124  marks[{z, y}] = '-';
1125  marks[{y, z}] = '>';
1126  if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1127  }
1128  } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o' && marks[{z, x}] != '-') {
1129  graph.eraseEdge(Edge(z, x));
1130  GUM_TRACE("5. Removing edge (" << z << "," << x << ")")
1131  if (!_existsDirectedPath_(graph, x, z) && graph.parents(x).empty()) {
1132  graph.addArc(z, x);
1133  GUM_TRACE("5.a Adding arc (" << z << "," << x << ")")
1134  marks[{z, x}] = '>';
1135  marks[{x, z}] = '-';
1136  if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1137  } else if (!_existsDirectedPath_(graph, z, x) && graph.parents(z).empty()) {
1138  graph.addArc(x, z);
1139  GUM_TRACE("5.b Adding arc (" << x << "," << z << ")")
1140  marks[{z, x}] = '-';
1141  marks[{x, z}] = '>';
1143  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1144  } else if (!_existsDirectedPath_(graph, x, z)) {
1145  graph.addArc(z, x);
1146  GUM_TRACE("5.c Adding arc (" << z << "," << x << ")")
1147  marks[{z, x}] = '>';
1148  marks[{x, z}] = '-';
1149  if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1150  } else if (!_existsDirectedPath_(graph, z, x)) {
1151  graph.addArc(x, z);
1152  GUM_TRACE("5.d Adding arc (" << x << "," << z << ")")
1153  marks[{z, x}] = '-';
1154  marks[{x, z}] = '>';
1156  if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1157  }
1158  }
1159  }
1160 
1161  bool Miic::_isNotLatentCouple_(const NodeId x, const NodeId y) {
1162  const auto& lbeg = _latentCouples_.begin();
1163  const auto& lend = _latentCouples_.end();
1164 
1165  return (std::find(lbeg, lend, Arc(x, y)) == lend)
1166  && (std::find(lbeg, lend, Arc(y, x)) == lend);
1167  }
1168 
1170  return (_initialMarks_.exists({x, y}) && _initialMarks_[{x, y}] == '-');
1171  }
1172  } /* namespace learning */
1173 
1174 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)