aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
localSearchWithTabuList_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief The local search with tabu list learning algorithm (for directed
24
*graphs)
25
*
26
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27
*/
28
29
#
include
<
agrum
/
BN
/
learning
/
paramUtils
/
DAG2BNLearner
.
h
>
30
#
include
<
agrum
/
BN
/
learning
/
structureUtils
/
graphChange
.
h
>
31
32
namespace
gum
{
33
34
namespace
learning
{
35
36
/// learns the structure of a Bayes net
37
template
<
typename
GRAPH_CHANGES_SELECTOR
>
38
DAG
LocalSearchWithTabuList
::
learnStructure
(
GRAPH_CHANGES_SELECTOR
&
selector
,
39
DAG
dag
) {
40
selector
.
setGraph
(
dag
);
41
42
unsigned
int
nb_changes_applied
= 0;
43
Idx
applied_change_with_positive_score
= 0;
44
Idx
current_N
= 0;
45
46
initApproximationScheme
();
47
48
// a vector that indicates which queues have valid scores, i.e., scores
49
// that were not invalidated by previously applied changes
50
std
::
vector
<
bool
>
impacted_queues
(
dag
.
size
(),
false
);
51
52
// the best dag found so far with its score
53
DAG
best_dag
=
dag
;
54
double
best_score
= 0;
55
double
current_score
= 0;
56
double
delta_score
= 0;
57
58
do
{
59
applied_change_with_positive_score
= 0;
60
delta_score
= 0;
61
62
std
::
vector
<
std
::
pair
<
NodeId
,
double
> >
ordered_queues
63
=
selector
.
nodesSortedByBestScore
();
64
65
for
(
Idx
j
= 0;
j
<
dag
.
size
(); ++
j
) {
66
NodeId
i
=
ordered_queues
[
j
].
first
;
67
68
if
(!
selector
.
empty
(
i
)
69
&& (!
nb_changes_applied
|| (
selector
.
bestScore
(
i
) > 0))) {
70
// pick up the best change
71
const
GraphChange
&
change
=
selector
.
bestChange
(
i
);
72
73
// perform the change
74
switch
(
change
.
type
()) {
75
case
GraphChangeType
::
ARC_ADDITION
:
76
if
(!
impacted_queues
[
change
.
node2
()]
77
&&
selector
.
isChangeValid
(
change
)) {
78
if
(
selector
.
bestScore
(
i
) > 0) {
79
++
applied_change_with_positive_score
;
80
}
else
if
(
current_score
>
best_score
) {
81
best_score
=
current_score
;
82
best_dag
=
dag
;
83
}
84
85
// std::cout << "apply arc addition " << change.node1()
86
// << " -> " << change.node2()
87
// << " delta = " << selector.bestScore( i )
88
// << std::endl;
89
90
delta_score
+=
selector
.
bestScore
(
i
);
91
current_score
+=
selector
.
bestScore
(
i
);
92
dag
.
addArc
(
change
.
node1
(),
change
.
node2
());
93
impacted_queues
[
change
.
node2
()] =
true
;
94
selector
.
applyChangeWithoutScoreUpdate
(
change
);
95
++
nb_changes_applied
;
96
}
97
98
break
;
99
100
case
GraphChangeType
::
ARC_DELETION
:
101
if
(!
impacted_queues
[
change
.
node2
()]
102
&&
selector
.
isChangeValid
(
change
)) {
103
if
(
selector
.
bestScore
(
i
) > 0) {
104
++
applied_change_with_positive_score
;
105
}
else
if
(
current_score
>
best_score
) {
106
best_score
=
current_score
;
107
best_dag
=
dag
;
108
}
109
110
// std::cout << "apply arc deletion " << change.node1()
111
// << " -> " << change.node2()
112
// << " delta = " << selector.bestScore( i )
113
// << std::endl;
114
115
delta_score
+=
selector
.
bestScore
(
i
);
116
current_score
+=
selector
.
bestScore
(
i
);
117
dag
.
eraseArc
(
Arc
(
change
.
node1
(),
change
.
node2
()));
118
impacted_queues
[
change
.
node2
()] =
true
;
119
selector
.
applyChangeWithoutScoreUpdate
(
change
);
120
++
nb_changes_applied
;
121
}
122
123
break
;
124
125
case
GraphChangeType
::
ARC_REVERSAL
:
126
if
((!
impacted_queues
[
change
.
node1
()])
127
&& (!
impacted_queues
[
change
.
node2
()])
128
&&
selector
.
isChangeValid
(
change
)) {
129
if
(
selector
.
bestScore
(
i
) > 0) {
130
++
applied_change_with_positive_score
;
131
}
else
if
(
current_score
>
best_score
) {
132
best_score
=
current_score
;
133
best_dag
=
dag
;
134
}
135
136
// std::cout << "apply arc reversal " << change.node1()
137
// << " -> " << change.node2()
138
// << " delta = " << selector.bestScore( i )
139
// << std::endl;
140
141
delta_score
+=
selector
.
bestScore
(
i
);
142
current_score
+=
selector
.
bestScore
(
i
);
143
dag
.
eraseArc
(
Arc
(
change
.
node1
(),
change
.
node2
()));
144
dag
.
addArc
(
change
.
node2
(),
change
.
node1
());
145
impacted_queues
[
change
.
node1
()] =
true
;
146
impacted_queues
[
change
.
node2
()] =
true
;
147
selector
.
applyChangeWithoutScoreUpdate
(
change
);
148
++
nb_changes_applied
;
149
}
150
151
break
;
152
153
default
:
154
GUM_ERROR
(
OperationNotAllowed
,
155
"edge modifications are not "
156
"supported by local search"
);
157
}
158
159
break
;
160
}
161
}
162
163
selector
.
updateScoresAfterAppliedChanges
();
164
165
// reset the impacted queue and applied changes structures
166
for
(
auto
iter
=
impacted_queues
.
begin
();
iter
!=
impacted_queues
.
end
();
167
++
iter
) {
168
*
iter
=
false
;
169
}
170
171
updateApproximationScheme
(
nb_changes_applied
);
172
173
// update current_N
174
if
(
applied_change_with_positive_score
) {
175
current_N
= 0;
176
nb_changes_applied
= 0;
177
}
else
{
178
++
current_N
;
179
}
180
181
// std::cout << "current N = " << current_N << std::endl;
182
}
while
((
current_N
<=
MaxNbDecreasing__
)
183
&&
continueApproximationScheme
(
delta_score
));
184
185
stopApproximationScheme
();
// just to be sure of the
186
// approximationScheme has
187
// been notified of the end of looop
188
189
if
(
current_score
>
best_score
) {
190
return
dag
;
191
}
else
{
192
return
best_dag
;
193
}
194
}
195
196
/// learns the structure and the parameters of a BN
197
template
<
typename
GUM_SCALAR
,
198
typename
GRAPH_CHANGES_SELECTOR
,
199
typename
PARAM_ESTIMATOR
>
200
BayesNet
<
GUM_SCALAR
>
201
LocalSearchWithTabuList
::
learnBN
(
GRAPH_CHANGES_SELECTOR
&
selector
,
202
PARAM_ESTIMATOR
&
estimator
,
203
DAG
initial_dag
) {
204
return
DAG2BNLearner
<>::
createBN
<
GUM_SCALAR
>(
205
estimator
,
206
learnStructure
(
selector
,
initial_dag
));
207
}
208
209
}
/* namespace learning */
210
211
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31