aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
imddi_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file Template Implementations of the IMDDI datastructure learner
24
* @brief
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27
* GONZALES(@AMU)
28
*/
29
// =======================================================
30
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
31
#
include
<
agrum
/
tools
/
core
/
priorityQueue
.
h
>
32
#
include
<
agrum
/
tools
/
core
/
types
.
h
>
33
// =======================================================
34
#
include
<
agrum
/
FMDP
/
learning
/
core
/
chiSquare
.
h
>
35
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
imddi
.
h
>
36
// =======================================================
37
#
include
<
agrum
/
tools
/
variables
/
labelizedVariable
.
h
>
38
// =======================================================
39
40
41
namespace
gum
{
42
43
// ############################################################################
44
// Constructor & destructor.
45
// ############################################################################
46
47
// ============================================================================
48
// Variable Learner constructor
49
// ============================================================================
50
template
< TESTNAME AttributeSelection,
bool
isScalar >
51
IMDDI< AttributeSelection, isScalar >::IMDDI(
MultiDimFunctionGraph
<
double
>*
target
,
52
double
attributeSelectionThreshold
,
53
double
pairSelectionThreshold
,
54
Set
<
const
DiscreteVariable
* >
attributeListe
,
55
const
DiscreteVariable
*
learnedValue
) :
56
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
target
,
attributeListe
,
learnedValue
),
57
_lg_
(&(
this
->
model_
),
pairSelectionThreshold
),
_nbTotalObservation_
(0),
58
_attributeSelectionThreshold_
(
attributeSelectionThreshold
) {
59
GUM_CONSTRUCTOR
(
IMDDI
);
60
_addLeaf_
(
this
->
root_
);
61
}
62
63
// ============================================================================
64
// Reward Learner constructor
65
// ============================================================================
66
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
67
IMDDI
<
AttributeSelection
,
isScalar
>::
IMDDI
(
MultiDimFunctionGraph
<
double
>*
target
,
68
double
attributeSelectionThreshold
,
69
double
pairSelectionThreshold
,
70
Set
<
const
DiscreteVariable
* >
attributeListe
) :
71
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
72
target
,
73
attributeListe
,
74
new
LabelizedVariable
(
"Reward"
,
""
, 2)),
75
_lg_
(&(
this
->
model_
),
pairSelectionThreshold
),
_nbTotalObservation_
(0),
76
_attributeSelectionThreshold_
(
attributeSelectionThreshold
) {
77
GUM_CONSTRUCTOR
(
IMDDI
);
78
_addLeaf_
(
this
->
root_
);
79
}
80
81
// ============================================================================
82
// Reward Learner constructor
83
// ============================================================================
84
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
85
IMDDI
<
AttributeSelection
,
isScalar
>::~
IMDDI
() {
86
GUM_DESTRUCTOR
(
IMDDI
);
87
for
(
HashTableIteratorSafe
<
NodeId
,
AbstractLeaf
* >
leafIter
=
_leafMap_
.
beginSafe
();
88
leafIter
!=
_leafMap_
.
endSafe
();
89
++
leafIter
)
90
delete
leafIter
.
val
();
91
}
92
93
94
// ############################################################################
95
// Incrementals methods
96
// ############################################################################
97
98
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
99
void
IMDDI
<
AttributeSelection
,
isScalar
>::
addObservation
(
const
Observation
*
obs
) {
100
_nbTotalObservation_
++;
101
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
addObservation
(
obs
);
102
}
103
104
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
105
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateNodeWithObservation_
(
const
Observation
*
newObs
,
106
NodeId
currentNodeId
) {
107
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
updateNodeWithObservation_
(
108
newObs
,
109
currentNodeId
);
110
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
)
_lg_
.
updateLeaf
(
_leafMap_
[
currentNodeId
]);
111
}
112
113
114
// ============================================================================
115
// Updates the tree after a new observation has been added
116
// ============================================================================
117
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
118
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateGraph
() {
119
_varOrder_
.
clear
();
120
121
// First xe initialize the node set which will give us the scores
122
Set
<
NodeId
>
currentNodeSet
;
123
currentNodeSet
.
insert
(
this
->
root_
);
124
125
// Then we initialize the pool of variables to consider
126
VariableSelector
vs
(
this
->
setOfVars_
);
127
for
(
vs
.
begin
();
vs
.
hasNext
();
vs
.
next
()) {
128
_updateScore_
(
vs
.
current
(),
this
->
root_
,
vs
);
129
}
130
131
// Then, until there's no node remaining
132
while
(!
vs
.
isEmpty
()) {
133
// We select the best var
134
const
DiscreteVariable
*
selectedVar
=
vs
.
select
();
135
_varOrder_
.
insert
(
selectedVar
);
136
137
// Then we decide if we update each node according to this var
138
_updateNodeSet_
(
currentNodeSet
,
selectedVar
,
vs
);
139
}
140
141
// If there are remaining node that are not leaves after we establish the
142
// var order
143
// these nodes are turned into leaf.
144
for
(
SetIteratorSafe
<
NodeId
>
nodeIter
=
currentNodeSet
.
beginSafe
();
145
nodeIter
!=
currentNodeSet
.
endSafe
();
146
++
nodeIter
)
147
this
->
convertNode2Leaf_
(*
nodeIter
);
148
149
150
if
(
_lg_
.
needsUpdate
())
_lg_
.
update
();
151
}
152
153
154
// ############################################################################
155
// Updating methods
156
// ############################################################################
157
158
159
// ###################################################################
160
// Select the most relevant variable
161
//
162
// First parameter is the set of variables among which the most
163
// relevant one is choosed
164
// Second parameter is the set of node the will attribute a score
165
// to each variable so that we choose the best.
166
// ###################################################################
167
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
168
void
IMDDI
<
AttributeSelection
,
isScalar
>::
_updateScore_
(
const
DiscreteVariable
*
var
,
169
NodeId
nody
,
170
VariableSelector
&
vs
) {
171
if
(!
this
->
nodeId2Database_
[
nody
]->
isTestRelevant
(
var
))
return
;
172
double
weight
173
= (
double
)
this
->
nodeId2Database_
[
nody
]->
nbObservation
() / (
double
)
this
->
_nbTotalObservation_
;
174
vs
.
updateScore
(
var
,
175
weight
*
this
->
nodeId2Database_
[
nody
]->
testValue
(
var
),
176
weight
*
this
->
nodeId2Database_
[
nody
]->
testOtherCriterion
(
var
));
177
}
178
179
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
180
void
IMDDI
<
AttributeSelection
,
isScalar
>::
_downdateScore_
(
const
DiscreteVariable
*
var
,
181
NodeId
nody
,
182
VariableSelector
&
vs
) {
183
if
(!
this
->
nodeId2Database_
[
nody
]->
isTestRelevant
(
var
))
return
;
184
double
weight
185
= (
double
)
this
->
nodeId2Database_
[
nody
]->
nbObservation
() / (
double
)
this
->
_nbTotalObservation_
;
186
vs
.
downdateScore
(
var
,
187
weight
*
this
->
nodeId2Database_
[
nody
]->
testValue
(
var
),
188
weight
*
this
->
nodeId2Database_
[
nody
]->
testOtherCriterion
(
var
));
189
}
190
191
192
// ============================================================================
193
// For each node in the given set, this methods checks whether or not
194
// we should installed the given variable as a test.
195
// If so, the node is updated
196
// ============================================================================
197
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
198
void
IMDDI
<
AttributeSelection
,
isScalar
>::
_updateNodeSet_
(
Set
<
NodeId
>&
nodeSet
,
199
const
DiscreteVariable
*
selectedVar
,
200
VariableSelector
&
vs
) {
201
Set
<
NodeId
>
oldNodeSet
(
nodeSet
);
202
nodeSet
.
clear
();
203
for
(
SetIteratorSafe
<
NodeId
>
nodeIter
=
oldNodeSet
.
beginSafe
();
204
nodeIter
!=
oldNodeSet
.
endSafe
();
205
++
nodeIter
) {
206
if
(
this
->
nodeId2Database_
[*
nodeIter
]->
isTestRelevant
(
selectedVar
)
207
&&
this
->
nodeId2Database_
[*
nodeIter
]->
testValue
(
selectedVar
)
208
>
_attributeSelectionThreshold_
) {
209
this
->
transpose_
(*
nodeIter
,
selectedVar
);
210
211
// Then we subtract the from the score given to each variables the
212
// quantity given by this node
213
for
(
vs
.
begin
();
vs
.
hasNext
();
vs
.
next
()) {
214
_downdateScore_
(
vs
.
current
(), *
nodeIter
,
vs
);
215
}
216
217
// And finally we add all its child to the new set of nodes
218
// and updates the remaining var's score
219
for
(
Idx
modality
= 0;
modality
<
this
->
nodeVarMap_
[*
nodeIter
]->
domainSize
(); ++
modality
) {
220
NodeId
sonId
=
this
->
nodeSonsMap_
[*
nodeIter
][
modality
];
221
nodeSet
<<
sonId
;
222
223
for
(
vs
.
begin
();
vs
.
hasNext
();
vs
.
next
()) {
224
_updateScore_
(
vs
.
current
(),
sonId
,
vs
);
225
}
226
}
227
}
else
{
228
nodeSet
<< *
nodeIter
;
229
}
230
}
231
}
232
233
234
// ============================================================================
235
// Insert a new node with given associated database, var and maybe sons
236
// ============================================================================
237
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
238
NodeId
IMDDI
<
AttributeSelection
,
isScalar
>::
insertLeafNode_
(
239
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
240
const
DiscreteVariable
*
boundVar
,
241
Set
<
const
Observation
* >*
obsSet
) {
242
NodeId
currentNodeId
243
=
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertLeafNode_
(
nDB
,
244
boundVar
,
245
obsSet
);
246
247
_addLeaf_
(
currentNodeId
);
248
249
return
currentNodeId
;
250
}
251
252
253
// ============================================================================
254
// Changes var associated to a node
255
// ============================================================================
256
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
257
void
IMDDI
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
NodeId
currentNodeId
,
258
const
DiscreteVariable
*
desiredVar
) {
259
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
)
_removeLeaf_
(
currentNodeId
);
260
261
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
currentNodeId
,
262
desiredVar
);
263
264
if
(
desiredVar
==
this
->
value_
)
_addLeaf_
(
currentNodeId
);
265
}
266
267
268
// ============================================================================
269
// Remove node from graph
270
// ============================================================================
271
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
272
void
IMDDI
<
AttributeSelection
,
isScalar
>::
removeNode_
(
NodeId
currentNodeId
) {
273
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
)
_removeLeaf_
(
currentNodeId
);
274
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
removeNode_
(
currentNodeId
);
275
}
276
277
278
// ============================================================================
279
// Add leaf to aggregator
280
// ============================================================================
281
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
282
void
IMDDI
<
AttributeSelection
,
isScalar
>::
_addLeaf_
(
NodeId
currentNodeId
) {
283
_leafMap_
.
insert
(
284
currentNodeId
,
285
new
ConcreteLeaf
<
AttributeSelection
,
isScalar
>(
currentNodeId
,
286
this
->
nodeId2Database_
[
currentNodeId
],
287
&(
this
->
valueAssumed_
)));
288
_lg_
.
addLeaf
(
_leafMap_
[
currentNodeId
]);
289
}
290
291
292
// ============================================================================
293
// Remove leaf from aggregator
294
// ============================================================================
295
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
296
void
IMDDI
<
AttributeSelection
,
isScalar
>::
_removeLeaf_
(
NodeId
currentNodeId
) {
297
_lg_
.
removeLeaf
(
_leafMap_
[
currentNodeId
]);
298
delete
_leafMap_
[
currentNodeId
];
299
_leafMap_
.
erase
(
currentNodeId
);
300
}
301
302
303
// ============================================================================
304
// Computes the Reduced and Ordered Function Graph associated to this ordered
305
// tree
306
// ============================================================================
307
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
308
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateFunctionGraph
() {
309
// if( _lg_.needsUpdate() || this->needUpdate_ ){
310
_rebuildFunctionGraph_
();
311
this
->
needUpdate_
=
false
;
312
// }
313
}
314
315
316
// ============================================================================
317
// Performs the leaves merging
318
// ============================================================================
319
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
320
void
IMDDI
<
AttributeSelection
,
isScalar
>::
_rebuildFunctionGraph_
() {
321
// *******************************************************************************************************
322
// Mise à jour de l'aggregateur de feuille
323
_lg_
.
update
();
324
325
// *******************************************************************************************************
326
// Reinitialisation du Graphe de Décision
327
this
->
target_
->
clear
();
328
for
(
auto
varIter
=
_varOrder_
.
beginSafe
();
varIter
!=
_varOrder_
.
endSafe
(); ++
varIter
)
329
this
->
target_
->
add
(**
varIter
);
330
this
->
target_
->
add
(*
this
->
value_
);
331
332
HashTable
<
NodeId
,
NodeId
>
toTarget
;
333
334
// *******************************************************************************************************
335
// Insertion des feuilles
336
HashTable
<
NodeId
,
AbstractLeaf
* >
treeNode2leaf
=
_lg_
.
leavesMap
();
337
HashTable
<
AbstractLeaf
*,
NodeId
>
leaf2DGNode
;
338
for
(
HashTableConstIteratorSafe
<
NodeId
,
AbstractLeaf
* >
treeNodeIter
339
=
treeNode2leaf
.
cbeginSafe
();
340
treeNodeIter
!=
treeNode2leaf
.
cendSafe
();
341
++
treeNodeIter
) {
342
if
(!
leaf2DGNode
.
exists
(
treeNodeIter
.
val
()))
343
leaf2DGNode
.
insert
(
treeNodeIter
.
val
(),
344
_insertLeafInFunctionGraph_
(
treeNodeIter
.
val
(),
Int2Type
<
isScalar
>()));
345
346
toTarget
.
insert
(
treeNodeIter
.
key
(),
leaf2DGNode
[
treeNodeIter
.
val
()]);
347
}
348
349
// *******************************************************************************************************
350
// Insertion des noeuds internes (avec vérification des possibilités de
351
// fusion)
352
for
(
SequenceIteratorSafe
<
const
DiscreteVariable
* >
varIter
=
_varOrder_
.
rbeginSafe
();
353
varIter
!=
_varOrder_
.
rendSafe
();
354
--
varIter
) {
355
for
(
Link
<
NodeId
>*
curNodeIter
=
this
->
var2Node_
[*
varIter
]->
list
();
curNodeIter
;
356
curNodeIter
=
curNodeIter
->
nextLink
()) {
357
NodeId
*
sonsMap
358
=
static_cast
<
NodeId
* >(
SOA_ALLOCATE
(
sizeof
(
NodeId
) * (*
varIter
)->
domainSize
()));
359
for
(
Idx
modality
= 0;
modality
< (*
varIter
)->
domainSize
(); ++
modality
)
360
sonsMap
[
modality
] =
toTarget
[
this
->
nodeSonsMap_
[
curNodeIter
->
element
()][
modality
]];
361
toTarget
.
insert
(
curNodeIter
->
element
(),
362
this
->
target_
->
manager
()->
addInternalNode
(*
varIter
,
sonsMap
));
363
}
364
}
365
366
// *******************************************************************************************************
367
// Polish
368
this
->
target_
->
manager
()->
setRootNode
(
toTarget
[
this
->
root_
]);
369
this
->
target_
->
manager
()->
clean
();
370
}
371
372
373
// ============================================================================
374
// Performs the leaves merging
375
// ============================================================================
376
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
377
NodeId
IMDDI
<
AttributeSelection
,
isScalar
>::
_insertLeafInFunctionGraph_
(
AbstractLeaf
*
leaf
,
378
Int2Type
<
true
>) {
379
double
value
= 0.0;
380
for
(
Idx
moda
= 0;
moda
<
leaf
->
nbModa
();
moda
++) {
381
value
+= (
double
)
leaf
->
effectif
(
moda
) *
this
->
valueAssumed_
.
atPos
(
moda
);
382
}
383
if
(
leaf
->
total
())
value
/= (
double
)
leaf
->
total
();
384
return
this
->
target_
->
manager
()->
addTerminalNode
(
value
);
385
}
386
387
388
// ============================================================================
389
// Performs the leaves merging
390
// ============================================================================
391
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
392
NodeId
IMDDI
<
AttributeSelection
,
isScalar
>::
_insertLeafInFunctionGraph_
(
AbstractLeaf
*
leaf
,
393
Int2Type
<
false
>) {
394
NodeId
*
sonsMap
395
=
static_cast
<
NodeId
* >(
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
this
->
value_
->
domainSize
()));
396
for
(
Idx
modality
= 0;
modality
<
this
->
value_
->
domainSize
(); ++
modality
) {
397
double
newVal
= 0.0;
398
if
(
leaf
->
total
())
newVal
= (
double
)
leaf
->
effectif
(
modality
) / (
double
)
leaf
->
total
();
399
sonsMap
[
modality
] =
this
->
target_
->
manager
()->
addTerminalNode
(
newVal
);
400
}
401
return
this
->
target_
->
manager
()->
addInternalNode
(
this
->
value_
,
sonsMap
);
402
}
403
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643