aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
correctedMutualInformation_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief The class computing n times the corrected mutual information,
25
* as used in the 3off2 algorithm
26
*
27
* @author Quentin FALCAND, Christophe GONZALES(@AMU) and Pierre-Henri
28
* WUILLEMIN(@LIP6).
29
*/
30
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
31
32
namespace
gum
{
33
34
namespace
learning
{
35
36
/// returns the allocator used by the score
37
template
<
template
<
typename
>
class
ALLOC >
38
typename
CorrectedMutualInformation<
ALLOC
>::
allocator_type
39
CorrectedMutualInformation
<
ALLOC
>::
getAllocator
()
const
{
40
return
NH__
.
getAllocator
();
41
}
42
43
44
/// default constructor
45
template
<
template
<
typename
>
class
ALLOC
>
46
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
47
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
48
const
Apriori
<
ALLOC
>&
apriori
,
49
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
50
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
ranges
,
51
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
52
nodeId2columns
,
53
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
54
NH__
(
parser
,
apriori
,
ranges
,
nodeId2columns
,
alloc
),
55
k_NML__
(
parser
,
apriori
,
ranges
,
nodeId2columns
,
alloc
),
56
score_MDL__
(
parser
,
apriori
,
ranges
,
nodeId2columns
,
alloc
),
57
ICache__
(
alloc
),
KCache__
(
alloc
) {
58
GUM_CONSTRUCTOR
(
CorrectedMutualInformation
);
59
}
60
61
62
/// default constructor
63
template
<
template
<
typename
>
class
ALLOC
>
64
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
65
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
66
const
Apriori
<
ALLOC
>&
apriori
,
67
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
68
nodeId2columns
,
69
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
70
NH__
(
parser
,
apriori
,
nodeId2columns
,
alloc
),
71
k_NML__
(
parser
,
apriori
,
nodeId2columns
,
alloc
),
72
score_MDL__
(
parser
,
apriori
,
nodeId2columns
,
alloc
),
ICache__
(
alloc
),
73
KCache__
(
alloc
) {
74
GUM_CONSTRUCTOR
(
CorrectedMutualInformation
);
75
}
76
77
78
/// copy constructor with a given allocator
79
template
<
template
<
typename
>
class
ALLOC
>
80
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
81
const
CorrectedMutualInformation
<
ALLOC
>&
from
,
82
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
83
NH__
(
from
.
NH__
,
alloc
),
84
k_NML__
(
from
.
k_NML__
,
alloc
),
score_MDL__
(
from
.
score_MDL__
,
alloc
),
85
kmode__
(
from
.
kmode__
),
use_ICache__
(
from
.
use_ICache__
),
86
use_HCache__
(
from
.
use_HCache__
),
use_KCache__
(
from
.
use_KCache__
),
87
use_CnrCache__
(
from
.
use_CnrCache__
),
ICache__
(
from
.
ICache__
,
alloc
),
88
KCache__
(
from
.
KCache__
,
alloc
) {
89
GUM_CONS_CPY
(
CorrectedMutualInformation
);
90
}
91
92
93
/// copy constructor
94
template
<
template
<
typename
>
class
ALLOC
>
95
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
96
const
CorrectedMutualInformation
<
ALLOC
>&
from
) :
97
CorrectedMutualInformation
(
from
,
from
.
getAllocator
()) {}
98
99
100
/// move constructor with a given allocator
101
template
<
template
<
typename
>
class
ALLOC
>
102
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
103
CorrectedMutualInformation
<
ALLOC
>&&
from
,
104
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
alloc
) :
105
NH__
(
std
::
move
(
from
.
NH__
),
alloc
),
106
k_NML__
(
std
::
move
(
from
.
k_NML__
),
alloc
),
107
score_MDL__
(
std
::
move
(
from
.
score_MDL__
),
alloc
),
kmode__
(
from
.
kmode__
),
108
use_ICache__
(
from
.
use_ICache__
),
use_HCache__
(
from
.
use_HCache__
),
109
use_KCache__
(
from
.
use_KCache__
),
use_CnrCache__
(
from
.
use_CnrCache__
),
110
ICache__
(
std
::
move
(
from
.
ICache__
),
alloc
),
111
KCache__
(
std
::
move
(
from
.
KCache__
),
alloc
) {
112
GUM_CONS_MOV
(
CorrectedMutualInformation
);
113
}
114
115
116
/// move constructor
117
template
<
template
<
typename
>
class
ALLOC
>
118
CorrectedMutualInformation
<
ALLOC
>::
CorrectedMutualInformation
(
119
CorrectedMutualInformation
<
ALLOC
>&&
from
) :
120
CorrectedMutualInformation
(
std
::
move
(
from
),
from
.
getAllocator
()) {}
121
122
123
/// virtual copy constructor with a given allocator
124
template
<
template
<
typename
>
class
ALLOC
>
125
CorrectedMutualInformation
<
ALLOC
>*
126
CorrectedMutualInformation
<
ALLOC
>::
clone
(
127
const
typename
CorrectedMutualInformation
<
ALLOC
>::
allocator_type
&
128
alloc
)
const
{
129
ALLOC
<
CorrectedMutualInformation
<
ALLOC
> >
allocator
(
alloc
);
130
CorrectedMutualInformation
<
ALLOC
>*
new_score
=
allocator
.
allocate
(1);
131
try
{
132
allocator
.
construct
(
new_score
, *
this
,
alloc
);
133
}
catch
(...) {
134
allocator
.
deallocate
(
new_score
, 1);
135
throw
;
136
}
137
138
return
new_score
;
139
}
140
141
142
/// virtual copy constructor
143
template
<
template
<
typename
>
class
ALLOC
>
144
CorrectedMutualInformation
<
ALLOC
>*
145
CorrectedMutualInformation
<
ALLOC
>::
clone
()
const
{
146
return
clone
(
this
->
getAllocator
());
147
}
148
149
150
/// destructor
151
template
<
template
<
typename
>
class
ALLOC
>
152
CorrectedMutualInformation
<
ALLOC
>::~
CorrectedMutualInformation
() {
153
// for debugging purposes
154
GUM_DESTRUCTOR
(
CorrectedMutualInformation
);
155
}
156
157
158
/// copy operator
159
template
<
template
<
typename
>
class
ALLOC
>
160
CorrectedMutualInformation
<
ALLOC
>&
161
CorrectedMutualInformation
<
ALLOC
>::
operator
=(
162
const
CorrectedMutualInformation
<
ALLOC
>&
from
) {
163
if
(
this
!= &
from
) {
164
NH__
=
from
.
NH__
;
165
k_NML__
=
from
.
k_NML__
;
166
score_MDL__
=
from
.
score_MDL__
;
167
kmode__
=
from
.
kmode__
;
168
use_ICache__
=
from
.
use_ICache__
;
169
use_HCache__
=
from
.
use_HCache__
;
170
use_KCache__
=
from
.
use_KCache__
;
171
use_CnrCache__
=
from
.
use_CnrCache__
;
172
ICache__
=
from
.
ICache__
;
173
KCache__
=
from
.
KCache__
;
174
}
175
return
*
this
;
176
}
177
178
179
/// move operator
180
template
<
template
<
typename
>
class
ALLOC
>
181
CorrectedMutualInformation
<
ALLOC
>&
182
CorrectedMutualInformation
<
ALLOC
>::
operator
=(
183
CorrectedMutualInformation
<
ALLOC
>&&
from
) {
184
if
(
this
!= &
from
) {
185
NH__
=
std
::
move
(
from
.
NH__
);
186
k_NML__
=
std
::
move
(
from
.
k_NML__
);
187
score_MDL__
=
std
::
move
(
from
.
score_MDL__
);
188
kmode__
=
from
.
kmode__
;
189
use_ICache__
=
from
.
use_ICache__
;
190
use_HCache__
=
from
.
use_HCache__
;
191
use_KCache__
=
from
.
use_KCache__
;
192
use_CnrCache__
=
from
.
use_CnrCache__
;
193
ICache__
=
std
::
move
(
from
.
ICache__
);
194
KCache__
=
std
::
move
(
from
.
KCache__
);
195
}
196
return
*
this
;
197
}
198
199
200
/// turn on/off the use of all the caches
201
template
<
template
<
typename
>
class
ALLOC
>
202
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useCache
(
bool
on_off
) {
203
useICache
(
on_off
);
204
useHCache
(
on_off
);
205
useKCache
(
on_off
);
206
useCnrCache
(
on_off
);
207
}
208
209
210
/// turn on/off the use of the I cache
211
template
<
template
<
typename
>
class
ALLOC
>
212
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useICache
(
bool
on_off
) {
213
if
(!
on_off
)
ICache__
.
clear
();
214
use_ICache__
=
on_off
;
215
}
216
217
218
/// turn on/off the use of the H cache
219
template
<
template
<
typename
>
class
ALLOC
>
220
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useHCache
(
bool
on_off
) {
221
if
(!
on_off
)
NH__
.
clearCache
();
222
use_HCache__
=
on_off
;
223
NH__
.
useCache
(
on_off
);
224
}
225
226
227
/// turn on/off the use of the K cache
228
template
<
template
<
typename
>
class
ALLOC
>
229
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useKCache
(
bool
on_off
) {
230
if
(!
on_off
)
KCache__
.
clear
();
231
use_KCache__
=
on_off
;
232
}
233
234
235
/// turn on/off the use of the Cnr cache
236
template
<
template
<
typename
>
class
ALLOC
>
237
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
useCnrCache
(
bool
on_off
) {
238
if
(!
on_off
)
k_NML__
.
clearCache
();
239
use_CnrCache__
=
on_off
;
240
k_NML__
.
useCache
(
on_off
);
241
}
242
243
244
/// clears all the data structures from memory
245
template
<
template
<
typename
>
class
ALLOC
>
246
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clear
() {
247
NH__
.
clear
();
248
k_NML__
.
clear
();
249
score_MDL__
.
clear
();
250
clearCache
();
251
}
252
253
254
/// clears the current cache (clear nodesets as well)
255
template
<
template
<
typename
>
class
ALLOC
>
256
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearCache
() {
257
NH__
.
clearCache
();
258
k_NML__
.
clearCache
();
259
ICache__
.
clear
();
260
KCache__
.
clear
();
261
}
262
263
264
/// clears the ICache (the mutual information cache)
265
template
<
template
<
typename
>
class
ALLOC
>
266
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearICache
() {
267
ICache__
.
clear
();
268
}
269
270
271
/// clears the HCache (the cache for the entropies)
272
template
<
template
<
typename
>
class
ALLOC
>
273
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearHCache
() {
274
NH__
.
clearCache
();
275
}
276
277
278
/// clears the KCache (the cache for the penalties)
279
template
<
template
<
typename
>
class
ALLOC
>
280
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearKCache
() {
281
KCache__
.
clear
();
282
}
283
284
285
/// clears the CnrCache (the cache for the Cnr formula)
286
template
<
template
<
typename
>
class
ALLOC
>
287
INLINE
void
CorrectedMutualInformation
<
ALLOC
>::
clearCnrCache
() {
288
k_NML__
.
clearCache
();
289
}
290
291
292
/// changes the max number of threads used to parse the database
293
template
<
template
<
typename
>
class
ALLOC
>
294
void
295
CorrectedMutualInformation
<
ALLOC
>::
setMaxNbThreads
(
std
::
size_t
nb
)
const
{
296
NH__
.
setMaxNbThreads
(
nb
);
297
k_NML__
.
setMaxNbThreads
(
nb
);
298
score_MDL__
.
setMaxNbThreads
(
nb
);
299
}
300
301
302
/// returns the number of threads used to parse the database
303
template
<
template
<
typename
>
class
ALLOC
>
304
std
::
size_t
CorrectedMutualInformation
<
ALLOC
>::
nbThreads
()
const
{
305
return
NH__
.
nbThreads
();
306
}
307
308
309
/** @brief changes the number min of rows a thread should process in a
310
* multithreading context */
311
template
<
template
<
typename
>
class
ALLOC
>
312
void
CorrectedMutualInformation
<
ALLOC
>::
setMinNbRowsPerThread
(
313
const
std
::
size_t
nb
)
const
{
314
NH__
.
setMinNbRowsPerThread
(
nb
);
315
k_NML__
.
setMinNbRowsPerThread
(
nb
);
316
score_MDL__
.
setMinNbRowsPerThread
(
nb
);
317
}
318
319
320
/// returns the minimum of rows that each thread should process
321
template
<
template
<
typename
>
class
ALLOC
>
322
INLINE
std
::
size_t
323
CorrectedMutualInformation
<
ALLOC
>::
minNbRowsPerThread
()
const
{
324
return
NH__
.
minNbRowsPerThread
();
325
}
326
327
328
/// sets new ranges to perform the countings used by the score
329
/** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
330
* indices. The countings are then performed only on the union of the
331
* rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
332
* cross validation tasks, in which part of the database should be ignored.
333
* An empty set of ranges is equivalent to an interval [X,Y) ranging over
334
* the whole database. */
335
template
<
template
<
typename
>
class
ALLOC
>
336
template
<
template
<
typename
>
class
XALLOC
>
337
void
CorrectedMutualInformation
<
ALLOC
>::
setRanges
(
338
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
339
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
340
new_ranges
) {
341
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
342
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
343
old_ranges
=
ranges
();
344
345
NH__
.
setRanges
(
new_ranges
);
346
k_NML__
.
setRanges
(
new_ranges
);
347
score_MDL__
.
setRanges
(
new_ranges
);
348
349
if
(
old_ranges
!=
ranges
())
clear
();
350
}
351
352
353
/// reset the ranges to the one range corresponding to the whole database
354
template
<
template
<
typename
>
class
ALLOC
>
355
void
CorrectedMutualInformation
<
ALLOC
>::
clearRanges
() {
356
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
357
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
358
old_ranges
=
ranges
();
359
NH__
.
clearRanges
();
360
k_NML__
.
clearRanges
();
361
score_MDL__
.
clearRanges
();
362
if
(
old_ranges
!=
ranges
())
clear
();
363
}
364
365
366
/// returns the current ranges
367
template
<
template
<
typename
>
class
ALLOC
>
368
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
369
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
370
CorrectedMutualInformation
<
ALLOC
>::
ranges
()
const
{
371
return
NH__
.
ranges
();
372
}
373
374
375
/// use the MDL penalty function
376
template
<
template
<
typename
>
class
ALLOC
>
377
void
CorrectedMutualInformation
<
ALLOC
>::
useMDL
() {
378
clearCache
();
379
kmode__
=
KModeTypes
::
MDL
;
380
}
381
382
383
/// use the kNML penalty function
384
template
<
template
<
typename
>
class
ALLOC
>
385
void
CorrectedMutualInformation
<
ALLOC
>::
useNML
() {
386
clearCache
();
387
kmode__
=
KModeTypes
::
NML
;
388
}
389
390
391
/// use no correction/penalty function
392
template
<
template
<
typename
>
class
ALLOC
>
393
void
CorrectedMutualInformation
<
ALLOC
>::
useNoCorr
() {
394
clearCache
();
395
kmode__
=
KModeTypes
::
NoCorr
;
396
}
397
398
399
/// returns the 2-point mutual information corresponding to a given nodeset
400
template
<
template
<
typename
>
class
ALLOC
>
401
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
score
(
NodeId
var1
,
402
NodeId
var2
) {
403
return
score
(
var1
,
var2
,
empty_conditioning_set__
);
404
}
405
406
407
/// returns the 2-point mutual information corresponding to a given nodeset
408
template
<
template
<
typename
>
class
ALLOC
>
409
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
score
(
410
NodeId
var1
,
411
NodeId
var2
,
412
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_ids
) {
413
return
NI_score__
(
var1
,
var2
,
conditioning_ids
)
414
-
K_score__
(
var1
,
var2
,
conditioning_ids
);
415
}
416
417
418
/// returns the 3-point mutual information corresponding to a given nodeset
419
template
<
template
<
typename
>
class
ALLOC
>
420
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
score
(
NodeId
var1
,
421
NodeId
var2
,
422
NodeId
var3
) {
423
return
score
(
var1
,
var2
,
var3
,
empty_conditioning_set__
);
424
}
425
426
427
/// returns the 3-point mutual information corresponding to a given nodeset
428
template
<
template
<
typename
>
class
ALLOC
>
429
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
score
(
430
NodeId
var1
,
431
NodeId
var2
,
432
NodeId
var3
,
433
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_ids
) {
434
return
NI_score__
(
var1
,
var2
,
var3
,
conditioning_ids
)
435
+
K_score__
(
var1
,
var2
,
var3
,
conditioning_ids
);
436
}
437
438
439
/// return N times the mutual information for conditioned pairs of variables
440
template
<
template
<
typename
>
class
ALLOC
>
441
double
CorrectedMutualInformation
<
ALLOC
>::
NI_score__
(
442
NodeId
var_x
,
443
NodeId
var_y
,
444
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
vars_z
) {
445
/*
446
* We have a few partial entropies to compute in order to have the
447
* 2-point mutual information:
448
* I(x;y) = H(x) + H(y) - H(x,y)
449
* correspondingly
450
* I(x;y) = Hx + Hy - Hxy
451
* or
452
* I(x;y|z) = H(x,z) + H(y,z) - H(z) - H(x,y,z)
453
* correspondingly
454
* I(x;y|z) = Hxz + Hyz - Hz - Hxyz
455
* Note that Entropy H is equal to 1/N times the log2Likelihood,
456
* where N is the size of the database.
457
* Remember that we return N times I(x;y|z)
458
*/
459
460
// if the score has already been computed, get its value
461
const
IdCondSet
<
ALLOC
>
idset_xyz
(
var_x
,
var_y
,
vars_z
,
false
,
false
);
462
if
(
use_ICache__
) {
463
try
{
464
return
ICache__
.
score
(
idset_xyz
);
465
}
catch
(
const
NotFound
&) {}
466
}
467
468
// compute the score
469
470
// here, we distinguish nodesets with conditioning nodes from those
471
// without conditioning nodes
472
double
score
;
473
if
(!
vars_z
.
empty
()) {
474
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >
vars
(
vars_z
);
475
// std::sort(vars.begin(), vars.end());
476
vars
.
push_back
(
var_x
);
477
vars
.
push_back
(
var_y
);
478
const
double
NHxyz
= -
NH__
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
479
480
vars
.
pop_back
();
481
const
double
NHxz
= -
NH__
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
482
483
vars
.
pop_back
();
484
vars
.
push_back
(
var_y
);
485
const
double
NHyz
= -
NH__
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
486
487
vars
.
pop_back
();
488
const
double
NHz
= -
NH__
.
score
(
IdCondSet
<
ALLOC
>(
vars
,
false
,
true
));
489
490
const
double
NHxz_NHyz
=
NHxz
+
NHyz
;
491
double
NHz_NHxyz
=
NHz
+
NHxyz
;
492
493
// avoid numeric instability due to rounding errors
494
double
ratio
= 1;
495
if
(
NHxz_NHyz
> 0) {
496
ratio
= (
NHxz_NHyz
-
NHz_NHxyz
) /
NHxz_NHyz
;
497
}
else
if
(
NHz_NHxyz
> 0) {
498
ratio
= (
NHxz_NHyz
-
NHz_NHxyz
) /
NHz_NHxyz
;
499
}
500
if
(
ratio
< 0)
ratio
= -
ratio
;
501
if
(
ratio
<
threshold__
) {
502
NHz_NHxyz
=
NHxz_NHyz
;
// ensure that the score is equal to 0
503
}
504
505
score
=
NHxz_NHyz
-
NHz_NHxyz
;
506
}
else
{
507
const
double
NHxy
508
= -
NH__
.
score
(
IdCondSet
<
ALLOC
>(
var_x
,
509
var_y
,
510
empty_conditioning_set__
,
511
true
,
512
false
));
513
const
double
NHx
= -
NH__
.
score
(
var_x
);
514
const
double
NHy
= -
NH__
.
score
(
var_y
);
515
516
double
NHx_NHy
=
NHx
+
NHy
;
517
518
// avoid numeric instability due to rounding errors
519
double
ratio
= 1;
520
if
(
NHx_NHy
> 0) {
521
ratio
= (
NHx_NHy
-
NHxy
) /
NHx_NHy
;
522
}
else
if
(
NHxy
> 0) {
523
ratio
= (
NHx_NHy
-
NHxy
) /
NHxy
;
524
}
525
if
(
ratio
< 0)
ratio
= -
ratio
;
526
if
(
ratio
<
threshold__
) {
527
NHx_NHy
=
NHxy
;
// ensure that the score is equal to 0
528
}
529
530
score
=
NHx_NHy
-
NHxy
;
531
}
532
533
534
// shall we put the score into the cache?
535
if
(
use_ICache__
) {
ICache__
.
insert
(
idset_xyz
,
score
); }
536
537
return
score
;
538
}
539
540
541
/// return N times the mutual information for conditioned triples of variables
542
template
<
template
<
typename
>
class
ALLOC
>
543
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
NI_score__
(
544
NodeId
var_x
,
545
NodeId
var_y
,
546
NodeId
var_z
,
547
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
ui_ids
) {
548
// conditional 3-point mutual information formula:
549
// I(x;y;z|{ui}) = I(x;y|{ui}) - I(x;y|z,{ui})
550
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >
uiz_ids
=
ui_ids
;
551
uiz_ids
.
push_back
(
var_z
);
552
return
NI_score__
(
var_x
,
var_y
,
ui_ids
) -
NI_score__
(
var_x
,
var_y
,
uiz_ids
);
553
}
554
555
556
/// 2pt penalty
557
template
<
template
<
typename
>
class
ALLOC
>
558
double
CorrectedMutualInformation
<
ALLOC
>::
K_score__
(
559
NodeId
var1
,
560
NodeId
var2
,
561
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
conditioning_ids
) {
562
// if no penalty, return 0
563
if
(
kmode__
==
KModeTypes
::
NoCorr
)
return
0.0;
564
565
566
// If using the K cache, verify whether the set isn't already known
567
IdCondSet
<
ALLOC
>
idset
;
568
if
(
use_KCache__
) {
569
idset
=
std
::
move
(
IdCondSet
<
ALLOC
>(
var1
,
var2
,
conditioning_ids
,
false
));
570
try
{
571
return
KCache__
.
score
(
idset
);
572
}
catch
(
const
NotFound
&) {}
573
}
574
575
// compute the score
576
double
score
;
577
size_t
rx
,
ry
,
rui
;
578
switch
(
kmode__
) {
579
case
KModeTypes
::
MDL
: {
580
const
auto
&
database
=
NH__
.
database
();
581
const
auto
&
node2cols
=
NH__
.
nodeId2Columns
();
582
583
rui
= 1;
584
if
(!
node2cols
.
empty
()) {
585
rx
=
database
.
domainSize
(
node2cols
.
second
(
var1
));
586
ry
=
database
.
domainSize
(
node2cols
.
second
(
var2
));
587
for
(
const
NodeId
i
:
conditioning_ids
) {
588
rui
*=
database
.
domainSize
(
node2cols
.
second
(
i
));
589
}
590
}
else
{
591
rx
=
database
.
domainSize
(
var1
);
592
ry
=
database
.
domainSize
(
var2
);
593
for
(
const
NodeId
i
:
conditioning_ids
) {
594
rui
*=
database
.
domainSize
(
i
);
595
}
596
}
597
598
// compute the size of the database, including the a priori
599
if
(!
use_KCache__
) {
600
idset
=
std
::
move
(
601
IdCondSet
<
ALLOC
>(
var1
,
var2
,
conditioning_ids
,
false
));
602
}
603
const
double
N
=
score_MDL__
.
N
(
idset
);
604
605
score
= 0.5 * (
rx
- 1) * (
ry
- 1) *
rui
*
std
::
log2
(
N
);
606
}
break
;
607
608
case
KModeTypes
::
NML
:
609
score
=
k_NML__
.
score
(
var1
,
var2
,
conditioning_ids
);
610
break
;
611
612
default
:
613
GUM_ERROR
(
NotImplementedYet
,
614
"CorrectedMutualInformation mode does "
615
"not support yet this correction"
);
616
}
617
618
// shall we put the score into the cache?
619
if
(
use_KCache__
) {
KCache__
.
insert
(
idset
,
score
); }
620
return
score
;
621
}
622
623
624
/// 3pt penalty
625
template
<
template
<
typename
>
class
ALLOC
>
626
INLINE
double
CorrectedMutualInformation
<
ALLOC
>::
K_score__
(
627
NodeId
var1
,
628
NodeId
var2
,
629
NodeId
var3
,
630
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
ui_ids
) {
631
// k(x;y;z|ui) = k(x;y|ui,z) - k(x;y|ui)
632
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >
uiz_ids
=
ui_ids
;
633
uiz_ids
.
push_back
(
var3
);
634
return
K_score__
(
var1
,
var2
,
uiz_ids
) -
K_score__
(
var1
,
var2
,
ui_ids
);
635
}
636
637
638
}
/* namespace learning */
639
640
}
/* namespace gum */
641
642
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31