aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
imddi_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file Template Implementations of the IMDDI datastructure learner
24  * @brief
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27  * GONZALES(@AMU)
28  */
29 // =======================================================
30 #include <agrum/tools/core/math/math_utils.h>
31 #include <agrum/tools/core/priorityQueue.h>
32 #include <agrum/tools/core/types.h>
33 // =======================================================
34 #include <agrum/FMDP/learning/core/chiSquare.h>
35 #include <agrum/FMDP/learning/datastructure/imddi.h>
36 // =======================================================
37 #include <agrum/tools/variables/labelizedVariable.h>
38 // =======================================================
39 
40 
41 namespace gum {
42 
43  // ############################################################################
44  // Constructor & destructor.
45  // ############################################################################
46 
47  // ============================================================================
48  // Variable Learner constructor
49  // ============================================================================
50  template < TESTNAME AttributeSelection, bool isScalar >
51  IMDDI< AttributeSelection, isScalar >::IMDDI(MultiDimFunctionGraph< double >* target,
60  _addLeaf_(this->root_);
61  }
62 
63  // ============================================================================
64  // Reward Learner constructor
65  // ============================================================================
66  template < TESTNAME AttributeSelection, bool isScalar >
72  target,
74  new LabelizedVariable("Reward", "", 2)),
78  _addLeaf_(this->root_);
79  }
80 
81  // ============================================================================
82  // Reward Learner constructor
83  // ============================================================================
84  template < TESTNAME AttributeSelection, bool isScalar >
89  ++leafIter)
90  delete leafIter.val();
91  }
92 
93 
94  // ############################################################################
95  // Incrementals methods
96  // ############################################################################
97 
98  template < TESTNAME AttributeSelection, bool isScalar >
102  }
103 
104  template < TESTNAME AttributeSelection, bool isScalar >
108  newObs,
109  currentNodeId);
111  }
112 
113 
114  // ============================================================================
115  // Updates the tree after a new observation has been added
116  // ============================================================================
117  template < TESTNAME AttributeSelection, bool isScalar >
119  _varOrder_.clear();
120 
121  // First xe initialize the node set which will give us the scores
123  currentNodeSet.insert(this->root_);
124 
125  // Then we initialize the pool of variables to consider
127  for (vs.begin(); vs.hasNext(); vs.next()) {
128  _updateScore_(vs.current(), this->root_, vs);
129  }
130 
131  // Then, until there's no node remaining
132  while (!vs.isEmpty()) {
133  // We select the best var
136 
137  // Then we decide if we update each node according to this var
139  }
140 
141  // If there are remaining node that are not leaves after we establish the
142  // var order
143  // these nodes are turned into leaf.
146  ++nodeIter)
147  this->convertNode2Leaf_(*nodeIter);
148 
149 
150  if (_lg_.needsUpdate()) _lg_.update();
151  }
152 
153 
154  // ############################################################################
155  // Updating methods
156  // ############################################################################
157 
158 
159  // ###################################################################
160  // Select the most relevant variable
161  //
162  // First parameter is the set of variables among which the most
163  // relevant one is choosed
164  // Second parameter is the set of node the will attribute a score
165  // to each variable so that we choose the best.
166  // ###################################################################
167  template < TESTNAME AttributeSelection, bool isScalar >
169  NodeId nody,
170  VariableSelector& vs) {
171  if (!this->nodeId2Database_[nody]->isTestRelevant(var)) return;
172  double weight
173  = (double)this->nodeId2Database_[nody]->nbObservation() / (double)this->_nbTotalObservation_;
177  }
178 
179  template < TESTNAME AttributeSelection, bool isScalar >
181  NodeId nody,
182  VariableSelector& vs) {
183  if (!this->nodeId2Database_[nody]->isTestRelevant(var)) return;
184  double weight
185  = (double)this->nodeId2Database_[nody]->nbObservation() / (double)this->_nbTotalObservation_;
189  }
190 
191 
192  // ============================================================================
193  // For each node in the given set, this methods checks whether or not
194  // we should installed the given variable as a test.
195  // If so, the node is updated
196  // ============================================================================
197  template < TESTNAME AttributeSelection, bool isScalar >
200  VariableSelector& vs) {
202  nodeSet.clear();
205  ++nodeIter) {
210 
211  // Then we subtract the from the score given to each variables the
212  // quantity given by this node
213  for (vs.begin(); vs.hasNext(); vs.next()) {
215  }
216 
217  // And finally we add all its child to the new set of nodes
218  // and updates the remaining var's score
219  for (Idx modality = 0; modality < this->nodeVarMap_[*nodeIter]->domainSize(); ++modality) {
221  nodeSet << sonId;
222 
223  for (vs.begin(); vs.hasNext(); vs.next()) {
225  }
226  }
227  } else {
228  nodeSet << *nodeIter;
229  }
230  }
231  }
232 
233 
234  // ============================================================================
235  // Insert a new node with given associated database, var and maybe sons
236  // ============================================================================
237  template < TESTNAME AttributeSelection, bool isScalar >
240  const DiscreteVariable* boundVar,
241  Set< const Observation* >* obsSet) {
244  boundVar,
245  obsSet);
246 
248 
249  return currentNodeId;
250  }
251 
252 
253  // ============================================================================
254  // Changes var associated to a node
255  // ============================================================================
256  template < TESTNAME AttributeSelection, bool isScalar >
258  const DiscreteVariable* desiredVar) {
260 
262  desiredVar);
263 
264  if (desiredVar == this->value_) _addLeaf_(currentNodeId);
265  }
266 
267 
268  // ============================================================================
269  // Remove node from graph
270  // ============================================================================
271  template < TESTNAME AttributeSelection, bool isScalar >
275  }
276 
277 
278  // ============================================================================
279  // Add leaf to aggregator
280  // ============================================================================
281  template < TESTNAME AttributeSelection, bool isScalar >
287  &(this->valueAssumed_)));
289  }
290 
291 
292  // ============================================================================
293  // Remove leaf from aggregator
294  // ============================================================================
295  template < TESTNAME AttributeSelection, bool isScalar >
298  delete _leafMap_[currentNodeId];
300  }
301 
302 
303  // ============================================================================
304  // Computes the Reduced and Ordered Function Graph associated to this ordered
305  // tree
306  // ============================================================================
307  template < TESTNAME AttributeSelection, bool isScalar >
309  // if( _lg_.needsUpdate() || this->needUpdate_ ){
311  this->needUpdate_ = false;
312  // }
313  }
314 
315 
316  // ============================================================================
317  // Performs the leaves merging
318  // ============================================================================
319  template < TESTNAME AttributeSelection, bool isScalar >
321  // *******************************************************************************************************
322  // Mise à jour de l'aggregateur de feuille
323  _lg_.update();
324 
325  // *******************************************************************************************************
326  // Reinitialisation du Graphe de Décision
327  this->target_->clear();
329  this->target_->add(**varIter);
330  this->target_->add(*this->value_);
331 
333 
334  // *******************************************************************************************************
335  // Insertion des feuilles
341  ++treeNodeIter) {
345 
347  }
348 
349  // *******************************************************************************************************
350  // Insertion des noeuds internes (avec vérification des possibilités de
351  // fusion)
354  --varIter) {
355  for (Link< NodeId >* curNodeIter = this->var2Node_[*varIter]->list(); curNodeIter;
357  NodeId* sonsMap
358  = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * (*varIter)->domainSize()));
359  for (Idx modality = 0; modality < (*varIter)->domainSize(); ++modality)
363  }
364  }
365 
366  // *******************************************************************************************************
367  // Polish
368  this->target_->manager()->setRootNode(toTarget[this->root_]);
369  this->target_->manager()->clean();
370  }
371 
372 
373  // ============================================================================
374  // Performs the leaves merging
375  // ============================================================================
376  template < TESTNAME AttributeSelection, bool isScalar >
378  Int2Type< true >) {
379  double value = 0.0;
380  for (Idx moda = 0; moda < leaf->nbModa(); moda++) {
381  value += (double)leaf->effectif(moda) * this->valueAssumed_.atPos(moda);
382  }
383  if (leaf->total()) value /= (double)leaf->total();
384  return this->target_->manager()->addTerminalNode(value);
385  }
386 
387 
388  // ============================================================================
389  // Performs the leaves merging
390  // ============================================================================
391  template < TESTNAME AttributeSelection, bool isScalar >
393  Int2Type< false >) {
394  NodeId* sonsMap
395  = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * this->value_->domainSize()));
396  for (Idx modality = 0; modality < this->value_->domainSize(); ++modality) {
397  double newVal = 0.0;
398  if (leaf->total()) newVal = (double)leaf->effectif(modality) / (double)leaf->total();
400  }
401  return this->target_->manager()->addInternalNode(this->value_, sonsMap);
402  }
403 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643