aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
Miic.cpp
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief Implementation of gum::learning::ThreeOffTwo
24
*
25
* @author Quentin FALCAND and Pierre-Henri WUILLEMIN(@LIP6)
26
*/
27
28
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
29
#
include
<
agrum
/
tools
/
core
/
hashTable
.
h
>
30
#
include
<
agrum
/
tools
/
core
/
heap
.
h
>
31
#
include
<
agrum
/
tools
/
core
/
timer
.
h
>
32
#
include
<
agrum
/
tools
/
graphs
/
mixedGraph
.
h
>
33
#
include
<
agrum
/
BN
/
learning
/
Miic
.
h
>
34
#
include
<
agrum
/
BN
/
learning
/
paramUtils
/
DAG2BNLearner
.
h
>
35
#
include
<
agrum
/
tools
/
stattests
/
correctedMutualInformation
.
h
>
36
37
38
namespace
gum
{
39
40
namespace
learning
{
41
42
/// default constructor
43
Miic
::
Miic
() {
GUM_CONSTRUCTOR
(
Miic
); }
44
45
/// default constructor with maxLog
46
Miic
::
Miic
(
int
maxLog
) :
maxLog__
(
maxLog
) {
GUM_CONSTRUCTOR
(
Miic
); }
47
48
/// copy constructor
49
Miic
::
Miic
(
const
Miic
&
from
) :
ApproximationScheme
(
from
) {
50
GUM_CONS_CPY
(
Miic
);
51
}
52
53
/// move constructor
54
Miic
::
Miic
(
Miic
&&
from
) :
ApproximationScheme
(
std
::
move
(
from
)) {
55
GUM_CONS_MOV
(
Miic
);
56
}
57
58
/// destructor
59
Miic
::~
Miic
() {
GUM_DESTRUCTOR
(
Miic
); }
60
61
/// copy operator
62
Miic
&
Miic
::
operator
=(
const
Miic
&
from
) {
63
ApproximationScheme
::
operator
=(
from
);
64
return
*
this
;
65
}
66
67
/// move operator
68
Miic
&
Miic
::
operator
=(
Miic
&&
from
) {
69
ApproximationScheme
::
operator
=(
std
::
move
(
from
));
70
return
*
this
;
71
}
72
73
74
bool
GreaterPairOn2nd
::
operator
()(
75
const
std
::
pair
<
76
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
77
double
>&
e1
,
78
const
std
::
pair
<
79
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
80
double
>&
e2
)
const
{
81
return
e1
.
second
>
e2
.
second
;
82
}
83
84
bool
GreaterAbsPairOn2nd
::
operator
()(
85
const
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
>&
e1
,
86
const
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
>&
e2
)
87
const
{
88
return
std
::
abs
(
e1
.
second
) >
std
::
abs
(
e2
.
second
);
89
}
90
91
bool
GreaterTupleOnLast
::
operator
()(
92
const
std
::
93
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
,
double
,
double
>&
94
e1
,
95
const
std
::
96
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
,
double
,
double
>&
97
e2
)
const
{
98
double
p1xz
=
std
::
get
< 2 >(
e1
);
99
double
p1yz
=
std
::
get
< 3 >(
e1
);
100
double
p2xz
=
std
::
get
< 2 >(
e2
);
101
double
p2yz
=
std
::
get
< 3 >(
e2
);
102
double
I1
=
std
::
get
< 1 >(
e1
);
103
double
I2
=
std
::
get
< 1 >(
e2
);
104
// First, we look at the sign of information.
105
// Then, the probility values
106
// and finally the abs value of information.
107
if
((
I1
< 0 &&
I2
< 0) || (
I1
>= 0 &&
I2
>= 0)) {
108
if
(
std
::
max
(
p1xz
,
p1yz
) ==
std
::
max
(
p2xz
,
p2yz
)) {
109
return
std
::
abs
(
I1
) >
std
::
abs
(
I2
);
110
}
else
{
111
return
std
::
max
(
p1xz
,
p1yz
) >
std
::
max
(
p2xz
,
p2yz
);
112
}
113
}
else
{
114
return
I1
<
I2
;
115
}
116
}
117
118
/// learns the structure of a MixedGraph
119
MixedGraph
Miic
::
learnMixedStructure
(
CorrectedMutualInformation
<>&
I
,
120
MixedGraph
graph
) {
121
timer_
.
reset
();
122
current_step_
= 0;
123
124
// clear the vector of latent arcs to be sure
125
latent_couples__
.
clear
();
126
127
/// the heap of ranks, with the score, and the NodeIds of x, y and z.
128
Heap
<
129
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
130
double
>,
131
GreaterPairOn2nd
>
132
rank_
;
133
134
/// the variables separation sets
135
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >
sep_set
;
136
137
initiation_
(
I
,
graph
,
sep_set
,
rank_
);
138
139
iteration_
(
I
,
graph
,
sep_set
,
rank_
);
140
141
// std::cout << "Le graphe contient: " << graph.sizeEdges() << " edges." <<
142
// std::endl; std::cout << "En voici la liste: " << graph.edges() <<
143
// std::endl;
144
145
if
(
usemiic__
) {
146
orientation_miic_
(
I
,
graph
,
sep_set
);
147
}
else
{
148
orientation_3off2_
(
I
,
graph
,
sep_set
);
149
}
150
151
return
graph
;
152
}
153
154
/*
155
* PHASE 1 : INITIATION
156
*
157
* We go over all edges and test if the variables are independent. If they
158
* are,
159
* the edge is deleted. If not, the best contributor is found.
160
*/
161
void
Miic
::
initiation_
(
162
CorrectedMutualInformation
<>&
I
,
163
MixedGraph
&
graph
,
164
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sep_set
,
165
Heap
<
166
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
167
double
>,
168
GreaterPairOn2nd
>&
rank_
) {
169
NodeId
x
,
y
;
170
EdgeSet
edges
=
graph
.
edges
();
171
Size
steps_init
=
edges
.
size
();
172
173
for
(
const
Edge
&
edge
:
edges
) {
174
x
=
edge
.
first
();
175
y
=
edge
.
second
();
176
double
Ixy
=
I
.
score
(
x
,
y
);
177
178
if
(
Ixy
<= 0) {
//< K
179
graph
.
eraseEdge
(
edge
);
180
sep_set
.
insert
(
std
::
make_pair
(
x
,
y
),
empty_set__
);
181
}
else
{
182
findBestContributor_
(
x
,
y
,
empty_set__
,
graph
,
I
,
rank_
);
183
}
184
185
++
current_step_
;
186
if
(
onProgress
.
hasListener
()) {
187
GUM_EMIT3
(
onProgress
,
188
(
current_step_
* 33) /
steps_init
,
189
0.,
190
timer_
.
step
());
191
}
192
}
193
}
194
195
/*
196
* PHASE 2 : ITERATION
197
*
198
* As long as we find important nodes for edges, we go over them to see if
199
* we can assess the independence of the variables.
200
*/
201
void
Miic
::
iteration_
(
202
CorrectedMutualInformation
<>&
I
,
203
MixedGraph
&
graph
,
204
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sep_set
,
205
Heap
<
206
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
207
double
>,
208
GreaterPairOn2nd
>&
rank_
) {
209
// if no triples to further examine pass
210
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
211
double
>
212
best
;
213
214
Size
steps_init
=
current_step_
;
215
Size
steps_iter
=
rank_
.
size
();
216
217
try
{
218
while
(
rank_
.
top
().
second
> 0.5) {
219
best
=
rank_
.
pop
();
220
221
const
NodeId
x
=
std
::
get
< 0 >(*(
best
.
first
));
222
const
NodeId
y
=
std
::
get
< 1 >(*(
best
.
first
));
223
const
NodeId
z
=
std
::
get
< 2 >(*(
best
.
first
));
224
std
::
vector
<
NodeId
>
ui
=
std
::
move
(
std
::
get
< 3 >(*(
best
.
first
)));
225
226
ui
.
push_back
(
z
);
227
const
double
Ixy_ui
=
I
.
score
(
x
,
y
,
ui
);
228
if
(
Ixy_ui
< 0) {
229
graph
.
eraseEdge
(
Edge
(
x
,
y
));
230
sep_set
.
insert
(
std
::
make_pair
(
x
,
y
),
std
::
move
(
ui
));
231
}
else
{
232
findBestContributor_
(
x
,
y
,
ui
,
graph
,
I
,
rank_
);
233
}
234
235
delete
best
.
first
;
236
237
++
current_step_
;
238
if
(
onProgress
.
hasListener
()) {
239
GUM_EMIT3
(
onProgress
,
240
(
current_step_
* 66) / (
steps_init
+
steps_iter
),
241
0.,
242
timer_
.
step
());
243
}
244
}
245
}
catch
(...) {}
// here, rank is empty
246
current_step_
=
steps_init
+
steps_iter
;
247
if
(
onProgress
.
hasListener
()) {
248
GUM_EMIT3
(
onProgress
, 66, 0.,
timer_
.
step
());
249
}
250
current_step_
=
steps_init
+
steps_iter
;
251
}
252
253
/*
254
* PHASE 3 : ORIENTATION
255
*
256
* Try to assess v-structures and propagate them.
257
*/
258
void
Miic
::
orientation_3off2_
(
259
CorrectedMutualInformation
<>&
I
,
260
MixedGraph
&
graph
,
261
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
262
sep_set
) {
263
std
::
vector
<
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
> >
264
triples
=
getUnshieldedTriples_
(
graph
,
I
,
sep_set
);
265
Size
steps_orient
=
triples
.
size
();
266
Size
past_steps
=
current_step_
;
267
268
// marks always correspond to the head of the arc/edge. - is for a forbidden
269
// arc, > for a mandatory arc
270
// we start by adding the mandatory arcs
271
for
(
auto
iter
=
initial_marks__
.
begin
();
iter
!=
initial_marks__
.
end
();
272
++
iter
) {
273
if
(
graph
.
existsEdge
(
iter
.
key
().
first
,
iter
.
key
().
second
)
274
&&
iter
.
val
() ==
'>'
) {
275
graph
.
eraseEdge
(
Edge
(
iter
.
key
().
first
,
iter
.
key
().
second
));
276
graph
.
addArc
(
iter
.
key
().
first
,
iter
.
key
().
second
);
277
}
278
}
279
280
NodeId
i
= 0;
281
// list of elements that we shouldnt read again, ie elements that are
282
// eligible to
283
// rule 0 after the first time they are tested, and elements on which rule 1
284
// has been applied
285
while
(
i
<
triples
.
size
()) {
286
// if i not in do_not_reread
287
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
>
triple
288
=
triples
[
i
];
289
NodeId
x
,
y
,
z
;
290
x
=
std
::
get
< 0 >(*
triple
.
first
);
291
y
=
std
::
get
< 1 >(*
triple
.
first
);
292
z
=
std
::
get
< 2 >(*
triple
.
first
);
293
294
std
::
vector
<
NodeId
>
ui
;
295
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
296
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
297
if
(
sep_set
.
exists
(
key
)) {
298
ui
=
sep_set
[
key
];
299
}
else
if
(
sep_set
.
exists
(
rev_key
)) {
300
ui
=
sep_set
[
rev_key
];
301
}
302
double
Ixyz_ui
=
triple
.
second
;
303
bool
reset
{
false
};
304
// try Rule 0
305
if
(
Ixyz_ui
< 0) {
306
// if ( z not in Sep[x,y])
307
if
(
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
308
if
(!
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
x
)) {
309
// when we try to add an arc to the graph, we always verify if
310
// we are allowed to do so, ie it is not a forbidden arc an it
311
// does not create a cycle
312
if
(!
existsDirectedPath__
(
graph
,
z
,
x
)
313
&& !(
initial_marks__
.
exists
({
x
,
z
})
314
&&
initial_marks__
[{
x
,
z
}] ==
'-'
)) {
315
reset
=
true
;
316
graph
.
eraseEdge
(
Edge
(
x
,
z
));
317
graph
.
addArc
(
x
,
z
);
318
}
else
if
(
existsDirectedPath__
(
graph
,
z
,
x
)
319
&& !(
initial_marks__
.
exists
({
z
,
x
})
320
&&
initial_marks__
[{
z
,
x
}] ==
'-'
)) {
321
reset
=
true
;
322
graph
.
eraseEdge
(
Edge
(
x
,
z
));
323
// if we find a cycle, we force the competing edge
324
graph
.
addArc
(
z
,
x
);
325
if
(
std
::
find
(
latent_couples__
.
begin
(),
326
latent_couples__
.
end
(),
327
Arc
(
z
,
x
))
328
==
latent_couples__
.
end
()) {
329
latent_couples__
.
push_back
(
Arc
(
z
,
x
));
330
}
331
}
332
}
else
if
(!
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
y
)) {
333
if
(!
existsDirectedPath__
(
graph
,
z
,
y
)
334
&& !(
initial_marks__
.
exists
({
x
,
z
})
335
&&
initial_marks__
[{
x
,
z
}] ==
'-'
)) {
336
reset
=
true
;
337
graph
.
eraseEdge
(
Edge
(
y
,
z
));
338
graph
.
addArc
(
y
,
z
);
339
}
else
if
(
existsDirectedPath__
(
graph
,
z
,
y
)
340
&& !(
initial_marks__
.
exists
({
z
,
y
})
341
&&
initial_marks__
[{
z
,
y
}] ==
'-'
)) {
342
reset
=
true
;
343
graph
.
eraseEdge
(
Edge
(
y
,
z
));
344
// if we find a cycle, we force the competing edge
345
graph
.
addArc
(
z
,
y
);
346
if
(
std
::
find
(
latent_couples__
.
begin
(),
347
latent_couples__
.
end
(),
348
Arc
(
z
,
y
))
349
==
latent_couples__
.
end
()) {
350
latent_couples__
.
push_back
(
Arc
(
z
,
y
));
351
}
352
}
353
}
else
{
354
// checking if the anti-directed arc already exists, to register a
355
// latent variable
356
if
(
graph
.
existsArc
(
z
,
x
)
357
&&
std
::
find
(
latent_couples__
.
begin
(),
358
latent_couples__
.
end
(),
359
Arc
(
z
,
x
))
360
==
latent_couples__
.
end
()
361
&&
std
::
find
(
latent_couples__
.
begin
(),
362
latent_couples__
.
end
(),
363
Arc
(
x
,
z
))
364
==
latent_couples__
.
end
()) {
365
latent_couples__
.
push_back
(
Arc
(
z
,
x
));
366
}
367
if
(
graph
.
existsArc
(
z
,
y
)
368
&&
std
::
find
(
latent_couples__
.
begin
(),
369
latent_couples__
.
end
(),
370
Arc
(
z
,
y
))
371
==
latent_couples__
.
end
()
372
&&
std
::
find
(
latent_couples__
.
begin
(),
373
latent_couples__
.
end
(),
374
Arc
(
y
,
z
))
375
==
latent_couples__
.
end
()) {
376
latent_couples__
.
push_back
(
Arc
(
z
,
y
));
377
}
378
}
379
}
380
}
else
{
// try Rule 1
381
if
(
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
y
)
382
&& !
graph
.
existsArc
(
y
,
z
)) {
383
if
(!
existsDirectedPath__
(
graph
,
y
,
z
)
384
&& !(
initial_marks__
.
exists
({
z
,
y
})
385
&&
initial_marks__
[{
z
,
y
}] ==
'-'
)) {
386
reset
=
true
;
387
graph
.
eraseEdge
(
Edge
(
z
,
y
));
388
graph
.
addArc
(
z
,
y
);
389
}
else
if
(
existsDirectedPath__
(
graph
,
y
,
z
)
390
&& !(
initial_marks__
.
exists
({
y
,
z
})
391
&&
initial_marks__
[{
y
,
z
}] ==
'-'
)) {
392
reset
=
true
;
393
graph
.
eraseEdge
(
Edge
(
z
,
y
));
394
// if we find a cycle, we force the competing edge
395
graph
.
addArc
(
y
,
z
);
396
if
(
std
::
find
(
latent_couples__
.
begin
(),
397
latent_couples__
.
end
(),
398
Arc
(
y
,
z
))
399
==
latent_couples__
.
end
()) {
400
latent_couples__
.
push_back
(
Arc
(
y
,
z
));
401
}
402
}
403
}
404
if
(
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
x
)
405
&& !
graph
.
existsArc
(
x
,
z
)) {
406
if
(!
existsDirectedPath__
(
graph
,
x
,
z
)
407
&& !(
initial_marks__
.
exists
({
z
,
x
})
408
&&
initial_marks__
[{
z
,
x
}] ==
'-'
)) {
409
reset
=
true
;
410
graph
.
eraseEdge
(
Edge
(
z
,
x
));
411
graph
.
addArc
(
z
,
x
);
412
}
else
if
(
existsDirectedPath__
(
graph
,
x
,
z
)
413
&& !(
initial_marks__
.
exists
({
x
,
z
})
414
&&
initial_marks__
[{
x
,
z
}] ==
'-'
)) {
415
reset
=
true
;
416
graph
.
eraseEdge
(
Edge
(
z
,
x
));
417
// if we find a cycle, we force the competing edge
418
graph
.
addArc
(
x
,
z
);
419
if
(
std
::
find
(
latent_couples__
.
begin
(),
420
latent_couples__
.
end
(),
421
Arc
(
x
,
z
))
422
==
latent_couples__
.
end
()) {
423
latent_couples__
.
push_back
(
Arc
(
x
,
z
));
424
}
425
}
426
}
427
}
// if rule 0 or rule 1
428
429
// if what we want to add already exists : pass to the next triplet
430
if
(
reset
) {
431
i
= 0;
432
}
else
{
433
++
i
;
434
}
435
if
(
onProgress
.
hasListener
()) {
436
GUM_EMIT3
(
onProgress
,
437
((
current_step_
+
i
) * 100) / (
past_steps
+
steps_orient
),
438
0.,
439
timer_
.
step
());
440
}
441
}
// while
442
443
// erasing the the double headed arcs
444
for
(
const
Arc
&
arc
:
latent_couples__
) {
445
graph
.
eraseArc
(
Arc
(
arc
.
head
(),
arc
.
tail
()));
446
}
447
}
448
449
/// varient trying to propagate both orientations in a bidirected arc
450
void
Miic
::
orientation_latents_
(
451
CorrectedMutualInformation
<>&
I
,
452
MixedGraph
&
graph
,
453
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
454
sep_set
) {
455
std
::
vector
<
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
> >
456
triples
=
getUnshieldedTriples_
(
graph
,
I
,
sep_set
);
457
Size
steps_orient
=
triples
.
size
();
458
Size
past_steps
=
current_step_
;
459
460
NodeId
i
= 0;
461
// list of elements that we shouldnt read again, ie elements that are
462
// eligible to
463
// rule 0 after the first time they are tested, and elements on which rule 1
464
// has been applied
465
while
(
i
<
triples
.
size
()) {
466
// if i not in do_not_reread
467
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
>
triple
468
=
triples
[
i
];
469
NodeId
x
,
y
,
z
;
470
x
=
std
::
get
< 0 >(*
triple
.
first
);
471
y
=
std
::
get
< 1 >(*
triple
.
first
);
472
z
=
std
::
get
< 2 >(*
triple
.
first
);
473
474
std
::
vector
<
NodeId
>
ui
;
475
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
476
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
477
if
(
sep_set
.
exists
(
key
)) {
478
ui
=
sep_set
[
key
];
479
}
else
if
(
sep_set
.
exists
(
rev_key
)) {
480
ui
=
sep_set
[
rev_key
];
481
}
482
double
Ixyz_ui
=
triple
.
second
;
483
// try Rule 0
484
if
(
Ixyz_ui
< 0) {
485
// if ( z not in Sep[x,y])
486
if
(
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
487
// if what we want to add already exists : pass
488
if
((
graph
.
existsArc
(
x
,
z
) ||
graph
.
existsArc
(
z
,
x
))
489
&& (
graph
.
existsArc
(
y
,
z
) ||
graph
.
existsArc
(
z
,
y
))) {
490
++
i
;
491
}
else
{
492
i
= 0;
493
graph
.
eraseEdge
(
Edge
(
x
,
z
));
494
graph
.
eraseEdge
(
Edge
(
y
,
z
));
495
// checking for cycles
496
if
(
graph
.
existsArc
(
z
,
x
)) {
497
graph
.
eraseArc
(
Arc
(
z
,
x
));
498
try
{
499
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
x
);
500
// if we find a cycle, we force the competing edge
501
latent_couples__
.
push_back
(
Arc
(
z
,
x
));
502
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
x
,
z
); }
503
graph
.
addArc
(
z
,
x
);
504
}
else
{
505
try
{
506
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
x
);
507
// if we find a cycle, we force the competing edge
508
graph
.
addArc
(
z
,
x
);
509
latent_couples__
.
push_back
(
Arc
(
z
,
x
));
510
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
x
,
z
); }
511
}
512
if
(
graph
.
existsArc
(
z
,
y
)) {
513
graph
.
eraseArc
(
Arc
(
z
,
y
));
514
try
{
515
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
y
);
516
// if we find a cycle, we force the competing edge
517
latent_couples__
.
push_back
(
Arc
(
z
,
y
));
518
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
y
,
z
); }
519
graph
.
addArc
(
z
,
y
);
520
}
else
{
521
try
{
522
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
y
);
523
// if we find a cycle, we force the competing edge
524
graph
.
addArc
(
z
,
y
);
525
latent_couples__
.
push_back
(
Arc
(
z
,
y
));
526
527
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
y
,
z
); }
528
}
529
if
(
graph
.
existsArc
(
z
,
x
)
530
&&
std
::
find
(
latent_couples__
.
begin
(),
531
latent_couples__
.
end
(),
532
Arc
(
z
,
x
))
533
==
latent_couples__
.
end
()
534
&&
std
::
find
(
latent_couples__
.
begin
(),
535
latent_couples__
.
end
(),
536
Arc
(
x
,
z
))
537
==
latent_couples__
.
end
()) {
538
latent_couples__
.
push_back
(
Arc
(
z
,
x
));
539
}
540
if
(
graph
.
existsArc
(
z
,
y
)
541
&&
std
::
find
(
latent_couples__
.
begin
(),
542
latent_couples__
.
end
(),
543
Arc
(
z
,
y
))
544
==
latent_couples__
.
end
()
545
&&
std
::
find
(
latent_couples__
.
begin
(),
546
latent_couples__
.
end
(),
547
Arc
(
y
,
z
))
548
==
latent_couples__
.
end
()) {
549
latent_couples__
.
push_back
(
Arc
(
z
,
y
));
550
}
551
}
552
}
else
{
553
++
i
;
554
}
555
}
else
{
// try Rule 1
556
bool
reset
{
false
};
557
if
(
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
y
)
558
&& !
graph
.
existsArc
(
y
,
z
)) {
559
reset
=
true
;
560
graph
.
eraseEdge
(
Edge
(
z
,
y
));
561
try
{
562
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
y
,
z
);
563
// if we find a cycle, we force the competing edge
564
graph
.
addArc
(
y
,
z
);
565
latent_couples__
.
push_back
(
Arc
(
y
,
z
));
566
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
z
,
y
); }
567
}
568
if
(
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
x
)
569
&& !
graph
.
existsArc
(
x
,
z
)) {
570
reset
=
true
;
571
graph
.
eraseEdge
(
Edge
(
z
,
x
));
572
try
{
573
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
x
,
z
);
574
// if we find a cycle, we force the competing edge
575
graph
.
addArc
(
x
,
z
);
576
latent_couples__
.
push_back
(
Arc
(
x
,
z
));
577
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
z
,
x
); }
578
}
579
580
if
(
reset
) {
581
i
= 0;
582
}
else
{
583
++
i
;
584
}
585
}
// if rule 0 or rule 1
586
if
(
onProgress
.
hasListener
()) {
587
GUM_EMIT3
(
onProgress
,
588
((
current_step_
+
i
) * 100) / (
past_steps
+
steps_orient
),
589
0.,
590
timer_
.
step
());
591
}
592
}
// while
593
594
// erasing the the double headed arcs
595
for
(
const
Arc
&
arc
:
latent_couples__
) {
596
graph
.
eraseArc
(
Arc
(
arc
.
head
(),
arc
.
tail
()));
597
}
598
}
599
600
/// varient using the orientation protocol of MIIC
601
void
602
Miic
::
orientation_miic_
(
CorrectedMutualInformation
<>&
I
,
603
MixedGraph
&
graph
,
604
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
605
std
::
vector
<
NodeId
> >&
sep_set
) {
606
// structure to store the orientations marks -, o, or >,
607
// Considers the head of the arc/edge first node -* second node
608
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>
marks
=
initial_marks__
;
609
610
// marks always correspond to the head of the arc/edge. - is for a forbidden
611
// arc, > for a mandatory arc
612
// we start by adding the mandatory arcs
613
for
(
auto
iter
=
marks
.
begin
();
iter
!=
marks
.
end
(); ++
iter
) {
614
if
(
graph
.
existsEdge
(
iter
.
key
().
first
,
iter
.
key
().
second
)
615
&&
iter
.
val
() ==
'>'
) {
616
graph
.
eraseEdge
(
Edge
(
iter
.
key
().
first
,
iter
.
key
().
second
));
617
graph
.
addArc
(
iter
.
key
().
first
,
iter
.
key
().
second
);
618
}
619
}
620
621
std
::
vector
<
std
::
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
622
double
,
623
double
,
624
double
> >
625
proba_triples
=
getUnshieldedTriplesMIIC_
(
graph
,
I
,
sep_set
,
marks
);
626
627
Size
steps_orient
=
proba_triples
.
size
();
628
Size
past_steps
=
current_step_
;
629
630
std
::
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
,
double
,
double
>
631
best
;
632
if
(
steps_orient
> 0) {
best
=
proba_triples
[0]; }
633
634
while
(!
proba_triples
.
empty
()
635
&&
std
::
max
(
std
::
get
< 2 >(
best
),
std
::
get
< 3 >(
best
)) > 0.5) {
636
NodeId
x
,
y
,
z
;
637
x
=
std
::
get
< 0 >(*
std
::
get
< 0 >(
best
));
638
y
=
std
::
get
< 1 >(*
std
::
get
< 0 >(
best
));
639
z
=
std
::
get
< 2 >(*
std
::
get
< 0 >(
best
));
640
// std::cout << "Triple: (" << x << "," << y << "," << z << ")" <<
641
// std::endl;
642
643
const
double
i3
=
std
::
get
< 1 >(
best
);
644
645
if
(
i3
<= 0) {
646
// v-structure discovery
647
if
(
marks
[{
x
,
z
}] ==
'o'
&&
marks
[{
y
,
z
}] ==
'o'
) {
// If x-z-y
648
if
(!
existsDirectedPath__
(
graph
,
z
,
x
,
false
)) {
649
graph
.
eraseEdge
(
Edge
(
x
,
z
));
650
graph
.
addArc
(
x
,
z
);
651
// std::cout << "1.a Removing edge (" << x << "," << z << ")" <<
652
// std::endl; std::cout << "1.a Adding arc (" << x << "," << z << ")"
653
// << std::endl;
654
marks
[{
x
,
z
}] =
'>'
;
655
if
(
graph
.
existsArc
(
z
,
x
)
656
&&
std
::
find
(
latent_couples__
.
begin
(),
657
latent_couples__
.
end
(),
658
Arc
(
z
,
x
))
659
==
latent_couples__
.
end
()
660
&&
std
::
find
(
latent_couples__
.
begin
(),
661
latent_couples__
.
end
(),
662
Arc
(
x
,
z
))
663
==
latent_couples__
.
end
()) {
664
// std::cout << "Adding latent couple (" << z << "," << x << ")" <<
665
// std::endl;
666
latent_couples__
.
push_back
(
Arc
(
z
,
x
));
667
}
668
if
(!
arc_probas__
.
exists
(
Arc
(
x
,
z
)))
669
arc_probas__
.
insert
(
Arc
(
x
,
z
),
std
::
get
< 2 >(
best
));
670
}
else
{
671
graph
.
eraseEdge
(
Edge
(
x
,
z
));
672
// std::cout << "1.b Adding arc (" << x << "," << z << ")" <<
673
// std::endl;
674
if
(!
existsDirectedPath__
(
graph
,
x
,
z
,
false
)) {
675
graph
.
addArc
(
z
,
x
);
676
// std::cout << "1.b Removing edge (" << x << "," << z << ")" <<
677
// std::endl;
678
marks
[{
z
,
x
}] =
'>'
;
679
}
680
}
681
682
if
(!
existsDirectedPath__
(
graph
,
z
,
y
,
false
)) {
683
graph
.
eraseEdge
(
Edge
(
y
,
z
));
684
graph
.
addArc
(
y
,
z
);
685
// std::cout << "1.c Removing edge (" << y << "," << z << ")" <<
686
// std::endl; std::cout << "1.c Adding arc (" << y << "," << z << ")"
687
// << std::endl;
688
marks
[{
y
,
z
}] =
'>'
;
689
if
(
graph
.
existsArc
(
z
,
y
)
690
&&
std
::
find
(
latent_couples__
.
begin
(),
691
latent_couples__
.
end
(),
692
Arc
(
z
,
y
))
693
==
latent_couples__
.
end
()
694
&&
std
::
find
(
latent_couples__
.
begin
(),
695
latent_couples__
.
end
(),
696
Arc
(
y
,
z
))
697
==
latent_couples__
.
end
()) {
698
latent_couples__
.
push_back
(
Arc
(
z
,
y
));
699
}
700
if
(!
arc_probas__
.
exists
(
Arc
(
y
,
z
)))
701
arc_probas__
.
insert
(
Arc
(
y
,
z
),
std
::
get
< 3 >(
best
));
702
}
else
{
703
graph
.
eraseEdge
(
Edge
(
y
,
z
));
704
// std::cout << "1.d Removing edge (" << y << "," << z << ")" <<
705
// std::endl;
706
if
(!
existsDirectedPath__
(
graph
,
y
,
z
,
false
)) {
707
graph
.
addArc
(
z
,
y
);
708
// std::cout << "1.d Adding arc (" << z << "," << y << ")" <<
709
// std::endl;
710
marks
[{
z
,
y
}] =
'>'
;
711
}
712
}
713
}
else
if
(
marks
[{
x
,
z
}] ==
'>'
&&
marks
[{
y
,
z
}] ==
'o'
) {
// If x->z-y
714
if
(!
existsDirectedPath__
(
graph
,
z
,
y
,
false
)) {
715
graph
.
eraseEdge
(
Edge
(
y
,
z
));
716
graph
.
addArc
(
y
,
z
);
717
// std::cout << "2.a Removing edge (" << y << "," << z << ")" <<
718
// std::endl; std::cout << "2.a Adding arc (" << y << "," << z << ")"
719
// << std::endl;
720
marks
[{
y
,
z
}] =
'>'
;
721
if
(
graph
.
existsArc
(
z
,
y
)
722
&&
std
::
find
(
latent_couples__
.
begin
(),
723
latent_couples__
.
end
(),
724
Arc
(
z
,
y
))
725
==
latent_couples__
.
end
()
726
&&
std
::
find
(
latent_couples__
.
begin
(),
727
latent_couples__
.
end
(),
728
Arc
(
y
,
z
))
729
==
latent_couples__
.
end
()) {
730
latent_couples__
.
push_back
(
Arc
(
z
,
y
));
731
}
732
if
(!
arc_probas__
.
exists
(
Arc
(
y
,
z
)))
733
arc_probas__
.
insert
(
Arc
(
y
,
z
),
std
::
get
< 3 >(
best
));
734
}
else
{
735
graph
.
eraseEdge
(
Edge
(
y
,
z
));
736
// std::cout << "2.b Removing edge (" << y << "," << z << ")" <<
737
// std::endl;
738
if
(!
existsDirectedPath__
(
graph
,
y
,
z
,
false
)) {
739
graph
.
addArc
(
z
,
y
);
740
// std::cout << "2.b Adding arc (" << y << "," << z << ")" <<
741
// std::endl;
742
marks
[{
z
,
y
}] =
'>'
;
743
}
744
}
745
}
else
if
(
marks
[{
y
,
z
}] ==
'>'
&&
marks
[{
x
,
z
}] ==
'o'
) {
746
if
(!
existsDirectedPath__
(
graph
,
z
,
x
,
false
)) {
747
graph
.
eraseEdge
(
Edge
(
x
,
z
));
748
graph
.
addArc
(
x
,
z
);
749
// std::cout << "3.a Removing edge (" << x << "," << z << ")" <<
750
// std::endl; std::cout << "3.a Adding arc (" << x << "," << z << ")"
751
// << std::endl;
752
marks
[{
x
,
z
}] =
'>'
;
753
if
(
graph
.
existsArc
(
z
,
x
)
754
&&
std
::
find
(
latent_couples__
.
begin
(),
755
latent_couples__
.
end
(),
756
Arc
(
z
,
x
))
757
==
latent_couples__
.
end
()
758
&&
std
::
find
(
latent_couples__
.
begin
(),
759
latent_couples__
.
end
(),
760
Arc
(
x
,
z
))
761
==
latent_couples__
.
end
()) {
762
latent_couples__
.
push_back
(
Arc
(
z
,
x
));
763
}
764
if
(!
arc_probas__
.
exists
(
Arc
(
x
,
z
)))
765
arc_probas__
.
insert
(
Arc
(
x
,
z
),
std
::
get
< 2 >(
best
));
766
}
else
{
767
graph
.
eraseEdge
(
Edge
(
x
,
z
));
768
// std::cout << "3.b Removing edge (" << x << "," << z << ")" <<
769
// std::endl;
770
if
(!
existsDirectedPath__
(
graph
,
x
,
z
,
false
)) {
771
graph
.
addArc
(
z
,
x
);
772
// std::cout << "3.b Adding arc (" << x << "," << z << ")" <<
773
// std::endl;
774
marks
[{
z
,
x
}] =
'>'
;
775
}
776
}
777
}
778
779
}
else
{
780
// orientation propagation
781
if
(
marks
[{
x
,
z
}] ==
'>'
&&
marks
[{
y
,
z
}] ==
'o'
782
&&
marks
[{
z
,
y
}] !=
'-'
) {
783
graph
.
eraseEdge
(
Edge
(
z
,
y
));
784
// std::cout << "4. Removing edge (" << z << "," << y << ")" <<
785
// std::endl;
786
if
(!
existsDirectedPath__
(
graph
,
y
,
z
) &&
graph
.
parents
(
y
).
empty
()) {
787
graph
.
addArc
(
z
,
y
);
788
// std::cout << "4.a Adding arc (" << z << "," << y << ")" <<
789
// std::endl;
790
marks
[{
z
,
y
}] =
'>'
;
791
marks
[{
y
,
z
}] =
'-'
;
792
if
(!
arc_probas__
.
exists
(
Arc
(
z
,
y
)))
793
arc_probas__
.
insert
(
Arc
(
z
,
y
),
std
::
get
< 3 >(
best
));
794
}
else
if
(!
existsDirectedPath__
(
graph
,
z
,
y
)
795
&&
graph
.
parents
(
z
).
empty
()) {
796
graph
.
addArc
(
y
,
z
);
797
// std::cout << "4.b Adding arc (" << y << "," << z << ")" <<
798
// std::endl;
799
marks
[{
z
,
y
}] =
'-'
;
800
marks
[{
y
,
z
}] =
'>'
;
801
latent_couples__
.
push_back
(
Arc
(
y
,
z
));
802
if
(!
arc_probas__
.
exists
(
Arc
(
y
,
z
)))
803
arc_probas__
.
insert
(
Arc
(
y
,
z
),
std
::
get
< 3 >(
best
));
804
}
else
if
(!
existsDirectedPath__
(
graph
,
y
,
z
)) {
805
graph
.
addArc
(
z
,
y
);
806
// std::cout << "4.c Adding arc (" << z << "," << y << ")" <<
807
// std::endl;
808
marks
[{
z
,
y
}] =
'>'
;
809
marks
[{
y
,
z
}] =
'-'
;
810
if
(!
arc_probas__
.
exists
(
Arc
(
z
,
y
)))
811
arc_probas__
.
insert
(
Arc
(
z
,
y
),
std
::
get
< 3 >(
best
));
812
}
else
if
(!
existsDirectedPath__
(
graph
,
z
,
y
)) {
813
graph
.
addArc
(
y
,
z
);
814
// std::cout << "4.d Adding arc (" << y << "," << z << ")" <<
815
// std::endl;
816
latent_couples__
.
push_back
(
Arc
(
y
,
z
));
817
marks
[{
z
,
y
}] =
'-'
;
818
marks
[{
y
,
z
}] =
'>'
;
819
if
(!
arc_probas__
.
exists
(
Arc
(
y
,
z
)))
820
arc_probas__
.
insert
(
Arc
(
y
,
z
),
std
::
get
< 3 >(
best
));
821
}
822
823
}
else
if
(
marks
[{
y
,
z
}] ==
'>'
&&
marks
[{
x
,
z
}] ==
'o'
824
&&
marks
[{
z
,
x
}] !=
'-'
) {
825
graph
.
eraseEdge
(
Edge
(
z
,
x
));
826
// std::cout << "5. Removing edge (" << z << "," << x << ")" <<
827
// std::endl;
828
if
(!
existsDirectedPath__
(
graph
,
x
,
z
) &&
graph
.
parents
(
x
).
empty
()) {
829
graph
.
addArc
(
z
,
x
);
830
// std::cout << "5.a Adding arc (" << z << "," << x << ")" <<
831
// std::endl;
832
marks
[{
z
,
x
}] =
'>'
;
833
marks
[{
x
,
z
}] =
'-'
;
834
if
(!
arc_probas__
.
exists
(
Arc
(
z
,
x
)))
835
arc_probas__
.
insert
(
Arc
(
z
,
x
),
std
::
get
< 2 >(
best
));
836
}
else
if
(!
existsDirectedPath__
(
graph
,
z
,
x
)
837
&&
graph
.
parents
(
z
).
empty
()) {
838
graph
.
addArc
(
x
,
z
);
839
// std::cout << "5.b Adding arc (" << x << "," << z << ")" <<
840
// std::endl;
841
marks
[{
z
,
x
}] =
'-'
;
842
marks
[{
x
,
z
}] =
'>'
;
843
latent_couples__
.
push_back
(
Arc
(
x
,
z
));
844
if
(!
arc_probas__
.
exists
(
Arc
(
x
,
z
)))
845
arc_probas__
.
insert
(
Arc
(
x
,
z
),
std
::
get
< 2 >(
best
));
846
}
else
if
(!
existsDirectedPath__
(
graph
,
x
,
z
)) {
847
graph
.
addArc
(
z
,
x
);
848
// std::cout << "5.c Adding arc (" << z << "," << x << ")" <<
849
// std::endl;
850
marks
[{
z
,
x
}] =
'>'
;
851
marks
[{
x
,
z
}] =
'-'
;
852
if
(!
arc_probas__
.
exists
(
Arc
(
z
,
x
)))
853
arc_probas__
.
insert
(
Arc
(
z
,
x
),
std
::
get
< 2 >(
best
));
854
}
else
if
(!
existsDirectedPath__
(
graph
,
z
,
x
)) {
855
graph
.
addArc
(
x
,
z
);
856
// std::cout << "5.d Adding arc (" << x << "," << z << ")" <<
857
// std::endl;
858
marks
[{
z
,
x
}] =
'-'
;
859
marks
[{
x
,
z
}] =
'>'
;
860
latent_couples__
.
push_back
(
Arc
(
x
,
z
));
861
if
(!
arc_probas__
.
exists
(
Arc
(
x
,
z
)))
862
arc_probas__
.
insert
(
Arc
(
x
,
z
),
std
::
get
< 2 >(
best
));
863
}
864
}
865
}
866
867
delete
std
::
get
< 0 >(
best
);
868
proba_triples
.
erase
(
proba_triples
.
begin
());
869
// actualisation of the list of triples
870
proba_triples
=
updateProbaTriples_
(
graph
,
proba_triples
);
871
872
if
(!
proba_triples
.
empty
())
best
=
proba_triples
[0];
873
874
++
current_step_
;
875
if
(
onProgress
.
hasListener
()) {
876
GUM_EMIT3
(
onProgress
,
877
(
current_step_
* 100) / (
steps_orient
+
past_steps
),
878
0.,
879
timer_
.
step
());
880
}
881
}
// while
882
883
// erasing the double headed arcs
884
for
(
auto
iter
=
latent_couples__
.
rbegin
();
iter
!=
latent_couples__
.
rend
();
885
++
iter
) {
886
graph
.
eraseArc
(
Arc
(
iter
->
head
(),
iter
->
tail
()));
887
if
(
existsDirectedPath__
(
graph
,
iter
->
head
(),
iter
->
tail
())) {
888
// if we find a cycle, we force the competing edge
889
graph
.
addArc
(
iter
->
head
(),
iter
->
tail
());
890
graph
.
eraseArc
(
Arc
(
iter
->
tail
(),
iter
->
head
()));
891
*
iter
=
Arc
(
iter
->
head
(),
iter
->
tail
());
892
}
893
}
894
895
if
(
onProgress
.
hasListener
()) {
896
GUM_EMIT3
(
onProgress
, 100, 0.,
timer_
.
step
());
897
}
898
}
899
900
/// finds the best contributor node for a pair given a conditioning set
901
void
Miic
::
findBestContributor_
(
902
NodeId
x
,
903
NodeId
y
,
904
const
std
::
vector
<
NodeId
>&
ui
,
905
const
MixedGraph
&
graph
,
906
CorrectedMutualInformation
<>&
I
,
907
Heap
<
908
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
909
double
>,
910
GreaterPairOn2nd
>&
rank_
) {
911
double
maxP
= -1.0;
912
NodeId
maxZ
= 0;
913
914
// compute N
915
//__N = I.N();
916
const
double
Ixy_ui
=
I
.
score
(
x
,
y
,
ui
);
917
918
for
(
const
NodeId
z
:
graph
) {
919
// if z!=x and z!=y and z not in ui
920
if
(
z
!=
x
&&
z
!=
y
&&
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
921
double
Pnv
;
922
double
Pb
;
923
924
// Computing Pnv
925
const
double
Ixyz_ui
=
I
.
score
(
x
,
y
,
z
,
ui
);
926
double
calc_expo1
= -
Ixyz_ui
*
M_LN2
;
927
// if exponentials are too high or to low, crop them at |__maxLog|
928
if
(
calc_expo1
>
maxLog__
) {
929
Pnv
= 0.0;
930
}
else
if
(
calc_expo1
< -
maxLog__
) {
931
Pnv
= 1.0;
932
}
else
{
933
Pnv
= 1 / (1 +
std
::
exp
(
calc_expo1
));
934
}
935
936
// Computing Pb
937
const
double
Ixz_ui
=
I
.
score
(
x
,
z
,
ui
);
938
const
double
Iyz_ui
=
I
.
score
(
y
,
z
,
ui
);
939
940
calc_expo1
= -(
Ixz_ui
-
Ixy_ui
) *
M_LN2
;
941
double
calc_expo2
= -(
Iyz_ui
-
Ixy_ui
) *
M_LN2
;
942
943
// if exponentials are too high or to low, crop them at maxLog__
944
if
(
calc_expo1
>
maxLog__
||
calc_expo2
>
maxLog__
) {
945
Pb
= 0.0;
946
}
else
if
(
calc_expo1
< -
maxLog__
&&
calc_expo2
< -
maxLog__
) {
947
Pb
= 1.0;
948
}
else
{
949
double
expo1
,
expo2
;
950
if
(
calc_expo1
< -
maxLog__
) {
951
expo1
= 0.0;
952
}
else
{
953
expo1
=
std
::
exp
(
calc_expo1
);
954
}
955
if
(
calc_expo2
< -
maxLog__
) {
956
expo2
= 0.0;
957
}
else
{
958
expo2
=
std
::
exp
(
calc_expo2
);
959
}
960
Pb
= 1 / (1 +
expo1
+
expo2
);
961
}
962
963
// Getting max(min(Pnv, pb))
964
const
double
min_pnv_pb
=
std
::
min
(
Pnv
,
Pb
);
965
if
(
min_pnv_pb
>
maxP
) {
966
maxP
=
min_pnv_pb
;
967
maxZ
=
z
;
968
}
969
}
// if z not in (x, y)
970
}
// for z in graph.nodes
971
// storing best z in rank_
972
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >*,
973
double
>
974
final
;
975
auto
tup
976
=
new
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
,
std
::
vector
<
NodeId
> >{
x
,
977
y
,
978
maxZ
,
979
ui
};
980
final
.
first
=
tup
;
981
final
.
second
=
maxP
;
982
rank_
.
insert
(
final
);
983
}
984
985
/// gets the list of unshielded triples in the graph in decreasing value of
986
///|I'(x, y, z|{ui})|
987
std
::
vector
<
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
> >
988
Miic
::
getUnshieldedTriples_
(
989
const
MixedGraph
&
graph
,
990
CorrectedMutualInformation
<>&
I
,
991
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
992
sep_set
) {
993
std
::
vector
<
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
> >
994
triples
;
995
for
(
NodeId
z
:
graph
) {
996
for
(
NodeId
x
:
graph
.
neighbours
(
z
)) {
997
for
(
NodeId
y
:
graph
.
neighbours
(
z
)) {
998
if
(
y
<
x
&& !
graph
.
existsEdge
(
x
,
y
)) {
999
std
::
vector
<
NodeId
>
ui
;
1000
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
1001
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
1002
if
(
sep_set
.
exists
(
key
)) {
1003
ui
=
sep_set
[
key
];
1004
}
else
if
(
sep_set
.
exists
(
rev_key
)) {
1005
ui
=
sep_set
[
rev_key
];
1006
}
1007
// remove z from ui if it's present
1008
const
auto
iter_z_place
=
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
);
1009
if
(
iter_z_place
!=
ui
.
end
()) {
ui
.
erase
(
iter_z_place
); }
1010
1011
double
Ixyz_ui
=
I
.
score
(
x
,
y
,
z
,
ui
);
1012
std
::
pair
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
>
triple
;
1013
auto
tup
=
new
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>{
x
,
y
,
z
};
1014
triple
.
first
=
tup
;
1015
triple
.
second
=
Ixyz_ui
;
1016
triples
.
push_back
(
triple
);
1017
}
1018
}
1019
}
1020
}
1021
std
::
sort
(
triples
.
begin
(),
triples
.
end
(),
GreaterAbsPairOn2nd
());
1022
return
triples
;
1023
}
1024
1025
/// gets the list of unshielded triples in the graph in decreasing value of
1026
///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
1027
std
::
vector
<
1028
std
::
1029
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
,
double
,
double
> >
1030
Miic
::
getUnshieldedTriplesMIIC_
(
1031
const
MixedGraph
&
graph
,
1032
CorrectedMutualInformation
<>&
I
,
1033
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
1034
sep_set
,
1035
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>&
marks
) {
1036
std
::
vector
<
std
::
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
1037
double
,
1038
double
,
1039
double
> >
1040
triples
;
1041
for
(
NodeId
z
:
graph
) {
1042
for
(
NodeId
x
:
graph
.
neighbours
(
z
)) {
1043
for
(
NodeId
y
:
graph
.
neighbours
(
z
)) {
1044
if
(
y
<
x
&& !
graph
.
existsEdge
(
x
,
y
)) {
1045
std
::
vector
<
NodeId
>
ui
;
1046
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
1047
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
1048
if
(
sep_set
.
exists
(
key
)) {
1049
ui
=
sep_set
[
key
];
1050
}
else
if
(
sep_set
.
exists
(
rev_key
)) {
1051
ui
=
sep_set
[
rev_key
];
1052
}
1053
// remove z from ui if it's present
1054
const
auto
iter_z_place
=
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
);
1055
if
(
iter_z_place
!=
ui
.
end
()) {
ui
.
erase
(
iter_z_place
); }
1056
1057
const
double
Ixyz_ui
=
I
.
score
(
x
,
y
,
z
,
ui
);
1058
auto
tup
=
new
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>{
x
,
y
,
z
};
1059
std
::
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
1060
double
,
1061
double
,
1062
double
>
1063
triple
{
tup
,
Ixyz_ui
, 0.5, 0.5};
1064
triples
.
push_back
(
triple
);
1065
if
(!
marks
.
exists
({
x
,
z
})) {
marks
.
insert
({
x
,
z
},
'o'
); }
1066
if
(!
marks
.
exists
({
z
,
x
})) {
marks
.
insert
({
z
,
x
},
'o'
); }
1067
if
(!
marks
.
exists
({
y
,
z
})) {
marks
.
insert
({
y
,
z
},
'o'
); }
1068
if
(!
marks
.
exists
({
z
,
y
})) {
marks
.
insert
({
z
,
y
},
'o'
); }
1069
}
1070
}
1071
}
1072
}
1073
triples
=
updateProbaTriples_
(
graph
,
triples
);
1074
std
::
sort
(
triples
.
begin
(),
triples
.
end
(),
GreaterTupleOnLast
());
1075
return
triples
;
1076
}
1077
1078
/// Gets the orientation probabilities like MIIC for the orientation phase
1079
std
::
vector
<
1080
std
::
1081
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
double
,
double
,
double
> >
1082
Miic
::
updateProbaTriples_
(
1083
const
MixedGraph
&
graph
,
1084
std
::
vector
<
std
::
tuple
<
std
::
tuple
<
NodeId
,
NodeId
,
NodeId
>*,
1085
double
,
1086
double
,
1087
double
> >
proba_triples
) {
1088
for
(
auto
&
triple
:
proba_triples
) {
1089
NodeId
x
,
y
,
z
;
1090
x
=
std
::
get
< 0 >(*
std
::
get
< 0 >(
triple
));
1091
y
=
std
::
get
< 1 >(*
std
::
get
< 0 >(
triple
));
1092
z
=
std
::
get
< 2 >(*
std
::
get
< 0 >(
triple
));
1093
const
double
Ixyz
=
std
::
get
< 1 >(
triple
);
1094
double
Pxz
=
std
::
get
< 2 >(
triple
);
1095
double
Pyz
=
std
::
get
< 3 >(
triple
);
1096
1097
if
(
Ixyz
<= 0) {
1098
const
double
expo
=
std
::
exp
(
Ixyz
);
1099
const
double
P0
= (1 +
expo
) / (1 + 3 *
expo
);
1100
// distinguish betweeen the initialization and the update process
1101
if
(
Pxz
==
Pyz
&&
Pyz
== 0.5) {
1102
std
::
get
< 2 >(
triple
) =
P0
;
1103
std
::
get
< 3 >(
triple
) =
P0
;
1104
}
else
{
1105
if
(
graph
.
existsArc
(
x
,
z
) &&
Pxz
>=
P0
) {
1106
std
::
get
< 3 >(
triple
) =
Pxz
* (1 / (1 +
expo
) - 0.5) + 0.5;
1107
}
else
if
(
graph
.
existsArc
(
y
,
z
) &&
Pyz
>=
P0
) {
1108
std
::
get
< 2 >(
triple
) =
Pyz
* (1 / (1 +
expo
) - 0.5) + 0.5;
1109
}
1110
}
1111
}
else
{
1112
const
double
expo
=
std
::
exp
(-
Ixyz
);
1113
if
(
graph
.
existsArc
(
x
,
z
) &&
Pxz
>= 0.5) {
1114
std
::
get
< 3 >(
triple
) =
Pxz
* (1 / (1 +
expo
) - 0.5) + 0.5;
1115
}
else
if
(
graph
.
existsArc
(
y
,
z
) &&
Pyz
>= 0.5) {
1116
std
::
get
< 2 >(
triple
) =
Pyz
* (1 / (1 +
expo
) - 0.5) + 0.5;
1117
}
1118
}
1119
}
1120
std
::
sort
(
proba_triples
.
begin
(),
proba_triples
.
end
(),
GreaterTupleOnLast
());
1121
return
proba_triples
;
1122
}
1123
1124
/// learns the structure of an Bayesian network, ie a DAG, from an Essential
1125
/// graph.
1126
DAG
Miic
::
learnStructure
(
CorrectedMutualInformation
<>&
I
,
1127
MixedGraph
initialGraph
) {
1128
MixedGraph
essentialGraph
=
learnMixedStructure
(
I
,
initialGraph
);
1129
// std::cout << "Le mixed graph mesdames et messieurs: "
1130
//<< essentialGraph.toDot() << std::endl;
1131
1132
// Second, orientate remaining edges
1133
const
Sequence
<
NodeId
>
order
=
essentialGraph
.
topologicalOrder
();
1134
// first, propagate existing orientations
1135
for
(
NodeId
x
:
order
) {
1136
if
(!
essentialGraph
.
parents
(
x
).
empty
()) {
1137
propagatesHead_
(
essentialGraph
,
x
);
1138
}
1139
}
1140
// std::cout << "Le mixed graph après une première propagation mesdames et
1141
// messieurs: "
1142
//<< essentialGraph.toDot() << std::endl;
1143
// then decide the orientation by the topological order and propagate them
1144
for
(
NodeId
x
:
order
) {
1145
if
(!
essentialGraph
.
neighbours
(
x
).
empty
()) {
1146
propagatesHead_
(
essentialGraph
,
x
);
1147
}
1148
}
1149
1150
// std::cout << "Le mixed graph après une deuxième propagation mesdames et
1151
// messieurs: "
1152
//<< essentialGraph.toDot() << std::endl;
1153
// std::cout << "Le graphe contient maintenant : " <<
1154
// essentialGraph.sizeArcs() << " arcs."
1155
//<< std::endl;
1156
// std::cout << "Que voici: " << essentialGraph.arcs() << std::endl;
1157
// turn the mixed graph into a dag
1158
DAG
dag
;
1159
for
(
auto
node
:
essentialGraph
) {
1160
dag
.
addNodeWithId
(
node
);
1161
}
1162
for
(
const
Arc
&
arc
:
essentialGraph
.
arcs
()) {
1163
dag
.
addArc
(
arc
.
tail
(),
arc
.
head
());
1164
}
1165
1166
return
dag
;
1167
}
1168
1169
/// Propagates the orientation from a node to its neighbours
1170
void
Miic
::
propagatesHead_
(
MixedGraph
&
graph
,
NodeId
node
) {
1171
const
auto
neighbours
=
graph
.
neighbours
(
node
);
1172
for
(
auto
&
neighbour
:
neighbours
) {
1173
// std::cout << "Orientation de l'edge (" << node << "," << neighbour <<
1174
// ")" << std::endl;
1175
if
(
graph
.
neighbours
(
neighbour
).
contains
(
node
)) {
1176
if
(!
existsDirectedPath__
(
graph
,
neighbour
,
node
)
1177
&& !(
initial_marks__
.
exists
({
node
,
neighbour
})
1178
&&
initial_marks__
[{
node
,
neighbour
}] ==
'-'
)
1179
&&
graph
.
parents
(
neighbour
).
empty
()) {
1180
graph
.
eraseEdge
(
Edge
(
neighbour
,
node
));
1181
graph
.
addArc
(
node
,
neighbour
);
1182
// std::cout << "1. Removing edge (" << neighbour << "," << node << ")"
1183
// << std::endl; std::cout << "1. Adding arc (" << node << "," <<
1184
// neighbour << ")" << std::endl;
1185
propagatesHead_
(
graph
,
neighbour
);
1186
}
else
if
(!
existsDirectedPath__
(
graph
,
node
,
neighbour
)
1187
&& !(
initial_marks__
.
exists
({
neighbour
,
node
})
1188
&&
initial_marks__
[{
neighbour
,
node
}] ==
'-'
)
1189
&&
graph
.
parents
(
node
).
empty
()) {
1190
graph
.
eraseEdge
(
Edge
(
neighbour
,
node
));
1191
graph
.
addArc
(
neighbour
,
node
);
1192
// std::cout << "2. Removing edge (" << neighbour << "," << node << ")"
1193
// << std::endl; std::cout << "2. Adding arc (" << neighbour << "," <<
1194
// node << ")" << std::endl;
1195
}
else
if
(!
existsDirectedPath__
(
graph
,
node
,
neighbour
)
1196
&& !(
initial_marks__
.
exists
({
neighbour
,
node
})
1197
&&
initial_marks__
[{
neighbour
,
node
}] ==
'-'
)) {
1198
graph
.
eraseEdge
(
Edge
(
neighbour
,
node
));
1199
graph
.
addArc
(
neighbour
,
node
);
1200
if
(!
graph
.
parents
(
neighbour
).
empty
()
1201
&& !
graph
.
parents
(
node
).
empty
()) {
1202
latent_couples__
.
push_back
(
Arc
(
node
,
neighbour
));
1203
}
1204
1205
// std::cout << "3. Removing edge (" << neighbour << "," << node << ")"
1206
// << std::endl; std::cout << "3. Adding arc (" << neighbour << "," <<
1207
// node << ")" << std::endl;
1208
}
else
if
(!
existsDirectedPath__
(
graph
,
neighbour
,
node
)
1209
&& !(
initial_marks__
.
exists
({
node
,
neighbour
})
1210
&&
initial_marks__
[{
node
,
neighbour
}] ==
'-'
)) {
1211
graph
.
eraseEdge
(
Edge
(
node
,
neighbour
));
1212
graph
.
addArc
(
node
,
neighbour
);
1213
if
(!
graph
.
parents
(
neighbour
).
empty
()
1214
&& !
graph
.
parents
(
node
).
empty
()) {
1215
latent_couples__
.
push_back
(
Arc
(
node
,
neighbour
));
1216
}
1217
// std::cout << "4. Removing edge (" << neighbour << "," << node << ")"
1218
// << std::endl; std::cout << "4. Adding arc (" << node << "," <<
1219
// neighbour << ")" << std::endl;
1220
}
1221
// else if (!graph.parents(neighbour).empty()
1222
//&& !graph.parents(node).empty()) {
1223
// graph.eraseEdge(Edge(neighbour, node));
1224
// graph.addArc(node, neighbour);
1225
//__latent_couples.push_back(Arc(node, neighbour));
1226
//}
1227
else
{
1228
graph
.
eraseEdge
(
Edge
(
neighbour
,
node
));
1229
// std::cout << "5. Removing edge (" << neighbour << "," << node << ")"
1230
// << std::endl;
1231
}
1232
}
1233
}
1234
}
1235
1236
/// get the list of arcs hiding latent variables
1237
const
std
::
vector
<
Arc
>
Miic
::
latentVariables
()
const
{
1238
return
latent_couples__
;
1239
}
1240
1241
/// learns the structure and the parameters of a BN
1242
template
<
typename
GUM_SCALAR
,
1243
typename
GRAPH_CHANGES_SELECTOR
,
1244
typename
PARAM_ESTIMATOR
>
1245
BayesNet
<
GUM_SCALAR
>
Miic
::
learnBN
(
GRAPH_CHANGES_SELECTOR
&
selector
,
1246
PARAM_ESTIMATOR
&
estimator
,
1247
DAG
initial_dag
) {
1248
return
DAG2BNLearner
<>::
createBN
<
GUM_SCALAR
>(
1249
estimator
,
1250
learnStructure
(
selector
,
initial_dag
));
1251
}
1252
1253
void
Miic
::
setMiicBehaviour
() {
this
->
usemiic__
=
true
; }
1254
void
Miic
::
set3off2Behaviour
() {
this
->
usemiic__
=
false
; }
1255
1256
void
Miic
::
addConstraints
(
1257
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>
constraints
) {
1258
this
->
initial_marks__
=
constraints
;
1259
}
1260
1261
1262
const
bool
Miic
::
existsDirectedPath__
(
const
MixedGraph
&
graph
,
1263
const
NodeId
n1
,
1264
const
NodeId
n2
,
1265
const
bool
countArc
)
const
{
1266
// not recursive version => use a FIFO for simulating the recursion
1267
List
<
NodeId
>
nodeFIFO
;
1268
nodeFIFO
.
pushBack
(
n2
);
1269
1270
// mark[node] = successor if visited, else mark[node] does not exist
1271
NodeProperty
<
NodeId
>
mark
;
1272
mark
.
insert
(
n2
,
n2
);
1273
1274
NodeId
current
;
1275
1276
while
(!
nodeFIFO
.
empty
()) {
1277
current
=
nodeFIFO
.
front
();
1278
nodeFIFO
.
popFront
();
1279
1280
// check the parents
1281
1282
for
(
const
auto
new_one
:
graph
.
parents
(
current
)) {
1283
if
(!
countArc
&&
current
==
n2
1284
&&
new_one
==
n1
)
// If countArc is set to false
1285
continue
;
// paths of length 1 are ignored
1286
1287
if
(
mark
.
exists
(
new_one
))
// if this node is already marked, do not
1288
continue
;
// check it again
1289
1290
if
(
graph
.
existsArc
(
current
,
1291
new_one
))
// if there is a double arc, pass
1292
continue
;
1293
1294
mark
.
insert
(
new_one
,
current
);
1295
1296
if
(
new_one
==
n1
) {
return
true
; }
1297
1298
nodeFIFO
.
pushBack
(
new_one
);
1299
}
1300
}
1301
1302
return
false
;
1303
}
1304
1305
}
/* namespace learning */
1306
1307
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31