aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
BayesNetFragment_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template implementation of BN/BayesNetFragment.h classes.
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
27
*/
28
#
include
<
agrum
/
BN
/
BayesNet
.
h
>
29
#
include
<
agrum
/
BN
/
BayesNetFragment
.
h
>
30
#
include
<
agrum
/
tools
/
multidim
/
potential
.
h
>
31
32
namespace
gum
{
33
template
<
typename
GUM_SCALAR >
34
BayesNetFragment< GUM_SCALAR >::BayesNetFragment(
35
const
IBayesNet< GUM_SCALAR >& bn) :
36
DiGraphListener(&bn.dag()),
37
bn__(bn) {
38
GUM_CONSTRUCTOR(BayesNetFragment);
39
}
40
41
template
<
typename
GUM_SCALAR
>
42
BayesNetFragment
<
GUM_SCALAR
>::~
BayesNetFragment
() {
43
GUM_DESTRUCTOR
(
BayesNetFragment
);
44
45
for
(
auto
node
:
nodes
())
46
if
(
localCPTs__
.
exists
(
node
))
uninstallCPT_
(
node
);
47
}
48
49
//============================================================
50
// signals to keep consistency with the referred BayesNet
51
template
<
typename
GUM_SCALAR
>
52
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
whenNodeAdded
(
const
void
*
src
,
53
NodeId
id
) {
54
// nothing to do
55
}
56
template
<
typename
GUM_SCALAR
>
57
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
whenNodeDeleted
(
const
void
*
src
,
58
NodeId
id
) {
59
uninstallNode
(
id
);
60
}
61
template
<
typename
GUM_SCALAR
>
62
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
whenArcAdded
(
const
void
*
src
,
63
NodeId
from
,
64
NodeId
to
) {
65
// nothing to do
66
}
67
template
<
typename
GUM_SCALAR
>
68
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
whenArcDeleted
(
const
void
*
src
,
69
NodeId
from
,
70
NodeId
to
) {
71
if
(
dag
().
existsArc
(
from
,
to
))
uninstallArc_
(
from
,
to
);
72
}
73
74
//============================================================
75
// IBayesNet interface : BayesNetFragment here is a decorator for the bn
76
77
template
<
typename
GUM_SCALAR
>
78
INLINE
const
Potential
<
GUM_SCALAR
>&
79
BayesNetFragment
<
GUM_SCALAR
>::
cpt
(
NodeId
id
)
const
{
80
if
(!
isInstalledNode
(
id
))
81
GUM_ERROR
(
NotFound
,
"NodeId "
<<
id
<<
" is not installed"
);
82
83
if
(
localCPTs__
.
exists
(
id
))
84
return
*
localCPTs__
[
id
];
85
else
86
return
bn__
.
cpt
(
id
);
87
}
88
89
template
<
typename
GUM_SCALAR
>
90
INLINE
const
VariableNodeMap
&
91
BayesNetFragment
<
GUM_SCALAR
>::
variableNodeMap
()
const
{
92
return
this
->
bn__
.
variableNodeMap
();
93
}
94
95
template
<
typename
GUM_SCALAR
>
96
INLINE
const
DiscreteVariable
&
97
BayesNetFragment
<
GUM_SCALAR
>::
variable
(
NodeId
id
)
const
{
98
if
(!
isInstalledNode
(
id
))
99
GUM_ERROR
(
NotFound
,
"NodeId "
<<
id
<<
" is not installed"
);
100
101
return
bn__
.
variable
(
id
);
102
}
103
104
template
<
typename
GUM_SCALAR
>
105
INLINE
NodeId
106
BayesNetFragment
<
GUM_SCALAR
>::
nodeId
(
const
DiscreteVariable
&
var
)
const
{
107
NodeId
id
=
bn__
.
nodeId
(
var
);
108
109
if
(!
isInstalledNode
(
id
))
110
GUM_ERROR
(
NotFound
,
"variable "
<<
var
.
name
() <<
" is not installed"
);
111
112
return
id
;
113
}
114
115
template
<
typename
GUM_SCALAR
>
116
INLINE
NodeId
117
BayesNetFragment
<
GUM_SCALAR
>::
idFromName
(
const
std
::
string
&
name
)
const
{
118
NodeId
id
=
bn__
.
idFromName
(
name
);
119
120
if
(!
isInstalledNode
(
id
))
121
GUM_ERROR
(
NotFound
,
"variable "
<<
name
<<
" is not installed"
);
122
123
return
id
;
124
}
125
126
template
<
typename
GUM_SCALAR
>
127
INLINE
const
DiscreteVariable
&
BayesNetFragment
<
GUM_SCALAR
>::
variableFromName
(
128
const
std
::
string
&
name
)
const
{
129
NodeId
id
=
idFromName
(
name
);
130
131
if
(!
isInstalledNode
(
id
))
132
GUM_ERROR
(
NotFound
,
"variable "
<<
name
<<
" is not installed"
);
133
134
return
bn__
.
variable
(
id
);
135
}
136
137
//============================================================
138
// specific API for BayesNetFragment
139
template
<
typename
GUM_SCALAR
>
140
INLINE
bool
BayesNetFragment
<
GUM_SCALAR
>::
isInstalledNode
(
NodeId
id
)
const
{
141
return
dag
().
existsNode
(
id
);
142
}
143
144
template
<
typename
GUM_SCALAR
>
145
void
BayesNetFragment
<
GUM_SCALAR
>::
installNode
(
NodeId
id
) {
146
if
(!
bn__
.
dag
().
existsNode
(
id
))
147
GUM_ERROR
(
NotFound
,
"Node "
<<
id
<<
" does not exist in referred BayesNet"
);
148
149
if
(!
isInstalledNode
(
id
)) {
150
this
->
dag_
.
addNodeWithId
(
id
);
151
152
// adding arcs with id as a tail
153
for
(
auto
pa
:
this
->
bn__
.
parents
(
id
)) {
154
if
(
isInstalledNode
(
pa
))
this
->
dag_
.
addArc
(
pa
,
id
);
155
}
156
157
// adding arcs with id as a head
158
for
(
auto
son
:
this
->
bn__
.
children
(
id
))
159
if
(
isInstalledNode
(
son
))
this
->
dag_
.
addArc
(
id
,
son
);
160
}
161
}
162
163
template
<
typename
GUM_SCALAR
>
164
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
installAscendants
(
NodeId
id
) {
165
installNode
(
id
);
166
167
// bn is a dag => this will have an end ...
168
for
(
auto
pa
:
this
->
bn__
.
parents
(
id
))
169
installAscendants
(
pa
);
170
}
171
172
template
<
typename
GUM_SCALAR
>
173
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallNode
(
NodeId
id
) {
174
if
(
isInstalledNode
(
id
)) {
175
uninstallCPT
(
id
);
176
this
->
dag_
.
eraseNode
(
id
);
177
}
178
}
179
180
template
<
typename
GUM_SCALAR
>
181
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallArc_
(
NodeId
from
,
182
NodeId
to
) {
183
this
->
dag_
.
eraseArc
(
Arc
(
from
,
to
));
184
}
185
186
template
<
typename
GUM_SCALAR
>
187
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
installArc_
(
NodeId
from
,
NodeId
to
) {
188
this
->
dag_
.
addArc
(
from
,
to
);
189
}
190
191
template
<
typename
GUM_SCALAR
>
192
void
BayesNetFragment
<
GUM_SCALAR
>::
installCPT_
(
193
NodeId
id
,
194
const
Potential
<
GUM_SCALAR
>&
pot
) {
195
// topology
196
const
auto
&
parents
=
this
->
parents
(
id
);
197
for
(
auto
node_it
=
parents
.
beginSafe
();
node_it
!=
parents
.
endSafe
();
198
++
node_it
)
// safe iterator needed here
199
uninstallArc_
(*
node_it
,
id
);
200
201
for
(
Idx
i
= 1;
i
<
pot
.
nbrDim
();
i
++) {
202
NodeId
parent
=
bn__
.
idFromName
(
pot
.
variable
(
i
).
name
());
203
204
if
(
isInstalledNode
(
parent
))
installArc_
(
parent
,
id
);
205
}
206
207
// local cpt
208
if
(
localCPTs__
.
exists
(
id
))
uninstallCPT_
(
id
);
209
210
localCPTs__
.
insert
(
id
,
new
gum
::
Potential
<
GUM_SCALAR
>(
pot
));
211
}
212
213
template
<
typename
GUM_SCALAR
>
214
void
BayesNetFragment
<
GUM_SCALAR
>::
installCPT
(
215
NodeId
id
,
216
const
Potential
<
GUM_SCALAR
>&
pot
) {
217
if
(!
dag
().
existsNode
(
id
))
218
GUM_ERROR
(
NotFound
,
"Node "
<<
id
<<
" is not installed in the fragment"
);
219
220
if
(&(
pot
.
variable
(0)) != &(
variable
(
id
))) {
221
GUM_ERROR
(
OperationNotAllowed
,
222
"The potential is not a marginal for bn__.variable <"
223
<<
variable
(
id
).
name
() <<
">"
);
224
}
225
226
const
NodeSet
&
parents
=
bn__
.
parents
(
id
);
227
228
for
(
Idx
i
= 1;
i
<
pot
.
nbrDim
();
i
++) {
229
if
(!
parents
.
contains
(
bn__
.
idFromName
(
pot
.
variable
(
i
).
name
())))
230
GUM_ERROR
(
OperationNotAllowed
,
231
"Variable <"
<<
pot
.
variable
(
i
).
name
()
232
<<
"> is not in the parents of node "
<<
id
);
233
}
234
235
installCPT_
(
id
,
pot
);
236
}
237
238
template
<
typename
GUM_SCALAR
>
239
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallCPT_
(
NodeId
id
) {
240
delete
localCPTs__
[
id
];
241
localCPTs__
.
erase
(
id
);
242
}
243
244
template
<
typename
GUM_SCALAR
>
245
INLINE
void
BayesNetFragment
<
GUM_SCALAR
>::
uninstallCPT
(
NodeId
id
) {
246
if
(
localCPTs__
.
exists
(
id
)) {
247
uninstallCPT_
(
id
);
248
249
// re-create arcs from referred potential
250
const
Potential
<
GUM_SCALAR
>&
pot
=
cpt
(
id
);
251
252
for
(
Idx
i
= 1;
i
<
pot
.
nbrDim
();
i
++) {
253
NodeId
parent
=
bn__
.
idFromName
(
pot
.
variable
(
i
).
name
());
254
255
if
(
isInstalledNode
(
parent
))
installArc_
(
parent
,
id
);
256
}
257
}
258
}
259
260
template
<
typename
GUM_SCALAR
>
261
void
BayesNetFragment
<
GUM_SCALAR
>::
installMarginal
(
262
NodeId
id
,
263
const
Potential
<
GUM_SCALAR
>&
pot
) {
264
if
(!
isInstalledNode
(
id
)) {
265
GUM_ERROR
(
NotFound
,
"The node "
<<
id
<<
" is not part of this fragment"
);
266
}
267
268
if
(
pot
.
nbrDim
() > 1) {
269
GUM_ERROR
(
OperationNotAllowed
,
"The potential is not a marginal :"
<<
pot
);
270
}
271
272
if
(&(
pot
.
variable
(0)) != &(
bn__
.
variable
(
id
))) {
273
GUM_ERROR
(
OperationNotAllowed
,
274
"The potential is not a marginal for bn__.variable <"
275
<<
bn__
.
variable
(
id
).
name
() <<
">"
);
276
}
277
278
installCPT_
(
id
,
pot
);
279
}
280
281
template
<
typename
GUM_SCALAR
>
282
bool
BayesNetFragment
<
GUM_SCALAR
>::
checkConsistency
(
NodeId
id
)
const
{
283
if
(!
isInstalledNode
(
id
))
284
GUM_ERROR
(
NotFound
,
"The node "
<<
id
<<
" is not part of this fragment"
);
285
286
const
auto
&
cpt
=
this
->
cpt
(
id
);
287
NodeSet
cpt_parents
;
288
289
for
(
Idx
i
= 1;
i
<
cpt
.
nbrDim
();
i
++) {
290
cpt_parents
.
insert
(
bn__
.
idFromName
(
cpt
.
variable
(
i
).
name
()));
291
}
292
293
return
(
this
->
parents
(
id
) ==
cpt_parents
);
294
}
295
296
template
<
typename
GUM_SCALAR
>
297
INLINE
bool
BayesNetFragment
<
GUM_SCALAR
>::
checkConsistency
()
const
{
298
for
(
auto
node
:
nodes
())
299
if
(!
checkConsistency
(
node
))
return
false
;
300
301
return
true
;
302
}
303
304
template
<
typename
GUM_SCALAR
>
305
std
::
string
BayesNetFragment
<
GUM_SCALAR
>::
toDot
()
const
{
306
std
::
stringstream
output
;
307
output
<<
"digraph \""
;
308
309
std
::
string
bn_name
;
310
311
static
std
::
string
inFragmentStyle
=
"fillcolor=\"#ffffaa\","
312
"color=\"#000000\","
313
"fontcolor=\"#000000\""
;
314
static
std
::
string
styleWithLocalCPT
=
"fillcolor=\"#ffddaa\","
315
"color=\"#000000\","
316
"fontcolor=\"#000000\""
;
317
static
std
::
string
notConsistantStyle
=
"fillcolor=\"#ff0000\","
318
"color=\"#000000\","
319
"fontcolor=\"#ffff00\""
;
320
static
std
::
string
outFragmentStyle
=
"fillcolor=\"#f0f0f0\","
321
"color=\"#f0f0f0\","
322
"fontcolor=\"#000000\""
;
323
324
try
{
325
bn_name
=
bn__
.
property
(
"name"
);
326
}
catch
(
NotFound
&) {
bn_name
=
"no_name"
; }
327
328
bn_name
=
"Fragment of "
+
bn_name
;
329
330
output
<<
bn_name
<<
"\" {"
<<
std
::
endl
;
331
output
<<
" graph [bgcolor=transparent,label=\""
<<
bn_name
<<
"\"];"
332
<<
std
::
endl
;
333
output
<<
" node [style=filled];"
<<
std
::
endl
<<
std
::
endl
;
334
335
for
(
auto
node
:
bn__
.
nodes
()) {
336
output
<<
"\""
<<
bn__
.
variable
(
node
).
name
() <<
"\" [comment=\""
<<
node
337
<<
":"
<<
bn__
.
variable
(
node
) <<
", \""
;
338
339
if
(
isInstalledNode
(
node
)) {
340
if
(!
checkConsistency
(
node
)) {
341
output
<<
notConsistantStyle
;
342
}
else
if
(
localCPTs__
.
exists
(
node
))
343
output
<<
styleWithLocalCPT
;
344
else
345
output
<<
inFragmentStyle
;
346
}
else
347
output
<<
outFragmentStyle
;
348
349
output
<<
"];"
<<
std
::
endl
;
350
}
351
352
output
<<
std
::
endl
;
353
354
std
::
string
tab
=
" "
;
355
356
for
(
auto
node
:
bn__
.
nodes
()) {
357
if
(
bn__
.
children
(
node
).
size
() > 0) {
358
for
(
auto
child
:
bn__
.
children
(
node
)) {
359
output
<<
tab
<<
"\""
<<
bn__
.
variable
(
node
).
name
() <<
"\" -> "
360
<<
"\""
<<
bn__
.
variable
(
child
).
name
() <<
"\" ["
;
361
362
if
(
dag
().
existsArc
(
Arc
(
node
,
child
)))
363
output
<<
inFragmentStyle
;
364
else
365
output
<<
outFragmentStyle
;
366
367
output
<<
"];"
<<
std
::
endl
;
368
}
369
}
370
}
371
372
output
<<
"}"
<<
std
::
endl
;
373
374
return
output
.
str
();
375
}
376
377
template
<
typename
GUM_SCALAR
>
378
gum
::
BayesNet
<
GUM_SCALAR
>
BayesNetFragment
<
GUM_SCALAR
>::
toBN
()
const
{
379
if
(!
checkConsistency
()) {
380
GUM_ERROR
(
OperationNotAllowed
,
"The fragment contains un-consistent node(s)"
)
381
}
382
gum
::
BayesNet
<
GUM_SCALAR
>
res
;
383
for
(
const
auto
nod
:
nodes
()) {
384
res
.
add
(
variable
(
nod
),
nod
);
385
}
386
for
(
const
auto
&
arc
:
dag
().
arcs
()) {
387
res
.
addArc
(
arc
.
tail
(),
arc
.
head
());
388
}
389
for
(
const
auto
nod
:
nodes
()) {
390
res
.
cpt
(
nod
).
fillWith
(
cpt
(
nod
));
391
}
392
393
return
res
;
394
}
395
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669