27 #ifndef DOXYGEN_SHOULD_SKIP_THIS 34 template <
template <
typename >
class ALLOC >
37 return __NH.getAllocator();
42 template <
template <
typename >
class ALLOC >
44 const DBRowGeneratorParser< ALLOC >& parser,
45 const Apriori< ALLOC >& apriori,
46 const std::vector< std::pair< std::size_t, std::size_t >,
47 ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
48 const Bijection<
NodeId, std::size_t, ALLOC< std::size_t > >&
51 __NH(parser, apriori, ranges, nodeId2columns, alloc),
52 __k_NML(parser, apriori, ranges, nodeId2columns, alloc),
53 __score_MDL(parser, apriori, ranges, nodeId2columns, alloc),
54 __ICache(alloc), __KCache(alloc) {
60 template <
template <
typename >
class ALLOC >
62 const DBRowGeneratorParser< ALLOC >& parser,
63 const Apriori< ALLOC >& apriori,
64 const Bijection<
NodeId, std::size_t, ALLOC< std::size_t > >&
67 __NH(parser, apriori, nodeId2columns, alloc),
68 __k_NML(parser, apriori, nodeId2columns, alloc),
69 __score_MDL(parser, apriori, nodeId2columns, alloc), __ICache(alloc),
76 template <
template <
typename >
class ALLOC >
78 const CorrectedMutualInformation< ALLOC >& from,
80 __NH(from.__NH, alloc),
81 __k_NML(from.__k_NML, alloc), __score_MDL(from.__score_MDL, alloc),
82 __kmode(from.__kmode), __use_ICache(from.__use_ICache),
83 __use_HCache(from.__use_HCache), __use_KCache(from.__use_KCache),
84 __use_CnrCache(from.__use_CnrCache), __ICache(from.__ICache, alloc),
85 __KCache(from.__KCache, alloc) {
91 template <
template <
typename >
class ALLOC >
93 const CorrectedMutualInformation< ALLOC >& from) :
98 template <
template <
typename >
class ALLOC >
100 CorrectedMutualInformation< ALLOC >&& from,
102 __NH(
std::move(from.__NH), alloc),
103 __k_NML(
std::move(from.__k_NML), alloc),
104 __score_MDL(
std::move(from.__score_MDL), alloc), __kmode(from.__kmode),
105 __use_ICache(from.__use_ICache), __use_HCache(from.__use_HCache),
106 __use_KCache(from.__use_KCache), __use_CnrCache(from.__use_CnrCache),
107 __ICache(
std::move(from.__ICache), alloc),
108 __KCache(
std::move(from.__KCache), alloc) {
114 template <
template <
typename >
class ALLOC >
116 CorrectedMutualInformation< ALLOC >&& from) :
121 template <
template <
typename >
class ALLOC >
122 CorrectedMutualInformation< ALLOC >*
126 ALLOC< CorrectedMutualInformation< ALLOC > > allocator(alloc);
127 CorrectedMutualInformation< ALLOC >* new_score = allocator.allocate(1);
129 allocator.construct(new_score, *
this, alloc);
131 allocator.deallocate(new_score, 1);
140 template <
template <
typename >
class ALLOC >
141 CorrectedMutualInformation< ALLOC >*
148 template <
template <
typename >
class ALLOC >
156 template <
template <
typename >
class ALLOC >
158 operator=(
const CorrectedMutualInformation< ALLOC >& from) {
161 __k_NML = from.__k_NML;
162 __score_MDL = from.__score_MDL;
163 __kmode = from.__kmode;
164 __use_ICache = from.__use_ICache;
165 __use_HCache = from.__use_HCache;
166 __use_KCache = from.__use_KCache;
167 __use_CnrCache = from.__use_CnrCache;
168 __ICache = from.__ICache;
169 __KCache = from.__KCache;
176 template <
template <
typename >
class ALLOC >
178 operator=(CorrectedMutualInformation< ALLOC >&& from) {
180 __NH = std::move(from.__NH);
181 __k_NML = std::move(from.__k_NML);
182 __score_MDL = std::move(from.__score_MDL);
183 __kmode = from.__kmode;
184 __use_ICache = from.__use_ICache;
185 __use_HCache = from.__use_HCache;
186 __use_KCache = from.__use_KCache;
187 __use_CnrCache = from.__use_CnrCache;
188 __ICache = std::move(from.__ICache);
189 __KCache = std::move(from.__KCache);
196 template <
template <
typename >
class ALLOC >
206 template <
template <
typename >
class ALLOC >
208 if (!on_off) __ICache.clear();
209 __use_ICache = on_off;
214 template <
template <
typename >
class ALLOC >
216 if (!on_off) __NH.clearCache();
217 __use_HCache = on_off;
218 __NH.useCache(on_off);
223 template <
template <
typename >
class ALLOC >
225 if (!on_off) __KCache.clear();
226 __use_KCache = on_off;
231 template <
template <
typename >
class ALLOC >
233 if (!on_off) __k_NML.clearCache();
234 __use_CnrCache = on_off;
235 __k_NML.useCache(on_off);
240 template <
template <
typename >
class ALLOC >
250 template <
template <
typename >
class ALLOC >
253 __k_NML.clearCache();
260 template <
template <
typename >
class ALLOC >
267 template <
template <
typename >
class ALLOC >
274 template <
template <
typename >
class ALLOC >
281 template <
template <
typename >
class ALLOC >
283 __k_NML.clearCache();
288 template <
template <
typename >
class ALLOC >
291 __NH.setMaxNbThreads(nb);
292 __k_NML.setMaxNbThreads(nb);
293 __score_MDL.setMaxNbThreads(nb);
298 template <
template <
typename >
class ALLOC >
300 return __NH.nbThreads();
306 template <
template <
typename >
class ALLOC >
308 const std::size_t nb)
const {
309 __NH.setMinNbRowsPerThread(nb);
310 __k_NML.setMinNbRowsPerThread(nb);
311 __score_MDL.setMinNbRowsPerThread(nb);
316 template <
template <
typename >
class ALLOC >
319 return __NH.minNbRowsPerThread();
330 template <
template <
typename >
class ALLOC >
331 template <
template <
typename >
class XALLOC >
333 const std::vector< std::pair< std::size_t, std::size_t >,
334 XALLOC< std::pair< std::size_t, std::size_t > > >&
336 std::vector< std::pair< std::size_t, std::size_t >,
337 ALLOC< std::pair< std::size_t, std::size_t > > >
340 __NH.setRanges(new_ranges);
341 __k_NML.setRanges(new_ranges);
342 __score_MDL.setRanges(new_ranges);
349 template <
template <
typename >
class ALLOC >
351 std::vector< std::pair< std::size_t, std::size_t >,
352 ALLOC< std::pair< std::size_t, std::size_t > > >
355 __k_NML.clearRanges();
356 __score_MDL.clearRanges();
362 template <
template <
typename >
class ALLOC >
363 INLINE
const std::vector< std::pair< std::size_t, std::size_t >,
364 ALLOC< std::pair< std::size_t, std::size_t > > >&
366 return __NH.ranges();
371 template <
template <
typename >
class ALLOC >
379 template <
template <
typename >
class ALLOC >
387 template <
template <
typename >
class ALLOC >
395 template <
template <
typename >
class ALLOC >
398 return score(var1, var2, __empty_conditioning_set);
403 template <
template <
typename >
class ALLOC >
407 const std::vector<
NodeId, ALLOC< NodeId > >& conditioning_ids) {
408 return __NI_score(var1, var2, conditioning_ids)
409 - __K_score(var1, var2, conditioning_ids);
414 template <
template <
typename >
class ALLOC >
418 return score(var1, var2, var3, __empty_conditioning_set);
423 template <
template <
typename >
class ALLOC >
428 const std::vector<
NodeId, ALLOC< NodeId > >& conditioning_ids) {
429 return __NI_score(var1, var2, var3, conditioning_ids)
430 + __K_score(var1, var2, var3, conditioning_ids);
435 template <
template <
typename >
class ALLOC >
436 double CorrectedMutualInformation< ALLOC >::__NI_score(
439 const std::vector<
NodeId, ALLOC< NodeId > >& vars_z) {
456 const IdSet< ALLOC > idset_xyz(var_x, var_y, vars_z,
false,
false);
459 return __ICache.score(idset_xyz);
460 }
catch (
const NotFound&) {}
468 if (!vars_z.empty()) {
469 std::vector< NodeId, ALLOC< NodeId > > vars(vars_z);
471 vars.push_back(var_x);
472 vars.push_back(var_y);
473 const double NHxyz = -__NH.score(IdSet< ALLOC >(vars,
false,
true));
476 const double NHxz = -__NH.score(IdSet< ALLOC >(vars,
false,
true));
479 vars.push_back(var_y);
480 const double NHyz = -__NH.score(IdSet< ALLOC >(vars,
false,
true));
483 const double NHz = -__NH.score(IdSet< ALLOC >(vars,
false,
true));
485 const double NHxz_NHyz = NHxz + NHyz;
486 double NHz_NHxyz = NHz + NHxyz;
491 ratio = (NHxz_NHyz - NHz_NHxyz) / NHxz_NHyz;
492 }
else if (NHz_NHxyz > 0) {
493 ratio = (NHxz_NHyz - NHz_NHxyz) / NHz_NHxyz;
495 if (ratio < 0) ratio = -ratio;
496 if (ratio < __threshold) {
497 NHz_NHxyz = NHxz_NHyz;
500 score = NHxz_NHyz - NHz_NHxyz;
502 const double NHxy = -__NH.score(
503 IdSet< ALLOC >(var_x, var_y, __empty_conditioning_set,
true,
false));
504 const double NHx = -__NH.score(var_x);
505 const double NHy = -__NH.score(var_y);
507 double NHx_NHy = NHx + NHy;
512 ratio = (NHx_NHy - NHxy) / NHx_NHy;
513 }
else if (NHxy > 0) {
514 ratio = (NHx_NHy - NHxy) / NHxy;
516 if (ratio < 0) ratio = -ratio;
517 if (ratio < __threshold) {
521 score = NHx_NHy - NHxy;
526 if (__use_ICache) { __ICache.insert(idset_xyz, score); }
533 template <
template <
typename >
class ALLOC >
534 INLINE
double CorrectedMutualInformation< ALLOC >::__NI_score(
538 const std::vector<
NodeId, ALLOC< NodeId > >& ui_ids) {
541 std::vector< NodeId, ALLOC< NodeId > > uiz_ids = ui_ids;
542 uiz_ids.push_back(var_z);
543 return __NI_score(var_x, var_y, ui_ids) - __NI_score(var_x, var_y, uiz_ids);
548 template <
template <
typename >
class ALLOC >
549 double CorrectedMutualInformation< ALLOC >::__K_score(
552 const std::vector<
NodeId, ALLOC< NodeId > >& conditioning_ids) {
558 IdSet< ALLOC > idset;
560 idset = std::move(IdSet< ALLOC >(var1, var2, conditioning_ids,
false));
562 return __KCache.score(idset);
563 }
catch (
const NotFound&) {}
571 const auto& database = __NH.database();
572 const auto& node2cols = __NH.nodeId2Columns();
575 if (!node2cols.empty()) {
576 rx = database.domainSize(node2cols.second(var1));
577 ry = database.domainSize(node2cols.second(var2));
578 for (
const NodeId i : conditioning_ids) {
579 rui *= database.domainSize(node2cols.second(i));
582 rx = database.domainSize(var1);
583 ry = database.domainSize(var2);
584 for (
const NodeId i : conditioning_ids) {
585 rui *= database.domainSize(i);
591 idset = std::move(IdSet< ALLOC >(var1, var2, conditioning_ids,
false));
593 const double N = __score_MDL.N(idset);
595 score = 0.5 * (rx - 1) * (ry - 1) * rui * std::log2(N);
599 score = __k_NML.score(var1, var2, conditioning_ids);
604 "CorrectedMutualInformation mode does " 605 "not support yet this correction");
609 if (__use_KCache) { __KCache.insert(idset, score); }
615 template <
template <
typename >
class ALLOC >
616 INLINE
double CorrectedMutualInformation< ALLOC >::__K_score(
620 const std::vector<
NodeId, ALLOC< NodeId > >& ui_ids) {
622 std::vector< NodeId, ALLOC< NodeId > > uiz_ids = ui_ids;
623 uiz_ids.push_back(var3);
624 return __K_score(var1, var2, uiz_ids) - __K_score(var1, var2, ui_ids);
gum is the global namespace for all aGrUM entities
Size NodeId
Type for node ids.
#define GUM_ERROR(type, msg)