aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
aggregatorDecomposition_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2019 Pierre-Henri WUILLEMIN & Christophe GONZALES(@AMU)
4
* {prenom.nom}_at_lip6.fr
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Inline implementation of AggregatorDecomposition.
25
*
26
* @author Gaspard Ducamp
27
*
28
*/
29
#
include
<
agrum
/
BN
/
inference
/
tools
/
aggregatorDecomposition
.
h
>
30
#
include
<
typeinfo
>
31
#
include
<
list
>
32
33
namespace
gum
{
34
35
template
<
typename
GUM_SCALAR >
36
INLINE AggregatorDecomposition<
GUM_SCALAR
>::
AggregatorDecomposition
() {
37
arity__
= 2;
38
GUM_CONSTRUCTOR
(
AggregatorDecomposition
);
39
}
40
41
template
<
typename
GUM_SCALAR
>
42
AggregatorDecomposition
<
GUM_SCALAR
>::~
AggregatorDecomposition
() {
43
GUM_DESTRUCTOR
(
AggregatorDecomposition
);
44
}
45
46
template
<
typename
GUM_SCALAR
>
47
BayesNet
<
GUM_SCALAR
>&
48
AggregatorDecomposition
<
GUM_SCALAR
>::
getDecomposedAggregator
(
49
BayesNet
<
GUM_SCALAR
>&
bn
) {
50
for
(
NodeId
node
:
bn
.
nodes
().
asNodeSet
()) {
51
std
::
string
description
=
bn
.
cpt
(
node
).
toString
();
52
auto
p
=
dynamic_cast
<
53
const
gum
::
aggregator
::
MultiDimAggregator
<
GUM_SCALAR
>* >(
54
bn
.
cpt
(
node
).
content
());
55
if
(
p
!=
nullptr
&&
p
->
isDecomposable
()) {
decomposeAggregator_
(
bn
,
node
); }
56
}
57
return
bn
;
58
}
59
60
template
<
typename
GUM_SCALAR
>
61
NodeId
AggregatorDecomposition
<
GUM_SCALAR
>::
addAggregator_
(
62
BayesNet
<
GUM_SCALAR
>&
bn
,
63
std
::
string
aggType
,
64
const
DiscreteVariable
&
var
,
65
Idx
value
) {
66
if
(
toLower
(
aggType
) ==
"min"
) {
67
return
bn
.
addMIN
(
var
);
68
}
else
if
(
toLower
(
aggType
) ==
"max"
) {
69
return
bn
.
addMAX
(
var
);
70
}
else
if
(
toLower
(
aggType
) ==
"count"
) {
71
return
bn
.
addCOUNT
(
var
,
value
);
72
}
else
if
(
toLower
(
aggType
) ==
"exists"
) {
73
return
bn
.
addEXISTS
(
var
,
value
);
74
}
else
if
(
toLower
(
aggType
) ==
"or"
) {
75
return
bn
.
addOR
(
var
);
76
}
else
if
(
toLower
(
aggType
) ==
"and"
) {
77
return
bn
.
addAND
(
var
);
78
}
else
if
(
toLower
(
aggType
) ==
"forall"
) {
79
return
bn
.
addFORALL
(
var
);
80
}
else
if
(
toLower
(
aggType
) ==
"amplitude"
) {
81
return
bn
.
addAMPLITUDE
(
var
);
82
}
else
if
(
toLower
(
aggType
) ==
"median"
) {
83
return
bn
.
addMEDIAN
(
var
);
84
}
else
if
(
toLower
(
aggType
) ==
"sum"
) {
85
return
bn
.
addSUM
(
var
);
86
}
else
{
87
std
::
string
msg
=
"Unknown aggregate: "
;
88
msg
.
append
(
aggType
);
89
GUM_ERROR
(
NotFound
,
msg
);
90
}
91
}
92
93
template
<
typename
GUM_SCALAR
>
94
BayesNet
<
GUM_SCALAR
>&
95
AggregatorDecomposition
<
GUM_SCALAR
>::
decomposeAggregator_
(
96
BayesNet
<
GUM_SCALAR
>&
bn
,
97
NodeId
initialAggregator
) {
98
auto
p
99
=
static_cast
<
const
gum
::
aggregator
::
MultiDimAggregator
<
GUM_SCALAR
>* >(
100
bn
.
cpt
(
initialAggregator
).
content
());
101
auto
newAgg
=
bn
.
variable
(
initialAggregator
).
clone
();
102
103
Set
<
NodeId
>
parents
=
bn
.
parents
(
initialAggregator
);
104
105
std
::
list
<
NodeId
>
orderedParents
= {};
106
107
for
(
const
auto
&
elt
:
parents
) {
108
orderedParents
.
push_back
(
elt
);
109
}
110
111
orderedParents
.
sort
();
112
113
Set
<
NodeId
>
newAggs
=
Set
<
NodeId
>();
114
List
<
NodeId
>
newAggParents
;
115
116
gum
::
Size
arity
=
getMaximumArity
();
117
gum
::
Size
q
= 0;
118
gum
::
Size
i
= 0;
119
120
long
minVal
= 0;
121
long
maxVal
= 0;
122
123
int
j
= 1;
124
125
std
::
string
newName
=
std
::
string
(
bn
.
variable
(
initialAggregator
).
name
()) +
"_"
126
+
std
::
to_string
(
j
);
127
std
::
string
aggType
=
p
->
aggregatorName
();
128
129
for
(
auto
parent
:
parents
) {
130
bn
.
eraseArc
(
parent
,
initialAggregator
);
131
}
132
133
/*
134
* We are constructing the new aggregator with a clone of the former
135
*/
136
newAgg
->
setName
(
newName
);
137
newAgg
->
setDescription
(
aggType
);
138
139
// for(Set<NodeId>::iterator it = parents.begin(); it!= parents.end(); ++it){
140
for
(
auto
it
=
orderedParents
.
begin
();
it
!=
orderedParents
.
end
(); ++
it
) {
141
if
(
q
<
parents
.
size
() -
parents
.
size
() %
arity
) {
142
if
(
i
==
arity
) {
143
i
= 0;
144
j
++;
145
146
if
(
newAgg
->
varType
() ==
VarType
::
Labelized
) {
147
addAggregator_
(
bn
,
aggType
, *
newAgg
,
p
->
domainSize
());
148
}
else
if
(
newAgg
->
varType
() ==
VarType
::
Range
) {
149
static_cast
<
RangeVariable
* >(
newAgg
)->
setMinVal
(
minVal
);
150
static_cast
<
RangeVariable
* >(
newAgg
)->
setMaxVal
(
maxVal
);
151
addAggregator_
(
bn
,
aggType
, *
newAgg
, 0);
152
}
else
{
153
GUM_ERROR
(
OperationNotAllowed
,
154
"Decomposition is not available for type : "
+
aggType
);
155
}
156
157
/*
158
* Adding arcs in the new node from its parents and adding thoses into
159
* the temporary potential
160
*/
161
for
(
NodeId
node
:
newAggParents
) {
162
bn
.
addArc
(
node
,
bn
.
idFromName
(
newName
));
163
}
164
165
/*
166
* Adding the new aggregator in t
167
*/
168
newAggs
.
insert
(
bn
.
idFromName
(
newName
));
169
170
newAggParents
.
clear
();
171
172
minVal
= 0;
173
maxVal
= 0;
174
175
newName
=
std
::
string
(
bn
.
variable
(
initialAggregator
).
name
()) +
"_"
176
+
std
::
to_string
(
j
);
177
178
delete
(
newAgg
);
179
newAgg
=
bn
.
variable
(
initialAggregator
).
clone
();
180
newAgg
->
setName
(
newName
);
181
newAgg
->
setDescription
(
aggType
);
182
183
if
(
bn
.
variable
(*
it
).
varType
() ==
VarType
::
Range
) {
184
minVal
185
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
minVal
();
186
maxVal
187
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
maxVal
();
188
}
189
190
newAggParents
.
push_back
(*
it
);
191
i
++;
192
}
else
{
193
if
(
bn
.
variable
(*
it
).
varType
() ==
VarType
::
Range
) {
194
minVal
195
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
minVal
();
196
maxVal
197
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
maxVal
();
198
}
199
200
newAggParents
.
push_back
(*
it
);
201
i
++;
202
}
203
}
else
{
204
newAggs
.
insert
(*
it
);
205
}
206
q
++;
207
}
208
209
if
(
newAgg
->
varType
() ==
VarType
::
Labelized
) {
210
addAggregator_
(
bn
,
aggType
, *
newAgg
,
p
->
domainSize
());
211
}
else
if
(
newAgg
->
varType
() ==
VarType
::
Range
) {
212
static_cast
<
RangeVariable
* >(
newAgg
)->
setMinVal
(
minVal
);
213
static_cast
<
RangeVariable
* >(
newAgg
)->
setMaxVal
(
maxVal
);
214
addAggregator_
(
bn
,
aggType
, *
newAgg
, 0);
215
}
else
{
216
GUM_ERROR
(
OperationNotAllowed
,
217
"Decomposition is not available for type : "
+
aggType
);
218
}
219
220
newAggs
.
insert
(
bn
.
idFromName
(
newName
));
221
222
for
(
NodeId
node
:
newAggParents
) {
223
bn
.
addArc
(
node
,
bn
.
idFromName
(
newName
));
224
}
225
226
Set
<
NodeId
>
final
=
addDepthLayer_
(
bn
,
newAggs
,
initialAggregator
,
j
);
227
228
for
(
auto
agg
:
final
) {
229
bn
.
addArc
(
agg
,
initialAggregator
);
230
}
231
232
delete
(
newAgg
);
233
return
bn
;
234
}
235
236
template
<
typename
GUM_SCALAR
>
237
Set
<
NodeId
>
AggregatorDecomposition
<
GUM_SCALAR
>::
addDepthLayer_
(
238
BayesNet
<
GUM_SCALAR
>&
bn
,
239
Set
<
NodeId
>
nodes
,
240
NodeId
initialAggregator
,
241
int
&
j
) {
242
auto
p
243
=
static_cast
<
const
gum
::
aggregator
::
MultiDimAggregator
<
GUM_SCALAR
>* >(
244
bn
.
cpt
(
initialAggregator
).
content
());
245
246
gum
::
Size
arity
=
getMaximumArity
();
247
std
::
string
aggType
=
p
->
aggregatorName
();
248
249
if
(
nodes
.
size
() <=
arity
) {
250
return
nodes
;
251
}
else
{
252
auto
newAgg
=
bn
.
variable
(
initialAggregator
).
clone
();
253
254
Set
<
NodeId
>
newAggs
=
Set
<
NodeId
>();
255
256
List
<
NodeId
>
newAggParents
;
257
258
std
::
list
<
NodeId
>
orderedParents
= {};
259
260
for
(
const
auto
&
elt
:
nodes
) {
261
orderedParents
.
push_back
(
elt
);
262
}
263
264
orderedParents
.
sort
();
265
266
gum
::
Size
i
= 0;
267
gum
::
Size
q
= 0;
268
long
minVal
= 0;
269
long
maxVal
= 0;
270
271
j
++;
272
273
std
::
string
newName
=
std
::
string
(
bn
.
variable
(
initialAggregator
).
name
())
274
+
"_"
+
std
::
to_string
(
j
);
275
276
newAgg
->
setName
(
newName
);
277
newAgg
->
setDescription
(
aggType
);
278
279
// for(Set<NodeId>::iterator it = nodes.begin(); it!= nodes.end(); ++it){
280
for
(
auto
it
=
orderedParents
.
begin
();
it
!=
orderedParents
.
end
(); ++
it
) {
281
if
(
q
<
nodes
.
size
() -
nodes
.
size
() %
arity
) {
282
if
(
i
==
arity
) {
283
i
= 0;
284
j
++;
285
286
if
(
newAgg
->
varType
() ==
VarType
::
Labelized
) {
287
addAggregator_
(
bn
,
aggType
, *
newAgg
,
p
->
domainSize
());
288
}
else
if
(
newAgg
->
varType
() ==
VarType
::
Range
) {
289
static_cast
<
RangeVariable
* >(
newAgg
)->
setMinVal
(
minVal
);
290
static_cast
<
RangeVariable
* >(
newAgg
)->
setMaxVal
(
maxVal
);
291
addAggregator_
(
bn
,
aggType
, *
newAgg
, 0);
292
}
else
{
293
GUM_ERROR
(
OperationNotAllowed
,
294
"Decomposition is not available for type : "
+
aggType
);
295
}
296
297
for
(
NodeId
node
:
newAggParents
) {
298
bn
.
addArc
(
node
,
bn
.
idFromName
(
newName
));
299
}
300
301
newAggs
.
insert
(
bn
.
idFromName
(
newName
));
302
303
newAggParents
.
clear
();
304
305
minVal
= 0;
306
maxVal
= 0;
307
308
newName
=
std
::
string
(
bn
.
variable
(
initialAggregator
).
name
()) +
"_"
309
+
std
::
to_string
(
j
);
310
311
delete
(
newAgg
);
312
newAgg
=
bn
.
variable
(
initialAggregator
).
clone
();
313
newAgg
->
setName
(
newName
);
314
newAgg
->
setDescription
(
aggType
);
315
316
if
(
bn
.
variable
(*
it
).
varType
() ==
VarType
::
Range
) {
317
minVal
318
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
minVal
();
319
maxVal
320
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
maxVal
();
321
}
322
323
newAggParents
.
push_back
(*
it
);
324
i
++;
325
}
else
{
326
if
(
bn
.
variable
(*
it
).
varType
() ==
VarType
::
Range
) {
327
minVal
328
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
minVal
();
329
maxVal
330
+=
static_cast
<
const
RangeVariable
& >(
bn
.
variable
(*
it
)).
maxVal
();
331
}
332
333
newAggParents
.
push_back
(*
it
);
334
i
++;
335
}
336
}
else
{
337
newAggs
.
insert
(*
it
);
338
}
339
q
++;
340
}
341
342
if
(
newAgg
->
varType
() ==
VarType
::
Labelized
) {
343
addAggregator_
(
bn
,
aggType
, *
newAgg
,
p
->
domainSize
());
344
}
else
if
(
newAgg
->
varType
() ==
VarType
::
Range
) {
345
static_cast
<
RangeVariable
* >(
newAgg
)->
setMinVal
(
minVal
);
346
static_cast
<
RangeVariable
* >(
newAgg
)->
setMaxVal
(
maxVal
);
347
addAggregator_
(
bn
,
aggType
, *
newAgg
, 0);
348
}
else
{
349
GUM_ERROR
(
OperationNotAllowed
,
350
"Decomposition is not available for type : "
+
aggType
);
351
}
352
353
newAggs
.
insert
(
bn
.
idFromName
(
newName
));
354
355
for
(
NodeId
node
:
newAggParents
) {
356
bn
.
addArc
(
node
,
bn
.
idFromName
(
newName
));
357
}
358
359
delete
(
newAgg
);
360
return
addDepthLayer_
(
bn
,
newAggs
,
initialAggregator
,
j
);
361
}
362
}
363
364
365
template
<
typename
GUM_SCALAR
>
366
INLINE
void
367
AggregatorDecomposition
<
GUM_SCALAR
>::
setMaximumArity
(
gum
::
Size
arity
) {
368
if
(
arity
< 2) {
369
GUM_ERROR
(
OperationNotAllowed
,
"Maximum arity should be at least 2"
);
370
}
371
arity__
=
arity
;
372
}
373
374
template
<
typename
GUM_SCALAR
>
375
gum
::
Size
AggregatorDecomposition
<
GUM_SCALAR
>::
getMaximumArity
() {
376
return
arity__
;
377
}
378
379
template
<
typename
GUM_SCALAR
>
380
INLINE
std
::
string
AggregatorDecomposition
<
GUM_SCALAR
>::
name
()
const
{
381
return
"aggregator decomposition"
;
382
}
383
384
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669