aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
iti_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file Template Implementations of the ITI datastructure learner
24  * @brief
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27  * GONZALES(@AMU)
28  */
29 // =======================================================
30 #include <agrum/tools/core/math/math_utils.h>
31 #include <agrum/tools/core/priorityQueue.h>
32 #include <agrum/tools/core/types.h>
33 // =======================================================
34 #include <agrum/FMDP/learning/core/chiSquare.h>
35 #include <agrum/FMDP/learning/datastructure/iti.h>
36 // =======================================================
37 #include <agrum/tools/variables/labelizedVariable.h>
38 // =======================================================
39 
40 
41 namespace gum {
42 
43  // ==========================================================================
44  /// @name Constructor & destructor.
45  // ==========================================================================
46 
47  // ###################################################################
48  /**
49  * ITI constructor for functions describing the behaviour of one variable
50  * according to a set of other variable such as conditionnal probabilities
51  * @param target : the MultiDimFunctionGraph in which we load the structure
52  * @param attributeSelectionThreshold : threshold under which a node is not
53  * installed (pe-pruning)
54  * @param temporaryAPIfix : Issue in API in regard to IMDDI
55  * @param attributeListe : Set of vars on which we rely to explain the
56  * behaviour of learned variable
57  * @param learnedValue : the variable from which we try to learn the behaviour
58  */
59  // ###################################################################
60  template < TESTNAME AttributeSelection, bool isScalar >
61  ITI< AttributeSelection, isScalar >::ITI(MultiDimFunctionGraph< double >* target,
68  _staleTable_.insert(this->root_, false);
69  }
70 
71  // ###################################################################
72  /**
73  * ITI constructeur for real functions. We try to predict the output of a
74  * function f given a set of variable
75  * @param target : the MultiDimFunctionGraph in which we load the structure
76  * @param attributeSelectionThreshold : threshold under which a node is not
77  * installed (pe-pruning)
78  * @param temporaryAPIfix : Issue in API in regard to IMDDI
79  * @param attributeListeSet of vars on which we rely to explain the
80  * behaviour of learned function
81  */
82  // ###################################################################
83  template < TESTNAME AttributeSelection, bool isScalar >
88  target,
90  new LabelizedVariable("Reward", "", 2)),
93  _staleTable_.insert(this->root_, false);
94  }
95 
96 
97  // ==========================================================================
98  /// @name New Observation insertion methods
99  // ==========================================================================
100 
101  // ############################################################################
102  /**
103  * Inserts a new observation
104  * @param the new observation to learn
105  */
106  // ############################################################################
107  template < TESTNAME AttributeSelection, bool isScalar >
111  }
112 
113  // ############################################################################
114  /**
115  * Will update internal graph's NodeDatabase of given node with the new
116  * observation
117  * @param newObs
118  * @param currentNodeId
119  */
120  // ############################################################################
121  template < TESTNAME AttributeSelection, bool isScalar >
125  newObs,
126  currentNodeId);
127  _staleTable_[currentNodeId] = true;
128  }
129 
130 
131  // ============================================================================
132  /// @name Graph Structure update methods
133  // ============================================================================
134 
135  // ############################################################################
136  /// Updates the internal graph after a new observation has been added
137  // ############################################################################
138  template < TESTNAME AttributeSelection, bool isScalar >
140  std::vector< NodeId > filo;
141  filo.push_back(this->root_);
143  potentialVars.insert(this->root_, new Set< const DiscreteVariable* >(this->setOfVars_));
144 
145 
146  while (!filo.empty()) {
148  filo.pop_back();
149 
150  // First we look for the best var to install on the node
152  Set< const DiscreteVariable* > bestVars;
153 
156  ++varIter)
159  if (varValue >= bestValue) {
160  if (varValue > bestValue) {
162  bestVars.clear();
163  }
165  }
166  }
167 
168  // Then We installed Variable a test on that node
170 
171  // The we move on the children if needed
172  if (this->nodeVarMap_[currentNodeId] != this->value_) {
173  for (Idx moda = 0; moda < this->nodeVarMap_[currentNodeId]->domainSize(); moda++) {
175  = new Set< const DiscreteVariable* >(*potentialVars[currentNodeId]);
178  if (_staleTable_[sonId]) {
181  }
182  }
183  }
184  }
185 
189  ++nodeIter)
190  delete nodeIter.val();
191  }
192 
193 
194  // ############################################################################
195  /**
196  * inserts a new node in internal graohs
197  * @param nDB : the associated database
198  * @param boundVar : the associated variable
199  * @return the newly created node's id
200  */
201  // ############################################################################
202  template < TESTNAME AttributeSelection, bool isScalar >
205  const DiscreteVariable* boundVar) {
207  _staleTable_.insert(n, true);
208  return n;
209  }
210 
211 
212  // ############################################################################
213  /**
214  * Changes the associated variable of a node
215  * @param chgedNodeId : the node to change
216  * @param desiredVar : its new associated variable
217  */
218  // ############################################################################
219  template < TESTNAME AttributeSelection, bool isScalar >
221  const DiscreteVariable* desiredVar) {
222  if (this->nodeVarMap_[currentNodeId] != desiredVar) {
223  _staleTable_[currentNodeId] = true;
225  desiredVar);
226  }
227  }
228 
229 
230  // ############################################################################
231  /**
232  * Removes a node from the internal graph
233  * @param removedNodeId : the node to remove
234  */
235  // ############################################################################
236  template < TESTNAME AttributeSelection, bool isScalar >
240  }
241 
242 
243  // ============================================================================
244  /// @name Function Graph Updating methods
245  // ============================================================================
246 
247  // ############################################################################
248  /// Updates target to currently learned graph structure
249  // ############################################################################
250  template < TESTNAME AttributeSelection, bool isScalar >
252  this->target_->clear();
254  }
255 
256 
257  // ############################################################################
258  /**
259  * Inserts an internal node in the target
260  * @param the source node in internal graph
261  * @return the mathcing node id in the target
262  */
263  // ############################################################################
264  template < TESTNAME AttributeSelection, bool isScalar >
266  if (this->nodeVarMap_[currentNodeId] == this->value_) {
268  return nody;
269  }
270 
271  if (!this->target_->variablesSequence().exists(this->nodeVarMap_[currentNodeId])) {
272  this->target_->add(*(this->nodeVarMap_[currentNodeId]));
273  }
274 
276  for (Idx moda = 0; moda < this->nodeVarMap_[currentNodeId]->domainSize(); ++moda) {
278  this->target_->manager()->setSon(nody, moda, son);
279  }
280 
281  return nody;
282  }
283 
284 
285  // ############################################################################
286  /**
287  * Insert a terminal node in the target.
288  * This function is called if we're learning a real value function.
289  * Inserts then a single value in target.
290  * @param the source node in the learned graph
291  * @return the matching node in the target
292  */
293  // ############################################################################
294  template < TESTNAME AttributeSelection, bool isScalar >
296  Int2Type< false >) {
297  if (!this->target_->variablesSequence().exists(this->value_))
298  this->target_->add(*(this->value_));
299 
301  if (tot == Size(0)) return this->target_->manager()->addTerminalNode(0.0);
302 
303  NodeId* sonsMap
304  = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * this->value_->domainSize()));
305  for (Idx modality = 0; modality < this->value_->domainSize(); ++modality) {
306  double newVal = 0.0;
307  newVal = (double)this->nodeId2Database_[currentNodeId]->effectif(modality) / (double)tot;
309  }
311  return nody;
312  }
313 
314 
315  // ############################################################################
316  /**
317  * Insert a terminal node in the target.
318  * This function is called if we're learning the behaviour of a variable.
319  * Inserts then this variable and the relevant value beneath into target.
320  * @param the source node in the learned graph
321  * @return the matching node in the target
322  */
323  // ############################################################################
324  template < TESTNAME AttributeSelection, bool isScalar >
326  Int2Type< true >) {
327  double value = 0.0;
328  for (auto valIter = this->nodeId2Database_[currentNodeId]->cbeginValues();
330  ++valIter) {
331  value += (double)valIter.key() * valIter.val();
332  }
334  value /= (double)this->nodeId2Database_[currentNodeId]->nbObservation();
336  return nody;
337  }
338 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643