aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
greedyHillClimbing_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/** @file
23
* @brief The greedy hill learning algorithm (for directed graphs)
24
*
25
* @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
26
*/
27
28
#
include
<
agrum
/
BN
/
learning
/
paramUtils
/
DAG2BNLearner
.
h
>
29
#
include
<
agrum
/
BN
/
learning
/
structureUtils
/
graphChange
.
h
>
30
31
namespace
gum
{
32
33
namespace
learning
{
34
35
/// learns the structure of a Bayes net
36
template
<
typename
GRAPH_CHANGES_SELECTOR
>
37
DAG
GreedyHillClimbing
::
learnStructure
(
GRAPH_CHANGES_SELECTOR
&
selector
,
38
DAG
dag
) {
39
selector
.
setGraph
(
dag
);
40
41
unsigned
int
nb_changes_applied
= 1;
42
double
delta_score
;
43
44
initApproximationScheme
();
45
46
// a vector that indicates which queues have valid scores, i.e., scores
47
// that were not invalidated by previously applied changes
48
std
::
vector
<
bool
>
impacted_queues
(
dag
.
size
(),
false
);
49
50
do
{
51
nb_changes_applied
= 0;
52
delta_score
= 0;
53
54
std
::
vector
<
std
::
pair
<
NodeId
,
double
> >
ordered_queues
55
=
selector
.
nodesSortedByBestScore
();
56
57
for
(
Idx
j
= 0;
j
<
dag
.
size
(); ++
j
) {
58
Idx
i
=
ordered_queues
[
j
].
first
;
59
60
if
(!(
selector
.
empty
(
i
)) && (
selector
.
bestScore
(
i
) > 0)) {
61
// pick up the best change
62
const
GraphChange
&
change
=
selector
.
bestChange
(
i
);
63
64
// perform the change
65
switch
(
change
.
type
()) {
66
case
GraphChangeType
::
ARC_ADDITION
:
67
if
(!
impacted_queues
[
change
.
node2
()]
68
&&
selector
.
isChangeValid
(
change
)) {
69
delta_score
+=
selector
.
bestScore
(
i
);
70
dag
.
addArc
(
change
.
node1
(),
change
.
node2
());
71
impacted_queues
[
change
.
node2
()] =
true
;
72
selector
.
applyChangeWithoutScoreUpdate
(
change
);
73
++
nb_changes_applied
;
74
}
75
76
break
;
77
78
case
GraphChangeType
::
ARC_DELETION
:
79
if
(!
impacted_queues
[
change
.
node2
()]
80
&&
selector
.
isChangeValid
(
change
)) {
81
delta_score
+=
selector
.
bestScore
(
i
);
82
dag
.
eraseArc
(
Arc
(
change
.
node1
(),
change
.
node2
()));
83
impacted_queues
[
change
.
node2
()] =
true
;
84
selector
.
applyChangeWithoutScoreUpdate
(
change
);
85
++
nb_changes_applied
;
86
}
87
88
break
;
89
90
case
GraphChangeType
::
ARC_REVERSAL
:
91
if
((!
impacted_queues
[
change
.
node1
()])
92
&& (!
impacted_queues
[
change
.
node2
()])
93
&&
selector
.
isChangeValid
(
change
)) {
94
delta_score
+=
selector
.
bestScore
(
i
);
95
dag
.
eraseArc
(
Arc
(
change
.
node1
(),
change
.
node2
()));
96
dag
.
addArc
(
change
.
node2
(),
change
.
node1
());
97
impacted_queues
[
change
.
node1
()] =
true
;
98
impacted_queues
[
change
.
node2
()] =
true
;
99
selector
.
applyChangeWithoutScoreUpdate
(
change
);
100
++
nb_changes_applied
;
101
}
102
103
break
;
104
105
default
:
106
GUM_ERROR
(
OperationNotAllowed
,
107
"edge modifications are not supported by local search"
);
108
}
109
}
110
}
111
112
selector
.
updateScoresAfterAppliedChanges
();
113
114
// reset the impacted queue and applied changes structures
115
for
(
auto
iter
=
impacted_queues
.
begin
();
iter
!=
impacted_queues
.
end
();
116
++
iter
) {
117
*
iter
=
false
;
118
}
119
120
updateApproximationScheme
(
nb_changes_applied
);
121
122
}
while
(
nb_changes_applied
&&
continueApproximationScheme
(
delta_score
));
123
124
stopApproximationScheme
();
// just to be sure of the approximationScheme
125
// has
126
// been notified of the end of looop
127
128
return
dag
;
129
}
130
131
/// learns the structure and the parameters of a BN
132
template
<
typename
GUM_SCALAR
,
133
typename
GRAPH_CHANGES_SELECTOR
,
134
typename
PARAM_ESTIMATOR
>
135
BayesNet
<
GUM_SCALAR
>
136
GreedyHillClimbing
::
learnBN
(
GRAPH_CHANGES_SELECTOR
&
selector
,
137
PARAM_ESTIMATOR
&
estimator
,
138
DAG
initial_dag
) {
139
return
DAG2BNLearner
<>::
createBN
<
GUM_SCALAR
>(
140
estimator
,
141
learnStructure
(
selector
,
initial_dag
));
142
}
143
144
}
/* namespace learning */
145
146
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::learning::genericBNLearner::Database::Database
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)
Definition:
genericBNLearner_tpl.h:31