aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
recordCounter.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief The class that computes countings of observations from the database.
24  *
25  * This class is the one to be called by scores and independence tests to
26  * compute countings of observations from tabular databases.
27  *
28  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
29  */
30 #ifndef GUM_LEARNING_RECORD_COUNTER_H
31 #define GUM_LEARNING_RECORD_COUNTER_H
32 
33 #include <vector>
34 #include <utility>
35 #include <sstream>
36 #include <string>
37 
38 #include <agrum/agrum.h>
39 #include <agrum/tools/core/bijection.h>
40 #include <agrum/tools/core/sequence.h>
41 #include <agrum/tools/core/OMPThreads.h>
42 #include <agrum/tools/core/threadData.h>
43 #include <agrum/tools/graphs/DAG.h>
44 #include <agrum/tools/database/DBRowGeneratorParser.h>
45 #include <agrum/tools/stattests/idCondSet.h>
46 
47 
48 namespace gum {
49 
50  namespace learning {
51 
52  /** @class RecordCounter
53  * @brief The class that computes countings of observations from the database.
54  * @headerfile recordCounter.h <agrum/BN/learning/scores_and_tests/recordCounter.h>
55  * @ingroup learning_scores
56  *
57  * This class is the one to be called by scores and independence tests to
58  * compute the countings of observations from tabular datasets they need.
59  * The countings are performed the following way:
60  * when asked for the countings over a set X = {X_1,...,X_n} of
61  * variables, the RecordCounter first checks whether it already contains
62  * some countings over a set Y of variables containing X. If this is the
63  * case, then it extracts from the countings over Y those over X (this is
64  * usually way faster than determining the countings by parsing the database).
65  * Otherwise, it determines the countings over X by parsing in a parallel
66  * way the database. Only the result of the last database-parsed countings
67  * is available for the subset counting determination. As an example, if
68  * we create a RecordCounter and ask it the countings over {A,B,C}, it will
69  * parse the database and provide the countings. Then, if we ask it countings
70  * over B, it will use the table over {A,B,C} to produce the countings we
71  * look for. Then, asking for countings over {A,C} will be performed the same
72  * way. Now, asking countings over {B,C,D} will require another database
73  * parsing. Finally, if we ask for countings over A, a new database parsing
74  * will be performed because only the countings over {B,C,D} are now contained
75  * in the RecordCounter.
76  *
77  * @par Here is an example of how to use the RecordCounter class:
78  * @code
79  * // here, write the code to construct your database, e.g.:
80  * gum::learning::DBInitializerFromCSV<> initializer( "file.csv" );
81  * const auto& var_names = initializer.variableNames();
82  * const std::size_t nb_vars = var_names.size();
83  * gum::learning::DBTranslatorSet<> translator_set;
84  * gum::learning::DBTranslator4ContinuousVariable<> translator;
85  * for (std::size_t i = 0; i < nb_vars; ++i) {
86  * translator_set.insertTranslator(translator, i);
87  * }
88  * gum::learning::DatabaseTable<> database(translator_set);
89  *
90  * // create the parser of the database
91  * gum::learning::DBRowGeneratorSet<> genset;
92  * gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset);
93  *
94  * // create the record counter
95  * gum::learning::RecordCounter<> counter(parser);
96  *
97  * // get the counts:
98  * gum::learning::IdCondSet<> ids ( 0, gum::vector<gum::NodeId> {2,1} );
99  * const std::vector< double >& counts1 = counter.counts ( ids );
100  *
101  * // change the rows from which we compute the counts:
102  * // they should now be made on rows [500,600) U [1050,1125) U [100,150)
103  * std::vector<std::pair<std::size_t,std::size_t>> new_ranges
104  * { std::pair<std::size_t,std::size_t>(500,600),
105  * std::pair<std::size_t,std::size_t>(1050,1125),
106  * std::pair<std::size_t,std::size_t>(100,150) };
107  * counter.setRanges ( new_ranges );
108  * const std::vector< double >& counts2 = counter.counts ( ids );
109  * @endcode
110  */
111  template < template < typename > class ALLOC = std::allocator >
112  class RecordCounter {
113  public:
114  /// type for the allocators passed in arguments of methods
115  using allocator_type = ALLOC< NodeId >;
116 
117  // ##########################################################################
118  /// @name Constructors / Destructors
119  // ##########################################################################
120  /// @{
121 
122  /// default constructor
123  /** @param parser the parser used to parse the database
124  * @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
125  * indices. The countings are then performed only on the union of the
126  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
127  * cross validation tasks, in which part of the database should be ignored.
128  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
129  * the whole database.
130  * @param nodeId2Columns a mapping from the ids of the nodes in the
131  * graphical model to the corresponding column in the DatabaseTable
132  * parsed by the parser. This enables estimating from a database in
133  * which variable A corresponds to the 2nd column the parameters of a BN
134  * in which variable A has a NodeId of 5. An empty nodeId2Columns
135  * bijection means that the mapping is an identity, i.e., the value of a
136  * NodeId is equal to the index of the column in the DatabaseTable.
137  * @param alloc the allocator used to allocate the structures within the
138  * RecordCounter.
139  * @warning If nodeId2columns is not empty, then only the counts over the
140  * ids belonging to this bijection can be computed: applying method
141  * counts() over other ids will raise exception NotFound. */
142  RecordCounter(
143  const DBRowGeneratorParser< ALLOC >& parser,
144  const std::vector< std::pair< std::size_t, std::size_t >,
145  ALLOC< std::pair< std::size_t, std::size_t > > >&
146  ranges,
147  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
148  nodeId2columns
149  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
150  const allocator_type& alloc = allocator_type());
151 
152  /// default constructor
153  /** @param parser the parser used to parse the database
154  * @param nodeId2Columns a mapping from the ids of the nodes in the
155  * graphical model to the corresponding column in the DatabaseTable
156  * parsed by the parser. This enables estimating from a database in
157  * which variable A corresponds to the 2nd column the parameters of a BN
158  * in which variable A has a NodeId of 5. An empty nodeId2Columns
159  * bijection means that the mapping is an identity, i.e., the value of a
160  * NodeId is equal to the index of the column in the DatabaseTable.
161  * @param alloc the allocator used to allocate the structures within the
162  * RecordCounter.
163  * @warning If nodeId2columns is not empty, then only the counts over the
164  * ids belonging to this bijection can be computed: applying method
165  * counts() over other ids will raise exception NotFound. */
166  RecordCounter(const DBRowGeneratorParser< ALLOC >& parser,
167  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
168  nodeId2columns
169  = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
170  const allocator_type& alloc = allocator_type());
171 
172  /// copy constructor
173  RecordCounter(const RecordCounter< ALLOC >& from);
174 
175  /// copy constructor with a given allocator
176  RecordCounter(const RecordCounter< ALLOC >& from,
177  const allocator_type& alloc);
178 
179  /// move constructor
180  RecordCounter(RecordCounter< ALLOC >&& from);
181 
182  /// move constructor with a given allocator
183  RecordCounter(RecordCounter< ALLOC >&& from, const allocator_type& alloc);
184 
185  /// virtual copy constructor
186  virtual RecordCounter< ALLOC >* clone() const;
187 
188  /// virtual copy constructor with a given allocator
189  virtual RecordCounter< ALLOC >* clone(const allocator_type& alloc) const;
190 
191  /// destructor
192  virtual ~RecordCounter();
193 
194  /// @}
195 
196 
197  // ##########################################################################
198  /// @name Operators
199  // ##########################################################################
200 
201  /// @{
202 
203  /// copy operator
204  RecordCounter< ALLOC >& operator=(const RecordCounter< ALLOC >& from);
205 
206  /// move operator
207  RecordCounter< ALLOC >& operator=(RecordCounter< ALLOC >&& from);
208 
209  /// @}
210 
211 
212  // ##########################################################################
213  /// @name Accessors / Modifiers
214  // ##########################################################################
215 
216  /// @{
217 
218  /// clears all the last database-parsed countings from memory
219  void clear();
220 
221  /// changes the max number of threads used to parse the database
222  void setMaxNbThreads(const std::size_t nb) const;
223 
224  /// returns the number of threads used to parse the database
225  std::size_t nbThreads() const;
226 
227  /** @brief changes the number min of rows a thread should process in a
228  * multithreading context
229  *
230  * When Method counts executes several threads to perform countings on the
231  * rows of the database, the MinNbRowsPerThread indicates how many rows each
232  * thread should at least process. This is used to compute the number of
233  * threads actually run. This number is equal to the min between the max
234  * number of threads allowed and the number of records in the database
235  * divided by nb. */
236  void setMinNbRowsPerThread(const std::size_t nb) const;
237 
238  /// returns the minimum of rows that each thread should process
239  std::size_t minNbRowsPerThread() const;
240 
241  /// returns the counts over all the variables in an IdCondSet
242  /** @param ids the idset of the variables over which we perform countings.
243  * @param check_discrete_vars The record counter can only produce correct
244  * results on sets of discrete variables. By default, the method does not
245  * check whether the variables corresponding to the IdCondSet are actually
246  * discrete. If check_discrete_vars is set to true, then this check is
247  * performed before computing the counting vector. In this case, if a
248  * variable is not discrete, a TypeError exception is raised.
249  * @return a vector containing the multidimensional contingency table
250  * over all the variables corresponding to the ids passed in argument
251  * (both at the left hand side and right hand side of the conditioning
252  * bar of the IdCondSet). The first dimension is that of the first variable
253  * in the IdCondSet, i.e., when its value increases by 1, the offset in the
254  * output vector also increases by 1. The second dimension is that of the
255  * second variable in the IdCondSet, i.e., when its value increases by 1, the
256  * offset in the ouput vector increases by the domain size of the first
257  * variable. For the third variable, the offset corresponds to the product
258  * of the domain sizes of the first two variables, and so on.
259  * @warning The vector returned by the function may differ from one
260  * call to another. So, care must be taken. E,g. a code like:
261  * @code
262  * const std::vector< double, ALLOC<double> >&
263  * counts = counter.counts(ids);
264  * counts = counter.counts(other_ids);
265  * @endcode
266  * may be erroneous because the two calls to method counts() may
267  * return references to different vectors. The correct way of using method
268  * counts() is always to call it declaring a new reference variable:
269  * @code
270  * const std::vector< double, ALLOC<double> >& counts =
271  * counter.counts(ids);
272  * const std::vector< double, ALLOC<double> >& other_counts =
273  * counter.counts(other_ids);
274  * @endcode
275  * @throw TypeError is raised if check_discrete_vars is set to true (i.e.,
276  * we check that all variables in the IdCondSet are discrete) and if at least
277  * one variable is not of a discrete nature.
278  */
279  const std::vector< double, ALLOC< double > >&
280  counts(const IdCondSet< ALLOC >& ids,
281  const bool check_discrete_vars = false);
282 
283  /// sets new ranges to perform the countings
284  /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
285  * indices. The countings are then performed only on the union of the
286  * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
287  * cross validation tasks, in which part of the database should be ignored.
288  * An empty set of ranges is equivalent to an interval [X,Y) ranging over
289  * the whole database. */
290  template < template < typename > class XALLOC >
291  void setRanges(
292  const std::vector< std::pair< std::size_t, std::size_t >,
293  XALLOC< std::pair< std::size_t, std::size_t > > >&
294  new_ranges);
295 
296  /// reset the ranges to the one range corresponding to the whole database
297  void clearRanges();
298 
299  /// returns the current ranges
300  const std::vector< std::pair< std::size_t, std::size_t >,
301  ALLOC< std::pair< std::size_t, std::size_t > > >&
302  ranges() const;
303 
304  /// assign a new Bayes net to all the counter's generators depending on a BN
305  /** Typically, generators based on EM or K-means depend on a model to
306  * compute correctly their outputs. Method setBayesNet enables to
307  * update their BN model. */
308  template < typename GUM_SCALAR >
309  void setBayesNet(const BayesNet< GUM_SCALAR >& new_bn);
310 
311  /// returns the allocator used
312  allocator_type getAllocator() const;
313 
314  /// returns the mapping from ids to column positions in the database
315  /** @warning An empty nodeId2Columns bijection means that the mapping is
316  * an identity, i.e., the value of a NodeId is equal to the index of the
317  * column in the DatabaseTable. */
318  const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
319  nodeId2Columns() const;
320 
321  /// returns the database on which we perform the counts
322  const DatabaseTable< ALLOC >& database() const;
323 
324  /// @}
325 
326 
327 #ifndef DOXYGEN_SHOULD_SKIP_THIS
328 
329  private:
330  // the parsers used by the threads
331  std::vector< ThreadData< DBRowGeneratorParser< ALLOC > >,
332  ALLOC< ThreadData< DBRowGeneratorParser< ALLOC > > > >
333  parsers__;
334 
335  // the set of ranges of the database's rows indices over which the user
336  // wishes to perform the countings
337  std::vector< std::pair< std::size_t, std::size_t >,
338  ALLOC< std::pair< std::size_t, std::size_t > > >
339  ranges__;
340 
341  // the ranges actually used by the threads: there is a hopefully clever
342  // algorithm that split the rows ranges into another set of ranges that
343  // are assigned to the threads. For instance, if the database has 1000
344  // rows and there are 10 threads, each one will be assed a set of 100
345  // rows. These sets are precisely what are stored in the field below
346  mutable std::vector< std::pair< std::size_t, std::size_t >,
347  ALLOC< std::pair< std::size_t, std::size_t > > >
348  thread_ranges__;
349 
350  // the mapping from the NodeIds of the variables to the indices of the
351  // columns in the database
352  Bijection< NodeId, std::size_t, ALLOC< std::size_t > > nodeId2columns__;
353 
354  // the last database-parsed countings
355  std::vector< double, ALLOC< double > > last_DB_countings__;
356 
357  // the ids of the nodes for the last database-parsed countings
358  IdCondSet< ALLOC > last_DB_ids__;
359 
360  // the last countings deduced from last_DB_countings__
361  std::vector< double, ALLOC< double > > last_nonDB_countings__;
362 
363  // the ids of the nodes of last countings deduced from last_DB_countings__
364  IdCondSet< ALLOC > last_nonDB_ids__;
365 
366  // the maximal number of threads that the record counter can use
367  mutable std::size_t max_nb_threads__{
368  std::size_t(gum::getMaxNumberOfThreads())};
369 
370  // the min number of rows that a thread should process in a
371  // multithreading context
372  mutable std::size_t min_nb_rows_per_thread__{100};
373 
374  // returns a mapping from the nodes ids to the columns of the database
375  // for a given sequence of ids. This is especially convenient when
376  // nodeId2columns__ is empty (which means that there is an identity mapping)
377  HashTable< NodeId, std::size_t >
378  getNodeIds2Columns__(const IdCondSet< ALLOC >& ids) const;
379 
380  /// extracts some new countings from previously computed ones
381  std::vector< double, ALLOC< double > >& extractFromCountings__(
382  const IdCondSet< ALLOC >& subset_ids,
383  const IdCondSet< ALLOC >& superset_ids,
384  const std::vector< double, ALLOC< double > >& superset_vect);
385 
386  /// parse the database to produce new countings
387  std::vector< double, ALLOC< double > >&
388  countFromDatabase__(const IdCondSet< ALLOC >& ids);
389 
390  /// the method used by threads to produce countings by parsing the database
391  void threadedCount__(
392  const std::size_t range_begin,
393  const std::size_t range_end,
394  DBRowGeneratorParser< ALLOC >& parser,
395  const std::vector< std::pair< std::size_t, std::size_t >,
396  ALLOC< std::pair< std::size_t, std::size_t > > >&
397  cols_and_offsets,
398  std::vector< double, ALLOC< double > >& countings);
399 
400  /// checks that the ranges passed in argument are ok or raise an exception
401  /** A range is ok if its upper bound is strictly higher than its lower
402  * bound and the latter is also lower than or equal to the number of rows
403  * in the database. */
404  template < template < typename > class XALLOC >
405  void checkRanges__(
406  const std::vector< std::pair< std::size_t, std::size_t >,
407  XALLOC< std::pair< std::size_t, std::size_t > > >&
408  new_ranges) const;
409 
410  /// check that the variables at indices [beg,end) of an idset are discrete
411  /** @throw TypeError is raised if at least one variable in ids is
412  * of a continuous nature. */
413  void checkDiscreteVariables__(const IdCondSet< ALLOC >& ids) const;
414 
415  /// compute and raise the exception when some variables are continuous
416  /** This method is used by checkDiscreteVariables__ to determine the
417  * appropriate message to include in the TypeError exception raised when
418  * some variables over which we should perform countings are continuous. */
419  void raiseCheckException__(
420  const std::vector< std::string, ALLOC< std::string > >& bad_vars) const;
421 
422  /// sets the ranges within which each thread will perform its computations
423  void dispatchRangesToThreads__();
424 
425 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
426  };
427 
428  } /* namespace learning */
429 
430 } /* namespace gum */
431 
432 /// always include the templated implementations
433 #include <agrum/tools/stattests/recordCounter_tpl.h>
434 
435 #endif /* GUM_LEARNING_RECORD_COUNTER_H */