aGrUM  0.14.2
localSearchWithTabuList_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}@lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it wil be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
29 
30 namespace gum {
31 
32  namespace learning {
33 
35  template < typename GRAPH_CHANGES_SELECTOR >
36  DAG LocalSearchWithTabuList::learnStructure(GRAPH_CHANGES_SELECTOR& selector,
37  DAG dag) {
38  selector.setGraph(dag);
39 
40  unsigned int nb_changes_applied = 0;
41  Idx applied_change_with_positive_score = 0;
42  Idx current_N = 0;
43 
45 
46  // a vector that indicates which queues have valid scores, i.e., scores
47  // that were not invalidated by previously applied changes
48  std::vector< bool > impacted_queues(dag.size(), false);
49 
50  // the best dag found so far with its score
51  DAG best_dag = dag;
52  double best_score = 0;
53  double current_score = 0;
54  double delta_score = 0;
55 
56  do {
57  applied_change_with_positive_score = 0;
58  delta_score = 0;
59 
60  std::vector< std::pair< NodeId, double > > ordered_queues =
61  selector.nodesSortedByBestScore();
62 
63  for (Idx j = 0; j < dag.size(); ++j) {
64  NodeId i = ordered_queues[j].first;
65 
66  if (!selector.empty(i)
67  && (!nb_changes_applied || (selector.bestScore(i) > 0))) {
68  // pick up the best change
69  const GraphChange& change = selector.bestChange(i);
70 
71  // perform the change
72  switch (change.type()) {
74  if (!impacted_queues[change.node2()]
75  && selector.isChangeValid(change)) {
76  if (selector.bestScore(i) > 0) {
77  ++applied_change_with_positive_score;
78  } else if (current_score > best_score) {
79  best_score = current_score;
80  best_dag = dag;
81  }
82 
83  // std::cout << "apply arc addition " << change.node1()
84  // << " -> " << change.node2()
85  // << " delta = " << selector.bestScore( i )
86  // << std::endl;
87 
88  delta_score += selector.bestScore(i);
89  current_score += selector.bestScore(i);
90  dag.addArc(change.node1(), change.node2());
91  impacted_queues[change.node2()] = true;
92  selector.applyChangeWithoutScoreUpdate(change);
93  ++nb_changes_applied;
94  }
95 
96  break;
97 
99  if (!impacted_queues[change.node2()]
100  && selector.isChangeValid(change)) {
101  if (selector.bestScore(i) > 0) {
102  ++applied_change_with_positive_score;
103  } else if (current_score > best_score) {
104  best_score = current_score;
105  best_dag = dag;
106  }
107 
108  // std::cout << "apply arc deletion " << change.node1()
109  // << " -> " << change.node2()
110  // << " delta = " << selector.bestScore( i )
111  // << std::endl;
112 
113  delta_score += selector.bestScore(i);
114  current_score += selector.bestScore(i);
115  dag.eraseArc(Arc(change.node1(), change.node2()));
116  impacted_queues[change.node2()] = true;
117  selector.applyChangeWithoutScoreUpdate(change);
118  ++nb_changes_applied;
119  }
120 
121  break;
122 
124  if ((!impacted_queues[change.node1()])
125  && (!impacted_queues[change.node2()])
126  && selector.isChangeValid(change)) {
127  if (selector.bestScore(i) > 0) {
128  ++applied_change_with_positive_score;
129  } else if (current_score > best_score) {
130  best_score = current_score;
131  best_dag = dag;
132  }
133 
134  // std::cout << "apply arc reversal " << change.node1()
135  // << " -> " << change.node2()
136  // << " delta = " << selector.bestScore( i )
137  // << std::endl;
138 
139  delta_score += selector.bestScore(i);
140  current_score += selector.bestScore(i);
141  dag.eraseArc(Arc(change.node1(), change.node2()));
142  dag.addArc(change.node2(), change.node1());
143  impacted_queues[change.node1()] = true;
144  impacted_queues[change.node2()] = true;
145  selector.applyChangeWithoutScoreUpdate(change);
146  ++nb_changes_applied;
147  }
148 
149  break;
150 
151  default:
153  "edge modifications are not "
154  "supported by local search");
155  }
156 
157  break;
158  }
159  }
160 
161  selector.updateScoresAfterAppliedChanges();
162 
163  // reset the impacted queue and applied changes structures
164  for (auto iter = impacted_queues.begin(); iter != impacted_queues.end();
165  ++iter) {
166  *iter = false;
167  }
168 
169  updateApproximationScheme(nb_changes_applied);
170 
171  // update current_N
172  if (applied_change_with_positive_score) {
173  current_N = 0;
174  nb_changes_applied = 0;
175  } else {
176  ++current_N;
177  }
178 
179  // std::cout << "current N = " << current_N << std::endl;
180  } while ((current_N <= __MaxNbDecreasing)
181  && continueApproximationScheme(delta_score));
182 
183  stopApproximationScheme(); // just to be sure of the
184  // approximationScheme has
185  // been notified of the end of looop
186 
187  if (current_score > best_score) {
188  return dag;
189  } else {
190  return best_dag;
191  }
192  }
193 
195  template < typename GUM_SCALAR,
196  typename GRAPH_CHANGES_SELECTOR,
197  typename PARAM_ESTIMATOR >
199  LocalSearchWithTabuList::learnBN(GRAPH_CHANGES_SELECTOR& selector,
200  PARAM_ESTIMATOR& estimator,
201  DAG initial_dag) {
202  return DAG2BNLearner<>::createBN< GUM_SCALAR >(
203  estimator, learnStructure(selector, initial_dag));
204  }
205 
206  } /* namespace learning */
207 
208 } /* namespace gum */
A class that, given a structure and a parameter estimator returns a full Bayes net.
Class representing a Bayesian Network.
Definition: BayesNet.h:76
the classes to account for structure changes in a graph
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Size size() const
alias for sizeNodes
void initApproximationScheme()
Initialise the scheme.
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
GraphChangeType type() const noexcept
returns the type of the operation
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
void stopApproximationScheme()
Stop the approximation scheme.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:40
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
Size __MaxNbDecreasing
the max number of changes decreasing the score that we allow to apply
NodeId node2() const noexcept
returns the second node involved in the modification
Size Idx
Type for indexes.
Definition: types.h:50
A class that, given a structure and a parameter estimator returns a full Bayes net.
Definition: DAG2BNLearner.h:49
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
Base class for dag.
Definition: DAG.h:99
Size NodeId
Type for node ids.
Definition: graphElements.h:97
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.
NodeId node1() const noexcept
returns the first node involved in the modification