aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
pseudoCount_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief the base class for all the independence tests used for learning
24
25
* @author Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
26
*/
27
#
include
<
agrum
/
tools
/
stattests
/
pseudoCount
.
h
>
28
#
include
<
agrum
/
tools
/
stattests
/
idCondSet
.
h
>
29
30
#
ifndef
DOXYGEN_SHOULD_SKIP_THIS
31
32
namespace
gum
{
33
34
namespace
learning
{
35
36
/// returns the allocator used by the independence test
37
template
<
template
<
typename
>
class
ALLOC
>
38
INLINE
typename
PseudoCount
<
ALLOC
>::
allocator_type
39
PseudoCount
<
ALLOC
>::
getAllocator
()
const
{
40
return
counter_
.
getAllocator
();
41
}
42
43
/// default constructor
44
template
<
template
<
typename
>
class
ALLOC
>
45
INLINE
PseudoCount
<
ALLOC
>::
PseudoCount
(
46
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
47
const
Apriori
<
ALLOC
>&
apriori
,
48
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
49
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
ranges
,
50
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
nodeId2columns
,
51
const
typename
PseudoCount
<
ALLOC
>::
allocator_type
&
alloc
) :
52
apriori_
(
apriori
.
clone
(
alloc
)),
53
counter_
(
parser
,
ranges
,
nodeId2columns
,
alloc
) {
54
GUM_CONSTRUCTOR
(
PseudoCount
);
55
}
56
57
58
/// default constructor
59
template
<
template
<
typename
>
class
ALLOC
>
60
INLINE
PseudoCount
<
ALLOC
>::
PseudoCount
(
61
const
DBRowGeneratorParser
<
ALLOC
>&
parser
,
62
const
Apriori
<
ALLOC
>&
apriori
,
63
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
nodeId2columns
,
64
const
typename
PseudoCount
<
ALLOC
>::
allocator_type
&
alloc
) :
65
apriori_
(
apriori
.
clone
(
alloc
)),
66
counter_
(
parser
,
nodeId2columns
,
alloc
) {
67
GUM_CONSTRUCTOR
(
PseudoCount
);
68
}
69
70
71
/// copy constructor with a given allocator
72
template
<
template
<
typename
>
class
ALLOC
>
73
INLINE
PseudoCount
<
ALLOC
>::
PseudoCount
(
74
const
PseudoCount
<
ALLOC
>&
from
,
75
const
typename
PseudoCount
<
ALLOC
>::
allocator_type
&
alloc
) :
76
apriori_
(
from
.
apriori_
->
clone
(
alloc
)),
77
counter_
(
from
.
counter_
,
alloc
) {
78
GUM_CONS_CPY
(
PseudoCount
);
79
}
80
81
82
/// copy constructor
83
template
<
template
<
typename
>
class
ALLOC
>
84
INLINE
PseudoCount
<
ALLOC
>::
PseudoCount
(
const
PseudoCount
<
ALLOC
>&
from
) :
85
PseudoCount
(
from
,
from
.
getAllocator
()) {}
86
87
88
/// move constructor
89
template
<
template
<
typename
>
class
ALLOC
>
90
INLINE
PseudoCount
<
ALLOC
>::
PseudoCount
(
91
PseudoCount
<
ALLOC
>&&
from
,
92
const
typename
PseudoCount
<
ALLOC
>::
allocator_type
&
alloc
) :
93
apriori_
(
from
.
apriori_
),
94
counter_
(
std
::
move
(
from
.
counter_
),
alloc
) {
95
from
.
apriori_
=
nullptr
;
96
GUM_CONS_MOV
(
PseudoCount
);
97
}
98
99
100
/// move constructor
101
template
<
template
<
typename
>
class
ALLOC
>
102
INLINE
PseudoCount
<
ALLOC
>::
PseudoCount
(
PseudoCount
<
ALLOC
>&&
from
) :
103
PseudoCount
(
std
::
move
(
from
),
from
.
getAllocator
()) {}
104
105
106
/// destructor
107
template
<
template
<
typename
>
class
ALLOC
>
108
INLINE
PseudoCount
<
ALLOC
>::~
PseudoCount
() {
109
if
(
apriori_
!=
nullptr
) {
110
ALLOC
<
Apriori
<
ALLOC
> >
allocator
(
this
->
getAllocator
());
111
allocator
.
destroy
(
apriori_
);
112
allocator
.
deallocate
(
apriori_
, 1);
113
}
114
GUM_DESTRUCTOR
(
PseudoCount
);
115
}
116
117
118
/// copy operator
119
template
<
template
<
typename
>
class
ALLOC
>
120
PseudoCount
<
ALLOC
>&
PseudoCount
<
ALLOC
>::
operator
=(
const
PseudoCount
<
ALLOC
>&
from
) {
121
if
(
this
!= &
from
) {
122
Apriori
<
ALLOC
>*
new_apriori
=
from
.
apriori_
->
clone
();
123
RecordCounter
<
ALLOC
>
new_counter
=
from
.
counter_
;
124
125
if
(
apriori_
!=
nullptr
) {
126
ALLOC
<
Apriori
<
ALLOC
> >
allocator
(
this
->
getAllocator
());
127
allocator
.
destroy
(
apriori_
);
128
allocator
.
deallocate
(
apriori_
, 1);
129
}
130
131
apriori_
=
new_apriori
;
132
counter_
=
std
::
move
(
new_counter
);
133
}
134
return
*
this
;
135
}
136
137
138
/// move operator
139
template
<
template
<
typename
>
class
ALLOC
>
140
PseudoCount
<
ALLOC
>&
PseudoCount
<
ALLOC
>::
operator
=(
PseudoCount
<
ALLOC
>&&
from
) {
141
if
(
this
!= &
from
) {
142
std
::
swap
(
apriori_
,
from
.
apriori_
);
143
144
counter_
=
std
::
move
(
from
.
counter_
);
145
}
146
return
*
this
;
147
}
148
149
150
/// changes the max number of threads used to parse the database
151
template
<
template
<
typename
>
class
ALLOC
>
152
INLINE
void
PseudoCount
<
ALLOC
>::
setMaxNbThreads
(
std
::
size_t
nb
)
const
{
153
counter_
.
setMaxNbThreads
(
nb
);
154
}
155
156
157
/// returns the number of threads used to parse the database
158
template
<
template
<
typename
>
class
ALLOC
>
159
INLINE
std
::
size_t
PseudoCount
<
ALLOC
>::
nbThreads
()
const
{
160
return
counter_
.
nbThreads
();
161
}
162
163
164
/** @brief changes the number min of rows a thread should process in a
165
* multithreading context */
166
template
<
template
<
typename
>
class
ALLOC
>
167
INLINE
void
PseudoCount
<
ALLOC
>::
setMinNbRowsPerThread
(
const
std
::
size_t
nb
)
const
{
168
counter_
.
setMinNbRowsPerThread
(
nb
);
169
}
170
171
172
/// returns the minimum of rows that each thread should process
173
template
<
template
<
typename
>
class
ALLOC
>
174
INLINE
std
::
size_t
PseudoCount
<
ALLOC
>::
minNbRowsPerThread
()
const
{
175
return
counter_
.
minNbRowsPerThread
();
176
}
177
178
179
/// sets new ranges to perform the countings used by the score
180
/** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
181
* indices. The countings are then performed only on the union of the
182
* rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
183
* cross validation tasks, in which part of the database should be ignored.
184
* An empty set of ranges is equivalent to an interval [X,Y) ranging over
185
* the whole database. */
186
template
<
template
<
typename
>
class
ALLOC
>
187
template
<
template
<
typename
>
class
XALLOC
>
188
void
PseudoCount
<
ALLOC
>::
setRanges
(
189
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
190
XALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
new_ranges
) {
191
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
192
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
193
old_ranges
=
ranges
();
194
counter_
.
setRanges
(
new_ranges
);
195
if
(
old_ranges
!=
ranges
())
clear
();
196
}
197
198
199
/// reset the ranges to the one range corresponding to the whole database
200
template
<
template
<
typename
>
class
ALLOC
>
201
void
PseudoCount
<
ALLOC
>::
clearRanges
() {
202
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
203
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >
204
old_ranges
=
ranges
();
205
counter_
.
clearRanges
();
206
}
207
208
209
/// returns the current ranges
210
template
<
template
<
typename
>
class
ALLOC
>
211
INLINE
const
std
::
vector
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
>,
212
ALLOC
<
std
::
pair
<
std
::
size_t
,
std
::
size_t
> > >&
213
PseudoCount
<
ALLOC
>::
ranges
()
const
{
214
return
counter_
.
ranges
();
215
}
216
217
218
/// clears all the data structures from memory
219
template
<
template
<
typename
>
class
ALLOC
>
220
INLINE
void
PseudoCount
<
ALLOC
>::
clear
() {
221
counter_
.
clear
();
222
}
223
224
225
/// return the mapping between the columns of the database and the node ids
226
template
<
template
<
typename
>
class
ALLOC
>
227
INLINE
const
Bijection
<
NodeId
,
std
::
size_t
,
ALLOC
<
std
::
size_t
> >&
228
PseudoCount
<
ALLOC
>::
nodeId2Columns
()
const
{
229
return
counter_
.
nodeId2Columns
();
230
}
231
232
233
/// return the database used by the score
234
template
<
template
<
typename
>
class
ALLOC
>
235
INLINE
const
DatabaseTable
<
ALLOC
>&
PseudoCount
<
ALLOC
>::
database
()
const
{
236
return
counter_
.
database
();
237
}
238
239
240
/// returns a counting vector where variables are marginalized from N_xyz
241
/** @param node_2_marginalize indicates which node(s) shall be marginalized:
242
* - 0 means that X should be marginalized
243
* - 1 means that Y should be marginalized
244
* - 2 means that Z should be marginalized
245
*/
246
template
<
template
<
typename
>
class
ALLOC
>
247
std
::
vector
<
double
,
ALLOC
<
double
> >
248
PseudoCount
<
ALLOC
>::
get
(
const
std
::
vector
<
NodeId
,
ALLOC
<
NodeId
> >&
ids
) {
249
IdCondSet
<
ALLOC
>
idset
(
ids
,
false
,
true
);
250
std
::
vector
<
double
,
ALLOC
<
double
> >
N_xyz
(
this
->
counter_
.
counts
(
idset
,
true
));
251
const
bool
informative_external_apriori
=
this
->
apriori_
->
isInformative
();
252
if
(
informative_external_apriori
)
this
->
apriori_
->
addAllApriori
(
idset
,
N_xyz
);
253
return
N_xyz
;
254
}
255
256
}
/* namespace learning */
257
258
}
/* namespace gum */
259
260
#
endif
/* DOXYGEN_SHOULD_SKIP_THIS */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31