aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
samplingInference_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the non pure virtual methods of class
25
* ApproximateInference.
26
*
27
* @author Paul ALAM & Pierre-Henri WUILLEMIN(@LIP6)
28
*/
29
30
#
include
<
agrum
/
BN
/
BayesNetFragment
.
h
>
31
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
32
#
include
<
agrum
/
BN
/
algorithms
/
dSeparation
.
h
>
33
#
include
<
agrum
/
BN
/
inference
/
tools
/
samplingInference
.
h
>
34
35
36
#
define
DEFAULT_MAXITER
10000000
37
#
define
DEFAULT_PERIOD_SIZE
100
38
#
define
DEFAULT_VERBOSITY
false
39
#
define
DEFAULT_TIMEOUT
6000
40
#
define
DEFAULT_EPSILON
1e-2
41
#
define
DEFAULT_MIN_EPSILON_RATE
1e-5
42
43
44
namespace
gum
{
45
46
template
<
typename
GUM_SCALAR >
47
SamplingInference< GUM_SCALAR >::SamplingInference(
const
IBayesNet< GUM_SCALAR >* bn) :
48
ApproximateInference< GUM_SCALAR >(bn), _estimator_(), _samplingBN_(
nullptr
) {
49
this
->setEpsilon(
DEFAULT_EPSILON
);
50
this
->setMinEpsilonRate(
DEFAULT_MIN_EPSILON_RATE
);
51
this
->setMaxIter(
DEFAULT_MAXITER
);
52
this
->setVerbosity(
DEFAULT_VERBOSITY
);
53
this
->setPeriodSize(
DEFAULT_PERIOD_SIZE
);
54
this
->setMaxTime(
DEFAULT_TIMEOUT
);
55
GUM_CONSTRUCTOR(SamplingInference);
56
}
57
58
59
template
<
typename
GUM_SCALAR
>
60
SamplingInference
<
GUM_SCALAR
>::~
SamplingInference
() {
61
GUM_DESTRUCTOR
(
SamplingInference
);
62
if
(
_samplingBN_
!=
nullptr
) {
63
if
(
isContextualized
) {
// otherwise _samplingBN_==&BN()
64
delete
_samplingBN_
;
65
}
66
}
67
}
68
69
template
<
typename
GUM_SCALAR
>
70
INLINE
const
IBayesNet
<
GUM_SCALAR
>&
SamplingInference
<
GUM_SCALAR
>::
samplingBN
() {
71
this
->
prepareInference
();
72
if
(
_samplingBN_
==
nullptr
)
73
return
this
->
BN
();
74
else
75
return
*
_samplingBN_
;
76
}
77
template
<
typename
GUM_SCALAR
>
78
void
SamplingInference
<
GUM_SCALAR
>::
setEstimatorFromBN_
() {
79
_estimator_
.
setFromBN
(&
samplingBN
(),
this
->
hardEvidenceNodes
());
80
this
->
isSetEstimator
=
true
;
81
}
82
83
template
<
typename
GUM_SCALAR
>
84
void
SamplingInference
<
GUM_SCALAR
>::
setEstimatorFromLBP_
(
85
LoopyBeliefPropagation
<
GUM_SCALAR
>*
lbp
,
86
GUM_SCALAR
virtualLBPSize
) {
87
_estimator_
.
setFromLBP
(
lbp
,
this
->
hardEvidenceNodes
(),
virtualLBPSize
);
88
this
->
isSetEstimator
=
true
;
89
}
90
91
92
template
<
typename
GUM_SCALAR
>
93
const
Potential
<
GUM_SCALAR
>&
SamplingInference
<
GUM_SCALAR
>::
currentPosterior
(
NodeId
id
) {
94
return
_estimator_
.
posterior
(
this
->
BN
().
variable
(
id
));
95
}
96
97
template
<
typename
GUM_SCALAR
>
98
const
Potential
<
GUM_SCALAR
>&
99
SamplingInference
<
GUM_SCALAR
>::
currentPosterior
(
const
std
::
string
&
name
) {
100
return
currentPosterior
(
this
->
BN
().
idFromName
(
name
));
101
}
102
103
template
<
typename
GUM_SCALAR
>
104
const
Potential
<
GUM_SCALAR
>&
SamplingInference
<
GUM_SCALAR
>::
posterior_
(
NodeId
id
) {
105
return
_estimator_
.
posterior
(
this
->
BN
().
variable
(
id
));
106
}
107
108
template
<
typename
GUM_SCALAR
>
109
void
SamplingInference
<
GUM_SCALAR
>::
contextualize
() {
110
// Finding Barren nodes
111
112
BarrenNodesFinder
barr_nodes
=
BarrenNodesFinder
(&
this
->
BN
().
dag
());
113
barr_nodes
.
setTargets
(&
this
->
targets
());
114
barr_nodes
.
setEvidence
(&
this
->
hardEvidenceNodes
());
115
const
NodeSet
&
barren
=
barr_nodes
.
barrenNodes
();
116
117
// creating BN fragment
118
_samplingBN_
=
new
BayesNetFragment
<
GUM_SCALAR
>(
this
->
BN
());
119
for
(
const
auto
elmt
:
this
->
BN
().
dag
().
asNodeSet
() -
barren
)
120
_samplingBN_
->
installNode
(
elmt
);
121
122
// D-separated nodes
123
124
dSeparation
dsep
=
gum
::
dSeparation
();
125
NodeSet
requisite
;
126
dsep
.
requisiteNodes
(
this
->
BN
().
dag
(),
127
this
->
BN
().
nodes
().
asNodeSet
(),
// no target for approximateInference
128
this
->
hardEvidenceNodes
(),
129
this
->
softEvidenceNodes
(),
// should be empty
130
requisite
);
131
requisite
+=
this
->
hardEvidenceNodes
();
132
133
auto
nonRequisite
=
this
->
BN
().
dag
().
asNodeSet
() -
requisite
;
134
135
for
(
const
auto
elmt
:
nonRequisite
)
136
_samplingBN_
->
uninstallNode
(
elmt
);
137
for
(
const
auto
hard
:
this
->
hardEvidenceNodes
()) {
138
gum
::
Instantiation
I
;
139
I
.
add
(
this
->
BN
().
variable
(
hard
));
140
I
.
chgVal
(
this
->
BN
().
variable
(
hard
),
this
->
hardEvidence
()[
hard
]);
141
142
for
(
const
auto
&
child
:
this
->
BN
().
children
(
hard
)) {
143
_samplingBN_
->
installCPT
(
child
,
this
->
BN
().
cpt
(
child
).
extract
(
I
));
144
}
145
}
146
147
this
->
isContextualized
=
true
;
148
this
->
onContextualize_
(
_samplingBN_
);
149
}
150
151
152
template
<
typename
GUM_SCALAR
>
153
void
SamplingInference
<
GUM_SCALAR
>::
makeInference_
() {
154
if
(!
isSetEstimator
)
this
->
setEstimatorFromBN_
();
155
loopApproxInference_
();
156
}
157
158
template
<
typename
GUM_SCALAR
>
159
void
SamplingInference
<
GUM_SCALAR
>::
loopApproxInference_
() {
160
//@todo This should be in _prepareInference_
161
if
(!
isContextualized
) {
this
->
contextualize
(); }
162
163
this
->
initApproximationScheme
();
164
gum
::
Instantiation
Ip
;
165
GUM_SCALAR
w
= .0;
//
166
167
// Burn in
168
Ip
=
this
->
burnIn_
();
169
do
{
170
Ip
=
this
->
draw_
(&
w
,
Ip
);
171
_estimator_
.
update
(
Ip
,
w
);
172
this
->
updateApproximationScheme
();
173
}
while
(
this
->
continueApproximationScheme
(
_estimator_
.
confidence
()));
174
175
this
->
isSetEstimator
=
false
;
176
}
177
178
179
template
<
typename
GUM_SCALAR
>
180
void
SamplingInference
<
GUM_SCALAR
>::
addVarSample_
(
NodeId
nod
,
Instantiation
*
I
) {
181
gum
::
Instantiation
Itop
=
gum
::
Instantiation
(*
I
);
182
183
I
->
add
(
samplingBN
().
variable
(
nod
));
184
I
->
chgVal
(
samplingBN
().
variable
(
nod
),
samplingBN
().
cpt
(
nod
).
extract
(
Itop
).
draw
());
185
}
186
187
template
<
typename
GUM_SCALAR
>
188
void
SamplingInference
<
GUM_SCALAR
>::
onContextualize_
(
BayesNetFragment
<
GUM_SCALAR
>*
bn
) {}
189
190
191
template
<
typename
GUM_SCALAR
>
192
void
SamplingInference
<
GUM_SCALAR
>::
onEvidenceAdded_
(
const
NodeId
id
,
bool
isHardEvidence
) {
193
if
(!
isHardEvidence
) {
194
GUM_ERROR
(
FatalError
,
"Approximated inference only accept hard evidence"
)
195
}
196
}
197
198
template
<
typename
GUM_SCALAR
>
199
void
SamplingInference
<
GUM_SCALAR
>::
onEvidenceErased_
(
const
NodeId
id
,
bool
isHardEvidence
) {}
200
201
template
<
typename
GUM_SCALAR
>
202
void
SamplingInference
<
GUM_SCALAR
>::
onAllEvidenceErased_
(
bool
contains_hard_evidence
) {}
203
204
template
<
typename
GUM_SCALAR
>
205
void
SamplingInference
<
GUM_SCALAR
>::
onEvidenceChanged_
(
const
NodeId
id
,
206
bool
hasChangedSoftHard
) {
207
if
(
hasChangedSoftHard
) {
208
GUM_ERROR
(
FatalError
,
"Approximated inference only accept hard evidence"
)
209
}
210
}
211
212
template
<
typename
GUM_SCALAR
>
213
void
SamplingInference
<
GUM_SCALAR
>::
onModelChanged_
(
const
GraphicalModel
*
bn
) {}
214
215
template
<
typename
GUM_SCALAR
>
216
void
SamplingInference
<
GUM_SCALAR
>::
updateOutdatedStructure_
() {}
217
218
template
<
typename
GUM_SCALAR
>
219
void
SamplingInference
<
GUM_SCALAR
>::
updateOutdatedPotentials_
() {}
220
221
template
<
typename
GUM_SCALAR
>
222
void
SamplingInference
<
GUM_SCALAR
>::
onMarginalTargetAdded_
(
const
NodeId
id
) {}
223
224
template
<
typename
GUM_SCALAR
>
225
void
SamplingInference
<
GUM_SCALAR
>::
onMarginalTargetErased_
(
const
NodeId
id
) {}
226
227
template
<
typename
GUM_SCALAR
>
228
void
SamplingInference
<
GUM_SCALAR
>::
onAllMarginalTargetsAdded_
() {}
229
230
template
<
typename
GUM_SCALAR
>
231
void
SamplingInference
<
GUM_SCALAR
>::
onAllMarginalTargetsErased_
() {}
232
233
template
<
typename
GUM_SCALAR
>
234
void
SamplingInference
<
GUM_SCALAR
>::
onStateChanged_
() {
235
if
(
this
->
isInferenceReady
()) {
236
_estimator_
.
clear
();
237
this
->
initApproximationScheme
();
238
}
239
}
240
}
// namespace gum
DEFAULT_MAXITER
#define DEFAULT_MAXITER
Definition:
samplingInference_tpl.h:36
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
DEFAULT_EPSILON
#define DEFAULT_EPSILON
Definition:
samplingInference_tpl.h:40
DEFAULT_PERIOD_SIZE
#define DEFAULT_PERIOD_SIZE
Definition:
samplingInference_tpl.h:37
DEFAULT_TIMEOUT
#define DEFAULT_TIMEOUT
Definition:
samplingInference_tpl.h:39
DEFAULT_MIN_EPSILON_RATE
#define DEFAULT_MIN_EPSILON_RATE
Definition:
samplingInference_tpl.h:41
DEFAULT_VERBOSITY
#define DEFAULT_VERBOSITY
Definition:
samplingInference_tpl.h:38