aGrUM  0.14.2
greedyHillClimbing_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}@lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it wil be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
28 
29 namespace gum {
30 
31  namespace learning {
32 
34  template < typename GRAPH_CHANGES_SELECTOR >
35  DAG GreedyHillClimbing::learnStructure(GRAPH_CHANGES_SELECTOR& selector,
36  DAG dag) {
37  selector.setGraph(dag);
38 
39  unsigned int nb_changes_applied = 1;
40  double delta_score;
41 
43 
44  // a vector that indicates which queues have valid scores, i.e., scores
45  // that were not invalidated by previously applied changes
46  std::vector< bool > impacted_queues(dag.size(), false);
47 
48  do {
49  nb_changes_applied = 0;
50  delta_score = 0;
51 
52  std::vector< std::pair< NodeId, double > > ordered_queues =
53  selector.nodesSortedByBestScore();
54 
55  for (Idx j = 0; j < dag.size(); ++j) {
56  Idx i = ordered_queues[j].first;
57 
58  if (!(selector.empty(i)) && (selector.bestScore(i) > 0)) {
59  // pick up the best change
60  const GraphChange& change = selector.bestChange(i);
61 
62  // perform the change
63  switch (change.type()) {
65  if (!impacted_queues[change.node2()]
66  && selector.isChangeValid(change)) {
67  delta_score += selector.bestScore(i);
68  dag.addArc(change.node1(), change.node2());
69  impacted_queues[change.node2()] = true;
70  selector.applyChangeWithoutScoreUpdate(change);
71  ++nb_changes_applied;
72  }
73 
74  break;
75 
77  if (!impacted_queues[change.node2()]
78  && selector.isChangeValid(change)) {
79  delta_score += selector.bestScore(i);
80  dag.eraseArc(Arc(change.node1(), change.node2()));
81  impacted_queues[change.node2()] = true;
82  selector.applyChangeWithoutScoreUpdate(change);
83  ++nb_changes_applied;
84  }
85 
86  break;
87 
89  if ((!impacted_queues[change.node1()])
90  && (!impacted_queues[change.node2()])
91  && selector.isChangeValid(change)) {
92  delta_score += selector.bestScore(i);
93  dag.eraseArc(Arc(change.node1(), change.node2()));
94  dag.addArc(change.node2(), change.node1());
95  impacted_queues[change.node1()] = true;
96  impacted_queues[change.node2()] = true;
97  selector.applyChangeWithoutScoreUpdate(change);
98  ++nb_changes_applied;
99  }
100 
101  break;
102 
103  default:
105  "edge modifications are not supported by local search");
106  }
107  }
108  }
109 
110  selector.updateScoresAfterAppliedChanges();
111 
112  // reset the impacted queue and applied changes structures
113  for (auto iter = impacted_queues.begin(); iter != impacted_queues.end();
114  ++iter) {
115  *iter = false;
116  }
117 
118  updateApproximationScheme(nb_changes_applied);
119 
120  } while (nb_changes_applied && continueApproximationScheme(delta_score));
121 
122  stopApproximationScheme(); // just to be sure of the approximationScheme
123  // has
124  // been notified of the end of looop
125 
126  return dag;
127  }
128 
130  template < typename GUM_SCALAR,
131  typename GRAPH_CHANGES_SELECTOR,
132  typename PARAM_ESTIMATOR >
134  GreedyHillClimbing::learnBN(GRAPH_CHANGES_SELECTOR& selector,
135  PARAM_ESTIMATOR& estimator,
136  DAG initial_dag) {
137  return DAG2BNLearner<>::createBN< GUM_SCALAR >(
138  estimator, learnStructure(selector, initial_dag));
139  }
140 
141  } /* namespace learning */
142 
143 } /* namespace gum */
A class that, given a structure and a parameter estimator returns a full Bayes net.
Class representing a Bayesian Network.
Definition: BayesNet.h:76
the classes to account for structure changes in a graph
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Size size() const
alias for sizeNodes
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
void initApproximationScheme()
Initialise the scheme.
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
GraphChangeType type() const noexcept
returns the type of the operation
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
void stopApproximationScheme()
Stop the approximation scheme.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:40
NodeId node2() const noexcept
returns the second node involved in the modification
Size Idx
Type for indexes.
Definition: types.h:50
A class that, given a structure and a parameter estimator returns a full Bayes net.
Definition: DAG2BNLearner.h:49
Base class for dag.
Definition: DAG.h:99
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.
NodeId node1() const noexcept
returns the first node involved in the modification