30 #ifndef GUM_LEARNING_RECORD_COUNTER_H 31 #define GUM_LEARNING_RECORD_COUNTER_H 38 #include <agrum/agrum.h> 39 #include <agrum/tools/core/bijection.h> 40 #include <agrum/tools/core/sequence.h> 41 #include <agrum/tools/core/OMPThreads.h> 42 #include <agrum/tools/core/threadData.h> 43 #include <agrum/tools/graphs/DAG.h> 44 #include <agrum/tools/database/DBRowGeneratorParser.h> 45 #include <agrum/tools/stattests/idCondSet.h> 111 template <
template <
typename >
class ALLOC = std::allocator >
112 class RecordCounter {
115 using allocator_type = ALLOC< NodeId >;
143 const DBRowGeneratorParser< ALLOC >& parser,
144 const std::vector< std::pair< std::size_t, std::size_t >,
145 ALLOC< std::pair< std::size_t, std::size_t > > >&
147 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
149 = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
150 const allocator_type& alloc = allocator_type());
166 RecordCounter(
const DBRowGeneratorParser< ALLOC >& parser,
167 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
169 = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
170 const allocator_type& alloc = allocator_type());
173 RecordCounter(
const RecordCounter< ALLOC >& from);
176 RecordCounter(
const RecordCounter< ALLOC >& from,
177 const allocator_type& alloc);
180 RecordCounter(RecordCounter< ALLOC >&& from);
183 RecordCounter(RecordCounter< ALLOC >&& from,
const allocator_type& alloc);
186 virtual RecordCounter< ALLOC >* clone()
const;
189 virtual RecordCounter< ALLOC >* clone(
const allocator_type& alloc)
const;
192 virtual ~RecordCounter();
204 RecordCounter< ALLOC >& operator=(
const RecordCounter< ALLOC >& from);
207 RecordCounter< ALLOC >& operator=(RecordCounter< ALLOC >&& from);
222 void setMaxNbThreads(
const std::size_t nb)
const;
225 std::size_t nbThreads()
const;
236 void setMinNbRowsPerThread(
const std::size_t nb)
const;
239 std::size_t minNbRowsPerThread()
const;
279 const std::vector<
double, ALLOC<
double > >&
280 counts(
const IdCondSet< ALLOC >& ids,
281 const bool check_discrete_vars =
false);
290 template <
template <
typename >
class XALLOC >
292 const std::vector< std::pair< std::size_t, std::size_t >,
293 XALLOC< std::pair< std::size_t, std::size_t > > >&
300 const std::vector< std::pair< std::size_t, std::size_t >,
301 ALLOC< std::pair< std::size_t, std::size_t > > >&
308 template <
typename GUM_SCALAR >
309 void setBayesNet(
const BayesNet< GUM_SCALAR >& new_bn);
312 allocator_type getAllocator()
const;
318 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
319 nodeId2Columns()
const;
322 const DatabaseTable< ALLOC >& database()
const;
327 #ifndef DOXYGEN_SHOULD_SKIP_THIS 331 std::vector< ThreadData< DBRowGeneratorParser< ALLOC > >,
332 ALLOC< ThreadData< DBRowGeneratorParser< ALLOC > > > >
337 std::vector< std::pair< std::size_t, std::size_t >,
338 ALLOC< std::pair< std::size_t, std::size_t > > >
346 mutable std::vector< std::pair< std::size_t, std::size_t >,
347 ALLOC< std::pair< std::size_t, std::size_t > > >
352 Bijection< NodeId, std::size_t, ALLOC< std::size_t > > nodeId2columns__;
355 std::vector<
double, ALLOC<
double > > last_DB_countings__;
358 IdCondSet< ALLOC > last_DB_ids__;
361 std::vector<
double, ALLOC<
double > > last_nonDB_countings__;
364 IdCondSet< ALLOC > last_nonDB_ids__;
367 mutable std::size_t max_nb_threads__{
368 std::size_t(gum::getMaxNumberOfThreads())};
372 mutable std::size_t min_nb_rows_per_thread__{100};
377 HashTable< NodeId, std::size_t >
378 getNodeIds2Columns__(
const IdCondSet< ALLOC >& ids)
const;
381 std::vector<
double, ALLOC<
double > >& extractFromCountings__(
382 const IdCondSet< ALLOC >& subset_ids,
383 const IdCondSet< ALLOC >& superset_ids,
384 const std::vector<
double, ALLOC<
double > >& superset_vect);
387 std::vector<
double, ALLOC<
double > >&
388 countFromDatabase__(
const IdCondSet< ALLOC >& ids);
391 void threadedCount__(
392 const std::size_t range_begin,
393 const std::size_t range_end,
394 DBRowGeneratorParser< ALLOC >& parser,
395 const std::vector< std::pair< std::size_t, std::size_t >,
396 ALLOC< std::pair< std::size_t, std::size_t > > >&
398 std::vector<
double, ALLOC<
double > >& countings);
404 template <
template <
typename >
class XALLOC >
406 const std::vector< std::pair< std::size_t, std::size_t >,
407 XALLOC< std::pair< std::size_t, std::size_t > > >&
413 void checkDiscreteVariables__(
const IdCondSet< ALLOC >& ids)
const;
419 void raiseCheckException__(
420 const std::vector< std::string, ALLOC< std::string > >& bad_vars)
const;
423 void dispatchRangesToThreads__();
433 #include <agrum/tools/stattests/recordCounter_tpl.h>