aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
imddi_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file Template Implementations of the IMDDI datastructure learner
24
* @brief
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27
* GONZALES(@AMU)
28
*/
29
// =======================================================
30
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
31
#
include
<
agrum
/
tools
/
core
/
priorityQueue
.
h
>
32
#
include
<
agrum
/
tools
/
core
/
types
.
h
>
33
// =======================================================
34
#
include
<
agrum
/
FMDP
/
learning
/
core
/
chiSquare
.
h
>
35
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
imddi
.
h
>
36
// =======================================================
37
#
include
<
agrum
/
tools
/
variables
/
labelizedVariable
.
h
>
38
// =======================================================
39
40
41
namespace
gum
{
42
43
// ############################################################################
44
// Constructor & destructor.
45
// ############################################################################
46
47
// ============================================================================
48
// Variable Learner constructor
49
// ============================================================================
50
template
< TESTNAME AttributeSelection,
bool
isScalar >
51
IMDDI< AttributeSelection, isScalar >::IMDDI(
52
MultiDimFunctionGraph
<
double
>*
target
,
53
double
attributeSelectionThreshold
,
54
double
pairSelectionThreshold
,
55
Set
<
const
DiscreteVariable
* >
attributeListe
,
56
const
DiscreteVariable
*
learnedValue
) :
57
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
target
,
58
attributeListe
,
59
learnedValue
),
60
lg__
(&(
this
->
model_
),
pairSelectionThreshold
),
nbTotalObservation__
(0),
61
attributeSelectionThreshold__
(
attributeSelectionThreshold
) {
62
GUM_CONSTRUCTOR
(
IMDDI
);
63
addLeaf__
(
this
->
root_
);
64
}
65
66
// ============================================================================
67
// Reward Learner constructor
68
// ============================================================================
69
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
70
IMDDI
<
AttributeSelection
,
isScalar
>::
IMDDI
(
71
MultiDimFunctionGraph
<
double
>*
target
,
72
double
attributeSelectionThreshold
,
73
double
pairSelectionThreshold
,
74
Set
<
const
DiscreteVariable
* >
attributeListe
) :
75
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
76
target
,
77
attributeListe
,
78
new
LabelizedVariable
(
"Reward"
,
""
, 2)),
79
lg__
(&(
this
->
model_
),
pairSelectionThreshold
),
nbTotalObservation__
(0),
80
attributeSelectionThreshold__
(
attributeSelectionThreshold
) {
81
GUM_CONSTRUCTOR
(
IMDDI
);
82
addLeaf__
(
this
->
root_
);
83
}
84
85
// ============================================================================
86
// Reward Learner constructor
87
// ============================================================================
88
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
89
IMDDI
<
AttributeSelection
,
isScalar
>::~
IMDDI
() {
90
GUM_DESTRUCTOR
(
IMDDI
);
91
for
(
HashTableIteratorSafe
<
NodeId
,
AbstractLeaf
* >
leafIter
92
=
leafMap__
.
beginSafe
();
93
leafIter
!=
leafMap__
.
endSafe
();
94
++
leafIter
)
95
delete
leafIter
.
val
();
96
}
97
98
99
// ############################################################################
100
// Incrementals methods
101
// ############################################################################
102
103
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
104
void
IMDDI
<
AttributeSelection
,
isScalar
>::
addObservation
(
105
const
Observation
*
obs
) {
106
nbTotalObservation__
++;
107
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
addObservation
(
obs
);
108
}
109
110
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
111
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateNodeWithObservation_
(
112
const
Observation
*
newObs
,
113
NodeId
currentNodeId
) {
114
IncrementalGraphLearner
<
AttributeSelection
,
115
isScalar
>::
updateNodeWithObservation_
(
newObs
,
116
currentNodeId
);
117
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
)
118
lg__
.
updateLeaf
(
leafMap__
[
currentNodeId
]);
119
}
120
121
122
// ============================================================================
123
// Updates the tree after a new observation has been added
124
// ============================================================================
125
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
126
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateGraph
() {
127
varOrder__
.
clear
();
128
129
// First xe initialize the node set which will give us the scores
130
Set
<
NodeId
>
currentNodeSet
;
131
currentNodeSet
.
insert
(
this
->
root_
);
132
133
// Then we initialize the pool of variables to consider
134
VariableSelector
vs
(
this
->
setOfVars_
);
135
for
(
vs
.
begin
();
vs
.
hasNext
();
vs
.
next
()) {
136
updateScore__
(
vs
.
current
(),
this
->
root_
,
vs
);
137
}
138
139
// Then, until there's no node remaining
140
while
(!
vs
.
isEmpty
()) {
141
// We select the best var
142
const
DiscreteVariable
*
selectedVar
=
vs
.
select
();
143
varOrder__
.
insert
(
selectedVar
);
144
145
// Then we decide if we update each node according to this var
146
updateNodeSet__
(
currentNodeSet
,
selectedVar
,
vs
);
147
}
148
149
// If there are remaining node that are not leaves after we establish the
150
// var order
151
// these nodes are turned into leaf.
152
for
(
SetIteratorSafe
<
NodeId
>
nodeIter
=
currentNodeSet
.
beginSafe
();
153
nodeIter
!=
currentNodeSet
.
endSafe
();
154
++
nodeIter
)
155
this
->
convertNode2Leaf_
(*
nodeIter
);
156
157
158
if
(
lg__
.
needsUpdate
())
lg__
.
update
();
159
}
160
161
162
// ############################################################################
163
// Updating methods
164
// ############################################################################
165
166
167
// ###################################################################
168
// Select the most relevant variable
169
//
170
// First parameter is the set of variables among which the most
171
// relevant one is choosed
172
// Second parameter is the set of node the will attribute a score
173
// to each variable so that we choose the best.
174
// ###################################################################
175
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
176
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateScore__
(
177
const
DiscreteVariable
*
var
,
178
NodeId
nody
,
179
VariableSelector
&
vs
) {
180
if
(!
this
->
nodeId2Database_
[
nody
]->
isTestRelevant
(
var
))
return
;
181
double
weight
= (
double
)
this
->
nodeId2Database_
[
nody
]->
nbObservation
()
182
/ (
double
)
this
->
nbTotalObservation__
;
183
vs
.
updateScore
(
var
,
184
weight
*
this
->
nodeId2Database_
[
nody
]->
testValue
(
var
),
185
weight
*
this
->
nodeId2Database_
[
nody
]->
testOtherCriterion
(
var
));
186
}
187
188
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
189
void
IMDDI
<
AttributeSelection
,
isScalar
>::
downdateScore__
(
190
const
DiscreteVariable
*
var
,
191
NodeId
nody
,
192
VariableSelector
&
vs
) {
193
if
(!
this
->
nodeId2Database_
[
nody
]->
isTestRelevant
(
var
))
return
;
194
double
weight
= (
double
)
this
->
nodeId2Database_
[
nody
]->
nbObservation
()
195
/ (
double
)
this
->
nbTotalObservation__
;
196
vs
.
downdateScore
(
var
,
197
weight
*
this
->
nodeId2Database_
[
nody
]->
testValue
(
var
),
198
weight
199
*
this
->
nodeId2Database_
[
nody
]->
testOtherCriterion
(
var
));
200
}
201
202
203
// ============================================================================
204
// For each node in the given set, this methods checks whether or not
205
// we should installed the given variable as a test.
206
// If so, the node is updated
207
// ============================================================================
208
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
209
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateNodeSet__
(
210
Set
<
NodeId
>&
nodeSet
,
211
const
DiscreteVariable
*
selectedVar
,
212
VariableSelector
&
vs
) {
213
Set
<
NodeId
>
oldNodeSet
(
nodeSet
);
214
nodeSet
.
clear
();
215
for
(
SetIteratorSafe
<
NodeId
>
nodeIter
=
oldNodeSet
.
beginSafe
();
216
nodeIter
!=
oldNodeSet
.
endSafe
();
217
++
nodeIter
) {
218
if
(
this
->
nodeId2Database_
[*
nodeIter
]->
isTestRelevant
(
selectedVar
)
219
&&
this
->
nodeId2Database_
[*
nodeIter
]->
testValue
(
selectedVar
)
220
>
attributeSelectionThreshold__
) {
221
this
->
transpose_
(*
nodeIter
,
selectedVar
);
222
223
// Then we subtract the from the score given to each variables the
224
// quantity given by this node
225
for
(
vs
.
begin
();
vs
.
hasNext
();
vs
.
next
()) {
226
downdateScore__
(
vs
.
current
(), *
nodeIter
,
vs
);
227
}
228
229
// And finally we add all its child to the new set of nodes
230
// and updates the remaining var's score
231
for
(
Idx
modality
= 0;
232
modality
<
this
->
nodeVarMap_
[*
nodeIter
]->
domainSize
();
233
++
modality
) {
234
NodeId
sonId
=
this
->
nodeSonsMap_
[*
nodeIter
][
modality
];
235
nodeSet
<<
sonId
;
236
237
for
(
vs
.
begin
();
vs
.
hasNext
();
vs
.
next
()) {
238
updateScore__
(
vs
.
current
(),
sonId
,
vs
);
239
}
240
}
241
}
else
{
242
nodeSet
<< *
nodeIter
;
243
}
244
}
245
}
246
247
248
// ============================================================================
249
// Insert a new node with given associated database, var and maybe sons
250
// ============================================================================
251
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
252
NodeId
IMDDI
<
AttributeSelection
,
isScalar
>::
insertLeafNode_
(
253
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
254
const
DiscreteVariable
*
boundVar
,
255
Set
<
const
Observation
* >*
obsSet
) {
256
NodeId
currentNodeId
257
=
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertLeafNode_
(
258
nDB
,
259
boundVar
,
260
obsSet
);
261
262
addLeaf__
(
currentNodeId
);
263
264
return
currentNodeId
;
265
}
266
267
268
// ============================================================================
269
// Changes var associated to a node
270
// ============================================================================
271
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
272
void
IMDDI
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
273
NodeId
currentNodeId
,
274
const
DiscreteVariable
*
desiredVar
) {
275
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
)
276
removeLeaf__
(
currentNodeId
);
277
278
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
279
currentNodeId
,
280
desiredVar
);
281
282
if
(
desiredVar
==
this
->
value_
)
addLeaf__
(
currentNodeId
);
283
}
284
285
286
// ============================================================================
287
// Remove node from graph
288
// ============================================================================
289
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
290
void
IMDDI
<
AttributeSelection
,
isScalar
>::
removeNode_
(
NodeId
currentNodeId
) {
291
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
)
292
removeLeaf__
(
currentNodeId
);
293
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
removeNode_
(
294
currentNodeId
);
295
}
296
297
298
// ============================================================================
299
// Add leaf to aggregator
300
// ============================================================================
301
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
302
void
IMDDI
<
AttributeSelection
,
isScalar
>::
addLeaf__
(
NodeId
currentNodeId
) {
303
leafMap__
.
insert
(
currentNodeId
,
304
new
ConcreteLeaf
<
AttributeSelection
,
isScalar
>(
305
currentNodeId
,
306
this
->
nodeId2Database_
[
currentNodeId
],
307
&(
this
->
valueAssumed_
)));
308
lg__
.
addLeaf
(
leafMap__
[
currentNodeId
]);
309
}
310
311
312
// ============================================================================
313
// Remove leaf from aggregator
314
// ============================================================================
315
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
316
void
IMDDI
<
AttributeSelection
,
isScalar
>::
removeLeaf__
(
NodeId
currentNodeId
) {
317
lg__
.
removeLeaf
(
leafMap__
[
currentNodeId
]);
318
delete
leafMap__
[
currentNodeId
];
319
leafMap__
.
erase
(
currentNodeId
);
320
}
321
322
323
// ============================================================================
324
// Computes the Reduced and Ordered Function Graph associated to this ordered
325
// tree
326
// ============================================================================
327
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
328
void
IMDDI
<
AttributeSelection
,
isScalar
>::
updateFunctionGraph
() {
329
// if( lg__.needsUpdate() || this->needUpdate_ ){
330
rebuildFunctionGraph__
();
331
this
->
needUpdate_
=
false
;
332
// }
333
}
334
335
336
// ============================================================================
337
// Performs the leaves merging
338
// ============================================================================
339
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
340
void
IMDDI
<
AttributeSelection
,
isScalar
>::
rebuildFunctionGraph__
() {
341
// *******************************************************************************************************
342
// Mise à jour de l'aggregateur de feuille
343
lg__
.
update
();
344
345
// *******************************************************************************************************
346
// Reinitialisation du Graphe de Décision
347
this
->
target_
->
clear
();
348
for
(
auto
varIter
=
varOrder__
.
beginSafe
();
varIter
!=
varOrder__
.
endSafe
();
349
++
varIter
)
350
this
->
target_
->
add
(**
varIter
);
351
this
->
target_
->
add
(*
this
->
value_
);
352
353
HashTable
<
NodeId
,
NodeId
>
toTarget
;
354
355
// *******************************************************************************************************
356
// Insertion des feuilles
357
HashTable
<
NodeId
,
AbstractLeaf
* >
treeNode2leaf
=
lg__
.
leavesMap
();
358
HashTable
<
AbstractLeaf
*,
NodeId
>
leaf2DGNode
;
359
for
(
HashTableConstIteratorSafe
<
NodeId
,
AbstractLeaf
* >
treeNodeIter
360
=
treeNode2leaf
.
cbeginSafe
();
361
treeNodeIter
!=
treeNode2leaf
.
cendSafe
();
362
++
treeNodeIter
) {
363
if
(!
leaf2DGNode
.
exists
(
treeNodeIter
.
val
()))
364
leaf2DGNode
.
insert
(
treeNodeIter
.
val
(),
365
insertLeafInFunctionGraph__
(
treeNodeIter
.
val
(),
366
Int2Type
<
isScalar
>()));
367
368
toTarget
.
insert
(
treeNodeIter
.
key
(),
leaf2DGNode
[
treeNodeIter
.
val
()]);
369
}
370
371
// *******************************************************************************************************
372
// Insertion des noeuds internes (avec vérification des possibilités de
373
// fusion)
374
for
(
SequenceIteratorSafe
<
const
DiscreteVariable
* >
varIter
375
=
varOrder__
.
rbeginSafe
();
376
varIter
!=
varOrder__
.
rendSafe
();
377
--
varIter
) {
378
for
(
Link
<
NodeId
>*
curNodeIter
=
this
->
var2Node_
[*
varIter
]->
list
();
379
curNodeIter
;
380
curNodeIter
=
curNodeIter
->
nextLink
()) {
381
NodeId
*
sonsMap
=
static_cast
<
NodeId
* >(
382
SOA_ALLOCATE
(
sizeof
(
NodeId
) * (*
varIter
)->
domainSize
()));
383
for
(
Idx
modality
= 0;
modality
< (*
varIter
)->
domainSize
(); ++
modality
)
384
sonsMap
[
modality
]
385
=
toTarget
[
this
->
nodeSonsMap_
[
curNodeIter
->
element
()][
modality
]];
386
toTarget
.
insert
(
387
curNodeIter
->
element
(),
388
this
->
target_
->
manager
()->
addInternalNode
(*
varIter
,
sonsMap
));
389
}
390
}
391
392
// *******************************************************************************************************
393
// Polish
394
this
->
target_
->
manager
()->
setRootNode
(
toTarget
[
this
->
root_
]);
395
this
->
target_
->
manager
()->
clean
();
396
}
397
398
399
// ============================================================================
400
// Performs the leaves merging
401
// ============================================================================
402
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
403
NodeId
IMDDI
<
AttributeSelection
,
isScalar
>::
insertLeafInFunctionGraph__
(
404
AbstractLeaf
*
leaf
,
405
Int2Type
<
true
>) {
406
double
value
= 0.0;
407
for
(
Idx
moda
= 0;
moda
<
leaf
->
nbModa
();
moda
++) {
408
value
+= (
double
)
leaf
->
effectif
(
moda
) *
this
->
valueAssumed_
.
atPos
(
moda
);
409
}
410
if
(
leaf
->
total
())
value
/= (
double
)
leaf
->
total
();
411
return
this
->
target_
->
manager
()->
addTerminalNode
(
value
);
412
}
413
414
415
// ============================================================================
416
// Performs the leaves merging
417
// ============================================================================
418
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
419
NodeId
IMDDI
<
AttributeSelection
,
isScalar
>::
insertLeafInFunctionGraph__
(
420
AbstractLeaf
*
leaf
,
421
Int2Type
<
false
>) {
422
NodeId
*
sonsMap
=
static_cast
<
NodeId
* >(
423
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
this
->
value_
->
domainSize
()));
424
for
(
Idx
modality
= 0;
modality
<
this
->
value_
->
domainSize
(); ++
modality
) {
425
double
newVal
= 0.0;
426
if
(
leaf
->
total
())
427
newVal
= (
double
)
leaf
->
effectif
(
modality
) / (
double
)
leaf
->
total
();
428
sonsMap
[
modality
] =
this
->
target_
->
manager
()->
addTerminalNode
(
newVal
);
429
}
430
return
this
->
target_
->
manager
()->
addInternalNode
(
this
->
value_
,
sonsMap
);
431
}
432
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669