aGrUM
0.21.0
a C++ library for (probabilistic) graphical models
Miic.cpp
Go to the documentation of this file.
1
2
/**
3
*
4
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe
5
* GONZALES(@AMU) info_at_agrum_dot_org
6
*
7
* This library is free software: you can redistribute it and/or modify
8
* it under the terms of the GNU Lesser General Public License as published by
9
* the Free Software Foundation, either version 3 of the License, or
10
* (at your option) any later version.
11
*
12
* This library is distributed in the hope that it will be useful,
13
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
* GNU Lesser General Public License for more details.
16
*
17
* You should have received a copy of the GNU Lesser General Public License
18
* along with this library. If not, see <http://www.gnu.org/licenses/>.
19
*
20
*/
21
22
23
/** @file
24
* @brief Implementation of gum::learning::ThreeOffTwo and MIIC
25
*
26
* @author Quentin FALCAND, Marvin LASSERRE and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
29
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
30
#
include
<
agrum
/
tools
/
core
/
hashTable
.
h
>
31
#
include
<
agrum
/
tools
/
core
/
heap
.
h
>
32
#
include
<
agrum
/
tools
/
core
/
timer
.
h
>
33
#
include
<
agrum
/
tools
/
graphs
/
mixedGraph
.
h
>
34
#
include
<
agrum
/
BN
/
learning
/
Miic
.
h
>
35
#
include
<
agrum
/
BN
/
learning
/
paramUtils
/
DAG2BNLearner
.
h
>
36
#
include
<
agrum
/
tools
/
stattests
/
correctedMutualInformation
.
h
>
37
38
39
namespace
gum
{
40
41
namespace
learning
{
42
43
/// default constructor
44
Miic
::
Miic
() :
_maxLog_
(100),
_size_
(0) {
GUM_CONSTRUCTOR
(
Miic
); }
45
46
/// default constructor with maxLog
47
Miic
::
Miic
(
int
maxLog
) :
_maxLog_
(
maxLog
),
_size_
(0) {
GUM_CONSTRUCTOR
(
Miic
); }
48
49
/// copy constructor
50
Miic
::
Miic
(
const
Miic
&
from
) :
ApproximationScheme
(
from
),
_size_
(
from
.
_size_
) {
51
GUM_CONS_CPY
(
Miic
);
52
}
53
54
/// move constructor
55
Miic
::
Miic
(
Miic
&&
from
) :
ApproximationScheme
(
std
::
move
(
from
)),
_size_
(
from
.
_size_
) {
56
GUM_CONS_MOV
(
Miic
);
57
}
58
59
/// destructor
60
Miic
::~
Miic
() {
GUM_DESTRUCTOR
(
Miic
); }
61
62
/// copy operator
63
Miic
&
Miic
::
operator
=(
const
Miic
&
from
) {
64
ApproximationScheme
::
operator
=(
from
);
65
return
*
this
;
66
}
67
68
/// move operator
69
Miic
&
Miic
::
operator
=(
Miic
&&
from
) {
70
ApproximationScheme
::
operator
=(
std
::
move
(
from
));
71
return
*
this
;
72
}
73
74
75
bool
GreaterPairOn2nd
::
operator
()(
const
CondRanking
&
e1
,
const
CondRanking
&
e2
)
const
{
76
return
e1
.
second
>
e2
.
second
;
77
}
78
79
bool
GreaterAbsPairOn2nd
::
operator
()(
const
Ranking
&
e1
,
const
Ranking
&
e2
)
const
{
80
return
std
::
abs
(
e1
.
second
) >
std
::
abs
(
e2
.
second
);
81
}
82
83
bool
GreaterTupleOnLast
::
operator
()(
const
ProbabilisticRanking
&
e1
,
84
const
ProbabilisticRanking
&
e2
)
const
{
85
double
p1xz
=
std
::
get
< 2 >(
e1
);
86
double
p1yz
=
std
::
get
< 3 >(
e1
);
87
double
p2xz
=
std
::
get
< 2 >(
e2
);
88
double
p2yz
=
std
::
get
< 3 >(
e2
);
89
double
I1
=
std
::
get
< 1 >(
e1
);
90
double
I2
=
std
::
get
< 1 >(
e2
);
91
// First, we look at the sign of information.
92
// Then, the probability values
93
// and finally the abs value of information.
94
if
((
I1
< 0 &&
I2
< 0) || (
I1
>= 0 &&
I2
>= 0)) {
95
if
(
std
::
max
(
p1xz
,
p1yz
) ==
std
::
max
(
p2xz
,
p2yz
)) {
96
return
std
::
abs
(
I1
) >
std
::
abs
(
I2
);
97
}
else
{
98
return
std
::
max
(
p1xz
,
p1yz
) >
std
::
max
(
p2xz
,
p2yz
);
99
}
100
}
else
{
101
return
I1
<
I2
;
102
}
103
}
104
105
/// learns the structure of a MixedGraph
106
MixedGraph
Miic
::
learnMixedStructure
(
CorrectedMutualInformation
<>&
mutualInformation
,
107
MixedGraph
graph
) {
108
timer_
.
reset
();
109
current_step_
= 0;
110
111
// clear the vector of latent arcs to be sure
112
_latentCouples_
.
clear
();
113
114
/// the heap of ranks, with the score, and the NodeIds of x, y and z.
115
Heap
<
CondRanking
,
GreaterPairOn2nd
>
rank
;
116
117
/// the variables separation sets
118
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >
sep_set
;
119
120
initiation_
(
mutualInformation
,
graph
,
sep_set
,
rank
);
121
122
iteration_
(
mutualInformation
,
graph
,
sep_set
,
rank
);
123
124
if
(
_useMiic_
) {
125
orientationMiic_
(
mutualInformation
,
graph
,
sep_set
);
126
}
else
{
127
orientation3off2_
(
mutualInformation
,
graph
,
sep_set
);
128
}
129
130
return
graph
;
131
}
132
133
/*
134
* PHASE 1 : INITIATION
135
*
136
* We go over all edges and test if the variables are independent. If they
137
* are,
138
* the edge is deleted. If not, the best contributor is found.
139
*/
140
void
Miic
::
initiation_
(
CorrectedMutualInformation
<>&
mutualInformation
,
141
MixedGraph
&
graph
,
142
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
,
143
Heap
<
CondRanking
,
GreaterPairOn2nd
>&
rank
) {
144
NodeId
x
,
y
;
145
EdgeSet
edges
=
graph
.
edges
();
146
Size
steps_init
=
edges
.
size
();
147
148
for
(
const
Edge
&
edge
:
edges
) {
149
x
=
edge
.
first
();
150
y
=
edge
.
second
();
151
double
Ixy
=
mutualInformation
.
score
(
x
,
y
);
152
153
if
(
Ixy
<= 0) {
//< K
154
graph
.
eraseEdge
(
edge
);
155
sepSet
.
insert
(
std
::
make_pair
(
x
,
y
),
_emptySet_
);
156
}
else
{
157
findBestContributor_
(
x
,
y
,
_emptySet_
,
graph
,
mutualInformation
,
rank
);
158
}
159
160
++
current_step_
;
161
if
(
onProgress
.
hasListener
()) {
162
GUM_EMIT3
(
onProgress
, (
current_step_
* 33) /
steps_init
, 0.,
timer_
.
step
());
163
}
164
}
165
}
166
167
/*
168
* PHASE 2 : ITERATION
169
*
170
* As long as we find important nodes for edges, we go over them to see if
171
* we can assess the independence of the variables.
172
*/
173
void
Miic
::
iteration_
(
CorrectedMutualInformation
<>&
mutualInformation
,
174
MixedGraph
&
graph
,
175
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
,
176
Heap
<
CondRanking
,
GreaterPairOn2nd
>&
rank
) {
177
// if no triples to further examine pass
178
CondRanking
best
;
179
180
Size
steps_init
=
current_step_
;
181
Size
steps_iter
=
rank
.
size
();
182
183
try
{
184
while
(
rank
.
top
().
second
> 0.5) {
185
best
=
rank
.
pop
();
186
187
const
NodeId
x
=
std
::
get
< 0 >(*(
best
.
first
));
188
const
NodeId
y
=
std
::
get
< 1 >(*(
best
.
first
));
189
const
NodeId
z
=
std
::
get
< 2 >(*(
best
.
first
));
190
std
::
vector
<
NodeId
>
ui
=
std
::
move
(
std
::
get
< 3 >(*(
best
.
first
)));
191
192
ui
.
push_back
(
z
);
193
const
double
i_xy_ui
=
mutualInformation
.
score
(
x
,
y
,
ui
);
194
if
(
i_xy_ui
< 0) {
195
graph
.
eraseEdge
(
Edge
(
x
,
y
));
196
sepSet
.
insert
(
std
::
make_pair
(
x
,
y
),
std
::
move
(
ui
));
197
}
else
{
198
findBestContributor_
(
x
,
y
,
ui
,
graph
,
mutualInformation
,
rank
);
199
}
200
201
delete
best
.
first
;
202
203
++
current_step_
;
204
if
(
onProgress
.
hasListener
()) {
205
GUM_EMIT3
(
onProgress
,
206
(
current_step_
* 66) / (
steps_init
+
steps_iter
),
207
0.,
208
timer_
.
step
());
209
}
210
}
211
}
catch
(...) {}
// here, rank is empty
212
current_step_
=
steps_init
+
steps_iter
;
213
if
(
onProgress
.
hasListener
()) {
GUM_EMIT3
(
onProgress
, 66, 0.,
timer_
.
step
()); }
214
current_step_
=
steps_init
+
steps_iter
;
215
}
216
217
/*
218
* PHASE 3 : ORIENTATION
219
*
220
* Try to assess v-structures and propagate them.
221
*/
222
void
Miic
::
orientation3off2_
(
223
CorrectedMutualInformation
<>&
mutualInformation
,
224
MixedGraph
&
graph
,
225
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
226
std
::
vector
<
Ranking
>
triples
=
unshieldedTriples_
(
graph
,
mutualInformation
,
sepSet
);
227
Size
steps_orient
=
triples
.
size
();
228
Size
past_steps
=
current_step_
;
229
230
// marks always correspond to the head of the arc/edge. - is for a forbidden
231
// arc, > for a mandatory arc
232
// we start by adding the mandatory arcs
233
for
(
auto
iter
=
_initialMarks_
.
begin
();
iter
!=
_initialMarks_
.
end
(); ++
iter
) {
234
if
(
graph
.
existsEdge
(
iter
.
key
().
first
,
iter
.
key
().
second
) &&
iter
.
val
() ==
'>'
) {
235
graph
.
eraseEdge
(
Edge
(
iter
.
key
().
first
,
iter
.
key
().
second
));
236
graph
.
addArc
(
iter
.
key
().
first
,
iter
.
key
().
second
);
237
}
238
}
239
240
NodeId
i
= 0;
241
// list of elements that we shouldn't read again, ie elements that are
242
// eligible to
243
// rule 0 after the first time they are tested, and elements on which rule 1
244
// has been applied
245
while
(
i
<
triples
.
size
()) {
246
// if i not in do_not_reread
247
Ranking
triple
=
triples
[
i
];
248
NodeId
x
,
y
,
z
;
249
x
=
std
::
get
< 0 >(*
triple
.
first
);
250
y
=
std
::
get
< 1 >(*
triple
.
first
);
251
z
=
std
::
get
< 2 >(*
triple
.
first
);
252
253
std
::
vector
<
NodeId
>
ui
;
254
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
255
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
256
if
(
sepSet
.
exists
(
key
)) {
257
ui
=
sepSet
[
key
];
258
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
259
ui
=
sepSet
[
rev_key
];
260
}
261
double
Ixyz_ui
=
triple
.
second
;
262
bool
reset
{
false
};
263
// try Rule 0
264
if
(
Ixyz_ui
< 0) {
265
// if ( z not in Sep[x,y])
266
if
(
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
267
if
(!
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
x
)) {
268
// when we try to add an arc to the graph, we always verify if
269
// we are allowed to do so, ie it is not a forbidden arc an it
270
// does not create a cycle
271
if
(!
_existsDirectedPath_
(
graph
,
z
,
x
) && !
isForbidenArc_
(
x
,
z
)) {
272
reset
=
true
;
273
graph
.
eraseEdge
(
Edge
(
x
,
z
));
274
graph
.
addArc
(
x
,
z
);
275
}
else
if
(
_existsDirectedPath_
(
graph
,
z
,
x
) && !
isForbidenArc_
(
z
,
x
)) {
276
reset
=
true
;
277
graph
.
eraseEdge
(
Edge
(
x
,
z
));
278
// if we find a cycle, we force the competing edge
279
graph
.
addArc
(
z
,
x
);
280
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
z
,
x
))
281
==
_latentCouples_
.
end
()) {
282
_latentCouples_
.
emplace_back
(
z
,
x
);
283
}
284
}
285
}
else
if
(!
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
y
)) {
286
if
(!
_existsDirectedPath_
(
graph
,
z
,
y
) && !
isForbidenArc_
(
x
,
z
)) {
287
reset
=
true
;
288
graph
.
eraseEdge
(
Edge
(
y
,
z
));
289
graph
.
addArc
(
y
,
z
);
290
}
else
if
(
_existsDirectedPath_
(
graph
,
z
,
y
) && !
isForbidenArc_
(
z
,
y
)) {
291
reset
=
true
;
292
graph
.
eraseEdge
(
Edge
(
y
,
z
));
293
// if we find a cycle, we force the competing edge
294
graph
.
addArc
(
z
,
y
);
295
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
z
,
y
))
296
==
_latentCouples_
.
end
()) {
297
_latentCouples_
.
emplace_back
(
z
,
y
);
298
}
299
}
300
}
else
{
301
// checking if the anti-directed arc already exists, to register a
302
// latent variable
303
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
304
_latentCouples_
.
emplace_back
(
z
,
x
);
305
}
306
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
307
_latentCouples_
.
emplace_back
(
z
,
y
);
308
}
309
}
310
}
311
}
else
{
// try Rule 1
312
if
(
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
y
) && !
graph
.
existsArc
(
y
,
z
)) {
313
if
(!
_existsDirectedPath_
(
graph
,
y
,
z
) && !
isForbidenArc_
(
z
,
y
)) {
314
reset
=
true
;
315
graph
.
eraseEdge
(
Edge
(
z
,
y
));
316
graph
.
addArc
(
z
,
y
);
317
}
else
if
(
_existsDirectedPath_
(
graph
,
y
,
z
) && !
isForbidenArc_
(
y
,
z
)) {
318
reset
=
true
;
319
graph
.
eraseEdge
(
Edge
(
z
,
y
));
320
// if we find a cycle, we force the competing edge
321
graph
.
addArc
(
y
,
z
);
322
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
y
,
z
))
323
==
_latentCouples_
.
end
()) {
324
_latentCouples_
.
emplace_back
(
y
,
z
);
325
}
326
}
327
}
328
if
(
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
x
) && !
graph
.
existsArc
(
x
,
z
)) {
329
if
(!
_existsDirectedPath_
(
graph
,
x
,
z
) && !
isForbidenArc_
(
z
,
x
)) {
330
reset
=
true
;
331
graph
.
eraseEdge
(
Edge
(
z
,
x
));
332
graph
.
addArc
(
z
,
x
);
333
}
else
if
(
_existsDirectedPath_
(
graph
,
x
,
z
) && !
isForbidenArc_
(
x
,
z
)) {
334
reset
=
true
;
335
graph
.
eraseEdge
(
Edge
(
z
,
x
));
336
// if we find a cycle, we force the competing edge
337
graph
.
addArc
(
x
,
z
);
338
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
x
,
z
))
339
==
_latentCouples_
.
end
()) {
340
_latentCouples_
.
emplace_back
(
x
,
z
);
341
}
342
}
343
}
344
}
// if rule 0 or rule 1
345
346
// if what we want to add already exists : pass to the next triplet
347
if
(
reset
) {
348
i
= 0;
349
}
else
{
350
++
i
;
351
}
352
if
(
onProgress
.
hasListener
()) {
353
GUM_EMIT3
(
onProgress
,
354
((
current_step_
+
i
) * 100) / (
past_steps
+
steps_orient
),
355
0.,
356
timer_
.
step
());
357
}
358
}
// while
359
360
// erasing the the double headed arcs
361
for
(
const
Arc
&
arc
:
_latentCouples_
) {
362
graph
.
eraseArc
(
Arc
(
arc
.
head
(),
arc
.
tail
()));
363
}
364
}
365
366
/// variant trying to propagate both orientations in a bidirected arc
367
void
Miic
::
orientationLatents_
(
368
CorrectedMutualInformation
<>&
mutualInformation
,
369
MixedGraph
&
graph
,
370
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
371
std
::
vector
<
Ranking
>
triples
=
unshieldedTriples_
(
graph
,
mutualInformation
,
sepSet
);
372
Size
steps_orient
=
triples
.
size
();
373
Size
past_steps
=
current_step_
;
374
375
NodeId
i
= 0;
376
// list of elements that we shouldnt read again, ie elements that are
377
// eligible to
378
// rule 0 after the first time they are tested, and elements on which rule 1
379
// has been applied
380
while
(
i
<
triples
.
size
()) {
381
// if i not in do_not_reread
382
Ranking
triple
=
triples
[
i
];
383
NodeId
x
,
y
,
z
;
384
x
=
std
::
get
< 0 >(*
triple
.
first
);
385
y
=
std
::
get
< 1 >(*
triple
.
first
);
386
z
=
std
::
get
< 2 >(*
triple
.
first
);
387
388
std
::
vector
<
NodeId
>
ui
;
389
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
390
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
391
if
(
sepSet
.
exists
(
key
)) {
392
ui
=
sepSet
[
key
];
393
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
394
ui
=
sepSet
[
rev_key
];
395
}
396
double
Ixyz_ui
=
triple
.
second
;
397
// try Rule 0
398
if
(
Ixyz_ui
< 0) {
399
// if ( z not in Sep[x,y])
400
if
(
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
401
// if what we want to add already exists : pass
402
if
((
graph
.
existsArc
(
x
,
z
) ||
graph
.
existsArc
(
z
,
x
))
403
&& (
graph
.
existsArc
(
y
,
z
) ||
graph
.
existsArc
(
z
,
y
))) {
404
++
i
;
405
}
else
{
406
i
= 0;
407
graph
.
eraseEdge
(
Edge
(
x
,
z
));
408
graph
.
eraseEdge
(
Edge
(
y
,
z
));
409
// checking for cycles
410
if
(
graph
.
existsArc
(
z
,
x
)) {
411
graph
.
eraseArc
(
Arc
(
z
,
x
));
412
try
{
413
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
x
);
414
// if we find a cycle, we force the competing edge
415
_latentCouples_
.
emplace_back
(
z
,
x
);
416
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
x
,
z
); }
417
graph
.
addArc
(
z
,
x
);
418
}
else
{
419
try
{
420
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
x
);
421
// if we find a cycle, we force the competing edge
422
graph
.
addArc
(
z
,
x
);
423
_latentCouples_
.
emplace_back
(
z
,
x
);
424
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
x
,
z
); }
425
}
426
if
(
graph
.
existsArc
(
z
,
y
)) {
427
graph
.
eraseArc
(
Arc
(
z
,
y
));
428
try
{
429
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
y
);
430
// if we find a cycle, we force the competing edge
431
_latentCouples_
.
emplace_back
(
z
,
y
);
432
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
y
,
z
); }
433
graph
.
addArc
(
z
,
y
);
434
}
else
{
435
try
{
436
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
y
);
437
// if we find a cycle, we force the competing edge
438
graph
.
addArc
(
z
,
y
);
439
_latentCouples_
.
emplace_back
(
z
,
y
);
440
441
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
y
,
z
); }
442
}
443
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
444
_latentCouples_
.
emplace_back
(
z
,
x
);
445
}
446
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
447
_latentCouples_
.
emplace_back
(
z
,
y
);
448
}
449
}
450
}
else
{
451
++
i
;
452
}
453
}
else
{
// try Rule 1
454
bool
reset
{
false
};
455
if
(
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
y
) && !
graph
.
existsArc
(
y
,
z
)) {
456
reset
=
true
;
457
graph
.
eraseEdge
(
Edge
(
z
,
y
));
458
try
{
459
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
y
,
z
);
460
// if we find a cycle, we force the competing edge
461
graph
.
addArc
(
y
,
z
);
462
_latentCouples_
.
emplace_back
(
y
,
z
);
463
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
z
,
y
); }
464
}
465
if
(
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
x
) && !
graph
.
existsArc
(
x
,
z
)) {
466
reset
=
true
;
467
graph
.
eraseEdge
(
Edge
(
z
,
x
));
468
try
{
469
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
x
,
z
);
470
// if we find a cycle, we force the competing edge
471
graph
.
addArc
(
x
,
z
);
472
_latentCouples_
.
emplace_back
(
x
,
z
);
473
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
z
,
x
); }
474
}
475
476
if
(
reset
) {
477
i
= 0;
478
}
else
{
479
++
i
;
480
}
481
}
// if rule 0 or rule 1
482
if
(
onProgress
.
hasListener
()) {
483
GUM_EMIT3
(
onProgress
,
484
((
current_step_
+
i
) * 100) / (
past_steps
+
steps_orient
),
485
0.,
486
timer_
.
step
());
487
}
488
}
// while
489
490
// erasing the the double headed arcs
491
for
(
const
Arc
&
arc
:
_latentCouples_
) {
492
graph
.
eraseArc
(
Arc
(
arc
.
head
(),
arc
.
tail
()));
493
}
494
}
495
496
/// varient using the orientation protocol of MIIC
497
void
Miic
::
orientationMiic_
(
498
CorrectedMutualInformation
<>&
mutualInformation
,
499
MixedGraph
&
graph
,
500
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
501
// structure to store the orientations marks -, o, or >,
502
// Considers the head of the arc/edge first node -* second node
503
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>
marks
=
_initialMarks_
;
504
505
// marks always correspond to the head of the arc/edge. - is for a forbidden
506
// arc, > for a mandatory arc
507
// we start by adding the mandatory arcs
508
for
(
auto
iter
=
marks
.
begin
();
iter
!=
marks
.
end
(); ++
iter
) {
509
if
(
graph
.
existsEdge
(
iter
.
key
().
first
,
iter
.
key
().
second
) &&
iter
.
val
() ==
'>'
) {
510
graph
.
eraseEdge
(
Edge
(
iter
.
key
().
first
,
iter
.
key
().
second
));
511
graph
.
addArc
(
iter
.
key
().
first
,
iter
.
key
().
second
);
512
}
513
}
514
515
std
::
vector
<
ProbabilisticRanking
>
proba_triples
516
=
unshieldedTriplesMiic_
(
graph
,
mutualInformation
,
sepSet
,
marks
);
517
518
const
Size
steps_orient
=
proba_triples
.
size
();
519
Size
past_steps
=
current_step_
;
520
521
ProbabilisticRanking
best
;
522
if
(
steps_orient
> 0) {
best
=
proba_triples
[0]; }
523
524
while
(!
proba_triples
.
empty
() &&
std
::
max
(
std
::
get
< 2 >(
best
),
std
::
get
< 3 >(
best
)) > 0.5) {
525
const
NodeId
x
=
std
::
get
< 0 >(*
std
::
get
< 0 >(
best
));
526
const
NodeId
y
=
std
::
get
< 1 >(*
std
::
get
< 0 >(
best
));
527
const
NodeId
z
=
std
::
get
< 2 >(*
std
::
get
< 0 >(
best
));
528
529
const
double
i3
=
std
::
get
< 1 >(
best
);
530
531
const
double
p1
=
std
::
get
< 2 >(
best
);
532
const
double
p2
=
std
::
get
< 3 >(
best
);
533
if
(
i3
<= 0) {
534
_orientingVstructureMiic_
(
graph
,
marks
,
x
,
y
,
z
,
p1
,
p2
);
535
}
else
{
536
_propagatingOrientationMiic_
(
graph
,
marks
,
x
,
y
,
z
,
p1
,
p2
);
537
}
538
539
delete
std
::
get
< 0 >(
best
);
540
proba_triples
.
erase
(
proba_triples
.
begin
());
541
// actualisation of the list of triples
542
proba_triples
=
updateProbaTriples_
(
graph
,
proba_triples
);
543
544
if
(!
proba_triples
.
empty
())
best
=
proba_triples
[0];
545
546
++
current_step_
;
547
if
(
onProgress
.
hasListener
()) {
548
GUM_EMIT3
(
onProgress
,
549
(
current_step_
* 100) / (
steps_orient
+
past_steps
),
550
0.,
551
timer_
.
step
());
552
}
553
}
// while
554
555
// erasing the double headed arcs
556
for
(
auto
iter
=
_latentCouples_
.
rbegin
();
iter
!=
_latentCouples_
.
rend
(); ++
iter
) {
557
graph
.
eraseArc
(
Arc
(
iter
->
head
(),
iter
->
tail
()));
558
if
(
_existsDirectedPath_
(
graph
,
iter
->
head
(),
iter
->
tail
())) {
559
// if we find a cycle, we force the competing edge
560
graph
.
addArc
(
iter
->
head
(),
iter
->
tail
());
561
graph
.
eraseArc
(
Arc
(
iter
->
tail
(),
iter
->
head
()));
562
*
iter
=
Arc
(
iter
->
head
(),
iter
->
tail
());
563
}
564
}
565
566
if
(
onProgress
.
hasListener
()) {
GUM_EMIT3
(
onProgress
, 100, 0.,
timer_
.
step
()); }
567
}
568
569
/// finds the best contributor node for a pair given a conditioning set
570
void
Miic
::
findBestContributor_
(
NodeId
x
,
571
NodeId
y
,
572
const
std
::
vector
<
NodeId
>&
ui
,
573
const
MixedGraph
&
graph
,
574
CorrectedMutualInformation
<>&
mutualInformation
,
575
Heap
<
CondRanking
,
GreaterPairOn2nd
>&
rank
) {
576
double
maxP
= -1.0;
577
NodeId
maxZ
= 0;
578
579
// compute N
580
// __N = I.N();
581
const
double
Ixy_ui
=
mutualInformation
.
score
(
x
,
y
,
ui
);
582
583
for
(
const
NodeId
z
:
graph
) {
584
// if z!=x and z!=y and z not in ui
585
if
(
z
!=
x
&&
z
!=
y
&&
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
586
double
Pnv
;
587
double
Pb
;
588
589
// Computing Pnv
590
const
double
Ixyz_ui
=
mutualInformation
.
score
(
x
,
y
,
z
,
ui
);
591
double
calc_expo1
= -
Ixyz_ui
*
M_LN2
;
592
// if exponential are too high or to low, crop them at _maxLog_
593
if
(
calc_expo1
>
_maxLog_
) {
594
Pnv
= 0.0;
595
}
else
if
(
calc_expo1
< -
_maxLog_
) {
596
Pnv
= 1.0;
597
}
else
{
598
Pnv
= 1 / (1 +
std
::
exp
(
calc_expo1
));
599
}
600
601
// Computing Pb
602
const
double
Ixz_ui
=
mutualInformation
.
score
(
x
,
z
,
ui
);
603
const
double
Iyz_ui
=
mutualInformation
.
score
(
y
,
z
,
ui
);
604
605
calc_expo1
= -(
Ixz_ui
-
Ixy_ui
) *
M_LN2
;
606
double
calc_expo2
= -(
Iyz_ui
-
Ixy_ui
) *
M_LN2
;
607
608
// if exponential are too high or to low, crop them at _maxLog_
609
if
(
calc_expo1
>
_maxLog_
||
calc_expo2
>
_maxLog_
) {
610
Pb
= 0.0;
611
}
else
if
(
calc_expo1
< -
_maxLog_
&&
calc_expo2
< -
_maxLog_
) {
612
Pb
= 1.0;
613
}
else
{
614
double
expo1
,
expo2
;
615
if
(
calc_expo1
< -
_maxLog_
) {
616
expo1
= 0.0;
617
}
else
{
618
expo1
=
std
::
exp
(
calc_expo1
);
619
}
620
if
(
calc_expo2
< -
_maxLog_
) {
621
expo2
= 0.0;
622
}
else
{
623
expo2
=
std
::
exp
(
calc_expo2
);
624
}
625
Pb
= 1 / (1 +
expo1
+
expo2
);
626
}
627
628
// Getting max(min(Pnv, pb))
629
const
double
min_pnv_pb
=
std
::
min
(
Pnv
,
Pb
);
630
if
(
min_pnv_pb
>
maxP
) {
631
maxP
=
min_pnv_pb
;
632
maxZ
=
z
;
633
}
634
}
// if z not in (x, y)
635
}
// for z in graph.nodes
636
// storing best z in rank_
637
CondRanking
final
;
638
auto
tup
=
new
CondThreePoints
{
x
,
y
,
maxZ
,
ui
};
639
final
.
first
=
tup
;
640
final
.
second
=
maxP
;
641
rank
.
insert
(
final
);
642
}
643
644
/// gets the list of unshielded triples in the graph in decreasing value of
645
///|I'(x, y, z|{ui})|
646
std
::
vector
<
Ranking
>
Miic
::
unshieldedTriples_
(
647
const
MixedGraph
&
graph
,
648
CorrectedMutualInformation
<>&
mutualInformation
,
649
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
650
std
::
vector
<
Ranking
>
triples
;
651
for
(
NodeId
z
:
graph
) {
652
for
(
NodeId
x
:
graph
.
neighbours
(
z
)) {
653
for
(
NodeId
y
:
graph
.
neighbours
(
z
)) {
654
if
(
y
<
x
&& !
graph
.
existsEdge
(
x
,
y
)) {
655
std
::
vector
<
NodeId
>
ui
;
656
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
657
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
658
if
(
sepSet
.
exists
(
key
)) {
659
ui
=
sepSet
[
key
];
660
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
661
ui
=
sepSet
[
rev_key
];
662
}
663
// remove z from ui if it's present
664
const
auto
iter_z_place
=
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
);
665
if
(
iter_z_place
!=
ui
.
end
()) {
ui
.
erase
(
iter_z_place
); }
666
667
double
Ixyz_ui
=
mutualInformation
.
score
(
x
,
y
,
z
,
ui
);
668
Ranking
triple
;
669
auto
tup
=
new
ThreePoints
{
x
,
y
,
z
};
670
triple
.
first
=
tup
;
671
triple
.
second
=
Ixyz_ui
;
672
triples
.
push_back
(
triple
);
673
}
674
}
675
}
676
}
677
std
::
sort
(
triples
.
begin
(),
triples
.
end
(),
GreaterAbsPairOn2nd
());
678
return
triples
;
679
}
680
681
/// gets the list of unshielded triples in the graph in decreasing value of
682
///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
683
std
::
vector
<
ProbabilisticRanking
>
Miic
::
unshieldedTriplesMiic_
(
684
const
MixedGraph
&
graph
,
685
CorrectedMutualInformation
<>&
mutualInformation
,
686
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
,
687
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>&
marks
) {
688
std
::
vector
<
ProbabilisticRanking
>
triples
;
689
for
(
NodeId
z
:
graph
) {
690
for
(
NodeId
x
:
graph
.
neighbours
(
z
)) {
691
for
(
NodeId
y
:
graph
.
neighbours
(
z
)) {
692
if
(
y
<
x
&& !
graph
.
existsEdge
(
x
,
y
)) {
693
std
::
vector
<
NodeId
>
ui
;
694
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
695
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
696
if
(
sepSet
.
exists
(
key
)) {
697
ui
=
sepSet
[
key
];
698
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
699
ui
=
sepSet
[
rev_key
];
700
}
701
// remove z from ui if it's present
702
const
auto
iter_z_place
=
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
);
703
if
(
iter_z_place
!=
ui
.
end
()) {
ui
.
erase
(
iter_z_place
); }
704
705
const
double
Ixyz_ui
=
mutualInformation
.
score
(
x
,
y
,
z
,
ui
);
706
auto
tup
=
new
ThreePoints
{
x
,
y
,
z
};
707
ProbabilisticRanking
triple
{
tup
,
Ixyz_ui
, 0.5, 0.5};
708
triples
.
push_back
(
triple
);
709
if
(!
marks
.
exists
({
x
,
z
})) {
marks
.
insert
({
x
,
z
},
'o'
); }
710
if
(!
marks
.
exists
({
z
,
x
})) {
marks
.
insert
({
z
,
x
},
'o'
); }
711
if
(!
marks
.
exists
({
y
,
z
})) {
marks
.
insert
({
y
,
z
},
'o'
); }
712
if
(!
marks
.
exists
({
z
,
y
})) {
marks
.
insert
({
z
,
y
},
'o'
); }
713
}
714
}
715
}
716
}
717
triples
=
updateProbaTriples_
(
graph
,
triples
);
718
std
::
sort
(
triples
.
begin
(),
triples
.
end
(),
GreaterTupleOnLast
());
719
return
triples
;
720
}
721
722
/// Gets the orientation probabilities like MIIC for the orientation phase
723
std
::
vector
<
ProbabilisticRanking
>
724
Miic
::
updateProbaTriples_
(
const
MixedGraph
&
graph
,
725
std
::
vector
<
ProbabilisticRanking
>
probaTriples
) {
726
for
(
auto
&
triple
:
probaTriples
) {
727
NodeId
x
,
y
,
z
;
728
x
=
std
::
get
< 0 >(*
std
::
get
< 0 >(
triple
));
729
y
=
std
::
get
< 1 >(*
std
::
get
< 0 >(
triple
));
730
z
=
std
::
get
< 2 >(*
std
::
get
< 0 >(
triple
));
731
const
double
Ixyz
=
std
::
get
< 1 >(
triple
);
732
double
Pxz
=
std
::
get
< 2 >(
triple
);
733
double
Pyz
=
std
::
get
< 3 >(
triple
);
734
735
if
(
Ixyz
<= 0) {
736
const
double
expo
=
std
::
exp
(
Ixyz
);
737
const
double
P0
= (1 +
expo
) / (1 + 3 *
expo
);
738
// distinguish between the initialization and the update process
739
if
(
Pxz
==
Pyz
&&
Pyz
== 0.5) {
740
std
::
get
< 2 >(
triple
) =
P0
;
741
std
::
get
< 3 >(
triple
) =
P0
;
742
}
else
{
743
if
(
graph
.
existsArc
(
x
,
z
) &&
Pxz
>=
P0
) {
744
std
::
get
< 3 >(
triple
) =
Pxz
* (1 / (1 +
expo
) - 0.5) + 0.5;
745
}
else
if
(
graph
.
existsArc
(
y
,
z
) &&
Pyz
>=
P0
) {
746
std
::
get
< 2 >(
triple
) =
Pyz
* (1 / (1 +
expo
) - 0.5) + 0.5;
747
}
748
}
749
}
else
{
750
const
double
expo
=
std
::
exp
(-
Ixyz
);
751
if
(
graph
.
existsArc
(
x
,
z
) &&
Pxz
>= 0.5) {
752
std
::
get
< 3 >(
triple
) =
Pxz
* (1 / (1 +
expo
) - 0.5) + 0.5;
753
}
else
if
(
graph
.
existsArc
(
y
,
z
) &&
Pyz
>= 0.5) {
754
std
::
get
< 2 >(
triple
) =
Pyz
* (1 / (1 +
expo
) - 0.5) + 0.5;
755
}
756
}
757
}
758
std
::
sort
(
probaTriples
.
begin
(),
probaTriples
.
end
(),
GreaterTupleOnLast
());
759
return
probaTriples
;
760
}
761
762
/// learns the structure of an Bayesian network, ie a DAG, from an Essential
763
/// graph.
764
DAG
Miic
::
learnStructure
(
CorrectedMutualInformation
<>&
I
,
MixedGraph
initialGraph
) {
765
MixedGraph
essentialGraph
=
learnMixedStructure
(
I
,
initialGraph
);
766
// orientate remaining edges
767
768
const
Sequence
<
NodeId
>
order
=
essentialGraph
.
topologicalOrder
();
769
770
// first, forbidden arcs force arc in the other direction
771
for
(
NodeId
x
:
order
) {
772
const
auto
nei_x
=
essentialGraph
.
neighbours
(
x
);
773
for
(
NodeId
y
:
nei_x
)
774
if
(
isForbidenArc_
(
x
,
y
)) {
775
essentialGraph
.
eraseEdge
(
Edge
(
x
,
y
));
776
if
(
isForbidenArc_
(
y
,
x
)) {
777
GUM_TRACE
(
"Neither arc allowed for edge ("
<<
x
<<
","
<<
y
<<
")"
)
778
}
else
{
779
GUM_TRACE
(
"Forced orientation : "
<<
y
<<
"->"
<<
x
)
780
essentialGraph
.
addArc
(
y
,
x
);
781
}
782
}
else
if
(
isForbidenArc_
(
y
,
x
)) {
783
essentialGraph
.
eraseEdge
(
Edge
(
x
,
y
));
784
GUM_TRACE
(
"Forced orientation : "
<<
x
<<
"->"
<<
y
)
785
essentialGraph
.
addArc
(
x
,
y
);
786
}
787
}
788
GUM_TRACE
(
essentialGraph
.
toDot
());
789
790
// first, propagate existing orientations
791
bool
newOrientation
=
true
;
792
while
(
newOrientation
) {
793
newOrientation
=
false
;
794
for
(
NodeId
x
:
order
) {
795
if
(!
essentialGraph
.
parents
(
x
).
empty
()) {
796
newOrientation
|=
propagatesRemainingOrientableEdges_
(
essentialGraph
,
x
);
797
}
798
}
799
}
800
GUM_TRACE
(
essentialGraph
.
toDot
());
801
propagatesOrientationInChainOfRemainingEdges_
(
essentialGraph
);
802
GUM_TRACE
(
essentialGraph
.
toDot
());
803
804
// then decide the orientation for double arcs
805
for
(
NodeId
x
:
order
)
806
for
(
NodeId
y
:
essentialGraph
.
parents
(
x
))
807
if
(
essentialGraph
.
parents
(
y
).
contains
(
x
)) {
808
GUM_TRACE
(
" + Resolving double arcs (poorly)"
)
809
essentialGraph
.
eraseArc
(
Arc
(
y
,
x
));
810
}
811
812
DAG
dag
;
813
for
(
auto
node
:
essentialGraph
) {
814
dag
.
addNodeWithId
(
node
);
815
}
816
for
(
const
Arc
&
arc
:
essentialGraph
.
arcs
()) {
817
dag
.
addArc
(
arc
.
tail
(),
arc
.
head
());
818
}
819
820
return
dag
;
821
}
822
823
bool
Miic
::
isOrientable_
(
const
MixedGraph
&
graph
,
NodeId
xi
,
NodeId
xj
)
const
{
824
// no cycle
825
if
(
_existsDirectedPath_
(
graph
,
xj
,
xi
)) {
826
GUM_TRACE
(
"cycle("
<<
xi
<<
"-"
<<
xj
<<
")"
)
827
return
false
;
828
}
829
830
// R1
831
if
(!(
graph
.
parents
(
xi
) -
graph
.
adjacents
(
xj
)).
empty
()) {
832
GUM_TRACE
(
"R1("
<<
xi
<<
"-"
<<
xj
<<
")"
)
833
return
true
;
834
}
835
836
// R2
837
if
(
_existsDirectedPath_
(
graph
,
xi
,
xj
)) {
838
GUM_TRACE
(
"R2("
<<
xi
<<
"-"
<<
xj
<<
")"
)
839
return
true
;
840
}
841
842
// R3
843
int
nbr
= 0;
844
for
(
const
auto
p
:
graph
.
parents
(
xj
)) {
845
if
(!
graph
.
mixedOrientedPath
(
xi
,
p
).
empty
()) {
846
nbr
+= 1;
847
if
(
nbr
== 2) {
848
GUM_TRACE
(
"R3("
<<
xi
<<
"-"
<<
xj
<<
")"
)
849
return
true
;
850
}
851
}
852
}
853
return
false
;
854
}
855
856
void
Miic
::
propagatesOrientationInChainOfRemainingEdges_
(
MixedGraph
&
essentialGraph
) {
857
// then decide the orientation for remaining edges
858
while
(!
essentialGraph
.
edges
().
empty
()) {
859
const
auto
&
edge
= *(
essentialGraph
.
edges
().
begin
());
860
NodeId
root
=
edge
.
first
();
861
Size
size_children_root
=
essentialGraph
.
children
(
root
).
size
();
862
NodeSet
visited
;
863
NodeSet
stack
{
root
};
864
// check the best root for the set of neighbours
865
while
(!
stack
.
empty
()) {
866
NodeId
next
= *(
stack
.
begin
());
867
stack
.
erase
(
next
);
868
if
(
visited
.
contains
(
next
))
continue
;
869
if
(
essentialGraph
.
children
(
next
).
size
() >
size_children_root
) {
870
size_children_root
=
essentialGraph
.
children
(
next
).
size
();
871
root
=
next
;
872
}
873
for
(
const
auto
n
:
essentialGraph
.
neighbours
(
next
))
874
if
(!
stack
.
contains
(
n
) && !
visited
.
contains
(
n
))
stack
.
insert
(
n
);
875
visited
.
insert
(
next
);
876
}
877
// orientation now
878
visited
.
clear
();
879
stack
.
clear
();
880
stack
.
insert
(
root
);
881
while
(!
stack
.
empty
()) {
882
NodeId
next
= *(
stack
.
begin
());
883
stack
.
erase
(
next
);
884
if
(
visited
.
contains
(
next
))
continue
;
885
const
auto
nei
=
essentialGraph
.
neighbours
(
next
);
886
for
(
const
auto
n
:
nei
) {
887
if
(!
stack
.
contains
(
n
) && !
visited
.
contains
(
n
))
stack
.
insert
(
n
);
888
GUM_TRACE
(
" + amap reasonably orientation for "
<<
n
<<
"->"
<<
next
);
889
essentialGraph
.
eraseEdge
(
Edge
(
n
,
next
));
890
essentialGraph
.
addArc
(
n
,
next
);
891
}
892
visited
.
insert
(
next
);
893
}
894
}
895
}
896
897
/// Propagates the orientation from a node to its neighbours
898
bool
Miic
::
propagatesRemainingOrientableEdges_
(
MixedGraph
&
graph
,
NodeId
xj
) {
899
bool
res
=
false
;
900
const
auto
neighbours
=
graph
.
neighbours
(
xj
);
901
for
(
auto
&
xi
:
neighbours
) {
902
bool
i_j
=
isOrientable_
(
graph
,
xi
,
xj
);
903
bool
j_i
=
isOrientable_
(
graph
,
xj
,
xi
);
904
if
(
i_j
||
j_i
) {
905
GUM_TRACE
(
" + Removing edge ("
<<
xi
<<
","
<<
xj
<<
")"
)
906
graph
.
eraseEdge
(
Edge
(
xi
,
xj
));
907
res
=
true
;
908
}
909
if
(
i_j
) {
910
GUM_TRACE
(
" + add arc ("
<<
xi
<<
","
<<
xj
<<
")"
)
911
graph
.
addArc
(
xi
,
xj
);
912
propagatesRemainingOrientableEdges_
(
graph
,
xj
);
913
}
914
if
(
j_i
) {
915
GUM_TRACE
(
" + add arc ("
<<
xi
<<
","
<<
xj
<<
")"
)
916
graph
.
addArc
(
xj
,
xi
);
917
propagatesRemainingOrientableEdges_
(
graph
,
xi
);
918
}
919
if
(
i_j
&&
j_i
) {
920
GUM_TRACE
(
" + add arc ("
<<
xi
<<
","
<<
xj
<<
")"
)
921
_latentCouples_
.
emplace_back
(
xi
,
xj
);
922
}
923
}
924
925
return
res
;
926
}
927
928
/// get the list of arcs hiding latent variables
929
const
std
::
vector
<
Arc
>
Miic
::
latentVariables
()
const
{
return
_latentCouples_
; }
930
931
/// learns the structure and the parameters of a BN
932
template
<
typename
GUM_SCALAR
,
typename
GRAPH_CHANGES_SELECTOR
,
typename
PARAM_ESTIMATOR
>
933
BayesNet
<
GUM_SCALAR
>
Miic
::
learnBN
(
GRAPH_CHANGES_SELECTOR
&
selector
,
934
PARAM_ESTIMATOR
&
estimator
,
935
DAG
initial_dag
) {
936
return
DAG2BNLearner
<>::
createBN
<
GUM_SCALAR
>(
estimator
,
937
learnStructure
(
selector
,
initial_dag
));
938
}
939
940
void
Miic
::
setMiicBehaviour
() {
this
->
_useMiic_
=
true
; }
941
942
void
Miic
::
set3of2Behaviour
() {
this
->
_useMiic_
=
false
; }
943
944
void
Miic
::
addConstraints
(
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>
constraints
) {
945
this
->
_initialMarks_
=
constraints
;
946
}
947
948
bool
Miic
::
_existsNonTrivialDirectedPath_
(
const
MixedGraph
&
graph
,
949
const
NodeId
n1
,
950
const
NodeId
n2
) {
951
for
(
const
auto
parent
:
graph
.
parents
(
n2
)) {
952
if
(
graph
.
existsArc
(
parent
,
953
n2
))
// if there is a double arc, pass
954
continue
;
955
if
(
parent
==
n1
)
// trivial directed path => not recognized
956
continue
;
957
if
(
_existsDirectedPath_
(
graph
,
n1
,
parent
))
return
true
;
958
}
959
return
false
;
960
}
961
962
bool
Miic
::
_existsDirectedPath_
(
const
MixedGraph
&
graph
,
const
NodeId
n1
,
const
NodeId
n2
) {
963
// not recursive version => use a FIFO for simulating the recursion
964
List
<
NodeId
>
nodeFIFO
;
965
// mark[node] = successor if visited, else mark[node] does not exist
966
Set
<
NodeId
>
mark
;
967
968
mark
.
insert
(
n2
);
969
nodeFIFO
.
pushBack
(
n2
);
970
971
NodeId
current
;
972
973
while
(!
nodeFIFO
.
empty
()) {
974
current
=
nodeFIFO
.
front
();
975
nodeFIFO
.
popFront
();
976
977
// check the parents
978
for
(
const
auto
new_one
:
graph
.
parents
(
current
)) {
979
if
(
graph
.
existsArc
(
current
,
980
new_one
))
// if there is a double arc, pass
981
continue
;
982
983
if
(
new_one
==
n1
) {
return
true
; }
984
985
if
(
mark
.
exists
(
new_one
))
// if this node is already marked, do not
986
continue
;
// check it again
987
988
mark
.
insert
(
new_one
);
989
nodeFIFO
.
pushBack
(
new_one
);
990
}
991
}
992
993
return
false
;
994
}
995
996
void
Miic
::
_orientingVstructureMiic_
(
MixedGraph
&
graph
,
997
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>&
marks
,
998
NodeId
x
,
999
NodeId
y
,
1000
NodeId
z
,
1001
double
p1
,
1002
double
p2
) {
1003
// v-structure discovery
1004
if
(
marks
[{
x
,
z
}] ==
'o'
&&
marks
[{
y
,
z
}] ==
'o'
) {
// If x-z-y
1005
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
x
)) {
1006
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1007
graph
.
addArc
(
x
,
z
);
1008
GUM_TRACE
(
"1.a Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1009
GUM_TRACE
(
"1.a Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1010
marks
[{
x
,
z
}] =
'>'
;
1011
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
1012
GUM_TRACE
(
"Adding latent couple ("
<<
z
<<
","
<<
x
<<
")"
)
1013
_latentCouples_
.
emplace_back
(
z
,
x
);
1014
}
1015
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1016
}
else
{
1017
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1018
GUM_TRACE
(
"1.b Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1019
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
x
,
z
)) {
1020
graph
.
addArc
(
z
,
x
);
1021
GUM_TRACE
(
"1.b Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1022
marks
[{
z
,
x
}] =
'>'
;
1023
}
1024
}
1025
1026
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
y
)) {
1027
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1028
graph
.
addArc
(
y
,
z
);
1029
GUM_TRACE
(
"1.c Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1030
GUM_TRACE
(
"1.c Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1031
marks
[{
y
,
z
}] =
'>'
;
1032
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
1033
_latentCouples_
.
emplace_back
(
z
,
y
);
1034
}
1035
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1036
}
else
{
1037
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1038
GUM_TRACE
(
"1.d Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1039
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
y
,
z
)) {
1040
graph
.
addArc
(
z
,
y
);
1041
GUM_TRACE
(
"1.d Adding arc ("
<<
z
<<
","
<<
y
<<
")"
)
1042
marks
[{
z
,
y
}] =
'>'
;
1043
}
1044
}
1045
}
else
if
(
marks
[{
x
,
z
}] ==
'>'
&&
marks
[{
y
,
z
}] ==
'o'
) {
// If x->z-y
1046
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
y
)) {
1047
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1048
graph
.
addArc
(
y
,
z
);
1049
GUM_TRACE
(
"2.a Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1050
GUM_TRACE
(
"2.a Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1051
marks
[{
y
,
z
}] =
'>'
;
1052
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
1053
_latentCouples_
.
emplace_back
(
z
,
y
);
1054
}
1055
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1056
}
else
{
1057
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1058
GUM_TRACE
(
"2.b Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1059
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
y
,
z
)) {
1060
graph
.
addArc
(
z
,
y
);
1061
GUM_TRACE
(
"2.b Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1062
marks
[{
z
,
y
}] =
'>'
;
1063
}
1064
}
1065
}
else
if
(
marks
[{
y
,
z
}] ==
'>'
&&
marks
[{
x
,
z
}] ==
'o'
) {
1066
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
x
)) {
1067
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1068
graph
.
addArc
(
x
,
z
);
1069
GUM_TRACE
(
"3.a Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1070
GUM_TRACE
(
"3.a Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1071
marks
[{
x
,
z
}] =
'>'
;
1072
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
1073
_latentCouples_
.
emplace_back
(
z
,
x
);
1074
}
1075
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1076
}
else
{
1077
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1078
GUM_TRACE
(
"3.b Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1079
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
x
,
z
)) {
1080
graph
.
addArc
(
z
,
x
);
1081
GUM_TRACE
(
"3.b Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1082
marks
[{
z
,
x
}] =
'>'
;
1083
}
1084
}
1085
}
1086
}
1087
1088
1089
void
Miic
::
_propagatingOrientationMiic_
(
MixedGraph
&
graph
,
1090
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>&
marks
,
1091
NodeId
x
,
1092
NodeId
y
,
1093
NodeId
z
,
1094
double
p1
,
1095
double
p2
) {
1096
// orientation propagation
1097
if
(
marks
[{
x
,
z
}] ==
'>'
&&
marks
[{
y
,
z
}] ==
'o'
&&
marks
[{
z
,
y
}] !=
'-'
) {
1098
graph
.
eraseEdge
(
Edge
(
z
,
y
));
1099
// std::cout << "4. Removing edge (" << z << "," << y << ")" <<
1100
// std::endl;
1101
if
(!
_existsDirectedPath_
(
graph
,
y
,
z
) &&
graph
.
parents
(
y
).
empty
()) {
1102
graph
.
addArc
(
z
,
y
);
1103
GUM_TRACE
(
"4.a Adding arc ("
<<
z
<<
","
<<
y
<<
")"
)
1104
marks
[{
z
,
y
}] =
'>'
;
1105
marks
[{
y
,
z
}] =
'-'
;
1106
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
y
)))
_arcProbas_
.
insert
(
Arc
(
z
,
y
),
p2
);
1107
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
y
) &&
graph
.
parents
(
z
).
empty
()) {
1108
graph
.
addArc
(
y
,
z
);
1109
GUM_TRACE
(
"4.b Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1110
marks
[{
z
,
y
}] =
'-'
;
1111
marks
[{
y
,
z
}] =
'>'
;
1112
_latentCouples_
.
emplace_back
(
y
,
z
);
1113
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1114
}
else
if
(!
_existsDirectedPath_
(
graph
,
y
,
z
)) {
1115
graph
.
addArc
(
z
,
y
);
1116
GUM_TRACE
(
"4.c Adding arc ("
<<
z
<<
","
<<
y
<<
")"
)
1117
marks
[{
z
,
y
}] =
'>'
;
1118
marks
[{
y
,
z
}] =
'-'
;
1119
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
y
)))
_arcProbas_
.
insert
(
Arc
(
z
,
y
),
p2
);
1120
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
y
)) {
1121
graph
.
addArc
(
y
,
z
);
1122
GUM_TRACE
(
"4.d Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1123
_latentCouples_
.
emplace_back
(
y
,
z
);
1124
marks
[{
z
,
y
}] =
'-'
;
1125
marks
[{
y
,
z
}] =
'>'
;
1126
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1127
}
1128
}
else
if
(
marks
[{
y
,
z
}] ==
'>'
&&
marks
[{
x
,
z
}] ==
'o'
&&
marks
[{
z
,
x
}] !=
'-'
) {
1129
graph
.
eraseEdge
(
Edge
(
z
,
x
));
1130
GUM_TRACE
(
"5. Removing edge ("
<<
z
<<
","
<<
x
<<
")"
)
1131
if
(!
_existsDirectedPath_
(
graph
,
x
,
z
) &&
graph
.
parents
(
x
).
empty
()) {
1132
graph
.
addArc
(
z
,
x
);
1133
GUM_TRACE
(
"5.a Adding arc ("
<<
z
<<
","
<<
x
<<
")"
)
1134
marks
[{
z
,
x
}] =
'>'
;
1135
marks
[{
x
,
z
}] =
'-'
;
1136
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
x
)))
_arcProbas_
.
insert
(
Arc
(
z
,
x
),
p1
);
1137
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
x
) &&
graph
.
parents
(
z
).
empty
()) {
1138
graph
.
addArc
(
x
,
z
);
1139
GUM_TRACE
(
"5.b Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1140
marks
[{
z
,
x
}] =
'-'
;
1141
marks
[{
x
,
z
}] =
'>'
;
1142
_latentCouples_
.
emplace_back
(
x
,
z
);
1143
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1144
}
else
if
(!
_existsDirectedPath_
(
graph
,
x
,
z
)) {
1145
graph
.
addArc
(
z
,
x
);
1146
GUM_TRACE
(
"5.c Adding arc ("
<<
z
<<
","
<<
x
<<
")"
)
1147
marks
[{
z
,
x
}] =
'>'
;
1148
marks
[{
x
,
z
}] =
'-'
;
1149
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
x
)))
_arcProbas_
.
insert
(
Arc
(
z
,
x
),
p1
);
1150
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
x
)) {
1151
graph
.
addArc
(
x
,
z
);
1152
GUM_TRACE
(
"5.d Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1153
marks
[{
z
,
x
}] =
'-'
;
1154
marks
[{
x
,
z
}] =
'>'
;
1155
_latentCouples_
.
emplace_back
(
x
,
z
);
1156
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1157
}
1158
}
1159
}
1160
1161
bool
Miic
::
_isNotLatentCouple_
(
const
NodeId
x
,
const
NodeId
y
) {
1162
const
auto
&
lbeg
=
_latentCouples_
.
begin
();
1163
const
auto
&
lend
=
_latentCouples_
.
end
();
1164
1165
return
(
std
::
find
(
lbeg
,
lend
,
Arc
(
x
,
y
)) ==
lend
)
1166
&& (
std
::
find
(
lbeg
,
lend
,
Arc
(
y
,
x
)) ==
lend
);
1167
}
1168
1169
bool
Miic
::
isForbidenArc_
(
NodeId
x
,
NodeId
y
)
const
{
1170
return
(
_initialMarks_
.
exists
({
x
,
y
}) &&
_initialMarks_
[{
x
,
y
}] ==
'-'
);
1171
}
1172
}
/* namespace learning */
1173
1174
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31