aGrUM  0.17.1
a C++ library for (probabilistic) graphical models
correctedMutualInformation_tpl.h
Go to the documentation of this file.
1 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
32 namespace gum {
33 
34  namespace learning {
35 
37  template < template < typename > class ALLOC >
40  return __NH.getAllocator();
41  }
42 
43 
45  template < template < typename > class ALLOC >
47  const DBRowGeneratorParser< ALLOC >& parser,
48  const Apriori< ALLOC >& apriori,
49  const std::vector< std::pair< std::size_t, std::size_t >,
50  ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
51  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
52  nodeId2columns,
54  __NH(parser, apriori, ranges, nodeId2columns, alloc),
55  __k_NML(parser, apriori, ranges, nodeId2columns, alloc),
56  __score_MDL(parser, apriori, ranges, nodeId2columns, alloc),
57  __ICache(alloc), __KCache(alloc) {
58  GUM_CONSTRUCTOR(CorrectedMutualInformation);
59  }
60 
61 
63  template < template < typename > class ALLOC >
65  const DBRowGeneratorParser< ALLOC >& parser,
66  const Apriori< ALLOC >& apriori,
67  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
68  nodeId2columns,
70  __NH(parser, apriori, nodeId2columns, alloc),
71  __k_NML(parser, apriori, nodeId2columns, alloc),
72  __score_MDL(parser, apriori, nodeId2columns, alloc), __ICache(alloc),
73  __KCache(alloc) {
74  GUM_CONSTRUCTOR(CorrectedMutualInformation);
75  }
76 
77 
79  template < template < typename > class ALLOC >
81  const CorrectedMutualInformation< ALLOC >& from,
83  __NH(from.__NH, alloc),
84  __k_NML(from.__k_NML, alloc), __score_MDL(from.__score_MDL, alloc),
85  __kmode(from.__kmode), __use_ICache(from.__use_ICache),
86  __use_HCache(from.__use_HCache), __use_KCache(from.__use_KCache),
87  __use_CnrCache(from.__use_CnrCache), __ICache(from.__ICache, alloc),
88  __KCache(from.__KCache, alloc) {
89  GUM_CONS_CPY(CorrectedMutualInformation);
90  }
91 
92 
94  template < template < typename > class ALLOC >
96  const CorrectedMutualInformation< ALLOC >& from) :
98 
99 
101  template < template < typename > class ALLOC >
103  CorrectedMutualInformation< ALLOC >&& from,
105  __NH(std::move(from.__NH), alloc),
106  __k_NML(std::move(from.__k_NML), alloc),
107  __score_MDL(std::move(from.__score_MDL), alloc), __kmode(from.__kmode),
108  __use_ICache(from.__use_ICache), __use_HCache(from.__use_HCache),
109  __use_KCache(from.__use_KCache), __use_CnrCache(from.__use_CnrCache),
110  __ICache(std::move(from.__ICache), alloc),
111  __KCache(std::move(from.__KCache), alloc) {
112  GUM_CONS_MOV(CorrectedMutualInformation);
113  }
114 
115 
117  template < template < typename > class ALLOC >
119  CorrectedMutualInformation< ALLOC >&& from) :
120  CorrectedMutualInformation(std::move(from), from.getAllocator()) {}
121 
122 
124  template < template < typename > class ALLOC >
125  CorrectedMutualInformation< ALLOC >*
128  alloc) const {
129  ALLOC< CorrectedMutualInformation< ALLOC > > allocator(alloc);
130  CorrectedMutualInformation< ALLOC >* new_score = allocator.allocate(1);
131  try {
132  allocator.construct(new_score, *this, alloc);
133  } catch (...) {
134  allocator.deallocate(new_score, 1);
135  throw;
136  }
137 
138  return new_score;
139  }
140 
141 
143  template < template < typename > class ALLOC >
144  CorrectedMutualInformation< ALLOC >*
146  return clone(this->getAllocator());
147  }
148 
149 
151  template < template < typename > class ALLOC >
153  // for debugging purposes
154  GUM_DESTRUCTOR(CorrectedMutualInformation);
155  }
156 
157 
159  template < template < typename > class ALLOC >
160  CorrectedMutualInformation< ALLOC >&
162  const CorrectedMutualInformation< ALLOC >& from) {
163  if (this != &from) {
164  __NH = from.__NH;
165  __k_NML = from.__k_NML;
166  __score_MDL = from.__score_MDL;
167  __kmode = from.__kmode;
168  __use_ICache = from.__use_ICache;
169  __use_HCache = from.__use_HCache;
170  __use_KCache = from.__use_KCache;
171  __use_CnrCache = from.__use_CnrCache;
172  __ICache = from.__ICache;
173  __KCache = from.__KCache;
174  }
175  return *this;
176  }
177 
178 
180  template < template < typename > class ALLOC >
181  CorrectedMutualInformation< ALLOC >&
183  CorrectedMutualInformation< ALLOC >&& from) {
184  if (this != &from) {
185  __NH = std::move(from.__NH);
186  __k_NML = std::move(from.__k_NML);
187  __score_MDL = std::move(from.__score_MDL);
188  __kmode = from.__kmode;
189  __use_ICache = from.__use_ICache;
190  __use_HCache = from.__use_HCache;
191  __use_KCache = from.__use_KCache;
192  __use_CnrCache = from.__use_CnrCache;
193  __ICache = std::move(from.__ICache);
194  __KCache = std::move(from.__KCache);
195  }
196  return *this;
197  }
198 
199 
201  template < template < typename > class ALLOC >
202  INLINE void CorrectedMutualInformation< ALLOC >::useCache(bool on_off) {
203  useICache(on_off);
204  useHCache(on_off);
205  useKCache(on_off);
206  useCnrCache(on_off);
207  }
208 
209 
211  template < template < typename > class ALLOC >
212  INLINE void CorrectedMutualInformation< ALLOC >::useICache(bool on_off) {
213  if (!on_off) __ICache.clear();
214  __use_ICache = on_off;
215  }
216 
217 
219  template < template < typename > class ALLOC >
220  INLINE void CorrectedMutualInformation< ALLOC >::useHCache(bool on_off) {
221  if (!on_off) __NH.clearCache();
222  __use_HCache = on_off;
223  __NH.useCache(on_off);
224  }
225 
226 
228  template < template < typename > class ALLOC >
229  INLINE void CorrectedMutualInformation< ALLOC >::useKCache(bool on_off) {
230  if (!on_off) __KCache.clear();
231  __use_KCache = on_off;
232  }
233 
234 
236  template < template < typename > class ALLOC >
237  INLINE void CorrectedMutualInformation< ALLOC >::useCnrCache(bool on_off) {
238  if (!on_off) __k_NML.clearCache();
239  __use_CnrCache = on_off;
240  __k_NML.useCache(on_off);
241  }
242 
243 
245  template < template < typename > class ALLOC >
247  __NH.clear();
248  __k_NML.clear();
249  __score_MDL.clear();
250  clearCache();
251  }
252 
253 
255  template < template < typename > class ALLOC >
257  __NH.clearCache();
258  __k_NML.clearCache();
259  __ICache.clear();
260  __KCache.clear();
261  }
262 
263 
265  template < template < typename > class ALLOC >
267  __ICache.clear();
268  }
269 
270 
272  template < template < typename > class ALLOC >
274  __NH.clearCache();
275  }
276 
277 
279  template < template < typename > class ALLOC >
281  __KCache.clear();
282  }
283 
284 
286  template < template < typename > class ALLOC >
288  __k_NML.clearCache();
289  }
290 
291 
293  template < template < typename > class ALLOC >
294  void
296  __NH.setMaxNbThreads(nb);
297  __k_NML.setMaxNbThreads(nb);
298  __score_MDL.setMaxNbThreads(nb);
299  }
300 
301 
303  template < template < typename > class ALLOC >
305  return __NH.nbThreads();
306  }
307 
308 
311  template < template < typename > class ALLOC >
313  const std::size_t nb) const {
314  __NH.setMinNbRowsPerThread(nb);
315  __k_NML.setMinNbRowsPerThread(nb);
316  __score_MDL.setMinNbRowsPerThread(nb);
317  }
318 
319 
321  template < template < typename > class ALLOC >
322  INLINE std::size_t
324  return __NH.minNbRowsPerThread();
325  }
326 
327 
329 
335  template < template < typename > class ALLOC >
336  template < template < typename > class XALLOC >
338  const std::vector< std::pair< std::size_t, std::size_t >,
339  XALLOC< std::pair< std::size_t, std::size_t > > >&
340  new_ranges) {
341  std::vector< std::pair< std::size_t, std::size_t >,
342  ALLOC< std::pair< std::size_t, std::size_t > > >
343  old_ranges = ranges();
344 
345  __NH.setRanges(new_ranges);
346  __k_NML.setRanges(new_ranges);
347  __score_MDL.setRanges(new_ranges);
348 
349  if (old_ranges != ranges()) clear();
350  }
351 
352 
354  template < template < typename > class ALLOC >
356  std::vector< std::pair< std::size_t, std::size_t >,
357  ALLOC< std::pair< std::size_t, std::size_t > > >
358  old_ranges = ranges();
359  __NH.clearRanges();
360  __k_NML.clearRanges();
361  __score_MDL.clearRanges();
362  if (old_ranges != ranges()) clear();
363  }
364 
365 
367  template < template < typename > class ALLOC >
368  INLINE const std::vector< std::pair< std::size_t, std::size_t >,
369  ALLOC< std::pair< std::size_t, std::size_t > > >&
371  return __NH.ranges();
372  }
373 
374 
376  template < template < typename > class ALLOC >
378  clearCache();
379  __kmode = KModeTypes::MDL;
380  }
381 
382 
384  template < template < typename > class ALLOC >
386  clearCache();
387  __kmode = KModeTypes::NML;
388  }
389 
390 
392  template < template < typename > class ALLOC >
394  clearCache();
395  __kmode = KModeTypes::NoCorr;
396  }
397 
398 
400  template < template < typename > class ALLOC >
402  NodeId var2) {
403  return score(var1, var2, __empty_conditioning_set);
404  }
405 
406 
408  template < template < typename > class ALLOC >
410  NodeId var1,
411  NodeId var2,
412  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
413  return __NI_score(var1, var2, conditioning_ids)
414  - __K_score(var1, var2, conditioning_ids);
415  }
416 
417 
419  template < template < typename > class ALLOC >
421  NodeId var2,
422  NodeId var3) {
423  return score(var1, var2, var3, __empty_conditioning_set);
424  }
425 
426 
428  template < template < typename > class ALLOC >
430  NodeId var1,
431  NodeId var2,
432  NodeId var3,
433  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
434  return __NI_score(var1, var2, var3, conditioning_ids)
435  + __K_score(var1, var2, var3, conditioning_ids);
436  }
437 
438 
440  template < template < typename > class ALLOC >
441  double CorrectedMutualInformation< ALLOC >::__NI_score(
442  NodeId var_x,
443  NodeId var_y,
444  const std::vector< NodeId, ALLOC< NodeId > >& vars_z) {
445  /*
446  * We have a few partial entropies to compute in order to have the
447  * 2-point mutual information:
448  * I(x;y) = H(x) + H(y) - H(x,y)
449  * correspondingly
450  * I(x;y) = Hx + Hy - Hxy
451  * or
452  * I(x;y|z) = H(x,z) + H(y,z) - H(z) - H(x,y,z)
453  * correspondingly
454  * I(x;y|z) = Hxz + Hyz - Hz - Hxyz
455  * Note that Entropy H is equal to 1/N times the log2Likelihood,
456  * where N is the size of the database.
457  * Remember that we return N times I(x;y|z)
458  */
459 
460  // if the score has already been computed, get its value
461  const IdSet< ALLOC > idset_xyz(var_x, var_y, vars_z, false, false);
462  if (__use_ICache) {
463  try {
464  return __ICache.score(idset_xyz);
465  } catch (const NotFound&) {}
466  }
467 
468  // compute the score
469 
470  // here, we distinguish nodesets with conditioning nodes from those
471  // without conditioning nodes
472  double score;
473  if (!vars_z.empty()) {
474  std::vector< NodeId, ALLOC< NodeId > > vars(vars_z);
475  // std::sort(vars.begin(), vars.end());
476  vars.push_back(var_x);
477  vars.push_back(var_y);
478  const double NHxyz = -__NH.score(IdSet< ALLOC >(vars, false, true));
479 
480  vars.pop_back();
481  const double NHxz = -__NH.score(IdSet< ALLOC >(vars, false, true));
482 
483  vars.pop_back();
484  vars.push_back(var_y);
485  const double NHyz = -__NH.score(IdSet< ALLOC >(vars, false, true));
486 
487  vars.pop_back();
488  const double NHz = -__NH.score(IdSet< ALLOC >(vars, false, true));
489 
490  const double NHxz_NHyz = NHxz + NHyz;
491  double NHz_NHxyz = NHz + NHxyz;
492 
493  // avoid numeric instability due to rounding errors
494  double ratio = 1;
495  if (NHxz_NHyz > 0) {
496  ratio = (NHxz_NHyz - NHz_NHxyz) / NHxz_NHyz;
497  } else if (NHz_NHxyz > 0) {
498  ratio = (NHxz_NHyz - NHz_NHxyz) / NHz_NHxyz;
499  }
500  if (ratio < 0) ratio = -ratio;
501  if (ratio < __threshold) {
502  NHz_NHxyz = NHxz_NHyz; // ensure that the score is equal to 0
503  }
504 
505  score = NHxz_NHyz - NHz_NHxyz;
506  } else {
507  const double NHxy = -__NH.score(
508  IdSet< ALLOC >(var_x, var_y, __empty_conditioning_set, true, false));
509  const double NHx = -__NH.score(var_x);
510  const double NHy = -__NH.score(var_y);
511 
512  double NHx_NHy = NHx + NHy;
513 
514  // avoid numeric instability due to rounding errors
515  double ratio = 1;
516  if (NHx_NHy > 0) {
517  ratio = (NHx_NHy - NHxy) / NHx_NHy;
518  } else if (NHxy > 0) {
519  ratio = (NHx_NHy - NHxy) / NHxy;
520  }
521  if (ratio < 0) ratio = -ratio;
522  if (ratio < __threshold) {
523  NHx_NHy = NHxy; // ensure that the score is equal to 0
524  }
525 
526  score = NHx_NHy - NHxy;
527  }
528 
529 
530  // shall we put the score into the cache?
531  if (__use_ICache) { __ICache.insert(idset_xyz, score); }
532 
533  return score;
534  }
535 
536 
538  template < template < typename > class ALLOC >
539  INLINE double CorrectedMutualInformation< ALLOC >::__NI_score(
540  NodeId var_x,
541  NodeId var_y,
542  NodeId var_z,
543  const std::vector< NodeId, ALLOC< NodeId > >& ui_ids) {
544  // conditional 3-point mutual information formula:
545  // I(x;y;z|{ui}) = I(x;y|{ui}) - I(x;y|z,{ui})
546  std::vector< NodeId, ALLOC< NodeId > > uiz_ids = ui_ids;
547  uiz_ids.push_back(var_z);
548  return __NI_score(var_x, var_y, ui_ids) - __NI_score(var_x, var_y, uiz_ids);
549  }
550 
551 
553  template < template < typename > class ALLOC >
554  double CorrectedMutualInformation< ALLOC >::__K_score(
555  NodeId var1,
556  NodeId var2,
557  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
558  // if no penalty, return 0
559  if (__kmode == KModeTypes::NoCorr) return 0.0;
560 
561 
562  // If using the K cache, verify whether the set isn't already known
563  IdSet< ALLOC > idset;
564  if (__use_KCache) {
565  idset = std::move(IdSet< ALLOC >(var1, var2, conditioning_ids, false));
566  try {
567  return __KCache.score(idset);
568  } catch (const NotFound&) {}
569  }
570 
571  // compute the score
572  double score;
573  size_t rx, ry, rui;
574  switch (__kmode) {
575  case KModeTypes::MDL: {
576  const auto& database = __NH.database();
577  const auto& node2cols = __NH.nodeId2Columns();
578 
579  rui = 1;
580  if (!node2cols.empty()) {
581  rx = database.domainSize(node2cols.second(var1));
582  ry = database.domainSize(node2cols.second(var2));
583  for (const NodeId i: conditioning_ids) {
584  rui *= database.domainSize(node2cols.second(i));
585  }
586  } else {
587  rx = database.domainSize(var1);
588  ry = database.domainSize(var2);
589  for (const NodeId i: conditioning_ids) {
590  rui *= database.domainSize(i);
591  }
592  }
593 
594  // compute the size of the database, including the a priori
595  if (!__use_KCache) {
596  idset = std::move(IdSet< ALLOC >(var1, var2, conditioning_ids, false));
597  }
598  const double N = __score_MDL.N(idset);
599 
600  score = 0.5 * (rx - 1) * (ry - 1) * rui * std::log2(N);
601  } break;
602 
603  case KModeTypes::NML:
604  score = __k_NML.score(var1, var2, conditioning_ids);
605  break;
606 
607  default:
608  GUM_ERROR(NotImplementedYet,
609  "CorrectedMutualInformation mode does "
610  "not support yet this correction");
611  }
612 
613  // shall we put the score into the cache?
614  if (__use_KCache) { __KCache.insert(idset, score); }
615  return score;
616  }
617 
618 
620  template < template < typename > class ALLOC >
621  INLINE double CorrectedMutualInformation< ALLOC >::__K_score(
622  NodeId var1,
623  NodeId var2,
624  NodeId var3,
625  const std::vector< NodeId, ALLOC< NodeId > >& ui_ids) {
626  // k(x;y;z|ui) = k(x;y|ui,z) - k(x;y|ui)
627  std::vector< NodeId, ALLOC< NodeId > > uiz_ids = ui_ids;
628  uiz_ids.push_back(var3);
629  return __K_score(var1, var2, uiz_ids) - __K_score(var1, var2, ui_ids);
630  }
631 
632 
633  } /* namespace learning */
634 
635 } /* namespace gum */
636 
637 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
void useNML()
use the kNML penalty function
void clearRanges()
reset the ranges to the one range corresponding to the whole database
void useCnrCache(bool on_off)
turn on/off the use of the CnrCache (the cache for the Cnr formula)
virtual void useCache(bool on_off)
turn on/off the use of all the caches
void useHCache(bool on_off)
turn on/off the use of the HCache (the cache for the entropies)
virtual ~CorrectedMutualInformation()
destructor
virtual void clear()
clears all the data structures from memory
void clearICache()
clears the ICache (the mutual information cache)
CorrectedMutualInformation< ALLOC > & operator=(const CorrectedMutualInformation< ALLOC > &from)
copy operator
ALLOC< NodeId > allocator_type
type for the allocators passed in arguments of methods
allocator_type getAllocator() const
returns the allocator used by the score
STL namespace.
void useKCache(bool on_off)
turn on/off the use of the KCache (the cache for the penalties)
virtual CorrectedMutualInformation< ALLOC > * clone() const
virtual copy constructor
virtual void clearCache()
clears all the current caches
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > & ranges() const
returns the current ranges
double score(NodeId var1, NodeId var2)
returns the 2-point mutual information corresponding to a given nodeset
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the mutual information
CorrectedMutualInformation(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &apriori, const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > &ranges, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
void clearKCache()
clears the KCache (the cache for the penalties)
void clearHCache()
clears the HCache (the cache for the entropies)
virtual std::size_t nbThreads() const
returns the number of threads used to parse the database
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context ...
void useMDL()
use the MDL penalty function
void useICache(bool on_off)
turn on/off the use of the ICache (the mutual information cache)
void clearCnrCache()
clears the CnrCache (the cache for the Cnr formula)
Size NodeId
Type for node ids.
Definition: graphElements.h:98
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
void useNoCorr()
use no correction/penalty function
virtual void setMaxNbThreads(std::size_t nb) const
changes the max number of threads used to parse the database