aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
estimator_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of Estimator for approximate inference in bayesian
25
* networks
26
*
27
* @author Paul ALAM & Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
30
namespace
gum
{
31
32
template
<
typename
GUM_SCALAR >
33
Estimator< GUM_SCALAR >::Estimator() {
34
GUM_CONSTRUCTOR(Estimator);
35
wtotal_ = (GUM_SCALAR)0.;
36
ntotal_ = (Size)0;
37
bn_ =
nullptr
;
38
}
39
40
41
template
<
typename
GUM_SCALAR
>
42
Estimator
<
GUM_SCALAR
>::
Estimator
(
const
IBayesNet
<
GUM_SCALAR
>*
bn
) :
43
Estimator
() {
44
bn_
=
bn
;
45
46
for
(
gum
::
NodeGraphPartIterator
iter
=
bn
->
nodes
().
begin
();
47
iter
!=
bn
->
nodes
().
end
();
48
++
iter
)
49
estimator_
.
insert
(
50
bn
->
variable
(*
iter
).
name
(),
51
std
::
vector
<
GUM_SCALAR
>(
bn
->
variable
(*
iter
).
domainSize
(), 0.0));
52
53
GUM_CONSTRUCTOR
(
Estimator
);
54
}
55
56
57
template
<
typename
GUM_SCALAR
>
58
INLINE
Estimator
<
GUM_SCALAR
>::~
Estimator
() {
59
GUM_DESTRUCTOR
(
Estimator
);
60
// remove all the posteriors computed
61
clear
();
62
}
63
64
65
/* adds all potential target variables from a given BN to the Estimator */
66
67
template
<
typename
GUM_SCALAR
>
68
void
Estimator
<
GUM_SCALAR
>::
setFromBN
(
const
IBayesNet
<
GUM_SCALAR
>*
bn
,
69
const
NodeSet
&
hardEvidence
) {
70
for
(
gum
::
NodeGraphPartIterator
iter
=
bn
->
nodes
().
begin
();
71
iter
!=
bn
->
nodes
().
end
();
72
++
iter
) {
73
auto
v
=
bn
->
variable
(*
iter
).
name
();
74
75
if
(!
hardEvidence
.
contains
(*
iter
)) {
76
if
(
estimator_
.
exists
(
v
))
77
estimator_
[
v
]
78
=
std
::
vector
<
GUM_SCALAR
>(
bn
->
variable
(*
iter
).
domainSize
(),
79
(
GUM_SCALAR
)0.0);
80
else
81
estimator_
.
insert
(
82
v
,
83
std
::
vector
<
GUM_SCALAR
>(
bn
->
variable
(*
iter
).
domainSize
(),
84
(
GUM_SCALAR
)0.0));
85
}
86
}
87
}
88
89
// we multiply the posteriors obtained by LoopyBeliefPropagation by the it's
90
// number of iterations
91
template
<
typename
GUM_SCALAR
>
92
void
93
Estimator
<
GUM_SCALAR
>::
setFromLBP
(
LoopyBeliefPropagation
<
GUM_SCALAR
>*
lbp
,
94
const
NodeSet
&
hardEvidence
,
95
GUM_SCALAR
virtualLBPSize
) {
96
for
(
const
auto
&
node
:
lbp
->
BN
().
nodes
()) {
97
if
(!
hardEvidence
.
contains
(
node
)) {
98
std
::
vector
<
GUM_SCALAR
>
v
;
99
auto
p
=
lbp
->
posterior
(
node
);
100
gum
::
Instantiation
inst
(
p
);
101
102
for
(
inst
.
setFirst
(); !
inst
.
end
(); ++
inst
) {
103
v
.
push_back
(
p
[
inst
] *
virtualLBPSize
);
104
}
105
106
estimator_
.
insert
(
lbp
->
BN
().
variable
(
node
).
name
(),
v
);
107
}
108
}
109
ntotal_
= (
Size
)
virtualLBPSize
;
110
wtotal_
=
virtualLBPSize
;
111
}
112
113
/*update the Estimator given an instantiation I with weight bias w*/
114
115
template
<
typename
GUM_SCALAR
>
116
void
Estimator
<
GUM_SCALAR
>::
update
(
Instantiation
I
,
GUM_SCALAR
w
) {
117
wtotal_
+=
w
;
118
ntotal_
+= (
Size
)1;
119
120
for
(
Idx
i
= 0;
i
<
I
.
nbrDim
();
i
++) {
121
if
(
estimator_
.
exists
(
I
.
variable
(
i
).
name
()))
122
estimator_
[
I
.
variable
(
i
).
name
()][
I
.
val
(
i
)] +=
w
;
123
}
124
}
125
126
/* returns the approximation CPT of a variable */
127
128
template
<
typename
GUM_SCALAR
>
129
const
Potential
<
GUM_SCALAR
>&
130
Estimator
<
GUM_SCALAR
>::
posterior
(
const
DiscreteVariable
&
var
) {
131
Potential
<
GUM_SCALAR
>*
p
=
nullptr
;
132
133
if
(!
estimator_
.
exists
(
var
.
name
()))
134
GUM_ERROR
(
NotFound
,
"Target variable not found"
);
135
136
// check if we have already computed the posterior
137
if
(
target_posteriors__
.
exists
(
var
.
name
())) {
138
p
=
target_posteriors__
[
var
.
name
()];
139
}
else
{
140
p
=
new
Potential
<
GUM_SCALAR
>();
141
*
p
<<
var
;
142
target_posteriors__
.
insert
(
var
.
name
(),
p
);
143
}
144
145
p
->
fillWith
(
estimator_
[
var
.
name
()]);
146
p
->
normalize
();
147
return
*
p
;
148
}
149
150
151
/* expected value considering a Bernouilli variable with parameter val */
152
153
template
<
typename
GUM_SCALAR
>
154
GUM_SCALAR
Estimator
<
GUM_SCALAR
>::
EV
(
std
::
string
name
,
Idx
val
) {
155
return
estimator_
[
name
][
val
] /
wtotal_
;
156
}
157
158
159
/* variance considering a Bernouilli variable with parameter val */
160
161
template
<
typename
GUM_SCALAR
>
162
GUM_SCALAR
Estimator
<
GUM_SCALAR
>::
variance
(
std
::
string
name
,
Idx
val
) {
163
GUM_SCALAR
p
=
EV
(
name
,
val
);
164
return
p
* (1 -
p
);
165
}
166
167
168
/* returns maximum length of confidence intervals for each variable, each
169
* parameter */
170
171
template
<
typename
GUM_SCALAR
>
172
GUM_SCALAR
Estimator
<
GUM_SCALAR
>::
confidence
() {
173
GUM_SCALAR
ic_max
= 0;
174
175
for
(
auto
iter
=
estimator_
.
begin
();
iter
!=
estimator_
.
end
(); ++
iter
) {
176
for
(
Idx
i
= 0;
i
<
iter
.
val
().
size
();
i
++) {
177
GUM_SCALAR
ic
=
GUM_SCALAR
(
178
2 * 1.96 *
std
::
sqrt
(
variance
(
iter
.
key
(),
i
) / (
ntotal_
- 1)));
179
if
(
ic
>
ic_max
)
ic_max
=
ic
;
180
}
181
}
182
183
return
ic_max
;
184
}
185
186
template
<
typename
GUM_SCALAR
>
187
void
Estimator
<
GUM_SCALAR
>::
clear
() {
188
estimator_
.
clear
();
189
wtotal_
= (
GUM_SCALAR
)0;
190
ntotal_
=
Size
(0);
191
for
(
const
auto
&
pot
:
target_posteriors__
)
192
delete
pot
.
second
;
193
target_posteriors__
.
clear
();
194
}
195
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669