aGrUM  0.14.2
correctedMutualInformation.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
27 #ifndef GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
28 #define GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
29 
30 #include <agrum/config.h>
31 #include <agrum/core/math/math.h>
32 #include <vector>
33 
37 
38 namespace gum {
39 
40  namespace learning {
41 
52  template < template < typename > class ALLOC = std::allocator >
54  public:
56  using allocator_type = ALLOC< NodeId >;
57 
58  // ##########################################################################
60  // ##########################################################################
62 
64 
87  const DBRowGeneratorParser< ALLOC >& parser,
88  const Apriori< ALLOC >& apriori,
89  const std::vector< std::pair< std::size_t, std::size_t >,
90  ALLOC< std::pair< std::size_t, std::size_t > > >&
91  ranges,
92  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
93  nodeId2columns =
94  Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
95  const allocator_type& alloc = allocator_type());
96 
98 
115  const DBRowGeneratorParser< ALLOC >& parser,
116  const Apriori< ALLOC >& apriori,
117  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
118  nodeId2columns =
119  Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
120  const allocator_type& alloc = allocator_type());
121 
124 
127  const allocator_type& alloc);
128 
131 
134  const allocator_type& alloc);
135 
138 
141  clone(const allocator_type& alloc) const;
142 
144  virtual ~CorrectedMutualInformation();
145 
147 
148 
149  // ##########################################################################
151  // ##########################################################################
152 
154 
158 
162 
164 
165 
166  // ##########################################################################
168  // ##########################################################################
170 
172  virtual void clear();
173 
175 
183  virtual void clearCache();
184 
186 
194  virtual void useCache(bool on_off);
195 
197  void useICache(bool on_off);
198 
200  void clearICache();
201 
203  void useHCache(bool on_off);
204 
206  void clearHCache();
207 
209  void useKCache(bool on_off);
210 
212  void clearKCache();
213 
215  void useCnrCache(bool on_off);
216 
218  void clearCnrCache();
219 
221 
222 
223  // ##########################################################################
225  // ##########################################################################
227 
229  double score(NodeId var1, NodeId var2);
230 
232  double score(NodeId var1,
233  NodeId var2,
234  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids);
235 
237  double score(NodeId var1, NodeId var2, NodeId var3);
238 
240  double score(NodeId var1,
241  NodeId var2,
242  NodeId var3,
243  const std::vector< NodeId, ALLOC< NodeId > >& conditioning_ids);
244 
246 
247 
248  // ##########################################################################
250  // ##########################################################################
252 
254  void useMDL();
255 
257  void useNML();
258 
260  void useNoCorr();
261 
263  virtual void setMaxNbThreads(std::size_t nb) const;
264 
266  virtual std::size_t nbThreads() const;
267 
277  virtual void setMinNbRowsPerThread(const std::size_t nb) const;
278 
280  virtual std::size_t minNbRowsPerThread() const;
281 
283 
289  template < template < typename > class XALLOC >
290  void setRanges(
291  const std::vector< std::pair< std::size_t, std::size_t >,
292  XALLOC< std::pair< std::size_t, std::size_t > > >&
293  new_ranges);
294 
296  void clearRanges();
297 
299  const std::vector< std::pair< std::size_t, std::size_t >,
300  ALLOC< std::pair< std::size_t, std::size_t > > >&
301  ranges() const;
302 
303 
306 
308 
309 
311  enum class KModeTypes { MDL, NML, NoCorr };
312 
313 
314 #ifndef DOXYGEN_SHOULD_SKIP_THIS
315 
316  private:
318  /* Note that the log2-likelihood is equal to N times the entropy H */
320 
322  KNML< ALLOC > __k_NML;
323 
326  ScoreMDL< ALLOC > __score_MDL;
327 
329  KModeTypes __kmode{KModeTypes::MDL};
330 
331 
333 
335  bool __use_ICache{true};
336 
338 
341  bool __use_HCache{true};
342 
344 
347  bool __use_KCache{true};
348 
350 
354  bool __use_CnrCache{true};
355 
356 
358  ScoringCache< ALLOC > __ICache;
359 
361  ScoringCache< ALLOC > __KCache;
362 
363 
365  const std::vector< NodeId, ALLOC< NodeId > > __empty_conditioning_set;
366 
368  const double __threshold{1e-10};
369 
370 
372  double __NI_score(NodeId var_x,
373  NodeId var_y,
374  const std::vector< NodeId, ALLOC< NodeId > >& vars_z);
375 
377  double __NI_score(NodeId var_x,
378  NodeId var_y,
379  NodeId var_z,
380  const std::vector< NodeId, ALLOC< NodeId > >& vars_ui);
381 
383  double __K_score(NodeId var_x,
384  NodeId var_y,
385  const std::vector< NodeId, ALLOC< NodeId > >& vars_z);
386 
388  double __K_score(NodeId var_x,
389  NodeId var_y,
390  NodeId var_z,
391  const std::vector< NodeId, ALLOC< NodeId > >& vars_ui);
392 
393 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
394  };
395 
396  } /* namespace learning */
397 
398 } /* namespace gum */
399 
400 
401 #ifndef GUM_NO_EXTERN_TEMPLATE_CLASS
402 extern template class gum::learning::CorrectedMutualInformation<>;
403 #endif
404 
405 
406 // always include the template implementation
408 
409 #endif /* GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H */
void useNML()
use the kNML penalty function
Useful macros for maths.
void clearRanges()
reset the ranges to the one range corresponding to the whole database
KModeTypes
the description type for the complexity correction
void useCnrCache(bool on_off)
turn on/off the use of the CnrCache (the cache for the Cnr formula)
virtual void useCache(bool on_off)
turn on/off the use of all the caches
void useHCache(bool on_off)
turn on/off the use of the HCache (the cache for the entropies)
The class computing n times the corrected mutual information, as used in the 3off2 algorithm...
virtual ~CorrectedMutualInformation()
destructor
virtual void clear()
clears all the data structures from memory
the class for computing Log2-likelihood scores
void clearICache()
clears the ICache (the mutual information cache)
CorrectedMutualInformation< ALLOC > & operator=(const CorrectedMutualInformation< ALLOC > &from)
copy operator
ALLOC< NodeId > allocator_type
type for the allocators passed in arguments of methods
allocator_type getAllocator() const
returns the allocator used by the score
void useKCache(bool on_off)
turn on/off the use of the KCache (the cache for the penalties)
the base class for all a priori
Definition: apriori.h:47
virtual CorrectedMutualInformation< ALLOC > * clone() const
virtual copy constructor
virtual void clearCache()
clears all the current caches
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
the class for computing Log2-likelihood scores
const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > & ranges() const
returns the current ranges
double score(NodeId var1, NodeId var2)
returns the 2-point mutual information corresponding to a given nodeset
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the mutual information
The class computing n times the corrected mutual information, as used in the 3off2 algorithm...
the class for computing BIC scores
Definition: scoreBIC.h:49
Set of pairs of elements with fast search for both elements.
Definition: bijection.h:1803
CorrectedMutualInformation(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &apriori, const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > &ranges, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
void clearKCache()
clears the KCache (the cache for the penalties)
the class for computing MDL scores
a cache for caching scores and independence tests resultsCaching previously computed scores or the re...
Definition: scoringCache.h:57
The class for the NML penalty used in 3off2.
void clearHCache()
clears the HCache (the cache for the entropies)
virtual std::size_t nbThreads() const
returns the number of threads used to parse the database
the class for computing the NML penalty used by 3off2
Definition: kNML.h:47
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context ...
void useMDL()
use the MDL penalty function
void useICache(bool on_off)
turn on/off the use of the ICache (the mutual information cache)
void clearCnrCache()
clears the CnrCache (the cache for the Cnr formula)
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
Size NodeId
Type for node ids.
Definition: graphElements.h:97
void useNoCorr()
use no correction/penalty function
virtual void setMaxNbThreads(std::size_t nb) const
changes the max number of threads used to parse the database