aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
correctedMutualInformation.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief The class computing n times the corrected mutual information,
25  * as used in the 3off2 algorithm
26  *
27  * @author Quentin FALCAND, Christophe GONZALES(@AMU) and Pierre-Henri
28  * WUILLEMIN(@LIP6).
29  */
30 #ifndef GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
31 #define GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
32 
33 #include <agrum/config.h>
34 #include <agrum/tools/core/math/math_utils.h>
35 #include <vector>
36 
37 #include <agrum/tools/stattests/kNML.h>
38 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
39 #include <agrum/BN/learning/scores_and_tests/scoreMDL.h>
40 
41 namespace gum {
42 
43  namespace learning {
44 
45  /**
46  * @class CorrectedMutualInformation
47  * @brief The class computing n times the corrected mutual information,
48  * as used in the 3off2 algorithm
49  *
50  * This class handles the computations and storage of the mutual information
51  * values used in 3off2 and potential corrections.
52  *
53  * @ingroup learning_scores
54  */
55  template < template < typename > class ALLOC = std::allocator >
57  public:
58  /// type for the allocators passed in arguments of methods
60 
61  // ##########################################################################
62  /// @name Constructors / Destructors
63  // ##########################################################################
64  /// @{
65 
66  /// default constructor
67  /** @param parser the parser used to parse the database
68  * @param apriori An apriori that we add to the computation of
69  * the score (this should come from expert knowledge): this consists in
70  * adding numbers to countings in the contingency tables
71  * @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
72  * indices. The countings are then performed only on the union of the
73  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
74  * cross validation tasks, in which part of the database should be ignored.
75  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
76  * the whole database.
77  * @param nodeId2Columns a mapping from the ids of the nodes in the
78  * graphical model to the corresponding column in the DatabaseTable
79  * parsed by the parser. This enables estimating from a database in
80  * which variable A corresponds to the 2nd column the parameters of a BN
81  * in which variable A has a NodeId of 5. An empty nodeId2Columns
82  * bijection means that the mapping is an identity, i.e., the value of a
83  * NodeId is equal to the index of the column in the DatabaseTable.
84  * @param alloc the allocator used to allocate the structures within the
85  * Score.
86  * @warning If nodeId2columns is not empty, then only the scores over the
87  * ids belonging to this bijection can be computed: applying method
88  * score() over other ids will raise exception NotFound. */
90  const DBRowGeneratorParser< ALLOC >& parser,
91  const Apriori< ALLOC >& apriori,
92  const std::vector< std::pair< std::size_t, std::size_t >,
93  ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
95  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
97 
98  /// default constructor
99  /** @param parser the parser used to parse the database
100  * @param apriori An apriori that we add to the computation of
101  * the score (this should come from expert knowledge): this consists in
102  * adding numbers to countings in the contingency tables
103  * @param nodeId2Columns a mapping from the ids of the nodes in the
104  * graphical model to the corresponding column in the DatabaseTable
105  * parsed by the parser. This enables estimating from a database in
106  * which variable A corresponds to the 2nd column the parameters of a BN
107  * in which variable A has a NodeId of 5. An empty nodeId2Columns
108  * bijection means that the mapping is an identity, i.e., the value of a
109  * NodeId is equal to the index of the column in the DatabaseTable.
110  * @param alloc the allocator used to allocate the structures within the
111  * Score.
112  * @warning If nodeId2columns is not empty, then only the scores over the
113  * ids belonging to this bijection can be computed: applying method
114  * score() over other ids will raise exception NotFound. */
116  const DBRowGeneratorParser< ALLOC >& parser,
117  const Apriori< ALLOC >& apriori,
119  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
120  const allocator_type& alloc = allocator_type());
121 
122  /// copy constructor
123  CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC >& from);
124 
125  /// copy constructor with a given allocator
126  CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC >& from,
127  const allocator_type& alloc);
128 
129  /// move constructor
130  CorrectedMutualInformation(CorrectedMutualInformation< ALLOC >&& from);
131 
132  /// move constructor with a given allocator
133  CorrectedMutualInformation(CorrectedMutualInformation< ALLOC >&& from,
134  const allocator_type& alloc);
135 
136  /// virtual copy constructor
137  virtual CorrectedMutualInformation< ALLOC >* clone() const;
138 
139  /// virtual copy constructor with a given allocator
140  virtual CorrectedMutualInformation< ALLOC >* clone(const allocator_type& alloc) const;
141 
142  /// destructor
143  virtual ~CorrectedMutualInformation();
144 
145  /// @}
146 
147 
148  // ##########################################################################
149  /// @name Operators
150  // ##########################################################################
151 
152  /// @{
153 
154  /// copy operator
157 
158  /// move operator
160 
161  /// @}
162 
163 
164  // ##########################################################################
165  /// @name caching functions
166  // ##########################################################################
167  /// @{
168 
169  /// clears all the data structures from memory
170  virtual void clear();
171 
172  /// clears all the current caches
173  /** There are 4 caches in the CorrectedMutualInformation class:
174  * # The I cache is intended to cache the computations of the mutual
175  * informations used by 3off2
176  * # the H cache is intended to store the results of the computations
177  * of the entropies used in the mutual information formula
178  * # the K cache is intended to store the penalties computed so far
179  * # the Cnr cache is intended to store the results of the computations
180  * of the Cnr formula used by the kNML penalty */
181  virtual void clearCache();
182 
183  /// turn on/off the use of all the caches
184  /** There are 4 caches in the CorrectedMutualInformation class:
185  * # The I cache is intended to cache the computations of the mutual
186  * informations used by 3off2
187  * # the H cache is intended to store the results of the computations
188  * of the entropies used in the mutual information formula
189  * # the K cache is intended to store the penalties computed so far
190  * # the Cnr cache is intended to store the results of the computations
191  * of the Cnr formula used by the kNML penalty */
192  virtual void useCache(bool on_off);
193 
194  /// turn on/off the use of the ICache (the mutual information cache)
195  void useICache(bool on_off);
196 
197  /// clears the ICache (the mutual information cache)
198  void clearICache();
199 
200  /// turn on/off the use of the HCache (the cache for the entropies)
201  void useHCache(bool on_off);
202 
203  /// clears the HCache (the cache for the entropies)
204  void clearHCache();
205 
206  /// turn on/off the use of the KCache (the cache for the penalties)
207  void useKCache(bool on_off);
208 
209  /// clears the KCache (the cache for the penalties)
210  void clearKCache();
211 
212  /// turn on/off the use of the CnrCache (the cache for the Cnr formula)
213  void useCnrCache(bool on_off);
214 
215  /// clears the CnrCache (the cache for the Cnr formula)
216  void clearCnrCache();
217 
218  /// @}
219 
220 
221  // ##########################################################################
222  /// @name score functions
223  // ##########################################################################
224  /// @{
225 
226  /// returns the 2-point mutual information corresponding to a given nodeset
227  double score(NodeId var1, NodeId var2);
228 
229  /// returns the 2-point mutual information corresponding to a given nodeset
230  double score(NodeId var1,
231  NodeId var2,
232  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids);
233 
234  /// returns the 3-point mutual information corresponding to a given nodeset
235  double score(NodeId var1, NodeId var2, NodeId var3);
236 
237  /// returns the 3-point mutual information corresponding to a given nodeset
238  double score(NodeId var1,
239  NodeId var2,
240  NodeId var3,
241  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids);
242 
243  /// @}
244 
245 
246  // ##########################################################################
247  /// @name Accessors / Modifiers
248  // ##########################################################################
249  /// @{
250 
251  /// use the MDL penalty function
252  void useMDL();
253 
254  /// use the kNML penalty function
255  void useNML();
256 
257  /// use no correction/penalty function
258  void useNoCorr();
259 
260  /// changes the max number of threads used to parse the database
261  virtual void setMaxNbThreads(std::size_t nb) const;
262 
263  /// returns the number of threads used to parse the database
264  virtual std::size_t nbThreads() const;
265 
266  /** @brief changes the number min of rows a thread should process in a
267  * multithreading context
268  *
269  * When computing score, several threads are used by record counters to
270  * perform countings on the rows of the database, the MinNbRowsPerThread
271  * method indicates how many rows each thread should at least process.
272  * This is used to compute the number of threads actually run. This number
273  * is equal to the min between the max number of threads allowed and the
274  * number of records in the database divided by nb. */
275  virtual void setMinNbRowsPerThread(const std::size_t nb) const;
276 
277  /// returns the minimum of rows that each thread should process
278  virtual std::size_t minNbRowsPerThread() const;
279 
280  /// sets new ranges to perform the countings used by the mutual information
281  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
282  * indices. The countings are then performed only on the union of the
283  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
284  * cross validation tasks, in which part of the database should be ignored.
285  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
286  * the whole database. */
287  template < template < typename > class XALLOC >
288  void setRanges(
289  const std::vector< std::pair< std::size_t, std::size_t >,
290  XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
291 
292  /// reset the ranges to the one range corresponding to the whole database
293  void clearRanges();
294 
295  /// returns the current ranges
296  const std::vector< std::pair< std::size_t, std::size_t >,
297  ALLOC< std::pair< std::size_t, std::size_t > > >&
298  ranges() const;
299 
300 
301  /// returns the allocator used by the score
303 
304  /// @}
305 
306 
307  /// the description type for the complexity correction
308  enum class KModeTypes
309  {
310  MDL,
311  NML,
312  NoCorr
313  };
314 
315 
316 #ifndef DOXYGEN_SHOULD_SKIP_THIS
317 
318  private:
319  /// The object to compute N times Entropy H used by mutual information I
320  /* Note that the log2-likelihood is equal to N times the entropy H */
321  ScoreLog2Likelihood< ALLOC > _NH_;
322 
323  /// the object computing the NML k score
324  KNML< ALLOC > _k_NML_;
325 
326  /** @brief a score MDL used to compute the size N of the database,
327  * including the a priori */
328  ScoreMDL< ALLOC > _score_MDL_;
329 
330  /// the mode used for the correction
331  KModeTypes _kmode_{KModeTypes::MDL};
332 
333 
334  /// a Boolean indicating whether we wish to use the I cache
335  /** The I cache is the cache used to store N times the values of
336  * mutual informations */
337  bool _use_ICache_{true};
338 
339  /// a Boolean indicating whether we wish to use the H cache
340  /** The H cache is the cache for storing N times the entropy. Mutual
341  * information is computed as a summation/subtraction of entropies. The
342  * latter are cached directly within the _NH_ instance. */
343  bool _use_HCache_{true};
344 
345  /// a Boolean indicating whether we wish to use the K cache
346  /** The K cache is used to cache K-scores, which corresponds to
347  * summations/subtractions of kNML individual values. The cache for the
348  * latter is called the Cnr cache because it uses Cnr values */
349  bool _use_KCache_{true};
350 
351  /// a Boolean indicating whether we wish to use the Cnr cache
352  /** When using the kNML class, the computation of the K-scores
353  * consists of summations/subtractions of kNML scores. The latter
354  * essentially amount to computing Cnr values. Those can be
355  * cached directly within the _k_NML_ instance */
356  bool _use_CnrCache_{true};
357 
358 
359  /// the ICache
360  ScoringCache< ALLOC > _ICache_;
361 
362  /// the KCache
363  ScoringCache< ALLOC > _KCache_;
364 
365 
366  /// an empty conditioning set
368 
369  /// a constant used to prevent numerical instabilities
370  const double _threshold_{1e-10};
371 
372 
373  /// returns the 2-point mutual information corresponding to a given nodeset
374  double _NI_score_(NodeId var_x,
375  NodeId var_y,
376  const std::vector< NodeId, ALLOC< NodeId > >& vars_z);
377 
378  /// returns the 3-point mutual information corresponding to a given nodeset
379  double _NI_score_(NodeId var_x,
380  NodeId var_y,
381  NodeId var_z,
382  const std::vector< NodeId, ALLOC< NodeId > >& vars_ui);
383 
384  /// computes the complexity correction for the mutual information
385  double _K_score_(NodeId var_x,
386  NodeId var_y,
387  const std::vector< NodeId, ALLOC< NodeId > >& vars_z);
388 
389  /// computes the complexity correction for the mutual information
390  double _K_score_(NodeId var_x,
391  NodeId var_y,
392  NodeId var_z,
393  const std::vector< NodeId, ALLOC< NodeId > >& vars_ui);
394 
395 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
396  };
397 
398  } /* namespace learning */
399 
400 } /* namespace gum */
401 
402 
403 #ifndef GUM_NO_EXTERN_TEMPLATE_CLASS
404 extern template class gum::learning::CorrectedMutualInformation<>;
405 #endif
406 
407 
408 // always include the template implementation
409 #include <agrum/tools/stattests/correctedMutualInformation_tpl.h>
410 
411 #endif /* GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H */
void useNML()
use the kNML penalty function
void clearRanges()
reset the ranges to the one range corresponding to the whole database
KModeTypes
the description type for the complexity correction
void useCnrCache(bool on_off)
turn on/off the use of the CnrCache (the cache for the Cnr formula)
virtual void useCache(bool on_off)
turn on/off the use of all the caches
void useHCache(bool on_off)
turn on/off the use of the HCache (the cache for the entropies)
The class computing n times the corrected mutual information, as used in the 3off2 algorithm...
double score(NodeId var1, NodeId var2, const std::vector< NodeId, ALLOC< NodeId > > &conditioning_ids)
returns the 2-point mutual information corresponding to a given nodeset
virtual ~CorrectedMutualInformation()
destructor
virtual void clear()
clears all the data structures from memory
void clearICache()
clears the ICache (the mutual information cache)
CorrectedMutualInformation< ALLOC > & operator=(const CorrectedMutualInformation< ALLOC > &from)
copy operator
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
allocator_type getAllocator() const
returns the allocator used by the score
void useKCache(bool on_off)
turn on/off the use of the KCache (the cache for the penalties)
virtual CorrectedMutualInformation< ALLOC > * clone() const
virtual copy constructor
virtual void clearCache()
clears all the current caches
CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC > &from)
copy constructor
virtual CorrectedMutualInformation< ALLOC > * clone(const allocator_type &alloc) const
virtual copy constructor with a given allocator
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
CorrectedMutualInformation(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &apriori, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > & ranges() const
returns the current ranges
double score(NodeId var1, NodeId var2)
returns the 2-point mutual information corresponding to a given nodeset
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the mutual information
CorrectedMutualInformation(CorrectedMutualInformation< ALLOC > &&from)
move constructor
double score(NodeId var1, NodeId var2, NodeId var3, const std::vector< NodeId, ALLOC< NodeId > > &conditioning_ids)
returns the 3-point mutual information corresponding to a given nodeset
CorrectedMutualInformation(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &apriori, const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > &ranges, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
double score(NodeId var1, NodeId var2, NodeId var3)
returns the 3-point mutual information corresponding to a given nodeset
CorrectedMutualInformation(CorrectedMutualInformation< ALLOC > &&from, const allocator_type &alloc)
move constructor with a given allocator
void clearKCache()
clears the KCache (the cache for the penalties)
CorrectedMutualInformation(const CorrectedMutualInformation< ALLOC > &from, const allocator_type &alloc)
copy constructor with a given allocator
void clearHCache()
clears the HCache (the cache for the entropies)
virtual std::size_t nbThreads() const
returns the number of threads used to parse the database
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context ...
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
void useMDL()
use the MDL penalty function
void useICache(bool on_off)
turn on/off the use of the ICache (the mutual information cache)
void clearCnrCache()
clears the CnrCache (the cache for the Cnr formula)
CorrectedMutualInformation< ALLOC > & operator=(CorrectedMutualInformation< ALLOC > &&from)
move operator
void useNoCorr()
use no correction/penalty function
virtual void setMaxNbThreads(std::size_t nb) const
changes the max number of threads used to parse the database