aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
jointTargetedInference_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the non pure virtual methods of class
25
* JointTargetedInference.
26
*/
27
#
include
<
agrum
/
BN
/
inference
/
tools
/
jointTargetedInference
.
h
>
28
#
include
<
agrum
/
tools
/
variables
/
rangeVariable
.
h
>
29
30
namespace
gum
{
31
32
33
// Default Constructor
34
template
<
typename
GUM_SCALAR >
35
JointTargetedInference< GUM_SCALAR >::JointTargetedInference(
36
const
IBayesNet< GUM_SCALAR >* bn) :
37
MarginalTargetedInference< GUM_SCALAR >(bn) {
38
// assign a BN if this has not been done before (due to virtual inheritance)
39
if
(
this
->hasNoModel_()) {
40
BayesNetInference< GUM_SCALAR >::setBayesNetDuringConstruction__(bn);
41
}
42
GUM_CONSTRUCTOR(JointTargetedInference);
43
}
44
45
46
// Destructor
47
template
<
typename
GUM_SCALAR
>
48
JointTargetedInference
<
GUM_SCALAR
>::~
JointTargetedInference
() {
49
GUM_DESTRUCTOR
(
JointTargetedInference
);
50
}
51
52
53
// assigns a new BN to the inference engine
54
template
<
typename
GUM_SCALAR
>
55
void
JointTargetedInference
<
GUM_SCALAR
>::
onModelChanged_
(
56
const
GraphicalModel
*
bn
) {
57
MarginalTargetedInference
<
GUM_SCALAR
>::
onModelChanged_
(
bn
);
58
onAllJointTargetsErased_
();
59
joint_targets__
.
clear
();
60
}
61
62
63
// ##############################################################################
64
// Targets
65
// ##############################################################################
66
67
// return true if target is a nodeset target.
68
template
<
typename
GUM_SCALAR
>
69
INLINE
bool
JointTargetedInference
<
GUM_SCALAR
>::
isJointTarget
(
70
const
NodeSet
&
vars
)
const
{
71
if
(
this
->
hasNoModel_
())
72
GUM_ERROR
(
NullElement
,
73
"No Bayes net has been assigned to the "
74
"inference algorithm"
);
75
76
const
auto
&
dag
=
this
->
BN
().
dag
();
77
for
(
const
auto
var
:
vars
) {
78
if
(!
dag
.
exists
(
var
)) {
79
GUM_ERROR
(
UndefinedElement
,
var
<<
" is not a NodeId in the bn"
);
80
}
81
}
82
83
return
joint_targets__
.
contains
(
vars
);
84
}
85
86
87
// Clear all previously defined single targets
88
template
<
typename
GUM_SCALAR
>
89
INLINE
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseAllMarginalTargets
() {
90
MarginalTargetedInference
<
GUM_SCALAR
>::
eraseAllTargets
();
91
}
92
93
94
// Clear all previously defined targets (single targets and sets of targets)
95
template
<
typename
GUM_SCALAR
>
96
INLINE
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseAllJointTargets
() {
97
if
(
joint_targets__
.
size
() > 0) {
98
// we already are in target mode. So no this->setTargetedMode_(); is needed
99
onAllJointTargetsErased_
();
100
joint_targets__
.
clear
();
101
this
->
setState_
(
GraphicalModelInference
<
102
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
103
}
104
}
105
106
107
// Clear all previously defined targets (single and joint targets)
108
template
<
typename
GUM_SCALAR
>
109
INLINE
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseAllTargets
() {
110
eraseAllMarginalTargets
();
111
eraseAllJointTargets
();
112
}
113
114
115
// Add a set of nodes as a new target
116
template
<
typename
GUM_SCALAR
>
117
void
JointTargetedInference
<
GUM_SCALAR
>::
addJointTarget
(
118
const
NodeSet
&
joint_target
) {
119
// check if the nodes in the target belong to the Bayesian network
120
if
(
this
->
hasNoModel_
())
121
GUM_ERROR
(
NullElement
,
122
"No Bayes net has been assigned to the "
123
"inference algorithm"
);
124
125
const
auto
&
dag
=
this
->
BN
().
dag
();
126
for
(
const
auto
node
:
joint_target
) {
127
if
(!
dag
.
exists
(
node
)) {
128
GUM_ERROR
(
UndefinedElement
,
129
"at least one one in "
<<
joint_target
130
<<
" does not belong to the bn"
);
131
}
132
}
133
134
// check that the joint_target set does not contain the new target
135
if
(
joint_targets__
.
contains
(
joint_target
))
return
;
136
137
// check if joint_target is a subset of an already existing target
138
for
(
const
auto
&
target
:
joint_targets__
) {
139
if
(
target
.
isProperSupersetOf
(
joint_target
))
return
;
140
}
141
142
// check if joint_target is not a superset of an already existing target
143
// in this case, we need to remove old existing target
144
for
(
auto
iter
=
joint_targets__
.
beginSafe
();
145
iter
!=
joint_targets__
.
endSafe
();
146
++
iter
) {
147
if
(
iter
->
isProperSubsetOf
(
joint_target
))
eraseJointTarget
(*
iter
);
148
}
149
150
this
->
setTargetedMode_
();
// does nothing if already in targeted mode
151
joint_targets__
.
insert
(
joint_target
);
152
onJointTargetAdded_
(
joint_target
);
153
this
->
setState_
(
154
GraphicalModelInference
<
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
155
}
156
157
158
// removes an existing set target
159
template
<
typename
GUM_SCALAR
>
160
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseJointTarget
(
161
const
NodeSet
&
joint_target
) {
162
// check if the nodes in the target belong to the Bayesian network
163
if
(
this
->
hasNoModel_
())
164
GUM_ERROR
(
NullElement
,
165
"No Bayes net has been assigned to the "
166
"inference algorithm"
);
167
168
const
auto
&
dag
=
this
->
BN
().
dag
();
169
for
(
const
auto
node
:
joint_target
) {
170
if
(!
dag
.
exists
(
node
)) {
171
GUM_ERROR
(
UndefinedElement
,
172
"at least one one in "
<<
joint_target
173
<<
" does not belong to the bn"
);
174
}
175
}
176
177
// check that the joint_target set does not contain the new target
178
if
(
joint_targets__
.
contains
(
joint_target
)) {
179
// note that we have to be in target mode when we are here
180
// so, no this->setTargetedMode_(); is necessary
181
onJointTargetErased_
(
joint_target
);
182
joint_targets__
.
erase
(
joint_target
);
183
this
->
setState_
(
GraphicalModelInference
<
184
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
185
}
186
}
187
188
189
/// returns the list of target sets
190
template
<
typename
GUM_SCALAR
>
191
INLINE
const
Set
<
NodeSet
>&
192
JointTargetedInference
<
GUM_SCALAR
>::
jointTargets
()
const
noexcept
{
193
return
joint_targets__
;
194
}
195
196
/// returns the number of target sets
197
template
<
typename
GUM_SCALAR
>
198
INLINE
Size
199
JointTargetedInference
<
GUM_SCALAR
>::
nbrJointTargets
()
const
noexcept
{
200
return
joint_targets__
.
size
();
201
}
202
203
204
// ##############################################################################
205
// Inference
206
// ##############################################################################
207
208
// Compute the posterior of a nodeset.
209
template
<
typename
GUM_SCALAR
>
210
const
Potential
<
GUM_SCALAR
>&
211
JointTargetedInference
<
GUM_SCALAR
>::
jointPosterior
(
const
NodeSet
&
nodes
) {
212
// try to get the smallest set of targets that contains "nodes"
213
NodeSet
set
;
214
bool
found_exact_target
=
false
;
215
216
if
(
joint_targets__
.
contains
(
nodes
)) {
217
set
=
nodes
;
218
found_exact_target
=
true
;
219
}
else
{
220
for
(
const
auto
&
target
:
joint_targets__
) {
221
if
(
nodes
.
isProperSubsetOf
(
target
)) {
222
set
=
target
;
223
break
;
224
}
225
}
226
}
227
228
if
(
set
.
empty
()) {
229
GUM_ERROR
(
UndefinedElement
,
230
" no joint target containing "
<<
nodes
<<
" could be found among "
231
<<
joint_targets__
);
232
}
233
234
if
(!
this
->
isInferenceDone
()) {
this
->
makeInference
(); }
235
236
if
(
found_exact_target
)
237
return
jointPosterior_
(
nodes
);
238
else
239
return
jointPosterior_
(
nodes
,
set
);
240
}
241
242
243
// Compute the posterior of a node
244
template
<
typename
GUM_SCALAR
>
245
const
Potential
<
GUM_SCALAR
>&
246
JointTargetedInference
<
GUM_SCALAR
>::
posterior
(
NodeId
node
) {
247
if
(
this
->
isTarget
(
node
))
248
return
MarginalTargetedInference
<
GUM_SCALAR
>::
posterior
(
node
);
249
else
250
return
jointPosterior
(
NodeSet
{
node
});
251
}
252
253
// Compute the posterior of a node
254
template
<
typename
GUM_SCALAR
>
255
const
Potential
<
GUM_SCALAR
>&
256
JointTargetedInference
<
GUM_SCALAR
>::
posterior
(
const
std
::
string
&
nodeName
) {
257
return
posterior
(
this
->
BN
().
idFromName
(
nodeName
));
258
}
259
260
// ##############################################################################
261
// Mutual Information
262
// ##############################################################################
263
template
<
typename
GUM_SCALAR
>
264
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
I
(
const
std
::
string
&
Xname
,
265
const
std
::
string
&
Yname
) {
266
return
I
(
this
->
BN
().
idFromName
(
Xname
),
this
->
BN
().
idFromName
(
Yname
));
267
}
268
269
template
<
typename
GUM_SCALAR
>
270
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
VI
(
const
std
::
string
&
Xname
,
271
const
std
::
string
&
Yname
) {
272
return
VI
(
this
->
BN
().
idFromName
(
Xname
),
this
->
BN
().
idFromName
(
Yname
));
273
}
274
275
/* Mutual information between X and Y
276
*
277
* @see http://en.wikipedia.org/wiki/Mutual_information
278
*
279
* @warning Due to limitation of @joint, may not be able to compute this value
280
* @throw OperationNotAllowed in these cases
281
*/
282
template
<
typename
GUM_SCALAR
>
283
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
I
(
NodeId
X
,
NodeId
Y
) {
284
Potential
<
GUM_SCALAR
>
pX
,
pY
, *
pXY
=
nullptr
;
285
if
(
X
==
Y
) {
286
GUM_ERROR
(
OperationNotAllowed
,
"Mutual Information I(X,Y) with X==Y"
);
287
}
288
289
try
{
290
// here use unnormalized joint posterior rather than just posterior
291
// to avoid saving the posterior in the cache of the inference engines
292
// like LazyPropagation or SahferShenoy.
293
pXY
=
this
->
unnormalizedJointPosterior_
({
X
,
Y
});
294
pXY
->
normalize
();
295
pX
=
pXY
->
margSumOut
({&(
this
->
BN
().
variable
(
Y
))});
296
pY
=
pXY
->
margSumOut
({&(
this
->
BN
().
variable
(
X
))});
297
}
catch
(...) {
298
if
(
pXY
!=
nullptr
) {
delete
pXY
; }
299
throw
;
300
}
301
302
Instantiation
i
(*
pXY
);
303
auto
res
= (
GUM_SCALAR
)0;
304
305
for
(
i
.
setFirst
(); !
i
.
end
(); ++
i
) {
306
GUM_SCALAR
vXY
= (*
pXY
)[
i
];
307
GUM_SCALAR
vX
=
pX
[
i
];
308
GUM_SCALAR
vY
=
pY
[
i
];
309
310
if
(
vXY
> (
GUM_SCALAR
)0) {
311
if
(
vX
== (
GUM_SCALAR
)0 ||
vY
== (
GUM_SCALAR
)0) {
312
GUM_ERROR
(
OperationNotAllowed
,
313
"Mutual Information (X,Y) with P(X)=0 or P(Y)=0 "
314
"and P(X,Y)>0"
);
315
}
316
317
res
+=
vXY
* (
std
::
log2
(
vXY
) -
std
::
log2
(
vX
) -
std
::
log2
(
vY
));
318
}
319
}
320
321
delete
pXY
;
322
323
return
res
;
324
}
325
326
327
/** Variation of information between X and Y
328
* @see http://en.wikipedia.org/wiki/Variation_of_information
329
*
330
* @warning Due to limitation of @joint, may not be able to compute this value
331
* @throw OperationNotAllowed in these cases
332
*/
333
template
<
typename
GUM_SCALAR
>
334
INLINE
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
VI
(
NodeId
X
,
NodeId
Y
) {
335
return
this
->
H
(
X
) +
this
->
H
(
Y
) - 2 *
I
(
X
,
Y
);
336
}
337
338
339
template
<
typename
GUM_SCALAR
>
340
Potential
<
GUM_SCALAR
>
341
JointTargetedInference
<
GUM_SCALAR
>::
evidenceJointImpact
(
342
const
NodeSet
&
targets
,
343
const
NodeSet
&
evs
) {
344
if
(!(
evs
*
targets
).
empty
()) {
345
GUM_ERROR
(
InvalidArgument
,
346
"Targets ("
<<
targets
<<
") can not intersect evs ("
<<
evs
347
<<
")."
);
348
}
349
auto
condset
=
this
->
BN
().
minimalCondSet
(
targets
,
evs
);
350
351
this
->
eraseAllTargets
();
352
this
->
eraseAllEvidence
();
353
354
Instantiation
iTarget
;
355
Potential
<
GUM_SCALAR
>
res
;
356
for
(
const
auto
&
target
:
targets
) {
357
res
.
add
(
this
->
BN
().
variable
(
target
));
358
iTarget
.
add
(
this
->
BN
().
variable
(
target
));
359
}
360
this
->
addJointTarget
(
targets
);
361
362
for
(
const
auto
&
n
:
condset
) {
363
res
.
add
(
this
->
BN
().
variable
(
n
));
364
this
->
addEvidence
(
n
, 0);
365
}
366
367
Instantiation
inst
(
res
);
368
for
(
inst
.
setFirstOut
(
iTarget
); !
inst
.
end
();
inst
.
incOut
(
iTarget
)) {
369
// inferring
370
for
(
const
auto
&
n
:
condset
)
371
this
->
chgEvidence
(
n
,
inst
.
val
(
this
->
BN
().
variable
(
n
)));
372
this
->
makeInference
();
373
// populate res
374
for
(
inst
.
setFirstIn
(
iTarget
); !
inst
.
end
();
inst
.
incIn
(
iTarget
)) {
375
res
.
set
(
inst
,
this
->
jointPosterior
(
targets
)[
inst
]);
376
}
377
inst
.
setFirstIn
(
iTarget
);
// remove inst.end() flag
378
}
379
380
return
res
;
381
}
382
383
template
<
typename
GUM_SCALAR
>
384
Potential
<
GUM_SCALAR
>
385
JointTargetedInference
<
GUM_SCALAR
>::
evidenceJointImpact
(
386
const
std
::
vector
<
std
::
string
>&
targets
,
387
const
std
::
vector
<
std
::
string
>&
evs
) {
388
const
auto
&
bn
=
this
->
BN
();
389
return
evidenceJointImpact
(
bn
.
nodeset
(
targets
),
bn
.
nodeset
(
evs
));
390
}
391
392
393
template
<
typename
GUM_SCALAR
>
394
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
jointMutualInformation
(
395
const
NodeSet
&
targets
) {
396
const
auto
&
bn
=
this
->
BN
();
397
const
Size
siz
=
targets
.
size
();
398
if
(
siz
<= 1) {
399
GUM_ERROR
(
InvalidArgument
,
400
"jointMutualInformation needs at least 2 variables (targets="
401
<<
targets
<<
")"
);
402
}
403
404
this
->
eraseAllTargets
();
405
this
->
eraseAllEvidence
();
406
this
->
addJointTarget
(
targets
);
407
this
->
makeInference
();
408
const
auto
po
=
this
->
jointPosterior
(
targets
);
409
410
gum
::
Instantiation
caracteristic
;
411
gum
::
Instantiation
variables
;
412
for
(
const
auto
nod
:
targets
) {
413
const
auto
&
var
=
bn
.
variable
(
nod
);
414
auto
pv
=
new
gum
::
RangeVariable
(
var
.
name
(),
""
, 0, 1);
415
caracteristic
.
add
(*
pv
);
416
variables
.
add
(
var
);
417
}
418
419
Set
<
const
DiscreteVariable
* >
sov
;
420
421
const
GUM_SCALAR
start
= (
siz
% 2 == 0) ?
GUM_SCALAR
(-1.0) :
GUM_SCALAR
(1.0);
422
GUM_SCALAR
sign
;
423
GUM_SCALAR
res
=
GUM_SCALAR
(0.0);
424
425
caracteristic
.
setFirst
();
426
for
(
caracteristic
.
inc
(); !
caracteristic
.
end
();
caracteristic
.
inc
()) {
427
sov
.
clear
();
428
sign
=
start
;
429
for
(
Idx
i
= 0;
i
<
caracteristic
.
nbrDim
();
i
++) {
430
if
(
caracteristic
.
val
(
i
) == 1) {
431
sign
= -
sign
;
432
sov
.
insert
(&
variables
.
variable
(
i
));
433
}
434
}
435
res
+=
sign
*
po
.
margSumIn
(
sov
).
entropy
();
436
}
437
438
for
(
Idx
i
= 0;
i
<
caracteristic
.
nbrDim
();
i
++) {
439
delete
&
caracteristic
.
variable
(
i
);
440
}
441
442
return
res
;
443
}
444
445
template
<
typename
GUM_SCALAR
>
446
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
jointMutualInformation
(
447
const
std
::
vector
<
std
::
string
>&
targets
) {
448
return
jointMutualInformation
(
this
->
BN
().
ids
(
targets
));
449
}
450
451
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669