aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
correctedMutualInformation_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief The class computing n times the corrected mutual information,
25  * as used in the 3off2 algorithm
26  *
27  * @author Quentin FALCAND, Christophe GONZALES(@AMU) and Pierre-Henri
28  * WUILLEMIN(@LIP6).
29  */
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
32 namespace gum {
33 
34  namespace learning {
35 
36  /// returns the allocator used by the score
37  template < template < typename > class ALLOC >
38  typename CorrectedMutualInformation< ALLOC >::allocator_type
40  return _NH_.getAllocator();
41  }
42 
43 
44  /// default constructor
45  template < template < typename > class ALLOC >
48  const Apriori< ALLOC >& apriori,
49  const std::vector< std::pair< std::size_t, std::size_t >,
50  ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
56  _KCache_(alloc) {
58  }
59 
60 
61  /// default constructor
62  template < template < typename > class ALLOC >
65  const Apriori< ALLOC >& apriori,
72  }
73 
74 
75  /// copy constructor with a given allocator
76  template < template < typename > class ALLOC >
80  _NH_(from._NH_, alloc),
86  }
87 
88 
89  /// copy constructor
90  template < template < typename > class ALLOC >
94 
95 
96  /// move constructor with a given allocator
97  template < template < typename > class ALLOC >
101  _NH_(std::move(from._NH_), alloc),
107  }
108 
109 
110  /// move constructor
111  template < template < typename > class ALLOC >
115 
116 
117  /// virtual copy constructor with a given allocator
118  template < template < typename > class ALLOC >
120  const typename CorrectedMutualInformation< ALLOC >::allocator_type& alloc) const {
123  try {
125  } catch (...) {
127  throw;
128  }
129 
130  return new_score;
131  }
132 
133 
134  /// virtual copy constructor
135  template < template < typename > class ALLOC >
137  return clone(this->getAllocator());
138  }
139 
140 
141  /// destructor
142  template < template < typename > class ALLOC >
144  // for debugging purposes
146  }
147 
148 
149  /// copy operator
150  template < template < typename > class ALLOC >
153  if (this != &from) {
154  _NH_ = from._NH_;
155  _k_NML_ = from._k_NML_;
157  _kmode_ = from._kmode_;
164  }
165  return *this;
166  }
167 
168 
169  /// move operator
170  template < template < typename > class ALLOC >
173  if (this != &from) {
174  _NH_ = std::move(from._NH_);
177  _kmode_ = from._kmode_;
184  }
185  return *this;
186  }
187 
188 
189  /// turn on/off the use of all the caches
190  template < template < typename > class ALLOC >
192  useICache(on_off);
193  useHCache(on_off);
194  useKCache(on_off);
196  }
197 
198 
199  /// turn on/off the use of the I cache
200  template < template < typename > class ALLOC >
202  if (!on_off) _ICache_.clear();
204  }
205 
206 
207  /// turn on/off the use of the H cache
208  template < template < typename > class ALLOC >
210  if (!on_off) _NH_.clearCache();
213  }
214 
215 
216  /// turn on/off the use of the K cache
217  template < template < typename > class ALLOC >
219  if (!on_off) _KCache_.clear();
221  }
222 
223 
224  /// turn on/off the use of the Cnr cache
225  template < template < typename > class ALLOC >
227  if (!on_off) _k_NML_.clearCache();
230  }
231 
232 
233  /// clears all the data structures from memory
234  template < template < typename > class ALLOC >
236  _NH_.clear();
237  _k_NML_.clear();
238  _score_MDL_.clear();
239  clearCache();
240  }
241 
242 
243  /// clears the current cache (clear nodesets as well)
244  template < template < typename > class ALLOC >
246  _NH_.clearCache();
248  _ICache_.clear();
249  _KCache_.clear();
250  }
251 
252 
253  /// clears the ICache (the mutual information cache)
254  template < template < typename > class ALLOC >
256  _ICache_.clear();
257  }
258 
259 
260  /// clears the HCache (the cache for the entropies)
261  template < template < typename > class ALLOC >
263  _NH_.clearCache();
264  }
265 
266 
267  /// clears the KCache (the cache for the penalties)
268  template < template < typename > class ALLOC >
270  _KCache_.clear();
271  }
272 
273 
274  /// clears the CnrCache (the cache for the Cnr formula)
275  template < template < typename > class ALLOC >
278  }
279 
280 
281  /// changes the max number of threads used to parse the database
282  template < template < typename > class ALLOC >
287  }
288 
289 
290  /// returns the number of threads used to parse the database
291  template < template < typename > class ALLOC >
293  return _NH_.nbThreads();
294  }
295 
296 
297  /** @brief changes the number min of rows a thread should process in a
298  * multithreading context */
299  template < template < typename > class ALLOC >
304  }
305 
306 
307  /// returns the minimum of rows that each thread should process
308  template < template < typename > class ALLOC >
310  return _NH_.minNbRowsPerThread();
311  }
312 
313 
314  /// sets new ranges to perform the countings used by the score
315  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
316  * indices. The countings are then performed only on the union of the
317  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
318  * cross validation tasks, in which part of the database should be ignored.
319  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
320  * the whole database. */
321  template < template < typename > class ALLOC >
322  template < template < typename > class XALLOC >
324  const std::vector< std::pair< std::size_t, std::size_t >,
325  XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges) {
326  std::vector< std::pair< std::size_t, std::size_t >,
327  ALLOC< std::pair< std::size_t, std::size_t > > >
328  old_ranges = ranges();
329 
333 
334  if (old_ranges != ranges()) clear();
335  }
336 
337 
338  /// reset the ranges to the one range corresponding to the whole database
339  template < template < typename > class ALLOC >
341  std::vector< std::pair< std::size_t, std::size_t >,
342  ALLOC< std::pair< std::size_t, std::size_t > > >
343  old_ranges = ranges();
344  _NH_.clearRanges();
347  if (old_ranges != ranges()) clear();
348  }
349 
350 
351  /// returns the current ranges
352  template < template < typename > class ALLOC >
353  INLINE const std::vector< std::pair< std::size_t, std::size_t >,
354  ALLOC< std::pair< std::size_t, std::size_t > > >&
356  return _NH_.ranges();
357  }
358 
359 
360  /// use the MDL penalty function
361  template < template < typename > class ALLOC >
363  clearCache();
365  }
366 
367 
368  /// use the kNML penalty function
369  template < template < typename > class ALLOC >
371  clearCache();
373  }
374 
375 
376  /// use no correction/penalty function
377  template < template < typename > class ALLOC >
379  clearCache();
381  }
382 
383 
384  /// returns the 2-point mutual information corresponding to a given nodeset
385  template < template < typename > class ALLOC >
388  }
389 
390 
391  /// returns the 2-point mutual information corresponding to a given nodeset
392  template < template < typename > class ALLOC >
394  NodeId var1,
395  NodeId var2,
396  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
398  }
399 
400 
401  /// returns the 3-point mutual information corresponding to a given nodeset
402  template < template < typename > class ALLOC >
403  INLINE double
406  }
407 
408 
409  /// returns the 3-point mutual information corresponding to a given nodeset
410  template < template < typename > class ALLOC >
412  NodeId var1,
413  NodeId var2,
414  NodeId var3,
415  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
418  }
419 
420 
421  /// return N times the mutual information for conditioned pairs of variables
422  template < template < typename > class ALLOC >
424  NodeId var_x,
425  NodeId var_y,
426  const std::vector< NodeId, ALLOC< NodeId > >& vars_z) {
427  /*
428  * We have a few partial entropies to compute in order to have the
429  * 2-point mutual information:
430  * I(x;y) = H(x) + H(y) - H(x,y)
431  * correspondingly
432  * I(x;y) = Hx + Hy - Hxy
433  * or
434  * I(x;y|z) = H(x,z) + H(y,z) - H(z) - H(x,y,z)
435  * correspondingly
436  * I(x;y|z) = Hxz + Hyz - Hz - Hxyz
437  * Note that Entropy H is equal to 1/N times the log2Likelihood,
438  * where N is the size of the database.
439  * Remember that we return N times I(x;y|z)
440  */
441 
442  // if the score has already been computed, get its value
443  const IdCondSet< ALLOC > idset_xyz(var_x, var_y, vars_z, false, false);
444  if (_use_ICache_) {
445  try {
446  return _ICache_.score(idset_xyz);
447  } catch (const NotFound&) {}
448  }
449 
450  // compute the score
451 
452  // here, we distinguish nodesets with conditioning nodes from those
453  // without conditioning nodes
454  double score;
455  if (!vars_z.empty()) {
457  // std::sort(vars.begin(), vars.end());
460  const double NHxyz = -_NH_.score(IdCondSet< ALLOC >(vars, false, true));
461 
462  vars.pop_back();
463  const double NHxz = -_NH_.score(IdCondSet< ALLOC >(vars, false, true));
464 
465  vars.pop_back();
467  const double NHyz = -_NH_.score(IdCondSet< ALLOC >(vars, false, true));
468 
469  vars.pop_back();
470  const double NHz = -_NH_.score(IdCondSet< ALLOC >(vars, false, true));
471 
472  const double NHxz_NHyz = NHxz + NHyz;
473  double NHz_NHxyz = NHz + NHxyz;
474 
475  // avoid numeric instability due to rounding errors
476  double ratio = 1;
477  if (NHxz_NHyz > 0) {
479  } else if (NHz_NHxyz > 0) {
481  }
482  if (ratio < 0) ratio = -ratio;
483  if (ratio < _threshold_) {
484  NHz_NHxyz = NHxz_NHyz; // ensure that the score is equal to 0
485  }
486 
488  } else {
489  const double NHxy
491  const double NHx = -_NH_.score(var_x);
492  const double NHy = -_NH_.score(var_y);
493 
494  double NHx_NHy = NHx + NHy;
495 
496  // avoid numeric instability due to rounding errors
497  double ratio = 1;
498  if (NHx_NHy > 0) {
499  ratio = (NHx_NHy - NHxy) / NHx_NHy;
500  } else if (NHxy > 0) {
501  ratio = (NHx_NHy - NHxy) / NHxy;
502  }
503  if (ratio < 0) ratio = -ratio;
504  if (ratio < _threshold_) {
505  NHx_NHy = NHxy; // ensure that the score is equal to 0
506  }
507 
508  score = NHx_NHy - NHxy;
509  }
510 
511 
512  // shall we put the score into the cache?
514 
515  return score;
516  }
517 
518 
519  /// return N times the mutual information for conditioned triples of variables
520  template < template < typename > class ALLOC >
522  NodeId var_x,
523  NodeId var_y,
524  NodeId var_z,
525  const std::vector< NodeId, ALLOC< NodeId > >& ui_ids) {
526  // conditional 3-point mutual information formula:
527  // I(x;y;z|{ui}) = I(x;y|{ui}) - I(x;y|z,{ui})
531  }
532 
533 
534  /// 2pt penalty
535  template < template < typename > class ALLOC >
537  NodeId var1,
538  NodeId var2,
539  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids) {
540  // if no penalty, return 0
541  if (_kmode_ == KModeTypes::NoCorr) return 0.0;
542 
543 
544  // If using the K cache, verify whether the set isn't already known
545  IdCondSet< ALLOC > idset;
546  if (_use_KCache_) {
548  try {
549  return _KCache_.score(idset);
550  } catch (const NotFound&) {}
551  }
552 
553  // compute the score
554  double score;
555  size_t rx, ry, rui;
556  switch (_kmode_) {
557  case KModeTypes::MDL: {
558  const auto& database = _NH_.database();
559  const auto& node2cols = _NH_.nodeId2Columns();
560 
561  rui = 1;
562  if (!node2cols.empty()) {
565  for (const NodeId i: conditioning_ids) {
567  }
568  } else {
571  for (const NodeId i: conditioning_ids) {
572  rui *= database.domainSize(i);
573  }
574  }
575 
576  // compute the size of the database, including the a priori
577  if (!_use_KCache_) {
579  }
580  const double N = _score_MDL_.N(idset);
581 
582  score = 0.5 * (rx - 1) * (ry - 1) * rui * std::log2(N);
583  } break;
584 
585  case KModeTypes::NML:
587  break;
588 
589  default:
591  "CorrectedMutualInformation mode does "
592  "not support yet this correction");
593  }
594 
595  // shall we put the score into the cache?
597  return score;
598  }
599 
600 
601  /// 3pt penalty
602  template < template < typename > class ALLOC >
604  NodeId var1,
605  NodeId var2,
606  NodeId var3,
607  const std::vector< NodeId, ALLOC< NodeId > >& ui_ids) {
608  // k(x;y;z|ui) = k(x;y|ui,z) - k(x;y|ui)
612  }
613 
614 
615  } /* namespace learning */
616 
617 } /* namespace gum */
618 
619 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)