aGrUM  0.14.2
iti_tpl.h
Go to the documentation of this file.
1 /*********************************************************************************
2  * Copyright (C) 2005 by Pierre-Henri WUILLEMIN et Christophe GONZALES *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  *********************************************************************************/
26 // =======================================================
27 #include <agrum/core/math/math.h>
29 #include <agrum/core/types.h>
30 // =======================================================
33 // =======================================================
35 // =======================================================
36 
37 
38 namespace gum {
39 
40  // ==========================================================================
42  // ==========================================================================
43 
44  // ###################################################################
56  // ###################################################################
57  template < TESTNAME AttributeSelection, bool isScalar >
60  double attributeSelectionThreshold,
61  Set< const DiscreteVariable* > attributeListe,
62  const DiscreteVariable* learnedValue) :
63  IncrementalGraphLearner< AttributeSelection, isScalar >(
64  target, attributeListe, learnedValue),
65  __nbTotalObservation(0),
66  __attributeSelectionThreshold(attributeSelectionThreshold) {
67  GUM_CONSTRUCTOR(ITI);
68  __staleTable.insert(this->_root, false);
69  }
70 
71  // ###################################################################
82  // ###################################################################
83  template < TESTNAME AttributeSelection, bool isScalar >
86  double attributeSelectionThreshold,
87  Set< const DiscreteVariable* > attributeListe) :
88  IncrementalGraphLearner< AttributeSelection, isScalar >(
89  target, attributeListe, new LabelizedVariable("Reward", "", 2)),
91  __attributeSelectionThreshold(attributeSelectionThreshold) {
92  GUM_CONSTRUCTOR(ITI);
93  __staleTable.insert(this->_root, false);
94  }
95 
96 
97  // ==========================================================================
99  // ==========================================================================
100 
101  // ############################################################################
106  // ############################################################################
107  template < TESTNAME AttributeSelection, bool isScalar >
108  void
112  }
113 
114  // ############################################################################
121  // ############################################################################
122  template < TESTNAME AttributeSelection, bool isScalar >
124  const Observation* newObs, NodeId currentNodeId) {
125  IncrementalGraphLearner< AttributeSelection,
126  isScalar >::_updateNodeWithObservation(newObs,
127  currentNodeId);
128  __staleTable[currentNodeId] = true;
129  }
130 
131 
132  // ============================================================================
134  // ============================================================================
135 
136  // ############################################################################
138  // ############################################################################
139  template < TESTNAME AttributeSelection, bool isScalar >
141  std::vector< NodeId > filo;
142  filo.push_back(this->_root);
144  potentialVars.insert(this->_root,
146 
147 
148  while (!filo.empty()) {
149  NodeId currentNodeId = filo.back();
150  filo.pop_back();
151 
152  // First we look for the best var to install on the node
153  double bestValue = __attributeSelectionThreshold;
155 
156  for (auto varIter = potentialVars[currentNodeId]->cbeginSafe();
157  varIter != potentialVars[currentNodeId]->cendSafe();
158  ++varIter)
159  if (this->_nodeId2Database[currentNodeId]->isTestRelevant(*varIter)) {
160  double varValue =
161  this->_nodeId2Database[currentNodeId]->testValue(*varIter);
162  if (varValue >= bestValue) {
163  if (varValue > bestValue) {
164  bestValue = varValue;
165  bestVars.clear();
166  }
167  bestVars.insert(*varIter);
168  }
169  }
170 
171  // Then We installed Variable a test on that node
172  this->_updateNode(currentNodeId, bestVars);
173 
174  // The we move on the children if needed
175  if (this->_nodeVarMap[currentNodeId] != this->_value) {
176  for (Idx moda = 0; moda < this->_nodeVarMap[currentNodeId]->domainSize();
177  moda++) {
178  Set< const DiscreteVariable* >* itsPotentialVars =
179  new Set< const DiscreteVariable* >(*potentialVars[currentNodeId]);
180  itsPotentialVars->erase(this->_nodeVarMap[currentNodeId]);
181  NodeId sonId = this->_nodeSonsMap[currentNodeId][moda];
182  if (__staleTable[sonId]) {
183  filo.push_back(sonId);
184  potentialVars.insert(sonId, itsPotentialVars);
185  }
186  }
187  }
188  }
189 
191  nodeIter = potentialVars.beginSafe();
192  nodeIter != potentialVars.endSafe();
193  ++nodeIter)
194  delete nodeIter.val();
195  }
196 
197 
198  // ############################################################################
205  // ############################################################################
206  template < TESTNAME AttributeSelection, bool isScalar >
209  const DiscreteVariable* boundVar) {
210  NodeId n =
212  nDB, boundVar);
213  __staleTable.insert(n, true);
214  return n;
215  }
216 
217 
218  // ############################################################################
224  // ############################################################################
225  template < TESTNAME AttributeSelection, bool isScalar >
227  NodeId currentNodeId, const DiscreteVariable* desiredVar) {
228  if (this->_nodeVarMap[currentNodeId] != desiredVar) {
229  __staleTable[currentNodeId] = true;
231  currentNodeId, desiredVar);
232  }
233  }
234 
235 
236  // ############################################################################
241  // ############################################################################
242  template < TESTNAME AttributeSelection, bool isScalar >
245  currentNodeId);
246  __staleTable.erase(currentNodeId);
247  }
248 
249 
250  // ============================================================================
252  // ============================================================================
253 
254  // ############################################################################
256  // ############################################################################
257  template < TESTNAME AttributeSelection, bool isScalar >
259  this->_target->clear();
260  this->_target->manager()->setRootNode(
261  this->__insertNodeInFunctionGraph(this->_root));
262  }
263 
264 
265  // ############################################################################
271  // ############################################################################
272  template < TESTNAME AttributeSelection, bool isScalar >
274  NodeId currentNodeId) {
275  if (this->_nodeVarMap[currentNodeId] == this->_value) {
276  NodeId nody = __insertTerminalNode(currentNodeId);
277  return nody;
278  }
279 
280  if (!this->_target->variablesSequence().exists(
281  this->_nodeVarMap[currentNodeId])) {
282  this->_target->add(*(this->_nodeVarMap[currentNodeId]));
283  }
284 
285  NodeId nody =
286  this->_target->manager()->addInternalNode(this->_nodeVarMap[currentNodeId]);
287  for (Idx moda = 0; moda < this->_nodeVarMap[currentNodeId]->domainSize();
288  ++moda) {
290  this->_nodeSonsMap[currentNodeId][moda]);
291  this->_target->manager()->setSon(nody, moda, son);
292  }
293 
294  return nody;
295  }
296 
297 
298  // ############################################################################
306  // ############################################################################
307  template < TESTNAME AttributeSelection, bool isScalar >
309  NodeId currentNodeId, Int2Type< false >) {
310  if (!this->_target->variablesSequence().exists(this->_value))
311  this->_target->add(*(this->_value));
312 
313  Size tot = this->_nodeId2Database[currentNodeId]->nbObservation();
314  if (tot == Size(0)) return this->_target->manager()->addTerminalNode(0.0);
315 
316  NodeId* sonsMap = static_cast< NodeId* >(
317  SOA_ALLOCATE(sizeof(NodeId) * this->_value->domainSize()));
318  for (Idx modality = 0; modality < this->_value->domainSize(); ++modality) {
319  double newVal = 0.0;
320  newVal = (double)this->_nodeId2Database[currentNodeId]->effectif(modality)
321  / (double)tot;
322  sonsMap[modality] = this->_target->manager()->addTerminalNode(newVal);
323  }
324  NodeId nody = this->_target->manager()->addInternalNode(this->_value, sonsMap);
325  return nody;
326  }
327 
328 
329  // ############################################################################
337  // ############################################################################
338  template < TESTNAME AttributeSelection, bool isScalar >
340  NodeId currentNodeId, Int2Type< true >) {
341  double value = 0.0;
342  for (auto valIter = this->_nodeId2Database[currentNodeId]->cbeginValues();
343  valIter != this->_nodeId2Database[currentNodeId]->cendValues();
344  ++valIter) {
345  value += (double)valIter.key() * valIter.val();
346  }
347  if (this->_nodeId2Database[currentNodeId]->nbObservation())
348  value /= (double)this->_nodeId2Database[currentNodeId]->nbObservation();
349  NodeId nody = this->_target->manager()->addTerminalNode(value);
350  return nody;
351  }
352 } // namespace gum
Useful macros for maths.
HashTable< NodeId, NodeId *> _nodeSonsMap
A table giving for any node a table mapping to its son idx is the modality of associated variable...
void updateFunctionGraph()
Updates target to currently learned graph structure.
Definition: iti_tpl.h:258
NodeId _insertNode(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar)
inserts a new node in internal graph
Definition: iti_tpl.h:207
Provides basic types used in aGrUM.
void setSon(const NodeId &node, const Idx &modality, const NodeId &sonNode)
Sets nodes son for given modality to designated son node.
Set< const DiscreteVariable *> _setOfVars
NodeId __insertNodeInFunctionGraph(NodeId src)
Inserts an internal node in the target.
Definition: iti_tpl.h:273
HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar > *> _nodeId2Database
This hashtable binds every node to an associated NodeDatabase which handles every observation that co...
double __attributeSelectionThreshold
The threshold above which we consider variables to be dependant.
Definition: iti.h:259
Learn a graphical representation of a function as a decision tree.
Definition: iti.h:59
void setRootNode(const NodeId &root)
Sets root node of decision diagram.
void _updateNode(NodeId nody, Set< const DiscreteVariable * > &bestVars)
From the given sets of node, selects randomly one and installs it on given node.
class LabelizedVariable
const iterator_safe & endSafe() noexcept
Returns the safe iterator pointing to the end of the hashtable.
<agrum/FMDP/learning/datastructure/incrementalGraphLearner>
void erase(const Key &key)
Removes a given element from the hash table.
virtual NodeId _insertNode(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar)
inserts a new node in internal graph
HashTable< NodeId, bool > __staleTable
Hashtable indicating if given node has been modified (upon receiving new exemple or through a transpo...
Definition: iti.h:253
Headers of the ITI class.
void updateGraph()
Updates the internal graph after a new observation has been added.
Definition: iti_tpl.h:140
NodeId addInternalNode(const DiscreteVariable *var)
Inserts a new non terminal node in graph.
void erase(const Key &k)
Erases an element from the set.
Definition: set_tpl.h:653
Base class for discrete random variable.
Safe Iterators for hashtables.
Definition: hashTable.h:2217
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
MultiDimFunctionGraph< double > * _target
The final diagram we&#39;re building.
Headers of the ChiSquare class.
The class for generic Hash Tables.
Definition: hashTable.h:676
Idx __nbTotalObservation
The total number of observation added to this tree.
Definition: iti.h:256
void _updateNodeWithObservation(const Observation *newObs, NodeId currentNodeId)
Will update internal graph&#39;s NodeDatabase of given node with the new observation. ...
Definition: iti_tpl.h:123
Representation of a setA Set is a structure that contains arbitrary elements.
Definition: set.h:162
ITI(MultiDimFunctionGraph< double > *target, double attributeSelectionThreshold, Set< const DiscreteVariable * > attributeListe, const DiscreteVariable *learnedValue)
ITI constructor for functions describing the behaviour of one variable according to a set of other va...
Definition: iti_tpl.h:58
virtual Size domainSize() const =0
virtual void add(const DiscreteVariable &v)
Adds a new var to the variables of the multidimensional matrix.
virtual void addObservation(const Observation *obs)
Inserts a new observation.
const const_iterator_safe & cendSafe() const noexcept
Returns the safe const_iterator pointing to the end of the hashtable.
void _removeNode(NodeId removedNodeId)
Removes a node from the internal graph.
Definition: iti_tpl.h:243
virtual void _removeNode(NodeId removedNodeId)
Removes a node from the internal graph.
priority queues (in which an element cannot appear more than once)
virtual const Sequence< const DiscreteVariable *> & variablesSequence() const override
Returns a const ref to the sequence of DiscreteVariable*.
virtual void _chgNodeBoundVar(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
iterator_safe beginSafe()
Returns the safe iterator pointing to the beginning of the hashtable.
NodeId __insertTerminalNode(NodeId src)
Insert a terminal node in the target.
Definition: iti.h:205
NodeId _root
The root of the ordered tree.
void _chgNodeBoundVar(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
Definition: iti_tpl.h:226
NodeId addTerminalNode(const GUM_SCALAR &value)
Adds a value to the MultiDimFunctionGraph.
Size Idx
Type for indexes.
Definition: types.h:50
void clear()
Removes all the elements, if any, from the set.
Definition: set_tpl.h:372
MultiDimFunctionGraphManager< GUM_SCALAR, TerminalNodePolicy > * manager()
Returns a const reference to the manager of this diagram.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
Base class for labelized discrete random variables.
HashTable< NodeId, const DiscreteVariable *> _nodeVarMap
Gives for any node its associated variable.
void addObservation(const Observation *obs)
Inserts a new observation.
Definition: iti_tpl.h:109
Size NodeId
Type for node ids.
Definition: graphElements.h:97
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:610
void clear()
Clears the function graph.
<agrum/FMDP/learning/datastructure/nodeDatabase.h>
Definition: nodeDatabase.h:55
#define SOA_ALLOCATE(x)