aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
localSearchWithTabuList_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief The local search with tabu list learning algorithm (for directed
24
*graphs)
25
*
26
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
29
#
include
<
agrum
/
BN
/
learning
/
paramUtils
/
DAG2BNLearner
.
h
>
30
#
include
<
agrum
/
BN
/
learning
/
structureUtils
/
graphChange
.
h
>
31
32
namespace
gum
{
33
34
namespace
learning
{
35
36
/// learns the structure of a Bayes net
37
template
<
typename
GRAPH_CHANGES_SELECTOR
>
38
DAG
LocalSearchWithTabuList
::
learnStructure
(
GRAPH_CHANGES_SELECTOR
&
selector
,
DAG
dag
) {
39
selector
.
setGraph
(
dag
);
40
41
unsigned
int
nb_changes_applied
= 0;
42
Idx
applied_change_with_positive_score
= 0;
43
Idx
current_N
= 0;
44
45
initApproximationScheme
();
46
47
// a vector that indicates which queues have valid scores, i.e., scores
48
// that were not invalidated by previously applied changes
49
std
::
vector
<
bool
>
impacted_queues
(
dag
.
size
(),
false
);
50
51
// the best dag found so far with its score
52
DAG
best_dag
=
dag
;
53
double
best_score
= 0;
54
double
current_score
= 0;
55
double
delta_score
= 0;
56
57
do
{
58
applied_change_with_positive_score
= 0;
59
delta_score
= 0;
60
61
std
::
vector
<
std
::
pair
<
NodeId
,
double
> >
ordered_queues
62
=
selector
.
nodesSortedByBestScore
();
63
64
for
(
Idx
j
= 0;
j
<
dag
.
size
(); ++
j
) {
65
NodeId
i
=
ordered_queues
[
j
].
first
;
66
67
if
(!
selector
.
empty
(
i
) && (!
nb_changes_applied
|| (
selector
.
bestScore
(
i
) > 0))) {
68
// pick up the best change
69
const
GraphChange
&
change
=
selector
.
bestChange
(
i
);
70
71
// perform the change
72
switch
(
change
.
type
()) {
73
case
GraphChangeType
::
ARC_ADDITION
:
74
if
(!
impacted_queues
[
change
.
node2
()] &&
selector
.
isChangeValid
(
change
)) {
75
if
(
selector
.
bestScore
(
i
) > 0) {
76
++
applied_change_with_positive_score
;
77
}
else
if
(
current_score
>
best_score
) {
78
best_score
=
current_score
;
79
best_dag
=
dag
;
80
}
81
82
// std::cout << "apply arc addition " << change.node1()
83
// << " -> " << change.node2()
84
// << " delta = " << selector.bestScore( i )
85
// << std::endl;
86
87
delta_score
+=
selector
.
bestScore
(
i
);
88
current_score
+=
selector
.
bestScore
(
i
);
89
dag
.
addArc
(
change
.
node1
(),
change
.
node2
());
90
impacted_queues
[
change
.
node2
()] =
true
;
91
selector
.
applyChangeWithoutScoreUpdate
(
change
);
92
++
nb_changes_applied
;
93
}
94
95
break
;
96
97
case
GraphChangeType
::
ARC_DELETION
:
98
if
(!
impacted_queues
[
change
.
node2
()] &&
selector
.
isChangeValid
(
change
)) {
99
if
(
selector
.
bestScore
(
i
) > 0) {
100
++
applied_change_with_positive_score
;
101
}
else
if
(
current_score
>
best_score
) {
102
best_score
=
current_score
;
103
best_dag
=
dag
;
104
}
105
106
// std::cout << "apply arc deletion " << change.node1()
107
// << " -> " << change.node2()
108
// << " delta = " << selector.bestScore( i )
109
// << std::endl;
110
111
delta_score
+=
selector
.
bestScore
(
i
);
112
current_score
+=
selector
.
bestScore
(
i
);
113
dag
.
eraseArc
(
Arc
(
change
.
node1
(),
change
.
node2
()));
114
impacted_queues
[
change
.
node2
()] =
true
;
115
selector
.
applyChangeWithoutScoreUpdate
(
change
);
116
++
nb_changes_applied
;
117
}
118
119
break
;
120
121
case
GraphChangeType
::
ARC_REVERSAL
:
122
if
((!
impacted_queues
[
change
.
node1
()]) && (!
impacted_queues
[
change
.
node2
()])
123
&&
selector
.
isChangeValid
(
change
)) {
124
if
(
selector
.
bestScore
(
i
) > 0) {
125
++
applied_change_with_positive_score
;
126
}
else
if
(
current_score
>
best_score
) {
127
best_score
=
current_score
;
128
best_dag
=
dag
;
129
}
130
131
// std::cout << "apply arc reversal " << change.node1()
132
// << " -> " << change.node2()
133
// << " delta = " << selector.bestScore( i )
134
// << std::endl;
135
136
delta_score
+=
selector
.
bestScore
(
i
);
137
current_score
+=
selector
.
bestScore
(
i
);
138
dag
.
eraseArc
(
Arc
(
change
.
node1
(),
change
.
node2
()));
139
dag
.
addArc
(
change
.
node2
(),
change
.
node1
());
140
impacted_queues
[
change
.
node1
()] =
true
;
141
impacted_queues
[
change
.
node2
()] =
true
;
142
selector
.
applyChangeWithoutScoreUpdate
(
change
);
143
++
nb_changes_applied
;
144
}
145
146
break
;
147
148
default
:
149
GUM_ERROR
(
OperationNotAllowed
,
150
"edge modifications are not "
151
"supported by local search"
);
152
}
153
154
break
;
155
}
156
}
157
158
selector
.
updateScoresAfterAppliedChanges
();
159
160
// reset the impacted queue and applied changes structures
161
for
(
auto
iter
=
impacted_queues
.
begin
();
iter
!=
impacted_queues
.
end
(); ++
iter
) {
162
*
iter
=
false
;
163
}
164
165
updateApproximationScheme
(
nb_changes_applied
);
166
167
// update current_N
168
if
(
applied_change_with_positive_score
) {
169
current_N
= 0;
170
nb_changes_applied
= 0;
171
}
else
{
172
++
current_N
;
173
}
174
175
// std::cout << "current N = " << current_N << std::endl;
176
}
while
((
current_N
<=
_MaxNbDecreasing_
) &&
continueApproximationScheme
(
delta_score
));
177
178
stopApproximationScheme
();
// just to be sure of the
179
// approximationScheme has
180
// been notified of the end of looop
181
182
if
(
current_score
>
best_score
) {
183
return
dag
;
184
}
else
{
185
return
best_dag
;
186
}
187
}
188
189
/// learns the structure and the parameters of a BN
190
template
<
typename
GUM_SCALAR
,
typename
GRAPH_CHANGES_SELECTOR
,
typename
PARAM_ESTIMATOR
>
191
BayesNet
<
GUM_SCALAR
>
LocalSearchWithTabuList
::
learnBN
(
GRAPH_CHANGES_SELECTOR
&
selector
,
192
PARAM_ESTIMATOR
&
estimator
,
193
DAG
initial_dag
) {
194
return
DAG2BNLearner
<>::
createBN
<
GUM_SCALAR
>(
estimator
,
195
learnStructure
(
selector
,
initial_dag
));
196
}
197
198
}
/* namespace learning */
199
200
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31