aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
Miic.cpp
Go to the documentation of this file.
1
2
/**
3
*
4
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe
5
* GONZALES(@AMU) info_at_agrum_dot_org
6
*
7
* This library is free software: you can redistribute it and/or modify
8
* it under the terms of the GNU Lesser General Public License as published by
9
* the Free Software Foundation, either version 3 of the License, or
10
* (at your option) any later version.
11
*
12
* This library is distributed in the hope that it will be useful,
13
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
* GNU Lesser General Public License for more details.
16
*
17
* You should have received a copy of the GNU Lesser General Public License
18
* along with this library. If not, see <http://www.gnu.org/licenses/>.
19
*
20
*/
21
22
23
/** @file
24
* @brief Implementation of gum::learning::ThreeOffTwo and MIIC
25
*
26
* @author Quentin FALCAND, Marvin LASSERRE and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
29
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
30
#
include
<
agrum
/
tools
/
core
/
hashTable
.
h
>
31
#
include
<
agrum
/
tools
/
core
/
heap
.
h
>
32
#
include
<
agrum
/
tools
/
core
/
timer
.
h
>
33
#
include
<
agrum
/
tools
/
graphs
/
mixedGraph
.
h
>
34
#
include
<
agrum
/
BN
/
learning
/
Miic
.
h
>
35
#
include
<
agrum
/
BN
/
learning
/
paramUtils
/
DAG2BNLearner
.
h
>
36
#
include
<
agrum
/
tools
/
stattests
/
correctedMutualInformation
.
h
>
37
38
39
namespace
gum
{
40
41
namespace
learning
{
42
43
/// default constructor
44
Miic
::
Miic
() :
_maxLog_
(100),
_size_
(0) {
GUM_CONSTRUCTOR
(
Miic
); }
45
46
/// default constructor with maxLog
47
Miic
::
Miic
(
int
maxLog
) :
_maxLog_
(
maxLog
),
_size_
(0) {
GUM_CONSTRUCTOR
(
Miic
); }
48
49
/// copy constructor
50
Miic
::
Miic
(
const
Miic
&
from
) :
ApproximationScheme
(
from
),
_size_
(
from
.
_size_
) {
51
GUM_CONS_CPY
(
Miic
);
52
}
53
54
/// move constructor
55
Miic
::
Miic
(
Miic
&&
from
) :
ApproximationScheme
(
std
::
move
(
from
)),
_size_
(
from
.
_size_
) {
56
GUM_CONS_MOV
(
Miic
);
57
}
58
59
/// destructor
60
Miic
::~
Miic
() {
GUM_DESTRUCTOR
(
Miic
); }
61
62
/// copy operator
63
Miic
&
Miic
::
operator
=(
const
Miic
&
from
) {
64
ApproximationScheme
::
operator
=(
from
);
65
return
*
this
;
66
}
67
68
/// move operator
69
Miic
&
Miic
::
operator
=(
Miic
&&
from
) {
70
ApproximationScheme
::
operator
=(
std
::
move
(
from
));
71
return
*
this
;
72
}
73
74
75
bool
GreaterPairOn2nd
::
operator
()(
const
CondRanking
&
e1
,
const
CondRanking
&
e2
)
const
{
76
return
e1
.
second
>
e2
.
second
;
77
}
78
79
bool
GreaterAbsPairOn2nd
::
operator
()(
const
Ranking
&
e1
,
const
Ranking
&
e2
)
const
{
80
return
std
::
abs
(
e1
.
second
) >
std
::
abs
(
e2
.
second
);
81
}
82
83
bool
GreaterTupleOnLast
::
operator
()(
const
ProbabilisticRanking
&
e1
,
84
const
ProbabilisticRanking
&
e2
)
const
{
85
double
p1xz
=
std
::
get
< 2 >(
e1
);
86
double
p1yz
=
std
::
get
< 3 >(
e1
);
87
double
p2xz
=
std
::
get
< 2 >(
e2
);
88
double
p2yz
=
std
::
get
< 3 >(
e2
);
89
double
I1
=
std
::
get
< 1 >(
e1
);
90
double
I2
=
std
::
get
< 1 >(
e2
);
91
// First, we look at the sign of information.
92
// Then, the probability values
93
// and finally the abs value of information.
94
if
((
I1
< 0 &&
I2
< 0) || (
I1
>= 0 &&
I2
>= 0)) {
95
if
(
std
::
max
(
p1xz
,
p1yz
) ==
std
::
max
(
p2xz
,
p2yz
)) {
96
return
std
::
abs
(
I1
) >
std
::
abs
(
I2
);
97
}
else
{
98
return
std
::
max
(
p1xz
,
p1yz
) >
std
::
max
(
p2xz
,
p2yz
);
99
}
100
}
else
{
101
return
I1
<
I2
;
102
}
103
}
104
105
/// learns the structure of a MixedGraph
106
MixedGraph
Miic
::
learnMixedStructure
(
CorrectedMutualInformation
<>&
mutualInformation
,
107
MixedGraph
graph
) {
108
timer_
.
reset
();
109
current_step_
= 0;
110
111
// clear the vector of latent arcs to be sure
112
_latentCouples_
.
clear
();
113
114
/// the heap of ranks, with the score, and the NodeIds of x, y and z.
115
Heap
<
CondRanking
,
GreaterPairOn2nd
>
rank
;
116
117
/// the variables separation sets
118
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >
sep_set
;
119
120
initiation_
(
mutualInformation
,
graph
,
sep_set
,
rank
);
121
122
iteration_
(
mutualInformation
,
graph
,
sep_set
,
rank
);
123
124
// std::cout << "Le graphe contient: " << graph.sizeEdges() << " edges." <<
125
// std::endl; std::cout << "En voici la liste: " << graph.edges() <<
126
// std::endl;
127
128
if
(
_useMiic_
) {
129
orientationMiic_
(
mutualInformation
,
graph
,
sep_set
);
130
}
else
{
131
orientation3off2_
(
mutualInformation
,
graph
,
sep_set
);
132
}
133
134
return
graph
;
135
}
136
137
/*
138
* PHASE 1 : INITIATION
139
*
140
* We go over all edges and test if the variables are independent. If they
141
* are,
142
* the edge is deleted. If not, the best contributor is found.
143
*/
144
void
Miic
::
initiation_
(
CorrectedMutualInformation
<>&
mutualInformation
,
145
MixedGraph
&
graph
,
146
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
,
147
Heap
<
CondRanking
,
GreaterPairOn2nd
>&
rank
) {
148
NodeId
x
,
y
;
149
EdgeSet
edges
=
graph
.
edges
();
150
Size
steps_init
=
edges
.
size
();
151
152
for
(
const
Edge
&
edge
:
edges
) {
153
x
=
edge
.
first
();
154
y
=
edge
.
second
();
155
double
Ixy
=
mutualInformation
.
score
(
x
,
y
);
156
157
if
(
Ixy
<= 0) {
//< K
158
graph
.
eraseEdge
(
edge
);
159
sepSet
.
insert
(
std
::
make_pair
(
x
,
y
),
_emptySet_
);
160
}
else
{
161
findBestContributor_
(
x
,
y
,
_emptySet_
,
graph
,
mutualInformation
,
rank
);
162
}
163
164
++
current_step_
;
165
if
(
onProgress
.
hasListener
()) {
166
GUM_EMIT3
(
onProgress
, (
current_step_
* 33) /
steps_init
, 0.,
timer_
.
step
());
167
}
168
}
169
}
170
171
/*
172
* PHASE 2 : ITERATION
173
*
174
* As long as we find important nodes for edges, we go over them to see if
175
* we can assess the independence of the variables.
176
*/
177
void
Miic
::
iteration_
(
CorrectedMutualInformation
<>&
mutualInformation
,
178
MixedGraph
&
graph
,
179
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
,
180
Heap
<
CondRanking
,
GreaterPairOn2nd
>&
rank
) {
181
// if no triples to further examine pass
182
CondRanking
best
;
183
184
Size
steps_init
=
current_step_
;
185
Size
steps_iter
=
rank
.
size
();
186
187
try
{
188
while
(
rank
.
top
().
second
> 0.5) {
189
best
=
rank
.
pop
();
190
191
const
NodeId
x
=
std
::
get
< 0 >(*(
best
.
first
));
192
const
NodeId
y
=
std
::
get
< 1 >(*(
best
.
first
));
193
const
NodeId
z
=
std
::
get
< 2 >(*(
best
.
first
));
194
std
::
vector
<
NodeId
>
ui
=
std
::
move
(
std
::
get
< 3 >(*(
best
.
first
)));
195
196
ui
.
push_back
(
z
);
197
const
double
i_xy_ui
=
mutualInformation
.
score
(
x
,
y
,
ui
);
198
if
(
i_xy_ui
< 0) {
199
graph
.
eraseEdge
(
Edge
(
x
,
y
));
200
sepSet
.
insert
(
std
::
make_pair
(
x
,
y
),
std
::
move
(
ui
));
201
}
else
{
202
findBestContributor_
(
x
,
y
,
ui
,
graph
,
mutualInformation
,
rank
);
203
}
204
205
delete
best
.
first
;
206
207
++
current_step_
;
208
if
(
onProgress
.
hasListener
()) {
209
GUM_EMIT3
(
onProgress
,
210
(
current_step_
* 66) / (
steps_init
+
steps_iter
),
211
0.,
212
timer_
.
step
());
213
}
214
}
215
}
catch
(...) {}
// here, rank is empty
216
current_step_
=
steps_init
+
steps_iter
;
217
if
(
onProgress
.
hasListener
()) {
GUM_EMIT3
(
onProgress
, 66, 0.,
timer_
.
step
()); }
218
current_step_
=
steps_init
+
steps_iter
;
219
}
220
221
/*
222
* PHASE 3 : ORIENTATION
223
*
224
* Try to assess v-structures and propagate them.
225
*/
226
void
Miic
::
orientation3off2_
(
227
CorrectedMutualInformation
<>&
mutualInformation
,
228
MixedGraph
&
graph
,
229
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
230
std
::
vector
<
Ranking
>
triples
=
unshieldedTriples_
(
graph
,
mutualInformation
,
sepSet
);
231
Size
steps_orient
=
triples
.
size
();
232
Size
past_steps
=
current_step_
;
233
234
// marks always correspond to the head of the arc/edge. - is for a forbidden
235
// arc, > for a mandatory arc
236
// we start by adding the mandatory arcs
237
for
(
auto
iter
=
_initialMarks_
.
begin
();
iter
!=
_initialMarks_
.
end
(); ++
iter
) {
238
if
(
graph
.
existsEdge
(
iter
.
key
().
first
,
iter
.
key
().
second
) &&
iter
.
val
() ==
'>'
) {
239
graph
.
eraseEdge
(
Edge
(
iter
.
key
().
first
,
iter
.
key
().
second
));
240
graph
.
addArc
(
iter
.
key
().
first
,
iter
.
key
().
second
);
241
}
242
}
243
244
NodeId
i
= 0;
245
// list of elements that we shouldnt read again, ie elements that are
246
// eligible to
247
// rule 0 after the first time they are tested, and elements on which rule 1
248
// has been applied
249
while
(
i
<
triples
.
size
()) {
250
// if i not in do_not_reread
251
Ranking
triple
=
triples
[
i
];
252
NodeId
x
,
y
,
z
;
253
x
=
std
::
get
< 0 >(*
triple
.
first
);
254
y
=
std
::
get
< 1 >(*
triple
.
first
);
255
z
=
std
::
get
< 2 >(*
triple
.
first
);
256
257
std
::
vector
<
NodeId
>
ui
;
258
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
259
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
260
if
(
sepSet
.
exists
(
key
)) {
261
ui
=
sepSet
[
key
];
262
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
263
ui
=
sepSet
[
rev_key
];
264
}
265
double
Ixyz_ui
=
triple
.
second
;
266
bool
reset
{
false
};
267
// try Rule 0
268
if
(
Ixyz_ui
< 0) {
269
// if ( z not in Sep[x,y])
270
if
(
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
271
if
(!
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
x
)) {
272
// when we try to add an arc to the graph, we always verify if
273
// we are allowed to do so, ie it is not a forbidden arc an it
274
// does not create a cycle
275
if
(!
_existsDirectedPath_
(
graph
,
z
,
x
) && !
isForbidenArc_
(
x
,
z
)) {
276
reset
=
true
;
277
graph
.
eraseEdge
(
Edge
(
x
,
z
));
278
graph
.
addArc
(
x
,
z
);
279
}
else
if
(
_existsDirectedPath_
(
graph
,
z
,
x
) && !
isForbidenArc_
(
z
,
x
)) {
280
reset
=
true
;
281
graph
.
eraseEdge
(
Edge
(
x
,
z
));
282
// if we find a cycle, we force the competing edge
283
graph
.
addArc
(
z
,
x
);
284
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
z
,
x
))
285
==
_latentCouples_
.
end
()) {
286
_latentCouples_
.
emplace_back
(
z
,
x
);
287
}
288
}
289
}
else
if
(!
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
y
)) {
290
if
(!
_existsDirectedPath_
(
graph
,
z
,
y
) && !
isForbidenArc_
(
x
,
z
)) {
291
reset
=
true
;
292
graph
.
eraseEdge
(
Edge
(
y
,
z
));
293
graph
.
addArc
(
y
,
z
);
294
}
else
if
(
_existsDirectedPath_
(
graph
,
z
,
y
) && !
isForbidenArc_
(
z
,
y
)) {
295
reset
=
true
;
296
graph
.
eraseEdge
(
Edge
(
y
,
z
));
297
// if we find a cycle, we force the competing edge
298
graph
.
addArc
(
z
,
y
);
299
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
z
,
y
))
300
==
_latentCouples_
.
end
()) {
301
_latentCouples_
.
emplace_back
(
z
,
y
);
302
}
303
}
304
}
else
{
305
// checking if the anti-directed arc already exists, to register a
306
// latent variable
307
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
308
_latentCouples_
.
emplace_back
(
z
,
x
);
309
}
310
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
311
_latentCouples_
.
emplace_back
(
z
,
y
);
312
}
313
}
314
}
315
}
else
{
// try Rule 1
316
if
(
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
y
) && !
graph
.
existsArc
(
y
,
z
)) {
317
if
(!
_existsDirectedPath_
(
graph
,
y
,
z
) && !
isForbidenArc_
(
z
,
y
)) {
318
reset
=
true
;
319
graph
.
eraseEdge
(
Edge
(
z
,
y
));
320
graph
.
addArc
(
z
,
y
);
321
}
else
if
(
_existsDirectedPath_
(
graph
,
y
,
z
) && !
isForbidenArc_
(
y
,
z
)) {
322
reset
=
true
;
323
graph
.
eraseEdge
(
Edge
(
z
,
y
));
324
// if we find a cycle, we force the competing edge
325
graph
.
addArc
(
y
,
z
);
326
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
y
,
z
))
327
==
_latentCouples_
.
end
()) {
328
_latentCouples_
.
emplace_back
(
y
,
z
);
329
}
330
}
331
}
332
if
(
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
x
) && !
graph
.
existsArc
(
x
,
z
)) {
333
if
(!
_existsDirectedPath_
(
graph
,
x
,
z
) && !
isForbidenArc_
(
z
,
x
)) {
334
reset
=
true
;
335
graph
.
eraseEdge
(
Edge
(
z
,
x
));
336
graph
.
addArc
(
z
,
x
);
337
}
else
if
(
_existsDirectedPath_
(
graph
,
x
,
z
) && !
isForbidenArc_
(
x
,
z
)) {
338
reset
=
true
;
339
graph
.
eraseEdge
(
Edge
(
z
,
x
));
340
// if we find a cycle, we force the competing edge
341
graph
.
addArc
(
x
,
z
);
342
if
(
std
::
find
(
_latentCouples_
.
begin
(),
_latentCouples_
.
end
(),
Arc
(
x
,
z
))
343
==
_latentCouples_
.
end
()) {
344
_latentCouples_
.
emplace_back
(
x
,
z
);
345
}
346
}
347
}
348
}
// if rule 0 or rule 1
349
350
// if what we want to add already exists : pass to the next triplet
351
if
(
reset
) {
352
i
= 0;
353
}
else
{
354
++
i
;
355
}
356
if
(
onProgress
.
hasListener
()) {
357
GUM_EMIT3
(
onProgress
,
358
((
current_step_
+
i
) * 100) / (
past_steps
+
steps_orient
),
359
0.,
360
timer_
.
step
());
361
}
362
}
// while
363
364
// erasing the the double headed arcs
365
for
(
const
Arc
&
arc
:
_latentCouples_
) {
366
graph
.
eraseArc
(
Arc
(
arc
.
head
(),
arc
.
tail
()));
367
}
368
}
369
370
/// varient trying to propagate both orientations in a bidirected arc
371
void
Miic
::
orientationLatents_
(
372
CorrectedMutualInformation
<>&
mutualInformation
,
373
MixedGraph
&
graph
,
374
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
375
std
::
vector
<
Ranking
>
triples
=
unshieldedTriples_
(
graph
,
mutualInformation
,
sepSet
);
376
Size
steps_orient
=
triples
.
size
();
377
Size
past_steps
=
current_step_
;
378
379
NodeId
i
= 0;
380
// list of elements that we shouldnt read again, ie elements that are
381
// eligible to
382
// rule 0 after the first time they are tested, and elements on which rule 1
383
// has been applied
384
while
(
i
<
triples
.
size
()) {
385
// if i not in do_not_reread
386
Ranking
triple
=
triples
[
i
];
387
NodeId
x
,
y
,
z
;
388
x
=
std
::
get
< 0 >(*
triple
.
first
);
389
y
=
std
::
get
< 1 >(*
triple
.
first
);
390
z
=
std
::
get
< 2 >(*
triple
.
first
);
391
392
std
::
vector
<
NodeId
>
ui
;
393
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
394
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
395
if
(
sepSet
.
exists
(
key
)) {
396
ui
=
sepSet
[
key
];
397
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
398
ui
=
sepSet
[
rev_key
];
399
}
400
double
Ixyz_ui
=
triple
.
second
;
401
// try Rule 0
402
if
(
Ixyz_ui
< 0) {
403
// if ( z not in Sep[x,y])
404
if
(
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
405
// if what we want to add already exists : pass
406
if
((
graph
.
existsArc
(
x
,
z
) ||
graph
.
existsArc
(
z
,
x
))
407
&& (
graph
.
existsArc
(
y
,
z
) ||
graph
.
existsArc
(
z
,
y
))) {
408
++
i
;
409
}
else
{
410
i
= 0;
411
graph
.
eraseEdge
(
Edge
(
x
,
z
));
412
graph
.
eraseEdge
(
Edge
(
y
,
z
));
413
// checking for cycles
414
if
(
graph
.
existsArc
(
z
,
x
)) {
415
graph
.
eraseArc
(
Arc
(
z
,
x
));
416
try
{
417
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
x
);
418
// if we find a cycle, we force the competing edge
419
_latentCouples_
.
emplace_back
(
z
,
x
);
420
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
x
,
z
); }
421
graph
.
addArc
(
z
,
x
);
422
}
else
{
423
try
{
424
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
x
);
425
// if we find a cycle, we force the competing edge
426
graph
.
addArc
(
z
,
x
);
427
_latentCouples_
.
emplace_back
(
z
,
x
);
428
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
x
,
z
); }
429
}
430
if
(
graph
.
existsArc
(
z
,
y
)) {
431
graph
.
eraseArc
(
Arc
(
z
,
y
));
432
try
{
433
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
y
);
434
// if we find a cycle, we force the competing edge
435
_latentCouples_
.
emplace_back
(
z
,
y
);
436
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
y
,
z
); }
437
graph
.
addArc
(
z
,
y
);
438
}
else
{
439
try
{
440
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
z
,
y
);
441
// if we find a cycle, we force the competing edge
442
graph
.
addArc
(
z
,
y
);
443
_latentCouples_
.
emplace_back
(
z
,
y
);
444
445
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
y
,
z
); }
446
}
447
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
448
_latentCouples_
.
emplace_back
(
z
,
x
);
449
}
450
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
451
_latentCouples_
.
emplace_back
(
z
,
y
);
452
}
453
}
454
}
else
{
455
++
i
;
456
}
457
}
else
{
// try Rule 1
458
bool
reset
{
false
};
459
if
(
graph
.
existsArc
(
x
,
z
) && !
graph
.
existsArc
(
z
,
y
) && !
graph
.
existsArc
(
y
,
z
)) {
460
reset
=
true
;
461
graph
.
eraseEdge
(
Edge
(
z
,
y
));
462
try
{
463
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
y
,
z
);
464
// if we find a cycle, we force the competing edge
465
graph
.
addArc
(
y
,
z
);
466
_latentCouples_
.
emplace_back
(
y
,
z
);
467
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
z
,
y
); }
468
}
469
if
(
graph
.
existsArc
(
y
,
z
) && !
graph
.
existsArc
(
z
,
x
) && !
graph
.
existsArc
(
x
,
z
)) {
470
reset
=
true
;
471
graph
.
eraseEdge
(
Edge
(
z
,
x
));
472
try
{
473
std
::
vector
<
NodeId
>
path
=
graph
.
directedPath
(
x
,
z
);
474
// if we find a cycle, we force the competing edge
475
graph
.
addArc
(
x
,
z
);
476
_latentCouples_
.
emplace_back
(
x
,
z
);
477
}
catch
(
gum
::
NotFound
) {
graph
.
addArc
(
z
,
x
); }
478
}
479
480
if
(
reset
) {
481
i
= 0;
482
}
else
{
483
++
i
;
484
}
485
}
// if rule 0 or rule 1
486
if
(
onProgress
.
hasListener
()) {
487
GUM_EMIT3
(
onProgress
,
488
((
current_step_
+
i
) * 100) / (
past_steps
+
steps_orient
),
489
0.,
490
timer_
.
step
());
491
}
492
}
// while
493
494
// erasing the the double headed arcs
495
for
(
const
Arc
&
arc
:
_latentCouples_
) {
496
graph
.
eraseArc
(
Arc
(
arc
.
head
(),
arc
.
tail
()));
497
}
498
}
499
500
/// varient using the orientation protocol of MIIC
501
void
Miic
::
orientationMiic_
(
502
CorrectedMutualInformation
<>&
mutualInformation
,
503
MixedGraph
&
graph
,
504
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
505
// structure to store the orientations marks -, o, or >,
506
// Considers the head of the arc/edge first node -* second node
507
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>
marks
=
_initialMarks_
;
508
509
// marks always correspond to the head of the arc/edge. - is for a forbidden
510
// arc, > for a mandatory arc
511
// we start by adding the mandatory arcs
512
for
(
auto
iter
=
marks
.
begin
();
iter
!=
marks
.
end
(); ++
iter
) {
513
if
(
graph
.
existsEdge
(
iter
.
key
().
first
,
iter
.
key
().
second
) &&
iter
.
val
() ==
'>'
) {
514
graph
.
eraseEdge
(
Edge
(
iter
.
key
().
first
,
iter
.
key
().
second
));
515
graph
.
addArc
(
iter
.
key
().
first
,
iter
.
key
().
second
);
516
}
517
}
518
519
std
::
vector
<
ProbabilisticRanking
>
proba_triples
520
=
unshieldedTriplesMiic_
(
graph
,
mutualInformation
,
sepSet
,
marks
);
521
522
const
Size
steps_orient
=
proba_triples
.
size
();
523
Size
past_steps
=
current_step_
;
524
525
ProbabilisticRanking
best
;
526
if
(
steps_orient
> 0) {
best
=
proba_triples
[0]; }
527
528
while
(!
proba_triples
.
empty
() &&
std
::
max
(
std
::
get
< 2 >(
best
),
std
::
get
< 3 >(
best
)) > 0.5) {
529
const
NodeId
x
=
std
::
get
< 0 >(*
std
::
get
< 0 >(
best
));
530
const
NodeId
y
=
std
::
get
< 1 >(*
std
::
get
< 0 >(
best
));
531
const
NodeId
z
=
std
::
get
< 2 >(*
std
::
get
< 0 >(
best
));
532
533
const
double
i3
=
std
::
get
< 1 >(
best
);
534
535
const
double
p1
=
std
::
get
< 2 >(
best
);
536
const
double
p2
=
std
::
get
< 3 >(
best
);
537
if
(
i3
<= 0) {
538
_orientingVstructureMiic_
(
graph
,
marks
,
x
,
y
,
z
,
p1
,
p2
);
539
}
else
{
540
_propagatingOrientationMiic_
(
graph
,
marks
,
x
,
y
,
z
,
p1
,
p2
);
541
}
542
543
delete
std
::
get
< 0 >(
best
);
544
proba_triples
.
erase
(
proba_triples
.
begin
());
545
// actualisation of the list of triples
546
proba_triples
=
updateProbaTriples_
(
graph
,
proba_triples
);
547
548
if
(!
proba_triples
.
empty
())
best
=
proba_triples
[0];
549
550
++
current_step_
;
551
if
(
onProgress
.
hasListener
()) {
552
GUM_EMIT3
(
onProgress
,
553
(
current_step_
* 100) / (
steps_orient
+
past_steps
),
554
0.,
555
timer_
.
step
());
556
}
557
}
// while
558
559
// erasing the double headed arcs
560
for
(
auto
iter
=
_latentCouples_
.
rbegin
();
iter
!=
_latentCouples_
.
rend
(); ++
iter
) {
561
graph
.
eraseArc
(
Arc
(
iter
->
head
(),
iter
->
tail
()));
562
if
(
_existsDirectedPath_
(
graph
,
iter
->
head
(),
iter
->
tail
())) {
563
// if we find a cycle, we force the competing edge
564
graph
.
addArc
(
iter
->
head
(),
iter
->
tail
());
565
graph
.
eraseArc
(
Arc
(
iter
->
tail
(),
iter
->
head
()));
566
*
iter
=
Arc
(
iter
->
head
(),
iter
->
tail
());
567
}
568
}
569
570
if
(
onProgress
.
hasListener
()) {
GUM_EMIT3
(
onProgress
, 100, 0.,
timer_
.
step
()); }
571
}
572
573
/// finds the best contributor node for a pair given a conditioning set
574
void
Miic
::
findBestContributor_
(
NodeId
x
,
575
NodeId
y
,
576
const
std
::
vector
<
NodeId
>&
ui
,
577
const
MixedGraph
&
graph
,
578
CorrectedMutualInformation
<>&
mutualInformation
,
579
Heap
<
CondRanking
,
GreaterPairOn2nd
>&
rank
) {
580
double
maxP
= -1.0;
581
NodeId
maxZ
= 0;
582
583
// compute N
584
// __N = I.N();
585
const
double
Ixy_ui
=
mutualInformation
.
score
(
x
,
y
,
ui
);
586
587
for
(
const
NodeId
z
:
graph
) {
588
// if z!=x and z!=y and z not in ui
589
if
(
z
!=
x
&&
z
!=
y
&&
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
) ==
ui
.
end
()) {
590
double
Pnv
;
591
double
Pb
;
592
593
// Computing Pnv
594
const
double
Ixyz_ui
=
mutualInformation
.
score
(
x
,
y
,
z
,
ui
);
595
double
calc_expo1
= -
Ixyz_ui
*
M_LN2
;
596
// if exponentials are too high or to low, crop them at | __maxLog|
597
if
(
calc_expo1
>
_maxLog_
) {
598
Pnv
= 0.0;
599
}
else
if
(
calc_expo1
< -
_maxLog_
) {
600
Pnv
= 1.0;
601
}
else
{
602
Pnv
= 1 / (1 +
std
::
exp
(
calc_expo1
));
603
}
604
605
// Computing Pb
606
const
double
Ixz_ui
=
mutualInformation
.
score
(
x
,
z
,
ui
);
607
const
double
Iyz_ui
=
mutualInformation
.
score
(
y
,
z
,
ui
);
608
609
calc_expo1
= -(
Ixz_ui
-
Ixy_ui
) *
M_LN2
;
610
double
calc_expo2
= -(
Iyz_ui
-
Ixy_ui
) *
M_LN2
;
611
612
// if exponentials are too high or to low, crop them at _maxLog_
613
if
(
calc_expo1
>
_maxLog_
||
calc_expo2
>
_maxLog_
) {
614
Pb
= 0.0;
615
}
else
if
(
calc_expo1
< -
_maxLog_
&&
calc_expo2
< -
_maxLog_
) {
616
Pb
= 1.0;
617
}
else
{
618
double
expo1
,
expo2
;
619
if
(
calc_expo1
< -
_maxLog_
) {
620
expo1
= 0.0;
621
}
else
{
622
expo1
=
std
::
exp
(
calc_expo1
);
623
}
624
if
(
calc_expo2
< -
_maxLog_
) {
625
expo2
= 0.0;
626
}
else
{
627
expo2
=
std
::
exp
(
calc_expo2
);
628
}
629
Pb
= 1 / (1 +
expo1
+
expo2
);
630
}
631
632
// Getting max(min(Pnv, pb))
633
const
double
min_pnv_pb
=
std
::
min
(
Pnv
,
Pb
);
634
if
(
min_pnv_pb
>
maxP
) {
635
maxP
=
min_pnv_pb
;
636
maxZ
=
z
;
637
}
638
}
// if z not in (x, y)
639
}
// for z in graph.nodes
640
// storing best z in rank_
641
CondRanking
final
;
642
auto
tup
=
new
CondThreePoints
{
x
,
y
,
maxZ
,
ui
};
643
final
.
first
=
tup
;
644
final
.
second
=
maxP
;
645
rank
.
insert
(
final
);
646
}
647
648
/// gets the list of unshielded triples in the graph in decreasing value of
649
///|I'(x, y, z|{ui})|
650
std
::
vector
<
Ranking
>
Miic
::
unshieldedTriples_
(
651
const
MixedGraph
&
graph
,
652
CorrectedMutualInformation
<>&
mutualInformation
,
653
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
) {
654
std
::
vector
<
Ranking
>
triples
;
655
for
(
NodeId
z
:
graph
) {
656
for
(
NodeId
x
:
graph
.
neighbours
(
z
)) {
657
for
(
NodeId
y
:
graph
.
neighbours
(
z
)) {
658
if
(
y
<
x
&& !
graph
.
existsEdge
(
x
,
y
)) {
659
std
::
vector
<
NodeId
>
ui
;
660
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
661
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
662
if
(
sepSet
.
exists
(
key
)) {
663
ui
=
sepSet
[
key
];
664
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
665
ui
=
sepSet
[
rev_key
];
666
}
667
// remove z from ui if it's present
668
const
auto
iter_z_place
=
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
);
669
if
(
iter_z_place
!=
ui
.
end
()) {
ui
.
erase
(
iter_z_place
); }
670
671
double
Ixyz_ui
=
mutualInformation
.
score
(
x
,
y
,
z
,
ui
);
672
Ranking
triple
;
673
auto
tup
=
new
ThreePoints
{
x
,
y
,
z
};
674
triple
.
first
=
tup
;
675
triple
.
second
=
Ixyz_ui
;
676
triples
.
push_back
(
triple
);
677
}
678
}
679
}
680
}
681
std
::
sort
(
triples
.
begin
(),
triples
.
end
(),
GreaterAbsPairOn2nd
());
682
return
triples
;
683
}
684
685
/// gets the list of unshielded triples in the graph in decreasing value of
686
///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
687
std
::
vector
<
ProbabilisticRanking
>
Miic
::
unshieldedTriplesMiic_
(
688
const
MixedGraph
&
graph
,
689
CorrectedMutualInformation
<>&
mutualInformation
,
690
const
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
std
::
vector
<
NodeId
> >&
sepSet
,
691
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>&
marks
) {
692
std
::
vector
<
ProbabilisticRanking
>
triples
;
693
for
(
NodeId
z
:
graph
) {
694
for
(
NodeId
x
:
graph
.
neighbours
(
z
)) {
695
for
(
NodeId
y
:
graph
.
neighbours
(
z
)) {
696
if
(
y
<
x
&& !
graph
.
existsEdge
(
x
,
y
)) {
697
std
::
vector
<
NodeId
>
ui
;
698
std
::
pair
<
NodeId
,
NodeId
>
key
= {
x
,
y
};
699
std
::
pair
<
NodeId
,
NodeId
>
rev_key
= {
y
,
x
};
700
if
(
sepSet
.
exists
(
key
)) {
701
ui
=
sepSet
[
key
];
702
}
else
if
(
sepSet
.
exists
(
rev_key
)) {
703
ui
=
sepSet
[
rev_key
];
704
}
705
// remove z from ui if it's present
706
const
auto
iter_z_place
=
std
::
find
(
ui
.
begin
(),
ui
.
end
(),
z
);
707
if
(
iter_z_place
!=
ui
.
end
()) {
ui
.
erase
(
iter_z_place
); }
708
709
const
double
Ixyz_ui
=
mutualInformation
.
score
(
x
,
y
,
z
,
ui
);
710
auto
tup
=
new
ThreePoints
{
x
,
y
,
z
};
711
ProbabilisticRanking
triple
{
tup
,
Ixyz_ui
, 0.5, 0.5};
712
triples
.
push_back
(
triple
);
713
if
(!
marks
.
exists
({
x
,
z
})) {
marks
.
insert
({
x
,
z
},
'o'
); }
714
if
(!
marks
.
exists
({
z
,
x
})) {
marks
.
insert
({
z
,
x
},
'o'
); }
715
if
(!
marks
.
exists
({
y
,
z
})) {
marks
.
insert
({
y
,
z
},
'o'
); }
716
if
(!
marks
.
exists
({
z
,
y
})) {
marks
.
insert
({
z
,
y
},
'o'
); }
717
}
718
}
719
}
720
}
721
triples
=
updateProbaTriples_
(
graph
,
triples
);
722
std
::
sort
(
triples
.
begin
(),
triples
.
end
(),
GreaterTupleOnLast
());
723
return
triples
;
724
}
725
726
/// Gets the orientation probabilities like MIIC for the orientation phase
727
std
::
vector
<
ProbabilisticRanking
>
728
Miic
::
updateProbaTriples_
(
const
MixedGraph
&
graph
,
729
std
::
vector
<
ProbabilisticRanking
>
probaTriples
) {
730
for
(
auto
&
triple
:
probaTriples
) {
731
NodeId
x
,
y
,
z
;
732
x
=
std
::
get
< 0 >(*
std
::
get
< 0 >(
triple
));
733
y
=
std
::
get
< 1 >(*
std
::
get
< 0 >(
triple
));
734
z
=
std
::
get
< 2 >(*
std
::
get
< 0 >(
triple
));
735
const
double
Ixyz
=
std
::
get
< 1 >(
triple
);
736
double
Pxz
=
std
::
get
< 2 >(
triple
);
737
double
Pyz
=
std
::
get
< 3 >(
triple
);
738
739
if
(
Ixyz
<= 0) {
740
const
double
expo
=
std
::
exp
(
Ixyz
);
741
const
double
P0
= (1 +
expo
) / (1 + 3 *
expo
);
742
// distinguish betweeen the initialization and the update process
743
if
(
Pxz
==
Pyz
&&
Pyz
== 0.5) {
744
std
::
get
< 2 >(
triple
) =
P0
;
745
std
::
get
< 3 >(
triple
) =
P0
;
746
}
else
{
747
if
(
graph
.
existsArc
(
x
,
z
) &&
Pxz
>=
P0
) {
748
std
::
get
< 3 >(
triple
) =
Pxz
* (1 / (1 +
expo
) - 0.5) + 0.5;
749
}
else
if
(
graph
.
existsArc
(
y
,
z
) &&
Pyz
>=
P0
) {
750
std
::
get
< 2 >(
triple
) =
Pyz
* (1 / (1 +
expo
) - 0.5) + 0.5;
751
}
752
}
753
}
else
{
754
const
double
expo
=
std
::
exp
(-
Ixyz
);
755
if
(
graph
.
existsArc
(
x
,
z
) &&
Pxz
>= 0.5) {
756
std
::
get
< 3 >(
triple
) =
Pxz
* (1 / (1 +
expo
) - 0.5) + 0.5;
757
}
else
if
(
graph
.
existsArc
(
y
,
z
) &&
Pyz
>= 0.5) {
758
std
::
get
< 2 >(
triple
) =
Pyz
* (1 / (1 +
expo
) - 0.5) + 0.5;
759
}
760
}
761
}
762
std
::
sort
(
probaTriples
.
begin
(),
probaTriples
.
end
(),
GreaterTupleOnLast
());
763
return
probaTriples
;
764
}
765
766
/// learns the structure of an Bayesian network, ie a DAG, from an Essential
767
/// graph.
768
DAG
Miic
::
learnStructure
(
CorrectedMutualInformation
<>&
I
,
MixedGraph
initialGraph
) {
769
MixedGraph
essentialGraph
=
learnMixedStructure
(
I
,
initialGraph
);
770
// orientate remaining edges
771
772
const
Sequence
<
NodeId
>
order
=
essentialGraph
.
topologicalOrder
();
773
774
// first, forbidden arcs force arc in the other direction
775
for
(
NodeId
x
:
order
) {
776
const
auto
nei_x
=
essentialGraph
.
neighbours
(
x
);
777
for
(
NodeId
y
:
nei_x
)
778
if
(
isForbidenArc_
(
x
,
y
)) {
779
essentialGraph
.
eraseEdge
(
Edge
(
x
,
y
));
780
if
(
isForbidenArc_
(
y
,
x
)) {
781
GUM_TRACE
(
"Neither arc allowed for edge ("
<<
x
<<
","
<<
y
<<
")"
)
782
}
else
{
783
GUM_TRACE
(
"Forced orientation : "
<<
y
<<
"->"
<<
x
)
784
essentialGraph
.
addArc
(
y
,
x
);
785
}
786
}
else
if
(
isForbidenArc_
(
y
,
x
)) {
787
essentialGraph
.
eraseEdge
(
Edge
(
x
,
y
));
788
GUM_TRACE
(
"Forced orientation : "
<<
x
<<
"->"
<<
y
)
789
essentialGraph
.
addArc
(
x
,
y
);
790
}
791
}
792
GUM_TRACE
(
essentialGraph
.
toDot
());
793
794
// first, propagate existing orientations
795
bool
newOrientation
=
true
;
796
while
(
newOrientation
) {
797
newOrientation
=
false
;
798
for
(
NodeId
x
:
order
) {
799
if
(!
essentialGraph
.
parents
(
x
).
empty
()) {
800
newOrientation
|=
propagatesRemainingOrientableEdges_
(
essentialGraph
,
x
);
801
}
802
}
803
}
804
GUM_TRACE
(
essentialGraph
.
toDot
());
805
propagatesOrientationInChainOfRemainingEdges_
(
essentialGraph
);
806
GUM_TRACE
(
essentialGraph
.
toDot
());
807
808
// then decide the orientation for double arcs
809
for
(
NodeId
x
:
order
)
810
for
(
NodeId
y
:
essentialGraph
.
parents
(
x
))
811
if
(
essentialGraph
.
parents
(
y
).
contains
(
x
)) {
812
GUM_TRACE
(
" + Resolving double arcs (poorly)"
)
813
essentialGraph
.
eraseArc
(
Arc
(
y
,
x
));
814
}
815
816
// std::cout << "Le mixed graph après une deuxième propagation mesdames et
817
// messieurs: "
818
//<< essentialGraph.toDot() << std::endl;
819
// std::cout << "Le graphe contient maintenant : " <<
820
// essentialGraph.sizeArcs() << " arcs."
821
//<< std::endl;
822
// std::cout << "Que voici: " << essentialGraph.arcs() << std::endl;
823
// turn the mixed graph into a dag
824
DAG
dag
;
825
for
(
auto
node
:
essentialGraph
) {
826
dag
.
addNodeWithId
(
node
);
827
}
828
for
(
const
Arc
&
arc
:
essentialGraph
.
arcs
()) {
829
dag
.
addArc
(
arc
.
tail
(),
arc
.
head
());
830
}
831
832
return
dag
;
833
}
834
835
bool
Miic
::
isOrientable_
(
const
MixedGraph
&
graph
,
NodeId
xi
,
NodeId
xj
)
const
{
836
// no cycle
837
if
(
_existsDirectedPath_
(
graph
,
xj
,
xi
)) {
838
GUM_TRACE
(
"cycle("
<<
xi
<<
"-"
<<
xj
<<
")"
)
839
return
false
;
840
}
841
842
// R1
843
if
(!(
graph
.
parents
(
xi
) -
graph
.
adjacents
(
xj
)).
empty
()) {
844
GUM_TRACE
(
"R1("
<<
xi
<<
"-"
<<
xj
<<
")"
)
845
return
true
;
846
}
847
848
// R2
849
if
(
_existsDirectedPath_
(
graph
,
xi
,
xj
)) {
850
GUM_TRACE
(
"R2("
<<
xi
<<
"-"
<<
xj
<<
")"
)
851
return
true
;
852
}
853
854
// R3
855
int
nbr
= 0;
856
for
(
const
auto
p
:
graph
.
parents
(
xj
)) {
857
if
(!
graph
.
mixedOrientedPath
(
xi
,
p
).
empty
()) {
858
nbr
+= 1;
859
if
(
nbr
== 2) {
860
GUM_TRACE
(
"R3("
<<
xi
<<
"-"
<<
xj
<<
")"
)
861
return
true
;
862
}
863
}
864
}
865
return
false
;
866
}
867
868
void
Miic
::
propagatesOrientationInChainOfRemainingEdges_
(
MixedGraph
&
essentialGraph
) {
869
// then decide the orientation for remaining edges
870
while
(!
essentialGraph
.
edges
().
empty
()) {
871
const
auto
&
edge
= *(
essentialGraph
.
edges
().
begin
());
872
NodeId
root
=
edge
.
first
();
873
Size
size_children_root
=
essentialGraph
.
children
(
root
).
size
();
874
NodeSet
visited
;
875
NodeSet
stack
{
root
};
876
// check the best root for the set of neighbours
877
while
(!
stack
.
empty
()) {
878
NodeId
next
= *(
stack
.
begin
());
879
stack
.
erase
(
next
);
880
if
(
visited
.
contains
(
next
))
continue
;
881
if
(
essentialGraph
.
children
(
next
).
size
() >
size_children_root
) {
882
size_children_root
=
essentialGraph
.
children
(
next
).
size
();
883
root
=
next
;
884
}
885
for
(
const
auto
n
:
essentialGraph
.
neighbours
(
next
))
886
if
(!
stack
.
contains
(
n
) && !
visited
.
contains
(
n
))
stack
.
insert
(
n
);
887
visited
.
insert
(
next
);
888
}
889
// orientation now
890
visited
.
clear
();
891
stack
.
clear
();
892
stack
.
insert
(
root
);
893
while
(!
stack
.
empty
()) {
894
NodeId
next
= *(
stack
.
begin
());
895
stack
.
erase
(
next
);
896
if
(
visited
.
contains
(
next
))
continue
;
897
const
auto
nei
=
essentialGraph
.
neighbours
(
next
);
898
for
(
const
auto
n
:
nei
) {
899
if
(!
stack
.
contains
(
n
) && !
visited
.
contains
(
n
))
stack
.
insert
(
n
);
900
GUM_TRACE
(
" + amap reasonably orientation for "
<<
n
<<
"->"
<<
next
);
901
essentialGraph
.
eraseEdge
(
Edge
(
n
,
next
));
902
essentialGraph
.
addArc
(
n
,
next
);
903
}
904
visited
.
insert
(
next
);
905
}
906
}
907
}
908
909
/// Propagates the orientation from a node to its neighbours
910
bool
Miic
::
propagatesRemainingOrientableEdges_
(
MixedGraph
&
graph
,
NodeId
xj
) {
911
bool
res
=
false
;
912
const
auto
neighbours
=
graph
.
neighbours
(
xj
);
913
for
(
auto
&
xi
:
neighbours
) {
914
bool
i_j
=
isOrientable_
(
graph
,
xi
,
xj
);
915
bool
j_i
=
isOrientable_
(
graph
,
xj
,
xi
);
916
if
(
i_j
||
j_i
) {
917
GUM_TRACE
(
" + Removing edge ("
<<
xi
<<
","
<<
xj
<<
")"
)
918
graph
.
eraseEdge
(
Edge
(
xi
,
xj
));
919
res
=
true
;
920
}
921
if
(
i_j
) {
922
GUM_TRACE
(
" + add arc ("
<<
xi
<<
","
<<
xj
<<
")"
)
923
graph
.
addArc
(
xi
,
xj
);
924
propagatesRemainingOrientableEdges_
(
graph
,
xj
);
925
}
926
if
(
j_i
) {
927
GUM_TRACE
(
" + add arc ("
<<
xi
<<
","
<<
xj
<<
")"
)
928
graph
.
addArc
(
xj
,
xi
);
929
propagatesRemainingOrientableEdges_
(
graph
,
xi
);
930
}
931
if
(
i_j
&&
j_i
) {
932
GUM_TRACE
(
" + add arc ("
<<
xi
<<
","
<<
xj
<<
")"
)
933
_latentCouples_
.
emplace_back
(
xi
,
xj
);
934
}
935
}
936
937
return
res
;
938
}
939
940
/// get the list of arcs hiding latent variables
941
const
std
::
vector
<
Arc
>
Miic
::
latentVariables
()
const
{
return
_latentCouples_
; }
942
943
/// learns the structure and the parameters of a BN
944
template
<
typename
GUM_SCALAR
,
typename
GRAPH_CHANGES_SELECTOR
,
typename
PARAM_ESTIMATOR
>
945
BayesNet
<
GUM_SCALAR
>
Miic
::
learnBN
(
GRAPH_CHANGES_SELECTOR
&
selector
,
946
PARAM_ESTIMATOR
&
estimator
,
947
DAG
initial_dag
) {
948
return
DAG2BNLearner
<>::
createBN
<
GUM_SCALAR
>(
estimator
,
949
learnStructure
(
selector
,
initial_dag
));
950
}
951
952
void
Miic
::
setMiicBehaviour
() {
this
->
_useMiic_
=
true
; }
953
954
void
Miic
::
set3of2Behaviour
() {
this
->
_useMiic_
=
false
; }
955
956
void
Miic
::
addConstraints
(
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>
constraints
) {
957
this
->
_initialMarks_
=
constraints
;
958
}
959
960
bool
Miic
::
_existsNonTrivialDirectedPath_
(
const
MixedGraph
&
graph
,
961
const
NodeId
n1
,
962
const
NodeId
n2
) {
963
for
(
const
auto
parent
:
graph
.
parents
(
n2
)) {
964
if
(
graph
.
existsArc
(
parent
,
965
n2
))
// if there is a double arc, pass
966
continue
;
967
if
(
parent
==
n1
)
// trivial directed path => not recognized
968
continue
;
969
if
(
_existsDirectedPath_
(
graph
,
n1
,
parent
))
return
true
;
970
}
971
return
false
;
972
}
973
974
bool
Miic
::
_existsDirectedPath_
(
const
MixedGraph
&
graph
,
const
NodeId
n1
,
const
NodeId
n2
) {
975
// not recursive version => use a FIFO for simulating the recursion
976
List
<
NodeId
>
nodeFIFO
;
977
// mark[node] = successor if visited, else mark[node] does not exist
978
Set
<
NodeId
>
mark
;
979
980
mark
.
insert
(
n2
);
981
nodeFIFO
.
pushBack
(
n2
);
982
983
NodeId
current
;
984
985
while
(!
nodeFIFO
.
empty
()) {
986
current
=
nodeFIFO
.
front
();
987
nodeFIFO
.
popFront
();
988
989
// check the parents
990
for
(
const
auto
new_one
:
graph
.
parents
(
current
)) {
991
if
(
graph
.
existsArc
(
current
,
992
new_one
))
// if there is a double arc, pass
993
continue
;
994
995
if
(
new_one
==
n1
) {
return
true
; }
996
997
if
(
mark
.
exists
(
new_one
))
// if this node is already marked, do not
998
continue
;
// check it again
999
1000
mark
.
insert
(
new_one
);
1001
nodeFIFO
.
pushBack
(
new_one
);
1002
}
1003
}
1004
1005
return
false
;
1006
}
1007
1008
void
Miic
::
_orientingVstructureMiic_
(
MixedGraph
&
graph
,
1009
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>&
marks
,
1010
NodeId
x
,
1011
NodeId
y
,
1012
NodeId
z
,
1013
double
p1
,
1014
double
p2
) {
1015
// v-structure discovery
1016
if
(
marks
[{
x
,
z
}] ==
'o'
&&
marks
[{
y
,
z
}] ==
'o'
) {
// If x-z-y
1017
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
x
)) {
1018
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1019
graph
.
addArc
(
x
,
z
);
1020
GUM_TRACE
(
"1.a Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1021
GUM_TRACE
(
"1.a Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1022
marks
[{
x
,
z
}] =
'>'
;
1023
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
1024
GUM_TRACE
(
"Adding latent couple ("
<<
z
<<
","
<<
x
<<
")"
)
1025
_latentCouples_
.
emplace_back
(
z
,
x
);
1026
}
1027
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1028
}
else
{
1029
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1030
GUM_TRACE
(
"1.b Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1031
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
x
,
z
)) {
1032
graph
.
addArc
(
z
,
x
);
1033
GUM_TRACE
(
"1.b Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1034
marks
[{
z
,
x
}] =
'>'
;
1035
}
1036
}
1037
1038
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
y
)) {
1039
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1040
graph
.
addArc
(
y
,
z
);
1041
GUM_TRACE
(
"1.c Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1042
GUM_TRACE
(
"1.c Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1043
marks
[{
y
,
z
}] =
'>'
;
1044
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
1045
_latentCouples_
.
emplace_back
(
z
,
y
);
1046
}
1047
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1048
}
else
{
1049
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1050
GUM_TRACE
(
"1.d Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1051
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
y
,
z
)) {
1052
graph
.
addArc
(
z
,
y
);
1053
GUM_TRACE
(
"1.d Adding arc ("
<<
z
<<
","
<<
y
<<
")"
)
1054
marks
[{
z
,
y
}] =
'>'
;
1055
}
1056
}
1057
}
else
if
(
marks
[{
x
,
z
}] ==
'>'
&&
marks
[{
y
,
z
}] ==
'o'
) {
// If x->z-y
1058
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
y
)) {
1059
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1060
graph
.
addArc
(
y
,
z
);
1061
GUM_TRACE
(
"2.a Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1062
GUM_TRACE
(
"2.a Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1063
marks
[{
y
,
z
}] =
'>'
;
1064
if
(
graph
.
existsArc
(
z
,
y
) &&
_isNotLatentCouple_
(
z
,
y
)) {
1065
_latentCouples_
.
emplace_back
(
z
,
y
);
1066
}
1067
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1068
}
else
{
1069
graph
.
eraseEdge
(
Edge
(
y
,
z
));
1070
GUM_TRACE
(
"2.b Removing edge ("
<<
y
<<
","
<<
z
<<
")"
)
1071
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
y
,
z
)) {
1072
graph
.
addArc
(
z
,
y
);
1073
GUM_TRACE
(
"2.b Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1074
marks
[{
z
,
y
}] =
'>'
;
1075
}
1076
}
1077
}
else
if
(
marks
[{
y
,
z
}] ==
'>'
&&
marks
[{
x
,
z
}] ==
'o'
) {
1078
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
z
,
x
)) {
1079
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1080
graph
.
addArc
(
x
,
z
);
1081
GUM_TRACE
(
"3.a Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1082
GUM_TRACE
(
"3.a Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1083
marks
[{
x
,
z
}] =
'>'
;
1084
if
(
graph
.
existsArc
(
z
,
x
) &&
_isNotLatentCouple_
(
z
,
x
)) {
1085
_latentCouples_
.
emplace_back
(
z
,
x
);
1086
}
1087
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1088
}
else
{
1089
graph
.
eraseEdge
(
Edge
(
x
,
z
));
1090
GUM_TRACE
(
"3.b Removing edge ("
<<
x
<<
","
<<
z
<<
")"
)
1091
if
(!
_existsNonTrivialDirectedPath_
(
graph
,
x
,
z
)) {
1092
graph
.
addArc
(
z
,
x
);
1093
GUM_TRACE
(
"3.b Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1094
marks
[{
z
,
x
}] =
'>'
;
1095
}
1096
}
1097
}
1098
}
1099
1100
1101
void
Miic
::
_propagatingOrientationMiic_
(
MixedGraph
&
graph
,
1102
HashTable
<
std
::
pair
<
NodeId
,
NodeId
>,
char
>&
marks
,
1103
NodeId
x
,
1104
NodeId
y
,
1105
NodeId
z
,
1106
double
p1
,
1107
double
p2
) {
1108
// orientation propagation
1109
if
(
marks
[{
x
,
z
}] ==
'>'
&&
marks
[{
y
,
z
}] ==
'o'
&&
marks
[{
z
,
y
}] !=
'-'
) {
1110
graph
.
eraseEdge
(
Edge
(
z
,
y
));
1111
// std::cout << "4. Removing edge (" << z << "," << y << ")" <<
1112
// std::endl;
1113
if
(!
_existsDirectedPath_
(
graph
,
y
,
z
) &&
graph
.
parents
(
y
).
empty
()) {
1114
graph
.
addArc
(
z
,
y
);
1115
GUM_TRACE
(
"4.a Adding arc ("
<<
z
<<
","
<<
y
<<
")"
)
1116
marks
[{
z
,
y
}] =
'>'
;
1117
marks
[{
y
,
z
}] =
'-'
;
1118
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
y
)))
_arcProbas_
.
insert
(
Arc
(
z
,
y
),
p2
);
1119
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
y
) &&
graph
.
parents
(
z
).
empty
()) {
1120
graph
.
addArc
(
y
,
z
);
1121
GUM_TRACE
(
"4.b Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1122
marks
[{
z
,
y
}] =
'-'
;
1123
marks
[{
y
,
z
}] =
'>'
;
1124
_latentCouples_
.
emplace_back
(
y
,
z
);
1125
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1126
}
else
if
(!
_existsDirectedPath_
(
graph
,
y
,
z
)) {
1127
graph
.
addArc
(
z
,
y
);
1128
GUM_TRACE
(
"4.c Adding arc ("
<<
z
<<
","
<<
y
<<
")"
)
1129
marks
[{
z
,
y
}] =
'>'
;
1130
marks
[{
y
,
z
}] =
'-'
;
1131
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
y
)))
_arcProbas_
.
insert
(
Arc
(
z
,
y
),
p2
);
1132
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
y
)) {
1133
graph
.
addArc
(
y
,
z
);
1134
GUM_TRACE
(
"4.d Adding arc ("
<<
y
<<
","
<<
z
<<
")"
)
1135
_latentCouples_
.
emplace_back
(
y
,
z
);
1136
marks
[{
z
,
y
}] =
'-'
;
1137
marks
[{
y
,
z
}] =
'>'
;
1138
if
(!
_arcProbas_
.
exists
(
Arc
(
y
,
z
)))
_arcProbas_
.
insert
(
Arc
(
y
,
z
),
p2
);
1139
}
1140
}
else
if
(
marks
[{
y
,
z
}] ==
'>'
&&
marks
[{
x
,
z
}] ==
'o'
&&
marks
[{
z
,
x
}] !=
'-'
) {
1141
graph
.
eraseEdge
(
Edge
(
z
,
x
));
1142
GUM_TRACE
(
"5. Removing edge ("
<<
z
<<
","
<<
x
<<
")"
)
1143
if
(!
_existsDirectedPath_
(
graph
,
x
,
z
) &&
graph
.
parents
(
x
).
empty
()) {
1144
graph
.
addArc
(
z
,
x
);
1145
GUM_TRACE
(
"5.a Adding arc ("
<<
z
<<
","
<<
x
<<
")"
)
1146
marks
[{
z
,
x
}] =
'>'
;
1147
marks
[{
x
,
z
}] =
'-'
;
1148
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
x
)))
_arcProbas_
.
insert
(
Arc
(
z
,
x
),
p1
);
1149
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
x
) &&
graph
.
parents
(
z
).
empty
()) {
1150
graph
.
addArc
(
x
,
z
);
1151
GUM_TRACE
(
"5.b Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1152
marks
[{
z
,
x
}] =
'-'
;
1153
marks
[{
x
,
z
}] =
'>'
;
1154
_latentCouples_
.
emplace_back
(
x
,
z
);
1155
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1156
}
else
if
(!
_existsDirectedPath_
(
graph
,
x
,
z
)) {
1157
graph
.
addArc
(
z
,
x
);
1158
GUM_TRACE
(
"5.c Adding arc ("
<<
z
<<
","
<<
x
<<
")"
)
1159
marks
[{
z
,
x
}] =
'>'
;
1160
marks
[{
x
,
z
}] =
'-'
;
1161
if
(!
_arcProbas_
.
exists
(
Arc
(
z
,
x
)))
_arcProbas_
.
insert
(
Arc
(
z
,
x
),
p1
);
1162
}
else
if
(!
_existsDirectedPath_
(
graph
,
z
,
x
)) {
1163
graph
.
addArc
(
x
,
z
);
1164
GUM_TRACE
(
"5.d Adding arc ("
<<
x
<<
","
<<
z
<<
")"
)
1165
marks
[{
z
,
x
}] =
'-'
;
1166
marks
[{
x
,
z
}] =
'>'
;
1167
_latentCouples_
.
emplace_back
(
x
,
z
);
1168
if
(!
_arcProbas_
.
exists
(
Arc
(
x
,
z
)))
_arcProbas_
.
insert
(
Arc
(
x
,
z
),
p1
);
1169
}
1170
}
1171
}
1172
1173
bool
Miic
::
_isNotLatentCouple_
(
const
NodeId
x
,
const
NodeId
y
) {
1174
const
auto
&
lbeg
=
_latentCouples_
.
begin
();
1175
const
auto
&
lend
=
_latentCouples_
.
end
();
1176
1177
return
(
std
::
find
(
lbeg
,
lend
,
Arc
(
x
,
y
)) ==
lend
)
1178
&& (
std
::
find
(
lbeg
,
lend
,
Arc
(
y
,
x
)) ==
lend
);
1179
}
1180
1181
bool
Miic
::
isForbidenArc_
(
NodeId
x
,
NodeId
y
)
const
{
1182
return
(
_initialMarks_
.
exists
({
x
,
y
}) &&
_initialMarks_
[{
x
,
y
}] ==
'-'
);
1183
}
1184
}
/* namespace learning */
1185
1186
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31