aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
dSeparation_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief d-separation analysis (as described in Koller & Friedman 2009)
25
*
26
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
29
30
namespace
gum
{
31
32
33
// update a set of potentials, keeping only those d-connected with
34
// query variables given evidence
35
template
<
typename
GUM_SCALAR
,
template
<
typename
>
class
TABLE
>
36
void
dSeparation
::
relevantPotentials
(
const
IBayesNet
<
GUM_SCALAR
>&
bn
,
37
const
NodeSet
&
query
,
38
const
NodeSet
&
hardEvidence
,
39
const
NodeSet
&
softEvidence
,
40
Set
<
const
TABLE
<
GUM_SCALAR
>* >&
potentials
) {
41
const
DAG
&
dag
=
bn
.
dag
();
42
43
// mark the set of ancestors of the evidence
44
NodeSet
ev_ancestors
(
dag
.
size
());
45
{
46
List
<
NodeId
>
anc_to_visit
;
47
for
(
const
auto
node
:
hardEvidence
)
48
anc_to_visit
.
insert
(
node
);
49
for
(
const
auto
node
:
softEvidence
)
50
anc_to_visit
.
insert
(
node
);
51
while
(!
anc_to_visit
.
empty
()) {
52
const
NodeId
node
=
anc_to_visit
.
front
();
53
anc_to_visit
.
popFront
();
54
55
if
(!
ev_ancestors
.
exists
(
node
)) {
56
ev_ancestors
.
insert
(
node
);
57
for
(
const
auto
par
:
dag
.
parents
(
node
)) {
58
anc_to_visit
.
insert
(
par
);
59
}
60
}
61
}
62
}
63
64
// create the marks indicating that we have visited a node
65
NodeSet
visited_from_child
(
dag
.
size
());
66
NodeSet
visited_from_parent
(
dag
.
size
());
67
68
/// for relevant potentials: indicate which tables contain a variable
69
/// (nodeId)
70
HashTable
<
NodeId
,
Set
<
const
TABLE
<
GUM_SCALAR
>* > >
node2potentials
;
71
for
(
const
auto
pot
:
potentials
) {
72
const
Sequence
<
const
DiscreteVariable
* >&
vars
=
pot
->
variablesSequence
();
73
for
(
const
auto
var
:
vars
) {
74
const
NodeId
id
=
bn
.
nodeId
(*
var
);
75
if
(!
node2potentials
.
exists
(
id
)) {
76
node2potentials
.
insert
(
id
,
Set
<
const
TABLE
<
GUM_SCALAR
>* >());
77
}
78
node2potentials
[
id
].
insert
(
pot
);
79
}
80
}
81
82
// indicate that we will send the ball to all the query nodes (as children):
83
// in list nodes_to_visit, the first element is the next node to send the
84
// ball to and the Boolean indicates whether we shall reach it from one of
85
// its children (true) or from one parent (false)
86
List
<
std
::
pair
<
NodeId
,
bool
> >
nodes_to_visit
;
87
for
(
const
auto
node
:
query
) {
88
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
node
,
true
));
89
}
90
91
// perform the bouncing ball until there is no node in the graph to send
92
// the ball to
93
while
(!
nodes_to_visit
.
empty
() && !
node2potentials
.
empty
()) {
94
// get the next node to visit
95
const
NodeId
node
=
nodes_to_visit
.
front
().
first
;
96
const
bool
direction
=
nodes_to_visit
.
front
().
second
;
97
nodes_to_visit
.
popFront
();
98
99
// check if the node has not already been visited in the same direction
100
bool
already_visited
;
101
if
(
direction
) {
102
already_visited
=
visited_from_child
.
exists
(
node
);
103
if
(!
already_visited
) {
visited_from_child
.
insert
(
node
); }
104
}
else
{
105
already_visited
=
visited_from_parent
.
exists
(
node
);
106
if
(!
already_visited
) {
visited_from_parent
.
insert
(
node
); }
107
}
108
109
// if the node belongs to the query, update _node2potentials_: remove all
110
// the potentials containing the node
111
if
(
node2potentials
.
exists
(
node
)) {
112
auto
&
pot_set
=
node2potentials
[
node
];
113
for
(
const
auto
pot
:
pot_set
) {
114
const
auto
&
vars
=
pot
->
variablesSequence
();
115
for
(
const
auto
var
:
vars
) {
116
const
NodeId
id
=
bn
.
nodeId
(*
var
);
117
if
(
id
!=
node
) {
118
node2potentials
[
id
].
erase
(
pot
);
119
if
(
node2potentials
[
id
].
empty
()) {
node2potentials
.
erase
(
id
); }
120
}
121
}
122
}
123
node2potentials
.
erase
(
node
);
124
125
// if _node2potentials_ is empty, no need to go on: all the potentials
126
// are d-connected to the query
127
if
(
node2potentials
.
empty
())
return
;
128
}
129
130
// if this is the first time we meet the node, then visit it
131
if
(!
already_visited
) {
132
// mark the node as reachable if this is not a hard evidence
133
const
bool
is_hard_evidence
=
hardEvidence
.
exists
(
node
);
134
135
// bounce the ball toward the neighbors
136
if
(
direction
&& !
is_hard_evidence
) {
// visit from a child
137
// visit the parents
138
for
(
const
auto
par
:
dag
.
parents
(
node
)) {
139
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
par
,
true
));
140
}
141
142
// visit the children
143
for
(
const
auto
chi
:
dag
.
children
(
node
)) {
144
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
chi
,
false
));
145
}
146
}
else
{
// visit from a parent
147
if
(!
hardEvidence
.
exists
(
node
)) {
148
// visit the children
149
for
(
const
auto
chi
:
dag
.
children
(
node
)) {
150
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
chi
,
false
));
151
}
152
}
153
if
(
ev_ancestors
.
exists
(
node
)) {
154
// visit the parents
155
for
(
const
auto
par
:
dag
.
parents
(
node
)) {
156
nodes_to_visit
.
insert
(
std
::
pair
<
NodeId
,
bool
>(
par
,
true
));
157
}
158
}
159
}
160
}
161
}
162
163
// here, all the potentials that belong to _node2potentials_ are d-separated
164
// from the query
165
for
(
const
auto
elt
:
node2potentials
) {
166
for
(
const
auto
pot
:
elt
.
second
) {
167
potentials
.
erase
(
pot
);
168
}
169
}
170
}
171
172
173
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643