aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
jointTargetedMNInference_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the non pure virtual methods of class
25
* JointTargetedMNInference.
26
*/
27
#
include
<
agrum
/
MN
/
inference
/
tools
/
jointTargetedMNInference
.
h
>
28
#
include
<
agrum
/
tools
/
variables
/
rangeVariable
.
h
>
29
30
namespace
gum
{
31
32
33
// Default Constructor
34
template
<
typename
GUM_SCALAR >
35
JointTargetedMNInference< GUM_SCALAR >::JointTargetedMNInference(
36
const
IMarkovNet< GUM_SCALAR >* mn) :
37
MarginalTargetedMNInference< GUM_SCALAR >(mn) {
38
// assign a MN if this has not been done before (due to virtual inheritance)
39
if
(
this
->hasNoModel_()) {
40
MarkovNetInference< GUM_SCALAR >::_setMarkovNetDuringConstruction_(mn);
41
}
42
GUM_CONSTRUCTOR(JointTargetedMNInference);
43
}
44
45
46
// Destructor
47
template
<
typename
GUM_SCALAR
>
48
JointTargetedMNInference
<
GUM_SCALAR
>::~
JointTargetedMNInference
() {
49
GUM_DESTRUCTOR
(
JointTargetedMNInference
);
50
}
51
52
53
// assigns a new MN to the inference engine
54
template
<
typename
GUM_SCALAR
>
55
void
JointTargetedMNInference
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
mn
) {
56
MarginalTargetedMNInference
<
GUM_SCALAR
>::
onModelChanged_
(
mn
);
57
onAllJointTargetsErased_
();
58
_joint_targets_
.
clear
();
59
}
60
61
62
// ##############################################################################
63
// Targets
64
// ##############################################################################
65
66
// return true if target is a nodeset target.
67
template
<
typename
GUM_SCALAR
>
68
INLINE
bool
JointTargetedMNInference
<
GUM_SCALAR
>::
isJointTarget
(
const
NodeSet
&
vars
)
const
{
69
if
(
this
->
hasNoModel_
())
70
GUM_ERROR
(
NullElement
,
71
"No Markov net has been assigned to the "
72
"inference algorithm"
);
73
74
const
auto
&
gra
=
this
->
MN
().
graph
();
75
for
(
const
auto
var
:
vars
) {
76
if
(!
gra
.
exists
(
var
)) {
77
GUM_ERROR
(
UndefinedElement
,
var
<<
" is not a NodeId in the Markov network"
)
78
}
79
}
80
81
return
_joint_targets_
.
contains
(
vars
);
82
}
83
84
85
// Clear all previously defined single targets
86
template
<
typename
GUM_SCALAR
>
87
INLINE
void
JointTargetedMNInference
<
GUM_SCALAR
>::
eraseAllMarginalTargets
() {
88
MarginalTargetedMNInference
<
GUM_SCALAR
>::
eraseAllTargets
();
89
}
90
91
92
// Clear all previously defined targets (single targets and sets of targets)
93
template
<
typename
GUM_SCALAR
>
94
INLINE
void
JointTargetedMNInference
<
GUM_SCALAR
>::
eraseAllJointTargets
() {
95
if
(
_joint_targets_
.
size
() > 0) {
96
// we already are in target mode. So no this->setTargetedMode_(); is needed
97
onAllJointTargetsErased_
();
98
_joint_targets_
.
clear
();
99
this
->
setState_
(
MarkovNetInference
<
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
100
}
101
}
102
103
104
// Clear all previously defined targets (single and joint targets)
105
template
<
typename
GUM_SCALAR
>
106
INLINE
void
JointTargetedMNInference
<
GUM_SCALAR
>::
eraseAllTargets
() {
107
eraseAllMarginalTargets
();
108
eraseAllJointTargets
();
109
}
110
111
112
// Add a set of nodes as a new target
113
template
<
typename
GUM_SCALAR
>
114
void
JointTargetedMNInference
<
GUM_SCALAR
>::
addJointTarget
(
const
NodeSet
&
joint_target
) {
115
// check if the nodes in the target belong to the Markov network
116
if
(
this
->
hasNoModel_
())
117
GUM_ERROR
(
NullElement
,
118
"No Markov net has been assigned to the "
119
"inference algorithm"
);
120
121
const
auto
&
dag
=
this
->
MN
().
graph
();
122
for
(
const
auto
node
:
joint_target
) {
123
if
(!
dag
.
exists
(
node
)) {
124
GUM_ERROR
(
UndefinedElement
,
125
"at least one one in "
<<
joint_target
<<
" does not belong to the mn"
);
126
}
127
}
128
129
if
(
isExactJointComputable_
(
joint_target
))
return
;
130
if
(!
superForJointComputable_
(
joint_target
).
empty
())
return
;
131
132
// check if joint_target is a subset of an already existing target
133
for
(
const
auto
&
target
:
_joint_targets_
) {
134
if
(
target
.
isProperSupersetOf
(
joint_target
))
return
;
135
}
136
137
// check if joint_target is not a superset of an already existing target
138
// in this case, we need to remove old existing target
139
for
(
auto
iter
=
_joint_targets_
.
beginSafe
();
iter
!=
_joint_targets_
.
endSafe
(); ++
iter
) {
140
if
(
iter
->
isProperSubsetOf
(
joint_target
))
eraseJointTarget
(*
iter
);
141
}
142
143
this
->
setTargetedMode_
();
// does nothing if already in targeted mode
144
_joint_targets_
.
insert
(
joint_target
);
145
onJointTargetAdded_
(
joint_target
);
146
this
->
setState_
(
MarkovNetInference
<
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
147
}
148
149
150
// removes an existing set target
151
template
<
typename
GUM_SCALAR
>
152
void
JointTargetedMNInference
<
GUM_SCALAR
>::
eraseJointTarget
(
const
NodeSet
&
joint_target
) {
153
// check if the nodes in the target belong to the Markov network
154
if
(
this
->
hasNoModel_
())
155
GUM_ERROR
(
NullElement
,
156
"No Markov net has been assigned to the "
157
"inference algorithm"
);
158
159
const
auto
&
dag
=
this
->
MN
().
graph
();
160
for
(
const
auto
node
:
joint_target
) {
161
if
(!
dag
.
exists
(
node
)) {
162
GUM_ERROR
(
UndefinedElement
,
163
"at least one one in "
<<
joint_target
<<
" does not belong to the mn"
);
164
}
165
}
166
167
// check that the joint_target set does not contain the new target
168
if
(
_joint_targets_
.
contains
(
joint_target
)) {
169
// note that we have to be in target mode when we are here
170
// so, no this->setTargetedMode_(); is necessary
171
onJointTargetErased_
(
joint_target
);
172
_joint_targets_
.
erase
(
joint_target
);
173
this
->
setState_
(
MarkovNetInference
<
GUM_SCALAR
>::
StateOfInference
::
OutdatedStructure
);
174
}
175
}
176
177
178
/// returns the list of target sets
179
template
<
typename
GUM_SCALAR
>
180
INLINE
const
Set
<
NodeSet
>&
181
JointTargetedMNInference
<
GUM_SCALAR
>::
jointTargets
()
const
noexcept
{
182
return
_joint_targets_
;
183
}
184
185
/// returns the number of target sets
186
template
<
typename
GUM_SCALAR
>
187
INLINE
Size
JointTargetedMNInference
<
GUM_SCALAR
>::
nbrJointTargets
()
const
noexcept
{
188
return
_joint_targets_
.
size
();
189
}
190
191
192
// ##############################################################################
193
// Inference
194
// ##############################################################################
195
196
// Compute the posterior of a nodeset.
197
template
<
typename
GUM_SCALAR
>
198
const
Potential
<
GUM_SCALAR
>&
199
JointTargetedMNInference
<
GUM_SCALAR
>::
jointPosterior
(
const
NodeSet
&
nodes
) {
200
// try to get the smallest set of targets that contains "nodes"
201
bool
found_exact_target
=
false
;
202
NodeSet
super_target
;
203
204
if
(
isExactJointComputable_
(
nodes
)) {
205
found_exact_target
=
true
;
206
}
else
{
207
super_target
=
superForJointComputable_
(
nodes
);
208
if
(
super_target
.
empty
()) {
209
GUM_ERROR
(
UndefinedElement
,
210
"No joint target containing "
<<
nodes
<<
" could be found among "
211
<<
_joint_targets_
);
212
}
213
}
214
215
if
(!
this
->
isInferenceDone
()) {
this
->
makeInference
(); }
216
217
if
(
found_exact_target
)
218
return
jointPosterior_
(
nodes
);
219
else
220
return
jointPosterior_
(
nodes
,
super_target
);
221
}
222
223
224
// Compute the posterior of a node
225
template
<
typename
GUM_SCALAR
>
226
const
Potential
<
GUM_SCALAR
>&
JointTargetedMNInference
<
GUM_SCALAR
>::
posterior
(
NodeId
node
) {
227
if
(
this
->
isTarget
(
node
))
228
return
MarginalTargetedMNInference
<
GUM_SCALAR
>::
posterior
(
node
);
229
else
230
return
jointPosterior
(
NodeSet
{
node
});
231
}
232
233
// Compute the posterior of a node
234
template
<
typename
GUM_SCALAR
>
235
const
Potential
<
GUM_SCALAR
>&
236
JointTargetedMNInference
<
GUM_SCALAR
>::
posterior
(
const
std
::
string
&
nodeName
) {
237
return
posterior
(
this
->
MN
().
idFromName
(
nodeName
));
238
}
239
240
// ##############################################################################
241
// Entropy
242
// ##############################################################################
243
template
<
typename
GUM_SCALAR
>
244
GUM_SCALAR
JointTargetedMNInference
<
GUM_SCALAR
>::
I
(
const
std
::
string
&
Xname
,
245
const
std
::
string
&
Yname
) {
246
return
I
(
this
->
MN
().
idFromName
(
Xname
),
this
->
MN
().
idFromName
(
Yname
));
247
}
248
249
template
<
typename
GUM_SCALAR
>
250
GUM_SCALAR
JointTargetedMNInference
<
GUM_SCALAR
>::
VI
(
const
std
::
string
&
Xname
,
251
const
std
::
string
&
Yname
) {
252
return
VI
(
this
->
MN
().
idFromName
(
Xname
),
this
->
MN
().
idFromName
(
Yname
));
253
}
254
255
256
/* Mutual information between X and Y
257
* @see http://en.wikipedia.org/wiki/Mutual_information
258
*
259
* @warning Due to limitation of @joint, may not be able to compute this value
260
* @throw OperationNotAllowed in these cases
261
*/
262
template
<
typename
GUM_SCALAR
>
263
GUM_SCALAR
JointTargetedMNInference
<
GUM_SCALAR
>::
I
(
NodeId
X
,
NodeId
Y
) {
264
Potential
<
GUM_SCALAR
>
pX
,
pY
, *
pXY
=
nullptr
;
265
if
(
X
==
Y
) {
GUM_ERROR
(
OperationNotAllowed
,
"Mutual Information I(X,Y) with X==Y"
) }
266
267
try
{
268
// here use unnormalized joint posterior rather than just posterior
269
// to avoid saving the posterior in the cache of the inference engines
270
// like LazyPropagation or Shafer-Shenoy.
271
pXY
=
this
->
unnormalizedJointPosterior_
({
X
,
Y
});
272
pXY
->
normalize
();
273
pX
=
pXY
->
margSumOut
({&(
this
->
MN
().
variable
(
Y
))});
274
pY
=
pXY
->
margSumOut
({&(
this
->
MN
().
variable
(
X
))});
275
}
catch
(...) {
276
if
(
pXY
!=
nullptr
) {
delete
pXY
; }
277
throw
;
278
}
279
280
Instantiation
i
(*
pXY
);
281
auto
res
= (
GUM_SCALAR
)0;
282
283
for
(
i
.
setFirst
(); !
i
.
end
(); ++
i
) {
284
GUM_SCALAR
vXY
= (*
pXY
)[
i
];
285
GUM_SCALAR
vX
=
pX
[
i
];
286
GUM_SCALAR
vY
=
pY
[
i
];
287
288
if
(
vXY
> (
GUM_SCALAR
)0) {
289
if
(
vX
== (
GUM_SCALAR
)0 ||
vY
== (
GUM_SCALAR
)0) {
290
GUM_ERROR
(
OperationNotAllowed
,
291
"Mutual Information (X,Y) with P(X)=0 or P(Y)=0 "
292
"and P(X,Y)>0"
);
293
}
294
295
res
+=
vXY
* (
std
::
log2
(
vXY
) -
std
::
log2
(
vX
) -
std
::
log2
(
vY
));
296
}
297
}
298
299
delete
pXY
;
300
301
return
res
;
302
}
303
304
305
/** Variation of information between X and Y
306
* @see http://en.wikipedia.org/wiki/Variation_of_information
307
*
308
* @warning Due to limitation of @joint, may not be able to compute this value
309
* @throw OperationNotAllowed in these cases
310
*/
311
template
<
typename
GUM_SCALAR
>
312
INLINE
GUM_SCALAR
JointTargetedMNInference
<
GUM_SCALAR
>::
VI
(
NodeId
X
,
NodeId
Y
) {
313
return
this
->
H
(
X
) +
this
->
H
(
Y
) - 2 *
I
(
X
,
Y
);
314
}
315
316
317
template
<
typename
GUM_SCALAR
>
318
Potential
<
GUM_SCALAR
>
319
JointTargetedMNInference
<
GUM_SCALAR
>::
evidenceJointImpact
(
const
NodeSet
&
targets
,
320
const
NodeSet
&
evs
) {
321
if
(!(
evs
*
targets
).
empty
()) {
322
GUM_ERROR
(
InvalidArgument
,
323
"Targets ("
<<
targets
<<
") can not intersect evs ("
<<
evs
<<
")."
);
324
}
325
auto
condset
=
this
->
MN
().
minimalCondSet
(
targets
,
evs
);
326
327
this
->
eraseAllTargets
();
328
this
->
eraseAllEvidence
();
329
330
Instantiation
iTarget
;
331
Potential
<
GUM_SCALAR
>
res
;
332
for
(
const
auto
&
target
:
targets
) {
333
res
.
add
(
this
->
MN
().
variable
(
target
));
334
iTarget
.
add
(
this
->
MN
().
variable
(
target
));
335
}
336
this
->
addJointTarget
(
targets
);
337
338
for
(
const
auto
&
n
:
condset
) {
339
res
.
add
(
this
->
MN
().
variable
(
n
));
340
this
->
addEvidence
(
n
, 0);
341
}
342
343
Instantiation
inst
(
res
);
344
for
(
inst
.
setFirstOut
(
iTarget
); !
inst
.
end
();
inst
.
incOut
(
iTarget
)) {
345
// inferring
346
for
(
const
auto
&
n
:
condset
)
347
this
->
chgEvidence
(
n
,
inst
.
val
(
this
->
MN
().
variable
(
n
)));
348
this
->
makeInference
();
349
// populate res
350
for
(
inst
.
setFirstIn
(
iTarget
); !
inst
.
end
();
inst
.
incIn
(
iTarget
)) {
351
res
.
set
(
inst
,
this
->
jointPosterior
(
targets
)[
inst
]);
352
}
353
inst
.
setFirstIn
(
iTarget
);
// remove inst.end() flag
354
}
355
356
return
res
;
357
}
358
359
template
<
typename
GUM_SCALAR
>
360
Potential
<
GUM_SCALAR
>
JointTargetedMNInference
<
GUM_SCALAR
>::
evidenceJointImpact
(
361
const
std
::
vector
<
std
::
string
>&
targets
,
362
const
std
::
vector
<
std
::
string
>&
evs
) {
363
const
auto
&
mn
=
this
->
MN
();
364
return
evidenceJointImpact
(
mn
.
nodeset
(
targets
),
mn
.
nodeset
(
evs
));
365
}
366
367
368
template
<
typename
GUM_SCALAR
>
369
GUM_SCALAR
370
JointTargetedMNInference
<
GUM_SCALAR
>::
jointMutualInformation
(
const
NodeSet
&
targets
) {
371
const
auto
&
mn
=
this
->
MN
();
372
const
Size
siz
=
targets
.
size
();
373
if
(
siz
<= 1) {
374
GUM_ERROR
(
InvalidArgument
,
375
"jointMutualInformation needs at least 2 variables (targets="
<<
targets
<<
")"
);
376
}
377
378
this
->
eraseAllTargets
();
379
this
->
eraseAllEvidence
();
380
this
->
addJointTarget
(
targets
);
381
this
->
makeInference
();
382
const
auto
po
=
this
->
jointPosterior
(
targets
);
383
384
gum
::
Instantiation
caracteristic
;
385
gum
::
Instantiation
variables
;
386
for
(
const
auto
nod
:
targets
) {
387
const
auto
&
var
=
mn
.
variable
(
nod
);
388
auto
pv
=
new
gum
::
RangeVariable
(
var
.
name
(),
""
, 0, 1);
389
caracteristic
.
add
(*
pv
);
390
variables
.
add
(
var
);
391
}
392
393
Set
<
const
DiscreteVariable
* >
sov
;
394
395
const
GUM_SCALAR
start
= (
siz
% 2 == 0) ?
GUM_SCALAR
(-1.0) :
GUM_SCALAR
(1.0);
396
GUM_SCALAR
sign
;
397
GUM_SCALAR
res
=
GUM_SCALAR
(0.0);
398
399
caracteristic
.
setFirst
();
400
for
(
caracteristic
.
inc
(); !
caracteristic
.
end
();
caracteristic
.
inc
()) {
401
sov
.
clear
();
402
sign
=
start
;
403
for
(
Idx
i
= 0;
i
<
caracteristic
.
nbrDim
();
i
++) {
404
if
(
caracteristic
.
val
(
i
) == 1) {
405
sign
= -
sign
;
406
sov
.
insert
(&
variables
.
variable
(
i
));
407
}
408
}
409
res
+=
sign
*
po
.
margSumIn
(
sov
).
entropy
();
410
}
411
412
for
(
Idx
i
= 0;
i
<
caracteristic
.
nbrDim
();
i
++) {
413
delete
&
caracteristic
.
variable
(
i
);
414
}
415
416
return
res
;
417
}
418
419
template
<
typename
GUM_SCALAR
>
420
GUM_SCALAR
JointTargetedMNInference
<
GUM_SCALAR
>::
jointMutualInformation
(
421
const
std
::
vector
<
std
::
string
>&
targets
) {
422
return
jointMutualInformation
(
this
->
MN
().
nodeset
(
targets
));
423
}
424
425
template
<
typename
GUM_SCALAR
>
426
bool
JointTargetedMNInference
<
GUM_SCALAR
>::
isExactJointComputable_
(
const
NodeSet
&
vars
) {
427
if
(
_joint_targets_
.
contains
(
vars
))
return
true
;
428
429
return
false
;
430
}
431
432
template
<
typename
GUM_SCALAR
>
433
NodeSet
JointTargetedMNInference
<
GUM_SCALAR
>::
superForJointComputable_
(
const
NodeSet
&
vars
) {
434
for
(
const
auto
&
target
:
_joint_targets_
)
435
if
(
vars
.
isProperSubsetOf
(
target
))
return
target
;
436
437
for
(
const
auto
&
factor
:
this
->
MN
().
factors
())
438
if
(
vars
.
isProperSubsetOf
(
factor
.
first
))
return
factor
.
first
;
439
440
return
NodeSet
();
441
}
442
443
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643