aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
jointTargetedInference_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the non pure virtual methods of class
25
* JointTargetedInference.
26
*/
27
#
include
<
agrum
/
BN
/
inference
/
tools
/
jointTargetedInference
.
h
>
28
#
include
<
agrum
/
tools
/
variables
/
rangeVariable
.
h
>
29
30
namespace
gum
{
31
32
33
// Default Constructor
34
template
<
typename
GUM_SCALAR >
35
JointTargetedInference< GUM_SCALAR >::JointTargetedInference(
const
IBayesNet< GUM_SCALAR >* bn) :
36
MarginalTargetedInference< GUM_SCALAR >(bn) {
37
// assign a BN if this has not been done before (due to virtual inheritance)
38
if
(
this
->hasNoModel_()) {
39
BayesNetInference< GUM_SCALAR >::_setBayesNetDuringConstruction_(bn);
40
}
41
GUM_CONSTRUCTOR(JointTargetedInference);
42
}
43
44
45
// Destructor
46
template
<
typename
GUM_SCALAR
>
47
JointTargetedInference
<
GUM_SCALAR
>::~
JointTargetedInference
() {
48
GUM_DESTRUCTOR
(
JointTargetedInference
);
49
}
50
51
52
// assigns a new BN to the inference engine
53
template
<
typename
GUM_SCALAR
>
54
void
JointTargetedInference
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
bn
) {
55
MarginalTargetedInference
<
GUM_SCALAR
>::
onModelChanged_
(
bn
);
56
onAllJointTargetsErased_
();
57
_joint_targets_
.
clear
();
58
}
59
60
61
// ##############################################################################
62
// Targets
63
// ##############################################################################
64
65
// return true if target is a nodeset target.
66
template
<
typename
GUM_SCALAR
>
67
INLINE
bool
JointTargetedInference
<
GUM_SCALAR
>::
isJointTarget
(
const
NodeSet
&
vars
)
const
{
68
if
(
this
->
hasNoModel_
())
69
GUM_ERROR
(
NullElement
,
70
"No Bayes net has been assigned to the "
71
"inference algorithm"
);
72
73
const
auto
&
dag
=
this
->
BN
().
dag
();
74
for
(
const
auto
var
:
vars
) {
75
if
(!
dag
.
exists
(
var
)) {
GUM_ERROR
(
UndefinedElement
,
var
<<
" is not a NodeId in the bn"
) }
76
}
77
78
return
_joint_targets_
.
contains
(
vars
);
79
}
80
81
82
// Clear all previously defined single targets
83
template
<
typename
GUM_SCALAR
>
84
INLINE
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseAllMarginalTargets
() {
85
MarginalTargetedInference
<
GUM_SCALAR
>::
eraseAllTargets
();
86
}
87
88
89
// Clear all previously defined targets (single targets and sets of targets)
90
template
<
typename
GUM_SCALAR
>
91
INLINE
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseAllJointTargets
() {
92
if
(
_joint_targets_
.
size
() > 0) {
93
// we already are in target mode. So no this->setTargetedMode_(); is needed
94
onAllJointTargetsErased_
();
95
_joint_targets_
.
clear
();
96
this
->
setState_
(
GraphicalModelInference
<
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
97
}
98
}
99
100
101
// Clear all previously defined targets (single and joint targets)
102
template
<
typename
GUM_SCALAR
>
103
INLINE
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseAllTargets
() {
104
eraseAllMarginalTargets
();
105
eraseAllJointTargets
();
106
}
107
108
109
// Add a set of nodes as a new target
110
template
<
typename
GUM_SCALAR
>
111
void
JointTargetedInference
<
GUM_SCALAR
>::
addJointTarget
(
const
NodeSet
&
joint_target
) {
112
// check if the nodes in the target belong to the Bayesian network
113
if
(
this
->
hasNoModel_
())
114
GUM_ERROR
(
NullElement
,
115
"No Bayes net has been assigned to the "
116
"inference algorithm"
);
117
118
const
auto
&
dag
=
this
->
BN
().
dag
();
119
for
(
const
auto
node
:
joint_target
) {
120
if
(!
dag
.
exists
(
node
)) {
121
GUM_ERROR
(
UndefinedElement
,
122
"at least one one in "
<<
joint_target
<<
" does not belong to the bn"
);
123
}
124
}
125
126
// check that the joint_target set does not contain the new target
127
if
(
_joint_targets_
.
contains
(
joint_target
))
return
;
128
129
// check if joint_target is a subset of an already existing target
130
for
(
const
auto
&
target
:
_joint_targets_
) {
131
if
(
target
.
isProperSupersetOf
(
joint_target
))
return
;
132
}
133
134
// check if joint_target is not a superset of an already existing target
135
// in this case, we need to remove old existing target
136
for
(
auto
iter
=
_joint_targets_
.
beginSafe
();
iter
!=
_joint_targets_
.
endSafe
(); ++
iter
) {
137
if
(
iter
->
isProperSubsetOf
(
joint_target
))
eraseJointTarget
(*
iter
);
138
}
139
140
this
->
setTargetedMode_
();
// does nothing if already in targeted mode
141
_joint_targets_
.
insert
(
joint_target
);
142
onJointTargetAdded_
(
joint_target
);
143
this
->
setState_
(
GraphicalModelInference
<
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
144
}
145
146
147
// removes an existing set target
148
template
<
typename
GUM_SCALAR
>
149
void
JointTargetedInference
<
GUM_SCALAR
>::
eraseJointTarget
(
const
NodeSet
&
joint_target
) {
150
// check if the nodes in the target belong to the Bayesian network
151
if
(
this
->
hasNoModel_
())
152
GUM_ERROR
(
NullElement
,
153
"No Bayes net has been assigned to the "
154
"inference algorithm"
);
155
156
const
auto
&
dag
=
this
->
BN
().
dag
();
157
for
(
const
auto
node
:
joint_target
) {
158
if
(!
dag
.
exists
(
node
)) {
159
GUM_ERROR
(
UndefinedElement
,
160
"at least one one in "
<<
joint_target
<<
" does not belong to the bn"
);
161
}
162
}
163
164
// check that the joint_target set does not contain the new target
165
if
(
_joint_targets_
.
contains
(
joint_target
)) {
166
// note that we have to be in target mode when we are here
167
// so, no this->setTargetedMode_(); is necessary
168
onJointTargetErased_
(
joint_target
);
169
_joint_targets_
.
erase
(
joint_target
);
170
this
->
setState_
(
GraphicalModelInference
<
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
171
}
172
}
173
174
175
/// returns the list of target sets
176
template
<
typename
GUM_SCALAR
>
177
INLINE
const
Set
<
NodeSet
>&
JointTargetedInference
<
GUM_SCALAR
>::
jointTargets
()
const
noexcept
{
178
return
_joint_targets_
;
179
}
180
181
/// returns the number of target sets
182
template
<
typename
GUM_SCALAR
>
183
INLINE
Size
JointTargetedInference
<
GUM_SCALAR
>::
nbrJointTargets
()
const
noexcept
{
184
return
_joint_targets_
.
size
();
185
}
186
187
188
// ##############################################################################
189
// Inference
190
// ##############################################################################
191
192
// Compute the posterior of a nodeset.
193
template
<
typename
GUM_SCALAR
>
194
const
Potential
<
GUM_SCALAR
>&
195
JointTargetedInference
<
GUM_SCALAR
>::
jointPosterior
(
const
NodeSet
&
nodes
) {
196
// try to get the smallest set of targets that contains "nodes"
197
NodeSet
set
;
198
bool
found_exact_target
=
false
;
199
200
if
(
_joint_targets_
.
contains
(
nodes
)) {
201
set
=
nodes
;
202
found_exact_target
=
true
;
203
}
else
{
204
for
(
const
auto
&
target
:
_joint_targets_
) {
205
if
(
nodes
.
isProperSubsetOf
(
target
)) {
206
set
=
target
;
207
break
;
208
}
209
}
210
}
211
212
if
(
set
.
empty
()) {
213
GUM_ERROR
(
UndefinedElement
,
214
" no joint target containing "
<<
nodes
<<
" could be found among "
215
<<
_joint_targets_
);
216
}
217
218
if
(!
this
->
isInferenceDone
()) {
this
->
makeInference
(); }
219
220
if
(
found_exact_target
)
221
return
jointPosterior_
(
nodes
);
222
else
223
return
jointPosterior_
(
nodes
,
set
);
224
}
225
226
227
// Compute the posterior of a node
228
template
<
typename
GUM_SCALAR
>
229
const
Potential
<
GUM_SCALAR
>&
JointTargetedInference
<
GUM_SCALAR
>::
posterior
(
NodeId
node
) {
230
if
(
this
->
isTarget
(
node
))
231
return
MarginalTargetedInference
<
GUM_SCALAR
>::
posterior
(
node
);
232
else
233
return
jointPosterior
(
NodeSet
{
node
});
234
}
235
236
// Compute the posterior of a node
237
template
<
typename
GUM_SCALAR
>
238
const
Potential
<
GUM_SCALAR
>&
239
JointTargetedInference
<
GUM_SCALAR
>::
posterior
(
const
std
::
string
&
nodeName
) {
240
return
posterior
(
this
->
BN
().
idFromName
(
nodeName
));
241
}
242
243
// ##############################################################################
244
// Mutual Information
245
// ##############################################################################
246
template
<
typename
GUM_SCALAR
>
247
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
I
(
const
std
::
string
&
Xname
,
248
const
std
::
string
&
Yname
) {
249
return
I
(
this
->
BN
().
idFromName
(
Xname
),
this
->
BN
().
idFromName
(
Yname
));
250
}
251
252
template
<
typename
GUM_SCALAR
>
253
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
VI
(
const
std
::
string
&
Xname
,
254
const
std
::
string
&
Yname
) {
255
return
VI
(
this
->
BN
().
idFromName
(
Xname
),
this
->
BN
().
idFromName
(
Yname
));
256
}
257
258
/* Mutual information between X and Y
259
*
260
* @see http://en.wikipedia.org/wiki/Mutual_information
261
*
262
* @warning Due to limitation of @joint, may not be able to compute this value
263
* @throw OperationNotAllowed in these cases
264
*/
265
template
<
typename
GUM_SCALAR
>
266
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
I
(
NodeId
X
,
NodeId
Y
) {
267
Potential
<
GUM_SCALAR
>
pX
,
pY
, *
pXY
=
nullptr
;
268
if
(
X
==
Y
) {
GUM_ERROR
(
OperationNotAllowed
,
"Mutual Information I(X,Y) with X==Y"
) }
269
270
try
{
271
// here use unnormalized joint posterior rather than just posterior
272
// to avoid saving the posterior in the cache of the inference engines
273
// like LazyPropagation or SahferShenoy.
274
pXY
=
this
->
unnormalizedJointPosterior_
({
X
,
Y
});
275
pXY
->
normalize
();
276
pX
=
pXY
->
margSumOut
({&(
this
->
BN
().
variable
(
Y
))});
277
pY
=
pXY
->
margSumOut
({&(
this
->
BN
().
variable
(
X
))});
278
}
catch
(...) {
279
if
(
pXY
!=
nullptr
) {
delete
pXY
; }
280
throw
;
281
}
282
283
Instantiation
i
(*
pXY
);
284
auto
res
= (
GUM_SCALAR
)0;
285
286
for
(
i
.
setFirst
(); !
i
.
end
(); ++
i
) {
287
GUM_SCALAR
vXY
= (*
pXY
)[
i
];
288
GUM_SCALAR
vX
=
pX
[
i
];
289
GUM_SCALAR
vY
=
pY
[
i
];
290
291
if
(
vXY
> (
GUM_SCALAR
)0) {
292
if
(
vX
== (
GUM_SCALAR
)0 ||
vY
== (
GUM_SCALAR
)0) {
293
GUM_ERROR
(
OperationNotAllowed
,
294
"Mutual Information (X,Y) with P(X)=0 or P(Y)=0 "
295
"and P(X,Y)>0"
);
296
}
297
298
res
+=
vXY
* (
std
::
log2
(
vXY
) -
std
::
log2
(
vX
) -
std
::
log2
(
vY
));
299
}
300
}
301
302
delete
pXY
;
303
304
return
res
;
305
}
306
307
308
/** Variation of information between X and Y
309
* @see http://en.wikipedia.org/wiki/Variation_of_information
310
*
311
* @warning Due to limitation of @joint, may not be able to compute this value
312
* @throw OperationNotAllowed in these cases
313
*/
314
template
<
typename
GUM_SCALAR
>
315
INLINE
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
VI
(
NodeId
X
,
NodeId
Y
) {
316
return
this
->
H
(
X
) +
this
->
H
(
Y
) - 2 *
I
(
X
,
Y
);
317
}
318
319
320
template
<
typename
GUM_SCALAR
>
321
Potential
<
GUM_SCALAR
>
322
JointTargetedInference
<
GUM_SCALAR
>::
evidenceJointImpact
(
const
NodeSet
&
targets
,
323
const
NodeSet
&
evs
) {
324
if
(!(
evs
*
targets
).
empty
()) {
325
GUM_ERROR
(
InvalidArgument
,
326
"Targets ("
<<
targets
<<
") can not intersect evs ("
<<
evs
<<
")."
);
327
}
328
auto
condset
=
this
->
BN
().
minimalCondSet
(
targets
,
evs
);
329
330
this
->
eraseAllTargets
();
331
this
->
eraseAllEvidence
();
332
333
Instantiation
iTarget
;
334
Potential
<
GUM_SCALAR
>
res
;
335
for
(
const
auto
&
target
:
targets
) {
336
res
.
add
(
this
->
BN
().
variable
(
target
));
337
iTarget
.
add
(
this
->
BN
().
variable
(
target
));
338
}
339
this
->
addJointTarget
(
targets
);
340
341
for
(
const
auto
&
n
:
condset
) {
342
res
.
add
(
this
->
BN
().
variable
(
n
));
343
this
->
addEvidence
(
n
, 0);
344
}
345
346
Instantiation
inst
(
res
);
347
for
(
inst
.
setFirstOut
(
iTarget
); !
inst
.
end
();
inst
.
incOut
(
iTarget
)) {
348
// inferring
349
for
(
const
auto
&
n
:
condset
)
350
this
->
chgEvidence
(
n
,
inst
.
val
(
this
->
BN
().
variable
(
n
)));
351
this
->
makeInference
();
352
// populate res
353
for
(
inst
.
setFirstIn
(
iTarget
); !
inst
.
end
();
inst
.
incIn
(
iTarget
)) {
354
res
.
set
(
inst
,
this
->
jointPosterior
(
targets
)[
inst
]);
355
}
356
inst
.
setFirstIn
(
iTarget
);
// remove inst.end() flag
357
}
358
359
return
res
;
360
}
361
362
template
<
typename
GUM_SCALAR
>
363
Potential
<
GUM_SCALAR
>
JointTargetedInference
<
GUM_SCALAR
>::
evidenceJointImpact
(
364
const
std
::
vector
<
std
::
string
>&
targets
,
365
const
std
::
vector
<
std
::
string
>&
evs
) {
366
const
auto
&
bn
=
this
->
BN
();
367
return
evidenceJointImpact
(
bn
.
nodeset
(
targets
),
bn
.
nodeset
(
evs
));
368
}
369
370
371
template
<
typename
GUM_SCALAR
>
372
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
jointMutualInformation
(
const
NodeSet
&
targets
) {
373
const
auto
&
bn
=
this
->
BN
();
374
const
Size
siz
=
targets
.
size
();
375
if
(
siz
<= 1) {
376
GUM_ERROR
(
InvalidArgument
,
377
"jointMutualInformation needs at least 2 variables (targets="
<<
targets
<<
")"
);
378
}
379
380
this
->
eraseAllTargets
();
381
this
->
eraseAllEvidence
();
382
this
->
addJointTarget
(
targets
);
383
this
->
makeInference
();
384
const
auto
po
=
this
->
jointPosterior
(
targets
);
385
386
gum
::
Instantiation
caracteristic
;
387
gum
::
Instantiation
variables
;
388
for
(
const
auto
nod
:
targets
) {
389
const
auto
&
var
=
bn
.
variable
(
nod
);
390
auto
pv
=
new
gum
::
RangeVariable
(
var
.
name
(),
""
, 0, 1);
391
caracteristic
.
add
(*
pv
);
392
variables
.
add
(
var
);
393
}
394
395
Set
<
const
DiscreteVariable
* >
sov
;
396
397
const
GUM_SCALAR
start
= (
siz
% 2 == 0) ?
GUM_SCALAR
(-1.0) :
GUM_SCALAR
(1.0);
398
GUM_SCALAR
sign
;
399
GUM_SCALAR
res
=
GUM_SCALAR
(0.0);
400
401
caracteristic
.
setFirst
();
402
for
(
caracteristic
.
inc
(); !
caracteristic
.
end
();
caracteristic
.
inc
()) {
403
sov
.
clear
();
404
sign
=
start
;
405
for
(
Idx
i
= 0;
i
<
caracteristic
.
nbrDim
();
i
++) {
406
if
(
caracteristic
.
val
(
i
) == 1) {
407
sign
= -
sign
;
408
sov
.
insert
(&
variables
.
variable
(
i
));
409
}
410
}
411
res
+=
sign
*
po
.
margSumIn
(
sov
).
entropy
();
412
}
413
414
for
(
Idx
i
= 0;
i
<
caracteristic
.
nbrDim
();
i
++) {
415
delete
&
caracteristic
.
variable
(
i
);
416
}
417
418
return
res
;
419
}
420
421
template
<
typename
GUM_SCALAR
>
422
GUM_SCALAR
JointTargetedInference
<
GUM_SCALAR
>::
jointMutualInformation
(
423
const
std
::
vector
<
std
::
string
>&
targets
) {
424
return
jointMutualInformation
(
this
->
BN
().
ids
(
targets
));
425
}
426
427
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643