aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
barrenNodesFinder.cpp
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Detect barren nodes for inference in Bayesian networks
25
*/
26
#
include
<
limits
>
27
28
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder
.
h
>
29
30
#
ifdef
GUM_NO_INLINE
31
#
include
<
agrum
/
BN
/
algorithms
/
barrenNodesFinder_inl
.
h
>
32
#
endif
// GUM_NO_INLINE
33
34
namespace
gum
{
35
36
/// returns the set of barren nodes in the messages sent in a junction tree
37
ArcProperty
<
NodeSet
>
38
BarrenNodesFinder
::
barrenNodes
(
const
CliqueGraph
&
junction_tree
) {
39
// assign a mark to all the nodes
40
// and mark all the observed nodes and their ancestors as non-barren
41
NodeProperty
<
Size
>
mark
(
dag__
->
size
());
42
{
43
for
(
const
auto
node
: *
dag__
)
44
mark
.
insert
(
node
, 0);
// for the moment, 0 = possibly barren
45
46
// mark all the observed nodes and their ancestors as non barren
47
// std::numeric_limits<unsigned int>::max () will be necessarily non
48
// barren
49
// later on
50
Sequence
<
NodeId
>
observed_anc
(
dag__
->
size
());
51
const
Size
non_barren
=
std
::
numeric_limits
<
Size
>::
max
();
52
for
(
const
auto
node
: *
observed_nodes__
)
53
observed_anc
.
insert
(
node
);
54
for
(
Idx
i
= 0;
i
<
observed_anc
.
size
(); ++
i
) {
55
const
NodeId
node
=
observed_anc
[
i
];
56
if
(!
mark
[
node
]) {
57
mark
[
node
] =
non_barren
;
58
for
(
const
auto
par
:
dag__
->
parents
(
node
)) {
59
if
(!
mark
[
par
] && !
observed_anc
.
exists
(
par
)) {
60
observed_anc
.
insert
(
par
);
61
}
62
}
63
}
64
}
65
}
66
67
// create the data structure that will contain the result of the
68
// method. By default, we assume that, for each pair of adjacent cliques,
69
// all
70
// the nodes that do not belong to their separator are possibly barren and,
71
// by sweeping the dag, we will remove the nodes that were determined
72
// above as non-barren. Structure result will assign to each (ordered) pair
73
// of adjacent cliques its set of barren nodes.
74
ArcProperty
<
NodeSet
>
result
;
75
for
(
const
auto
&
edge
:
junction_tree
.
edges
()) {
76
const
NodeSet
&
separator
=
junction_tree
.
separator
(
edge
);
77
78
NodeSet
non_barren1
=
junction_tree
.
clique
(
edge
.
first
());
79
for
(
auto
iter
=
non_barren1
.
beginSafe
();
iter
!=
non_barren1
.
endSafe
();
80
++
iter
) {
81
if
(
mark
[*
iter
] ||
separator
.
exists
(*
iter
)) {
non_barren1
.
erase
(
iter
); }
82
}
83
result
.
insert
(
Arc
(
edge
.
first
(),
edge
.
second
()),
std
::
move
(
non_barren1
));
84
85
NodeSet
non_barren2
=
junction_tree
.
clique
(
edge
.
second
());
86
for
(
auto
iter
=
non_barren2
.
beginSafe
();
iter
!=
non_barren2
.
endSafe
();
87
++
iter
) {
88
if
(
mark
[*
iter
] ||
separator
.
exists
(*
iter
)) {
non_barren2
.
erase
(
iter
); }
89
}
90
result
.
insert
(
Arc
(
edge
.
second
(),
edge
.
first
()),
std
::
move
(
non_barren2
));
91
}
92
93
// for each node in the DAG, indicate which are the arcs in the result
94
// structure whose separator contain it: the separators are actually the
95
// targets of the queries.
96
NodeProperty
<
ArcSet
>
node2arc
;
97
for
(
const
auto
node
: *
dag__
)
98
node2arc
.
insert
(
node
,
ArcSet
());
99
for
(
const
auto
&
elt
:
result
) {
100
const
Arc
&
arc
=
elt
.
first
;
101
if
(!
result
[
arc
].
empty
()) {
// no need to further process cliques
102
const
NodeSet
&
separator
=
// with no barren nodes
103
junction_tree
.
separator
(
Edge
(
arc
.
tail
(),
arc
.
head
()));
104
105
for
(
const
auto
node
:
separator
) {
106
node2arc
[
node
].
insert
(
arc
);
107
}
108
}
109
}
110
111
// To determine the set of non-barren nodes w.r.t. a given single node
112
// query, we rely on the fact that those are precisely all the ancestors of
113
// this single node. To mutualize the computations, we will thus sweep the
114
// DAG from top to bottom and exploit the fact that the set of ancestors of
115
// the child of a given node A contain the ancestors of A. Therefore, we
116
// will
117
// determine sets of paths in the DAG and, for each path, compute the set of
118
// its barren nodes from the source to the destination of the path. The
119
// optimal set of paths, i.e., that which will minimize computations, is
120
// obtained by solving a "minimum path cover in directed acyclic graphs".
121
// But
122
// such an algorithm is too costly for the gain we can get from it, so we
123
// will
124
// rely on a simple heuristics.
125
126
// To compute the heuristics, we proceed as follows:
127
// 1/ we mark to 1 all the nodes that are ancestors of at least one (key)
128
// node
129
// with a non-empty arcset in node2arc and we extract from those the
130
// roots, i.e., those nodes whose set of parents, if any, have all been
131
// identified as non-barren by being marked as
132
// std::numeric_limits<unsigned int>::max (). Such nodes are
133
// thus the top of the graph to sweep.
134
// 2/ create a copy of the subgraph of the DAG w.r.t. the 1-marked nodes
135
// and, for each node, if the node has several parents and children,
136
// keep only one arc from one of the parents to the child with the
137
// smallest
138
// number of parents, and try to create a matching between parents and
139
// children and add one arc for each edge of this matching. This will
140
// allow
141
// us to create distinct paths in the DAG. Whenever a child has no more
142
// parents, it becomes the root of a new path.
143
// 3/ the sweeping will be performed from the roots of all these paths.
144
145
// perform step 1/
146
NodeSet
path_roots
;
147
{
148
List
<
NodeId
>
nodes_to_mark
;
149
for
(
const
auto
&
elt
:
node2arc
) {
150
if
(!
elt
.
second
.
empty
()) {
// only process nodes with assigned arcs
151
nodes_to_mark
.
insert
(
elt
.
first
);
152
}
153
}
154
while
(!
nodes_to_mark
.
empty
()) {
155
NodeId
node
=
nodes_to_mark
.
front
();
156
nodes_to_mark
.
popFront
();
157
158
if
(!
mark
[
node
]) {
// mark the node and all its ancestors
159
mark
[
node
] = 1;
160
Size
nb_par
= 0;
161
for
(
auto
par
:
dag__
->
parents
(
node
)) {
162
Size
parent_mark
=
mark
[
par
];
163
if
(
parent_mark
!=
std
::
numeric_limits
<
Size
>::
max
()) {
164
++
nb_par
;
165
if
(
parent_mark
== 0) {
nodes_to_mark
.
insert
(
par
); }
166
}
167
}
168
169
if
(
nb_par
== 0) {
path_roots
.
insert
(
node
); }
170
}
171
}
172
}
173
174
// perform step 2/
175
DAG
sweep_dag
= *
dag__
;
176
for
(
const
auto
node
: *
dag__
) {
// keep only nodes marked with 1
177
if
(
mark
[
node
] != 1) {
sweep_dag
.
eraseNode
(
node
); }
178
}
179
for
(
const
auto
node
:
sweep_dag
) {
180
const
Size
nb_parents
=
sweep_dag
.
parents
(
node
).
size
();
181
const
Size
nb_children
=
sweep_dag
.
children
(
node
).
size
();
182
if
((
nb_parents
> 1) || (
nb_children
> 1)) {
183
// perform the matching
184
const
auto
&
parents
=
sweep_dag
.
parents
(
node
);
185
186
// if there is no child, remove all the parents except the first one
187
if
(
nb_children
== 0) {
188
auto
iter_par
=
parents
.
beginSafe
();
189
for
(++
iter_par
;
iter_par
!=
parents
.
endSafe
(); ++
iter_par
) {
190
sweep_dag
.
eraseArc
(
Arc
(*
iter_par
,
node
));
191
}
192
}
else
{
193
// find the child with the smallest number of parents
194
const
auto
&
children
=
sweep_dag
.
children
(
node
);
195
NodeId
smallest_child
= 0;
196
Size
smallest_nb_par
=
std
::
numeric_limits
<
Size
>::
max
();
197
for
(
const
auto
child
:
children
) {
198
const
auto
new_nb
=
sweep_dag
.
parents
(
child
).
size
();
199
if
(
new_nb
<
smallest_nb_par
) {
200
smallest_nb_par
=
new_nb
;
201
smallest_child
=
child
;
202
}
203
}
204
205
// if there is no parent, just keep the link with the smallest child
206
// and remove all the other arcs
207
if
(
nb_parents
== 0) {
208
for
(
auto
iter
=
children
.
beginSafe
();
iter
!=
children
.
endSafe
();
209
++
iter
) {
210
if
(*
iter
!=
smallest_child
) {
211
if
(
sweep_dag
.
parents
(*
iter
).
size
() == 1) {
212
path_roots
.
insert
(*
iter
);
213
}
214
sweep_dag
.
eraseArc
(
Arc
(
node
, *
iter
));
215
}
216
}
217
}
else
{
218
auto
nb_match
=
Size
(
std
::
min
(
nb_parents
,
nb_children
) - 1);
219
auto
iter_par
=
parents
.
beginSafe
();
220
++
iter_par
;
// skip the first parent, whose arc with node will
221
// remain
222
auto
iter_child
=
children
.
beginSafe
();
223
for
(
Idx
i
= 0;
i
<
nb_match
; ++
i
, ++
iter_par
, ++
iter_child
) {
224
if
(*
iter_child
==
smallest_child
) { ++
iter_child
; }
225
sweep_dag
.
addArc
(*
iter_par
, *
iter_child
);
226
sweep_dag
.
eraseArc
(
Arc
(*
iter_par
,
node
));
227
sweep_dag
.
eraseArc
(
Arc
(
node
, *
iter_child
));
228
}
229
for
(;
iter_par
!=
parents
.
endSafe
(); ++
iter_par
) {
230
sweep_dag
.
eraseArc
(
Arc
(*
iter_par
,
node
));
231
}
232
for
(;
iter_child
!=
children
.
endSafe
(); ++
iter_child
) {
233
if
(*
iter_child
!=
smallest_child
) {
234
if
(
sweep_dag
.
parents
(*
iter_child
).
size
() == 1) {
235
path_roots
.
insert
(*
iter_child
);
236
}
237
sweep_dag
.
eraseArc
(
Arc
(
node
, *
iter_child
));
238
}
239
}
240
}
241
}
242
}
243
}
244
245
// step 3: sweep the paths from the roots of sweep_dag
246
// here, the idea is that, for each path of sweep_dag, the mark we put
247
// to the ancestors is a given number, say N, that increases from path
248
// to path. Hence, for a given path, all the nodes that are marked with a
249
// number at least as high as N are non-barren, the others being barren.
250
Idx
mark_id
= 2;
251
for
(
NodeId
path
:
path_roots
) {
252
// perform the sweeping from the path
253
while
(
true
) {
254
// mark all the ancestors of the node
255
List
<
NodeId
>
to_mark
{
path
};
256
while
(!
to_mark
.
empty
()) {
257
NodeId
node
=
to_mark
.
front
();
258
to_mark
.
popFront
();
259
if
(
mark
[
node
] <
mark_id
) {
260
mark
[
node
] =
mark_id
;
261
for
(
const
auto
par
:
dag__
->
parents
(
node
)) {
262
if
(
mark
[
par
] <
mark_id
) {
to_mark
.
insert
(
par
); }
263
}
264
}
265
}
266
267
// now, get all the arcs that contained node "path" in their separator.
268
// this node acts as a query target and, therefore, its ancestors
269
// shall be non-barren.
270
const
ArcSet
&
arcs
=
node2arc
[
path
];
271
for
(
const
auto
&
arc
:
arcs
) {
272
NodeSet
&
barren
=
result
[
arc
];
273
for
(
auto
iter
=
barren
.
beginSafe
();
iter
!=
barren
.
endSafe
(); ++
iter
) {
274
if
(
mark
[*
iter
] >=
mark_id
) {
275
// this indicates a non-barren node
276
barren
.
erase
(
iter
);
277
}
278
}
279
}
280
281
// go to the next sweeping node
282
const
NodeSet
&
sweep_children
=
sweep_dag
.
children
(
path
);
283
if
(
sweep_children
.
size
()) {
284
path
= *(
sweep_children
.
begin
());
285
}
else
{
286
// here, the path has ended, so we shall go to the next path
287
++
mark_id
;
288
break
;
289
}
290
}
291
}
292
293
return
result
;
294
}
295
296
/// returns the set of barren nodes
297
NodeSet
BarrenNodesFinder
::
barrenNodes
() {
298
// mark all the nodes in the dag as barren (true)
299
NodeProperty
<
bool
>
barren_mark
=
dag__
->
nodesProperty
(
true
);
300
301
// mark all the ancestors of the evidence and targets as non-barren
302
List
<
NodeId
>
nodes_to_examine
;
303
int
nb_non_barren
= 0;
304
for
(
const
auto
node
: *
observed_nodes__
)
305
nodes_to_examine
.
insert
(
node
);
306
for
(
const
auto
node
: *
target_nodes__
)
307
nodes_to_examine
.
insert
(
node
);
308
309
while
(!
nodes_to_examine
.
empty
()) {
310
const
NodeId
node
=
nodes_to_examine
.
front
();
311
nodes_to_examine
.
popFront
();
312
if
(
barren_mark
[
node
]) {
313
barren_mark
[
node
] =
false
;
314
++
nb_non_barren
;
315
for
(
const
auto
par
:
dag__
->
parents
(
node
))
316
nodes_to_examine
.
insert
(
par
);
317
}
318
}
319
320
// here, all the nodes marked true are barren
321
NodeSet
barren_nodes
(
dag__
->
sizeNodes
() -
nb_non_barren
);
322
for
(
const
auto
&
marked_pair
:
barren_mark
)
323
if
(
marked_pair
.
second
)
barren_nodes
.
insert
(
marked_pair
.
first
);
324
325
return
barren_nodes
;
326
}
327
328
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669