aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
BayesBall_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the BayesBall class.
25
*/
26
27
namespace
gum
{
28
29
30
// update a set of potentials, keeping only those d-connected with
31
// query variables
32
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
33
void
BayesBall
::
relevantPotentials
(
const
IBayesNet
<
GUM_SCALAR
>&
bn
,
34
const
NodeSet
&
query
,
35
const
NodeSet
&
hardEvidence
,
36
const
NodeSet
&
softEvidence
,
37
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
potentials
) {
38
const
DAG
&
dag
=
bn
.
dag
();
39
40
// create the marks (top = first and bottom = second)
41
NodeProperty
<
std
::
pair
<
bool
,
bool
> >
marks
;
42
marks
.
resize
(
dag
.
size
());
43
const
std
::
pair
<
bool
,
bool
>
empty_mark
(
false
,
false
);
44
45
/// for relevant potentials: indicate which tables contain a variable
46
/// (nodeId)
47
HashTable
<
NodeId
,
Set
<
const
TABLE
<
GUM_SCALAR
>* > >
node2potentials
;
48
for
(
const
auto
pot
:
potentials
) {
49
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
50
for
(
const
auto
var
:
vars
) {
51
const
NodeId
id
=
bn
.
nodeId
(*
var
);
52
if
(!
node2potentials
.
exists
(
id
)) {
53
node2potentials
.
insert
(
id
,
Set
<
const
TABLE
<
GUM_SCALAR
>* >());
54
}
55
node2potentials
[
id
].
insert
(
pot
);
56
}
57
}
58
59
// indicate that we will send the ball to all the query nodes (as children):
60
// in list nodes_to_visit, the first element is the next node to send the
61
// ball to and the Boolean indicates whether we shall reach it from one of
62
// its children (true) or from one parent (false)
63
List
<
std
::
pair
<
NodeId
,
bool
> >
nodes_to_visit
;
64
for
(
const
auto
node
:
query
) {
65
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
node
,
true
));
66
}
67
68
// perform the bouncing ball until _node2potentials_ becomes empty (which
69
// means that we have reached all the potentials and, therefore, those
70
// are d-connected to query) or until there is no node in the graph to send
71
// the ball to
72
while
(!
nodes_to_visit
.
empty
() && !
node2potentials
.
empty
()) {
73
// get the next node to visit
74
NodeId
node
=
nodes_to_visit
.
front
().
first
;
75
76
// if the marks of the node do not exist, create them
77
if
(!
marks
.
exists
(
node
))
marks
.
insert
(
node
,
empty_mark
);
78
79
// if the node belongs to the query, update _node2potentials_: remove all
80
// the potentials containing the node
81
if
(
node2potentials
.
exists
(
node
)) {
82
auto
&
pot_set
=
node2potentials
[
node
];
83
for
(
const
auto
pot
:
pot_set
) {
84
const
auto
&
vars
=
pot
->
variablesSequence
();
85
for
(
const
auto
var
:
vars
) {
86
const
NodeId
id
=
bn
.
nodeId
(*
var
);
87
if
(
id
!=
node
) {
88
node2potentials
[
id
].
erase
(
pot
);
89
if
(
node2potentials
[
id
].
empty
()) {
node2potentials
.
erase
(
id
); }
90
}
91
}
92
}
93
node2potentials
.
erase
(
node
);
94
95
// if _node2potentials_ is empty, no need to go on: all the potentials
96
// are d-connected to the query
97
if
(
node2potentials
.
empty
())
return
;
98
}
99
100
101
// bounce the ball toward the neighbors
102
if
(
nodes_to_visit
.
front
().
second
) {
// visit from a child
103
nodes_to_visit
.
popFront
();
104
105
if
(
hardEvidence
.
exists
(
node
)) {
continue
; }
106
107
if
(!
marks
[
node
].
first
) {
108
marks
[
node
].
first
=
true
;
// top marked
109
for
(
const
auto
par
:
dag
.
parents
(
node
)) {
110
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
par
,
true
));
111
}
112
}
113
114
if
(!
marks
[
node
].
second
) {
115
marks
[
node
].
second
=
true
;
// bottom marked
116
for
(
const
auto
chi
:
dag
.
children
(
node
)) {
117
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
chi
,
false
));
118
}
119
}
120
}
else
{
// visit from a parent
121
nodes_to_visit
.
popFront
();
122
123
const
bool
is_hard_evidence
=
hardEvidence
.
exists
(
node
);
124
const
bool
is_evidence
=
is_hard_evidence
||
softEvidence
.
exists
(
node
);
125
126
if
(
is_evidence
&& !
marks
[
node
].
first
) {
127
marks
[
node
].
first
=
true
;
128
129
for
(
const
auto
par
:
dag
.
parents
(
node
)) {
130
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
par
,
true
));
131
}
132
}
133
134
if
(!
is_hard_evidence
&& !
marks
[
node
].
second
) {
135
marks
[
node
].
second
=
true
;
136
137
for
(
const
auto
chi
:
dag
.
children
(
node
)) {
138
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
chi
,
false
));
139
}
140
}
141
}
142
}
143
144
145
// here, all the potentials that belong to _node2potentials_ are d-separated
146
// from the query
147
for
(
const
auto
elt
:
node2potentials
) {
148
for
(
const
auto
pot
:
elt
.
second
) {
149
potentials
.
erase
(
pot
);
150
}
151
}
152
}
153
154
155
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643