aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
estimator_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of Estimator for approximate inference in bayesian
25
* networks
26
*
27
* @author Paul ALAM & Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
30
namespace
gum
{
31
32
template
<
typename
GUM_SCALAR >
33
Estimator< GUM_SCALAR >::Estimator() {
34
GUM_CONSTRUCTOR(Estimator);
35
wtotal_ = (GUM_SCALAR)0.;
36
ntotal_ = (Size)0;
37
bn_ =
nullptr
;
38
}
39
40
41
template
<
typename
GUM_SCALAR
>
42
Estimator
<
GUM_SCALAR
>::
Estimator
(
const
IBayesNet
<
GUM_SCALAR
>*
bn
) :
Estimator
() {
43
bn_
=
bn
;
44
45
for
(
gum
::
NodeGraphPartIterator
iter
=
bn
->
nodes
().
begin
();
iter
!=
bn
->
nodes
().
end
(); ++
iter
)
46
estimator_
.
insert
(
bn
->
variable
(*
iter
).
name
(),
47
std
::
vector
<
GUM_SCALAR
>(
bn
->
variable
(*
iter
).
domainSize
(), 0.0));
48
49
GUM_CONSTRUCTOR
(
Estimator
);
50
}
51
52
53
template
<
typename
GUM_SCALAR
>
54
INLINE
Estimator
<
GUM_SCALAR
>::~
Estimator
() {
55
GUM_DESTRUCTOR
(
Estimator
);
56
// remove all the posteriors computed
57
clear
();
58
}
59
60
61
/* adds all potential target variables from a given BN to the Estimator */
62
63
template
<
typename
GUM_SCALAR
>
64
void
Estimator
<
GUM_SCALAR
>::
setFromBN
(
const
IBayesNet
<
GUM_SCALAR
>*
bn
,
65
const
NodeSet
&
hardEvidence
) {
66
for
(
gum
::
NodeGraphPartIterator
iter
=
bn
->
nodes
().
begin
();
iter
!=
bn
->
nodes
().
end
(); ++
iter
) {
67
auto
v
=
bn
->
variable
(*
iter
).
name
();
68
69
if
(!
hardEvidence
.
contains
(*
iter
)) {
70
if
(
estimator_
.
exists
(
v
))
71
estimator_
[
v
]
72
=
std
::
vector
<
GUM_SCALAR
>(
bn
->
variable
(*
iter
).
domainSize
(), (
GUM_SCALAR
)0.0);
73
else
74
estimator_
.
insert
(
75
v
,
76
std
::
vector
<
GUM_SCALAR
>(
bn
->
variable
(*
iter
).
domainSize
(), (
GUM_SCALAR
)0.0));
77
}
78
}
79
}
80
81
// we multiply the posteriors obtained by LoopyBeliefPropagation by the it's
82
// number of iterations
83
template
<
typename
GUM_SCALAR
>
84
void
Estimator
<
GUM_SCALAR
>::
setFromLBP
(
LoopyBeliefPropagation
<
GUM_SCALAR
>*
lbp
,
85
const
NodeSet
&
hardEvidence
,
86
GUM_SCALAR
virtualLBPSize
) {
87
for
(
const
auto
&
node
:
lbp
->
BN
().
nodes
()) {
88
if
(!
hardEvidence
.
contains
(
node
)) {
89
std
::
vector
<
GUM_SCALAR
>
v
;
90
auto
p
=
lbp
->
posterior
(
node
);
91
gum
::
Instantiation
inst
(
p
);
92
93
for
(
inst
.
setFirst
(); !
inst
.
end
(); ++
inst
) {
94
v
.
push_back
(
p
[
inst
] *
virtualLBPSize
);
95
}
96
97
estimator_
.
insert
(
lbp
->
BN
().
variable
(
node
).
name
(),
v
);
98
}
99
}
100
ntotal_
= (
Size
)
virtualLBPSize
;
101
wtotal_
=
virtualLBPSize
;
102
}
103
104
/*update the Estimator given an instantiation I with weight bias w*/
105
106
template
<
typename
GUM_SCALAR
>
107
void
Estimator
<
GUM_SCALAR
>::
update
(
Instantiation
I
,
GUM_SCALAR
w
) {
108
wtotal_
+=
w
;
109
ntotal_
+= (
Size
)1;
110
111
for
(
Idx
i
= 0;
i
<
I
.
nbrDim
();
i
++) {
112
if
(
estimator_
.
exists
(
I
.
variable
(
i
).
name
()))
estimator_
[
I
.
variable
(
i
).
name
()][
I
.
val
(
i
)] +=
w
;
113
}
114
}
115
116
/* returns the approximation CPT of a variable */
117
118
template
<
typename
GUM_SCALAR
>
119
const
Potential
<
GUM_SCALAR
>&
Estimator
<
GUM_SCALAR
>::
posterior
(
const
DiscreteVariable
&
var
) {
120
Potential
<
GUM_SCALAR
>*
p
=
nullptr
;
121
122
if
(!
estimator_
.
exists
(
var
.
name
()))
GUM_ERROR
(
NotFound
,
"Target variable not found"
)
123
124
// check if we have already computed the posterior
125
if
(
_target_posteriors_
.
exists
(
var
.
name
())) {
126
p
=
_target_posteriors_
[
var
.
name
()];
127
}
else
{
128
p
=
new
Potential
<
GUM_SCALAR
>();
129
*
p
<<
var
;
130
_target_posteriors_
.
insert
(
var
.
name
(),
p
);
131
}
132
133
p
->
fillWith
(
estimator_
[
var
.
name
()]);
134
p
->
normalize
();
135
return
*
p
;
136
}
137
138
139
/* expected value considering a Bernouilli variable with parameter val */
140
141
template
<
typename
GUM_SCALAR
>
142
GUM_SCALAR
Estimator
<
GUM_SCALAR
>::
EV
(
std
::
string
name
,
Idx
val
) {
143
return
estimator_
[
name
][
val
] /
wtotal_
;
144
}
145
146
147
/* variance considering a Bernouilli variable with parameter val */
148
149
template
<
typename
GUM_SCALAR
>
150
GUM_SCALAR
Estimator
<
GUM_SCALAR
>::
variance
(
std
::
string
name
,
Idx
val
) {
151
GUM_SCALAR
p
=
EV
(
name
,
val
);
152
return
p
* (1 -
p
);
153
}
154
155
156
/* returns maximum length of confidence intervals for each variable, each
157
* parameter */
158
159
template
<
typename
GUM_SCALAR
>
160
GUM_SCALAR
Estimator
<
GUM_SCALAR
>::
confidence
() {
161
GUM_SCALAR
ic_max
= 0;
162
163
for
(
auto
iter
=
estimator_
.
begin
();
iter
!=
estimator_
.
end
(); ++
iter
) {
164
for
(
Idx
i
= 0;
i
<
iter
.
val
().
size
();
i
++) {
165
GUM_SCALAR
ic
=
GUM_SCALAR
(2 * 1.96 *
std
::
sqrt
(
variance
(
iter
.
key
(),
i
) / (
ntotal_
- 1)));
166
if
(
ic
>
ic_max
)
ic_max
=
ic
;
167
}
168
}
169
170
return
ic_max
;
171
}
172
173
template
<
typename
GUM_SCALAR
>
174
void
Estimator
<
GUM_SCALAR
>::
clear
() {
175
estimator_
.
clear
();
176
wtotal_
= (
GUM_SCALAR
)0;
177
ntotal_
=
Size
(0);
178
for
(
const
auto
&
pot
:
_target_posteriors_
)
179
delete
pot
.
second
;
180
_target_posteriors_
.
clear
();
181
}
182
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643