aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
searchStrategy_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Inline implementation of the SearchStrategy class.
25
*
26
* @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
#
include
<
agrum
/
PRM
/
gspan
/
searchStrategy
.
h
>
29
30
namespace
gum
{
31
namespace
prm
{
32
namespace
gspan
{
33
34
template
<
typename
GUM_SCALAR >
35
double
SearchStrategy<
GUM_SCALAR
>::
computeCost_
(
const
Pattern
&
p
) {
36
double
cost
= 0;
37
const
Sequence
<
PRMInstance
<
GUM_SCALAR
>* >&
seq
38
= *(
this
->
tree_
->
data
(
p
).
iso_map
.
begin
().
val
());
39
Sequence
<
PRMClassElement
<
GUM_SCALAR
>* >
input_set
;
40
41
for
(
const
auto
inst
:
seq
) {
42
for
(
const
auto
input
:
inst
->
type
().
slotChains
())
43
for
(
const
auto
inst2
:
inst
->
getInstances
(
input
->
id
()))
44
if
((!
seq
.
exists
(
inst2
))
45
&& (!
input_set
.
exists
(&(
inst2
->
get
(
input
->
lastElt
().
safeName
()))))) {
46
cost
+=
std
::
log
(
input
->
type
().
variable
().
domainSize
());
47
input_set
.
insert
(&(
inst2
->
get
(
input
->
lastElt
().
safeName
())));
48
}
49
50
for
(
auto
vec
=
inst
->
beginInvRef
();
vec
!=
inst
->
endInvRef
(); ++
vec
)
51
for
(
const
auto
inverse
: *
vec
.
val
())
52
if
(!
seq
.
exists
(
inverse
.
first
)) {
53
cost
+=
std
::
log
(
inst
->
get
(
vec
.
key
()).
type
().
variable
().
domainSize
());
54
break
;
55
}
56
}
57
58
return
cost
;
59
}
60
61
template
<
typename
GUM_SCALAR >
62
void
StrictSearch<
GUM_SCALAR
>::
_buildPatternGraph_
(
63
typename
StrictSearch
<
GUM_SCALAR
>::
PData
&
data
,
64
Set
<
Potential
<
GUM_SCALAR
>* >&
pool
,
65
const
Sequence
<
PRMInstance
<
GUM_SCALAR
>* >&
match
) {
66
for
(
const
auto
inst
:
match
) {
67
for
(
const
auto
&
elt
: *
inst
) {
68
// Adding the node
69
NodeId
id
=
data
.
graph
.
addNode
();
70
data
.
node2attr
.
insert
(
id
,
_str_
(
inst
,
elt
.
second
));
71
data
.
mod
.
insert
(
id
,
elt
.
second
->
type
()->
domainSize
());
72
data
.
vars
.
insert
(
id
, &
elt
.
second
->
type
().
variable
());
73
pool
.
insert
(
const_cast
<
Potential
<
GUM_SCALAR
>* >(&(
elt
.
second
->
cpf
())));
74
}
75
}
76
77
// Second we add edges and nodes to inners or outputs
78
for
(
const
auto
inst
:
match
)
79
for
(
const
auto
&
elt
: *
inst
) {
80
NodeId
node
=
data
.
node2attr
.
first
(
_str_
(
inst
,
elt
.
second
));
81
bool
found
=
false
;
// If this is set at true, then node is an outer node
82
83
// Children existing in the instance type's DAG
84
for
(
const
auto
chld
:
inst
->
type
().
containerDag
().
children
(
elt
.
second
->
id
())) {
85
data
.
graph
.
addEdge
(
node
,
data
.
node2attr
.
first
(
_str_
(
inst
,
inst
->
get
(
chld
))));
86
}
87
88
// Parents existing in the instance type's DAG
89
for
(
const
auto
par
:
inst
->
type
().
containerDag
().
parents
(
elt
.
second
->
id
())) {
90
switch
(
inst
->
type
().
get
(
par
).
elt_type
()) {
91
case
PRMClassElement
<
GUM_SCALAR
>::
prm_attribute
:
92
case
PRMClassElement
<
GUM_SCALAR
>::
prm_aggregate
: {
93
data
.
graph
.
addEdge
(
node
,
data
.
node2attr
.
first
(
_str_
(
inst
,
inst
->
get
(
par
))));
94
break
;
95
}
96
97
case
PRMClassElement
<
GUM_SCALAR
>::
prm_slotchain
: {
98
for
(
const
auto
inst2
:
inst
->
getInstances
(
par
))
99
if
(
match
.
exists
(
inst2
))
100
data
.
graph
.
addEdge
(
node
,
101
data
.
node2attr
.
first
(
102
_str_
(
inst2
,
103
static_cast
<
const
PRMSlotChain
<
GUM_SCALAR
>& >(
104
inst
->
type
().
get
(
par
)))));
105
106
break
;
107
}
108
109
default
: {
/* Do nothing */
110
}
111
}
112
}
113
114
// Referring PRMAttribute<GUM_SCALAR>
115
if
(
inst
->
hasRefAttr
(
elt
.
second
->
id
())) {
116
const
std
::
vector
<
std
::
pair
<
PRMInstance
<
GUM_SCALAR
>*,
std
::
string
> >&
ref_attr
117
=
inst
->
getRefAttr
(
elt
.
second
->
id
());
118
119
for
(
auto
pair
=
ref_attr
.
begin
();
pair
!=
ref_attr
.
end
(); ++
pair
) {
120
if
(
match
.
exists
(
pair
->
first
)) {
121
NodeId
id
=
pair
->
first
->
type
().
get
(
pair
->
second
).
id
();
122
123
for
(
const
auto
child
:
pair
->
first
->
type
().
containerDag
().
children
(
id
))
124
data
.
graph
.
addEdge
(
125
node
,
126
data
.
node2attr
.
first
(
_str_
(
pair
->
first
,
pair
->
first
->
get
(
child
))));
127
}
else
{
128
found
=
true
;
129
}
130
}
131
}
132
133
if
(
found
)
134
data
.
outputs
.
insert
(
node
);
135
else
136
data
.
inners
.
insert
(
node
);
137
}
138
}
139
140
template
<
typename
GUM_SCALAR
>
141
std
::
pair
<
Size
,
Size
>
StrictSearch
<
GUM_SCALAR
>::
_elimination_cost_
(
142
typename
StrictSearch
<
GUM_SCALAR
>::
PData
&
data
,
143
Set
<
Potential
<
GUM_SCALAR
>* >&
pool
) {
144
List
<
NodeSet
>
partial_order
;
145
146
if
(
data
.
inners
.
size
())
partial_order
.
insert
(
data
.
inners
);
147
148
if
(
data
.
outputs
.
size
())
partial_order
.
insert
(
data
.
outputs
);
149
150
PartialOrderedTriangulation
t
(&(
data
.
graph
), &(
data
.
mod
), &
partial_order
);
151
const
std
::
vector
<
NodeId
>&
elim_order
=
t
.
eliminationOrder
();
152
Size
max
(0),
max_count
(1);
153
Set
<
Potential
<
GUM_SCALAR
>* >
trash
;
154
Potential
<
GUM_SCALAR
>*
pot
= 0;
155
156
for
(
size_t
idx
= 0;
idx
<
data
.
inners
.
size
(); ++
idx
) {
157
pot
=
new
Potential
<
GUM_SCALAR
>(
new
MultiDimSparse
<
GUM_SCALAR
>(0));
158
pot
->
add
(*(
data
.
vars
.
second
(
elim_order
[
idx
])));
159
trash
.
insert
(
pot
);
160
Set
<
Potential
<
GUM_SCALAR
>* >
toRemove
;
161
162
for
(
const
auto
p
:
pool
)
163
if
(
p
->
contains
(*(
data
.
vars
.
second
(
elim_order
[
idx
])))) {
164
for
(
auto
var
=
p
->
variablesSequence
().
begin
();
var
!=
p
->
variablesSequence
().
end
();
165
++
var
) {
166
try
{
167
pot
->
add
(**
var
);
168
}
catch
(
DuplicateElement
&) {}
169
}
170
171
toRemove
.
insert
(
p
);
172
}
173
174
if
(
pot
->
domainSize
() >
max
) {
175
max
=
pot
->
domainSize
();
176
max_count
= 1;
177
}
else
if
(
pot
->
domainSize
() ==
max
) {
178
++
max_count
;
179
}
180
181
for
(
const
auto
p
:
toRemove
)
182
pool
.
erase
(
p
);
183
184
pot
->
erase
(*(
data
.
vars
.
second
(
elim_order
[
idx
])));
185
}
186
187
for
(
const
auto
pot
:
trash
)
188
delete
pot
;
189
190
return
std
::
make_pair
(
max
,
max_count
);
191
}
192
193
// The SearchStrategy class
194
template
<
typename
GUM_SCALAR
>
195
INLINE
SearchStrategy
<
GUM_SCALAR
>::
SearchStrategy
() :
tree_
(0) {
196
GUM_CONSTRUCTOR
(
SearchStrategy
);
197
}
198
199
template
<
typename
GUM_SCALAR
>
200
INLINE
201
SearchStrategy
<
GUM_SCALAR
>::
SearchStrategy
(
const
SearchStrategy
<
GUM_SCALAR
>&
from
) :
202
tree_
(
from
.
tree_
) {
203
GUM_CONS_CPY
(
SearchStrategy
);
204
}
205
206
template
<
typename
GUM_SCALAR
>
207
INLINE
SearchStrategy
<
GUM_SCALAR
>::~
SearchStrategy
() {
208
GUM_DESTRUCTOR
(
SearchStrategy
);
209
}
210
211
template
<
typename
GUM_SCALAR
>
212
INLINE
SearchStrategy
<
GUM_SCALAR
>&
213
SearchStrategy
<
GUM_SCALAR
>::
operator
=(
const
SearchStrategy
<
GUM_SCALAR
>&
from
) {
214
this
->
tree_
=
from
.
tree_
;
215
return
*
this
;
216
}
217
218
template
<
typename
GUM_SCALAR
>
219
INLINE
void
SearchStrategy
<
GUM_SCALAR
>::
setTree
(
DFSTree
<
GUM_SCALAR
>*
tree
) {
220
this
->
tree_
=
tree
;
221
}
222
223
// FrequenceSearch
224
225
// The FrequenceSearch class
226
template
<
typename
GUM_SCALAR
>
227
INLINE
FrequenceSearch
<
GUM_SCALAR
>::
FrequenceSearch
(
Size
freq
) :
228
SearchStrategy
<
GUM_SCALAR
>(),
_freq_
(
freq
) {
229
GUM_CONSTRUCTOR
(
FrequenceSearch
);
230
}
231
232
template
<
typename
GUM_SCALAR
>
233
INLINE
234
FrequenceSearch
<
GUM_SCALAR
>::
FrequenceSearch
(
const
FrequenceSearch
<
GUM_SCALAR
>&
from
) :
235
SearchStrategy
<
GUM_SCALAR
>(
from
),
236
_freq_
(
from
.
_freq_
) {
237
GUM_CONS_CPY
(
FrequenceSearch
);
238
}
239
240
template
<
typename
GUM_SCALAR
>
241
INLINE
FrequenceSearch
<
GUM_SCALAR
>::~
FrequenceSearch
() {
242
GUM_DESTRUCTOR
(
FrequenceSearch
);
243
}
244
245
template
<
typename
GUM_SCALAR
>
246
INLINE
FrequenceSearch
<
GUM_SCALAR
>&
247
FrequenceSearch
<
GUM_SCALAR
>::
operator
=(
const
FrequenceSearch
<
GUM_SCALAR
>&
from
) {
248
_freq_
=
from
.
_freq_
;
249
return
*
this
;
250
}
251
252
template
<
typename
GUM_SCALAR
>
253
INLINE
bool
FrequenceSearch
<
GUM_SCALAR
>::
accept_root
(
const
Pattern
*
r
) {
254
return
this
->
tree_
->
frequency
(*
r
) >=
_freq_
;
255
}
256
257
template
<
typename
GUM_SCALAR
>
258
INLINE
bool
259
FrequenceSearch
<
GUM_SCALAR
>::
accept_growth
(
const
Pattern
*
parent
,
260
const
Pattern
*
child
,
261
const
EdgeGrowth
<
GUM_SCALAR
>&
growh
) {
262
return
this
->
tree_
->
frequency
(*
child
) >=
_freq_
;
263
}
264
265
template
<
typename
GUM_SCALAR
>
266
INLINE
bool
FrequenceSearch
<
GUM_SCALAR
>::
operator
()(
gspan
::
Pattern
*
i
,
gspan
::
Pattern
*
j
) {
267
// We want a descending order
268
return
this
->
tree_
->
frequency
(*
i
) >
this
->
tree_
->
frequency
(*
j
);
269
}
270
271
template
<
typename
GUM_SCALAR
>
272
INLINE
bool
FrequenceSearch
<
GUM_SCALAR
>::
operator
()(
LabelData
*
i
,
LabelData
*
j
) {
273
return
(
this
->
tree_
->
graph
().
size
(
i
) >
this
->
tree_
->
graph
().
size
(
j
));
274
}
275
276
// StrictSearch
277
278
// The StrictSearch class
279
template
<
typename
GUM_SCALAR
>
280
INLINE
StrictSearch
<
GUM_SCALAR
>::
StrictSearch
(
Size
freq
) :
281
SearchStrategy
<
GUM_SCALAR
>(),
_freq_
(
freq
),
_dot_
(
"."
) {
282
GUM_CONSTRUCTOR
(
StrictSearch
);
283
}
284
285
template
<
typename
GUM_SCALAR
>
286
INLINE
StrictSearch
<
GUM_SCALAR
>::
StrictSearch
(
const
StrictSearch
<
GUM_SCALAR
>&
from
) :
287
SearchStrategy
<
GUM_SCALAR
>(
from
),
_freq_
(
from
.
_freq_
) {
288
GUM_CONS_CPY
(
StrictSearch
);
289
}
290
291
template
<
typename
GUM_SCALAR
>
292
INLINE
StrictSearch
<
GUM_SCALAR
>::~
StrictSearch
() {
293
GUM_DESTRUCTOR
(
StrictSearch
);
294
}
295
296
template
<
typename
GUM_SCALAR
>
297
INLINE
StrictSearch
<
GUM_SCALAR
>&
298
StrictSearch
<
GUM_SCALAR
>::
operator
=(
const
StrictSearch
<
GUM_SCALAR
>&
from
) {
299
_freq_
=
from
.
_freq_
;
300
return
*
this
;
301
}
302
303
template
<
typename
GUM_SCALAR
>
304
INLINE
bool
StrictSearch
<
GUM_SCALAR
>::
accept_root
(
const
Pattern
*
r
) {
305
return
(
this
->
tree_
->
frequency
(*
r
) >=
_freq_
);
306
}
307
308
template
<
typename
GUM_SCALAR
>
309
INLINE
bool
310
StrictSearch
<
GUM_SCALAR
>::
accept_growth
(
const
Pattern
*
parent
,
311
const
Pattern
*
child
,
312
const
EdgeGrowth
<
GUM_SCALAR
>&
growth
) {
313
return
_inner_cost_
(
child
) +
this
->
tree_
->
frequency
(*
child
) *
_outer_cost_
(
child
)
314
<
this
->
tree_
->
frequency
(*
child
) *
_outer_cost_
(
parent
);
315
}
316
317
template
<
typename
GUM_SCALAR
>
318
INLINE
bool
StrictSearch
<
GUM_SCALAR
>::
operator
()(
gspan
::
Pattern
*
i
,
gspan
::
Pattern
*
j
) {
319
return
_inner_cost_
(
i
) +
this
->
tree_
->
frequency
(*
i
) *
_outer_cost_
(
i
)
320
<
_inner_cost_
(
j
) +
this
->
tree_
->
frequency
(*
j
) *
_outer_cost_
(
j
);
321
}
322
323
template
<
typename
GUM_SCALAR
>
324
INLINE
bool
StrictSearch
<
GUM_SCALAR
>::
operator
()(
LabelData
*
i
,
LabelData
*
j
) {
325
return
i
->
tree_width
*
this
->
tree_
->
graph
().
size
(
i
)
326
<
j
->
tree_width
*
this
->
tree_
->
graph
().
size
(
j
);
327
}
328
329
template
<
typename
GUM_SCALAR
>
330
INLINE
double
StrictSearch
<
GUM_SCALAR
>::
_inner_cost_
(
const
Pattern
*
p
) {
331
try
{
332
return
_map_
[
p
].
first
;
333
}
catch
(
NotFound
&) {
334
_compute_costs_
(
p
);
335
return
_map_
[
p
].
first
;
336
}
337
}
338
339
template
<
typename
GUM_SCALAR
>
340
INLINE
double
StrictSearch
<
GUM_SCALAR
>::
_outer_cost_
(
const
Pattern
*
p
) {
341
try
{
342
return
_map_
[
p
].
second
;
343
}
catch
(
NotFound
&) {
344
_compute_costs_
(
p
);
345
return
_map_
[
p
].
second
;
346
}
347
}
348
349
template
<
typename
GUM_SCALAR
>
350
INLINE
std
::
string
351
StrictSearch
<
GUM_SCALAR
>::
_str_
(
const
PRMInstance
<
GUM_SCALAR
>*
i
,
352
const
PRMAttribute
<
GUM_SCALAR
>*
a
)
const
{
353
return
i
->
name
() +
_dot_
+
a
->
safeName
();
354
}
355
356
template
<
typename
GUM_SCALAR
>
357
INLINE
std
::
string
358
StrictSearch
<
GUM_SCALAR
>::
_str_
(
const
PRMInstance
<
GUM_SCALAR
>*
i
,
359
const
PRMAttribute
<
GUM_SCALAR
>&
a
)
const
{
360
return
i
->
name
() +
_dot_
+
a
.
safeName
();
361
}
362
363
template
<
typename
GUM_SCALAR
>
364
INLINE
std
::
string
365
StrictSearch
<
GUM_SCALAR
>::
_str_
(
const
PRMInstance
<
GUM_SCALAR
>*
i
,
366
const
PRMSlotChain
<
GUM_SCALAR
>&
a
)
const
{
367
return
i
->
name
() +
_dot_
+
a
.
lastElt
().
safeName
();
368
}
369
370
template
<
typename
GUM_SCALAR
>
371
INLINE
void
StrictSearch
<
GUM_SCALAR
>::
_compute_costs_
(
const
Pattern
*
p
) {
372
typename
StrictSearch
<
GUM_SCALAR
>::
PData
data
;
373
Set
<
Potential
<
GUM_SCALAR
>* >
pool
;
374
_buildPatternGraph_
(
data
,
pool
, *(
this
->
tree_
->
data
(*
p
).
iso_map
.
begin
().
val
()));
375
double
inner
=
std
::
log
(
_elimination_cost_
(
data
,
pool
).
first
);
376
double
outer
=
this
->
computeCost_
(*
p
);
377
_map_
.
insert
(
p
,
std
::
make_pair
(
inner
,
outer
));
378
}
379
380
// TreeWidthSearch
381
382
template
<
typename
GUM_SCALAR
>
383
INLINE
TreeWidthSearch
<
GUM_SCALAR
>::
TreeWidthSearch
() :
SearchStrategy
<
GUM_SCALAR
>() {
384
GUM_CONSTRUCTOR
(
TreeWidthSearch
);
385
}
386
387
template
<
typename
GUM_SCALAR
>
388
INLINE
389
TreeWidthSearch
<
GUM_SCALAR
>::
TreeWidthSearch
(
const
TreeWidthSearch
<
GUM_SCALAR
>&
from
) :
390
SearchStrategy
<
GUM_SCALAR
>(
from
) {
391
GUM_CONS_CPY
(
TreeWidthSearch
);
392
}
393
394
template
<
typename
GUM_SCALAR
>
395
INLINE
TreeWidthSearch
<
GUM_SCALAR
>::~
TreeWidthSearch
() {
396
GUM_DESTRUCTOR
(
TreeWidthSearch
);
397
}
398
399
template
<
typename
GUM_SCALAR
>
400
INLINE
TreeWidthSearch
<
GUM_SCALAR
>&
401
TreeWidthSearch
<
GUM_SCALAR
>::
operator
=(
const
TreeWidthSearch
<
GUM_SCALAR
>&
from
) {
402
return
*
this
;
403
}
404
405
template
<
typename
GUM_SCALAR
>
406
INLINE
double
TreeWidthSearch
<
GUM_SCALAR
>::
cost
(
const
Pattern
&
p
) {
407
try
{
408
return
_map_
[&
p
];
409
}
catch
(
NotFound
&) {
410
_map_
.
insert
(&
p
,
this
->
computeCost_
(
p
));
411
return
_map_
[&
p
];
412
}
413
}
414
415
template
<
typename
GUM_SCALAR
>
416
INLINE
bool
TreeWidthSearch
<
GUM_SCALAR
>::
accept_root
(
const
Pattern
*
r
) {
417
Size
tree_width
= 0;
418
419
for
(
const
auto
n
:
r
->
nodes
())
420
tree_width
+=
r
->
label
(
n
).
tree_width
;
421
422
return
tree_width
>=
cost
(*
r
);
423
}
424
425
template
<
typename
GUM_SCALAR
>
426
INLINE
bool
427
TreeWidthSearch
<
GUM_SCALAR
>::
accept_growth
(
const
Pattern
*
parent
,
428
const
Pattern
*
child
,
429
const
EdgeGrowth
<
GUM_SCALAR
>&
growth
) {
430
return
cost
(*
parent
) >=
cost
(*
child
);
431
}
432
433
template
<
typename
GUM_SCALAR
>
434
INLINE
bool
TreeWidthSearch
<
GUM_SCALAR
>::
operator
()(
gspan
::
Pattern
*
i
,
gspan
::
Pattern
*
j
) {
435
return
cost
(*
i
) <
cost
(*
j
);
436
}
437
438
template
<
typename
GUM_SCALAR
>
439
INLINE
bool
TreeWidthSearch
<
GUM_SCALAR
>::
operator
()(
LabelData
*
i
,
LabelData
*
j
) {
440
return
i
->
tree_width
<
j
->
tree_width
;
441
}
442
443
}
/* namespace gspan */
444
}
/* namespace prm */
445
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::prm::ParamScopeData::ParamScopeData
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)
Definition:
PRMClass_tpl.h:1032
gum::prm::gspan::operator<<
INLINE std::ostream & operator<<(std::ostream &out, const EdgeData< GUM_SCALAR > &data)
Print a EdgeData<GUM_SCALAR> in out.
Definition:
interfaceGraph_tpl.h:370