aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
CNLoopyPropagation_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
#
include
<
agrum
/
CN
/
inference
/
CNLoopyPropagation
.
h
>
23
24
namespace
gum
{
25
namespace
credal
{
26
27
template
<
typename
GUM_SCALAR >
28
void
CNLoopyPropagation<
GUM_SCALAR
>::
saveInference
(
const
std
::
string
&
path
) {
29
std
::
string
path_name
=
path
.
substr
(0,
path
.
size
() - 4);
30
path_name
=
path_name
+
".res"
;
31
32
std
::
ofstream
res
(
path_name
.
c_str
(),
std
::
ios
::
out
|
std
::
ios
::
trunc
);
33
34
if
(!
res
.
good
()) {
35
GUM_ERROR
(
NotFound
,
36
"CNLoopyPropagation<GUM_SCALAR>::saveInference(std::"
37
"string & path) : could not open file : "
38
+
path_name
);
39
}
40
41
std
::
string
ext
=
path
.
substr
(
path
.
size
() - 3,
path
.
size
());
42
43
if
(
std
::
strcmp
(
ext
.
c_str
(),
"evi"
) == 0) {
44
std
::
ifstream
evi
(
path
.
c_str
(),
std
::
ios
::
in
);
45
std
::
string
ligne
;
46
47
if
(!
evi
.
good
()) {
48
GUM_ERROR
(
NotFound
,
49
"CNLoopyPropagation<GUM_SCALAR>::saveInference(std::"
50
"string & path) : could not open file : "
51
+
ext
);
52
}
53
54
while
(
evi
.
good
()) {
55
getline
(
evi
,
ligne
);
56
res
<<
ligne
<<
"\n"
;
57
}
58
59
evi
.
close
();
60
}
61
62
res
<<
"[RESULTATS]"
63
<<
"\n"
;
64
65
for
(
auto
node
:
_bnet_
->
nodes
()) {
66
// calcul distri posteriori
67
GUM_SCALAR
msg_p_min
= 1.0;
68
GUM_SCALAR
msg_p_max
= 0.0;
69
70
// cas evidence, calcul immediat
71
if
(
_infE_
::
evidence_
.
exists
(
node
)) {
72
if
(
_infE_
::
evidence_
[
node
][1] == 0.) {
73
msg_p_min
= 0.;
74
}
else
if
(
_infE_
::
evidence_
[
node
][1] == 1.) {
75
msg_p_min
= 1.;
76
}
77
78
msg_p_max
=
msg_p_min
;
79
}
80
// sinon depuis node P et node L
81
else
{
82
GUM_SCALAR
min
=
NodesP_min_
[
node
];
83
GUM_SCALAR
max
;
84
85
if
(
NodesP_max_
.
exists
(
node
)) {
86
max
=
NodesP_max_
[
node
];
87
}
else
{
88
max
=
min
;
89
}
90
91
GUM_SCALAR
lmin
=
NodesL_min_
[
node
];
92
GUM_SCALAR
lmax
;
93
94
if
(
NodesL_max_
.
exists
(
node
)) {
95
lmax
=
NodesL_max_
[
node
];
96
}
else
{
97
lmax
=
lmin
;
98
}
99
100
// cas limites sur min
101
if
(
min
==
INF_
&&
lmin
== 0.) {
102
std
::
cout
<<
"proba ERR (negatif) : pi = inf, l = 0"
<<
std
::
endl
;
103
}
104
105
if
(
lmin
==
INF_
) {
// cas infini
106
msg_p_min
=
GUM_SCALAR
(1.);
107
}
else
if
(
min
== 0. ||
lmin
== 0.) {
108
msg_p_min
=
GUM_SCALAR
(0.);
109
}
else
{
110
msg_p_min
=
GUM_SCALAR
(1. / (1. + ((1. /
min
- 1.) * 1. /
lmin
)));
111
}
112
113
// cas limites sur max
114
if
(
max
==
INF_
&&
lmax
== 0.) {
115
std
::
cout
<<
"proba ERR (negatif) : pi = inf, l = 0"
<<
std
::
endl
;
116
}
117
118
if
(
lmax
==
INF_
) {
// cas infini
119
msg_p_max
=
GUM_SCALAR
(1.);
120
}
else
if
(
max
== 0. ||
lmax
== 0.) {
121
msg_p_max
=
GUM_SCALAR
(0.);
122
}
else
{
123
msg_p_max
=
GUM_SCALAR
(1. / (1. + ((1. /
max
- 1.) * 1. /
lmax
)));
124
}
125
}
126
127
if
(
msg_p_min
!=
msg_p_min
&&
msg_p_max
==
msg_p_max
) {
msg_p_min
=
msg_p_max
; }
128
129
if
(
msg_p_max
!=
msg_p_max
&&
msg_p_min
==
msg_p_min
) {
msg_p_max
=
msg_p_min
; }
130
131
if
(
msg_p_max
!=
msg_p_max
&&
msg_p_min
!=
msg_p_min
) {
132
std
::
cout
<<
std
::
endl
;
133
std
::
cout
<<
"pas de proba calculable (verifier observations)"
<<
std
::
endl
;
134
}
135
136
res
<<
"P("
<<
_bnet_
->
variable
(
node
).
name
() <<
" | e) = "
;
137
138
if
(
_infE_
::
evidence_
.
exists
(
node
)) {
139
res
<<
"(observe)"
<<
std
::
endl
;
140
}
else
{
141
res
<<
std
::
endl
;
142
}
143
144
res
<<
"\t\t"
<<
_bnet_
->
variable
(
node
).
label
(0) <<
" [ "
<< (
GUM_SCALAR
)1. -
msg_p_max
;
145
146
if
(
msg_p_min
!=
msg_p_max
) {
147
res
<<
", "
<< (
GUM_SCALAR
)1. -
msg_p_min
<<
" ] | "
;
148
}
else
{
149
res
<<
" ] | "
;
150
}
151
152
res
<<
_bnet_
->
variable
(
node
).
label
(1) <<
" [ "
<<
msg_p_min
;
153
154
if
(
msg_p_min
!=
msg_p_max
) {
155
res
<<
", "
<<
msg_p_max
<<
" ]"
<<
std
::
endl
;
156
}
else
{
157
res
<<
" ]"
<<
std
::
endl
;
158
}
159
}
// end of : for each node
160
161
res
.
close
();
162
}
163
164
/**
165
* pour les fonctions suivantes, les GUM_SCALAR min/max doivent etre
166
* initialises
167
* (min a 1 et max a 0) pour comparer avec les resultats intermediaires
168
*/
169
170
/**
171
* une fois les cpts marginalises sur X et Ui, on calcul le min/max,
172
*/
173
template
<
typename
GUM_SCALAR
>
174
void
CNLoopyPropagation
<
GUM_SCALAR
>::
compute_ext_
(
GUM_SCALAR
&
msg_l_min
,
175
GUM_SCALAR
&
msg_l_max
,
176
std
::
vector
<
GUM_SCALAR
>&
lx
,
177
GUM_SCALAR
&
num_min
,
178
GUM_SCALAR
&
num_max
,
179
GUM_SCALAR
&
den_min
,
180
GUM_SCALAR
&
den_max
) {
181
GUM_SCALAR
num_min_tmp
= 1.;
182
GUM_SCALAR
den_min_tmp
= 1.;
183
GUM_SCALAR
num_max_tmp
= 1.;
184
GUM_SCALAR
den_max_tmp
= 1.;
185
186
GUM_SCALAR
res_min
= 1.0,
res_max
= 0.0;
187
188
auto
lsize
=
lx
.
size
();
189
190
for
(
decltype
(
lsize
)
i
= 0;
i
<
lsize
;
i
++) {
191
bool
non_defini_min
=
false
;
192
bool
non_defini_max
=
false
;
193
194
if
(
lx
[
i
] ==
INF_
) {
195
num_min_tmp
=
num_min
;
196
den_min_tmp
=
den_max
;
197
num_max_tmp
=
num_max
;
198
den_max_tmp
=
den_min
;
199
}
else
if
(
lx
[
i
] == (
GUM_SCALAR
)1.) {
200
num_min_tmp
=
GUM_SCALAR
(1.);
201
den_min_tmp
=
GUM_SCALAR
(1.);
202
num_max_tmp
=
GUM_SCALAR
(1.);
203
den_max_tmp
=
GUM_SCALAR
(1.);
204
}
else
if
(
lx
[
i
] > (
GUM_SCALAR
)1.) {
205
GUM_SCALAR
li
=
GUM_SCALAR
(1.) / (
lx
[
i
] -
GUM_SCALAR
(1.));
206
num_min_tmp
=
num_min
+
li
;
207
den_min_tmp
=
den_max
+
li
;
208
num_max_tmp
=
num_max
+
li
;
209
den_max_tmp
=
den_min
+
li
;
210
}
else
if
(
lx
[
i
] < (
GUM_SCALAR
)1.) {
211
GUM_SCALAR
li
=
GUM_SCALAR
(1.) / (
lx
[
i
] -
GUM_SCALAR
(1.));
212
num_min_tmp
=
num_max
+
li
;
213
den_min_tmp
=
den_min
+
li
;
214
num_max_tmp
=
num_min
+
li
;
215
den_max_tmp
=
den_max
+
li
;
216
}
217
218
if
(
den_min_tmp
== 0. &&
num_min_tmp
== 0.) {
219
non_defini_min
=
true
;
220
}
else
if
(
den_min_tmp
== 0. &&
num_min_tmp
!= 0.) {
221
res_min
=
INF_
;
222
}
else
if
(
den_min_tmp
!=
INF_
||
num_min_tmp
!=
INF_
) {
223
res_min
=
num_min_tmp
/
den_min_tmp
;
224
}
225
226
if
(
den_max_tmp
== 0. &&
num_max_tmp
== 0.) {
227
non_defini_max
=
true
;
228
}
else
if
(
den_max_tmp
== 0. &&
num_max_tmp
!= 0.) {
229
res_max
=
INF_
;
230
}
else
if
(
den_max_tmp
!=
INF_
||
num_max_tmp
!=
INF_
) {
231
res_max
=
num_max_tmp
/
den_max_tmp
;
232
}
233
234
if
(
non_defini_max
&&
non_defini_min
) {
235
std
::
cout
<<
"undefined msg"
<<
std
::
endl
;
236
continue
;
237
}
else
if
(
non_defini_min
&& !
non_defini_max
) {
238
res_min
=
res_max
;
239
}
else
if
(
non_defini_max
&& !
non_defini_min
) {
240
res_max
=
res_min
;
241
}
242
243
if
(
res_min
< 0.) {
res_min
= 0.; }
244
245
if
(
res_max
< 0.) {
res_max
= 0.; }
246
247
if
(
msg_l_min
==
msg_l_max
&&
msg_l_min
== -2.) {
248
msg_l_min
=
res_min
;
249
msg_l_max
=
res_max
;
250
}
251
252
if
(
res_max
>
msg_l_max
) {
msg_l_max
=
res_max
; }
253
254
if
(
res_min
<
msg_l_min
) {
msg_l_min
=
res_min
; }
255
256
}
// end of : for each lx
257
}
258
259
/**
260
* extremes pour une combinaison des parents, message vers parent
261
*/
262
template
<
typename
GUM_SCALAR
>
263
void
CNLoopyPropagation
<
GUM_SCALAR
>::
compute_ext_
(
264
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >&
combi_msg_p
,
265
const
NodeId
&
id
,
266
GUM_SCALAR
&
msg_l_min
,
267
GUM_SCALAR
&
msg_l_max
,
268
std
::
vector
<
GUM_SCALAR
>&
lx
,
269
const
Idx
&
pos
) {
270
GUM_SCALAR
num_min
= 0.;
271
GUM_SCALAR
num_max
= 0.;
272
GUM_SCALAR
den_min
= 0.;
273
GUM_SCALAR
den_max
= 0.;
274
275
auto
taille
=
combi_msg_p
.
size
();
276
277
std
::
vector
<
typename
std
::
vector
<
GUM_SCALAR
>::
iterator
>
it
(
taille
);
278
279
for
(
decltype
(
taille
)
i
= 0;
i
<
taille
;
i
++) {
280
it
[
i
] =
combi_msg_p
[
i
].
begin
();
281
}
282
283
Size
pp
=
pos
;
284
285
Size
combi_den
= 0;
286
Size
combi_num
=
pp
;
287
288
// marginalisation
289
while
(
it
[
taille
- 1] !=
combi_msg_p
[
taille
- 1].
end
()) {
290
GUM_SCALAR
prod
= 1.;
291
292
for
(
decltype
(
taille
)
k
= 0;
k
<
taille
;
k
++) {
293
prod
*= *
it
[
k
];
294
}
295
296
den_min
+= (
_cn_
->
get_binaryCPT_min
()[
id
][
combi_den
] *
prod
);
297
den_max
+= (
_cn_
->
get_binaryCPT_max
()[
id
][
combi_den
] *
prod
);
298
299
num_min
+= (
_cn_
->
get_binaryCPT_min
()[
id
][
combi_num
] *
prod
);
300
num_max
+= (
_cn_
->
get_binaryCPT_max
()[
id
][
combi_num
] *
prod
);
301
302
combi_den
++;
303
combi_num
++;
304
305
if
(
pp
!= 0) {
306
if
(
combi_den
%
pp
== 0) {
307
combi_den
+=
pp
;
308
combi_num
+=
pp
;
309
}
310
}
311
312
// incrementation
313
++
it
[0];
314
315
for
(
decltype
(
taille
)
i
= 0; (
i
<
taille
- 1) && (
it
[
i
] ==
combi_msg_p
[
i
].
end
()); ++
i
) {
316
it
[
i
] =
combi_msg_p
[
i
].
begin
();
317
++
it
[
i
+ 1];
318
}
319
}
// end of : marginalisation
320
321
compute_ext_
(
msg_l_min
,
msg_l_max
,
lx
,
num_min
,
num_max
,
den_min
,
den_max
);
322
}
323
324
/**
325
* extremes pour une combinaison des parents, message vers enfant
326
* marginalisation cpts
327
*/
328
template
<
typename
GUM_SCALAR
>
329
void
CNLoopyPropagation
<
GUM_SCALAR
>::
compute_ext_
(
330
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >&
combi_msg_p
,
331
const
NodeId
&
id
,
332
GUM_SCALAR
&
msg_p_min
,
333
GUM_SCALAR
&
msg_p_max
) {
334
GUM_SCALAR
min
= 0.;
335
GUM_SCALAR
max
= 0.;
336
337
auto
taille
=
combi_msg_p
.
size
();
338
339
std
::
vector
<
typename
std
::
vector
<
GUM_SCALAR
>::
iterator
>
it
(
taille
);
340
341
for
(
decltype
(
taille
)
i
= 0;
i
<
taille
;
i
++) {
342
it
[
i
] =
combi_msg_p
[
i
].
begin
();
343
}
344
345
int
combi
= 0;
346
auto
theEnd
=
combi_msg_p
[
taille
- 1].
end
();
347
348
while
(
it
[
taille
- 1] !=
theEnd
) {
349
GUM_SCALAR
prod
= 1.;
350
351
for
(
decltype
(
taille
)
k
= 0;
k
<
taille
;
k
++) {
352
prod
*= *
it
[
k
];
353
}
354
355
min
+= (
_cn_
->
get_binaryCPT_min
()[
id
][
combi
] *
prod
);
356
max
+= (
_cn_
->
get_binaryCPT_max
()[
id
][
combi
] *
prod
);
357
358
combi
++;
359
360
// incrementation
361
++
it
[0];
362
363
for
(
decltype
(
taille
)
i
= 0; (
i
<
taille
- 1) && (
it
[
i
] ==
combi_msg_p
[
i
].
end
()); ++
i
) {
364
it
[
i
] =
combi_msg_p
[
i
].
begin
();
365
++
it
[
i
+ 1];
366
}
367
}
368
369
if
(
min
<
msg_p_min
) {
msg_p_min
=
min
; }
370
371
if
(
max
>
msg_p_max
) {
msg_p_max
=
max
; }
372
}
373
374
/**
375
* enumerate combinations messages parents, pour message vers enfant
376
*/
377
template
<
typename
GUM_SCALAR
>
378
void
CNLoopyPropagation
<
GUM_SCALAR
>::
enum_combi_
(
379
std
::
vector
<
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> > >&
msgs_p
,
380
const
NodeId
&
id
,
381
GUM_SCALAR
&
msg_p_min
,
382
GUM_SCALAR
&
msg_p_max
) {
383
auto
taille
=
msgs_p
.
size
();
384
385
// source node
386
if
(
taille
== 0) {
387
msg_p_min
=
_cn_
->
get_binaryCPT_min
()[
id
][0];
388
msg_p_max
=
_cn_
->
get_binaryCPT_max
()[
id
][0];
389
return
;
390
}
391
392
decltype
(
taille
)
msgPerm
= 1;
393
#
pragma
omp
parallel
394
{
395
GUM_SCALAR
msg_pmin
=
msg_p_min
;
396
GUM_SCALAR
msg_pmax
=
msg_p_max
;
397
398
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >
combi_msg_p
(
taille
);
399
400
decltype
(
taille
)
confs
= 1;
401
402
#
pragma
omp
for
403
404
for
(
long
i
= 0;
i
<
long
(
taille
);
i
++) {
405
confs
*=
msgs_p
[
i
].
size
();
406
}
407
408
#
pragma
omp
atomic
409
msgPerm
*=
confs
;
410
#
pragma
omp
barrier
411
#
pragma
omp
flush
// ( msgPerm ) let the compiler choose what to flush (due to mvsc)
412
413
#
pragma
omp
for
414
415
for
(
int
j
= 0;
j
<
int
(
msgPerm
);
j
++) {
416
// get jth msg :
417
auto
jvalue
=
j
;
418
419
for
(
decltype
(
taille
)
i
= 0;
i
<
taille
;
i
++) {
420
if
(
msgs_p
[
i
].
size
() == 2) {
421
combi_msg_p
[
i
] = (
jvalue
& 1) ?
msgs_p
[
i
][1] :
msgs_p
[
i
][0];
422
jvalue
/= 2;
423
}
else
{
424
combi_msg_p
[
i
] =
msgs_p
[
i
][0];
425
}
426
}
427
428
compute_ext_
(
combi_msg_p
,
id
,
msg_pmin
,
msg_pmax
);
429
}
430
431
// since min is INF_ and max is 0 at init, there is no issue having more threads
432
// here
433
// than during for loop
434
#
pragma
omp
critical
(
msgpminmax
)
435
{
436
#
pragma
omp
flush
//( msg_p_min )
437
//#pragma omp flush ( msg_p_max ) let the compiler choose what to
438
// flush (due to mvsc)
439
440
if
(
msg_p_min
>
msg_pmin
) {
msg_p_min
=
msg_pmin
; }
441
442
if
(
msg_p_max
<
msg_pmax
) {
msg_p_max
=
msg_pmax
; }
443
}
444
}
445
return
;
446
}
447
448
/**
449
* comme precedemment mais pour message parent, vraisemblance prise en
450
* compte
451
*/
452
template
<
typename
GUM_SCALAR
>
453
void
CNLoopyPropagation
<
GUM_SCALAR
>::
enum_combi_
(
454
std
::
vector
<
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> > >&
msgs_p
,
455
const
NodeId
&
id
,
456
GUM_SCALAR
&
real_msg_l_min
,
457
GUM_SCALAR
&
real_msg_l_max
,
458
std
::
vector
<
GUM_SCALAR
>&
lx
,
459
const
Idx
&
pos
) {
460
GUM_SCALAR
msg_l_min
=
real_msg_l_min
;
461
GUM_SCALAR
msg_l_max
=
real_msg_l_max
;
462
463
auto
taille
=
msgs_p
.
size
();
464
465
// one parent node, the one receiving the message
466
if
(
taille
== 0) {
467
GUM_SCALAR
num_min
=
_cn_
->
get_binaryCPT_min
()[
id
][1];
468
GUM_SCALAR
num_max
=
_cn_
->
get_binaryCPT_max
()[
id
][1];
469
GUM_SCALAR
den_min
=
_cn_
->
get_binaryCPT_min
()[
id
][0];
470
GUM_SCALAR
den_max
=
_cn_
->
get_binaryCPT_max
()[
id
][0];
471
472
compute_ext_
(
msg_l_min
,
msg_l_max
,
lx
,
num_min
,
num_max
,
den_min
,
den_max
);
473
474
real_msg_l_min
=
msg_l_min
;
475
real_msg_l_max
=
msg_l_max
;
476
return
;
477
}
478
479
decltype
(
taille
)
msgPerm
= 1;
480
#
pragma
omp
parallel
481
{
482
GUM_SCALAR
msg_lmin
=
msg_l_min
;
483
GUM_SCALAR
msg_lmax
=
msg_l_max
;
484
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >
combi_msg_p
(
taille
);
485
486
decltype
(
taille
)
confs
= 1;
487
#
pragma
omp
for
488
489
for
(
int
i
= 0;
i
<
int
(
taille
);
i
++) {
490
confs
*=
msgs_p
[
i
].
size
();
491
}
492
493
#
pragma
omp
atomic
494
msgPerm
*=
confs
;
495
#
pragma
omp
barrier
496
#
pragma
omp
flush
(
msgPerm
)
497
498
// direct binary representation of config, no need for iterators
499
#
pragma
omp
for
500
501
for
(
long
j
= 0;
j
<
long
(
msgPerm
);
j
++) {
502
// get jth msg :
503
auto
jvalue
=
j
;
504
505
for
(
decltype
(
taille
)
i
= 0;
i
<
taille
;
i
++) {
506
if
(
msgs_p
[
i
].
size
() == 2) {
507
combi_msg_p
[
i
] = (
jvalue
& 1) ?
msgs_p
[
i
][1] :
msgs_p
[
i
][0];
508
jvalue
/= 2;
509
}
else
{
510
combi_msg_p
[
i
] =
msgs_p
[
i
][0];
511
}
512
}
513
514
compute_ext_
(
combi_msg_p
,
id
,
msg_lmin
,
msg_lmax
,
lx
,
pos
);
515
}
516
517
// there may be more threads here than in the for loop, therefor positive test
518
// is NECESSARY (init is -2)
519
#
pragma
omp
critical
(
msglminmax
)
520
{
521
#
pragma
omp
flush
(
msg_l_min
)
522
#
pragma
omp
flush
(
msg_l_max
)
523
524
if
((
msg_l_min
>
msg_lmin
||
msg_l_min
== -2) &&
msg_lmin
> 0) {
msg_l_min
=
msg_lmin
; }
525
526
if
((
msg_l_max
<
msg_lmax
||
msg_l_max
== -2) &&
msg_lmax
> 0) {
msg_l_max
=
msg_lmax
; }
527
}
528
}
529
530
real_msg_l_min
=
msg_l_min
;
531
real_msg_l_max
=
msg_l_max
;
532
}
533
534
template
<
typename
GUM_SCALAR
>
535
void
CNLoopyPropagation
<
GUM_SCALAR
>::
makeInference
() {
536
if
(
InferenceUpToDate_
) {
return
; }
537
538
initialize_
();
539
540
_infE_
::
initApproximationScheme
();
541
542
switch
(
_inferenceType_
) {
543
case
InferenceType
::
nodeToNeighbours
:
544
makeInferenceNodeToNeighbours_
();
545
break
;
546
547
case
InferenceType
::
ordered
:
548
makeInferenceByOrderedArcs_
();
549
break
;
550
551
case
InferenceType
::
randomOrder
:
552
makeInferenceByRandomOrder_
();
553
break
;
554
}
555
556
//_updateMarginals();
557
updateIndicatrices_
();
// will call updateMarginals_()
558
559
computeExpectations_
();
560
561
InferenceUpToDate_
=
true
;
562
}
563
564
template
<
typename
GUM_SCALAR
>
565
void
CNLoopyPropagation
<
GUM_SCALAR
>::
eraseAllEvidence
() {
566
_infE_
::
eraseAllEvidence
();
567
568
ArcsL_min_
.
clear
();
569
ArcsL_max_
.
clear
();
570
ArcsP_min_
.
clear
();
571
ArcsP_max_
.
clear
();
572
NodesL_min_
.
clear
();
573
NodesL_max_
.
clear
();
574
NodesP_min_
.
clear
();
575
NodesP_max_
.
clear
();
576
577
InferenceUpToDate_
=
false
;
578
579
if
(
msg_l_sent_
.
size
() > 0) {
580
for
(
auto
node
:
_bnet_
->
nodes
()) {
581
delete
msg_l_sent_
[
node
];
582
}
583
}
584
585
msg_l_sent_
.
clear
();
586
update_l_
.
clear
();
587
update_p_
.
clear
();
588
589
active_nodes_set
.
clear
();
590
next_active_nodes_set
.
clear
();
591
}
592
593
template
<
typename
GUM_SCALAR
>
594
void
CNLoopyPropagation
<
GUM_SCALAR
>::
initialize_
() {
595
const
DAG
&
graphe
=
_bnet_
->
dag
();
596
597
// use const iterators with cbegin when available
598
for
(
auto
node
:
_bnet_
->
topologicalOrder
()) {
599
update_p_
.
set
(
node
,
false
);
600
update_l_
.
set
(
node
,
false
);
601
NodeSet
*
parents_
=
new
NodeSet
();
602
msg_l_sent_
.
set
(
node
,
parents_
);
603
604
// accelerer init pour evidences
605
if
(
_infE_
::
evidence_
.
exists
(
node
)) {
606
if
(
_infE_
::
evidence_
[
node
][1] != 0. &&
_infE_
::
evidence_
[
node
][1] != 1.) {
607
GUM_ERROR
(
OperationNotAllowed
,
"CNLoopyPropagation can only handle HARD evidences"
)
608
}
609
610
active_nodes_set
.
insert
(
node
);
611
update_l_
.
set
(
node
,
true
);
612
update_p_
.
set
(
node
,
true
);
613
614
if
(
_infE_
::
evidence_
[
node
][1] == (
GUM_SCALAR
)1.) {
615
NodesL_min_
.
set
(
node
,
INF_
);
616
NodesP_min_
.
set
(
node
, (
GUM_SCALAR
)1.);
617
}
else
if
(
_infE_
::
evidence_
[
node
][1] == (
GUM_SCALAR
)0.) {
618
NodesL_min_
.
set
(
node
, (
GUM_SCALAR
)0.);
619
NodesP_min_
.
set
(
node
, (
GUM_SCALAR
)0.);
620
}
621
622
std
::
vector
<
GUM_SCALAR
>
marg
(2);
623
marg
[1] =
NodesP_min_
[
node
];
624
marg
[0] = 1 -
marg
[1];
625
626
_infE_
::
oldMarginalMin_
.
set
(
node
,
marg
);
627
_infE_
::
oldMarginalMax_
.
set
(
node
,
marg
);
628
629
continue
;
630
}
631
632
NodeSet
par_
=
graphe
.
parents
(
node
);
633
NodeSet
enf_
=
graphe
.
children
(
node
);
634
635
if
(
par_
.
size
() == 0) {
636
active_nodes_set
.
insert
(
node
);
637
update_p_
.
set
(
node
,
true
);
638
update_l_
.
set
(
node
,
true
);
639
}
640
641
if
(
enf_
.
size
() == 0) {
642
active_nodes_set
.
insert
(
node
);
643
update_p_
.
set
(
node
,
true
);
644
update_l_
.
set
(
node
,
true
);
645
}
646
647
/**
648
* messages and so parents need to be read in order of appearance
649
* use potentials instead of dag
650
*/
651
const
auto
parents
= &
_bnet_
->
cpt
(
node
).
variablesSequence
();
652
653
std
::
vector
<
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> > >
msgs_p
;
654
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >
msg_p
;
655
std
::
vector
<
GUM_SCALAR
>
distri
(2);
656
657
// +1 from start to avoid counting_ itself
658
// use const iterators when available with cbegin
659
for
(
auto
jt
= ++
parents
->
begin
(),
theEnd
=
parents
->
end
();
jt
!=
theEnd
; ++
jt
) {
660
// compute probability distribution to avoid doing it multiple times
661
// (at
662
// each combination of messages)
663
distri
[1] =
NodesP_min_
[
_bnet_
->
nodeId
(**
jt
)];
664
distri
[0] = (
GUM_SCALAR
)1. -
distri
[1];
665
msg_p
.
push_back
(
distri
);
666
667
if
(
NodesP_max_
.
exists
(
_bnet_
->
nodeId
(**
jt
))) {
668
distri
[1] =
NodesP_max_
[
_bnet_
->
nodeId
(**
jt
)];
669
distri
[0] = (
GUM_SCALAR
)1. -
distri
[1];
670
msg_p
.
push_back
(
distri
);
671
}
672
673
msgs_p
.
push_back
(
msg_p
);
674
msg_p
.
clear
();
675
}
676
677
GUM_SCALAR
msg_p_min
= 1.;
678
GUM_SCALAR
msg_p_max
= 0.;
679
680
if
(
_cn_
->
currentNodeType
(
node
) !=
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
) {
681
enum_combi_
(
msgs_p
,
node
,
msg_p_min
,
msg_p_max
);
682
}
683
684
if
(
msg_p_min
<= (
GUM_SCALAR
)0.) {
msg_p_min
= (
GUM_SCALAR
)0.; }
685
686
if
(
msg_p_max
<= (
GUM_SCALAR
)0.) {
msg_p_max
= (
GUM_SCALAR
)0.; }
687
688
NodesP_min_
.
set
(
node
,
msg_p_min
);
689
std
::
vector
<
GUM_SCALAR
>
marg
(2);
690
marg
[1] =
msg_p_min
;
691
marg
[0] = 1 -
msg_p_min
;
692
693
_infE_
::
oldMarginalMin_
.
set
(
node
,
marg
);
694
695
if
(
msg_p_min
!=
msg_p_max
) {
696
marg
[1] =
msg_p_max
;
697
marg
[0] = 1 -
msg_p_max
;
698
NodesP_max_
.
insert
(
node
,
msg_p_max
);
699
}
700
701
_infE_
::
oldMarginalMax_
.
set
(
node
,
marg
);
702
703
NodesL_min_
.
set
(
node
, (
GUM_SCALAR
)1.);
704
}
705
706
for
(
auto
arc
:
_bnet_
->
arcs
()) {
707
ArcsP_min_
.
set
(
arc
,
NodesP_min_
[
arc
.
tail
()]);
708
709
if
(
NodesP_max_
.
exists
(
arc
.
tail
())) {
ArcsP_max_
.
set
(
arc
,
NodesP_max_
[
arc
.
tail
()]); }
710
711
ArcsL_min_
.
set
(
arc
,
NodesL_min_
[
arc
.
tail
()]);
712
}
713
}
714
715
template
<
typename
GUM_SCALAR
>
716
void
CNLoopyPropagation
<
GUM_SCALAR
>::
makeInferenceNodeToNeighbours_
() {
717
const
DAG
&
graphe
=
_bnet_
->
dag
();
718
719
GUM_SCALAR
eps
;
720
// to validate TestSuite
721
_infE_
::
continueApproximationScheme
(1.);
722
723
do
{
724
for
(
auto
node
:
active_nodes_set
) {
725
for
(
auto
chil
:
graphe
.
children
(
node
)) {
726
if
(
_cn_
->
currentNodeType
(
chil
) ==
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
) {
727
continue
;
728
}
729
730
msgP_
(
node
,
chil
);
731
}
732
733
for
(
auto
par
:
graphe
.
parents
(
node
)) {
734
if
(
_cn_
->
currentNodeType
(
node
) ==
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
) {
735
continue
;
736
}
737
738
msgL_
(
node
,
par
);
739
}
740
}
741
742
eps
=
calculateEpsilon_
();
743
744
_infE_
::
updateApproximationScheme
();
745
746
active_nodes_set
.
clear
();
747
active_nodes_set
=
next_active_nodes_set
;
748
next_active_nodes_set
.
clear
();
749
750
}
while
(
_infE_
::
continueApproximationScheme
(
eps
) &&
active_nodes_set
.
size
() > 0);
751
752
_infE_
::
stopApproximationScheme
();
// just to be sure of the
753
// approximationScheme has been notified of
754
// the end of looop
755
}
756
757
template
<
typename
GUM_SCALAR
>
758
void
CNLoopyPropagation
<
GUM_SCALAR
>::
makeInferenceByRandomOrder_
() {
759
Size
nbrArcs
=
_bnet_
->
dag
().
sizeArcs
();
760
761
std
::
vector
<
cArcP
>
seq
;
762
seq
.
reserve
(
nbrArcs
);
763
764
for
(
const
auto
&
arc
:
_bnet_
->
arcs
()) {
765
seq
.
push_back
(&
arc
);
766
}
767
768
GUM_SCALAR
eps
;
769
// validate TestSuite
770
_infE_
::
continueApproximationScheme
(1.);
771
772
do
{
773
for
(
Size
j
= 0,
theEnd
=
nbrArcs
/ 2;
j
<
theEnd
;
j
++) {
774
auto
w1
=
rand
() %
nbrArcs
,
w2
=
rand
() %
nbrArcs
;
775
776
if
(
w1
==
w2
) {
continue
; }
777
778
std
::
swap
(
seq
[
w1
],
seq
[
w2
]);
779
}
780
781
for
(
const
auto
it
:
seq
) {
782
if
(
_cn_
->
currentNodeType
(
it
->
tail
()) ==
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
783
||
_cn_
->
currentNodeType
(
it
->
head
()) ==
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
) {
784
continue
;
785
}
786
787
msgP_
(
it
->
tail
(),
it
->
head
());
788
msgL_
(
it
->
head
(),
it
->
tail
());
789
}
790
791
eps
=
calculateEpsilon_
();
792
793
_infE_
::
updateApproximationScheme
();
794
795
}
while
(
_infE_
::
continueApproximationScheme
(
eps
));
796
}
797
798
// gives slightly worse results for some variable/modalities than other
799
// inference
800
// types (node D on 2U network loose 0.03 precision)
801
template
<
typename
GUM_SCALAR
>
802
void
CNLoopyPropagation
<
GUM_SCALAR
>::
makeInferenceByOrderedArcs_
() {
803
Size
nbrArcs
=
_bnet_
->
dag
().
sizeArcs
();
804
805
std
::
vector
<
cArcP
>
seq
;
806
seq
.
reserve
(
nbrArcs
);
807
808
for
(
const
auto
&
arc
:
_bnet_
->
arcs
()) {
809
seq
.
push_back
(&
arc
);
810
}
811
812
GUM_SCALAR
eps
;
813
// validate TestSuite
814
_infE_
::
continueApproximationScheme
(1.);
815
816
do
{
817
for
(
const
auto
it
:
seq
) {
818
if
(
_cn_
->
currentNodeType
(
it
->
tail
()) ==
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
819
||
_cn_
->
currentNodeType
(
it
->
head
()) ==
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
) {
820
continue
;
821
}
822
823
msgP_
(
it
->
tail
(),
it
->
head
());
824
msgL_
(
it
->
head
(),
it
->
tail
());
825
}
826
827
eps
=
calculateEpsilon_
();
828
829
_infE_
::
updateApproximationScheme
();
830
831
}
while
(
_infE_
::
continueApproximationScheme
(
eps
));
832
}
833
834
template
<
typename
GUM_SCALAR
>
835
void
CNLoopyPropagation
<
GUM_SCALAR
>::
msgL_
(
const
NodeId
Y
,
const
NodeId
X
) {
836
NodeSet
const
&
children
=
_bnet_
->
children
(
Y
);
837
NodeSet
const
&
parents_
=
_bnet_
->
parents
(
Y
);
838
839
const
auto
parents
= &
_bnet_
->
cpt
(
Y
).
variablesSequence
();
840
841
if
(((
children
.
size
() +
parents
->
size
() - 1) == 1) && (!
_infE_
::
evidence_
.
exists
(
Y
))) {
842
return
;
843
}
844
845
bool
update_l
=
update_l_
[
Y
];
846
bool
update_p
=
update_p_
[
Y
];
847
848
if
(!
update_p
&& !
update_l
) {
return
; }
849
850
msg_l_sent_
[
Y
]->
insert
(
X
);
851
852
// for future refresh LM/PI
853
if
(
msg_l_sent_
[
Y
]->
size
() ==
parents_
.
size
()) {
854
msg_l_sent_
[
Y
]->
clear
();
855
update_l_
[
Y
] =
false
;
856
}
857
858
// refresh LM_part
859
if
(
update_l
) {
860
if
(!
children
.
empty
() && !
_infE_
::
evidence_
.
exists
(
Y
)) {
861
GUM_SCALAR
lmin
= 1.;
862
GUM_SCALAR
lmax
= 1.;
863
864
for
(
auto
chil
:
children
) {
865
lmin
*=
ArcsL_min_
[
Arc
(
Y
,
chil
)];
866
867
if
(
ArcsL_max_
.
exists
(
Arc
(
Y
,
chil
))) {
868
lmax
*=
ArcsL_max_
[
Arc
(
Y
,
chil
)];
869
}
else
{
870
lmax
*=
ArcsL_min_
[
Arc
(
Y
,
chil
)];
871
}
872
}
873
874
lmin
=
lmax
;
875
876
if
(
lmax
!=
lmax
&&
lmin
==
lmin
) {
lmax
=
lmin
; }
877
878
if
(
lmax
!=
lmax
&&
lmin
!=
lmin
) {
879
std
::
cout
<<
"no likelihood defined [lmin, lmax] (incompatibles "
880
"evidence ?)"
881
<<
std
::
endl
;
882
}
883
884
if
(
lmin
< 0.) {
lmin
= 0.; }
885
886
if
(
lmax
< 0.) {
lmax
= 0.; }
887
888
// no need to update nodeL if evidence since nodeL will never be used
889
890
NodesL_min_
[
Y
] =
lmin
;
891
892
if
(
lmin
!=
lmax
) {
893
NodesL_max_
.
set
(
Y
,
lmax
);
894
}
else
if
(
NodesL_max_
.
exists
(
Y
)) {
895
NodesL_max_
.
erase
(
Y
);
896
}
897
898
}
// end of : node has children & no evidence
899
900
}
// end of : if update_l
901
902
GUM_SCALAR
lmin
=
NodesL_min_
[
Y
];
903
GUM_SCALAR
lmax
;
904
905
if
(
NodesL_max_
.
exists
(
Y
)) {
906
lmax
=
NodesL_max_
[
Y
];
907
}
else
{
908
lmax
=
lmin
;
909
}
910
911
/**
912
* lmin == lmax == 1 => sends 1 as message to parents
913
*/
914
915
if
(
lmin
==
lmax
&&
lmin
== 1.) {
916
ArcsL_min_
[
Arc
(
X
,
Y
)] =
lmin
;
917
918
if
(
ArcsL_max_
.
exists
(
Arc
(
X
,
Y
))) {
ArcsL_max_
.
erase
(
Arc
(
X
,
Y
)); }
919
920
return
;
921
}
922
923
// garder pour chaque noeud un table des parents maj, une fois tous maj,
924
// stop
925
// jusque notification msg L ou P
926
927
if
(
update_p
||
update_l
) {
928
std
::
vector
<
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> > >
msgs_p
;
929
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >
msg_p
;
930
std
::
vector
<
GUM_SCALAR
>
distri
(2);
931
932
Idx
pos
;
933
934
// +1 from start to avoid counting_ itself
935
// use const iterators with cbegin when available
936
for
(
auto
jt
= ++
parents
->
begin
(),
theEnd
=
parents
->
end
();
jt
!=
theEnd
; ++
jt
) {
937
if
(
_bnet_
->
nodeId
(**
jt
) ==
X
) {
938
// retirer la variable courante de la taille
939
pos
=
parents
->
pos
(*
jt
) - 1;
940
continue
;
941
}
942
943
// compute probability distribution to avoid doing it multiple times
944
// (at each combination of messages)
945
distri
[1] =
ArcsP_min_
[
Arc
(
_bnet_
->
nodeId
(**
jt
),
Y
)];
946
distri
[0] =
GUM_SCALAR
(1.) -
distri
[1];
947
msg_p
.
push_back
(
distri
);
948
949
if
(
ArcsP_max_
.
exists
(
Arc
(
_bnet_
->
nodeId
(**
jt
),
Y
))) {
950
distri
[1] =
ArcsP_max_
[
Arc
(
_bnet_
->
nodeId
(**
jt
),
Y
)];
951
distri
[0] =
GUM_SCALAR
(1.) -
distri
[1];
952
msg_p
.
push_back
(
distri
);
953
}
954
955
msgs_p
.
push_back
(
msg_p
);
956
msg_p
.
clear
();
957
}
958
959
GUM_SCALAR
min
= -2.;
960
GUM_SCALAR
max
= -2.;
961
962
std
::
vector
<
GUM_SCALAR
>
lx
;
963
lx
.
push_back
(
lmin
);
964
965
if
(
lmin
!=
lmax
) {
lx
.
push_back
(
lmax
); }
966
967
enum_combi_
(
msgs_p
,
Y
,
min
,
max
,
lx
,
pos
);
968
969
if
(
min
== -2. ||
max
== -2.) {
970
if
(
min
!= -2.) {
971
max
=
min
;
972
}
else
if
(
max
!= -2.) {
973
min
=
max
;
974
}
else
{
975
std
::
cout
<<
std
::
endl
;
976
std
::
cout
<<
"!!!! pas de message L calculable !!!!"
<<
std
::
endl
;
977
return
;
978
}
979
}
980
981
if
(
min
< 0.) {
min
= 0.; }
982
983
if
(
max
< 0.) {
max
= 0.; }
984
985
bool
update
=
false
;
986
987
if
(
min
!=
ArcsL_min_
[
Arc
(
X
,
Y
)]) {
988
ArcsL_min_
[
Arc
(
X
,
Y
)] =
min
;
989
update
=
true
;
990
}
991
992
if
(
ArcsL_max_
.
exists
(
Arc
(
X
,
Y
))) {
993
if
(
max
!=
ArcsL_max_
[
Arc
(
X
,
Y
)]) {
994
if
(
max
!=
min
) {
995
ArcsL_max_
[
Arc
(
X
,
Y
)] =
max
;
996
}
else
{
// if ( max == min )
997
ArcsL_max_
.
erase
(
Arc
(
X
,
Y
));
998
}
999
1000
update
=
true
;
1001
}
1002
}
else
{
1003
if
(
max
!=
min
) {
1004
ArcsL_max_
.
insert
(
Arc
(
X
,
Y
),
max
);
1005
update
=
true
;
1006
}
1007
}
1008
1009
if
(
update
) {
1010
update_l_
.
set
(
X
,
true
);
1011
next_active_nodes_set
.
insert
(
X
);
1012
}
1013
1014
}
// end of update_p || update_l
1015
}
1016
1017
template
<
typename
GUM_SCALAR
>
1018
void
CNLoopyPropagation
<
GUM_SCALAR
>::
msgP_
(
const
NodeId
X
,
const
NodeId
demanding_child
) {
1019
NodeSet
const
&
children
=
_bnet_
->
children
(
X
);
1020
1021
const
auto
parents
= &
_bnet_
->
cpt
(
X
).
variablesSequence
();
1022
1023
if
(((
children
.
size
() +
parents
->
size
() - 1) == 1) && (!
_infE_
::
evidence_
.
exists
(
X
))) {
1024
return
;
1025
}
1026
1027
// LM_part ---- from all children but one --- the lonely one will get the
1028
// message
1029
1030
if
(
_infE_
::
evidence_
.
exists
(
X
)) {
1031
ArcsP_min_
[
Arc
(
X
,
demanding_child
)] =
_infE_
::
evidence_
[
X
][1];
1032
1033
if
(
ArcsP_max_
.
exists
(
Arc
(
X
,
demanding_child
))) {
1034
ArcsP_max_
.
erase
(
Arc
(
X
,
demanding_child
));
1035
}
1036
1037
return
;
1038
}
1039
1040
bool
update_l
=
update_l_
[
X
];
1041
bool
update_p
=
update_p_
[
X
];
1042
1043
if
(!
update_p
&& !
update_l
) {
return
; }
1044
1045
GUM_SCALAR
lmin
= 1.;
1046
GUM_SCALAR
lmax
= 1.;
1047
1048
// use cbegin if available
1049
for
(
auto
chil
:
children
) {
1050
if
(
chil
==
demanding_child
) {
continue
; }
1051
1052
lmin
*=
ArcsL_min_
[
Arc
(
X
,
chil
)];
1053
1054
if
(
ArcsL_max_
.
exists
(
Arc
(
X
,
chil
))) {
1055
lmax
*=
ArcsL_max_
[
Arc
(
X
,
chil
)];
1056
}
else
{
1057
lmax
*=
ArcsL_min_
[
Arc
(
X
,
chil
)];
1058
}
1059
}
1060
1061
if
(
lmin
!=
lmin
&&
lmax
==
lmax
) {
lmin
=
lmax
; }
1062
1063
if
(
lmax
!=
lmax
&&
lmin
==
lmin
) {
lmax
=
lmin
; }
1064
1065
if
(
lmax
!=
lmax
&&
lmin
!=
lmin
) {
1066
std
::
cout
<<
"pas de vraisemblance definie [lmin, lmax] (observations "
1067
"incompatibles ?)"
1068
<<
std
::
endl
;
1069
return
;
1070
}
1071
1072
if
(
lmin
< 0.) {
lmin
= 0.; }
1073
1074
if
(
lmax
< 0.) {
lmax
= 0.; }
1075
1076
// refresh PI_part
1077
GUM_SCALAR
min
=
INF_
;
1078
GUM_SCALAR
max
= 0.;
1079
1080
if
(
update_p
) {
1081
std
::
vector
<
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> > >
msgs_p
;
1082
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >
msg_p
;
1083
std
::
vector
<
GUM_SCALAR
>
distri
(2);
1084
1085
// +1 from start to avoid counting_ itself
1086
// use const_iterators if available
1087
for
(
auto
jt
= ++
parents
->
begin
(),
theEnd
=
parents
->
end
();
jt
!=
theEnd
; ++
jt
) {
1088
// compute probability distribution to avoid doing it multiple times
1089
// (at
1090
// each combination of messages)
1091
distri
[1] =
ArcsP_min_
[
Arc
(
_bnet_
->
nodeId
(**
jt
),
X
)];
1092
distri
[0] =
GUM_SCALAR
(1.) -
distri
[1];
1093
msg_p
.
push_back
(
distri
);
1094
1095
if
(
ArcsP_max_
.
exists
(
Arc
(
_bnet_
->
nodeId
(**
jt
),
X
))) {
1096
distri
[1] =
ArcsP_max_
[
Arc
(
_bnet_
->
nodeId
(**
jt
),
X
)];
1097
distri
[0] =
GUM_SCALAR
(1.) -
distri
[1];
1098
msg_p
.
push_back
(
distri
);
1099
}
1100
1101
msgs_p
.
push_back
(
msg_p
);
1102
msg_p
.
clear
();
1103
}
1104
1105
enum_combi_
(
msgs_p
,
X
,
min
,
max
);
1106
1107
if
(
min
< 0.) {
min
= 0.; }
1108
1109
if
(
max
< 0.) {
max
= 0.; }
1110
1111
if
(
min
==
INF_
||
max
==
INF_
) {
1112
std
::
cout
<<
" ERREUR msg P min = max = INF "
<<
std
::
endl
;
1113
std
::
cout
.
flush
();
1114
return
;
1115
}
1116
1117
NodesP_min_
[
X
] =
min
;
1118
1119
if
(
min
!=
max
) {
1120
NodesP_max_
.
set
(
X
,
max
);
1121
}
else
if
(
NodesP_max_
.
exists
(
X
)) {
1122
NodesP_max_
.
erase
(
X
);
1123
}
1124
1125
update_p_
.
set
(
X
,
false
);
1126
1127
}
// end of update_p
1128
else
{
1129
min
=
NodesP_min_
[
X
];
1130
1131
if
(
NodesP_max_
.
exists
(
X
)) {
1132
max
=
NodesP_max_
[
X
];
1133
}
else
{
1134
max
=
min
;
1135
}
1136
}
1137
1138
if
(
update_p
||
update_l
) {
1139
GUM_SCALAR
msg_p_min
;
1140
GUM_SCALAR
msg_p_max
;
1141
1142
// cas limites sur min
1143
if
(
min
==
INF_
&&
lmin
== 0.) {
1144
std
::
cout
<<
"MESSAGE P ERR (negatif) : pi = inf, l = 0"
<<
std
::
endl
;
1145
}
1146
1147
if
(
lmin
==
INF_
) {
// cas infini
1148
msg_p_min
=
GUM_SCALAR
(1.);
1149
}
else
if
(
min
== 0. ||
lmin
== 0.) {
1150
msg_p_min
= 0;
1151
}
else
{
1152
msg_p_min
=
GUM_SCALAR
(1. / (1. + ((1. /
min
- 1.) * 1. /
lmin
)));
1153
}
1154
1155
// cas limites sur max
1156
if
(
max
==
INF_
&&
lmax
== 0.) {
1157
std
::
cout
<<
"MESSAGE P ERR (negatif) : pi = inf, l = 0"
<<
std
::
endl
;
1158
}
1159
1160
if
(
lmax
==
INF_
) {
// cas infini
1161
msg_p_max
=
GUM_SCALAR
(1.);
1162
}
else
if
(
max
== 0. ||
lmax
== 0.) {
1163
msg_p_max
= 0;
1164
}
else
{
1165
msg_p_max
=
GUM_SCALAR
(1. / (1. + ((1. /
max
- 1.) * 1. /
lmax
)));
1166
}
1167
1168
if
(
msg_p_min
!=
msg_p_min
&&
msg_p_max
==
msg_p_max
) {
1169
msg_p_min
=
msg_p_max
;
1170
std
::
cout
<<
std
::
endl
;
1171
std
::
cout
<<
"msg_p_min is NaN"
<<
std
::
endl
;
1172
}
1173
1174
if
(
msg_p_max
!=
msg_p_max
&&
msg_p_min
==
msg_p_min
) {
1175
msg_p_max
=
msg_p_min
;
1176
std
::
cout
<<
std
::
endl
;
1177
std
::
cout
<<
"msg_p_max is NaN"
<<
std
::
endl
;
1178
}
1179
1180
if
(
msg_p_max
!=
msg_p_max
&&
msg_p_min
!=
msg_p_min
) {
1181
std
::
cout
<<
std
::
endl
;
1182
std
::
cout
<<
"pas de message P calculable (verifier observations)"
<<
std
::
endl
;
1183
return
;
1184
}
1185
1186
if
(
msg_p_min
< 0.) {
msg_p_min
= 0.; }
1187
1188
if
(
msg_p_max
< 0.) {
msg_p_max
= 0.; }
1189
1190
bool
update
=
false
;
1191
1192
if
(
msg_p_min
!=
ArcsP_min_
[
Arc
(
X
,
demanding_child
)]) {
1193
ArcsP_min_
[
Arc
(
X
,
demanding_child
)] =
msg_p_min
;
1194
update
=
true
;
1195
}
1196
1197
if
(
ArcsP_max_
.
exists
(
Arc
(
X
,
demanding_child
))) {
1198
if
(
msg_p_max
!=
ArcsP_max_
[
Arc
(
X
,
demanding_child
)]) {
1199
if
(
msg_p_max
!=
msg_p_min
) {
1200
ArcsP_max_
[
Arc
(
X
,
demanding_child
)] =
msg_p_max
;
1201
}
else
{
// if ( msg_p_max == msg_p_min )
1202
ArcsP_max_
.
erase
(
Arc
(
X
,
demanding_child
));
1203
}
1204
1205
update
=
true
;
1206
}
1207
}
else
{
1208
if
(
msg_p_max
!=
msg_p_min
) {
1209
ArcsP_max_
.
insert
(
Arc
(
X
,
demanding_child
),
msg_p_max
);
1210
update
=
true
;
1211
}
1212
}
1213
1214
if
(
update
) {
1215
update_p_
.
set
(
demanding_child
,
true
);
1216
next_active_nodes_set
.
insert
(
demanding_child
);
1217
}
1218
1219
}
// end of : update_l || update_p
1220
}
1221
1222
template
<
typename
GUM_SCALAR
>
1223
void
CNLoopyPropagation
<
GUM_SCALAR
>::
refreshLMsPIs_
(
bool
refreshIndic
) {
1224
for
(
auto
node
:
_bnet_
->
nodes
()) {
1225
if
((!
refreshIndic
)
1226
&&
_cn_
->
currentNodeType
(
node
) ==
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
) {
1227
continue
;
1228
}
1229
1230
NodeSet
const
&
children
=
_bnet_
->
children
(
node
);
1231
1232
auto
parents
= &
_bnet_
->
cpt
(
node
).
variablesSequence
();
1233
1234
if
(
update_l_
[
node
]) {
1235
GUM_SCALAR
lmin
= 1.;
1236
GUM_SCALAR
lmax
= 1.;
1237
1238
if
(!
children
.
empty
() && !
_infE_
::
evidence_
.
exists
(
node
)) {
1239
for
(
auto
chil
:
children
) {
1240
lmin
*=
ArcsL_min_
[
Arc
(
node
,
chil
)];
1241
1242
if
(
ArcsL_max_
.
exists
(
Arc
(
node
,
chil
))) {
1243
lmax
*=
ArcsL_max_
[
Arc
(
node
,
chil
)];
1244
}
else
{
1245
lmax
*=
ArcsL_min_
[
Arc
(
node
,
chil
)];
1246
}
1247
}
1248
1249
if
(
lmin
!=
lmin
&&
lmax
==
lmax
) {
lmin
=
lmax
; }
1250
1251
lmax
=
lmin
;
1252
1253
if
(
lmax
!=
lmax
&&
lmin
!=
lmin
) {
1254
std
::
cout
<<
"pas de vraisemblance definie [lmin, lmax] (observations "
1255
"incompatibles ?)"
1256
<<
std
::
endl
;
1257
return
;
1258
}
1259
1260
if
(
lmin
< 0.) {
lmin
= 0.; }
1261
1262
if
(
lmax
< 0.) {
lmax
= 0.; }
1263
1264
NodesL_min_
[
node
] =
lmin
;
1265
1266
if
(
lmin
!=
lmax
) {
1267
NodesL_max_
.
set
(
node
,
lmax
);
1268
}
else
if
(
NodesL_max_
.
exists
(
node
)) {
1269
NodesL_max_
.
erase
(
node
);
1270
}
1271
}
1272
1273
}
// end of : update_l
1274
1275
if
(
update_p_
[
node
]) {
1276
if
((
parents
->
size
() - 1) > 0 && !
_infE_
::
evidence_
.
exists
(
node
)) {
1277
std
::
vector
<
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> > >
msgs_p
;
1278
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >
msg_p
;
1279
std
::
vector
<
GUM_SCALAR
>
distri
(2);
1280
1281
// +1 from start to avoid counting_ itself
1282
// cbegin
1283
for
(
auto
jt
= ++
parents
->
begin
(),
theEnd
=
parents
->
end
();
jt
!=
theEnd
; ++
jt
) {
1284
// compute probability distribution to avoid doing it multiple
1285
// times
1286
// (at each combination of messages)
1287
distri
[1] =
ArcsP_min_
[
Arc
(
_bnet_
->
nodeId
(**
jt
),
node
)];
1288
distri
[0] =
GUM_SCALAR
(1.) -
distri
[1];
1289
msg_p
.
push_back
(
distri
);
1290
1291
if
(
ArcsP_max_
.
exists
(
Arc
(
_bnet_
->
nodeId
(**
jt
),
node
))) {
1292
distri
[1] =
ArcsP_max_
[
Arc
(
_bnet_
->
nodeId
(**
jt
),
node
)];
1293
distri
[0] =
GUM_SCALAR
(1.) -
distri
[1];
1294
msg_p
.
push_back
(
distri
);
1295
}
1296
1297
msgs_p
.
push_back
(
msg_p
);
1298
msg_p
.
clear
();
1299
}
1300
1301
GUM_SCALAR
min
=
INF_
;
1302
GUM_SCALAR
max
= 0.;
1303
1304
enum_combi_
(
msgs_p
,
node
,
min
,
max
);
1305
1306
if
(
min
< 0.) {
min
= 0.; }
1307
1308
if
(
max
< 0.) {
max
= 0.; }
1309
1310
NodesP_min_
[
node
] =
min
;
1311
1312
if
(
min
!=
max
) {
1313
NodesP_max_
.
set
(
node
,
max
);
1314
}
else
if
(
NodesP_max_
.
exists
(
node
)) {
1315
NodesP_max_
.
erase
(
node
);
1316
}
1317
1318
update_p_
[
node
] =
false
;
1319
}
1320
}
// end of update_p
1321
1322
}
// end of : for each node
1323
}
1324
1325
template
<
typename
GUM_SCALAR
>
1326
void
CNLoopyPropagation
<
GUM_SCALAR
>::
updateMarginals_
() {
1327
for
(
auto
node
:
_bnet_
->
nodes
()) {
1328
GUM_SCALAR
msg_p_min
= 1.;
1329
GUM_SCALAR
msg_p_max
= 0.;
1330
1331
if
(
_infE_
::
evidence_
.
exists
(
node
)) {
1332
if
(
_infE_
::
evidence_
[
node
][1] == 0.) {
1333
msg_p_min
= (
GUM_SCALAR
)0.;
1334
}
else
if
(
_infE_
::
evidence_
[
node
][1] == 1.) {
1335
msg_p_min
= 1.;
1336
}
1337
1338
msg_p_max
=
msg_p_min
;
1339
}
else
{
1340
GUM_SCALAR
min
=
NodesP_min_
[
node
];
1341
GUM_SCALAR
max
;
1342
1343
if
(
NodesP_max_
.
exists
(
node
)) {
1344
max
=
NodesP_max_
[
node
];
1345
}
else
{
1346
max
=
min
;
1347
}
1348
1349
GUM_SCALAR
lmin
=
NodesL_min_
[
node
];
1350
GUM_SCALAR
lmax
;
1351
if
(
NodesL_max_
.
exists
(
node
)) {
1352
lmax
=
NodesL_max_
[
node
];
1353
}
else
{
1354
lmax
=
lmin
;
1355
}
1356
1357
if
(
min
==
INF_
||
max
==
INF_
) {
1358
std
::
cout
<<
" min ou max === INF_ !!!!!!!!!!!!!!!!!!!!!!!!!! "
<<
std
::
endl
;
1359
return
;
1360
}
1361
1362
if
(
min
==
INF_
&&
lmin
== 0.) {
1363
std
::
cout
<<
"proba ERR (negatif) : pi = inf, l = 0"
<<
std
::
endl
;
1364
return
;
1365
}
1366
1367
if
(
lmin
==
INF_
) {
1368
msg_p_min
=
GUM_SCALAR
(1.);
1369
}
else
if
(
min
== 0. ||
lmin
== 0.) {
1370
msg_p_min
=
GUM_SCALAR
(0.);
1371
}
else
{
1372
msg_p_min
=
GUM_SCALAR
(1. / (1. + ((1. /
min
- 1.) * 1. /
lmin
)));
1373
}
1374
1375
if
(
max
==
INF_
&&
lmax
== 0.) {
1376
std
::
cout
<<
"proba ERR (negatif) : pi = inf, l = 0"
<<
std
::
endl
;
1377
return
;
1378
}
1379
1380
if
(
lmax
==
INF_
) {
1381
msg_p_max
=
GUM_SCALAR
(1.);
1382
}
else
if
(
max
== 0. ||
lmax
== 0.) {
1383
msg_p_max
=
GUM_SCALAR
(0.);
1384
}
else
{
1385
msg_p_max
=
GUM_SCALAR
(1. / (1. + ((1. /
max
- 1.) * 1. /
lmax
)));
1386
}
1387
}
1388
1389
if
(
msg_p_min
!=
msg_p_min
&&
msg_p_max
==
msg_p_max
) {
1390
msg_p_min
=
msg_p_max
;
1391
std
::
cout
<<
std
::
endl
;
1392
std
::
cout
<<
"msg_p_min is NaN"
<<
std
::
endl
;
1393
}
1394
1395
if
(
msg_p_max
!=
msg_p_max
&&
msg_p_min
==
msg_p_min
) {
1396
msg_p_max
=
msg_p_min
;
1397
std
::
cout
<<
std
::
endl
;
1398
std
::
cout
<<
"msg_p_max is NaN"
<<
std
::
endl
;
1399
}
1400
1401
if
(
msg_p_max
!=
msg_p_max
&&
msg_p_min
!=
msg_p_min
) {
1402
std
::
cout
<<
std
::
endl
;
1403
std
::
cout
<<
"Please check the observations (no proba can be computed)"
<<
std
::
endl
;
1404
return
;
1405
}
1406
1407
if
(
msg_p_min
< 0.) {
msg_p_min
= 0.; }
1408
1409
if
(
msg_p_max
< 0.) {
msg_p_max
= 0.; }
1410
1411
_infE_
::
marginalMin_
[
node
][0] = 1 -
msg_p_max
;
1412
_infE_
::
marginalMax_
[
node
][0] = 1 -
msg_p_min
;
1413
_infE_
::
marginalMin_
[
node
][1] =
msg_p_min
;
1414
_infE_
::
marginalMax_
[
node
][1] =
msg_p_max
;
1415
}
1416
}
1417
1418
template
<
typename
GUM_SCALAR
>
1419
GUM_SCALAR
CNLoopyPropagation
<
GUM_SCALAR
>::
calculateEpsilon_
() {
1420
refreshLMsPIs_
();
1421
updateMarginals_
();
1422
1423
return
_infE_
::
computeEpsilon_
();
1424
}
1425
1426
template
<
typename
GUM_SCALAR
>
1427
void
CNLoopyPropagation
<
GUM_SCALAR
>::
updateIndicatrices_
() {
1428
for
(
auto
node
:
_bnet_
->
nodes
()) {
1429
if
(
_cn_
->
currentNodeType
(
node
) !=
CredalNet
<
GUM_SCALAR
>::
NodeType
::
Indic
) {
continue
; }
1430
1431
for
(
auto
pare
:
_bnet_
->
parents
(
node
)) {
1432
msgP_
(
pare
,
node
);
1433
}
1434
}
1435
1436
refreshLMsPIs_
(
true
);
1437
updateMarginals_
();
1438
}
1439
1440
template
<
typename
GUM_SCALAR
>
1441
void
CNLoopyPropagation
<
GUM_SCALAR
>::
computeExpectations_
() {
1442
if
(
_infE_
::
modal_
.
empty
()) {
return
; }
1443
1444
std
::
vector
<
std
::
vector
<
GUM_SCALAR
> >
vertices
(2,
std
::
vector
<
GUM_SCALAR
>(2));
1445
1446
for
(
auto
node
:
_bnet_
->
nodes
()) {
1447
vertices
[0][0] =
_infE_
::
marginalMin_
[
node
][0];
1448
vertices
[0][1] =
_infE_
::
marginalMax_
[
node
][1];
1449
1450
vertices
[1][0] =
_infE_
::
marginalMax_
[
node
][0];
1451
vertices
[1][1] =
_infE_
::
marginalMin_
[
node
][1];
1452
1453
for
(
auto
vertex
= 0,
vend
= 2;
vertex
!=
vend
;
vertex
++) {
1454
_infE_
::
updateExpectations_
(
node
,
vertices
[
vertex
]);
1455
// test credal sets vertices elim
1456
// remove with L2U since variables are binary
1457
// but does the user know that ?
1458
_infE_
::
updateCredalSets_
(
1459
node
,
1460
vertices
[
vertex
]);
// no redundancy elimination with 2 vertices
1461
}
1462
}
1463
}
1464
1465
template
<
typename
GUM_SCALAR
>
1466
CNLoopyPropagation
<
GUM_SCALAR
>::
CNLoopyPropagation
(
const
CredalNet
<
GUM_SCALAR
>&
cnet
) :
1467
InferenceEngine
<
GUM_SCALAR
>::
InferenceEngine
(
cnet
) {
1468
if
(!
cnet
.
isSeparatelySpecified
()) {
1469
GUM_ERROR
(
OperationNotAllowed
,
1470
"CNLoopyPropagation is only available "
1471
"with separately specified nets"
);
1472
}
1473
1474
// test for binary cn
1475
for
(
auto
node
:
cnet
.
current_bn
().
nodes
())
1476
if
(
cnet
.
current_bn
().
variable
(
node
).
domainSize
() != 2) {
1477
GUM_ERROR
(
OperationNotAllowed
,
1478
"CNLoopyPropagation is only available "
1479
"with binary credal networks"
);
1480
}
1481
1482
// test if compute CPTMinMax has been called
1483
if
(!
cnet
.
hasComputedBinaryCPTMinMax
()) {
1484
GUM_ERROR
(
OperationNotAllowed
,
1485
"CNLoopyPropagation only works when "
1486
"\"computeBinaryCPTMinMax()\" has been called for "
1487
"this credal net"
);
1488
}
1489
1490
_cn_
= &
cnet
;
1491
_bnet_
= &
cnet
.
current_bn
();
1492
1493
_inferenceType_
=
InferenceType
::
nodeToNeighbours
;
1494
InferenceUpToDate_
=
false
;
1495
1496
GUM_CONSTRUCTOR
(
CNLoopyPropagation
);
1497
}
1498
1499
template
<
typename
GUM_SCALAR
>
1500
CNLoopyPropagation
<
GUM_SCALAR
>::~
CNLoopyPropagation
() {
1501
InferenceUpToDate_
=
false
;
1502
1503
if
(
msg_l_sent_
.
size
() > 0) {
1504
for
(
auto
node
:
_bnet_
->
nodes
()) {
1505
delete
msg_l_sent_
[
node
];
1506
}
1507
}
1508
1509
//_msg_l_sent.clear();
1510
//_update_l.clear();
1511
//_update_p.clear();
1512
1513
GUM_DESTRUCTOR
(
CNLoopyPropagation
);
1514
}
1515
1516
template
<
typename
GUM_SCALAR
>
1517
void
CNLoopyPropagation
<
GUM_SCALAR
>::
inferenceType
(
InferenceType
inft
) {
1518
_inferenceType_
=
inft
;
1519
}
1520
1521
template
<
typename
GUM_SCALAR
>
1522
typename
CNLoopyPropagation
<
GUM_SCALAR
>::
InferenceType
1523
CNLoopyPropagation
<
GUM_SCALAR
>::
inferenceType
() {
1524
return
_inferenceType_
;
1525
}
1526
}
// namespace credal
1527
}
// end of namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::credal
namespace for all credal networks entities
Definition:
LpInterface.cpp:37