aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
iti_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file Template Implementations of the ITI datastructure learner
24  * @brief
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27  * GONZALES(@AMU)
28  */
29 // =======================================================
30 #include <agrum/tools/core/math/math_utils.h>
31 #include <agrum/tools/core/priorityQueue.h>
32 #include <agrum/tools/core/types.h>
33 // =======================================================
34 #include <agrum/FMDP/learning/core/chiSquare.h>
35 #include <agrum/FMDP/learning/datastructure/iti.h>
36 // =======================================================
37 #include <agrum/tools/variables/labelizedVariable.h>
38 // =======================================================
39 
40 
41 namespace gum {
42 
43  // ==========================================================================
44  /// @name Constructor & destructor.
45  // ==========================================================================
46 
47  // ###################################################################
48  /**
49  * ITI constructor for functions describing the behaviour of one variable
50  * according to a set of other variable such as conditionnal probabilities
51  * @param target : the MultiDimFunctionGraph in which we load the structure
52  * @param attributeSelectionThreshold : threshold under which a node is not
53  * installed (pe-pruning)
54  * @param temporaryAPIfix : Issue in API in regard to IMDDI
55  * @param attributeListe : Set of vars on which we rely to explain the
56  * behaviour of learned variable
57  * @param learnedValue : the variable from which we try to learn the behaviour
58  */
59  // ###################################################################
60  template < TESTNAME AttributeSelection, bool isScalar >
61  ITI< AttributeSelection, isScalar >::ITI(
62  MultiDimFunctionGraph< double >* target,
68  learnedValue),
72  staleTable__.insert(this->root_, false);
73  }
74 
75  // ###################################################################
76  /**
77  * ITI constructeur for real functions. We try to predict the output of a
78  * function f given a set of variable
79  * @param target : the MultiDimFunctionGraph in which we load the structure
80  * @param attributeSelectionThreshold : threshold under which a node is not
81  * installed (pe-pruning)
82  * @param temporaryAPIfix : Issue in API in regard to IMDDI
83  * @param attributeListeSet of vars on which we rely to explain the
84  * behaviour of learned function
85  */
86  // ###################################################################
87  template < TESTNAME AttributeSelection, bool isScalar >
89  MultiDimFunctionGraph< double >* target,
93  target,
95  new LabelizedVariable("Reward", "", 2)),
99  staleTable__.insert(this->root_, false);
100  }
101 
102 
103  // ==========================================================================
104  /// @name New Observation insertion methods
105  // ==========================================================================
106 
107  // ############################################################################
108  /**
109  * Inserts a new observation
110  * @param the new observation to learn
111  */
112  // ############################################################################
113  template < TESTNAME AttributeSelection, bool isScalar >
114  void
118  }
119 
120  // ############################################################################
121  /**
122  * Will update internal graph's NodeDatabase of given node with the new
123  * observation
124  * @param newObs
125  * @param currentNodeId
126  */
127  // ############################################################################
128  template < TESTNAME AttributeSelection, bool isScalar >
130  const Observation* newObs,
134  currentNodeId);
135  staleTable__[currentNodeId] = true;
136  }
137 
138 
139  // ============================================================================
140  /// @name Graph Structure update methods
141  // ============================================================================
142 
143  // ############################################################################
144  /// Updates the internal graph after a new observation has been added
145  // ############################################################################
146  template < TESTNAME AttributeSelection, bool isScalar >
148  std::vector< NodeId > filo;
149  filo.push_back(this->root_);
151  potentialVars.insert(this->root_,
152  new Set< const DiscreteVariable* >(this->setOfVars_));
153 
154 
155  while (!filo.empty()) {
157  filo.pop_back();
158 
159  // First we look for the best var to install on the node
161  Set< const DiscreteVariable* > bestVars;
162 
165  ++varIter)
167  double varValue
169  if (varValue >= bestValue) {
170  if (varValue > bestValue) {
172  bestVars.clear();
173  }
175  }
176  }
177 
178  // Then We installed Variable a test on that node
180 
181  // The we move on the children if needed
182  if (this->nodeVarMap_[currentNodeId] != this->value_) {
183  for (Idx moda = 0; moda < this->nodeVarMap_[currentNodeId]->domainSize();
184  moda++) {
186  = new Set< const DiscreteVariable* >(*potentialVars[currentNodeId]);
189  if (staleTable__[sonId]) {
192  }
193  }
194  }
195  }
196 
200  ++nodeIter)
201  delete nodeIter.val();
202  }
203 
204 
205  // ############################################################################
206  /**
207  * inserts a new node in internal graohs
208  * @param nDB : the associated database
209  * @param boundVar : the associated variable
210  * @return the newly created node's id
211  */
212  // ############################################################################
213  template < TESTNAME AttributeSelection, bool isScalar >
216  const DiscreteVariable* boundVar) {
217  NodeId n
219  nDB,
220  boundVar);
221  staleTable__.insert(n, true);
222  return n;
223  }
224 
225 
226  // ############################################################################
227  /**
228  * Changes the associated variable of a node
229  * @param chgedNodeId : the node to change
230  * @param desiredVar : its new associated variable
231  */
232  // ############################################################################
233  template < TESTNAME AttributeSelection, bool isScalar >
236  const DiscreteVariable* desiredVar) {
237  if (this->nodeVarMap_[currentNodeId] != desiredVar) {
238  staleTable__[currentNodeId] = true;
241  desiredVar);
242  }
243  }
244 
245 
246  // ############################################################################
247  /**
248  * Removes a node from the internal graph
249  * @param removedNodeId : the node to remove
250  */
251  // ############################################################################
252  template < TESTNAME AttributeSelection, bool isScalar >
255  currentNodeId);
257  }
258 
259 
260  // ============================================================================
261  /// @name Function Graph Updating methods
262  // ============================================================================
263 
264  // ############################################################################
265  /// Updates target to currently learned graph structure
266  // ############################################################################
267  template < TESTNAME AttributeSelection, bool isScalar >
269  this->target_->clear();
270  this->target_->manager()->setRootNode(
271  this->insertNodeInFunctionGraph__(this->root_));
272  }
273 
274 
275  // ############################################################################
276  /**
277  * Inserts an internal node in the target
278  * @param the source node in internal graph
279  * @return the mathcing node id in the target
280  */
281  // ############################################################################
282  template < TESTNAME AttributeSelection, bool isScalar >
285  if (this->nodeVarMap_[currentNodeId] == this->value_) {
287  return nody;
288  }
289 
290  if (!this->target_->variablesSequence().exists(
291  this->nodeVarMap_[currentNodeId])) {
292  this->target_->add(*(this->nodeVarMap_[currentNodeId]));
293  }
294 
296  this->nodeVarMap_[currentNodeId]);
297  for (Idx moda = 0; moda < this->nodeVarMap_[currentNodeId]->domainSize();
298  ++moda) {
301  this->target_->manager()->setSon(nody, moda, son);
302  }
303 
304  return nody;
305  }
306 
307 
308  // ############################################################################
309  /**
310  * Insert a terminal node in the target.
311  * This function is called if we're learning a real value function.
312  * Inserts then a single value in target.
313  * @param the source node in the learned graph
314  * @return the matching node in the target
315  */
316  // ############################################################################
317  template < TESTNAME AttributeSelection, bool isScalar >
320  Int2Type< false >) {
321  if (!this->target_->variablesSequence().exists(this->value_))
322  this->target_->add(*(this->value_));
323 
325  if (tot == Size(0)) return this->target_->manager()->addTerminalNode(0.0);
326 
327  NodeId* sonsMap = static_cast< NodeId* >(
328  SOA_ALLOCATE(sizeof(NodeId) * this->value_->domainSize()));
329  for (Idx modality = 0; modality < this->value_->domainSize(); ++modality) {
330  double newVal = 0.0;
332  / (double)tot;
334  }
336  return nody;
337  }
338 
339 
340  // ############################################################################
341  /**
342  * Insert a terminal node in the target.
343  * This function is called if we're learning the behaviour of a variable.
344  * Inserts then this variable and the relevant value beneath into target.
345  * @param the source node in the learned graph
346  * @return the matching node in the target
347  */
348  // ############################################################################
349  template < TESTNAME AttributeSelection, bool isScalar >
352  Int2Type< true >) {
353  double value = 0.0;
354  for (auto valIter = this->nodeId2Database_[currentNodeId]->cbeginValues();
356  ++valIter) {
357  value += (double)valIter.key() * valIter.val();
358  }
360  value /= (double)this->nodeId2Database_[currentNodeId]->nbObservation();
362  return nody;
363  }
364 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669