aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
BayesNetFragment_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template implementation of BN/BayesNetFragment.h classes.
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
27
*/
28
#
include
<
agrum
/
BN
/
BayesNet
.
h
>
29
#
include
<
agrum
/
BN
/
BayesNetFragment
.
h
>
30
#
include
<
agrum
/
tools
/
multidim
/
potential
.
h
>
31
32
namespace
gum
{
33
template
<
typename
GUM_SCALAR >
34
BayesNetFragment< GUM_SCALAR >::BayesNetFragment(
const
IBayesNet< GUM_SCALAR >& bn) :
35
DiGraphListener(&bn.dag()), _bn_(bn) {
36
GUM_CONSTRUCTOR(BayesNetFragment);
37
}
38
39
template
<
typename
GUM_SCALAR
>
40
BayesNetFragment
<
GUM_SCALAR
>::~
BayesNetFragment
() {
41
GUM_DESTRUCTOR
(
BayesNetFragment
);
42
43
for
(
auto
node
:
nodes
())
44
if
(
_localCPTs_
.
exists
(
node
))
uninstallCPT_
(
node
);
45
}
46
47
//============================================================
48
// signals to keep consistency with the referred BayesNet
49
template
<
typename
GUM_SCALAR
>
50
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
whenNodeAdded
(
const
void
*
src
,
NodeId
id
) {
51
// nothing to do
52
}
53
template
<
typename
GUM_SCALAR
>
54
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
whenNodeDeleted
(
const
void
*
src
,
NodeId
id
) {
55
uninstallNode
(
id
);
56
}
57
template
<
typename
GUM_SCALAR
>
58
INLINE
void
59
BayesNetFragment
<
GUM_SCALAR
>::
whenArcAdded
(
const
void
*
src
,
NodeId
from
,
NodeId
to
) {
60
// nothing to do
61
}
62
template
<
typename
GUM_SCALAR
>
63
INLINE
void
64
BayesNetFragment
<
GUM_SCALAR
>::
whenArcDeleted
(
const
void
*
src
,
NodeId
from
,
NodeId
to
) {
65
if
(
dag
().
existsArc
(
from
,
to
))
uninstallArc_
(
from
,
to
);
66
}
67
68
//============================================================
69
// IBayesNet interface : BayesNetFragment here is a decorator for the bn
70
71
template
<
typename
GUM_SCALAR
>
72
INLINE
const
Potential
<
GUM_SCALAR
>&
BayesNetFragment
<
GUM_SCALAR
>::
cpt
(
NodeId
id
)
const
{
73
if
(!
isInstalledNode
(
id
))
GUM_ERROR
(
NotFound
,
"NodeId "
<<
id
<<
" is not installed"
)
74
75
if
(
_localCPTs_
.
exists
(
id
))
76
return
*
_localCPTs_
[
id
];
77
else
78
return
_bn_
.
cpt
(
id
);
79
}
80
81
template
<
typename
GUM_SCALAR
>
82
INLINE
const
VariableNodeMap
&
BayesNetFragment
<
GUM_SCALAR
>::
variableNodeMap
()
const
{
83
return
this
->
_bn_
.
variableNodeMap
();
84
}
85
86
template
<
typename
GUM_SCALAR
>
87
INLINE
const
DiscreteVariable
&
BayesNetFragment
<
GUM_SCALAR
>::
variable
(
NodeId
id
)
const
{
88
if
(!
isInstalledNode
(
id
))
GUM_ERROR
(
NotFound
,
"NodeId "
<<
id
<<
" is not installed"
)
89
90
return
_bn_
.
variable
(
id
);
91
}
92
93
template
<
typename
GUM_SCALAR
>
94
INLINE
NodeId
BayesNetFragment
<
GUM_SCALAR
>::
nodeId
(
const
DiscreteVariable
&
var
)
const
{
95
NodeId
id
=
_bn_
.
nodeId
(
var
);
96
97
if
(!
isInstalledNode
(
id
))
GUM_ERROR
(
NotFound
,
"variable "
<<
var
.
name
() <<
" is not installed"
)
98
99
return
id
;
100
}
101
102
template
<
typename
GUM_SCALAR
>
103
INLINE
NodeId
BayesNetFragment
<
GUM_SCALAR
>::
idFromName
(
const
std
::
string
&
name
)
const
{
104
NodeId
id
=
_bn_
.
idFromName
(
name
);
105
106
if
(!
isInstalledNode
(
id
))
GUM_ERROR
(
NotFound
,
"variable "
<<
name
<<
" is not installed"
)
107
108
return
id
;
109
}
110
111
template
<
typename
GUM_SCALAR
>
112
INLINE
const
DiscreteVariable
&
113
BayesNetFragment
<
GUM_SCALAR
>::
variableFromName
(
const
std
::
string
&
name
)
const
{
114
NodeId
id
=
idFromName
(
name
);
115
116
if
(!
isInstalledNode
(
id
))
GUM_ERROR
(
NotFound
,
"variable "
<<
name
<<
" is not installed"
)
117
118
return
_bn_
.
variable
(
id
);
119
}
120
121
//============================================================
122
// specific API for BayesNetFragment
123
template
<
typename
GUM_SCALAR
>
124
INLINE
bool
BayesNetFragment
<
GUM_SCALAR
>::
isInstalledNode
(
NodeId
id
)
const
{
125
return
dag
().
existsNode
(
id
);
126
}
127
128
template
<
typename
GUM_SCALAR
>
129
void
BayesNetFragment
<
GUM_SCALAR
>::
installNode
(
NodeId
id
) {
130
if
(!
_bn_
.
dag
().
existsNode
(
id
))
131
GUM_ERROR
(
NotFound
,
"Node "
<<
id
<<
" does not exist in referred BayesNet"
)
132
133
if
(!
isInstalledNode
(
id
)) {
134
this
->
dag_
.
addNodeWithId
(
id
);
135
136
// adding arcs with id as a tail
137
for
(
auto
pa
:
this
->
_bn_
.
parents
(
id
)) {
138
if
(
isInstalledNode
(
pa
))
this
->
dag_
.
addArc
(
pa
,
id
);
139
}
140
141
// adding arcs with id as a head
142
for
(
auto
son
:
this
->
_bn_
.
children
(
id
))
143
if
(
isInstalledNode
(
son
))
this
->
dag_
.
addArc
(
id
,
son
);
144
}
145
}
146
147
template
<
typename
GUM_SCALAR
>
148
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
installAscendants
(
NodeId
id
) {
149
installNode
(
id
);
150
151
// bn is a dag => this will have an end ...
152
for
(
auto
pa
:
this
->
_bn_
.
parents
(
id
))
153
installAscendants
(
pa
);
154
}
155
156
template
<
typename
GUM_SCALAR
>
157
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallNode
(
NodeId
id
) {
158
if
(
isInstalledNode
(
id
)) {
159
uninstallCPT
(
id
);
160
this
->
dag_
.
eraseNode
(
id
);
161
}
162
}
163
164
template
<
typename
GUM_SCALAR
>
165
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallArc_
(
NodeId
from
,
NodeId
to
) {
166
this
->
dag_
.
eraseArc
(
Arc
(
from
,
to
));
167
}
168
169
template
<
typename
GUM_SCALAR
>
170
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
installArc_
(
NodeId
from
,
NodeId
to
) {
171
this
->
dag_
.
addArc
(
from
,
to
);
172
}
173
174
template
<
typename
GUM_SCALAR
>
175
void
BayesNetFragment
<
GUM_SCALAR
>::
installCPT_
(
NodeId
id
,
const
Potential
<
GUM_SCALAR
>&
pot
) {
176
// topology
177
const
auto
&
parents
=
this
->
parents
(
id
);
178
for
(
auto
node_it
=
parents
.
beginSafe
();
node_it
!=
parents
.
endSafe
();
179
++
node_it
)
// safe iterator needed here
180
uninstallArc_
(*
node_it
,
id
);
181
182
for
(
Idx
i
= 1;
i
<
pot
.
nbrDim
();
i
++) {
183
NodeId
parent
=
_bn_
.
idFromName
(
pot
.
variable
(
i
).
name
());
184
185
if
(
isInstalledNode
(
parent
))
installArc_
(
parent
,
id
);
186
}
187
188
// local cpt
189
if
(
_localCPTs_
.
exists
(
id
))
uninstallCPT_
(
id
);
190
191
_localCPTs_
.
insert
(
id
,
new
gum
::
Potential
<
GUM_SCALAR
>(
pot
));
192
}
193
194
template
<
typename
GUM_SCALAR
>
195
void
BayesNetFragment
<
GUM_SCALAR
>::
installCPT
(
NodeId
id
,
const
Potential
<
GUM_SCALAR
>&
pot
) {
196
if
(!
dag
().
existsNode
(
id
))
197
GUM_ERROR
(
NotFound
,
"Node "
<<
id
<<
" is not installed in the fragment"
)
198
199
if
(&(
pot
.
variable
(0)) != &(
variable
(
id
))) {
200
GUM_ERROR
(
OperationNotAllowed
,
201
"The potential is not a marginal for _bn_.variable <"
<<
variable
(
id
).
name
()
202
<<
">"
);
203
}
204
205
const
NodeSet
&
parents
=
_bn_
.
parents
(
id
);
206
207
for
(
Idx
i
= 1;
i
<
pot
.
nbrDim
();
i
++) {
208
if
(!
parents
.
contains
(
_bn_
.
idFromName
(
pot
.
variable
(
i
).
name
())))
209
GUM_ERROR
(
OperationNotAllowed
,
210
"Variable <"
<<
pot
.
variable
(
i
).
name
() <<
"> is not in the parents of node "
211
<<
id
);
212
}
213
214
installCPT_
(
id
,
pot
);
215
}
216
217
template
<
typename
GUM_SCALAR
>
218
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallCPT_
(
NodeId
id
) {
219
delete
_localCPTs_
[
id
];
220
_localCPTs_
.
erase
(
id
);
221
}
222
223
template
<
typename
GUM_SCALAR
>
224
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallCPT
(
NodeId
id
) {
225
if
(
_localCPTs_
.
exists
(
id
)) {
226
uninstallCPT_
(
id
);
227
228
// re-create arcs from referred potential
229
const
Potential
<
GUM_SCALAR
>&
pot
=
cpt
(
id
);
230
231
for
(
Idx
i
= 1;
i
<
pot
.
nbrDim
();
i
++) {
232
NodeId
parent
=
_bn_
.
idFromName
(
pot
.
variable
(
i
).
name
());
233
234
if
(
isInstalledNode
(
parent
))
installArc_
(
parent
,
id
);
235
}
236
}
237
}
238
239
template
<
typename
GUM_SCALAR
>
240
void
BayesNetFragment
<
GUM_SCALAR
>::
installMarginal
(
NodeId
id
,
241
const
Potential
<
GUM_SCALAR
>&
pot
) {
242
if
(!
isInstalledNode
(
id
)) {
243
GUM_ERROR
(
NotFound
,
"The node "
<<
id
<<
" is not part of this fragment"
)
244
}
245
246
if
(
pot
.
nbrDim
() > 1) {
247
GUM_ERROR
(
OperationNotAllowed
,
"The potential is not a marginal :"
<<
pot
)
248
}
249
250
if
(&(
pot
.
variable
(0)) != &(
_bn_
.
variable
(
id
))) {
251
GUM_ERROR
(
OperationNotAllowed
,
252
"The potential is not a marginal for _bn_.variable <"
<<
_bn_
.
variable
(
id
).
name
()
253
<<
">"
);
254
}
255
256
installCPT_
(
id
,
pot
);
257
}
258
259
template
<
typename
GUM_SCALAR
>
260
bool
BayesNetFragment
<
GUM_SCALAR
>::
checkConsistency
(
NodeId
id
)
const
{
261
if
(!
isInstalledNode
(
id
))
262
GUM_ERROR
(
NotFound
,
"The node "
<<
id
<<
" is not part of this fragment"
)
263
264
const
auto
&
cpt
=
this
->
cpt
(
id
);
265
NodeSet
cpt_parents
;
266
267
for
(
Idx
i
= 1;
i
<
cpt
.
nbrDim
();
i
++) {
268
cpt_parents
.
insert
(
_bn_
.
idFromName
(
cpt
.
variable
(
i
).
name
()));
269
}
270
271
return
(
this
->
parents
(
id
) ==
cpt_parents
);
272
}
273
274
template
<
typename
GUM_SCALAR
>
275
INLINE
bool
BayesNetFragment
<
GUM_SCALAR
>::
checkConsistency
()
const
{
276
for
(
auto
node
:
nodes
())
277
if
(!
checkConsistency
(
node
))
return
false
;
278
279
return
true
;
280
}
281
282
template
<
typename
GUM_SCALAR
>
283
std
::
string
BayesNetFragment
<
GUM_SCALAR
>::
toDot
()
const
{
284
std
::
stringstream
output
;
285
output
<<
"digraph \""
;
286
287
std
::
string
bn_name
;
288
289
static
std
::
string
inFragmentStyle
=
"fillcolor=\"#ffffaa\","
290
"color=\"#000000\","
291
"fontcolor=\"#000000\""
;
292
static
std
::
string
styleWithLocalCPT
=
"fillcolor=\"#ffddaa\","
293
"color=\"#000000\","
294
"fontcolor=\"#000000\""
;
295
static
std
::
string
notConsistantStyle
=
"fillcolor=\"#ff0000\","
296
"color=\"#000000\","
297
"fontcolor=\"#ffff00\""
;
298
static
std
::
string
outFragmentStyle
=
"fillcolor=\"#f0f0f0\","
299
"color=\"#f0f0f0\","
300
"fontcolor=\"#000000\""
;
301
302
try
{
303
bn_name
=
_bn_
.
property
(
"name"
);
304
}
catch
(
NotFound
&) {
bn_name
=
"no_name"
; }
305
306
bn_name
=
"Fragment of "
+
bn_name
;
307
308
output
<<
bn_name
<<
"\" {"
<<
std
::
endl
;
309
output
<<
" graph [bgcolor=transparent,label=\""
<<
bn_name
<<
"\"];"
<<
std
::
endl
;
310
output
<<
" node [style=filled];"
<<
std
::
endl
<<
std
::
endl
;
311
312
for
(
auto
node
:
_bn_
.
nodes
()) {
313
output
<<
"\""
<<
_bn_
.
variable
(
node
).
name
() <<
"\" [comment=\""
<<
node
<<
":"
314
<<
_bn_
.
variable
(
node
) <<
", \""
;
315
316
if
(
isInstalledNode
(
node
)) {
317
if
(!
checkConsistency
(
node
)) {
318
output
<<
notConsistantStyle
;
319
}
else
if
(
_localCPTs_
.
exists
(
node
))
320
output
<<
styleWithLocalCPT
;
321
else
322
output
<<
inFragmentStyle
;
323
}
else
324
output
<<
outFragmentStyle
;
325
326
output
<<
"];"
<<
std
::
endl
;
327
}
328
329
output
<<
std
::
endl
;
330
331
std
::
string
tab
=
" "
;
332
333
for
(
auto
node
:
_bn_
.
nodes
()) {
334
if
(
_bn_
.
children
(
node
).
size
() > 0) {
335
for
(
auto
child
:
_bn_
.
children
(
node
)) {
336
output
<<
tab
<<
"\""
<<
_bn_
.
variable
(
node
).
name
() <<
"\" -> "
337
<<
"\""
<<
_bn_
.
variable
(
child
).
name
() <<
"\" ["
;
338
339
if
(
dag
().
existsArc
(
Arc
(
node
,
child
)))
340
output
<<
inFragmentStyle
;
341
else
342
output
<<
outFragmentStyle
;
343
344
output
<<
"];"
<<
std
::
endl
;
345
}
346
}
347
}
348
349
output
<<
"}"
<<
std
::
endl
;
350
351
return
output
.
str
();
352
}
353
354
template
<
typename
GUM_SCALAR
>
355
gum
::
BayesNet
<
GUM_SCALAR
>
BayesNetFragment
<
GUM_SCALAR
>::
toBN
()
const
{
356
if
(!
checkConsistency
()) {
357
GUM_ERROR
(
OperationNotAllowed
,
"The fragment contains un-consistent node(s)"
)
358
}
359
gum
::
BayesNet
<
GUM_SCALAR
>
res
;
360
for
(
const
auto
nod
:
nodes
()) {
361
res
.
add
(
variable
(
nod
),
nod
);
362
}
363
for
(
const
auto
&
arc
:
dag
().
arcs
()) {
364
res
.
addArc
(
arc
.
tail
(),
arc
.
head
());
365
}
366
for
(
const
auto
nod
:
nodes
()) {
367
res
.
cpt
(
nod
).
fillWith
(
cpt
(
nod
));
368
}
369
370
return
res
;
371
}
372
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643