aGrUM  0.16.0
greedyHillClimbing_tpl.h
Go to the documentation of this file.
1 
31 
32 namespace gum {
33 
34  namespace learning {
35 
37  template < typename GRAPH_CHANGES_SELECTOR >
38  DAG GreedyHillClimbing::learnStructure(GRAPH_CHANGES_SELECTOR& selector,
39  DAG dag) {
40  selector.setGraph(dag);
41 
42  unsigned int nb_changes_applied = 1;
43  double delta_score;
44 
46 
47  // a vector that indicates which queues have valid scores, i.e., scores
48  // that were not invalidated by previously applied changes
49  std::vector< bool > impacted_queues(dag.size(), false);
50 
51  do {
52  nb_changes_applied = 0;
53  delta_score = 0;
54 
55  std::vector< std::pair< NodeId, double > > ordered_queues =
56  selector.nodesSortedByBestScore();
57 
58  for (Idx j = 0; j < dag.size(); ++j) {
59  Idx i = ordered_queues[j].first;
60 
61  if (!(selector.empty(i)) && (selector.bestScore(i) > 0)) {
62  // pick up the best change
63  const GraphChange& change = selector.bestChange(i);
64 
65  // perform the change
66  switch (change.type()) {
68  if (!impacted_queues[change.node2()]
69  && selector.isChangeValid(change)) {
70  delta_score += selector.bestScore(i);
71  dag.addArc(change.node1(), change.node2());
72  impacted_queues[change.node2()] = true;
73  selector.applyChangeWithoutScoreUpdate(change);
74  ++nb_changes_applied;
75  }
76 
77  break;
78 
80  if (!impacted_queues[change.node2()]
81  && selector.isChangeValid(change)) {
82  delta_score += selector.bestScore(i);
83  dag.eraseArc(Arc(change.node1(), change.node2()));
84  impacted_queues[change.node2()] = true;
85  selector.applyChangeWithoutScoreUpdate(change);
86  ++nb_changes_applied;
87  }
88 
89  break;
90 
92  if ((!impacted_queues[change.node1()])
93  && (!impacted_queues[change.node2()])
94  && selector.isChangeValid(change)) {
95  delta_score += selector.bestScore(i);
96  dag.eraseArc(Arc(change.node1(), change.node2()));
97  dag.addArc(change.node2(), change.node1());
98  impacted_queues[change.node1()] = true;
99  impacted_queues[change.node2()] = true;
100  selector.applyChangeWithoutScoreUpdate(change);
101  ++nb_changes_applied;
102  }
103 
104  break;
105 
106  default:
108  "edge modifications are not supported by local search");
109  }
110  }
111  }
112 
113  selector.updateScoresAfterAppliedChanges();
114 
115  // reset the impacted queue and applied changes structures
116  for (auto iter = impacted_queues.begin(); iter != impacted_queues.end();
117  ++iter) {
118  *iter = false;
119  }
120 
121  updateApproximationScheme(nb_changes_applied);
122 
123  } while (nb_changes_applied && continueApproximationScheme(delta_score));
124 
125  stopApproximationScheme(); // just to be sure of the approximationScheme
126  // has
127  // been notified of the end of looop
128 
129  return dag;
130  }
131 
133  template < typename GUM_SCALAR,
134  typename GRAPH_CHANGES_SELECTOR,
135  typename PARAM_ESTIMATOR >
137  GreedyHillClimbing::learnBN(GRAPH_CHANGES_SELECTOR& selector,
138  PARAM_ESTIMATOR& estimator,
139  DAG initial_dag) {
140  return DAG2BNLearner<>::createBN< GUM_SCALAR >(
141  estimator, learnStructure(selector, initial_dag));
142  }
143 
144  } /* namespace learning */
145 
146 } /* namespace gum */
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Class representing a Bayesian Network.
Definition: BayesNet.h:78
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Size size() const
alias for sizeNodes
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
void initApproximationScheme()
Initialise the scheme.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
GraphChangeType type() const noexcept
returns the type of the operation
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
void stopApproximationScheme()
Stop the approximation scheme.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:43
NodeId node2() const noexcept
returns the second node involved in the modification
Size Idx
Type for indexes.
Definition: types.h:53
A class that, given a structure and a parameter estimator returns a full Bayes net.
Definition: DAG2BNLearner.h:52
Base class for dag.
Definition: DAG.h:102
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.
NodeId node1() const noexcept
returns the first node involved in the modification