aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
correctedMutualInformation.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief The class computing n times the corrected mutual information,
25  * as used in the 3off2 algorithm
26  *
27  * @author Quentin FALCAND, Christophe GONZALES(@AMU) and Pierre-Henri
28  * WUILLEMIN(@LIP6).
29  */
30 #ifndef GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
31 #define GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
32 
33 #include <agrum/config.h>
34 #include <agrum/tools/core/math/math_utils.h>
35 #include <vector>
36 
37 #include <agrum/tools/stattests/kNML.h>
38 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
39 #include <agrum/BN/learning/scores_and_tests/scoreMDL.h>
40 
41 namespace gum {
42 
43  namespace learning {
44 
45  /**
46  * @class CorrectedMutualInformation
47  * @brief The class computing n times the corrected mutual information,
48  * as used in the 3off2 algorithm
49  *
50  * This class handles the computations and storage of the mutual information
51  * values used in 3off2 and potential corrections.
52  *
53  * @ingroup learning_scores
54  */
55  template < template < typename > class ALLOC = std::allocator >
57  public:
58  /// type for the allocators passed in arguments of methods
60 
61  // ##########################################################################
62  /// @name Constructors / Destructors
63  // ##########################################################################
64  /// @{
65 
66  /// default constructor
67  /** @param parser the parser used to parse the database
68  * @param apriori An apriori that we add to the computation of
69  * the score (this should come from expert knowledge): this consists in
70  * adding numbers to countings in the contingency tables
71  * @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
72  * indices. The countings are then performed only on the union of the
73  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
74  * cross validation tasks, in which part of the database should be ignored.
75  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
76  * the whole database.
77  * @param nodeId2Columns a mapping from the ids of the nodes in the
78  * graphical model to the corresponding column in the DatabaseTable
79  * parsed by the parser. This enables estimating from a database in
80  * which variable A corresponds to the 2nd column the parameters of a BN
81  * in which variable A has a NodeId of 5. An empty nodeId2Columns
82  * bijection means that the mapping is an identity, i.e., the value of a
83  * NodeId is equal to the index of the column in the DatabaseTable.
84  * @param alloc the allocator used to allocate the structures within the
85  * Score.
86  * @warning If nodeId2columns is not empty, then only the scores over the
87  * ids belonging to this bijection can be computed: applying method
88  * score() over other ids will raise exception NotFound. */
90  const DBRowGeneratorParser< ALLOC >& parser,
91  const Apriori< ALLOC >& apriori,
92  const std::vector< std::pair< std::size_t, std::size_t >,
93  ALLOC< std::pair< std::size_t, std::size_t > > >&
94  ranges,
95  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
97  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
99 
100  /// default constructor
101  /** @param parser the parser used to parse the database
102  * @param apriori An apriori that we add to the computation of
103  * the score (this should come from expert knowledge): this consists in
104  * adding numbers to countings in the contingency tables
105  * @param nodeId2Columns a mapping from the ids of the nodes in the
106  * graphical model to the corresponding column in the DatabaseTable
107  * parsed by the parser. This enables estimating from a database in
108  * which variable A corresponds to the 2nd column the parameters of a BN
109  * in which variable A has a NodeId of 5. An empty nodeId2Columns
110  * bijection means that the mapping is an identity, i.e., the value of a
111  * NodeId is equal to the index of the column in the DatabaseTable.
112  * @param alloc the allocator used to allocate the structures within the
113  * Score.
114  * @warning If nodeId2columns is not empty, then only the scores over the
115  * ids belonging to this bijection can be computed: applying method
116  * score() over other ids will raise exception NotFound. */
118  const DBRowGeneratorParser< ALLOC >& parser,
119  const Apriori< ALLOC >& apriori,
120  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
122  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
123  const allocator_type& alloc = allocator_type());
124 
125  /// copy constructor
126  CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC >& from);
127 
128  /// copy constructor with a given allocator
129  CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC >& from,
130  const allocator_type& alloc);
131 
132  /// move constructor
133  CorrectedMutualInformation(CorrectedMutualInformation< ALLOC >&& from);
134 
135  /// move constructor with a given allocator
136  CorrectedMutualInformation(CorrectedMutualInformation< ALLOC >&& from,
137  const allocator_type& alloc);
138 
139  /// virtual copy constructor
140  virtual CorrectedMutualInformation< ALLOC >* clone() const;
141 
142  /// virtual copy constructor with a given allocator
144  clone(const allocator_type& alloc) const;
145 
146  /// destructor
147  virtual ~CorrectedMutualInformation();
148 
149  /// @}
150 
151 
152  // ##########################################################################
153  /// @name Operators
154  // ##########################################################################
155 
156  /// @{
157 
158  /// copy operator
161 
162  /// move operator
165 
166  /// @}
167 
168 
169  // ##########################################################################
170  /// @name caching functions
171  // ##########################################################################
172  /// @{
173 
174  /// clears all the data structures from memory
175  virtual void clear();
176 
177  /// clears all the current caches
178  /** There are 4 caches in the CorrectedMutualInformation class:
179  * # The I cache is intended to cache the computations of the mutual
180  * informations used by 3off2
181  * # the H cache is intended to store the results of the computations
182  * of the entropies used in the mutual information formula
183  * # the K cache is intended to store the penalties computed so far
184  * # the Cnr cache is intended to store the results of the computations
185  * of the Cnr formula used by the kNML penalty */
186  virtual void clearCache();
187 
188  /// turn on/off the use of all the caches
189  /** There are 4 caches in the CorrectedMutualInformation class:
190  * # The I cache is intended to cache the computations of the mutual
191  * informations used by 3off2
192  * # the H cache is intended to store the results of the computations
193  * of the entropies used in the mutual information formula
194  * # the K cache is intended to store the penalties computed so far
195  * # the Cnr cache is intended to store the results of the computations
196  * of the Cnr formula used by the kNML penalty */
197  virtual void useCache(bool on_off);
198 
199  /// turn on/off the use of the ICache (the mutual information cache)
200  void useICache(bool on_off);
201 
202  /// clears the ICache (the mutual information cache)
203  void clearICache();
204 
205  /// turn on/off the use of the HCache (the cache for the entropies)
206  void useHCache(bool on_off);
207 
208  /// clears the HCache (the cache for the entropies)
209  void clearHCache();
210 
211  /// turn on/off the use of the KCache (the cache for the penalties)
212  void useKCache(bool on_off);
213 
214  /// clears the KCache (the cache for the penalties)
215  void clearKCache();
216 
217  /// turn on/off the use of the CnrCache (the cache for the Cnr formula)
218  void useCnrCache(bool on_off);
219 
220  /// clears the CnrCache (the cache for the Cnr formula)
221  void clearCnrCache();
222 
223  /// @}
224 
225 
226  // ##########################################################################
227  /// @name score functions
228  // ##########################################################################
229  /// @{
230 
231  /// returns the 2-point mutual information corresponding to a given nodeset
232  double score(NodeId var1, NodeId var2);
233 
234  /// returns the 2-point mutual information corresponding to a given nodeset
235  double score(NodeId var1,
236  NodeId var2,
237  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids);
238 
239  /// returns the 3-point mutual information corresponding to a given nodeset
240  double score(NodeId var1, NodeId var2, NodeId var3);
241 
242  /// returns the 3-point mutual information corresponding to a given nodeset
243  double score(NodeId var1,
244  NodeId var2,
245  NodeId var3,
246  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids);
247 
248  /// @}
249 
250 
251  // ##########################################################################
252  /// @name Accessors / Modifiers
253  // ##########################################################################
254  /// @{
255 
256  /// use the MDL penalty function
257  void useMDL();
258 
259  /// use the kNML penalty function
260  void useNML();
261 
262  /// use no correction/penalty function
263  void useNoCorr();
264 
265  /// changes the max number of threads used to parse the database
266  virtual void setMaxNbThreads(std::size_t nb) const;
267 
268  /// returns the number of threads used to parse the database
269  virtual std::size_t nbThreads() const;
270 
271  /** @brief changes the number min of rows a thread should process in a
272  * multithreading context
273  *
274  * When computing score, several threads are used by record counters to
275  * perform countings on the rows of the database, the MinNbRowsPerThread
276  * method indicates how many rows each thread should at least process.
277  * This is used to compute the number of threads actually run. This number
278  * is equal to the min between the max number of threads allowed and the
279  * number of records in the database divided by nb. */
280  virtual void setMinNbRowsPerThread(const std::size_t nb) const;
281 
282  /// returns the minimum of rows that each thread should process
283  virtual std::size_t minNbRowsPerThread() const;
284 
285  /// sets new ranges to perform the countings used by the mutual information
286  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
287  * indices. The countings are then performed only on the union of the
288  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
289  * cross validation tasks, in which part of the database should be ignored.
290  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
291  * the whole database. */
292  template < template < typename > class XALLOC >
293  void setRanges(
294  const std::vector< std::pair< std::size_t, std::size_t >,
295  XALLOC< std::pair< std::size_t, std::size_t > > >&
296  new_ranges);
297 
298  /// reset the ranges to the one range corresponding to the whole database
299  void clearRanges();
300 
301  /// returns the current ranges
302  const std::vector< std::pair< std::size_t, std::size_t >,
303  ALLOC< std::pair< std::size_t, std::size_t > > >&
304  ranges() const;
305 
306 
307  /// returns the allocator used by the score
309 
310  /// @}
311 
312 
313  /// the description type for the complexity correction
314  enum class KModeTypes
315  {
316  MDL,
317  NML,
318  NoCorr
319  };
320 
321 
322 #ifndef DOXYGEN_SHOULD_SKIP_THIS
323 
324  private:
325  /// The object to compute N times Entropy H used by mutual information I
326  /* Note that the log2-likelihood is equal to N times the entropy H */
327  ScoreLog2Likelihood< ALLOC > NH__;
328 
329  /// the object computing the NML k score
330  KNML< ALLOC > k_NML__;
331 
332  /** @brief a score MDL used to compute the size N of the database,
333  * including the a priori */
334  ScoreMDL< ALLOC > score_MDL__;
335 
336  /// the mode used for the correction
337  KModeTypes kmode__{KModeTypes::MDL};
338 
339 
340  /// a Boolean indicating whether we wish to use the I cache
341  /** The I cache is the cache used to store N times the values of
342  * mutual informations */
343  bool use_ICache__{true};
344 
345  /// a Boolean indicating whether we wish to use the H cache
346  /** The H cache is the cache for storing N times the entropy. Mutual
347  * information is computed as a summation/subtraction of entropies. The
348  * latter are cached directly within the NH__ instance. */
349  bool use_HCache__{true};
350 
351  /// a Boolean indicating whether we wish to use the K cache
352  /** The K cache is used to cache K-scores, which corresponds to
353  * summations/subtractions of kNML individual values. The cache for the
354  * latter is called the Cnr cache because it uses Cnr values */
355  bool use_KCache__{true};
356 
357  /// a Boolean indicating whether we wish to use the Cnr cache
358  /** When using the kNML class, the computation of the K-scores
359  * consists of summations/subtractions of kNML scores. The latter
360  * essentially amount to computing Cnr values. Those can be
361  * cached directly within the k_NML__ instance */
362  bool use_CnrCache__{true};
363 
364 
365  /// the ICache
366  ScoringCache< ALLOC > ICache__;
367 
368  /// the KCache
369  ScoringCache< ALLOC > KCache__;
370 
371 
372  /// an empty conditioning set
374 
375  /// a constant used to prevent numerical instabilities
376  const double threshold__{1e-10};
377 
378 
379  /// returns the 2-point mutual information corresponding to a given nodeset
380  double NI_score__(NodeId var_x,
381  NodeId var_y,
382  const std::vector< NodeId, ALLOC< NodeId > >& vars_z);
383 
384  /// returns the 3-point mutual information corresponding to a given nodeset
385  double NI_score__(NodeId var_x,
386  NodeId var_y,
387  NodeId var_z,
388  const std::vector< NodeId, ALLOC< NodeId > >& vars_ui);
389 
390  /// computes the complexity correction for the mutual information
391  double K_score__(NodeId var_x,
392  NodeId var_y,
393  const std::vector< NodeId, ALLOC< NodeId > >& vars_z);
394 
395  /// computes the complexity correction for the mutual information
396  double K_score__(NodeId var_x,
397  NodeId var_y,
398  NodeId var_z,
399  const std::vector< NodeId, ALLOC< NodeId > >& vars_ui);
400 
401 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
402  };
403 
404  } /* namespace learning */
405 
406 } /* namespace gum */
407 
408 
409 #ifndef GUM_NO_EXTERN_TEMPLATE_CLASS
410 extern template class gum::learning::CorrectedMutualInformation<>;
411 #endif
412 
413 
414 // always include the template implementation
415 #include <agrum/tools/stattests/correctedMutualInformation_tpl.h>
416 
417 #endif /* GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H */
void useNML()
use the kNML penalty function
void clearRanges()
reset the ranges to the one range corresponding to the whole database
KModeTypes
the description type for the complexity correction
void useCnrCache(bool on_off)
turn on/off the use of the CnrCache (the cache for the Cnr formula)
virtual void useCache(bool on_off)
turn on/off the use of all the caches
void useHCache(bool on_off)
turn on/off the use of the HCache (the cache for the entropies)
The class computing n times the corrected mutual information, as used in the 3off2 algorithm...
double score(NodeId var1, NodeId var2, const std::vector< NodeId, ALLOC< NodeId > > &conditioning_ids)
returns the 2-point mutual information corresponding to a given nodeset
virtual ~CorrectedMutualInformation()
destructor
virtual void clear()
clears all the data structures from memory
void clearICache()
clears the ICache (the mutual information cache)
CorrectedMutualInformation< ALLOC > & operator=(const CorrectedMutualInformation< ALLOC > &from)
copy operator
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
allocator_type getAllocator() const
returns the allocator used by the score
void useKCache(bool on_off)
turn on/off the use of the KCache (the cache for the penalties)
virtual CorrectedMutualInformation< ALLOC > * clone() const
virtual copy constructor
virtual void clearCache()
clears all the current caches
CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC > &from)
copy constructor
virtual CorrectedMutualInformation< ALLOC > * clone(const allocator_type &alloc) const
virtual copy constructor with a given allocator
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
CorrectedMutualInformation(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &apriori, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > & ranges() const
returns the current ranges
double score(NodeId var1, NodeId var2)
returns the 2-point mutual information corresponding to a given nodeset
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the mutual information
CorrectedMutualInformation(CorrectedMutualInformation< ALLOC > &&from)
move constructor
double score(NodeId var1, NodeId var2, NodeId var3, const std::vector< NodeId, ALLOC< NodeId > > &conditioning_ids)
returns the 3-point mutual information corresponding to a given nodeset
CorrectedMutualInformation(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &apriori, const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > &ranges, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
double score(NodeId var1, NodeId var2, NodeId var3)
returns the 3-point mutual information corresponding to a given nodeset
CorrectedMutualInformation(CorrectedMutualInformation< ALLOC > &&from, const allocator_type &alloc)
move constructor with a given allocator
void clearKCache()
clears the KCache (the cache for the penalties)
CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC > &from, const allocator_type &alloc)
copy constructor with a given allocator
void clearHCache()
clears the HCache (the cache for the entropies)
virtual std::size_t nbThreads() const
returns the number of threads used to parse the database
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context ...
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
void useMDL()
use the MDL penalty function
void useICache(bool on_off)
turn on/off the use of the ICache (the mutual information cache)
void clearCnrCache()
clears the CnrCache (the cache for the Cnr formula)
CorrectedMutualInformation< ALLOC > & operator=(CorrectedMutualInformation< ALLOC > &&from)
move operator
void useNoCorr()
use no correction/penalty function
virtual void setMaxNbThreads(std::size_t nb) const
changes the max number of threads used to parse the database