aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
BayesBall_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the BayesBall class.
25
*/
26
27
namespace
gum
{
28
29
30
// update a set of potentials, keeping only those d-connected with
31
// query variables
32
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
33
void
34
BayesBall
::
relevantPotentials
(
const
IBayesNet
<
GUM_SCALAR
>&
bn
,
35
const
NodeSet
&
query
,
36
const
NodeSet
&
hardEvidence
,
37
const
NodeSet
&
softEvidence
,
38
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
potentials
) {
39
const
DAG
&
dag
=
bn
.
dag
();
40
41
// create the marks (top = first and bottom = second)
42
NodeProperty
<
std
::
pair
<
bool
,
bool
> >
marks
;
43
marks
.
resize
(
dag
.
size
());
44
const
std
::
pair
<
bool
,
bool
>
empty_mark
(
false
,
false
);
45
46
/// for relevant potentials: indicate which tables contain a variable
47
/// (nodeId)
48
HashTable
<
NodeId
,
Set
<
const
TABLE
<
GUM_SCALAR
>* > >
node2potentials
;
49
for
(
const
auto
pot
:
potentials
) {
50
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
51
for
(
const
auto
var
:
vars
) {
52
const
NodeId
id
=
bn
.
nodeId
(*
var
);
53
if
(!
node2potentials
.
exists
(
id
)) {
54
node2potentials
.
insert
(
id
,
Set
<
const
TABLE
<
GUM_SCALAR
>* >());
55
}
56
node2potentials
[
id
].
insert
(
pot
);
57
}
58
}
59
60
// indicate that we will send the ball to all the query nodes (as children):
61
// in list nodes_to_visit, the first element is the next node to send the
62
// ball to and the Boolean indicates whether we shall reach it from one of
63
// its children (true) or from one parent (false)
64
List
<
std
::
pair
<
NodeId
,
bool
> >
nodes_to_visit
;
65
for
(
const
auto
node
:
query
) {
66
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
node
,
true
));
67
}
68
69
// perform the bouncing ball until node2potentials__ becomes empty (which
70
// means that we have reached all the potentials and, therefore, those
71
// are d-connected to query) or until there is no node in the graph to send
72
// the ball to
73
while
(!
nodes_to_visit
.
empty
() && !
node2potentials
.
empty
()) {
74
// get the next node to visit
75
NodeId
node
=
nodes_to_visit
.
front
().
first
;
76
77
// if the marks of the node do not exist, create them
78
if
(!
marks
.
exists
(
node
))
marks
.
insert
(
node
,
empty_mark
);
79
80
// if the node belongs to the query, update node2potentials__: remove all
81
// the potentials containing the node
82
if
(
node2potentials
.
exists
(
node
)) {
83
auto
&
pot_set
=
node2potentials
[
node
];
84
for
(
const
auto
pot
:
pot_set
) {
85
const
auto
&
vars
=
pot
->
variablesSequence
();
86
for
(
const
auto
var
:
vars
) {
87
const
NodeId
id
=
bn
.
nodeId
(*
var
);
88
if
(
id
!=
node
) {
89
node2potentials
[
id
].
erase
(
pot
);
90
if
(
node2potentials
[
id
].
empty
()) {
node2potentials
.
erase
(
id
); }
91
}
92
}
93
}
94
node2potentials
.
erase
(
node
);
95
96
// if node2potentials__ is empty, no need to go on: all the potentials
97
// are d-connected to the query
98
if
(
node2potentials
.
empty
())
return
;
99
}
100
101
102
// bounce the ball toward the neighbors
103
if
(
nodes_to_visit
.
front
().
second
) {
// visit from a child
104
nodes_to_visit
.
popFront
();
105
106
if
(
hardEvidence
.
exists
(
node
)) {
continue
; }
107
108
if
(!
marks
[
node
].
first
) {
109
marks
[
node
].
first
=
true
;
// top marked
110
for
(
const
auto
par
:
dag
.
parents
(
node
)) {
111
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
par
,
true
));
112
}
113
}
114
115
if
(!
marks
[
node
].
second
) {
116
marks
[
node
].
second
=
true
;
// bottom marked
117
for
(
const
auto
chi
:
dag
.
children
(
node
)) {
118
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
chi
,
false
));
119
}
120
}
121
}
else
{
// visit from a parent
122
nodes_to_visit
.
popFront
();
123
124
const
bool
is_hard_evidence
=
hardEvidence
.
exists
(
node
);
125
const
bool
is_evidence
=
is_hard_evidence
||
softEvidence
.
exists
(
node
);
126
127
if
(
is_evidence
&& !
marks
[
node
].
first
) {
128
marks
[
node
].
first
=
true
;
129
130
for
(
const
auto
par
:
dag
.
parents
(
node
)) {
131
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
par
,
true
));
132
}
133
}
134
135
if
(!
is_hard_evidence
&& !
marks
[
node
].
second
) {
136
marks
[
node
].
second
=
true
;
137
138
for
(
const
auto
chi
:
dag
.
children
(
node
)) {
139
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
chi
,
false
));
140
}
141
}
142
}
143
}
144
145
146
// here, all the potentials that belong to node2potentials__ are d-separated
147
// from the query
148
for
(
const
auto
elt
:
node2potentials
) {
149
for
(
const
auto
pot
:
elt
.
second
) {
150
potentials
.
erase
(
pot
);
151
}
152
}
153
}
154
155
156
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669