aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
correctedMutualInformation_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief The class computing n times the corrected mutual information,
25
* as used in the 3off2 algorithm
26
*
27
* @author Quentin FALCAND, Christophe GONZALES(@AMU) and Pierre-Henri
28
* WUILLEMIN(@LIP6).
29
*/
30
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
31
32
namespace
gum
{
33
34
namespace
learning
{
35
36
/// returns the allocator used by the score
37
template
<
template
<
typename
>
class
ALLOC >
38
typename
CorrectedMutualInformation<
ALLOC
>::
allocator_type
39
CorrectedMutualInformation
<
ALLOC
>::
getAllocator
()
const
{
40
return
_NH_
.
getAllocator
();
41
}
42
43
44
/// default constructor
45
template
<
template
<
typename
>
class
ALLOC
>
46
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
47
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
48
const
Apriori
<
ALLOC
>&
apriori
,
49
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
50
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
ranges
,
51
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
nodeId2columns
,
52
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
53
_NH_
(
parser
,
apriori
,
ranges
,
nodeId2columns
,
alloc
),
54
_k_NML_
(
parser
,
apriori
,
ranges
,
nodeId2columns
,
alloc
),
55
_score_MDL_
(
parser
,
apriori
,
ranges
,
nodeId2columns
,
alloc
),
_ICache_
(
alloc
),
56
_KCache_
(
alloc
) {
57
GUM_CONSTRUCTOR
(
CorrectedMutualInformation
);
58
}
59
60
61
/// default constructor
62
template
<
template
<
typename
>
class
ALLOC
>
63
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
64
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
65
const
Apriori
<
ALLOC
>&
apriori
,
66
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
nodeId2columns
,
67
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
68
_NH_
(
parser
,
apriori
,
nodeId2columns
,
alloc
),
69
_k_NML_
(
parser
,
apriori
,
nodeId2columns
,
alloc
),
70
_score_MDL_
(
parser
,
apriori
,
nodeId2columns
,
alloc
),
_ICache_
(
alloc
),
_KCache_
(
alloc
) {
71
GUM_CONSTRUCTOR
(
CorrectedMutualInformation
);
72
}
73
74
75
/// copy constructor with a given allocator
76
template
<
template
<
typename
>
class
ALLOC
>
77
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
78
const
CorrectedMutualInformation
<
ALLOC
>&
from
,
79
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
80
_NH_
(
from
.
_NH_
,
alloc
),
81
_k_NML_
(
from
.
_k_NML_
,
alloc
),
_score_MDL_
(
from
.
_score_MDL_
,
alloc
),
_kmode_
(
from
.
_kmode_
),
82
_use_ICache_
(
from
.
_use_ICache_
),
_use_HCache_
(
from
.
_use_HCache_
),
83
_use_KCache_
(
from
.
_use_KCache_
),
_use_CnrCache_
(
from
.
_use_CnrCache_
),
84
_ICache_
(
from
.
_ICache_
,
alloc
),
_KCache_
(
from
.
_KCache_
,
alloc
) {
85
GUM_CONS_CPY
(
CorrectedMutualInformation
);
86
}
87
88
89
/// copy constructor
90
template
<
template
<
typename
>
class
ALLOC
>
91
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
92
const
CorrectedMutualInformation
<
ALLOC
>&
from
) :
93
CorrectedMutualInformation
(
from
,
from
.
getAllocator
()) {}
94
95
96
/// move constructor with a given allocator
97
template
<
template
<
typename
>
class
ALLOC
>
98
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
99
CorrectedMutualInformation
<
ALLOC
>&&
from
,
100
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
101
_NH_
(
std
::
move
(
from
.
_NH_
),
alloc
),
102
_k_NML_
(
std
::
move
(
from
.
_k_NML_
),
alloc
),
_score_MDL_
(
std
::
move
(
from
.
_score_MDL_
),
alloc
),
103
_kmode_
(
from
.
_kmode_
),
_use_ICache_
(
from
.
_use_ICache_
),
_use_HCache_
(
from
.
_use_HCache_
),
104
_use_KCache_
(
from
.
_use_KCache_
),
_use_CnrCache_
(
from
.
_use_CnrCache_
),
105
_ICache_
(
std
::
move
(
from
.
_ICache_
),
alloc
),
_KCache_
(
std
::
move
(
from
.
_KCache_
),
alloc
) {
106
GUM_CONS_MOV
(
CorrectedMutualInformation
);
107
}
108
109
110
/// move constructor
111
template
<
template
<
typename
>
class
ALLOC
>
112
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
113
CorrectedMutualInformation
<
ALLOC
>&&
from
) :
114
CorrectedMutualInformation
(
std
::
move
(
from
),
from
.
getAllocator
()) {}
115
116
117
/// virtual copy constructor with a given allocator
118
template
<
template
<
typename
>
class
ALLOC
>
119
CorrectedMutualInformation
<
ALLOC
>*
CorrectedMutualInformation
<
ALLOC
>::
clone
(
120
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
)
const
{
121
ALLOC
<
CorrectedMutualInformation
<
ALLOC
> >
allocator
(
alloc
);
122
CorrectedMutualInformation
<
ALLOC
>*
new_score
=
allocator
.
allocate
(1);
123
try
{
124
allocator
.
construct
(
new_score
, *
this
,
alloc
);
125
}
catch
(...) {
126
allocator
.
deallocate
(
new_score
, 1);
127
throw
;
128
}
129
130
return
new_score
;
131
}
132
133
134
/// virtual copy constructor
135
template
<
template
<
typename
>
class
ALLOC
>
136
CorrectedMutualInformation
<
ALLOC
>*
CorrectedMutualInformation
<
ALLOC
>::
clone
()
const
{
137
return
clone
(
this
->
getAllocator
());
138
}
139
140
141
/// destructor
142
template
<
template
<
typename
>
class
ALLOC
>
143
CorrectedMutualInformation
<
ALLOC
>::~
CorrectedMutualInformation
() {
144
// for debugging purposes
145
GUM_DESTRUCTOR
(
CorrectedMutualInformation
);
146
}
147
148
149
/// copy operator
150
template
<
template
<
typename
>
class
ALLOC
>
151
CorrectedMutualInformation
<
ALLOC
>&
CorrectedMutualInformation
<
ALLOC
>::
operator
=(
152
const
CorrectedMutualInformation
<
ALLOC
>&
from
) {
153
if
(
this
!= &
from
) {
154
_NH_
=
from
.
_NH_
;
155
_k_NML_
=
from
.
_k_NML_
;
156
_score_MDL_
=
from
.
_score_MDL_
;
157
_kmode_
=
from
.
_kmode_
;
158
_use_ICache_
=
from
.
_use_ICache_
;
159
_use_HCache_
=
from
.
_use_HCache_
;
160
_use_KCache_
=
from
.
_use_KCache_
;
161
_use_CnrCache_
=
from
.
_use_CnrCache_
;
162
_ICache_
=
from
.
_ICache_
;
163
_KCache_
=
from
.
_KCache_
;
164
}
165
return
*
this
;
166
}
167
168
169
/// move operator
170
template
<
template
<
typename
>
class
ALLOC
>
171
CorrectedMutualInformation
<
ALLOC
>&
172
CorrectedMutualInformation
<
ALLOC
>::
operator
=(
CorrectedMutualInformation
<
ALLOC
>&&
from
) {
173
if
(
this
!= &
from
) {
174
_NH_
=
std
::
move
(
from
.
_NH_
);
175
_k_NML_
=
std
::
move
(
from
.
_k_NML_
);
176
_score_MDL_
=
std
::
move
(
from
.
_score_MDL_
);
177
_kmode_
=
from
.
_kmode_
;
178
_use_ICache_
=
from
.
_use_ICache_
;
179
_use_HCache_
=
from
.
_use_HCache_
;
180
_use_KCache_
=
from
.
_use_KCache_
;
181
_use_CnrCache_
=
from
.
_use_CnrCache_
;
182
_ICache_
=
std
::
move
(
from
.
_ICache_
);
183
_KCache_
=
std
::
move
(
from
.
_KCache_
);
184
}
185
return
*
this
;
186
}
187
188
189
/// turn on/off the use of all the caches
190
template
<
template
<
typename
>
class
ALLOC
>
191
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useCache
(
bool
on_off
) {
192
useICache
(
on_off
);
193
useHCache
(
on_off
);
194
useKCache
(
on_off
);
195
useCnrCache
(
on_off
);
196
}
197
198
199
/// turn on/off the use of the I cache
200
template
<
template
<
typename
>
class
ALLOC
>
201
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useICache
(
bool
on_off
) {
202
if
(!
on_off
)
_ICache_
.
clear
();
203
_use_ICache_
=
on_off
;
204
}
205
206
207
/// turn on/off the use of the H cache
208
template
<
template
<
typename
>
class
ALLOC
>
209
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useHCache
(
bool
on_off
) {
210
if
(!
on_off
)
_NH_
.
clearCache
();
211
_use_HCache_
=
on_off
;
212
_NH_
.
useCache
(
on_off
);
213
}
214
215
216
/// turn on/off the use of the K cache
217
template
<
template
<
typename
>
class
ALLOC
>
218
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useKCache
(
bool
on_off
) {
219
if
(!
on_off
)
_KCache_
.
clear
();
220
_use_KCache_
=
on_off
;
221
}
222
223
224
/// turn on/off the use of the Cnr cache
225
template
<
template
<
typename
>
class
ALLOC
>
226
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useCnrCache
(
bool
on_off
) {
227
if
(!
on_off
)
_k_NML_
.
clearCache
();
228
_use_CnrCache_
=
on_off
;
229
_k_NML_
.
useCache
(
on_off
);
230
}
231
232
233
/// clears all the data structures from memory
234
template
<
template
<
typename
>
class
ALLOC
>
235
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clear
() {
236
_NH_
.
clear
();
237
_k_NML_
.
clear
();
238
_score_MDL_
.
clear
();
239
clearCache
();
240
}
241
242
243
/// clears the current cache (clear nodesets as well)
244
template
<
template
<
typename
>
class
ALLOC
>
245
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearCache
() {
246
_NH_
.
clearCache
();
247
_k_NML_
.
clearCache
();
248
_ICache_
.
clear
();
249
_KCache_
.
clear
();
250
}
251
252
253
/// clears the ICache (the mutual information cache)
254
template
<
template
<
typename
>
class
ALLOC
>
255
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearICache
() {
256
_ICache_
.
clear
();
257
}
258
259
260
/// clears the HCache (the cache for the entropies)
261
template
<
template
<
typename
>
class
ALLOC
>
262
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearHCache
() {
263
_NH_
.
clearCache
();
264
}
265
266
267
/// clears the KCache (the cache for the penalties)
268
template
<
template
<
typename
>
class
ALLOC
>
269
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearKCache
() {
270
_KCache_
.
clear
();
271
}
272
273
274
/// clears the CnrCache (the cache for the Cnr formula)
275
template
<
template
<
typename
>
class
ALLOC
>
276
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearCnrCache
() {
277
_k_NML_
.
clearCache
();
278
}
279
280
281
/// changes the max number of threads used to parse the database
282
template
<
template
<
typename
>
class
ALLOC
>
283
void
CorrectedMutualInformation
<
ALLOC
>::
setMaxNbThreads
(
std
::
size_t
nb
)
const
{
284
_NH_
.
setMaxNbThreads
(
nb
);
285
_k_NML_
.
setMaxNbThreads
(
nb
);
286
_score_MDL_
.
setMaxNbThreads
(
nb
);
287
}
288
289
290
/// returns the number of threads used to parse the database
291
template
<
template
<
typename
>
class
ALLOC
>
292
std
::
size_t
CorrectedMutualInformation
<
ALLOC
>::
nbThreads
()
const
{
293
return
_NH_
.
nbThreads
();
294
}
295
296
297
/** @brief changes the number min of rows a thread should process in a
298
* multithreading context */
299
template
<
template
<
typename
>
class
ALLOC
>
300
void
CorrectedMutualInformation
<
ALLOC
>::
setMinNbRowsPerThread
(
const
std
::
size_t
nb
)
const
{
301
_NH_
.
setMinNbRowsPerThread
(
nb
);
302
_k_NML_
.
setMinNbRowsPerThread
(
nb
);
303
_score_MDL_
.
setMinNbRowsPerThread
(
nb
);
304
}
305
306
307
/// returns the minimum of rows that each thread should process
308
template
<
template
<
typename
>
class
ALLOC
>
309
INLINE
std
::
size_t
CorrectedMutualInformation
<
ALLOC
>::
minNbRowsPerThread
()
const
{
310
return
_NH_
.
minNbRowsPerThread
();
311
}
312
313
314
/// sets new ranges to perform the countings used by the score
315
/** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
316
* indices. The countings are then performed only on the union of the
317
* rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
318
* cross validation tasks, in which part of the database should be ignored.
319
* An empty set of ranges is equivalent to an interval [X,Y) ranging over
320
* the whole database. */
321
template
<
template
<
typename
>
class
ALLOC
>
322
template
<
template
<
typename
>
class
XALLOC
>
323
void
CorrectedMutualInformation
<
ALLOC
>::
setRanges
(
324
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
325
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
new_ranges
) {
326
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
327
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
328
old_ranges
=
ranges
();
329
330
_NH_
.
setRanges
(
new_ranges
);
331
_k_NML_
.
setRanges
(
new_ranges
);
332
_score_MDL_
.
setRanges
(
new_ranges
);
333
334
if
(
old_ranges
!=
ranges
())
clear
();
335
}
336
337
338
/// reset the ranges to the one range corresponding to the whole database
339
template
<
template
<
typename
>
class
ALLOC
>
340
void
CorrectedMutualInformation
<
ALLOC
>::
clearRanges
() {
341
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
342
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
343
old_ranges
=
ranges
();
344
_NH_
.
clearRanges
();
345
_k_NML_
.
clearRanges
();
346
_score_MDL_
.
clearRanges
();
347
if
(
old_ranges
!=
ranges
())
clear
();
348
}
349
350
351
/// returns the current ranges
352
template
<
template
<
typename
>
class
ALLOC
>
353
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
354
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
355
CorrectedMutualInformation
<
ALLOC
>::
ranges
()
const
{
356
return
_NH_
.
ranges
();
357
}
358
359
360
/// use the MDL penalty function
361
template
<
template
<
typename
>
class
ALLOC
>
362
void
CorrectedMutualInformation
<
ALLOC
>::
useMDL
() {
363
clearCache
();
364
_kmode_
=
KModeTypes
::
MDL
;
365
}
366
367
368
/// use the kNML penalty function
369
template
<
template
<
typename
>
class
ALLOC
>
370
void
CorrectedMutualInformation
<
ALLOC
>::
useNML
() {
371
clearCache
();
372
_kmode_
=
KModeTypes
::
NML
;
373
}
374
375
376
/// use no correction/penalty function
377
template
<
template
<
typename
>
class
ALLOC
>
378
void
CorrectedMutualInformation
<
ALLOC
>::
useNoCorr
() {
379
clearCache
();
380
_kmode_
=
KModeTypes
::
NoCorr
;
381
}
382
383
384
/// returns the 2-point mutual information corresponding to a given nodeset
385
template
<
template
<
typename
>
class
ALLOC
>
386
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
score
(
NodeId
var1
,
NodeId
var2
) {
387
return
score
(
var1
,
var2
,
_empty_conditioning_set_
);
388
}
389
390
391
/// returns the 2-point mutual information corresponding to a given nodeset
392
template
<
template
<
typename
>
class
ALLOC
>
393
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
score
(
394
NodeId
var1
,
395
NodeId
var2
,
396
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_ids
) {
397
return
_NI_score_
(
var1
,
var2
,
conditioning_ids
) -
_K_score_
(
var1
,
var2
,
conditioning_ids
);
398
}
399
400
401
/// returns the 3-point mutual information corresponding to a given nodeset
402
template
<
template
<
typename
>
class
ALLOC
>
403
INLINE
double
404
CorrectedMutualInformation
<
ALLOC
>::
score
(
NodeId
var1
,
NodeId
var2
,
NodeId
var3
) {
405
return
score
(
var1
,
var2
,
var3
,
_empty_conditioning_set_
);
406
}
407
408
409
/// returns the 3-point mutual information corresponding to a given nodeset
410
template
<
template
<
typename
>
class
ALLOC
>
411
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
score
(
412
NodeId
var1
,
413
NodeId
var2
,
414
NodeId
var3
,
415
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_ids
) {
416
return
_NI_score_
(
var1
,
var2
,
var3
,
conditioning_ids
)
417
+
_K_score_
(
var1
,
var2
,
var3
,
conditioning_ids
);
418
}
419
420
421
/// return N times the mutual information for conditioned pairs of variables
422
template
<
template
<
typename
>
class
ALLOC
>
423
double
CorrectedMutualInformation
<
ALLOC
>::
_NI_score_
(
424
NodeId
var_x
,
425
NodeId
var_y
,
426
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
vars_z
) {
427
/*
428
* We have a few partial entropies to compute in order to have the
429
* 2-point mutual information:
430
* I(x;y) = H(x) + H(y) - H(x,y)
431
* correspondingly
432
* I(x;y) = Hx + Hy - Hxy
433
* or
434
* I(x;y|z) = H(x,z) + H(y,z) - H(z) - H(x,y,z)
435
* correspondingly
436
* I(x;y|z) = Hxz + Hyz - Hz - Hxyz
437
* Note that Entropy H is equal to 1/N times the log2Likelihood,
438
* where N is the size of the database.
439
* Remember that we return N times I(x;y|z)
440
*/
441
442
// if the score has already been computed, get its value
443
const
IdCondSet
<
ALLOC
>
idset_xyz
(
var_x
,
var_y
,
vars_z
,
false
,
false
);
444
if
(
_use_ICache_
) {
445
try
{
446
return
_ICache_
.
score
(
idset_xyz
);
447
}
catch
(
const
NotFound
&) {}
448
}
449
450
// compute the score
451
452
// here, we distinguish nodesets with conditioning nodes from those
453
// without conditioning nodes
454
double
score
;
455
if
(!
vars_z
.
empty
()) {
456
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >
vars
(
vars_z
);
457
// std::sort(vars.begin(), vars.end());
458
vars
.
push_back
(
var_x
);
459
vars
.
push_back
(
var_y
);
460
const
double
NHxyz
= -
_NH_
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
461
462
vars
.
pop_back
();
463
const
double
NHxz
= -
_NH_
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
464
465
vars
.
pop_back
();
466
vars
.
push_back
(
var_y
);
467
const
double
NHyz
= -
_NH_
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
468
469
vars
.
pop_back
();
470
const
double
NHz
= -
_NH_
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
471
472
const
double
NHxz_NHyz
=
NHxz
+
NHyz
;
473
double
NHz_NHxyz
=
NHz
+
NHxyz
;
474
475
// avoid numeric instability due to rounding errors
476
double
ratio
= 1;
477
if
(
NHxz_NHyz
> 0) {
478
ratio
= (
NHxz_NHyz
-
NHz_NHxyz
) /
NHxz_NHyz
;
479
}
else
if
(
NHz_NHxyz
> 0) {
480
ratio
= (
NHxz_NHyz
-
NHz_NHxyz
) /
NHz_NHxyz
;
481
}
482
if
(
ratio
< 0)
ratio
= -
ratio
;
483
if
(
ratio
<
_threshold_
) {
484
NHz_NHxyz
=
NHxz_NHyz
;
// ensure that the score is equal to 0
485
}
486
487
score
=
NHxz_NHyz
-
NHz_NHxyz
;
488
}
else
{
489
const
double
NHxy
490
= -
_NH_
.
score
(
IdCondSet
<
ALLOC
>(
var_x
,
var_y
,
_empty_conditioning_set_
,
true
,
false
));
491
const
double
NHx
= -
_NH_
.
score
(
var_x
);
492
const
double
NHy
= -
_NH_
.
score
(
var_y
);
493
494
double
NHx_NHy
=
NHx
+
NHy
;
495
496
// avoid numeric instability due to rounding errors
497
double
ratio
= 1;
498
if
(
NHx_NHy
> 0) {
499
ratio
= (
NHx_NHy
-
NHxy
) /
NHx_NHy
;
500
}
else
if
(
NHxy
> 0) {
501
ratio
= (
NHx_NHy
-
NHxy
) /
NHxy
;
502
}
503
if
(
ratio
< 0)
ratio
= -
ratio
;
504
if
(
ratio
<
_threshold_
) {
505
NHx_NHy
=
NHxy
;
// ensure that the score is equal to 0
506
}
507
508
score
=
NHx_NHy
-
NHxy
;
509
}
510
511
512
// shall we put the score into the cache?
513
if
(
_use_ICache_
) {
_ICache_
.
insert
(
idset_xyz
,
score
); }
514
515
return
score
;
516
}
517
518
519
/// return N times the mutual information for conditioned triples of variables
520
template
<
template
<
typename
>
class
ALLOC
>
521
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
_NI_score_
(
522
NodeId
var_x
,
523
NodeId
var_y
,
524
NodeId
var_z
,
525
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
ui_ids
) {
526
// conditional 3-point mutual information formula:
527
// I(x;y;z|{ui}) = I(x;y|{ui}) - I(x;y|z,{ui})
528
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >
uiz_ids
=
ui_ids
;
529
uiz_ids
.
push_back
(
var_z
);
530
return
_NI_score_
(
var_x
,
var_y
,
ui_ids
) -
_NI_score_
(
var_x
,
var_y
,
uiz_ids
);
531
}
532
533
534
/// 2pt penalty
535
template
<
template
<
typename
>
class
ALLOC
>
536
double
CorrectedMutualInformation
<
ALLOC
>::
_K_score_
(
537
NodeId
var1
,
538
NodeId
var2
,
539
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_ids
) {
540
// if no penalty, return 0
541
if
(
_kmode_
==
KModeTypes
::
NoCorr
)
return
0.0;
542
543
544
// If using the K cache, verify whether the set isn't already known
545
IdCondSet
<
ALLOC
>
idset
;
546
if
(
_use_KCache_
) {
547
idset
=
std
::
move
(
IdCondSet
<
ALLOC
>(
var1
,
var2
,
conditioning_ids
,
false
));
548
try
{
549
return
_KCache_
.
score
(
idset
);
550
}
catch
(
const
NotFound
&) {}
551
}
552
553
// compute the score
554
double
score
;
555
size_t
rx
,
ry
,
rui
;
556
switch
(
_kmode_
) {
557
case
KModeTypes
::
MDL
: {
558
const
auto
&
database
=
_NH_
.
database
();
559
const
auto
&
node2cols
=
_NH_
.
nodeId2Columns
();
560
561
rui
= 1;
562
if
(!
node2cols
.
empty
()) {
563
rx
=
database
.
domainSize
(
node2cols
.
second
(
var1
));
564
ry
=
database
.
domainSize
(
node2cols
.
second
(
var2
));
565
for
(
const
NodeId
i
:
conditioning_ids
) {
566
rui
*=
database
.
domainSize
(
node2cols
.
second
(
i
));
567
}
568
}
else
{
569
rx
=
database
.
domainSize
(
var1
);
570
ry
=
database
.
domainSize
(
var2
);
571
for
(
const
NodeId
i
:
conditioning_ids
) {
572
rui
*=
database
.
domainSize
(
i
);
573
}
574
}
575
576
// compute the size of the database, including the a priori
577
if
(!
_use_KCache_
) {
578
idset
=
std
::
move
(
IdCondSet
<
ALLOC
>(
var1
,
var2
,
conditioning_ids
,
false
));
579
}
580
const
double
N
=
_score_MDL_
.
N
(
idset
);
581
582
score
= 0.5 * (
rx
- 1) * (
ry
- 1) *
rui
*
std
::
log2
(
N
);
583
}
break
;
584
585
case
KModeTypes
::
NML
:
586
score
=
_k_NML_
.
score
(
var1
,
var2
,
conditioning_ids
);
587
break
;
588
589
default
:
590
GUM_ERROR
(
NotImplementedYet
,
591
"CorrectedMutualInformation mode does "
592
"not support yet this correction"
);
593
}
594
595
// shall we put the score into the cache?
596
if
(
_use_KCache_
) {
_KCache_
.
insert
(
idset
,
score
); }
597
return
score
;
598
}
599
600
601
/// 3pt penalty
602
template
<
template
<
typename
>
class
ALLOC
>
603
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
_K_score_
(
604
NodeId
var1
,
605
NodeId
var2
,
606
NodeId
var3
,
607
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
ui_ids
) {
608
// k(x;y;z|ui) = k(x;y|ui,z) - k(x;y|ui)
609
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >
uiz_ids
=
ui_ids
;
610
uiz_ids
.
push_back
(
var3
);
611
return
_K_score_
(
var1
,
var2
,
uiz_ids
) -
_K_score_
(
var1
,
var2
,
ui_ids
);
612
}
613
614
615
}
/* namespace learning */
616
617
}
/* namespace gum */
618
619
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31