aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
recordCounter.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief The class that computes countings of observations from the database.
24  *
25  * This class is the one to be called by scores and independence tests to
26  * compute countings of observations from tabular databases.
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 #ifndef GUM_LEARNING_RECORD_COUNTER_H
31 #define GUM_LEARNING_RECORD_COUNTER_H
32 
33 #include <vector>
34 #include <utility>
35 #include <sstream>
36 #include <string>
37 
38 #include <agrum/agrum.h>
39 #include <agrum/tools/core/bijection.h>
40 #include <agrum/tools/core/sequence.h>
41 #include <agrum/tools/core/OMPThreads.h>
42 #include <agrum/tools/core/threadData.h>
43 #include <agrum/tools/graphs/DAG.h>
44 #include <agrum/tools/database/DBRowGeneratorParser.h>
45 #include <agrum/tools/stattests/idCondSet.h>
46 
47 
48 namespace gum {
49 
50  namespace learning {
51 
52  /** @class RecordCounter
53  * @brief The class that computes countings of observations from the database.
54  * @headerfile recordCounter.h <agrum/BN/learning/scores_and_tests/recordCounter.h>
55  * @ingroup learning_scores
56  *
57  * This class is the one to be called by scores and independence tests to
58  * compute the countings of observations from tabular datasets they need.
59  * The countings are performed the following way:
60  * when asked for the countings over a set X = {X_1,...,X_n} of
61  * variables, the RecordCounter first checks whether it already contains
62  * some countings over a set Y of variables containing X. If this is the
63  * case, then it extracts from the countings over Y those over X (this is
64  * usually way faster than determining the countings by parsing the database).
65  * Otherwise, it determines the countings over X by parsing in a parallel
66  * way the database. Only the result of the last database-parsed countings
67  * is available for the subset counting determination. As an example, if
68  * we create a RecordCounter and ask it the countings over {A,B,C}, it will
69  * parse the database and provide the countings. Then, if we ask it countings
70  * over B, it will use the table over {A,B,C} to produce the countings we
71  * look for. Then, asking for countings over {A,C} will be performed the same
72  * way. Now, asking countings over {B,C,D} will require another database
73  * parsing. Finally, if we ask for countings over A, a new database parsing
74  * will be performed because only the countings over {B,C,D} are now contained
75  * in the RecordCounter.
76  *
77  * @par Here is an example of how to use the RecordCounter class:
78  * @code
79  * // here, write the code to construct your database, e.g.:
80  * gum::learning::DBInitializerFromCSV<> initializer( "file.csv" );
81  * const auto& var_names = initializer.variableNames();
82  * const std::size_t nb_vars = var_names.size();
83  * gum::learning::DBTranslatorSet<> translator_set;
84  * gum::learning::DBTranslator4ContinuousVariable<> translator;
85  * for (std::size_t i = 0; i < nb_vars; ++i) {
86  * translator_set.insertTranslator(translator, i);
87  * }
88  * gum::learning::DatabaseTable<> database(translator_set);
89  *
90  * // create the parser of the database
91  * gum::learning::DBRowGeneratorSet<> genset;
92  * gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset);
93  *
94  * // create the record counter
95  * gum::learning::RecordCounter<> counter(parser);
96  *
97  * // get the counts:
98  * gum::learning::IdCondSet<> ids ( 0, gum::vector<gum::NodeId> {2,1} );
99  * const std::vector< double >& counts1 = counter.counts ( ids );
100  *
101  * // change the rows from which we compute the counts:
102  * // they should now be made on rows [500,600) U [1050,1125) U [100,150)
103  * std::vector<std::pair<std::size_t,std::size_t>> new_ranges
104  * { std::pair<std::size_t,std::size_t>(500,600),
105  * std::pair<std::size_t,std::size_t>(1050,1125),
106  * std::pair<std::size_t,std::size_t>(100,150) };
107  * counter.setRanges ( new_ranges );
108  * const std::vector< double >& counts2 = counter.counts ( ids );
109  * @endcode
110  */
111  template < template < typename > class ALLOC = std::allocator >
112  class RecordCounter {
113  public:
114  /// type for the allocators passed in arguments of methods
115  using allocator_type = ALLOC< NodeId >;
116 
117  // ##########################################################################
118  /// @name Constructors / Destructors
119  // ##########################################################################
120  /// @{
121 
122  /// default constructor
123  /** @param parser the parser used to parse the database
124  * @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
125  * indices. The countings are then performed only on the union of the
126  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
127  * cross validation tasks, in which part of the database should be ignored.
128  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
129  * the whole database.
130  * @param nodeId2Columns a mapping from the ids of the nodes in the
131  * graphical model to the corresponding column in the DatabaseTable
132  * parsed by the parser. This enables estimating from a database in
133  * which variable A corresponds to the 2nd column the parameters of a BN
134  * in which variable A has a NodeId of 5. An empty nodeId2Columns
135  * bijection means that the mapping is an identity, i.e., the value of a
136  * NodeId is equal to the index of the column in the DatabaseTable.
137  * @param alloc the allocator used to allocate the structures within the
138  * RecordCounter.
139  * @warning If nodeId2columns is not empty, then only the counts over the
140  * ids belonging to this bijection can be computed: applying method
141  * counts() over other ids will raise exception NotFound. */
142  RecordCounter(const DBRowGeneratorParser< ALLOC >& parser,
143  const std::vector< std::pair< std::size_t, std::size_t >,
144  ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
145  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns
146  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
147  const allocator_type& alloc = allocator_type());
148 
149  /// default constructor
150  /** @param parser the parser used to parse the database
151  * @param nodeId2Columns a mapping from the ids of the nodes in the
152  * graphical model to the corresponding column in the DatabaseTable
153  * parsed by the parser. This enables estimating from a database in
154  * which variable A corresponds to the 2nd column the parameters of a BN
155  * in which variable A has a NodeId of 5. An empty nodeId2Columns
156  * bijection means that the mapping is an identity, i.e., the value of a
157  * NodeId is equal to the index of the column in the DatabaseTable.
158  * @param alloc the allocator used to allocate the structures within the
159  * RecordCounter.
160  * @warning If nodeId2columns is not empty, then only the counts over the
161  * ids belonging to this bijection can be computed: applying method
162  * counts() over other ids will raise exception NotFound. */
163  RecordCounter(const DBRowGeneratorParser< ALLOC >& parser,
164  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns
165  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
166  const allocator_type& alloc = allocator_type());
167 
168  /// copy constructor
169  RecordCounter(const RecordCounter< ALLOC >& from);
170 
171  /// copy constructor with a given allocator
172  RecordCounter(const RecordCounter< ALLOC >& from, const allocator_type& alloc);
173 
174  /// move constructor
175  RecordCounter(RecordCounter< ALLOC >&& from);
176 
177  /// move constructor with a given allocator
178  RecordCounter(RecordCounter< ALLOC >&& from, const allocator_type& alloc);
179 
180  /// virtual copy constructor
181  virtual RecordCounter< ALLOC >* clone() const;
182 
183  /// virtual copy constructor with a given allocator
184  virtual RecordCounter< ALLOC >* clone(const allocator_type& alloc) const;
185 
186  /// destructor
187  virtual ~RecordCounter();
188 
189  /// @}
190 
191 
192  // ##########################################################################
193  /// @name Operators
194  // ##########################################################################
195 
196  /// @{
197 
198  /// copy operator
199  RecordCounter< ALLOC >& operator=(const RecordCounter< ALLOC >& from);
200 
201  /// move operator
202  RecordCounter< ALLOC >& operator=(RecordCounter< ALLOC >&& from);
203 
204  /// @}
205 
206 
207  // ##########################################################################
208  /// @name Accessors / Modifiers
209  // ##########################################################################
210 
211  /// @{
212 
213  /// clears all the last database-parsed countings from memory
214  void clear();
215 
216  /// changes the max number of threads used to parse the database
217  void setMaxNbThreads(const std::size_t nb) const;
218 
219  /// returns the number of threads used to parse the database
220  std::size_t nbThreads() const;
221 
222  /** @brief changes the number min of rows a thread should process in a
223  * multithreading context
224  *
225  * When Method counts executes several threads to perform countings on the
226  * rows of the database, the MinNbRowsPerThread indicates how many rows each
227  * thread should at least process. This is used to compute the number of
228  * threads actually run. This number is equal to the min between the max
229  * number of threads allowed and the number of records in the database
230  * divided by nb. */
231  void setMinNbRowsPerThread(const std::size_t nb) const;
232 
233  /// returns the minimum of rows that each thread should process
234  std::size_t minNbRowsPerThread() const;
235 
236  /// returns the counts over all the variables in an IdCondSet
237  /** @param ids the idset of the variables over which we perform countings.
238  * @param check_discrete_vars The record counter can only produce correct
239  * results on sets of discrete variables. By default, the method does not
240  * check whether the variables corresponding to the IdCondSet are actually
241  * discrete. If check_discrete_vars is set to true, then this check is
242  * performed before computing the counting vector. In this case, if a
243  * variable is not discrete, a TypeError exception is raised.
244  * @return a vector containing the multidimensional contingency table
245  * over all the variables corresponding to the ids passed in argument
246  * (both at the left hand side and right hand side of the conditioning
247  * bar of the IdCondSet). The first dimension is that of the first variable
248  * in the IdCondSet, i.e., when its value increases by 1, the offset in the
249  * output vector also increases by 1. The second dimension is that of the
250  * second variable in the IdCondSet, i.e., when its value increases by 1, the
251  * offset in the ouput vector increases by the domain size of the first
252  * variable. For the third variable, the offset corresponds to the product
253  * of the domain sizes of the first two variables, and so on.
254  * @warning The vector returned by the function may differ from one
255  * call to another. So, care must be taken. E,g. a code like:
256  * @code
257  * const std::vector< double, ALLOC<double> >&
258  * counts = counter.counts(ids);
259  * counts = counter.counts(other_ids);
260  * @endcode
261  * may be erroneous because the two calls to method counts() may
262  * return references to different vectors. The correct way of using method
263  * counts() is always to call it declaring a new reference variable:
264  * @code
265  * const std::vector< double, ALLOC<double> >& counts =
266  * counter.counts(ids);
267  * const std::vector< double, ALLOC<double> >& other_counts =
268  * counter.counts(other_ids);
269  * @endcode
270  * @throw TypeError is raised if check_discrete_vars is set to true (i.e.,
271  * we check that all variables in the IdCondSet are discrete) and if at least
272  * one variable is not of a discrete nature.
273  */
274  const std::vector< double, ALLOC< double > >& counts(const IdCondSet< ALLOC >& ids,
275  const bool check_discrete_vars = false);
276 
277  /// sets new ranges to perform the countings
278  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
279  * indices. The countings are then performed only on the union of the
280  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
281  * cross validation tasks, in which part of the database should be ignored.
282  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
283  * the whole database. */
284  template < template < typename > class XALLOC >
285  void setRanges(
286  const std::vector< std::pair< std::size_t, std::size_t >,
287  XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
288 
289  /// reset the ranges to the one range corresponding to the whole database
290  void clearRanges();
291 
292  /// returns the current ranges
293  const std::vector< std::pair< std::size_t, std::size_t >,
294  ALLOC< std::pair< std::size_t, std::size_t > > >&
295  ranges() const;
296 
297  /// assign a new Bayes net to all the counter's generators depending on a BN
298  /** Typically, generators based on EM or K-means depend on a model to
299  * compute correctly their outputs. Method setBayesNet enables to
300  * update their BN model. */
301  template < typename GUM_SCALAR >
302  void setBayesNet(const BayesNet< GUM_SCALAR >& new_bn);
303 
304  /// returns the allocator used
305  allocator_type getAllocator() const;
306 
307  /// returns the mapping from ids to column positions in the database
308  /** @warning An empty nodeId2Columns bijection means that the mapping is
309  * an identity, i.e., the value of a NodeId is equal to the index of the
310  * column in the DatabaseTable. */
311  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2Columns() const;
312 
313  /// returns the database on which we perform the counts
314  const DatabaseTable< ALLOC >& database() const;
315 
316  /// @}
317 
318 
319 #ifndef DOXYGEN_SHOULD_SKIP_THIS
320 
321  private:
322  // the parsers used by the threads
323  std::vector< ThreadData< DBRowGeneratorParser< ALLOC > >,
324  ALLOC< ThreadData< DBRowGeneratorParser< ALLOC > > > >
325  _parsers_;
326 
327  // the set of ranges of the database's rows indices over which the user
328  // wishes to perform the countings
329  std::vector< std::pair< std::size_t, std::size_t >,
330  ALLOC< std::pair< std::size_t, std::size_t > > >
331  _ranges_;
332 
333  // the ranges actually used by the threads: there is a hopefully clever
334  // algorithm that split the rows ranges into another set of ranges that
335  // are assigned to the threads. For instance, if the database has 1000
336  // rows and there are 10 threads, each one will be assed a set of 100
337  // rows. These sets are precisely what are stored in the field below
338  mutable std::vector< std::pair< std::size_t, std::size_t >,
339  ALLOC< std::pair< std::size_t, std::size_t > > >
340  _thread_ranges_;
341 
342  // the mapping from the NodeIds of the variables to the indices of the
343  // columns in the database
344  Bijection< NodeId, std::size_t, ALLOC< std::size_t > > _nodeId2columns_;
345 
346  // the last database-parsed countings
347  std::vector< double, ALLOC< double > > _last_DB_countings_;
348 
349  // the ids of the nodes for the last database-parsed countings
350  IdCondSet< ALLOC > _last_DB_ids_;
351 
352  // the last countings deduced from _last_DB_countings_
353  std::vector< double, ALLOC< double > > _last_nonDB_countings_;
354 
355  // the ids of the nodes of last countings deduced from _last_DB_countings_
356  IdCondSet< ALLOC > _last_nonDB_ids_;
357 
358  // the maximal number of threads that the record counter can use
359  mutable std::size_t _max_nb_threads_{std::size_t(gum::getMaxNumberOfThreads())};
360 
361  // the min number of rows that a thread should process in a
362  // multithreading context
363  mutable std::size_t _min_nb_rows_per_thread_{100};
364 
365  // returns a mapping from the nodes ids to the columns of the database
366  // for a given sequence of ids. This is especially convenient when
367  // _nodeId2columns_ is empty (which means that there is an identity mapping)
368  HashTable< NodeId, std::size_t > _getNodeIds2Columns_(const IdCondSet< ALLOC >& ids) const;
369 
370  /// extracts some new countings from previously computed ones
371  std::vector< double, ALLOC< double > >&
372  _extractFromCountings_(const IdCondSet< ALLOC >& subset_ids,
373  const IdCondSet< ALLOC >& superset_ids,
374  const std::vector< double, ALLOC< double > >& superset_vect);
375 
376  /// parse the database to produce new countings
377  std::vector< double, ALLOC< double > >& _countFromDatabase_(const IdCondSet< ALLOC >& ids);
378 
379  /// the method used by threads to produce countings by parsing the database
380  void _threadedCount_(
381  const std::size_t range_begin,
382  const std::size_t range_end,
383  DBRowGeneratorParser< ALLOC >& parser,
384  const std::vector< std::pair< std::size_t, std::size_t >,
385  ALLOC< std::pair< std::size_t, std::size_t > > >& cols_and_offsets,
386  std::vector< double, ALLOC< double > >& countings);
387 
388  /// checks that the ranges passed in argument are ok or raise an exception
389  /** A range is ok if its upper bound is strictly higher than its lower
390  * bound and the latter is also lower than or equal to the number of rows
391  * in the database. */
392  template < template < typename > class XALLOC >
393  void _checkRanges_(
394  const std::vector< std::pair< std::size_t, std::size_t >,
395  XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges) const;
396 
397  /// check that the variables at indices [beg,end) of an idset are discrete
398  /** @throw TypeError is raised if at least one variable in ids is
399  * of a continuous nature. */
400  void _checkDiscreteVariables_(const IdCondSet< ALLOC >& ids) const;
401 
402  /// compute and raise the exception when some variables are continuous
403  /** This method is used by _checkDiscreteVariables_ to determine the
404  * appropriate message to include in the TypeError exception raised when
405  * some variables over which we should perform countings are continuous. */
406  void _raiseCheckException_(
407  const std::vector< std::string, ALLOC< std::string > >& bad_vars) const;
408 
409  /// sets the ranges within which each thread will perform its computations
410  void _dispatchRangesToThreads_();
411 
412 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
413  };
414 
415  } /* namespace learning */
416 
417 } /* namespace gum */
418 
419 /// always include the templated implementations
420 #include <agrum/tools/stattests/recordCounter_tpl.h>
421 
422 #endif /* GUM_LEARNING_RECORD_COUNTER_H */