aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
incrementalGraphLearner.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Headers of the interface specifying functions to be implemented by any
25  * incremental learner.
26  *
27  * @author Jean-Christophe MAGNAN
28  */
29 
30 // =========================================================================
31 #ifndef GUM_INCREMENTAL_GRAPH_LEARNER_H
32 #define GUM_INCREMENTAL_GRAPH_LEARNER_H
33 // =========================================================================
34 // =========================================================================
35 #include <agrum/tools/multidim/implementations/multiDimFunctionGraph.h>
36 // =========================================================================
37 #include <agrum/FMDP/learning/core/templateStrategy.h>
38 #include <agrum/FMDP/learning/datastructure/IVisitableGraphLearner.h>
39 #include <agrum/FMDP/learning/datastructure/nodeDatabase.h>
40 // =========================================================================
41 #include <agrum/tools/multidim/utils/FunctionGraphUtilities/link.h>
42 // =========================================================================
43 
44 namespace gum {
45 
46  /**
47  * @class IncrementalGraphLearner incrementalGraphLearner.h
48  * <agrum/FMDP/learning/datastructure/incrementalGraphLearner>
49  * @brief
50  * @ingroup fmdp_group
51  *
52  * Abstract class for incrementaly learn a graphical representation of a
53  * function.
54  * Can handle both function of real values, and function explaining the
55  * behaviour
56  * of a variable given set of other variables (as typically in conditionnal
57  * probabilities)
58  *
59  * Maintains two graph in memory, one which is incrementaly updated and the
60  * other one
61  * which is updated on demand and is usable by the outside.
62  *
63  */
64  template < TESTNAME AttributeSelection, bool isScalar = false >
66  typedef typename ValueSelect< isScalar, double, Idx >::type ValueType;
67 
68  public:
69  // ###################################################################
70  /// @name Constructor & destructor.
71  // ###################################################################
72  /// @{
73 
74  // ==========================================================================
75  /**
76  * Default constructor
77  * @param target : the output diagram usable by the outside
78  * @param attributesSet : set of variables from which we try to describe the
79  * learned function
80  * @param learnVariable : if we tried to learn a the behaviour of a variable
81  * given variable given another set of variables, this is the one. If we are
82  * learning a function of real value, this is just a computationnal trick
83  * (and is to be deprecated)
84  */
85  // ==========================================================================
86  IncrementalGraphLearner(MultiDimFunctionGraph< double >* target,
87  Set< const DiscreteVariable* > attributesSet,
88  const DiscreteVariable* learnVariable);
89 
90  // ==========================================================================
91  /// Default destructor
92  // ==========================================================================
93  virtual ~IncrementalGraphLearner();
94 
95  private:
96  // ==========================================================================
97  /// Template function dispatcher
98  // ==========================================================================
99  void clearValue__() { clearValue__(Int2Type< isScalar >()); }
100 
101  // ==========================================================================
102  /// In the case where we're learning a function of real values
103  /// this has to be wiped out upon destruction (to be deprecated)
104  // ==========================================================================
105  void clearValue__(Int2Type< true >) { delete value_; }
106 
107  // ==========================================================================
108  /// In case where we're learning function of variable behaviour,
109  /// this should do nothing
110  // ==========================================================================
111  void clearValue__(Int2Type< false >) {}
112 
113  /// @}
114 
115 
116  // ###################################################################
117  /// @name New Observation insertion methods
118  // ###################################################################
119  /// @{
120  public:
121  // ==========================================================================
122  /**
123  * Inserts a new observation
124  */
125  // ==========================================================================
126  virtual void addObservation(const Observation* obs);
127 
128  private:
129  // ==========================================================================
130  /**
131  * Get value assumed by studied variable for current observation
132  */
133  // ==========================================================================
134  void assumeValue__(const Observation* obs) {
135  assumeValue__(obs, Int2Type< isScalar >());
136  }
137  void assumeValue__(const Observation* obs, Int2Type< true >) {
138  if (!valueAssumed_.exists(obs->reward())) valueAssumed_ << obs->reward();
139  }
140  void assumeValue__(const Observation* obs, Int2Type< false >) {
141  if (!valueAssumed_.exists(obs->modality(value_)))
142  valueAssumed_ << obs->modality(value_);
143  }
144 
145 
146  // ==========================================================================
147  /**
148  * Seek modality assumed in obs for given var
149  */
150  // ==========================================================================
152  return branchObs__(obs, var, Int2Type< isScalar >());
153  }
155  const DiscreteVariable* var,
156  Int2Type< true >) {
157  return obs->rModality(var);
158  }
160  const DiscreteVariable* var,
161  Int2Type< false >) {
162  return obs->modality(var);
163  }
164 
165  protected:
166  // ==========================================================================
167  /**
168  * Will update internal graph's NodeDatabase of given node with the new
169  * observation
170  * @param newObs
171  * @param currentNodeId
172  */
173  // ==========================================================================
174  virtual void updateNodeWithObservation_(const Observation* newObs,
175  NodeId currentNodeId) {
176  nodeId2Database_[currentNodeId]->addObservation(newObs);
177  }
178 
179  /// @}
180 
181  // ###################################################################
182  /// @name Graph Structure update methods
183  // ###################################################################
184  /// @{
185 
186  public:
187  // ==========================================================================
188  /// If a new modality appears to exists for given variable,
189  /// call this method to turn every associated node to this variable into
190  /// leaf.
191  /// Graph has then indeed to be revised
192  // ==========================================================================
193  virtual void updateVar(const DiscreteVariable*);
194 
195  // ==========================================================================
196  /// Updates the tree after a new observation has been added
197  // ==========================================================================
198  virtual void updateGraph() = 0;
199 
200  protected:
201  // ==========================================================================
202  /**
203  * From the given sets of node, selects randomly one and installs it
204  * on given node. Chechks of course if node's current variable is not in
205  * that
206  * set first.
207  * @param nody : the node we update
208  * @param bestVars : the set of interessting vars to be installed here
209  */
210  // ==========================================================================
211  void updateNode_(NodeId nody, Set< const DiscreteVariable* >& bestVars);
212 
213  // ==========================================================================
214  /// Turns the given node into a leaf if not already so
215  // ==========================================================================
216  virtual void convertNode2Leaf_(NodeId);
217 
218  // ==========================================================================
219  /// Installs given variable to the given node, ensuring that the variable
220  /// is not present in its subtree
221  // ==========================================================================
222  virtual void transpose_(NodeId, const DiscreteVariable*);
223 
224  // ==========================================================================
225  /**
226  * inserts a new node in internal graph
227  * @param nDB : the associated database
228  * @param boundVar : the associated variable
229  * @return the newly created node's id
230  */
231  // ==========================================================================
233  const DiscreteVariable* boundVar);
234 
235  // ==========================================================================
236  /**
237  * inserts a new internal node in internal graph
238  * @param nDB : the associated database
239  * @param boundVar : the associated variable
240  * @param sonsMap : a table giving node's sons node
241  * @return the newly created node's id
242  */
243  // ==========================================================================
244  virtual NodeId
246  const DiscreteVariable* boundVar,
247  NodeId* sonsMap);
248 
249  // ==========================================================================
250  /**
251  * inserts a new leaf node in internal graohs
252  * @param nDB : the associated database
253  * @param boundVar : the associated variable
254  * @param obsSet : the set of observation this leaf retains
255  * @return the newly created node's id
256  */
257  // ==========================================================================
258  virtual NodeId
260  const DiscreteVariable* boundVar,
261  Set< const Observation* >* obsSet);
262 
263  // ==========================================================================
264  /**
265  * Changes the associated variable of a node
266  * @param chgedNodeId : the node to change
267  * @param desiredVar : its new associated variable
268  */
269  // ==========================================================================
270  virtual void chgNodeBoundVar_(NodeId chgedNodeId,
271  const DiscreteVariable* desiredVar);
272 
273  // ==========================================================================
274  /**
275  * Removes a node from the internal graph
276  * @param removedNodeId : the node to remove
277  */
278  // ==========================================================================
279  virtual void removeNode_(NodeId removedNodeId);
280 
281  /// @}
282 
283 
284  // ###################################################################
285  /// @name Function Graph Updating methods
286  // ###################################################################
287  /// @{
288  public:
289  // ==========================================================================
290  /// Updates target to currently learned graph structure
291  // ==========================================================================
292  virtual void updateFunctionGraph() = 0;
293 
294  /// @}
295 
296 
297  public:
298  // ==========================================================================
299  ///
300  // ==========================================================================
301  Size size() { return nodeVarMap_.size(); }
302 
303 
304  // ###################################################################
305  /// @name Visit Methods
306  // ###################################################################
307  /// @{
308  public:
309  // ==========================================================================
310  ///
311  // ==========================================================================
312  NodeId root() const { return this->root_; }
313 
314  // ==========================================================================
315  ///
316  // ==========================================================================
317  bool isTerminal(NodeId ni) const { return !this->nodeSonsMap_.exists(ni); }
318 
319  // ==========================================================================
320  ///
321  // ==========================================================================
323  return this->nodeVarMap_[ni];
324  }
325 
326  // ==========================================================================
327  ///
328  // ==========================================================================
330  return this->nodeSonsMap_[ni][modality];
331  }
332 
333  // ==========================================================================
334  ///
335  // ==========================================================================
337  return this->nodeId2Database_[ni]->nbObservation();
338  }
339 
340  // ==========================================================================
341  ///
342  // ==========================================================================
343  virtual void insertSetOfVars(MultiDimFunctionGraph< double >* ret) const {
344  for (SetIteratorSafe< const DiscreteVariable* > varIter
345  = setOfVars_.beginSafe();
346  varIter != setOfVars_.endSafe();
347  ++varIter)
348  ret->add(**varIter);
349  }
350  /// @}
351 
352  protected:
353  /// @}
354 
355  // ###################################################################
356  /// @name Model handling datastructures
357  // ###################################################################
358  /// @{
359 
360  // ==========================================================================
361  /// The source of nodeId
362  // ==========================================================================
364 
365  // ==========================================================================
366  /// The root of the ordered tree
367  // ==========================================================================
369 
370  // ==========================================================================
371  /// Gives for any node its associated variable
372  // ==========================================================================
374 
375  // ==========================================================================
376  /// A table giving for any node a table mapping to its son
377  /// idx is the modality of associated variable
378  // ==========================================================================
380 
381  // ==========================================================================
382  /// Associates to any variable the list of all nodes associated to
383  /// this variable
384  // ==========================================================================
386 
387  // ==========================================================================
388  /// This hashtable binds every node to an associated NodeDatabase
389  /// which handles every observation that concerns that node
390  // ==========================================================================
393 
394  // ==========================================================================
395  /// This hashtable binds to every leaf an associated set of all
396  /// hte observations compatible with it
397  // ==========================================================================
399 
400  /// @}
401 
402 
403  /// The final diagram we're building
405 
407 
410 
412  };
413 
414 
415 } /* namespace gum */
416 
417 #include <agrum/FMDP/learning/datastructure/incrementalGraphLearner_tpl.h>
418 
419 #endif // GUM_INCREMENTAL_GRAPH_LEARNER_H
virtual void updateFunctionGraph()=0
Updates target to currently learned graph structure.
void updateNode_(NodeId nody, Set< const DiscreteVariable * > &bestVars)
From the given sets of node, selects randomly one and installs it on given node.
IncrementalGraphLearner(MultiDimFunctionGraph< double > *target, Set< const DiscreteVariable * > attributesSet, const DiscreteVariable *learnVariable)
Default constructor.
MultiDimFunctionGraph< double > * target_
The final diagram we&#39;re building.
virtual void convertNode2Leaf_(NodeId)
Turns the given node into a leaf if not already so.
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
void assumeValue__(const Observation *obs)
Get value assumed by studied variable for current observation.
virtual ~IncrementalGraphLearner()
Default destructor.
virtual void removeNode_(NodeId removedNodeId)
Removes a node from the internal graph.
NodeId root_
The root of the ordered tree.
virtual NodeId insertInternalNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, NodeId *sonsMap)
inserts a new internal node in internal graph
virtual void transpose_(NodeId, const DiscreteVariable *)
Installs given variable to the given node, ensuring that the variable is not present in its subtree...
virtual void addObservation(const Observation *obs)
Inserts a new observation.
virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
virtual void updateGraph()=0
Updates the tree after a new observation has been added.
Set< const DiscreteVariable *> setOfVars_
virtual NodeId insertNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar)
inserts a new node in internal graph
HashTable< NodeId, NodeId *> nodeSonsMap_
A table giving for any node a table mapping to its son idx is the modality of associated variable...
void assumeValue__(const Observation *obs, Int2Type< false >)
Inserts a new observation.
NodeGraphPart model_
The source of nodeId.
HashTable< NodeId, Set< const Observation *> *> leafDatabase_
This hashtable binds to every leaf an associated set of all hte observations compatible with it...
void clearValue__()
Template function dispatcher.
const DiscreteVariable * nodeVar(NodeId ni) const
Idx branchObs__(const Observation *obs, const DiscreteVariable *var, Int2Type< false >)
Inserts a new observation.
virtual void insertSetOfVars(MultiDimFunctionGraph< double > *ret) const
NodeId nodeSon(NodeId ni, Idx modality) const
void clearValue__(Int2Type< false >)
In case where we&#39;re learning function of variable behaviour, this should do nothing.
virtual void updateNodeWithObservation_(const Observation *newObs, NodeId currentNodeId)
Will update internal graph&#39;s NodeDatabase of given node with the new observation. ...
Idx branchObs__(const Observation *obs, const DiscreteVariable *var)
Seek modality assumed in obs for given var.
HashTable< NodeId, const DiscreteVariable *> nodeVarMap_
Gives for any node its associated variable.
virtual void updateVar(const DiscreteVariable *)
If a new modality appears to exists for given variable, call this method to turn every associated nod...
virtual NodeId insertLeafNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, Set< const Observation * > *obsSet)
inserts a new leaf node in internal graohs