aGrUM  0.16.0
imddi_tpl.h
Go to the documentation of this file.
1 
29 // =======================================================
30 #include <agrum/core/math/math.h>
32 #include <agrum/core/types.h>
33 // =======================================================
36 // =======================================================
38 // =======================================================
39 
40 
41 namespace gum {
42 
43  // ############################################################################
44  // Constructor & destructor.
45  // ############################################################################
46 
47  // ============================================================================
48  // Variable Learner constructor
49  // ============================================================================
50  template < TESTNAME AttributeSelection, bool isScalar >
53  double attributeSelectionThreshold,
54  double pairSelectionThreshold,
55  Set< const DiscreteVariable* > attributeListe,
56  const DiscreteVariable* learnedValue) :
57  IncrementalGraphLearner< AttributeSelection, isScalar >(
58  target, attributeListe, learnedValue),
59  __lg(&(this->_model), pairSelectionThreshold), __nbTotalObservation(0),
60  __attributeSelectionThreshold(attributeSelectionThreshold) {
61  GUM_CONSTRUCTOR(IMDDI);
62  __addLeaf(this->_root);
63  }
64 
65  // ============================================================================
66  // Reward Learner constructor
67  // ============================================================================
68  template < TESTNAME AttributeSelection, bool isScalar >
71  double attributeSelectionThreshold,
72  double pairSelectionThreshold,
73  Set< const DiscreteVariable* > attributeListe) :
74  IncrementalGraphLearner< AttributeSelection, isScalar >(
75  target, attributeListe, new LabelizedVariable("Reward", "", 2)),
76  __lg(&(this->_model), pairSelectionThreshold), __nbTotalObservation(0),
77  __attributeSelectionThreshold(attributeSelectionThreshold) {
78  GUM_CONSTRUCTOR(IMDDI);
79  __addLeaf(this->_root);
80  }
81 
82  // ============================================================================
83  // Reward Learner constructor
84  // ============================================================================
85  template < TESTNAME AttributeSelection, bool isScalar >
87  GUM_DESTRUCTOR(IMDDI);
89  __leafMap.beginSafe();
90  leafIter != __leafMap.endSafe();
91  ++leafIter)
92  delete leafIter.val();
93  }
94 
95 
96  // ############################################################################
97  // Incrementals methods
98  // ############################################################################
99 
100  template < TESTNAME AttributeSelection, bool isScalar >
102  const Observation* obs) {
105  }
106 
107  template < TESTNAME AttributeSelection, bool isScalar >
109  const Observation* newObs, NodeId currentNodeId) {
110  IncrementalGraphLearner< AttributeSelection,
111  isScalar >::_updateNodeWithObservation(newObs,
112  currentNodeId);
113  if (this->_nodeVarMap[currentNodeId] == this->_value)
114  __lg.updateLeaf(__leafMap[currentNodeId]);
115  }
116 
117 
118  // ============================================================================
119  // Updates the tree after a new observation has been added
120  // ============================================================================
121  template < TESTNAME AttributeSelection, bool isScalar >
123  __varOrder.clear();
124 
125  // First xe initialize the node set which will give us the scores
126  Set< NodeId > currentNodeSet;
127  currentNodeSet.insert(this->_root);
128 
129  // Then we initialize the pool of variables to consider
130  VariableSelector vs(this->_setOfVars);
131  for (vs.begin(); vs.hasNext(); vs.next()) {
132  __updateScore(vs.current(), this->_root, vs);
133  }
134 
135  // Then, until there's no node remaining
136  while (!vs.isEmpty()) {
137  // We select the best var
138  const DiscreteVariable* selectedVar = vs.select();
139  __varOrder.insert(selectedVar);
140 
141  // Then we decide if we update each node according to this var
142  __updateNodeSet(currentNodeSet, selectedVar, vs);
143  }
144 
145  // If there are remaining node that are not leaves after we establish the
146  // var order
147  // these nodes are turned into leaf.
148  for (SetIteratorSafe< NodeId > nodeIter = currentNodeSet.beginSafe();
149  nodeIter != currentNodeSet.endSafe();
150  ++nodeIter)
151  this->_convertNode2Leaf(*nodeIter);
152 
153 
154  if (__lg.needsUpdate()) __lg.update();
155  }
156 
157 
158  // ############################################################################
159  // Updating methods
160  // ############################################################################
161 
162 
163  // ###################################################################
164  // Select the most relevant variable
165  //
166  // First parameter is the set of variables among which the most
167  // relevant one is choosed
168  // Second parameter is the set of node the will attribute a score
169  // to each variable so that we choose the best.
170  // ###################################################################
171  template < TESTNAME AttributeSelection, bool isScalar >
173  const DiscreteVariable* var, NodeId nody, VariableSelector& vs) {
174  if (!this->_nodeId2Database[nody]->isTestRelevant(var)) return;
175  double weight = (double)this->_nodeId2Database[nody]->nbObservation()
176  / (double)this->__nbTotalObservation;
177  vs.updateScore(var,
178  weight * this->_nodeId2Database[nody]->testValue(var),
179  weight * this->_nodeId2Database[nody]->testOtherCriterion(var));
180  }
181 
182  template < TESTNAME AttributeSelection, bool isScalar >
184  const DiscreteVariable* var, NodeId nody, VariableSelector& vs) {
185  if (!this->_nodeId2Database[nody]->isTestRelevant(var)) return;
186  double weight = (double)this->_nodeId2Database[nody]->nbObservation()
187  / (double)this->__nbTotalObservation;
188  vs.downdateScore(var,
189  weight * this->_nodeId2Database[nody]->testValue(var),
190  weight
191  * this->_nodeId2Database[nody]->testOtherCriterion(var));
192  }
193 
194 
195  // ============================================================================
196  // For each node in the given set, this methods checks whether or not
197  // we should installed the given variable as a test.
198  // If so, the node is updated
199  // ============================================================================
200  template < TESTNAME AttributeSelection, bool isScalar >
202  Set< NodeId >& nodeSet,
203  const DiscreteVariable* selectedVar,
204  VariableSelector& vs) {
205  Set< NodeId > oldNodeSet(nodeSet);
206  nodeSet.clear();
207  for (SetIteratorSafe< NodeId > nodeIter = oldNodeSet.beginSafe();
208  nodeIter != oldNodeSet.endSafe();
209  ++nodeIter) {
210  if (this->_nodeId2Database[*nodeIter]->isTestRelevant(selectedVar)
211  && this->_nodeId2Database[*nodeIter]->testValue(selectedVar)
213  this->_transpose(*nodeIter, selectedVar);
214 
215  // Then we subtract the from the score given to each variables the
216  // quantity given by this node
217  for (vs.begin(); vs.hasNext(); vs.next()) {
218  __downdateScore(vs.current(), *nodeIter, vs);
219  }
220 
221  // And finally we add all its child to the new set of nodes
222  // and updates the remaining var's score
223  for (Idx modality = 0;
224  modality < this->_nodeVarMap[*nodeIter]->domainSize();
225  ++modality) {
226  NodeId sonId = this->_nodeSonsMap[*nodeIter][modality];
227  nodeSet << sonId;
228 
229  for (vs.begin(); vs.hasNext(); vs.next()) {
230  __updateScore(vs.current(), sonId, vs);
231  }
232  }
233  } else {
234  nodeSet << *nodeIter;
235  }
236  }
237  }
238 
239 
240  // ============================================================================
241  // Insert a new node with given associated database, var and maybe sons
242  // ============================================================================
243  template < TESTNAME AttributeSelection, bool isScalar >
246  const DiscreteVariable* boundVar,
247  Set< const Observation* >* obsSet) {
248  NodeId currentNodeId =
250  nDB, boundVar, obsSet);
251 
252  __addLeaf(currentNodeId);
253 
254  return currentNodeId;
255  }
256 
257 
258  // ============================================================================
259  // Changes var associated to a node
260  // ============================================================================
261  template < TESTNAME AttributeSelection, bool isScalar >
263  NodeId currentNodeId, const DiscreteVariable* desiredVar) {
264  if (this->_nodeVarMap[currentNodeId] == this->_value)
265  __removeLeaf(currentNodeId);
266 
268  currentNodeId, desiredVar);
269 
270  if (desiredVar == this->_value) __addLeaf(currentNodeId);
271  }
272 
273 
274  // ============================================================================
275  // Remove node from graph
276  // ============================================================================
277  template < TESTNAME AttributeSelection, bool isScalar >
279  if (this->_nodeVarMap[currentNodeId] == this->_value)
280  __removeLeaf(currentNodeId);
282  currentNodeId);
283  }
284 
285 
286  // ============================================================================
287  // Add leaf to aggregator
288  // ============================================================================
289  template < TESTNAME AttributeSelection, bool isScalar >
291  __leafMap.insert(currentNodeId,
293  currentNodeId,
294  this->_nodeId2Database[currentNodeId],
295  &(this->_valueAssumed)));
296  __lg.addLeaf(__leafMap[currentNodeId]);
297  }
298 
299 
300  // ============================================================================
301  // Remove leaf from aggregator
302  // ============================================================================
303  template < TESTNAME AttributeSelection, bool isScalar >
305  __lg.removeLeaf(__leafMap[currentNodeId]);
306  delete __leafMap[currentNodeId];
307  __leafMap.erase(currentNodeId);
308  }
309 
310 
311  // ============================================================================
312  // Computes the Reduced and Ordered Function Graph associated to this ordered
313  // tree
314  // ============================================================================
315  template < TESTNAME AttributeSelection, bool isScalar >
317  // if( __lg.needsUpdate() || this->_needUpdate ){
319  this->_needUpdate = false;
320  // }
321  }
322 
323 
324  // ============================================================================
325  // Performs the leaves merging
326  // ============================================================================
327  template < TESTNAME AttributeSelection, bool isScalar >
329  // *******************************************************************************************************
330  // Mise à jour de l'aggregateur de feuille
331  __lg.update();
332 
333  // *******************************************************************************************************
334  // Reinitialisation du Graphe de Décision
335  this->_target->clear();
336  for (auto varIter = __varOrder.beginSafe(); varIter != __varOrder.endSafe();
337  ++varIter)
338  this->_target->add(**varIter);
339  this->_target->add(*this->_value);
340 
342 
343  // *******************************************************************************************************
344  // Insertion des feuilles
348  treeNode2leaf.cbeginSafe();
349  treeNodeIter != treeNode2leaf.cendSafe();
350  ++treeNodeIter) {
351  if (!leaf2DGNode.exists(treeNodeIter.val()))
352  leaf2DGNode.insert(treeNodeIter.val(),
353  __insertLeafInFunctionGraph(treeNodeIter.val(),
355 
356  toTarget.insert(treeNodeIter.key(), leaf2DGNode[treeNodeIter.val()]);
357  }
358 
359  // *******************************************************************************************************
360  // Insertion des noeuds internes (avec vérification des possibilités de
361  // fusion)
363  __varOrder.rbeginSafe();
364  varIter != __varOrder.rendSafe();
365  --varIter) {
366  for (Link< NodeId >* curNodeIter = this->_var2Node[*varIter]->list();
367  curNodeIter;
368  curNodeIter = curNodeIter->nextLink()) {
369  NodeId* sonsMap = static_cast< NodeId* >(
370  SOA_ALLOCATE(sizeof(NodeId) * (*varIter)->domainSize()));
371  for (Idx modality = 0; modality < (*varIter)->domainSize(); ++modality)
372  sonsMap[modality] =
373  toTarget[this->_nodeSonsMap[curNodeIter->element()][modality]];
374  toTarget.insert(
375  curNodeIter->element(),
376  this->_target->manager()->addInternalNode(*varIter, sonsMap));
377  }
378  }
379 
380  // *******************************************************************************************************
381  // Polish
382  this->_target->manager()->setRootNode(toTarget[this->_root]);
383  this->_target->manager()->clean();
384  }
385 
386 
387  // ============================================================================
388  // Performs the leaves merging
389  // ============================================================================
390  template < TESTNAME AttributeSelection, bool isScalar >
393  double value = 0.0;
394  for (Idx moda = 0; moda < leaf->nbModa(); moda++) {
395  value += (double)leaf->effectif(moda) * this->_valueAssumed.atPos(moda);
396  }
397  if (leaf->total()) value /= (double)leaf->total();
398  return this->_target->manager()->addTerminalNode(value);
399  }
400 
401 
402  // ============================================================================
403  // Performs the leaves merging
404  // ============================================================================
405  template < TESTNAME AttributeSelection, bool isScalar >
408  NodeId* sonsMap = static_cast< NodeId* >(
409  SOA_ALLOCATE(sizeof(NodeId) * this->_value->domainSize()));
410  for (Idx modality = 0; modality < this->_value->domainSize(); ++modality) {
411  double newVal = 0.0;
412  if (leaf->total())
413  newVal = (double)leaf->effectif(modality) / (double)leaf->total();
414  sonsMap[modality] = this->_target->manager()->addTerminalNode(newVal);
415  }
416  return this->_target->manager()->addInternalNode(this->_value, sonsMap);
417  }
418 } // namespace gum
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
HashTable< NodeId, NodeId *> _nodeSonsMap
A table giving for any node a table mapping to its son idx is the modality of associated variable...
Safe iterators for Sequence.
Definition: sequence.h:1206
void updateGraph()
Updates the tree after a new observation has been added.
Definition: imddi_tpl.h:122
LeafAggregator __lg
Definition: imddi.h:170
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
NodeId __insertLeafInFunctionGraph(AbstractLeaf *, Int2Type< true >)
Computes the score of the given variables for the given node.
Definition: imddi_tpl.h:391
HashTable< const DiscreteVariable *, LinkedList< NodeId > *> _var2Node
Associates to any variable the list of all nodes associated to this variable.
void __removeLeaf(NodeId)
Adds a new observation to the structure.
Definition: imddi_tpl.h:304
virtual double effectif(Idx) const =0
Gaves the leaf effectif for given modality.
Set< const DiscreteVariable *> _setOfVars
void addObservation(const Observation *)
Adds a new observation to the structure.
Definition: imddi_tpl.h:101
void _updateNodeWithObservation(const Observation *newObs, NodeId currentNodeId)
Adds a new observation to the structure.
Definition: imddi_tpl.h:108
void clean()
Removes var without nodes in the diagram.
HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar > *> _nodeId2Database
This hashtable binds every node to an associated NodeDatabase which handles every observation that co...
void setRootNode(const NodeId &root)
Sets root node of decision diagram.
void _removeNode(NodeId removedNodeId)
Adds a new observation to the structure.
Definition: imddi_tpl.h:278
class LabelizedVariable
Safe iterators for the Set classDevelopers may consider using Set<x>::iterator_safe instead of SetIte...
Definition: set.h:811
void __updateScore(const DiscreteVariable *, NodeId, VariableSelector &vs)
Computes the score of the given variables for the given node.
Definition: imddi_tpl.h:172
<agrum/FMDP/learning/datastructure/incrementalGraphLearner>
const_iterator_safe cbeginSafe() const
Returns the safe const_iterator pointing to the beginning of the hashtable.
void updateFunctionGraph()
Computes the score of the given variables for the given node.
Definition: imddi_tpl.h:316
<agrum/FMDP/learning/datastructure/leaves/abstractLeaf.h>
Definition: abstractLeaf.h:53
Safe Const Iterators for hashtables.
Definition: hashTable.h:1918
void downdateScore(const DiscreteVariable *var, double score, double secondaryscore)
The set of remaining vars to select among.
virtual void _convertNode2Leaf(NodeId)
Turns the given node into a leaf if not already so.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
NodeId addInternalNode(const DiscreteVariable *var)
Inserts a new non terminal node in graph.
Base class for discrete random variable.
Safe Iterators for hashtables.
Definition: hashTable.h:2220
void __rebuildFunctionGraph()
Computes the score of the given variables for the given node.
Definition: imddi_tpl.h:328
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
MultiDimFunctionGraph< double > * _target
The final diagram we&#39;re building.
HashTable< NodeId, AbstractLeaf *> __leafMap
Definition: imddi.h:172
virtual NodeId _insertLeafNode(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, Set< const Observation * > *obsSet)
inserts a new leaf node in internal graohs
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Sequence< const DiscreteVariable *> __varOrder
Definition: imddi.h:168
void _chgNodeBoundVar(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Adds a new observation to the structure.
Definition: imddi_tpl.h:262
Representation of a setA Set is a structure that contains arbitrary elements.
Definition: set.h:165
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition: set_tpl.h:502
virtual Size domainSize() const =0
void begin()
The set of remaining vars to select among.
virtual Idx nbModa() const =0
virtual void add(const DiscreteVariable &v)
Adds a new var to the variables of the multidimensional matrix.
virtual void addObservation(const Observation *obs)
Inserts a new observation.
const const_iterator_safe & cendSafe() const noexcept
Returns the safe const_iterator pointing to the end of the hashtable.
<agrum/FMDP/learning/datastructure/leaves/concreteLeaf.h>
Definition: concreteLeaf.h:58
virtual void _transpose(NodeId, const DiscreteVariable *)
Installs given variable to the given node, ensuring that the variable is not present in its subtree...
~IMDDI()
Default destructor.
Definition: imddi_tpl.h:86
bool hasNext()
The set of remaining vars to select among.
void __addLeaf(NodeId)
Adds a new observation to the structure.
Definition: imddi_tpl.h:290
virtual void _removeNode(NodeId removedNodeId)
Removes a node from the internal graph.
void __downdateScore(const DiscreteVariable *, NodeId, VariableSelector &vs)
Computes the score of the given variables for the given node.
Definition: imddi_tpl.h:183
const DiscreteVariable * select()
Select the most relevant variable.
HashTable< NodeId, AbstractLeaf *> leavesMap()
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
IMDDI(MultiDimFunctionGraph< double > *target, double attributeSelectionThreshold, double pairSelectionThreshold, Set< const DiscreteVariable * > attributeListe, const DiscreteVariable *learnedValue)
Variable Learner constructor.
Definition: imddi_tpl.h:51
<agrum/FMDP/planning/FunctionGraph/variableselector.h>
virtual void _chgNodeBoundVar(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition: set_tpl.h:488
bool isEmpty()
The set of remaining vars to select among.
void updateScore(const DiscreteVariable *var, double score, double secondaryscore)
The set of remaining vars to select among.
NodeId _root
The root of the ordered tree.
double __attributeSelectionThreshold
The threshold above which we consider variables to be dependant.
Definition: imddi.h:178
NodeId addTerminalNode(const GUM_SCALAR &value)
Adds a value to the MultiDimFunctionGraph.
Size Idx
Type for indexes.
Definition: types.h:53
void removeLeaf(AbstractLeaf *)
void next()
The set of remaining vars to select among.
void clear()
Removes all the elements, if any, from the set.
Definition: set_tpl.h:375
MultiDimFunctionGraphManager< GUM_SCALAR, TerminalNodePolicy > * manager()
Returns a const reference to the manager of this diagram.
NodeGraphPart _model
The source of nodeId.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
virtual double total() const =0
bool updateLeaf(AbstractLeaf *)
NodeId _insertLeafNode(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, Set< const Observation * > *sonsMap)
Adds a new observation to the structure.
Definition: imddi_tpl.h:244
void addLeaf(AbstractLeaf *)
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const DiscreteVariable * current()
The set of remaining vars to select among.
HashTable< NodeId, const DiscreteVariable *> _nodeVarMap
Gives for any node its associated variable.
Size NodeId
Type for node ids.
Definition: graphElements.h:98
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:613
void __updateNodeSet(Set< NodeId > &, const DiscreteVariable *, VariableSelector &)
For each node in the given set, this methods checks whether or not we should installed the given vari...
Definition: imddi_tpl.h:201
void clear()
Clears the function graph.
<agrum/FMDP/learning/datastructure/nodeDatabase.h>
Definition: nodeDatabase.h:58
#define SOA_ALLOCATE(x)
Idx __nbTotalObservation
The total number of observation added to this tree.
Definition: imddi.h:175