aGrUM  0.16.0
localSearchWithTabuList_tpl.h
Go to the documentation of this file.
1 
32 
33 namespace gum {
34 
35  namespace learning {
36 
38  template < typename GRAPH_CHANGES_SELECTOR >
39  DAG LocalSearchWithTabuList::learnStructure(GRAPH_CHANGES_SELECTOR& selector,
40  DAG dag) {
41  selector.setGraph(dag);
42 
43  unsigned int nb_changes_applied = 0;
44  Idx applied_change_with_positive_score = 0;
45  Idx current_N = 0;
46 
48 
49  // a vector that indicates which queues have valid scores, i.e., scores
50  // that were not invalidated by previously applied changes
51  std::vector< bool > impacted_queues(dag.size(), false);
52 
53  // the best dag found so far with its score
54  DAG best_dag = dag;
55  double best_score = 0;
56  double current_score = 0;
57  double delta_score = 0;
58 
59  do {
60  applied_change_with_positive_score = 0;
61  delta_score = 0;
62 
63  std::vector< std::pair< NodeId, double > > ordered_queues =
64  selector.nodesSortedByBestScore();
65 
66  for (Idx j = 0; j < dag.size(); ++j) {
67  NodeId i = ordered_queues[j].first;
68 
69  if (!selector.empty(i)
70  && (!nb_changes_applied || (selector.bestScore(i) > 0))) {
71  // pick up the best change
72  const GraphChange& change = selector.bestChange(i);
73 
74  // perform the change
75  switch (change.type()) {
77  if (!impacted_queues[change.node2()]
78  && selector.isChangeValid(change)) {
79  if (selector.bestScore(i) > 0) {
80  ++applied_change_with_positive_score;
81  } else if (current_score > best_score) {
82  best_score = current_score;
83  best_dag = dag;
84  }
85 
86  // std::cout << "apply arc addition " << change.node1()
87  // << " -> " << change.node2()
88  // << " delta = " << selector.bestScore( i )
89  // << std::endl;
90 
91  delta_score += selector.bestScore(i);
92  current_score += selector.bestScore(i);
93  dag.addArc(change.node1(), change.node2());
94  impacted_queues[change.node2()] = true;
95  selector.applyChangeWithoutScoreUpdate(change);
96  ++nb_changes_applied;
97  }
98 
99  break;
100 
102  if (!impacted_queues[change.node2()]
103  && selector.isChangeValid(change)) {
104  if (selector.bestScore(i) > 0) {
105  ++applied_change_with_positive_score;
106  } else if (current_score > best_score) {
107  best_score = current_score;
108  best_dag = dag;
109  }
110 
111  // std::cout << "apply arc deletion " << change.node1()
112  // << " -> " << change.node2()
113  // << " delta = " << selector.bestScore( i )
114  // << std::endl;
115 
116  delta_score += selector.bestScore(i);
117  current_score += selector.bestScore(i);
118  dag.eraseArc(Arc(change.node1(), change.node2()));
119  impacted_queues[change.node2()] = true;
120  selector.applyChangeWithoutScoreUpdate(change);
121  ++nb_changes_applied;
122  }
123 
124  break;
125 
127  if ((!impacted_queues[change.node1()])
128  && (!impacted_queues[change.node2()])
129  && selector.isChangeValid(change)) {
130  if (selector.bestScore(i) > 0) {
131  ++applied_change_with_positive_score;
132  } else if (current_score > best_score) {
133  best_score = current_score;
134  best_dag = dag;
135  }
136 
137  // std::cout << "apply arc reversal " << change.node1()
138  // << " -> " << change.node2()
139  // << " delta = " << selector.bestScore( i )
140  // << std::endl;
141 
142  delta_score += selector.bestScore(i);
143  current_score += selector.bestScore(i);
144  dag.eraseArc(Arc(change.node1(), change.node2()));
145  dag.addArc(change.node2(), change.node1());
146  impacted_queues[change.node1()] = true;
147  impacted_queues[change.node2()] = true;
148  selector.applyChangeWithoutScoreUpdate(change);
149  ++nb_changes_applied;
150  }
151 
152  break;
153 
154  default:
156  "edge modifications are not "
157  "supported by local search");
158  }
159 
160  break;
161  }
162  }
163 
164  selector.updateScoresAfterAppliedChanges();
165 
166  // reset the impacted queue and applied changes structures
167  for (auto iter = impacted_queues.begin(); iter != impacted_queues.end();
168  ++iter) {
169  *iter = false;
170  }
171 
172  updateApproximationScheme(nb_changes_applied);
173 
174  // update current_N
175  if (applied_change_with_positive_score) {
176  current_N = 0;
177  nb_changes_applied = 0;
178  } else {
179  ++current_N;
180  }
181 
182  // std::cout << "current N = " << current_N << std::endl;
183  } while ((current_N <= __MaxNbDecreasing)
184  && continueApproximationScheme(delta_score));
185 
186  stopApproximationScheme(); // just to be sure of the
187  // approximationScheme has
188  // been notified of the end of looop
189 
190  if (current_score > best_score) {
191  return dag;
192  } else {
193  return best_dag;
194  }
195  }
196 
198  template < typename GUM_SCALAR,
199  typename GRAPH_CHANGES_SELECTOR,
200  typename PARAM_ESTIMATOR >
202  LocalSearchWithTabuList::learnBN(GRAPH_CHANGES_SELECTOR& selector,
203  PARAM_ESTIMATOR& estimator,
204  DAG initial_dag) {
205  return DAG2BNLearner<>::createBN< GUM_SCALAR >(
206  estimator, learnStructure(selector, initial_dag));
207  }
208 
209  } /* namespace learning */
210 
211 } /* namespace gum */
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Class representing a Bayesian Network.
Definition: BayesNet.h:78
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Size size() const
alias for sizeNodes
void initApproximationScheme()
Initialise the scheme.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
GraphChangeType type() const noexcept
returns the type of the operation
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
void stopApproximationScheme()
Stop the approximation scheme.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:43
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
Size __MaxNbDecreasing
the max number of changes decreasing the score that we allow to apply
NodeId node2() const noexcept
returns the second node involved in the modification
Size Idx
Type for indexes.
Definition: types.h:53
A class that, given a structure and a parameter estimator returns a full Bayes net.
Definition: DAG2BNLearner.h:52
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
Base class for dag.
Definition: DAG.h:102
Size NodeId
Type for node ids.
Definition: graphElements.h:98
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.
NodeId node1() const noexcept
returns the first node involved in the modification