aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
searchStrategy_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Inline implementation of the SearchStrategy class.
25
*
26
* @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
#
include
<
agrum
/
PRM
/
gspan
/
searchStrategy
.
h
>
29
30
namespace
gum
{
31
namespace
prm
{
32
namespace
gspan
{
33
34
template
<
typename
GUM_SCALAR >
35
double
SearchStrategy<
GUM_SCALAR
>::
computeCost_
(
const
Pattern
&
p
) {
36
double
cost
= 0;
37
const
Sequence
<
PRMInstance
<
GUM_SCALAR
>* >&
seq
38
= *(
this
->
tree_
->
data
(
p
).
iso_map
.
begin
().
val
());
39
Sequence
<
PRMClassElement
<
GUM_SCALAR
>* >
input_set
;
40
41
for
(
const
auto
inst
:
seq
) {
42
for
(
const
auto
input
:
inst
->
type
().
slotChains
())
43
for
(
const
auto
inst2
:
inst
->
getInstances
(
input
->
id
()))
44
if
((!
seq
.
exists
(
inst2
))
45
&& (!
input_set
.
exists
(
46
&(
inst2
->
get
(
input
->
lastElt
().
safeName
()))))) {
47
cost
+=
std
::
log
(
input
->
type
().
variable
().
domainSize
());
48
input_set
.
insert
(&(
inst2
->
get
(
input
->
lastElt
().
safeName
())));
49
}
50
51
for
(
auto
vec
=
inst
->
beginInvRef
();
vec
!=
inst
->
endInvRef
(); ++
vec
)
52
for
(
const
auto
inverse
: *
vec
.
val
())
53
if
(!
seq
.
exists
(
inverse
.
first
)) {
54
cost
+=
std
::
log
(
55
inst
->
get
(
vec
.
key
()).
type
().
variable
().
domainSize
());
56
break
;
57
}
58
}
59
60
return
cost
;
61
}
62
63
template
<
typename
GUM_SCALAR >
64
void
StrictSearch<
GUM_SCALAR
>::
buildPatternGraph__
(
65
typename
StrictSearch
<
GUM_SCALAR
>::
PData
&
data
,
66
Set
<
Potential
<
GUM_SCALAR
>* >&
pool
,
67
const
Sequence
<
PRMInstance
<
GUM_SCALAR
>* >&
match
) {
68
for
(
const
auto
inst
:
match
) {
69
for
(
const
auto
&
elt
: *
inst
) {
70
// Adding the node
71
NodeId
id
=
data
.
graph
.
addNode
();
72
data
.
node2attr
.
insert
(
id
,
str__
(
inst
,
elt
.
second
));
73
data
.
mod
.
insert
(
id
,
elt
.
second
->
type
()->
domainSize
());
74
data
.
vars
.
insert
(
id
, &
elt
.
second
->
type
().
variable
());
75
pool
.
insert
(
76
const_cast
<
Potential
<
GUM_SCALAR
>* >(&(
elt
.
second
->
cpf
())));
77
}
78
}
79
80
// Second we add edges and nodes to inners or outputs
81
for
(
const
auto
inst
:
match
)
82
for
(
const
auto
&
elt
: *
inst
) {
83
NodeId
node
=
data
.
node2attr
.
first
(
str__
(
inst
,
elt
.
second
));
84
bool
found
85
=
false
;
// If this is set at true, then node is an outer node
86
87
// Children existing in the instance type's DAG
88
for
(
const
auto
chld
:
89
inst
->
type
().
containerDag
().
children
(
elt
.
second
->
id
())) {
90
data
.
graph
.
addEdge
(
91
node
,
92
data
.
node2attr
.
first
(
str__
(
inst
,
inst
->
get
(
chld
))));
93
}
94
95
// Parents existing in the instance type's DAG
96
for
(
const
auto
par
:
97
inst
->
type
().
containerDag
().
parents
(
elt
.
second
->
id
())) {
98
switch
(
inst
->
type
().
get
(
par
).
elt_type
()) {
99
case
PRMClassElement
<
GUM_SCALAR
>::
prm_attribute
:
100
case
PRMClassElement
<
GUM_SCALAR
>::
prm_aggregate
: {
101
data
.
graph
.
addEdge
(
102
node
,
103
data
.
node2attr
.
first
(
str__
(
inst
,
inst
->
get
(
par
))));
104
break
;
105
}
106
107
case
PRMClassElement
<
GUM_SCALAR
>::
prm_slotchain
: {
108
for
(
const
auto
inst2
:
inst
->
getInstances
(
par
))
109
if
(
match
.
exists
(
inst2
))
110
data
.
graph
.
addEdge
(
111
node
,
112
data
.
node2attr
.
first
(
113
str__
(
inst2
,
114
static_cast
<
const
PRMSlotChain
<
GUM_SCALAR
>& >(
115
inst
->
type
().
get
(
par
)))));
116
117
break
;
118
}
119
120
default
: {
/* Do nothing */
121
}
122
}
123
}
124
125
// Referring PRMAttribute<GUM_SCALAR>
126
if
(
inst
->
hasRefAttr
(
elt
.
second
->
id
())) {
127
const
std
::
vector
<
128
std
::
pair
<
PRMInstance
<
GUM_SCALAR
>*,
std
::
string
> >&
ref_attr
129
=
inst
->
getRefAttr
(
elt
.
second
->
id
());
130
131
for
(
auto
pair
=
ref_attr
.
begin
();
pair
!=
ref_attr
.
end
(); ++
pair
) {
132
if
(
match
.
exists
(
pair
->
first
)) {
133
NodeId
id
=
pair
->
first
->
type
().
get
(
pair
->
second
).
id
();
134
135
for
(
const
auto
child
:
136
pair
->
first
->
type
().
containerDag
().
children
(
id
))
137
data
.
graph
.
addEdge
(
138
node
,
139
data
.
node2attr
.
first
(
140
str__
(
pair
->
first
,
pair
->
first
->
get
(
child
))));
141
}
else
{
142
found
=
true
;
143
}
144
}
145
}
146
147
if
(
found
)
148
data
.
outputs
.
insert
(
node
);
149
else
150
data
.
inners
.
insert
(
node
);
151
}
152
}
153
154
template
<
typename
GUM_SCALAR
>
155
std
::
pair
<
Size
,
Size
>
StrictSearch
<
GUM_SCALAR
>::
elimination_cost__
(
156
typename
StrictSearch
<
GUM_SCALAR
>::
PData
&
data
,
157
Set
<
Potential
<
GUM_SCALAR
>* >&
pool
) {
158
List
<
NodeSet
>
partial_order
;
159
160
if
(
data
.
inners
.
size
())
partial_order
.
insert
(
data
.
inners
);
161
162
if
(
data
.
outputs
.
size
())
partial_order
.
insert
(
data
.
outputs
);
163
164
PartialOrderedTriangulation
t
(&(
data
.
graph
), &(
data
.
mod
), &
partial_order
);
165
const
std
::
vector
<
NodeId
>&
elim_order
=
t
.
eliminationOrder
();
166
Size
max
(0),
max_count
(1);
167
Set
<
Potential
<
GUM_SCALAR
>* >
trash
;
168
Potential
<
GUM_SCALAR
>*
pot
= 0;
169
170
for
(
size_t
idx
= 0;
idx
<
data
.
inners
.
size
(); ++
idx
) {
171
pot
=
new
Potential
<
GUM_SCALAR
>(
new
MultiDimSparse
<
GUM_SCALAR
>(0));
172
pot
->
add
(*(
data
.
vars
.
second
(
elim_order
[
idx
])));
173
trash
.
insert
(
pot
);
174
Set
<
Potential
<
GUM_SCALAR
>* >
toRemove
;
175
176
for
(
const
auto
p
:
pool
)
177
if
(
p
->
contains
(*(
data
.
vars
.
second
(
elim_order
[
idx
])))) {
178
for
(
auto
var
=
p
->
variablesSequence
().
begin
();
179
var
!=
p
->
variablesSequence
().
end
();
180
++
var
) {
181
try
{
182
pot
->
add
(**
var
);
183
}
catch
(
DuplicateElement
&) {}
184
}
185
186
toRemove
.
insert
(
p
);
187
}
188
189
if
(
pot
->
domainSize
() >
max
) {
190
max
=
pot
->
domainSize
();
191
max_count
= 1;
192
}
else
if
(
pot
->
domainSize
() ==
max
) {
193
++
max_count
;
194
}
195
196
for
(
const
auto
p
:
toRemove
)
197
pool
.
erase
(
p
);
198
199
pot
->
erase
(*(
data
.
vars
.
second
(
elim_order
[
idx
])));
200
}
201
202
for
(
const
auto
pot
:
trash
)
203
delete
pot
;
204
205
return
std
::
make_pair
(
max
,
max_count
);
206
}
207
208
// The SearchStrategy class
209
template
<
typename
GUM_SCALAR
>
210
INLINE
SearchStrategy
<
GUM_SCALAR
>::
SearchStrategy
() :
tree_
(0) {
211
GUM_CONSTRUCTOR
(
SearchStrategy
);
212
}
213
214
template
<
typename
GUM_SCALAR
>
215
INLINE
SearchStrategy
<
GUM_SCALAR
>::
SearchStrategy
(
216
const
SearchStrategy
<
GUM_SCALAR
>&
from
) :
217
tree_
(
from
.
tree_
) {
218
GUM_CONS_CPY
(
SearchStrategy
);
219
}
220
221
template
<
typename
GUM_SCALAR
>
222
INLINE
SearchStrategy
<
GUM_SCALAR
>::~
SearchStrategy
() {
223
GUM_DESTRUCTOR
(
SearchStrategy
);
224
}
225
226
template
<
typename
GUM_SCALAR
>
227
INLINE
SearchStrategy
<
GUM_SCALAR
>&
SearchStrategy
<
GUM_SCALAR
>::
operator
=(
228
const
SearchStrategy
<
GUM_SCALAR
>&
from
) {
229
this
->
tree_
=
from
.
tree_
;
230
return
*
this
;
231
}
232
233
template
<
typename
GUM_SCALAR
>
234
INLINE
void
235
SearchStrategy
<
GUM_SCALAR
>::
setTree
(
DFSTree
<
GUM_SCALAR
>*
tree
) {
236
this
->
tree_
=
tree
;
237
}
238
239
// FrequenceSearch
240
241
// The FrequenceSearch class
242
template
<
typename
GUM_SCALAR
>
243
INLINE
FrequenceSearch
<
GUM_SCALAR
>::
FrequenceSearch
(
Size
freq
) :
244
SearchStrategy
<
GUM_SCALAR
>(),
freq__
(
freq
) {
245
GUM_CONSTRUCTOR
(
FrequenceSearch
);
246
}
247
248
template
<
typename
GUM_SCALAR
>
249
INLINE
FrequenceSearch
<
GUM_SCALAR
>::
FrequenceSearch
(
250
const
FrequenceSearch
<
GUM_SCALAR
>&
from
) :
251
SearchStrategy
<
GUM_SCALAR
>(
from
),
252
freq__
(
from
.
freq__
) {
253
GUM_CONS_CPY
(
FrequenceSearch
);
254
}
255
256
template
<
typename
GUM_SCALAR
>
257
INLINE
FrequenceSearch
<
GUM_SCALAR
>::~
FrequenceSearch
() {
258
GUM_DESTRUCTOR
(
FrequenceSearch
);
259
}
260
261
template
<
typename
GUM_SCALAR
>
262
INLINE
FrequenceSearch
<
GUM_SCALAR
>&
263
FrequenceSearch
<
GUM_SCALAR
>::
operator
=(
264
const
FrequenceSearch
<
GUM_SCALAR
>&
from
) {
265
freq__
=
from
.
freq__
;
266
return
*
this
;
267
}
268
269
template
<
typename
GUM_SCALAR
>
270
INLINE
bool
FrequenceSearch
<
GUM_SCALAR
>::
accept_root
(
const
Pattern
*
r
) {
271
return
this
->
tree_
->
frequency
(*
r
) >=
freq__
;
272
}
273
274
template
<
typename
GUM_SCALAR
>
275
INLINE
bool
FrequenceSearch
<
GUM_SCALAR
>::
accept_growth
(
276
const
Pattern
*
parent
,
277
const
Pattern
*
child
,
278
const
EdgeGrowth
<
GUM_SCALAR
>&
growh
) {
279
return
this
->
tree_
->
frequency
(*
child
) >=
freq__
;
280
}
281
282
template
<
typename
GUM_SCALAR
>
283
INLINE
bool
FrequenceSearch
<
GUM_SCALAR
>::
operator
()(
gspan
::
Pattern
*
i
,
284
gspan
::
Pattern
*
j
) {
285
// We want a descending order
286
return
this
->
tree_
->
frequency
(*
i
) >
this
->
tree_
->
frequency
(*
j
);
287
}
288
289
template
<
typename
GUM_SCALAR
>
290
INLINE
bool
FrequenceSearch
<
GUM_SCALAR
>::
operator
()(
LabelData
*
i
,
291
LabelData
*
j
) {
292
return
(
this
->
tree_
->
graph
().
size
(
i
) >
this
->
tree_
->
graph
().
size
(
j
));
293
}
294
295
// StrictSearch
296
297
// The StrictSearch class
298
template
<
typename
GUM_SCALAR
>
299
INLINE
StrictSearch
<
GUM_SCALAR
>::
StrictSearch
(
Size
freq
) :
300
SearchStrategy
<
GUM_SCALAR
>(),
freq__
(
freq
),
dot__
(
"."
) {
301
GUM_CONSTRUCTOR
(
StrictSearch
);
302
}
303
304
template
<
typename
GUM_SCALAR
>
305
INLINE
StrictSearch
<
GUM_SCALAR
>::
StrictSearch
(
306
const
StrictSearch
<
GUM_SCALAR
>&
from
) :
307
SearchStrategy
<
GUM_SCALAR
>(
from
),
308
freq__
(
from
.
freq__
) {
309
GUM_CONS_CPY
(
StrictSearch
);
310
}
311
312
template
<
typename
GUM_SCALAR
>
313
INLINE
StrictSearch
<
GUM_SCALAR
>::~
StrictSearch
() {
314
GUM_DESTRUCTOR
(
StrictSearch
);
315
}
316
317
template
<
typename
GUM_SCALAR
>
318
INLINE
StrictSearch
<
GUM_SCALAR
>&
StrictSearch
<
GUM_SCALAR
>::
operator
=(
319
const
StrictSearch
<
GUM_SCALAR
>&
from
) {
320
freq__
=
from
.
freq__
;
321
return
*
this
;
322
}
323
324
template
<
typename
GUM_SCALAR
>
325
INLINE
bool
StrictSearch
<
GUM_SCALAR
>::
accept_root
(
const
Pattern
*
r
) {
326
return
(
this
->
tree_
->
frequency
(*
r
) >=
freq__
);
327
}
328
329
template
<
typename
GUM_SCALAR
>
330
INLINE
bool
StrictSearch
<
GUM_SCALAR
>::
accept_growth
(
331
const
Pattern
*
parent
,
332
const
Pattern
*
child
,
333
const
EdgeGrowth
<
GUM_SCALAR
>&
growth
) {
334
return
inner_cost__
(
child
)
335
+
this
->
tree_
->
frequency
(*
child
) *
outer_cost__
(
child
)
336
<
this
->
tree_
->
frequency
(*
child
) *
outer_cost__
(
parent
);
337
}
338
339
template
<
typename
GUM_SCALAR
>
340
INLINE
bool
StrictSearch
<
GUM_SCALAR
>::
operator
()(
gspan
::
Pattern
*
i
,
341
gspan
::
Pattern
*
j
) {
342
return
inner_cost__
(
i
) +
this
->
tree_
->
frequency
(*
i
) *
outer_cost__
(
i
)
343
<
inner_cost__
(
j
) +
this
->
tree_
->
frequency
(*
j
) *
outer_cost__
(
j
);
344
}
345
346
template
<
typename
GUM_SCALAR
>
347
INLINE
bool
StrictSearch
<
GUM_SCALAR
>::
operator
()(
LabelData
*
i
,
348
LabelData
*
j
) {
349
return
i
->
tree_width
*
this
->
tree_
->
graph
().
size
(
i
)
350
<
j
->
tree_width
*
this
->
tree_
->
graph
().
size
(
j
);
351
}
352
353
template
<
typename
GUM_SCALAR
>
354
INLINE
double
StrictSearch
<
GUM_SCALAR
>::
inner_cost__
(
const
Pattern
*
p
) {
355
try
{
356
return
map__
[
p
].
first
;
357
}
catch
(
NotFound
&) {
358
compute_costs__
(
p
);
359
return
map__
[
p
].
first
;
360
}
361
}
362
363
template
<
typename
GUM_SCALAR
>
364
INLINE
double
StrictSearch
<
GUM_SCALAR
>::
outer_cost__
(
const
Pattern
*
p
) {
365
try
{
366
return
map__
[
p
].
second
;
367
}
catch
(
NotFound
&) {
368
compute_costs__
(
p
);
369
return
map__
[
p
].
second
;
370
}
371
}
372
373
template
<
typename
GUM_SCALAR
>
374
INLINE
std
::
string
StrictSearch
<
GUM_SCALAR
>::
str__
(
375
const
PRMInstance
<
GUM_SCALAR
>*
i
,
376
const
PRMAttribute
<
GUM_SCALAR
>*
a
)
const
{
377
return
i
->
name
() +
dot__
+
a
->
safeName
();
378
}
379
380
template
<
typename
GUM_SCALAR
>
381
INLINE
std
::
string
StrictSearch
<
GUM_SCALAR
>::
str__
(
382
const
PRMInstance
<
GUM_SCALAR
>*
i
,
383
const
PRMAttribute
<
GUM_SCALAR
>&
a
)
const
{
384
return
i
->
name
() +
dot__
+
a
.
safeName
();
385
}
386
387
template
<
typename
GUM_SCALAR
>
388
INLINE
std
::
string
StrictSearch
<
GUM_SCALAR
>::
str__
(
389
const
PRMInstance
<
GUM_SCALAR
>*
i
,
390
const
PRMSlotChain
<
GUM_SCALAR
>&
a
)
const
{
391
return
i
->
name
() +
dot__
+
a
.
lastElt
().
safeName
();
392
}
393
394
template
<
typename
GUM_SCALAR
>
395
INLINE
void
StrictSearch
<
GUM_SCALAR
>::
compute_costs__
(
const
Pattern
*
p
) {
396
typename
StrictSearch
<
GUM_SCALAR
>::
PData
data
;
397
Set
<
Potential
<
GUM_SCALAR
>* >
pool
;
398
buildPatternGraph__
(
data
,
399
pool
,
400
*(
this
->
tree_
->
data
(*
p
).
iso_map
.
begin
().
val
()));
401
double
inner
=
std
::
log
(
elimination_cost__
(
data
,
pool
).
first
);
402
double
outer
=
this
->
computeCost_
(*
p
);
403
map__
.
insert
(
p
,
std
::
make_pair
(
inner
,
outer
));
404
}
405
406
// TreeWidthSearch
407
408
template
<
typename
GUM_SCALAR
>
409
INLINE
TreeWidthSearch
<
GUM_SCALAR
>::
TreeWidthSearch
() :
410
SearchStrategy
<
GUM_SCALAR
>() {
411
GUM_CONSTRUCTOR
(
TreeWidthSearch
);
412
}
413
414
template
<
typename
GUM_SCALAR
>
415
INLINE
TreeWidthSearch
<
GUM_SCALAR
>::
TreeWidthSearch
(
416
const
TreeWidthSearch
<
GUM_SCALAR
>&
from
) :
417
SearchStrategy
<
GUM_SCALAR
>(
from
) {
418
GUM_CONS_CPY
(
TreeWidthSearch
);
419
}
420
421
template
<
typename
GUM_SCALAR
>
422
INLINE
TreeWidthSearch
<
GUM_SCALAR
>::~
TreeWidthSearch
() {
423
GUM_DESTRUCTOR
(
TreeWidthSearch
);
424
}
425
426
template
<
typename
GUM_SCALAR
>
427
INLINE
TreeWidthSearch
<
GUM_SCALAR
>&
428
TreeWidthSearch
<
GUM_SCALAR
>::
operator
=(
429
const
TreeWidthSearch
<
GUM_SCALAR
>&
from
) {
430
return
*
this
;
431
}
432
433
template
<
typename
GUM_SCALAR
>
434
INLINE
double
TreeWidthSearch
<
GUM_SCALAR
>::
cost
(
const
Pattern
&
p
) {
435
try
{
436
return
map__
[&
p
];
437
}
catch
(
NotFound
&) {
438
map__
.
insert
(&
p
,
this
->
computeCost_
(
p
));
439
return
map__
[&
p
];
440
}
441
}
442
443
template
<
typename
GUM_SCALAR
>
444
INLINE
bool
TreeWidthSearch
<
GUM_SCALAR
>::
accept_root
(
const
Pattern
*
r
) {
445
Size
tree_width
= 0;
446
447
for
(
const
auto
n
:
r
->
nodes
())
448
tree_width
+=
r
->
label
(
n
).
tree_width
;
449
450
return
tree_width
>=
cost
(*
r
);
451
}
452
453
template
<
typename
GUM_SCALAR
>
454
INLINE
bool
TreeWidthSearch
<
GUM_SCALAR
>::
accept_growth
(
455
const
Pattern
*
parent
,
456
const
Pattern
*
child
,
457
const
EdgeGrowth
<
GUM_SCALAR
>&
growth
) {
458
return
cost
(*
parent
) >=
cost
(*
child
);
459
}
460
461
template
<
typename
GUM_SCALAR
>
462
INLINE
bool
TreeWidthSearch
<
GUM_SCALAR
>::
operator
()(
gspan
::
Pattern
*
i
,
463
gspan
::
Pattern
*
j
) {
464
return
cost
(*
i
) <
cost
(*
j
);
465
}
466
467
template
<
typename
GUM_SCALAR
>
468
INLINE
bool
TreeWidthSearch
<
GUM_SCALAR
>::
operator
()(
LabelData
*
i
,
469
LabelData
*
j
) {
470
return
i
->
tree_width
<
j
->
tree_width
;
471
}
472
473
}
/* namespace gspan */
474
}
/* namespace prm */
475
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::prm::ParamScopeData::ParamScopeData
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)
Definition:
PRMClass_tpl.h:1101
gum::prm::gspan::operator<<
INLINE std::ostream & operator<<(std::ostream &out, const EdgeData< GUM_SCALAR > &data)
Print a EdgeData<GUM_SCALAR> in out.
Definition:
interfaceGraph_tpl.h:393