aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
binaryJoinTreeConverterDefault.cpp
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief An algorithm for converting a join tree into a binary join tree
24
*
25
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
26
*/
27
28
#
include
<
agrum
/
agrum
.
h
>
29
30
#
include
<
agrum
/
tools
/
core
/
priorityQueue
.
h
>
31
#
include
<
agrum
/
tools
/
graphs
/
algorithms
/
binaryJoinTreeConverterDefault
.
h
>
32
33
namespace
gum
{
34
35
/// default constructor
36
BinaryJoinTreeConverterDefault
::
BinaryJoinTreeConverterDefault
() {
37
GUM_CONSTRUCTOR
(
BinaryJoinTreeConverterDefault
);
38
}
39
40
/// destructor
41
BinaryJoinTreeConverterDefault
::~
BinaryJoinTreeConverterDefault
() {
42
GUM_DESTRUCTOR
(
BinaryJoinTreeConverterDefault
);
43
}
44
45
/** @brief a function used to mark the nodes belonging to a given
46
* connected component */
47
void
BinaryJoinTreeConverterDefault
::
_markConnectedComponent_
(
const
CliqueGraph
&
JT
,
48
NodeId
root
,
49
NodeProperty
<
bool
>&
mark
)
const
{
50
// we mark the nodes in a depth first search manner. To avoid a recursive
51
// algorithm, use a vector to simulate a stack of nodes to inspect.
52
// stack => depth first search
53
std
::
vector
<
NodeId
>
nodes_to_inspect
;
54
nodes_to_inspect
.
reserve
(
JT
.
sizeNodes
());
55
56
// the idea to populate the marks is to use the stack: root is
57
// put onto the stack. Then, while the stack is not empty, remove
58
// the top of the stack and mark it and put into the stack its
59
// adjacent nodes.
60
nodes_to_inspect
.
push_back
(
root
);
61
62
while
(!
nodes_to_inspect
.
empty
()) {
63
// process the top of the stack
64
NodeId
current_node
=
nodes_to_inspect
.
back
();
65
nodes_to_inspect
.
pop_back
();
66
67
// only process the node if it has not been processed yet (actually,
68
// this should not occur unless the clique graph is not singly connected)
69
70
if
(!
mark
[
current_node
]) {
71
mark
[
current_node
] =
true
;
72
73
// put the neighbors onto the stack
74
for
(
const
auto
neigh
:
JT
.
neighbours
(
current_node
))
75
if
(!
mark
[
neigh
])
nodes_to_inspect
.
push_back
(
neigh
);
76
}
77
}
78
}
79
80
/// returns the domain size of the union of two cliques
81
float
BinaryJoinTreeConverterDefault
::
_combinedSize_
(
82
const
NodeSet
&
nodes1
,
83
const
NodeSet
&
nodes2
,
84
const
NodeProperty
<
Size
>&
domain_sizes
)
const
{
85
float
result
= 1;
86
87
for
(
const
auto
node
:
nodes1
)
88
result
*=
domain_sizes
[
node
];
89
90
for
(
const
auto
node
:
nodes2
)
91
if
(!
nodes1
.
exists
(
node
))
result
*=
domain_sizes
[
node
];
92
93
return
result
;
94
}
95
96
/// returns all the roots considered for all the connected components
97
const
NodeSet
&
BinaryJoinTreeConverterDefault
::
roots
()
const
{
return
_roots_
; }
98
99
/// convert a clique and its adjacent cliques into a binary join tree
100
void
BinaryJoinTreeConverterDefault
::
_convertClique_
(
101
CliqueGraph
&
JT
,
102
NodeId
clique
,
103
NodeId
from
,
104
const
NodeProperty
<
Size
>&
domain_sizes
)
const
{
105
// get the neighbors of clique. If there are fewer than 3 neighbors,
106
// there is nothing to do
107
const
NodeSet
&
neighbors
=
JT
.
neighbours
(
clique
);
108
109
if
(
neighbors
.
size
() <= 2)
return
;
110
111
if
((
neighbors
.
size
() == 3) && (
clique
!=
from
))
return
;
112
113
// here we need to transform the neighbors into a binary tree
114
// create a vector with all the ids of the cliques to combine
115
std
::
vector
<
NodeId
>
cliques
;
116
cliques
.
reserve
(
neighbors
.
size
());
117
118
for
(
const
auto
nei
:
neighbors
)
119
if
(
nei
!=
from
)
cliques
.
push_back
(
nei
);
120
121
// create a vector indicating wether the elements in cliques contain
122
// relevant information or not (during the execution of the for
123
// loop below, a cell of vector cliques may actually contain only
124
// trash data)
125
std
::
vector
<
bool
>
is_cliques_relevant
(
cliques
.
size
(),
true
);
126
127
// for each pair of cliques (i,j), compute the size of the clique that would
128
// result from the combination of clique i with clique j and store the
129
// result
130
// into a priorityQueue
131
std
::
pair
<
NodeId
,
NodeId
>
pair
;
132
133
PriorityQueue
<
std
::
pair
<
NodeId
,
NodeId
>,
float
>
queue
;
134
135
for
(
NodeId
i
= 0;
i
<
cliques
.
size
(); ++
i
) {
136
pair
.
first
=
i
;
137
const
NodeSet
&
nodes1
=
JT
.
separator
(
cliques
[
i
],
clique
);
138
139
for
(
NodeId
j
=
i
+ 1;
j
<
cliques
.
size
(); ++
j
) {
140
pair
.
second
=
j
;
141
queue
.
insert
(
pair
,
_combinedSize_
(
nodes1
,
JT
.
separator
(
cliques
[
j
],
clique
),
domain_sizes
));
142
}
143
}
144
145
// now parse the priority queue: the top element (i,j) gives the combination
146
// to perform. When the result R has been computed, substitute i by R,
147
// remove
148
// table j and recompute all the priorities of all the pairs (R,k) still
149
// available.
150
for
(
NodeId
k
= 2;
k
<
cliques
.
size
(); ++
k
) {
151
// get the combination to perform and do it
152
pair
=
queue
.
pop
();
153
NodeId
ti
=
pair
.
first
;
154
NodeId
tj
=
pair
.
second
;
155
156
// create a new clique that will become adjacent to ti and tj
157
// and remove the edges between ti, tj and clique
158
const
NodeSet
&
nodes1
=
JT
.
separator
(
cliques
[
ti
],
clique
);
159
const
NodeSet
&
nodes2
=
JT
.
separator
(
cliques
[
tj
],
clique
);
160
NodeId
new_node
=
JT
.
addNode
(
nodes1
+
nodes2
);
161
JT
.
addEdge
(
cliques
[
ti
],
new_node
);
162
JT
.
addEdge
(
cliques
[
tj
],
new_node
);
163
JT
.
addEdge
(
clique
,
new_node
);
164
JT
.
eraseEdge
(
Edge
(
cliques
[
ti
],
clique
));
165
JT
.
eraseEdge
(
Edge
(
cliques
[
tj
],
clique
));
166
167
// substitute cliques[pair.first] by the result
168
cliques
[
ti
] =
new_node
;
169
is_cliques_relevant
[
tj
] =
false
;
// now tj is no more a neighbor of clique
170
171
// remove all the pairs involving tj in the priority queue
172
173
for
(
NodeId
ind
= 0;
ind
<
tj
; ++
ind
) {
174
if
(
is_cliques_relevant
[
ind
]) {
175
pair
.
first
=
ind
;
176
queue
.
erase
(
pair
);
177
}
178
}
179
180
pair
.
first
=
tj
;
181
182
for
(
NodeId
ind
=
tj
+ 1;
ind
<
cliques
.
size
(); ++
ind
) {
183
if
(
is_cliques_relevant
[
ind
]) {
184
pair
.
second
=
ind
;
185
queue
.
erase
(
pair
);
186
}
187
}
188
189
// update the "combined" size of all the pairs involving "new_node"
190
{
191
const
NodeSet
&
nodes1
=
JT
.
separator
(
cliques
[
ti
],
clique
);
192
pair
.
second
=
ti
;
193
float
newsize
;
194
195
for
(
NodeId
ind
= 0;
ind
<
ti
; ++
ind
) {
196
if
(
is_cliques_relevant
[
ind
]) {
197
pair
.
first
=
ind
;
198
newsize
=
_combinedSize_
(
nodes1
,
JT
.
separator
(
cliques
[
ind
],
clique
),
domain_sizes
);
199
queue
.
setPriority
(
pair
,
newsize
);
200
}
201
}
202
203
pair
.
first
=
ti
;
204
205
for
(
NodeId
ind
=
ti
+ 1;
ind
<
cliques
.
size
(); ++
ind
) {
206
if
(
is_cliques_relevant
[
ind
]) {
207
pair
.
second
=
ind
;
208
newsize
=
_combinedSize_
(
nodes1
,
JT
.
separator
(
cliques
[
ind
],
clique
),
domain_sizes
);
209
queue
.
setPriority
(
pair
,
newsize
);
210
}
211
}
212
}
213
}
214
}
215
216
/// convert a whole connected component into a binary join tree
217
void
BinaryJoinTreeConverterDefault
::
_convertConnectedComponent_
(
218
CliqueGraph
&
JT
,
219
NodeId
current_node
,
220
NodeId
from
,
221
const
NodeProperty
<
Size
>&
domain_sizes
,
222
NodeProperty
<
bool
>&
mark
)
const
{
223
// first, indicate that the node has been marked (this avoids looping
224
// if JT is not a tree
225
mark
[
current_node
] =
true
;
226
227
// parse all the neighbors except nodes already converted and convert them
228
for
(
const
auto
neigh
:
JT
.
neighbours
(
current_node
)) {
229
if
(!
mark
[
neigh
]) {
230
_convertConnectedComponent_
(
JT
,
neigh
,
current_node
,
domain_sizes
,
mark
);
231
}
232
}
233
234
// convert the current node
235
_convertClique_
(
JT
,
current_node
,
from
,
domain_sizes
);
236
}
237
238
/// computes the binary join tree
239
CliqueGraph
BinaryJoinTreeConverterDefault
::
convert
(
const
CliqueGraph
&
JT
,
240
const
NodeProperty
<
Size
>&
domain_sizes
,
241
const
NodeSet
&
specified_roots
) {
242
// first, we copy the current clique graph. By default, this is what we
243
// will return
244
CliqueGraph
binJT
=
JT
;
245
246
// check that there is no connected component without a root. In such a
247
// case,
248
// assign an arbitrary root to it
249
_roots_
=
specified_roots
;
250
251
NodeProperty
<
bool
>
mark
=
JT
.
nodesProperty
(
false
,
JT
.
sizeNodes
());
252
253
// for each specified root, populate its connected component
254
for
(
const
auto
root
:
specified_roots
) {
255
// check that the root has not already been marked
256
// in this case, this means that more than one root has been specified
257
// for a given connected component
258
if
(
mark
[
root
])
259
GUM_ERROR
(
InvalidNode
,
260
"several roots have been specified for a given "
261
"connected component (last : "
262
<<
root
<<
")"
);
263
264
_markConnectedComponent_
(
JT
,
root
,
mark
);
265
}
266
267
// check that all nodes have been marked. If this is not the case, then
268
// this means that we need to add new roots
269
for
(
const
auto
&
elt
:
mark
)
270
if
(!
elt
.
second
) {
271
_roots_
<<
elt
.
first
;
272
_markConnectedComponent_
(
JT
,
elt
.
first
,
mark
);
273
}
274
275
// here, we know that each connected component has one and only one root.
276
// Now we can apply a recursive collect algorithm starting from root
277
// that transforms each clique with more than 3 neighbors into a set of
278
// cliques having at most 3 neighbors.
279
NodeProperty
<
bool
>
mark2
=
JT
.
nodesProperty
(
false
,
JT
.
sizeNodes
());
280
281
for
(
const
auto
root
:
_roots_
)
282
_convertConnectedComponent_
(
binJT
,
root
,
root
,
domain_sizes
,
mark2
);
283
284
// binJT is now a binary join tree, so we can return it
285
return
binJT
;
286
}
287
288
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643