aGrUM  0.16.0
correctedMutualInformation_tpl.h
Go to the documentation of this file.
1 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
32 namespace gum {
33 
34  namespace learning {
35 
37  template < template < typename > class ALLOC >
40  return __NH.getAllocator();
41  }
42 
43 
45  template < template < typename > class ALLOC >
47  const DBRowGeneratorParser< ALLOC >& parser,
48  const Apriori< ALLOC >& apriori,
49  const std::vector< std::pair< std::size_t, std::size_t >,
50  ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
51  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
52  nodeId2columns,
54  __NH(parser, apriori, ranges, nodeId2columns, alloc),
55  __k_NML(parser, apriori, ranges, nodeId2columns, alloc),
56  __score_MDL(parser, apriori, ranges, nodeId2columns, alloc),
57  __ICache(alloc), __KCache(alloc) {
58  GUM_CONSTRUCTOR(CorrectedMutualInformation);
59  }
60 
61 
63  template < template < typename > class ALLOC >
65  const DBRowGeneratorParser< ALLOC >& parser,
66  const Apriori< ALLOC >& apriori,
67  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
68  nodeId2columns,
70  __NH(parser, apriori, nodeId2columns, alloc),
71  __k_NML(parser, apriori, nodeId2columns, alloc),
72  __score_MDL(parser, apriori, nodeId2columns, alloc), __ICache(alloc),
73  __KCache(alloc) {
74  GUM_CONSTRUCTOR(CorrectedMutualInformation);
75  }
76 
77 
79  template < template < typename > class ALLOC >
81  const CorrectedMutualInformation< ALLOC >& from,
83  __NH(from.__NH, alloc),
84  __k_NML(from.__k_NML, alloc), __score_MDL(from.__score_MDL, alloc),
85  __kmode(from.__kmode), __use_ICache(from.__use_ICache),
86  __use_HCache(from.__use_HCache), __use_KCache(from.__use_KCache),
87  __use_CnrCache(from.__use_CnrCache), __ICache(from.__ICache, alloc),
88  __KCache(from.__KCache, alloc) {
89  GUM_CONS_CPY(CorrectedMutualInformation);
90  }
91 
92 
94  template < template < typename > class ALLOC >
96  const CorrectedMutualInformation< ALLOC >& from) :
98 
99 
101  template < template < typename > class ALLOC >
103  CorrectedMutualInformation< ALLOC >&& from,
105  __NH(std::move(from.__NH), alloc),
106  __k_NML(std::move(from.__k_NML), alloc),
107  __score_MDL(std::move(from.__score_MDL), alloc), __kmode(from.__kmode),
108  __use_ICache(from.__use_ICache), __use_HCache(from.__use_HCache),
109  __use_KCache(from.__use_KCache), __use_CnrCache(from.__use_CnrCache),
110  __ICache(std::move(from.__ICache), alloc),
111  __KCache(std::move(from.__KCache), alloc) {
112  GUM_CONS_MOV(CorrectedMutualInformation);
113  }
114 
115 
117  template < template < typename > class ALLOC >
119  CorrectedMutualInformation< ALLOC >&& from) :
120  CorrectedMutualInformation(std::move(from), from.getAllocator()) {}
121 
122 
124  template < template < typename > class ALLOC >
125  CorrectedMutualInformation< ALLOC >*
128  alloc) const {
129  ALLOC< CorrectedMutualInformation< ALLOC > > allocator(alloc);
130  CorrectedMutualInformation< ALLOC >* new_score = allocator.allocate(1);
131  try {
132  allocator.construct(new_score, *this, alloc);
133  } catch (...) {
134  allocator.deallocate(new_score, 1);
135  throw;
136  }
137 
138  return new_score;
139  }
140 
141 
143  template < template < typename > class ALLOC >
144  CorrectedMutualInformation< ALLOC >*
146  return clone(this->getAllocator());
147  }
148 
149 
151  template < template < typename > class ALLOC >
153  // for debugging purposes
154  GUM_DESTRUCTOR(CorrectedMutualInformation);
155  }
156 
157 
159  template < template < typename > class ALLOC >
160  CorrectedMutualInformation< ALLOC >& CorrectedMutualInformation< ALLOC >::
161  operator=(const CorrectedMutualInformation< ALLOC >& from) {
162  if (this != &from) {
163  __NH = from.__NH;
164  __k_NML = from.__k_NML;
165  __score_MDL = from.__score_MDL;
166  __kmode = from.__kmode;
167  __use_ICache = from.__use_ICache;
168  __use_HCache = from.__use_HCache;
169  __use_KCache = from.__use_KCache;
170  __use_CnrCache = from.__use_CnrCache;
171  __ICache = from.__ICache;
172  __KCache = from.__KCache;
173  }
174  return *this;
175  }
176 
177 
179  template < template < typename > class ALLOC >
180  CorrectedMutualInformation< ALLOC >& CorrectedMutualInformation< ALLOC >::
181  operator=(CorrectedMutualInformation< ALLOC >&& from) {
182  if (this != &from) {
183  __NH = std::move(from.__NH);
184  __k_NML = std::move(from.__k_NML);
185  __score_MDL = std::move(from.__score_MDL);
186  __kmode = from.__kmode;
187  __use_ICache = from.__use_ICache;
188  __use_HCache = from.__use_HCache;
189  __use_KCache = from.__use_KCache;
190  __use_CnrCache = from.__use_CnrCache;
191  __ICache = std::move(from.__ICache);
192  __KCache = std::move(from.__KCache);
193  }
194  return *this;
195  }
196 
197 
199  template < template < typename > class ALLOC >
200  INLINE void CorrectedMutualInformation< ALLOC >::useCache(bool on_off) {
201  useICache(on_off);
202  useHCache(on_off);
203  useKCache(on_off);
204  useCnrCache(on_off);
205  }
206 
207 
209  template < template < typename > class ALLOC >
210  INLINE void CorrectedMutualInformation< ALLOC >::useICache(bool on_off) {
211  if (!on_off) __ICache.clear();
212  __use_ICache = on_off;
213  }
214 
215 
217  template < template < typename > class ALLOC >
218  INLINE void CorrectedMutualInformation< ALLOC >::useHCache(bool on_off) {
219  if (!on_off) __NH.clearCache();
220  __use_HCache = on_off;
221  __NH.useCache(on_off);
222  }
223 
224 
226  template < template < typename > class ALLOC >
227  INLINE void CorrectedMutualInformation< ALLOC >::useKCache(bool on_off) {
228  if (!on_off) __KCache.clear();
229  __use_KCache = on_off;
230  }
231 
232 
234  template < template < typename > class ALLOC >
235  INLINE void CorrectedMutualInformation< ALLOC >::useCnrCache(bool on_off) {
236  if (!on_off) __k_NML.clearCache();
237  __use_CnrCache = on_off;
238  __k_NML.useCache(on_off);
239  }
240 
241 
243  template < template < typename > class ALLOC >
245  __NH.clear();
246  __k_NML.clear();
247  __score_MDL.clear();
248  clearCache();
249  }
250 
251 
253  template < template < typename > class ALLOC >
255  __NH.clearCache();
256  __k_NML.clearCache();
257  __ICache.clear();
258  __KCache.clear();
259  }
260 
261 
263  template < template < typename > class ALLOC >
265  __ICache.clear();
266  }
267 
268 
270  template < template < typename > class ALLOC >
272  __NH.clearCache();
273  }
274 
275 
277  template < template < typename > class ALLOC >
279  __KCache.clear();
280  }
281 
282 
284  template < template < typename > class ALLOC >
286  __k_NML.clearCache();
287  }
288 
289 
291  template < template < typename > class ALLOC >
292  void
294  __NH.setMaxNbThreads(nb);
295  __k_NML.setMaxNbThreads(nb);
296  __score_MDL.setMaxNbThreads(nb);
297  }
298 
299 
301  template < template < typename > class ALLOC >
303  return __NH.nbThreads();
304  }
305 
306 
309  template < template < typename > class ALLOC >
311  const std::size_t nb) const {
312  __NH.setMinNbRowsPerThread(nb);
313  __k_NML.setMinNbRowsPerThread(nb);
314  __score_MDL.setMinNbRowsPerThread(nb);
315  }
316 
317 
319  template < template < typename > class ALLOC >
320  INLINE std::size_t
322  return __NH.minNbRowsPerThread();
323  }
324 
325 
327 
333  template < template < typename > class ALLOC >
334  template < template < typename > class XALLOC >
336  const std::vector< std::pair< std::size_t, std::size_t >,
337  XALLOC< std::pair< std::size_t, std::size_t > > >&
338  new_ranges) {
339  std::vector< std::pair< std::size_t, std::size_t >,
340  ALLOC< std::pair< std::size_t, std::size_t > > >
341  old_ranges = ranges();
342 
343  __NH.setRanges(new_ranges);
344  __k_NML.setRanges(new_ranges);
345  __score_MDL.setRanges(new_ranges);
346 
347  if (old_ranges != ranges()) clear();
348  }
349 
350 
352  template < template < typename > class ALLOC >
354  std::vector< std::pair< std::size_t, std::size_t >,
355  ALLOC< std::pair< std::size_t, std::size_t > > >
356  old_ranges = ranges();
357  __NH.clearRanges();
358  __k_NML.clearRanges();
359  __score_MDL.clearRanges();
360  if (old_ranges != ranges()) clear();
361  }
362 
363 
365  template < template < typename > class ALLOC >
366  INLINE const std::vector< std::pair< std::size_t, std::size_t >,
367  ALLOC< std::pair< std::size_t, std::size_t > > >&
369  return __NH.ranges();
370  }
371 
372 
374  template < template < typename > class ALLOC >
376  clearCache();
377  __kmode = KModeTypes::MDL;
378  }
379 
380 
382  template < template < typename > class ALLOC >
384  clearCache();
385  __kmode = KModeTypes::NML;
386  }
387 
388 
390  template < template < typename > class ALLOC >
392  clearCache();
393  __kmode = KModeTypes::NoCorr;
394  }
395 
396 
398  template < template < typename > class ALLOC >
400  NodeId var2) {
401  return score(var1, var2, __empty_conditioning_set);
402  }
403 
404 
406  template < template < typename > class ALLOC >
408  NodeId var1,
409  NodeId var2,
410  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
411  return __NI_score(var1, var2, conditioning_ids)
412  - __K_score(var1, var2, conditioning_ids);
413  }
414 
415 
417  template < template < typename > class ALLOC >
419  NodeId var2,
420  NodeId var3) {
421  return score(var1, var2, var3, __empty_conditioning_set);
422  }
423 
424 
426  template < template < typename > class ALLOC >
428  NodeId var1,
429  NodeId var2,
430  NodeId var3,
431  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
432  return __NI_score(var1, var2, var3, conditioning_ids)
433  + __K_score(var1, var2, var3, conditioning_ids);
434  }
435 
436 
438  template < template < typename > class ALLOC >
439  double CorrectedMutualInformation< ALLOC >::__NI_score(
440  NodeId var_x,
441  NodeId var_y,
442  const std::vector< NodeId, ALLOC< NodeId > >& vars_z) {
443  /*
444  * We have a few partial entropies to compute in order to have the
445  * 2-point mutual information:
446  * I(x;y) = H(x) + H(y) - H(x,y)
447  * correspondingly
448  * I(x;y) = Hx + Hy - Hxy
449  * or
450  * I(x;y|z) = H(x,z) + H(y,z) - H(z) - H(x,y,z)
451  * correspondingly
452  * I(x;y|z) = Hxz + Hyz - Hz - Hxyz
453  * Note that Entropy H is equal to 1/N times the log2Likelihood,
454  * where N is the size of the database.
455  * Remember that we return N times I(x;y|z)
456  */
457 
458  // if the score has already been computed, get its value
459  const IdSet< ALLOC > idset_xyz(var_x, var_y, vars_z, false, false);
460  if (__use_ICache) {
461  try {
462  return __ICache.score(idset_xyz);
463  } catch (const NotFound&) {}
464  }
465 
466  // compute the score
467 
468  // here, we distinguish nodesets with conditioning nodes from those
469  // without conditioning nodes
470  double score;
471  if (!vars_z.empty()) {
472  std::vector< NodeId, ALLOC< NodeId > > vars(vars_z);
473  // std::sort(vars.begin(), vars.end());
474  vars.push_back(var_x);
475  vars.push_back(var_y);
476  const double NHxyz = -__NH.score(IdSet< ALLOC >(vars, false, true));
477 
478  vars.pop_back();
479  const double NHxz = -__NH.score(IdSet< ALLOC >(vars, false, true));
480 
481  vars.pop_back();
482  vars.push_back(var_y);
483  const double NHyz = -__NH.score(IdSet< ALLOC >(vars, false, true));
484 
485  vars.pop_back();
486  const double NHz = -__NH.score(IdSet< ALLOC >(vars, false, true));
487 
488  const double NHxz_NHyz = NHxz + NHyz;
489  double NHz_NHxyz = NHz + NHxyz;
490 
491  // avoid numeric instability due to rounding errors
492  double ratio = 1;
493  if (NHxz_NHyz > 0) {
494  ratio = (NHxz_NHyz - NHz_NHxyz) / NHxz_NHyz;
495  } else if (NHz_NHxyz > 0) {
496  ratio = (NHxz_NHyz - NHz_NHxyz) / NHz_NHxyz;
497  }
498  if (ratio < 0) ratio = -ratio;
499  if (ratio < __threshold) {
500  NHz_NHxyz = NHxz_NHyz; // ensure that the score is equal to 0
501  }
502 
503  score = NHxz_NHyz - NHz_NHxyz;
504  } else {
505  const double NHxy = -__NH.score(
506  IdSet< ALLOC >(var_x, var_y, __empty_conditioning_set, true, false));
507  const double NHx = -__NH.score(var_x);
508  const double NHy = -__NH.score(var_y);
509 
510  double NHx_NHy = NHx + NHy;
511 
512  // avoid numeric instability due to rounding errors
513  double ratio = 1;
514  if (NHx_NHy > 0) {
515  ratio = (NHx_NHy - NHxy) / NHx_NHy;
516  } else if (NHxy > 0) {
517  ratio = (NHx_NHy - NHxy) / NHxy;
518  }
519  if (ratio < 0) ratio = -ratio;
520  if (ratio < __threshold) {
521  NHx_NHy = NHxy; // ensure that the score is equal to 0
522  }
523 
524  score = NHx_NHy - NHxy;
525  }
526 
527 
528  // shall we put the score into the cache?
529  if (__use_ICache) { __ICache.insert(idset_xyz, score); }
530 
531  return score;
532  }
533 
534 
536  template < template < typename > class ALLOC >
537  INLINE double CorrectedMutualInformation< ALLOC >::__NI_score(
538  NodeId var_x,
539  NodeId var_y,
540  NodeId var_z,
541  const std::vector< NodeId, ALLOC< NodeId > >& ui_ids) {
542  // conditional 3-point mutual information formula:
543  // I(x;y;z|{ui}) = I(x;y|{ui}) - I(x;y|z,{ui})
544  std::vector< NodeId, ALLOC< NodeId > > uiz_ids = ui_ids;
545  uiz_ids.push_back(var_z);
546  return __NI_score(var_x, var_y, ui_ids) - __NI_score(var_x, var_y, uiz_ids);
547  }
548 
549 
551  template < template < typename > class ALLOC >
552  double CorrectedMutualInformation< ALLOC >::__K_score(
553  NodeId var1,
554  NodeId var2,
555  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
556  // if no penalty, return 0
557  if (__kmode == KModeTypes::NoCorr) return 0.0;
558 
559 
560  // If using the K cache, verify whether the set isn't already known
561  IdSet< ALLOC > idset;
562  if (__use_KCache) {
563  idset = std::move(IdSet< ALLOC >(var1, var2, conditioning_ids, false));
564  try {
565  return __KCache.score(idset);
566  } catch (const NotFound&) {}
567  }
568 
569  // compute the score
570  double score;
571  size_t rx, ry, rui;
572  switch (__kmode) {
573  case KModeTypes::MDL: {
574  const auto& database = __NH.database();
575  const auto& node2cols = __NH.nodeId2Columns();
576 
577  rui = 1;
578  if (!node2cols.empty()) {
579  rx = database.domainSize(node2cols.second(var1));
580  ry = database.domainSize(node2cols.second(var2));
581  for (const NodeId i : conditioning_ids) {
582  rui *= database.domainSize(node2cols.second(i));
583  }
584  } else {
585  rx = database.domainSize(var1);
586  ry = database.domainSize(var2);
587  for (const NodeId i : conditioning_ids) {
588  rui *= database.domainSize(i);
589  }
590  }
591 
592  // compute the size of the database, including the a priori
593  if (!__use_KCache) {
594  idset = std::move(IdSet< ALLOC >(var1, var2, conditioning_ids, false));
595  }
596  const double N = __score_MDL.N(idset);
597 
598  score = 0.5 * (rx - 1) * (ry - 1) * rui * std::log2(N);
599  } break;
600 
601  case KModeTypes::NML:
602  score = __k_NML.score(var1, var2, conditioning_ids);
603  break;
604 
605  default:
606  GUM_ERROR(NotImplementedYet,
607  "CorrectedMutualInformation mode does "
608  "not support yet this correction");
609  }
610 
611  // shall we put the score into the cache?
612  if (__use_KCache) { __KCache.insert(idset, score); }
613  return score;
614  }
615 
616 
618  template < template < typename > class ALLOC >
619  INLINE double CorrectedMutualInformation< ALLOC >::__K_score(
620  NodeId var1,
621  NodeId var2,
622  NodeId var3,
623  const std::vector< NodeId, ALLOC< NodeId > >& ui_ids) {
624  // k(x;y;z|ui) = k(x;y|ui,z) - k(x;y|ui)
625  std::vector< NodeId, ALLOC< NodeId > > uiz_ids = ui_ids;
626  uiz_ids.push_back(var3);
627  return __K_score(var1, var2, uiz_ids) - __K_score(var1, var2, ui_ids);
628  }
629 
630 
631  } /* namespace learning */
632 
633 } /* namespace gum */
634 
635 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
void useNML()
use the kNML penalty function
void clearRanges()
reset the ranges to the one range corresponding to the whole database
void useCnrCache(bool on_off)
turn on/off the use of the CnrCache (the cache for the Cnr formula)
virtual void useCache(bool on_off)
turn on/off the use of all the caches
void useHCache(bool on_off)
turn on/off the use of the HCache (the cache for the entropies)
virtual ~CorrectedMutualInformation()
destructor
virtual void clear()
clears all the data structures from memory
void clearICache()
clears the ICache (the mutual information cache)
CorrectedMutualInformation< ALLOC > & operator=(const CorrectedMutualInformation< ALLOC > &from)
copy operator
ALLOC< NodeId > allocator_type
type for the allocators passed in arguments of methods
allocator_type getAllocator() const
returns the allocator used by the score
STL namespace.
void useKCache(bool on_off)
turn on/off the use of the KCache (the cache for the penalties)
virtual CorrectedMutualInformation< ALLOC > * clone() const
virtual copy constructor
virtual void clearCache()
clears all the current caches
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > & ranges() const
returns the current ranges
double score(NodeId var1, NodeId var2)
returns the 2-point mutual information corresponding to a given nodeset
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the mutual information
CorrectedMutualInformation(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &apriori, const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > &ranges, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
void clearKCache()
clears the KCache (the cache for the penalties)
void clearHCache()
clears the HCache (the cache for the entropies)
virtual std::size_t nbThreads() const
returns the number of threads used to parse the database
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context ...
void useMDL()
use the MDL penalty function
void useICache(bool on_off)
turn on/off the use of the ICache (the mutual information cache)
void clearCnrCache()
clears the CnrCache (the cache for the Cnr formula)
Size NodeId
Type for node ids.
Definition: graphElements.h:98
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
void useNoCorr()
use no correction/penalty function
virtual void setMaxNbThreads(std::size_t nb) const
changes the max number of threads used to parse the database