aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
imddi_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file Template Implementations of the IMDDI datastructure learner
24  * @brief
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27  * GONZALES(@AMU)
28  */
29 // =======================================================
30 #include <agrum/tools/core/math/math_utils.h>
31 #include <agrum/tools/core/priorityQueue.h>
32 #include <agrum/tools/core/types.h>
33 // =======================================================
34 #include <agrum/FMDP/learning/core/chiSquare.h>
35 #include <agrum/FMDP/learning/datastructure/imddi.h>
36 // =======================================================
37 #include <agrum/tools/variables/labelizedVariable.h>
38 // =======================================================
39 
40 
41 namespace gum {
42 
43  // ############################################################################
44  // Constructor & destructor.
45  // ############################################################################
46 
47  // ============================================================================
48  // Variable Learner constructor
49  // ============================================================================
50  template < TESTNAME AttributeSelection, bool isScalar >
51  IMDDI< AttributeSelection, isScalar >::IMDDI(
52  MultiDimFunctionGraph< double >* target,
59  learnedValue),
63  addLeaf__(this->root_);
64  }
65 
66  // ============================================================================
67  // Reward Learner constructor
68  // ============================================================================
69  template < TESTNAME AttributeSelection, bool isScalar >
71  MultiDimFunctionGraph< double >* target,
76  target,
78  new LabelizedVariable("Reward", "", 2)),
82  addLeaf__(this->root_);
83  }
84 
85  // ============================================================================
86  // Reward Learner constructor
87  // ============================================================================
88  template < TESTNAME AttributeSelection, bool isScalar >
92  = leafMap__.beginSafe();
94  ++leafIter)
95  delete leafIter.val();
96  }
97 
98 
99  // ############################################################################
100  // Incrementals methods
101  // ############################################################################
102 
103  template < TESTNAME AttributeSelection, bool isScalar >
105  const Observation* obs) {
108  }
109 
110  template < TESTNAME AttributeSelection, bool isScalar >
112  const Observation* newObs,
116  currentNodeId);
117  if (this->nodeVarMap_[currentNodeId] == this->value_)
119  }
120 
121 
122  // ============================================================================
123  // Updates the tree after a new observation has been added
124  // ============================================================================
125  template < TESTNAME AttributeSelection, bool isScalar >
127  varOrder__.clear();
128 
129  // First xe initialize the node set which will give us the scores
131  currentNodeSet.insert(this->root_);
132 
133  // Then we initialize the pool of variables to consider
135  for (vs.begin(); vs.hasNext(); vs.next()) {
136  updateScore__(vs.current(), this->root_, vs);
137  }
138 
139  // Then, until there's no node remaining
140  while (!vs.isEmpty()) {
141  // We select the best var
144 
145  // Then we decide if we update each node according to this var
147  }
148 
149  // If there are remaining node that are not leaves after we establish the
150  // var order
151  // these nodes are turned into leaf.
154  ++nodeIter)
155  this->convertNode2Leaf_(*nodeIter);
156 
157 
158  if (lg__.needsUpdate()) lg__.update();
159  }
160 
161 
162  // ############################################################################
163  // Updating methods
164  // ############################################################################
165 
166 
167  // ###################################################################
168  // Select the most relevant variable
169  //
170  // First parameter is the set of variables among which the most
171  // relevant one is choosed
172  // Second parameter is the set of node the will attribute a score
173  // to each variable so that we choose the best.
174  // ###################################################################
175  template < TESTNAME AttributeSelection, bool isScalar >
177  const DiscreteVariable* var,
178  NodeId nody,
179  VariableSelector& vs) {
180  if (!this->nodeId2Database_[nody]->isTestRelevant(var)) return;
181  double weight = (double)this->nodeId2Database_[nody]->nbObservation()
182  / (double)this->nbTotalObservation__;
186  }
187 
188  template < TESTNAME AttributeSelection, bool isScalar >
190  const DiscreteVariable* var,
191  NodeId nody,
192  VariableSelector& vs) {
193  if (!this->nodeId2Database_[nody]->isTestRelevant(var)) return;
194  double weight = (double)this->nodeId2Database_[nody]->nbObservation()
195  / (double)this->nbTotalObservation__;
198  weight
200  }
201 
202 
203  // ============================================================================
204  // For each node in the given set, this methods checks whether or not
205  // we should installed the given variable as a test.
206  // If so, the node is updated
207  // ============================================================================
208  template < TESTNAME AttributeSelection, bool isScalar >
210  Set< NodeId >& nodeSet,
212  VariableSelector& vs) {
214  nodeSet.clear();
217  ++nodeIter) {
222 
223  // Then we subtract the from the score given to each variables the
224  // quantity given by this node
225  for (vs.begin(); vs.hasNext(); vs.next()) {
227  }
228 
229  // And finally we add all its child to the new set of nodes
230  // and updates the remaining var's score
231  for (Idx modality = 0;
233  ++modality) {
235  nodeSet << sonId;
236 
237  for (vs.begin(); vs.hasNext(); vs.next()) {
239  }
240  }
241  } else {
242  nodeSet << *nodeIter;
243  }
244  }
245  }
246 
247 
248  // ============================================================================
249  // Insert a new node with given associated database, var and maybe sons
250  // ============================================================================
251  template < TESTNAME AttributeSelection, bool isScalar >
254  const DiscreteVariable* boundVar,
255  Set< const Observation* >* obsSet) {
258  nDB,
259  boundVar,
260  obsSet);
261 
263 
264  return currentNodeId;
265  }
266 
267 
268  // ============================================================================
269  // Changes var associated to a node
270  // ============================================================================
271  template < TESTNAME AttributeSelection, bool isScalar >
274  const DiscreteVariable* desiredVar) {
275  if (this->nodeVarMap_[currentNodeId] == this->value_)
277 
280  desiredVar);
281 
282  if (desiredVar == this->value_) addLeaf__(currentNodeId);
283  }
284 
285 
286  // ============================================================================
287  // Remove node from graph
288  // ============================================================================
289  template < TESTNAME AttributeSelection, bool isScalar >
291  if (this->nodeVarMap_[currentNodeId] == this->value_)
294  currentNodeId);
295  }
296 
297 
298  // ============================================================================
299  // Add leaf to aggregator
300  // ============================================================================
301  template < TESTNAME AttributeSelection, bool isScalar >
307  &(this->valueAssumed_)));
309  }
310 
311 
312  // ============================================================================
313  // Remove leaf from aggregator
314  // ============================================================================
315  template < TESTNAME AttributeSelection, bool isScalar >
318  delete leafMap__[currentNodeId];
320  }
321 
322 
323  // ============================================================================
324  // Computes the Reduced and Ordered Function Graph associated to this ordered
325  // tree
326  // ============================================================================
327  template < TESTNAME AttributeSelection, bool isScalar >
329  // if( lg__.needsUpdate() || this->needUpdate_ ){
331  this->needUpdate_ = false;
332  // }
333  }
334 
335 
336  // ============================================================================
337  // Performs the leaves merging
338  // ============================================================================
339  template < TESTNAME AttributeSelection, bool isScalar >
341  // *******************************************************************************************************
342  // Mise à jour de l'aggregateur de feuille
343  lg__.update();
344 
345  // *******************************************************************************************************
346  // Reinitialisation du Graphe de Décision
347  this->target_->clear();
349  ++varIter)
350  this->target_->add(**varIter);
351  this->target_->add(*this->value_);
352 
354 
355  // *******************************************************************************************************
356  // Insertion des feuilles
362  ++treeNodeIter) {
366  Int2Type< isScalar >()));
367 
369  }
370 
371  // *******************************************************************************************************
372  // Insertion des noeuds internes (avec vérification des possibilités de
373  // fusion)
377  --varIter) {
378  for (Link< NodeId >* curNodeIter = this->var2Node_[*varIter]->list();
379  curNodeIter;
381  NodeId* sonsMap = static_cast< NodeId* >(
382  SOA_ALLOCATE(sizeof(NodeId) * (*varIter)->domainSize()));
383  for (Idx modality = 0; modality < (*varIter)->domainSize(); ++modality)
387  curNodeIter->element(),
389  }
390  }
391 
392  // *******************************************************************************************************
393  // Polish
394  this->target_->manager()->setRootNode(toTarget[this->root_]);
395  this->target_->manager()->clean();
396  }
397 
398 
399  // ============================================================================
400  // Performs the leaves merging
401  // ============================================================================
402  template < TESTNAME AttributeSelection, bool isScalar >
405  Int2Type< true >) {
406  double value = 0.0;
407  for (Idx moda = 0; moda < leaf->nbModa(); moda++) {
408  value += (double)leaf->effectif(moda) * this->valueAssumed_.atPos(moda);
409  }
410  if (leaf->total()) value /= (double)leaf->total();
411  return this->target_->manager()->addTerminalNode(value);
412  }
413 
414 
415  // ============================================================================
416  // Performs the leaves merging
417  // ============================================================================
418  template < TESTNAME AttributeSelection, bool isScalar >
421  Int2Type< false >) {
422  NodeId* sonsMap = static_cast< NodeId* >(
423  SOA_ALLOCATE(sizeof(NodeId) * this->value_->domainSize()));
424  for (Idx modality = 0; modality < this->value_->domainSize(); ++modality) {
425  double newVal = 0.0;
426  if (leaf->total())
427  newVal = (double)leaf->effectif(modality) / (double)leaf->total();
429  }
430  return this->target_->manager()->addInternalNode(this->value_, sonsMap);
431  }
432 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669