aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
localSearchWithTabuList_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief The local search with tabu list learning algorithm (for directed
24  *graphs)
25  *
26  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
30 #include <agrum/BN/learning/structureUtils/graphChange.h>
31 
32 namespace gum {
33 
34  namespace learning {
35 
36  /// learns the structure of a Bayes net
37  template < typename GRAPH_CHANGES_SELECTOR >
39  DAG dag) {
41 
42  unsigned int nb_changes_applied = 0;
44  Idx current_N = 0;
45 
47 
48  // a vector that indicates which queues have valid scores, i.e., scores
49  // that were not invalidated by previously applied changes
50  std::vector< bool > impacted_queues(dag.size(), false);
51 
52  // the best dag found so far with its score
53  DAG best_dag = dag;
54  double best_score = 0;
55  double current_score = 0;
56  double delta_score = 0;
57 
58  do {
60  delta_score = 0;
61 
62  std::vector< std::pair< NodeId, double > > ordered_queues
64 
65  for (Idx j = 0; j < dag.size(); ++j) {
67 
68  if (!selector.empty(i)
69  && (!nb_changes_applied || (selector.bestScore(i) > 0))) {
70  // pick up the best change
72 
73  // perform the change
74  switch (change.type()) {
78  if (selector.bestScore(i) > 0) {
80  } else if (current_score > best_score) {
82  best_dag = dag;
83  }
84 
85  // std::cout << "apply arc addition " << change.node1()
86  // << " -> " << change.node2()
87  // << " delta = " << selector.bestScore( i )
88  // << std::endl;
89 
93  impacted_queues[change.node2()] = true;
96  }
97 
98  break;
99 
103  if (selector.bestScore(i) > 0) {
105  } else if (current_score > best_score) {
107  best_dag = dag;
108  }
109 
110  // std::cout << "apply arc deletion " << change.node1()
111  // << " -> " << change.node2()
112  // << " delta = " << selector.bestScore( i )
113  // << std::endl;
114 
118  impacted_queues[change.node2()] = true;
121  }
122 
123  break;
124 
126  if ((!impacted_queues[change.node1()])
127  && (!impacted_queues[change.node2()])
129  if (selector.bestScore(i) > 0) {
131  } else if (current_score > best_score) {
133  best_dag = dag;
134  }
135 
136  // std::cout << "apply arc reversal " << change.node1()
137  // << " -> " << change.node2()
138  // << " delta = " << selector.bestScore( i )
139  // << std::endl;
140 
145  impacted_queues[change.node1()] = true;
146  impacted_queues[change.node2()] = true;
149  }
150 
151  break;
152 
153  default:
155  "edge modifications are not "
156  "supported by local search");
157  }
158 
159  break;
160  }
161  }
162 
164 
165  // reset the impacted queue and applied changes structures
166  for (auto iter = impacted_queues.begin(); iter != impacted_queues.end();
167  ++iter) {
168  *iter = false;
169  }
170 
172 
173  // update current_N
175  current_N = 0;
176  nb_changes_applied = 0;
177  } else {
178  ++current_N;
179  }
180 
181  // std::cout << "current N = " << current_N << std::endl;
182  } while ((current_N <= MaxNbDecreasing__)
184 
185  stopApproximationScheme(); // just to be sure of the
186  // approximationScheme has
187  // been notified of the end of looop
188 
189  if (current_score > best_score) {
190  return dag;
191  } else {
192  return best_dag;
193  }
194  }
195 
196  /// learns the structure and the parameters of a BN
197  template < typename GUM_SCALAR,
198  typename GRAPH_CHANGES_SELECTOR,
199  typename PARAM_ESTIMATOR >
203  DAG initial_dag) {
204  return DAG2BNLearner<>::createBN< GUM_SCALAR >(
205  estimator,
207  }
208 
209  } /* namespace learning */
210 
211 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
Database(const std::string &filename, const BayesNet< GUM_SCALAR > &bn, const std::vector< std::string > &missing_symbols)