25 #ifndef DOXYGEN_SHOULD_SKIP_THIS 32 template <
template <
typename >
class ALLOC >
35 return _counter.getAllocator();
40 template <
template <
typename >
class ALLOC >
42 const DBRowGeneratorParser< ALLOC >& parser,
43 const Apriori< ALLOC >& apriori,
44 const std::vector< std::pair< std::size_t, std::size_t >,
45 ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
46 const Bijection<
NodeId, std::size_t, ALLOC< std::size_t > >&
49 _apriori(apriori.clone(alloc)),
50 _counter(parser, ranges, nodeId2columns, alloc), _cache(alloc) {
56 template <
template <
typename >
class ALLOC >
58 const DBRowGeneratorParser< ALLOC >& parser,
59 const Apriori< ALLOC >& apriori,
60 const Bijection<
NodeId, std::size_t, ALLOC< std::size_t > >&
70 template <
template <
typename >
class ALLOC >
72 const IndependenceTest< ALLOC >& from,
82 template <
template <
typename >
class ALLOC >
84 const IndependenceTest< ALLOC >& from) :
89 template <
template <
typename >
class ALLOC >
91 IndependenceTest< ALLOC >&& from,
96 from._apriori =
nullptr;
102 template <
template <
typename >
class ALLOC >
104 IndependenceTest< ALLOC >&& from) :
109 template <
template <
typename >
class ALLOC >
112 ALLOC< Apriori< ALLOC > > allocator(this->
getAllocator());
121 template <
template <
typename >
class ALLOC >
123 operator=(
const IndependenceTest< ALLOC >& from) {
125 Apriori< ALLOC >* new_apriori = from._apriori->clone();
126 RecordCounter< ALLOC > new_counter = from._counter;
127 ScoringCache< ALLOC > new_cache = from._cache;
130 ALLOC< Apriori< ALLOC > > allocator(this->
getAllocator());
137 _cache = std::move(new_cache);
146 template <
template <
typename >
class ALLOC >
148 operator=(IndependenceTest< ALLOC >&& from) {
152 _counter = std::move(from._counter);
153 _cache = std::move(from._cache);
161 template <
template <
typename >
class ALLOC >
168 template <
template <
typename >
class ALLOC >
176 template <
template <
typename >
class ALLOC >
178 const std::size_t nb)
const {
184 template <
template <
typename >
class ALLOC >
186 return _counter.minNbRowsPerThread();
197 template <
template <
typename >
class ALLOC >
198 template <
template <
typename >
class XALLOC >
200 const std::vector< std::pair< std::size_t, std::size_t >,
201 XALLOC< std::pair< std::size_t, std::size_t > > >&
203 std::vector< std::pair< std::size_t, std::size_t >,
204 ALLOC< std::pair< std::size_t, std::size_t > > >
212 template <
template <
typename >
class ALLOC >
214 std::vector< std::pair< std::size_t, std::size_t >,
215 ALLOC< std::pair< std::size_t, std::size_t > > >
223 template <
template <
typename >
class ALLOC >
224 INLINE
const std::vector< std::pair< std::size_t, std::size_t >,
225 ALLOC< std::pair< std::size_t, std::size_t > > >&
232 template <
template <
typename >
class ALLOC >
235 IdSet< ALLOC > idset(
239 return _cache.score(idset);
240 }
catch (NotFound&) {}
241 double the_score =
_score(idset);
242 _cache.insert(std::move(idset), the_score);
245 return _score(std::move(idset));
251 template <
template <
typename >
class ALLOC >
255 const std::vector<
NodeId, ALLOC< NodeId > >& rhs_ids) {
256 IdSet< ALLOC > idset(
257 var1, var2, rhs_ids,
false,
false, this->
getAllocator());
260 return _cache.score(idset);
261 }
catch (NotFound&) {}
262 double the_score =
_score(idset);
263 _cache.insert(std::move(idset), the_score);
272 template <
template <
typename >
class ALLOC >
280 template <
template <
typename >
class ALLOC >
287 template <
template <
typename >
class ALLOC >
294 template <
template <
typename >
class ALLOC >
295 INLINE
const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >&
302 template <
template <
typename >
class ALLOC >
303 INLINE
const DatabaseTable< ALLOC >&
315 template <
template <
typename >
class ALLOC >
317 const std::size_t node_2_marginalize,
318 const std::size_t X_size,
319 const std::size_t Y_size,
320 const std::size_t Z_size,
321 const std::vector<
double, ALLOC< double > >& N_xyz)
const {
323 std::size_t out_size = Z_size;
324 if (node_2_marginalize == std::size_t(0))
326 else if (node_2_marginalize == std::size_t(1))
330 std::vector< double, ALLOC< double > > res(out_size, 0.0);
333 if (node_2_marginalize == std::size_t(0)) {
334 for (std::size_t yz = std::size_t(0), xyz = std::size_t(0); yz < out_size;
336 for (std::size_t x = std::size_t(0); x < X_size; ++x, ++xyz) {
337 res[yz] += N_xyz[xyz];
340 }
else if (node_2_marginalize == std::size_t(1)) {
341 for (std::size_t z = std::size_t(0),
342 xyz = std::size_t(0),
343 beg_xz = std::size_t(0);
345 ++z, beg_xz += X_size) {
346 for (std::size_t y = std::size_t(0); y < Y_size; ++y) {
347 for (std::size_t x = std::size_t(0), xz = beg_xz; x < X_size;
349 res[xz] += N_xyz[xyz];
353 }
else if (node_2_marginalize == std::size_t(2)) {
354 const std::size_t XY_size = X_size * Y_size;
355 for (std::size_t z = std::size_t(0), xyz = std::size_t(0); z < out_size;
357 for (std::size_t xy = std::size_t(0); xy < XY_size; ++xy, ++xyz) {
358 res[z] += N_xyz[xyz];
363 "_marginalize not implemented for nodeset " 364 << node_2_marginalize);
ScoringCache< ALLOC > _cache
the scoring cache
std::vector< double, ALLOC< double > > _marginalize(const std::size_t node_2_marginalize, const std::size_t X_size, const std::size_t Y_size, const std::size_t Z_size, const std::vector< double, ALLOC< double > > &N_xyz) const
returns a counting vector where variables are marginalized from N_xyz
ALLOC< NodeId > allocator_type
type for the allocators passed in arguments of methods
RecordCounter< ALLOC > _counter
the record counter used for the countings over discrete variables
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > & ranges() const
returns the current ranges
void swap(HashTable< LpCol, double > *&a, HashTable< LpCol, double > *&b)
Swap the addresses of two pointers to hashTables.
virtual void clearCache()
clears the current cache
const DatabaseTable< ALLOC > & database() const
return the database used by the score
const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > & nodeId2Columns() const
return the mapping between the columns of the database and the node ids
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context ...
gum is the global namespace for all aGrUM entities
virtual double _score(const IdSet< ALLOC > &idset)=0
returns the score for a given IdSet
const std::vector< NodeId, ALLOC< NodeId > > _empty_ids
an empty vector
IndependenceTest(const DBRowGeneratorParser< ALLOC > &parser, const Apriori< ALLOC > &external_apriori, const std::vector< std::pair< std::size_t, std::size_t >, ALLOC< std::pair< std::size_t, std::size_t > > > &ranges, const Bijection< NodeId, std::size_t, ALLOC< std::size_t > > &nodeId2columns=Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), const allocator_type &alloc=allocator_type())
default constructor
allocator_type getAllocator() const
returns the allocator used by the score
void clearRanges()
reset the ranges to the one range corresponding to the whole database
virtual IndependenceTest< ALLOC > * clone() const =0
virtual copy constructor
virtual void setMaxNbThreads(std::size_t nb) const
changes the max number of threads used to parse the database
IndependenceTest< ALLOC > & operator=(const IndependenceTest< ALLOC > &from)
copy operator
double score(const NodeId var1, const NodeId var2)
returns the score of a pair of nodes
virtual ~IndependenceTest()
destructor
bool _use_cache
a Boolean indicating whether we wish to use the cache
virtual void useCache(const bool on_off)
turn on/off the use of a cache of the previously computed score
virtual std::size_t nbThreads() const
returns the number of threads used to parse the database
Size NodeId
Type for node ids.
virtual void clear()
clears all the data structures from memory, including the cache
#define GUM_ERROR(type, msg)
void setRanges(const std::vector< std::pair< std::size_t, std::size_t >, XALLOC< std::pair< std::size_t, std::size_t > > > &new_ranges)
sets new ranges to perform the countings used by the independence test
Apriori< ALLOC > * _apriori
the expert knowledge a priori we add to the contongency tables