aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
iti_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file Template Implementations of the ITI datastructure learner
24
* @brief
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27
* GONZALES(@AMU)
28
*/
29
// =======================================================
30
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
31
#
include
<
agrum
/
tools
/
core
/
priorityQueue
.
h
>
32
#
include
<
agrum
/
tools
/
core
/
types
.
h
>
33
// =======================================================
34
#
include
<
agrum
/
FMDP
/
learning
/
core
/
chiSquare
.
h
>
35
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
iti
.
h
>
36
// =======================================================
37
#
include
<
agrum
/
tools
/
variables
/
labelizedVariable
.
h
>
38
// =======================================================
39
40
41
namespace
gum
{
42
43
// ==========================================================================
44
/// @name Constructor & destructor.
45
// ==========================================================================
46
47
// ###################################################################
48
/**
49
* ITI constructor for functions describing the behaviour of one variable
50
* according to a set of other variable such as conditionnal probabilities
51
* @param target : the MultiDimFunctionGraph in which we load the structure
52
* @param attributeSelectionThreshold : threshold under which a node is not
53
* installed (pe-pruning)
54
* @param temporaryAPIfix : Issue in API in regard to IMDDI
55
* @param attributeListe : Set of vars on which we rely to explain the
56
* behaviour of learned variable
57
* @param learnedValue : the variable from which we try to learn the behaviour
58
*/
59
// ###################################################################
60
template
< TESTNAME AttributeSelection,
bool
isScalar >
61
ITI< AttributeSelection, isScalar >::ITI(
62
MultiDimFunctionGraph
<
double
>*
target
,
63
double
attributeSelectionThreshold
,
64
Set
<
const
DiscreteVariable
* >
attributeListe
,
65
const
DiscreteVariable
*
learnedValue
) :
66
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
target
,
67
attributeListe
,
68
learnedValue
),
69
nbTotalObservation__
(0),
70
attributeSelectionThreshold__
(
attributeSelectionThreshold
) {
71
GUM_CONSTRUCTOR
(
ITI
);
72
staleTable__
.
insert
(
this
->
root_
,
false
);
73
}
74
75
// ###################################################################
76
/**
77
* ITI constructeur for real functions. We try to predict the output of a
78
* function f given a set of variable
79
* @param target : the MultiDimFunctionGraph in which we load the structure
80
* @param attributeSelectionThreshold : threshold under which a node is not
81
* installed (pe-pruning)
82
* @param temporaryAPIfix : Issue in API in regard to IMDDI
83
* @param attributeListeSet of vars on which we rely to explain the
84
* behaviour of learned function
85
*/
86
// ###################################################################
87
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
88
ITI
<
AttributeSelection
,
isScalar
>::
ITI
(
89
MultiDimFunctionGraph
<
double
>*
target
,
90
double
attributeSelectionThreshold
,
91
Set
<
const
DiscreteVariable
* >
attributeListe
) :
92
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
93
target
,
94
attributeListe
,
95
new
LabelizedVariable
(
"Reward"
,
""
, 2)),
96
nbTotalObservation__
(0),
97
attributeSelectionThreshold__
(
attributeSelectionThreshold
) {
98
GUM_CONSTRUCTOR
(
ITI
);
99
staleTable__
.
insert
(
this
->
root_
,
false
);
100
}
101
102
103
// ==========================================================================
104
/// @name New Observation insertion methods
105
// ==========================================================================
106
107
// ############################################################################
108
/**
109
* Inserts a new observation
110
* @param the new observation to learn
111
*/
112
// ############################################################################
113
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
114
void
115
ITI
<
AttributeSelection
,
isScalar
>::
addObservation
(
const
Observation
*
obs
) {
116
nbTotalObservation__
++;
117
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
addObservation
(
obs
);
118
}
119
120
// ############################################################################
121
/**
122
* Will update internal graph's NodeDatabase of given node with the new
123
* observation
124
* @param newObs
125
* @param currentNodeId
126
*/
127
// ############################################################################
128
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
129
void
ITI
<
AttributeSelection
,
isScalar
>::
updateNodeWithObservation_
(
130
const
Observation
*
newObs
,
131
NodeId
currentNodeId
) {
132
IncrementalGraphLearner
<
AttributeSelection
,
133
isScalar
>::
updateNodeWithObservation_
(
newObs
,
134
currentNodeId
);
135
staleTable__
[
currentNodeId
] =
true
;
136
}
137
138
139
// ============================================================================
140
/// @name Graph Structure update methods
141
// ============================================================================
142
143
// ############################################################################
144
/// Updates the internal graph after a new observation has been added
145
// ############################################################################
146
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
147
void
ITI
<
AttributeSelection
,
isScalar
>::
updateGraph
() {
148
std
::
vector
<
NodeId
>
filo
;
149
filo
.
push_back
(
this
->
root_
);
150
HashTable
<
NodeId
,
Set
<
const
DiscreteVariable
* >* >
potentialVars
;
151
potentialVars
.
insert
(
this
->
root_
,
152
new
Set
<
const
DiscreteVariable
* >(
this
->
setOfVars_
));
153
154
155
while
(!
filo
.
empty
()) {
156
NodeId
currentNodeId
=
filo
.
back
();
157
filo
.
pop_back
();
158
159
// First we look for the best var to install on the node
160
double
bestValue
=
attributeSelectionThreshold__
;
161
Set
<
const
DiscreteVariable
* >
bestVars
;
162
163
for
(
auto
varIter
=
potentialVars
[
currentNodeId
]->
cbeginSafe
();
164
varIter
!=
potentialVars
[
currentNodeId
]->
cendSafe
();
165
++
varIter
)
166
if
(
this
->
nodeId2Database_
[
currentNodeId
]->
isTestRelevant
(*
varIter
)) {
167
double
varValue
168
=
this
->
nodeId2Database_
[
currentNodeId
]->
testValue
(*
varIter
);
169
if
(
varValue
>=
bestValue
) {
170
if
(
varValue
>
bestValue
) {
171
bestValue
=
varValue
;
172
bestVars
.
clear
();
173
}
174
bestVars
.
insert
(*
varIter
);
175
}
176
}
177
178
// Then We installed Variable a test on that node
179
this
->
updateNode_
(
currentNodeId
,
bestVars
);
180
181
// The we move on the children if needed
182
if
(
this
->
nodeVarMap_
[
currentNodeId
] !=
this
->
value_
) {
183
for
(
Idx
moda
= 0;
moda
<
this
->
nodeVarMap_
[
currentNodeId
]->
domainSize
();
184
moda
++) {
185
Set
<
const
DiscreteVariable
* >*
itsPotentialVars
186
=
new
Set
<
const
DiscreteVariable
* >(*
potentialVars
[
currentNodeId
]);
187
itsPotentialVars
->
erase
(
this
->
nodeVarMap_
[
currentNodeId
]);
188
NodeId
sonId
=
this
->
nodeSonsMap_
[
currentNodeId
][
moda
];
189
if
(
staleTable__
[
sonId
]) {
190
filo
.
push_back
(
sonId
);
191
potentialVars
.
insert
(
sonId
,
itsPotentialVars
);
192
}
193
}
194
}
195
}
196
197
for
(
HashTableIteratorSafe
<
NodeId
,
Set
<
const
DiscreteVariable
* >* >
nodeIter
198
=
potentialVars
.
beginSafe
();
199
nodeIter
!=
potentialVars
.
endSafe
();
200
++
nodeIter
)
201
delete
nodeIter
.
val
();
202
}
203
204
205
// ############################################################################
206
/**
207
* inserts a new node in internal graohs
208
* @param nDB : the associated database
209
* @param boundVar : the associated variable
210
* @return the newly created node's id
211
*/
212
// ############################################################################
213
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
214
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
insertNode_
(
215
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
216
const
DiscreteVariable
*
boundVar
) {
217
NodeId
n
218
=
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertNode_
(
219
nDB
,
220
boundVar
);
221
staleTable__
.
insert
(
n
,
true
);
222
return
n
;
223
}
224
225
226
// ############################################################################
227
/**
228
* Changes the associated variable of a node
229
* @param chgedNodeId : the node to change
230
* @param desiredVar : its new associated variable
231
*/
232
// ############################################################################
233
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
234
void
ITI
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
235
NodeId
currentNodeId
,
236
const
DiscreteVariable
*
desiredVar
) {
237
if
(
this
->
nodeVarMap_
[
currentNodeId
] !=
desiredVar
) {
238
staleTable__
[
currentNodeId
] =
true
;
239
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
240
currentNodeId
,
241
desiredVar
);
242
}
243
}
244
245
246
// ############################################################################
247
/**
248
* Removes a node from the internal graph
249
* @param removedNodeId : the node to remove
250
*/
251
// ############################################################################
252
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
253
void
ITI
<
AttributeSelection
,
isScalar
>::
removeNode_
(
NodeId
currentNodeId
) {
254
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
removeNode_
(
255
currentNodeId
);
256
staleTable__
.
erase
(
currentNodeId
);
257
}
258
259
260
// ============================================================================
261
/// @name Function Graph Updating methods
262
// ============================================================================
263
264
// ############################################################################
265
/// Updates target to currently learned graph structure
266
// ############################################################################
267
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
268
void
ITI
<
AttributeSelection
,
isScalar
>::
updateFunctionGraph
() {
269
this
->
target_
->
clear
();
270
this
->
target_
->
manager
()->
setRootNode
(
271
this
->
insertNodeInFunctionGraph__
(
this
->
root_
));
272
}
273
274
275
// ############################################################################
276
/**
277
* Inserts an internal node in the target
278
* @param the source node in internal graph
279
* @return the mathcing node id in the target
280
*/
281
// ############################################################################
282
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
283
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
insertNodeInFunctionGraph__
(
284
NodeId
currentNodeId
) {
285
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
) {
286
NodeId
nody
=
insertTerminalNode__
(
currentNodeId
);
287
return
nody
;
288
}
289
290
if
(!
this
->
target_
->
variablesSequence
().
exists
(
291
this
->
nodeVarMap_
[
currentNodeId
])) {
292
this
->
target_
->
add
(*(
this
->
nodeVarMap_
[
currentNodeId
]));
293
}
294
295
NodeId
nody
=
this
->
target_
->
manager
()->
addInternalNode
(
296
this
->
nodeVarMap_
[
currentNodeId
]);
297
for
(
Idx
moda
= 0;
moda
<
this
->
nodeVarMap_
[
currentNodeId
]->
domainSize
();
298
++
moda
) {
299
NodeId
son
=
this
->
insertNodeInFunctionGraph__
(
300
this
->
nodeSonsMap_
[
currentNodeId
][
moda
]);
301
this
->
target_
->
manager
()->
setSon
(
nody
,
moda
,
son
);
302
}
303
304
return
nody
;
305
}
306
307
308
// ############################################################################
309
/**
310
* Insert a terminal node in the target.
311
* This function is called if we're learning a real value function.
312
* Inserts then a single value in target.
313
* @param the source node in the learned graph
314
* @return the matching node in the target
315
*/
316
// ############################################################################
317
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
318
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
insertTerminalNode__
(
319
NodeId
currentNodeId
,
320
Int2Type
<
false
>) {
321
if
(!
this
->
target_
->
variablesSequence
().
exists
(
this
->
value_
))
322
this
->
target_
->
add
(*(
this
->
value_
));
323
324
Size
tot
=
this
->
nodeId2Database_
[
currentNodeId
]->
nbObservation
();
325
if
(
tot
==
Size
(0))
return
this
->
target_
->
manager
()->
addTerminalNode
(0.0);
326
327
NodeId
*
sonsMap
=
static_cast
<
NodeId
* >(
328
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
this
->
value_
->
domainSize
()));
329
for
(
Idx
modality
= 0;
modality
<
this
->
value_
->
domainSize
(); ++
modality
) {
330
double
newVal
= 0.0;
331
newVal
= (
double
)
this
->
nodeId2Database_
[
currentNodeId
]->
effectif
(
modality
)
332
/ (
double
)
tot
;
333
sonsMap
[
modality
] =
this
->
target_
->
manager
()->
addTerminalNode
(
newVal
);
334
}
335
NodeId
nody
=
this
->
target_
->
manager
()->
addInternalNode
(
this
->
value_
,
sonsMap
);
336
return
nody
;
337
}
338
339
340
// ############################################################################
341
/**
342
* Insert a terminal node in the target.
343
* This function is called if we're learning the behaviour of a variable.
344
* Inserts then this variable and the relevant value beneath into target.
345
* @param the source node in the learned graph
346
* @return the matching node in the target
347
*/
348
// ############################################################################
349
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
350
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
insertTerminalNode__
(
351
NodeId
currentNodeId
,
352
Int2Type
<
true
>) {
353
double
value
= 0.0;
354
for
(
auto
valIter
=
this
->
nodeId2Database_
[
currentNodeId
]->
cbeginValues
();
355
valIter
!=
this
->
nodeId2Database_
[
currentNodeId
]->
cendValues
();
356
++
valIter
) {
357
value
+= (
double
)
valIter
.
key
() *
valIter
.
val
();
358
}
359
if
(
this
->
nodeId2Database_
[
currentNodeId
]->
nbObservation
())
360
value
/= (
double
)
this
->
nodeId2Database_
[
currentNodeId
]->
nbObservation
();
361
NodeId
nody
=
this
->
target_
->
manager
()->
addTerminalNode
(
value
);
362
return
nody
;
363
}
364
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669