aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
Miic.cpp
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief Implementation of gum::learning::ThreeOffTwo
24  *
25  * @author Quentin FALCAND and Pierre-Henri WUILLEMIN(@LIP6)
26  */
27 
28 #include <agrum/tools/core/math/math_utils.h>
29 #include <agrum/tools/core/hashTable.h>
30 #include <agrum/tools/core/heap.h>
31 #include <agrum/tools/core/timer.h>
32 #include <agrum/tools/graphs/mixedGraph.h>
33 #include <agrum/BN/learning/Miic.h>
34 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
35 #include <agrum/tools/stattests/correctedMutualInformation.h>
36 
37 
38 namespace gum {
39 
40  namespace learning {
41 
42  /// default constructor
44 
45  /// default constructor with maxLog
47 
48  /// copy constructor
51  }
52 
53  /// move constructor
56  }
57 
58  /// destructor
60 
61  /// copy operator
62  Miic& Miic::operator=(const Miic& from) {
64  return *this;
65  }
66 
67  /// move operator
70  return *this;
71  }
72 
73 
75  const std::pair<
77  double >& e1,
78  const std::pair<
80  double >& e2) const {
81  return e1.second > e2.second;
82  }
83 
85  const std::pair< std::tuple< NodeId, NodeId, NodeId >*, double >& e1,
86  const std::pair< std::tuple< NodeId, NodeId, NodeId >*, double >& e2)
87  const {
88  return std::abs(e1.second) > std::abs(e2.second);
89  }
90 
92  const std::
93  tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double, double >&
94  e1,
95  const std::
96  tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double, double >&
97  e2) const {
98  double p1xz = std::get< 2 >(e1);
99  double p1yz = std::get< 3 >(e1);
100  double p2xz = std::get< 2 >(e2);
101  double p2yz = std::get< 3 >(e2);
102  double I1 = std::get< 1 >(e1);
103  double I2 = std::get< 1 >(e2);
104  // First, we look at the sign of information.
105  // Then, the probility values
106  // and finally the abs value of information.
107  if ((I1 < 0 && I2 < 0) || (I1 >= 0 && I2 >= 0)) {
108  if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) {
109  return std::abs(I1) > std::abs(I2);
110  } else {
111  return std::max(p1xz, p1yz) > std::max(p2xz, p2yz);
112  }
113  } else {
114  return I1 < I2;
115  }
116  }
117 
118  /// learns the structure of a MixedGraph
120  MixedGraph graph) {
121  timer_.reset();
122  current_step_ = 0;
123 
124  // clear the vector of latent arcs to be sure
126 
127  /// the heap of ranks, with the score, and the NodeIds of x, y and z.
128  Heap<
130  double >,
132  rank_;
133 
134  /// the variables separation sets
136 
138 
140 
141  // std::cout << "Le graphe contient: " << graph.sizeEdges() << " edges." <<
142  // std::endl; std::cout << "En voici la liste: " << graph.edges() <<
143  // std::endl;
144 
145  if (usemiic__) {
147  } else {
149  }
150 
151  return graph;
152  }
153 
154  /*
155  * PHASE 1 : INITIATION
156  *
157  * We go over all edges and test if the variables are independent. If they
158  * are,
159  * the edge is deleted. If not, the best contributor is found.
160  */
163  MixedGraph& graph,
165  Heap<
167  double >,
168  GreaterPairOn2nd >& rank_) {
169  NodeId x, y;
170  EdgeSet edges = graph.edges();
172 
173  for (const Edge& edge: edges) {
174  x = edge.first();
175  y = edge.second();
176  double Ixy = I.score(x, y);
177 
178  if (Ixy <= 0) { //< K
181  } else {
183  }
184 
185  ++current_step_;
186  if (onProgress.hasListener()) {
188  (current_step_ * 33) / steps_init,
189  0.,
190  timer_.step());
191  }
192  }
193  }
194 
195  /*
196  * PHASE 2 : ITERATION
197  *
198  * As long as we find important nodes for edges, we go over them to see if
199  * we can assess the independence of the variables.
200  */
203  MixedGraph& graph,
205  Heap<
207  double >,
208  GreaterPairOn2nd >& rank_) {
209  // if no triples to further examine pass
211  double >
212  best;
213 
216 
217  try {
218  while (rank_.top().second > 0.5) {
219  best = rank_.pop();
220 
221  const NodeId x = std::get< 0 >(*(best.first));
222  const NodeId y = std::get< 1 >(*(best.first));
223  const NodeId z = std::get< 2 >(*(best.first));
224  std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
225 
226  ui.push_back(z);
227  const double Ixy_ui = I.score(x, y, ui);
228  if (Ixy_ui < 0) {
229  graph.eraseEdge(Edge(x, y));
231  } else {
233  }
234 
235  delete best.first;
236 
237  ++current_step_;
238  if (onProgress.hasListener()) {
240  (current_step_ * 66) / (steps_init + steps_iter),
241  0.,
242  timer_.step());
243  }
244  }
245  } catch (...) {} // here, rank is empty
247  if (onProgress.hasListener()) {
248  GUM_EMIT3(onProgress, 66, 0., timer_.step());
249  }
251  }
252 
253  /*
254  * PHASE 3 : ORIENTATION
255  *
256  * Try to assess v-structures and propagate them.
257  */
260  MixedGraph& graph,
261  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
262  sep_set) {
263  std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*, double > >
267 
268  // marks always correspond to the head of the arc/edge. - is for a forbidden
269  // arc, > for a mandatory arc
270  // we start by adding the mandatory arcs
271  for (auto iter = initial_marks__.begin(); iter != initial_marks__.end();
272  ++iter) {
274  && iter.val() == '>') {
277  }
278  }
279 
280  NodeId i = 0;
281  // list of elements that we shouldnt read again, ie elements that are
282  // eligible to
283  // rule 0 after the first time they are tested, and elements on which rule 1
284  // has been applied
285  while (i < triples.size()) {
286  // if i not in do_not_reread
287  std::pair< std::tuple< NodeId, NodeId, NodeId >*, double > triple
288  = triples[i];
289  NodeId x, y, z;
290  x = std::get< 0 >(*triple.first);
291  y = std::get< 1 >(*triple.first);
292  z = std::get< 2 >(*triple.first);
293 
294  std::vector< NodeId > ui;
295  std::pair< NodeId, NodeId > key = {x, y};
296  std::pair< NodeId, NodeId > rev_key = {y, x};
297  if (sep_set.exists(key)) {
298  ui = sep_set[key];
299  } else if (sep_set.exists(rev_key)) {
300  ui = sep_set[rev_key];
301  }
302  double Ixyz_ui = triple.second;
303  bool reset{false};
304  // try Rule 0
305  if (Ixyz_ui < 0) {
306  // if ( z not in Sep[x,y])
307  if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
308  if (!graph.existsArc(x, z) && !graph.existsArc(z, x)) {
309  // when we try to add an arc to the graph, we always verify if
310  // we are allowed to do so, ie it is not a forbidden arc an it
311  // does not create a cycle
312  if (!existsDirectedPath__(graph, z, x)
313  && !(initial_marks__.exists({x, z})
314  && initial_marks__[{x, z}] == '-')) {
315  reset = true;
316  graph.eraseEdge(Edge(x, z));
317  graph.addArc(x, z);
318  } else if (existsDirectedPath__(graph, z, x)
319  && !(initial_marks__.exists({z, x})
320  && initial_marks__[{z, x}] == '-')) {
321  reset = true;
322  graph.eraseEdge(Edge(x, z));
323  // if we find a cycle, we force the competing edge
324  graph.addArc(z, x);
327  Arc(z, x))
328  == latent_couples__.end()) {
330  }
331  }
332  } else if (!graph.existsArc(y, z) && !graph.existsArc(z, y)) {
333  if (!existsDirectedPath__(graph, z, y)
334  && !(initial_marks__.exists({x, z})
335  && initial_marks__[{x, z}] == '-')) {
336  reset = true;
337  graph.eraseEdge(Edge(y, z));
338  graph.addArc(y, z);
339  } else if (existsDirectedPath__(graph, z, y)
340  && !(initial_marks__.exists({z, y})
341  && initial_marks__[{z, y}] == '-')) {
342  reset = true;
343  graph.eraseEdge(Edge(y, z));
344  // if we find a cycle, we force the competing edge
345  graph.addArc(z, y);
348  Arc(z, y))
349  == latent_couples__.end()) {
351  }
352  }
353  } else {
354  // checking if the anti-directed arc already exists, to register a
355  // latent variable
356  if (graph.existsArc(z, x)
359  Arc(z, x))
360  == latent_couples__.end()
363  Arc(x, z))
364  == latent_couples__.end()) {
366  }
367  if (graph.existsArc(z, y)
370  Arc(z, y))
371  == latent_couples__.end()
374  Arc(y, z))
375  == latent_couples__.end()) {
377  }
378  }
379  }
380  } else { // try Rule 1
381  if (graph.existsArc(x, z) && !graph.existsArc(z, y)
382  && !graph.existsArc(y, z)) {
383  if (!existsDirectedPath__(graph, y, z)
384  && !(initial_marks__.exists({z, y})
385  && initial_marks__[{z, y}] == '-')) {
386  reset = true;
387  graph.eraseEdge(Edge(z, y));
388  graph.addArc(z, y);
389  } else if (existsDirectedPath__(graph, y, z)
390  && !(initial_marks__.exists({y, z})
391  && initial_marks__[{y, z}] == '-')) {
392  reset = true;
393  graph.eraseEdge(Edge(z, y));
394  // if we find a cycle, we force the competing edge
395  graph.addArc(y, z);
398  Arc(y, z))
399  == latent_couples__.end()) {
401  }
402  }
403  }
404  if (graph.existsArc(y, z) && !graph.existsArc(z, x)
405  && !graph.existsArc(x, z)) {
406  if (!existsDirectedPath__(graph, x, z)
407  && !(initial_marks__.exists({z, x})
408  && initial_marks__[{z, x}] == '-')) {
409  reset = true;
410  graph.eraseEdge(Edge(z, x));
411  graph.addArc(z, x);
412  } else if (existsDirectedPath__(graph, x, z)
413  && !(initial_marks__.exists({x, z})
414  && initial_marks__[{x, z}] == '-')) {
415  reset = true;
416  graph.eraseEdge(Edge(z, x));
417  // if we find a cycle, we force the competing edge
418  graph.addArc(x, z);
421  Arc(x, z))
422  == latent_couples__.end()) {
424  }
425  }
426  }
427  } // if rule 0 or rule 1
428 
429  // if what we want to add already exists : pass to the next triplet
430  if (reset) {
431  i = 0;
432  } else {
433  ++i;
434  }
435  if (onProgress.hasListener()) {
437  ((current_step_ + i) * 100) / (past_steps + steps_orient),
438  0.,
439  timer_.step());
440  }
441  } // while
442 
443  // erasing the the double headed arcs
444  for (const Arc& arc: latent_couples__) {
445  graph.eraseArc(Arc(arc.head(), arc.tail()));
446  }
447  }
448 
449  /// varient trying to propagate both orientations in a bidirected arc
452  MixedGraph& graph,
453  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
454  sep_set) {
455  std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*, double > >
459 
460  NodeId i = 0;
461  // list of elements that we shouldnt read again, ie elements that are
462  // eligible to
463  // rule 0 after the first time they are tested, and elements on which rule 1
464  // has been applied
465  while (i < triples.size()) {
466  // if i not in do_not_reread
467  std::pair< std::tuple< NodeId, NodeId, NodeId >*, double > triple
468  = triples[i];
469  NodeId x, y, z;
470  x = std::get< 0 >(*triple.first);
471  y = std::get< 1 >(*triple.first);
472  z = std::get< 2 >(*triple.first);
473 
474  std::vector< NodeId > ui;
475  std::pair< NodeId, NodeId > key = {x, y};
476  std::pair< NodeId, NodeId > rev_key = {y, x};
477  if (sep_set.exists(key)) {
478  ui = sep_set[key];
479  } else if (sep_set.exists(rev_key)) {
480  ui = sep_set[rev_key];
481  }
482  double Ixyz_ui = triple.second;
483  // try Rule 0
484  if (Ixyz_ui < 0) {
485  // if ( z not in Sep[x,y])
486  if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
487  // if what we want to add already exists : pass
488  if ((graph.existsArc(x, z) || graph.existsArc(z, x))
489  && (graph.existsArc(y, z) || graph.existsArc(z, y))) {
490  ++i;
491  } else {
492  i = 0;
493  graph.eraseEdge(Edge(x, z));
494  graph.eraseEdge(Edge(y, z));
495  // checking for cycles
496  if (graph.existsArc(z, x)) {
497  graph.eraseArc(Arc(z, x));
498  try {
500  // if we find a cycle, we force the competing edge
502  } catch (gum::NotFound) { graph.addArc(x, z); }
503  graph.addArc(z, x);
504  } else {
505  try {
507  // if we find a cycle, we force the competing edge
508  graph.addArc(z, x);
510  } catch (gum::NotFound) { graph.addArc(x, z); }
511  }
512  if (graph.existsArc(z, y)) {
513  graph.eraseArc(Arc(z, y));
514  try {
516  // if we find a cycle, we force the competing edge
518  } catch (gum::NotFound) { graph.addArc(y, z); }
519  graph.addArc(z, y);
520  } else {
521  try {
523  // if we find a cycle, we force the competing edge
524  graph.addArc(z, y);
526 
527  } catch (gum::NotFound) { graph.addArc(y, z); }
528  }
529  if (graph.existsArc(z, x)
532  Arc(z, x))
533  == latent_couples__.end()
536  Arc(x, z))
537  == latent_couples__.end()) {
539  }
540  if (graph.existsArc(z, y)
543  Arc(z, y))
544  == latent_couples__.end()
547  Arc(y, z))
548  == latent_couples__.end()) {
550  }
551  }
552  } else {
553  ++i;
554  }
555  } else { // try Rule 1
556  bool reset{false};
557  if (graph.existsArc(x, z) && !graph.existsArc(z, y)
558  && !graph.existsArc(y, z)) {
559  reset = true;
560  graph.eraseEdge(Edge(z, y));
561  try {
563  // if we find a cycle, we force the competing edge
564  graph.addArc(y, z);
566  } catch (gum::NotFound) { graph.addArc(z, y); }
567  }
568  if (graph.existsArc(y, z) && !graph.existsArc(z, x)
569  && !graph.existsArc(x, z)) {
570  reset = true;
571  graph.eraseEdge(Edge(z, x));
572  try {
574  // if we find a cycle, we force the competing edge
575  graph.addArc(x, z);
577  } catch (gum::NotFound) { graph.addArc(z, x); }
578  }
579 
580  if (reset) {
581  i = 0;
582  } else {
583  ++i;
584  }
585  } // if rule 0 or rule 1
586  if (onProgress.hasListener()) {
588  ((current_step_ + i) * 100) / (past_steps + steps_orient),
589  0.,
590  timer_.step());
591  }
592  } // while
593 
594  // erasing the the double headed arcs
595  for (const Arc& arc: latent_couples__) {
596  graph.eraseArc(Arc(arc.head(), arc.tail()));
597  }
598  }
599 
600  /// varient using the orientation protocol of MIIC
601  void
603  MixedGraph& graph,
604  const HashTable< std::pair< NodeId, NodeId >,
605  std::vector< NodeId > >& sep_set) {
606  // structure to store the orientations marks -, o, or >,
607  // Considers the head of the arc/edge first node -* second node
609 
610  // marks always correspond to the head of the arc/edge. - is for a forbidden
611  // arc, > for a mandatory arc
612  // we start by adding the mandatory arcs
613  for (auto iter = marks.begin(); iter != marks.end(); ++iter) {
615  && iter.val() == '>') {
618  }
619  }
620 
622  double,
623  double,
624  double > >
626 
629 
630  std::tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double, double >
631  best;
632  if (steps_orient > 0) { best = proba_triples[0]; }
633 
634  while (!proba_triples.empty()
635  && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
636  NodeId x, y, z;
637  x = std::get< 0 >(*std::get< 0 >(best));
638  y = std::get< 1 >(*std::get< 0 >(best));
639  z = std::get< 2 >(*std::get< 0 >(best));
640  // std::cout << "Triple: (" << x << "," << y << "," << z << ")" <<
641  // std::endl;
642 
643  const double i3 = std::get< 1 >(best);
644 
645  if (i3 <= 0) {
646  // v-structure discovery
647  if (marks[{x, z}] == 'o' && marks[{y, z}] == 'o') { // If x-z-y
648  if (!existsDirectedPath__(graph, z, x, false)) {
649  graph.eraseEdge(Edge(x, z));
650  graph.addArc(x, z);
651  // std::cout << "1.a Removing edge (" << x << "," << z << ")" <<
652  // std::endl; std::cout << "1.a Adding arc (" << x << "," << z << ")"
653  // << std::endl;
654  marks[{x, z}] = '>';
655  if (graph.existsArc(z, x)
658  Arc(z, x))
659  == latent_couples__.end()
662  Arc(x, z))
663  == latent_couples__.end()) {
664  // std::cout << "Adding latent couple (" << z << "," << x << ")" <<
665  // std::endl;
667  }
668  if (!arc_probas__.exists(Arc(x, z)))
669  arc_probas__.insert(Arc(x, z), std::get< 2 >(best));
670  } else {
671  graph.eraseEdge(Edge(x, z));
672  // std::cout << "1.b Adding arc (" << x << "," << z << ")" <<
673  // std::endl;
674  if (!existsDirectedPath__(graph, x, z, false)) {
675  graph.addArc(z, x);
676  // std::cout << "1.b Removing edge (" << x << "," << z << ")" <<
677  // std::endl;
678  marks[{z, x}] = '>';
679  }
680  }
681 
682  if (!existsDirectedPath__(graph, z, y, false)) {
683  graph.eraseEdge(Edge(y, z));
684  graph.addArc(y, z);
685  // std::cout << "1.c Removing edge (" << y << "," << z << ")" <<
686  // std::endl; std::cout << "1.c Adding arc (" << y << "," << z << ")"
687  // << std::endl;
688  marks[{y, z}] = '>';
689  if (graph.existsArc(z, y)
692  Arc(z, y))
693  == latent_couples__.end()
696  Arc(y, z))
697  == latent_couples__.end()) {
699  }
700  if (!arc_probas__.exists(Arc(y, z)))
701  arc_probas__.insert(Arc(y, z), std::get< 3 >(best));
702  } else {
703  graph.eraseEdge(Edge(y, z));
704  // std::cout << "1.d Removing edge (" << y << "," << z << ")" <<
705  // std::endl;
706  if (!existsDirectedPath__(graph, y, z, false)) {
707  graph.addArc(z, y);
708  // std::cout << "1.d Adding arc (" << z << "," << y << ")" <<
709  // std::endl;
710  marks[{z, y}] = '>';
711  }
712  }
713  } else if (marks[{x, z}] == '>' && marks[{y, z}] == 'o') { // If x->z-y
714  if (!existsDirectedPath__(graph, z, y, false)) {
715  graph.eraseEdge(Edge(y, z));
716  graph.addArc(y, z);
717  // std::cout << "2.a Removing edge (" << y << "," << z << ")" <<
718  // std::endl; std::cout << "2.a Adding arc (" << y << "," << z << ")"
719  // << std::endl;
720  marks[{y, z}] = '>';
721  if (graph.existsArc(z, y)
724  Arc(z, y))
725  == latent_couples__.end()
728  Arc(y, z))
729  == latent_couples__.end()) {
731  }
732  if (!arc_probas__.exists(Arc(y, z)))
733  arc_probas__.insert(Arc(y, z), std::get< 3 >(best));
734  } else {
735  graph.eraseEdge(Edge(y, z));
736  // std::cout << "2.b Removing edge (" << y << "," << z << ")" <<
737  // std::endl;
738  if (!existsDirectedPath__(graph, y, z, false)) {
739  graph.addArc(z, y);
740  // std::cout << "2.b Adding arc (" << y << "," << z << ")" <<
741  // std::endl;
742  marks[{z, y}] = '>';
743  }
744  }
745  } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o') {
746  if (!existsDirectedPath__(graph, z, x, false)) {
747  graph.eraseEdge(Edge(x, z));
748  graph.addArc(x, z);
749  // std::cout << "3.a Removing edge (" << x << "," << z << ")" <<
750  // std::endl; std::cout << "3.a Adding arc (" << x << "," << z << ")"
751  // << std::endl;
752  marks[{x, z}] = '>';
753  if (graph.existsArc(z, x)
756  Arc(z, x))
757  == latent_couples__.end()
760  Arc(x, z))
761  == latent_couples__.end()) {
763  }
764  if (!arc_probas__.exists(Arc(x, z)))
765  arc_probas__.insert(Arc(x, z), std::get< 2 >(best));
766  } else {
767  graph.eraseEdge(Edge(x, z));
768  // std::cout << "3.b Removing edge (" << x << "," << z << ")" <<
769  // std::endl;
770  if (!existsDirectedPath__(graph, x, z, false)) {
771  graph.addArc(z, x);
772  // std::cout << "3.b Adding arc (" << x << "," << z << ")" <<
773  // std::endl;
774  marks[{z, x}] = '>';
775  }
776  }
777  }
778 
779  } else {
780  // orientation propagation
781  if (marks[{x, z}] == '>' && marks[{y, z}] == 'o'
782  && marks[{z, y}] != '-') {
783  graph.eraseEdge(Edge(z, y));
784  // std::cout << "4. Removing edge (" << z << "," << y << ")" <<
785  // std::endl;
786  if (!existsDirectedPath__(graph, y, z) && graph.parents(y).empty()) {
787  graph.addArc(z, y);
788  // std::cout << "4.a Adding arc (" << z << "," << y << ")" <<
789  // std::endl;
790  marks[{z, y}] = '>';
791  marks[{y, z}] = '-';
792  if (!arc_probas__.exists(Arc(z, y)))
793  arc_probas__.insert(Arc(z, y), std::get< 3 >(best));
794  } else if (!existsDirectedPath__(graph, z, y)
795  && graph.parents(z).empty()) {
796  graph.addArc(y, z);
797  // std::cout << "4.b Adding arc (" << y << "," << z << ")" <<
798  // std::endl;
799  marks[{z, y}] = '-';
800  marks[{y, z}] = '>';
802  if (!arc_probas__.exists(Arc(y, z)))
803  arc_probas__.insert(Arc(y, z), std::get< 3 >(best));
804  } else if (!existsDirectedPath__(graph, y, z)) {
805  graph.addArc(z, y);
806  // std::cout << "4.c Adding arc (" << z << "," << y << ")" <<
807  // std::endl;
808  marks[{z, y}] = '>';
809  marks[{y, z}] = '-';
810  if (!arc_probas__.exists(Arc(z, y)))
811  arc_probas__.insert(Arc(z, y), std::get< 3 >(best));
812  } else if (!existsDirectedPath__(graph, z, y)) {
813  graph.addArc(y, z);
814  // std::cout << "4.d Adding arc (" << y << "," << z << ")" <<
815  // std::endl;
817  marks[{z, y}] = '-';
818  marks[{y, z}] = '>';
819  if (!arc_probas__.exists(Arc(y, z)))
820  arc_probas__.insert(Arc(y, z), std::get< 3 >(best));
821  }
822 
823  } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o'
824  && marks[{z, x}] != '-') {
825  graph.eraseEdge(Edge(z, x));
826  // std::cout << "5. Removing edge (" << z << "," << x << ")" <<
827  // std::endl;
828  if (!existsDirectedPath__(graph, x, z) && graph.parents(x).empty()) {
829  graph.addArc(z, x);
830  // std::cout << "5.a Adding arc (" << z << "," << x << ")" <<
831  // std::endl;
832  marks[{z, x}] = '>';
833  marks[{x, z}] = '-';
834  if (!arc_probas__.exists(Arc(z, x)))
835  arc_probas__.insert(Arc(z, x), std::get< 2 >(best));
836  } else if (!existsDirectedPath__(graph, z, x)
837  && graph.parents(z).empty()) {
838  graph.addArc(x, z);
839  // std::cout << "5.b Adding arc (" << x << "," << z << ")" <<
840  // std::endl;
841  marks[{z, x}] = '-';
842  marks[{x, z}] = '>';
844  if (!arc_probas__.exists(Arc(x, z)))
845  arc_probas__.insert(Arc(x, z), std::get< 2 >(best));
846  } else if (!existsDirectedPath__(graph, x, z)) {
847  graph.addArc(z, x);
848  // std::cout << "5.c Adding arc (" << z << "," << x << ")" <<
849  // std::endl;
850  marks[{z, x}] = '>';
851  marks[{x, z}] = '-';
852  if (!arc_probas__.exists(Arc(z, x)))
853  arc_probas__.insert(Arc(z, x), std::get< 2 >(best));
854  } else if (!existsDirectedPath__(graph, z, x)) {
855  graph.addArc(x, z);
856  // std::cout << "5.d Adding arc (" << x << "," << z << ")" <<
857  // std::endl;
858  marks[{z, x}] = '-';
859  marks[{x, z}] = '>';
861  if (!arc_probas__.exists(Arc(x, z)))
862  arc_probas__.insert(Arc(x, z), std::get< 2 >(best));
863  }
864  }
865  }
866 
867  delete std::get< 0 >(best);
869  // actualisation of the list of triples
871 
872  if (!proba_triples.empty()) best = proba_triples[0];
873 
874  ++current_step_;
875  if (onProgress.hasListener()) {
877  (current_step_ * 100) / (steps_orient + past_steps),
878  0.,
879  timer_.step());
880  }
881  } // while
882 
883  // erasing the double headed arcs
885  ++iter) {
886  graph.eraseArc(Arc(iter->head(), iter->tail()));
887  if (existsDirectedPath__(graph, iter->head(), iter->tail())) {
888  // if we find a cycle, we force the competing edge
889  graph.addArc(iter->head(), iter->tail());
890  graph.eraseArc(Arc(iter->tail(), iter->head()));
891  *iter = Arc(iter->head(), iter->tail());
892  }
893  }
894 
895  if (onProgress.hasListener()) {
896  GUM_EMIT3(onProgress, 100, 0., timer_.step());
897  }
898  }
899 
900  /// finds the best contributor node for a pair given a conditioning set
902  NodeId x,
903  NodeId y,
904  const std::vector< NodeId >& ui,
905  const MixedGraph& graph,
907  Heap<
909  double >,
910  GreaterPairOn2nd >& rank_) {
911  double maxP = -1.0;
912  NodeId maxZ = 0;
913 
914  // compute N
915  //__N = I.N();
916  const double Ixy_ui = I.score(x, y, ui);
917 
918  for (const NodeId z: graph) {
919  // if z!=x and z!=y and z not in ui
920  if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
921  double Pnv;
922  double Pb;
923 
924  // Computing Pnv
925  const double Ixyz_ui = I.score(x, y, z, ui);
926  double calc_expo1 = -Ixyz_ui * M_LN2;
927  // if exponentials are too high or to low, crop them at |__maxLog|
928  if (calc_expo1 > maxLog__) {
929  Pnv = 0.0;
930  } else if (calc_expo1 < -maxLog__) {
931  Pnv = 1.0;
932  } else {
933  Pnv = 1 / (1 + std::exp(calc_expo1));
934  }
935 
936  // Computing Pb
937  const double Ixz_ui = I.score(x, z, ui);
938  const double Iyz_ui = I.score(y, z, ui);
939 
940  calc_expo1 = -(Ixz_ui - Ixy_ui) * M_LN2;
941  double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2;
942 
943  // if exponentials are too high or to low, crop them at maxLog__
944  if (calc_expo1 > maxLog__ || calc_expo2 > maxLog__) {
945  Pb = 0.0;
946  } else if (calc_expo1 < -maxLog__ && calc_expo2 < -maxLog__) {
947  Pb = 1.0;
948  } else {
949  double expo1, expo2;
950  if (calc_expo1 < -maxLog__) {
951  expo1 = 0.0;
952  } else {
953  expo1 = std::exp(calc_expo1);
954  }
955  if (calc_expo2 < -maxLog__) {
956  expo2 = 0.0;
957  } else {
958  expo2 = std::exp(calc_expo2);
959  }
960  Pb = 1 / (1 + expo1 + expo2);
961  }
962 
963  // Getting max(min(Pnv, pb))
964  const double min_pnv_pb = std::min(Pnv, Pb);
965  if (min_pnv_pb > maxP) {
966  maxP = min_pnv_pb;
967  maxZ = z;
968  }
969  } // if z not in (x, y)
970  } // for z in graph.nodes
971  // storing best z in rank_
973  double >
974  final;
975  auto tup
976  = new std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >{x,
977  y,
978  maxZ,
979  ui};
980  final.first = tup;
981  final.second = maxP;
982  rank_.insert(final);
983  }
984 
985  /// gets the list of unshielded triples in the graph in decreasing value of
986  ///|I'(x, y, z|{ui})|
987  std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*, double > >
989  const MixedGraph& graph,
991  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
992  sep_set) {
993  std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*, double > >
994  triples;
995  for (NodeId z: graph) {
996  for (NodeId x: graph.neighbours(z)) {
997  for (NodeId y: graph.neighbours(z)) {
998  if (y < x && !graph.existsEdge(x, y)) {
999  std::vector< NodeId > ui;
1000  std::pair< NodeId, NodeId > key = {x, y};
1001  std::pair< NodeId, NodeId > rev_key = {y, x};
1002  if (sep_set.exists(key)) {
1003  ui = sep_set[key];
1004  } else if (sep_set.exists(rev_key)) {
1005  ui = sep_set[rev_key];
1006  }
1007  // remove z from ui if it's present
1008  const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
1009  if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
1010 
1011  double Ixyz_ui = I.score(x, y, z, ui);
1012  std::pair< std::tuple< NodeId, NodeId, NodeId >*, double > triple;
1013  auto tup = new std::tuple< NodeId, NodeId, NodeId >{x, y, z};
1014  triple.first = tup;
1015  triple.second = Ixyz_ui;
1017  }
1018  }
1019  }
1020  }
1022  return triples;
1023  }
1024 
1025  /// gets the list of unshielded triples in the graph in decreasing value of
1026  ///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
1027  std::vector<
1028  std::
1029  tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double, double > >
1031  const MixedGraph& graph,
1033  const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
1034  sep_set,
1035  HashTable< std::pair< NodeId, NodeId >, char >& marks) {
1037  double,
1038  double,
1039  double > >
1040  triples;
1041  for (NodeId z: graph) {
1042  for (NodeId x: graph.neighbours(z)) {
1043  for (NodeId y: graph.neighbours(z)) {
1044  if (y < x && !graph.existsEdge(x, y)) {
1045  std::vector< NodeId > ui;
1046  std::pair< NodeId, NodeId > key = {x, y};
1047  std::pair< NodeId, NodeId > rev_key = {y, x};
1048  if (sep_set.exists(key)) {
1049  ui = sep_set[key];
1050  } else if (sep_set.exists(rev_key)) {
1051  ui = sep_set[rev_key];
1052  }
1053  // remove z from ui if it's present
1054  const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
1055  if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
1056 
1057  const double Ixyz_ui = I.score(x, y, z, ui);
1058  auto tup = new std::tuple< NodeId, NodeId, NodeId >{x, y, z};
1059  std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
1060  double,
1061  double,
1062  double >
1063  triple{tup, Ixyz_ui, 0.5, 0.5};
1065  if (!marks.exists({x, z})) { marks.insert({x, z}, 'o'); }
1066  if (!marks.exists({z, x})) { marks.insert({z, x}, 'o'); }
1067  if (!marks.exists({y, z})) { marks.insert({y, z}, 'o'); }
1068  if (!marks.exists({z, y})) { marks.insert({z, y}, 'o'); }
1069  }
1070  }
1071  }
1072  }
1075  return triples;
1076  }
1077 
1078  /// Gets the orientation probabilities like MIIC for the orientation phase
1079  std::vector<
1080  std::
1081  tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double, double > >
1083  const MixedGraph& graph,
1085  double,
1086  double,
1087  double > > proba_triples) {
1088  for (auto& triple: proba_triples) {
1089  NodeId x, y, z;
1090  x = std::get< 0 >(*std::get< 0 >(triple));
1091  y = std::get< 1 >(*std::get< 0 >(triple));
1092  z = std::get< 2 >(*std::get< 0 >(triple));
1093  const double Ixyz = std::get< 1 >(triple);
1094  double Pxz = std::get< 2 >(triple);
1095  double Pyz = std::get< 3 >(triple);
1096 
1097  if (Ixyz <= 0) {
1098  const double expo = std::exp(Ixyz);
1099  const double P0 = (1 + expo) / (1 + 3 * expo);
1100  // distinguish betweeen the initialization and the update process
1101  if (Pxz == Pyz && Pyz == 0.5) {
1102  std::get< 2 >(triple) = P0;
1103  std::get< 3 >(triple) = P0;
1104  } else {
1105  if (graph.existsArc(x, z) && Pxz >= P0) {
1106  std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
1107  } else if (graph.existsArc(y, z) && Pyz >= P0) {
1108  std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
1109  }
1110  }
1111  } else {
1112  const double expo = std::exp(-Ixyz);
1113  if (graph.existsArc(x, z) && Pxz >= 0.5) {
1114  std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
1115  } else if (graph.existsArc(y, z) && Pyz >= 0.5) {
1116  std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
1117  }
1118  }
1119  }
1121  return proba_triples;
1122  }
1123 
1124  /// learns the structure of an Bayesian network, ie a DAG, from an Essential
1125  /// graph.
1129  // std::cout << "Le mixed graph mesdames et messieurs: "
1130  //<< essentialGraph.toDot() << std::endl;
1131 
1132  // Second, orientate remaining edges
1134  // first, propagate existing orientations
1135  for (NodeId x: order) {
1136  if (!essentialGraph.parents(x).empty()) {
1138  }
1139  }
1140  // std::cout << "Le mixed graph après une première propagation mesdames et
1141  // messieurs: "
1142  //<< essentialGraph.toDot() << std::endl;
1143  // then decide the orientation by the topological order and propagate them
1144  for (NodeId x: order) {
1145  if (!essentialGraph.neighbours(x).empty()) {
1147  }
1148  }
1149 
1150  // std::cout << "Le mixed graph après une deuxième propagation mesdames et
1151  // messieurs: "
1152  //<< essentialGraph.toDot() << std::endl;
1153  // std::cout << "Le graphe contient maintenant : " <<
1154  // essentialGraph.sizeArcs() << " arcs."
1155  //<< std::endl;
1156  // std::cout << "Que voici: " << essentialGraph.arcs() << std::endl;
1157  // turn the mixed graph into a dag
1158  DAG dag;
1159  for (auto node: essentialGraph) {
1161  }
1162  for (const Arc& arc: essentialGraph.arcs()) {
1163  dag.addArc(arc.tail(), arc.head());
1164  }
1165 
1166  return dag;
1167  }
1168 
1169  /// Propagates the orientation from a node to its neighbours
1171  const auto neighbours = graph.neighbours(node);
1172  for (auto& neighbour: neighbours) {
1173  // std::cout << "Orientation de l'edge (" << node << "," << neighbour <<
1174  // ")" << std::endl;
1178  && initial_marks__[{node, neighbour}] == '-')
1179  && graph.parents(neighbour).empty()) {
1182  // std::cout << "1. Removing edge (" << neighbour << "," << node << ")"
1183  // << std::endl; std::cout << "1. Adding arc (" << node << "," <<
1184  // neighbour << ")" << std::endl;
1186  } else if (!existsDirectedPath__(graph, node, neighbour)
1188  && initial_marks__[{neighbour, node}] == '-')
1189  && graph.parents(node).empty()) {
1192  // std::cout << "2. Removing edge (" << neighbour << "," << node << ")"
1193  // << std::endl; std::cout << "2. Adding arc (" << neighbour << "," <<
1194  // node << ")" << std::endl;
1195  } else if (!existsDirectedPath__(graph, node, neighbour)
1197  && initial_marks__[{neighbour, node}] == '-')) {
1200  if (!graph.parents(neighbour).empty()
1201  && !graph.parents(node).empty()) {
1203  }
1204 
1205  // std::cout << "3. Removing edge (" << neighbour << "," << node << ")"
1206  // << std::endl; std::cout << "3. Adding arc (" << neighbour << "," <<
1207  // node << ")" << std::endl;
1208  } else if (!existsDirectedPath__(graph, neighbour, node)
1210  && initial_marks__[{node, neighbour}] == '-')) {
1213  if (!graph.parents(neighbour).empty()
1214  && !graph.parents(node).empty()) {
1216  }
1217  // std::cout << "4. Removing edge (" << neighbour << "," << node << ")"
1218  // << std::endl; std::cout << "4. Adding arc (" << node << "," <<
1219  // neighbour << ")" << std::endl;
1220  }
1221  // else if (!graph.parents(neighbour).empty()
1222  //&& !graph.parents(node).empty()) {
1223  // graph.eraseEdge(Edge(neighbour, node));
1224  // graph.addArc(node, neighbour);
1225  //__latent_couples.push_back(Arc(node, neighbour));
1226  //}
1227  else {
1229  // std::cout << "5. Removing edge (" << neighbour << "," << node << ")"
1230  // << std::endl;
1231  }
1232  }
1233  }
1234  }
1235 
1236  /// get the list of arcs hiding latent variables
1237  const std::vector< Arc > Miic::latentVariables() const {
1238  return latent_couples__;
1239  }
1240 
1241  /// learns the structure and the parameters of a BN
1242  template < typename GUM_SCALAR,
1243  typename GRAPH_CHANGES_SELECTOR,
1244  typename PARAM_ESTIMATOR >
1247  DAG initial_dag) {
1248  return DAG2BNLearner<>::createBN< GUM_SCALAR >(
1249  estimator,
1251  }
1252 
1253  void Miic::setMiicBehaviour() { this->usemiic__ = true; }
1254  void Miic::set3off2Behaviour() { this->usemiic__ = false; }
1255 
1257  HashTable< std::pair< NodeId, NodeId >, char > constraints) {
1258  this->initial_marks__ = constraints;
1259  }
1260 
1261 
1263  const NodeId n1,
1264  const NodeId n2,
1265  const bool countArc) const {
1266  // not recursive version => use a FIFO for simulating the recursion
1267  List< NodeId > nodeFIFO;
1268  nodeFIFO.pushBack(n2);
1269 
1270  // mark[node] = successor if visited, else mark[node] does not exist
1272  mark.insert(n2, n2);
1273 
1274  NodeId current;
1275 
1276  while (!nodeFIFO.empty()) {
1277  current = nodeFIFO.front();
1278  nodeFIFO.popFront();
1279 
1280  // check the parents
1281 
1282  for (const auto new_one: graph.parents(current)) {
1283  if (!countArc && current == n2
1284  && new_one == n1) // If countArc is set to false
1285  continue; // paths of length 1 are ignored
1286 
1287  if (mark.exists(new_one)) // if this node is already marked, do not
1288  continue; // check it again
1289 
1290  if (graph.existsArc(current,
1291  new_one)) // if there is a double arc, pass
1292  continue;
1293 
1295 
1296  if (new_one == n1) { return true; }
1297 
1299  }
1300  }
1301 
1302  return false;
1303  }
1304 
1305  } /* namespace learning */
1306 
1307 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)