aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
iti_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file Template Implementations of the ITI datastructure learner
24
* @brief
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27
* GONZALES(@AMU)
28
*/
29
// =======================================================
30
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
31
#
include
<
agrum
/
tools
/
core
/
priorityQueue
.
h
>
32
#
include
<
agrum
/
tools
/
core
/
types
.
h
>
33
// =======================================================
34
#
include
<
agrum
/
FMDP
/
learning
/
core
/
chiSquare
.
h
>
35
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
iti
.
h
>
36
// =======================================================
37
#
include
<
agrum
/
tools
/
variables
/
labelizedVariable
.
h
>
38
// =======================================================
39
40
41
namespace
gum
{
42
43
// ==========================================================================
44
/// @name Constructor & destructor.
45
// ==========================================================================
46
47
// ###################################################################
48
/**
49
* ITI constructor for functions describing the behaviour of one variable
50
* according to a set of other variable such as conditionnal probabilities
51
* @param target : the MultiDimFunctionGraph in which we load the structure
52
* @param attributeSelectionThreshold : threshold under which a node is not
53
* installed (pe-pruning)
54
* @param temporaryAPIfix : Issue in API in regard to IMDDI
55
* @param attributeListe : Set of vars on which we rely to explain the
56
* behaviour of learned variable
57
* @param learnedValue : the variable from which we try to learn the behaviour
58
*/
59
// ###################################################################
60
template
< TESTNAME AttributeSelection,
bool
isScalar >
61
ITI< AttributeSelection, isScalar >::ITI(
MultiDimFunctionGraph
<
double
>*
target
,
62
double
attributeSelectionThreshold
,
63
Set
<
const
DiscreteVariable
* >
attributeListe
,
64
const
DiscreteVariable
*
learnedValue
) :
65
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
target
,
attributeListe
,
learnedValue
),
66
_nbTotalObservation_
(0),
_attributeSelectionThreshold_
(
attributeSelectionThreshold
) {
67
GUM_CONSTRUCTOR
(
ITI
);
68
_staleTable_
.
insert
(
this
->
root_
,
false
);
69
}
70
71
// ###################################################################
72
/**
73
* ITI constructeur for real functions. We try to predict the output of a
74
* function f given a set of variable
75
* @param target : the MultiDimFunctionGraph in which we load the structure
76
* @param attributeSelectionThreshold : threshold under which a node is not
77
* installed (pe-pruning)
78
* @param temporaryAPIfix : Issue in API in regard to IMDDI
79
* @param attributeListeSet of vars on which we rely to explain the
80
* behaviour of learned function
81
*/
82
// ###################################################################
83
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
84
ITI
<
AttributeSelection
,
isScalar
>::
ITI
(
MultiDimFunctionGraph
<
double
>*
target
,
85
double
attributeSelectionThreshold
,
86
Set
<
const
DiscreteVariable
* >
attributeListe
) :
87
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>(
88
target
,
89
attributeListe
,
90
new
LabelizedVariable
(
"Reward"
,
""
, 2)),
91
_nbTotalObservation_
(0),
_attributeSelectionThreshold_
(
attributeSelectionThreshold
) {
92
GUM_CONSTRUCTOR
(
ITI
);
93
_staleTable_
.
insert
(
this
->
root_
,
false
);
94
}
95
96
97
// ==========================================================================
98
/// @name New Observation insertion methods
99
// ==========================================================================
100
101
// ############################################################################
102
/**
103
* Inserts a new observation
104
* @param the new observation to learn
105
*/
106
// ############################################################################
107
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
108
void
ITI
<
AttributeSelection
,
isScalar
>::
addObservation
(
const
Observation
*
obs
) {
109
_nbTotalObservation_
++;
110
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
addObservation
(
obs
);
111
}
112
113
// ############################################################################
114
/**
115
* Will update internal graph's NodeDatabase of given node with the new
116
* observation
117
* @param newObs
118
* @param currentNodeId
119
*/
120
// ############################################################################
121
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
122
void
ITI
<
AttributeSelection
,
isScalar
>::
updateNodeWithObservation_
(
const
Observation
*
newObs
,
123
NodeId
currentNodeId
) {
124
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
updateNodeWithObservation_
(
125
newObs
,
126
currentNodeId
);
127
_staleTable_
[
currentNodeId
] =
true
;
128
}
129
130
131
// ============================================================================
132
/// @name Graph Structure update methods
133
// ============================================================================
134
135
// ############################################################################
136
/// Updates the internal graph after a new observation has been added
137
// ############################################################################
138
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
139
void
ITI
<
AttributeSelection
,
isScalar
>::
updateGraph
() {
140
std
::
vector
<
NodeId
>
filo
;
141
filo
.
push_back
(
this
->
root_
);
142
HashTable
<
NodeId
,
Set
<
const
DiscreteVariable
* >* >
potentialVars
;
143
potentialVars
.
insert
(
this
->
root_
,
new
Set
<
const
DiscreteVariable
* >(
this
->
setOfVars_
));
144
145
146
while
(!
filo
.
empty
()) {
147
NodeId
currentNodeId
=
filo
.
back
();
148
filo
.
pop_back
();
149
150
// First we look for the best var to install on the node
151
double
bestValue
=
_attributeSelectionThreshold_
;
152
Set
<
const
DiscreteVariable
* >
bestVars
;
153
154
for
(
auto
varIter
=
potentialVars
[
currentNodeId
]->
cbeginSafe
();
155
varIter
!=
potentialVars
[
currentNodeId
]->
cendSafe
();
156
++
varIter
)
157
if
(
this
->
nodeId2Database_
[
currentNodeId
]->
isTestRelevant
(*
varIter
)) {
158
double
varValue
=
this
->
nodeId2Database_
[
currentNodeId
]->
testValue
(*
varIter
);
159
if
(
varValue
>=
bestValue
) {
160
if
(
varValue
>
bestValue
) {
161
bestValue
=
varValue
;
162
bestVars
.
clear
();
163
}
164
bestVars
.
insert
(*
varIter
);
165
}
166
}
167
168
// Then We installed Variable a test on that node
169
this
->
updateNode_
(
currentNodeId
,
bestVars
);
170
171
// The we move on the children if needed
172
if
(
this
->
nodeVarMap_
[
currentNodeId
] !=
this
->
value_
) {
173
for
(
Idx
moda
= 0;
moda
<
this
->
nodeVarMap_
[
currentNodeId
]->
domainSize
();
moda
++) {
174
Set
<
const
DiscreteVariable
* >*
itsPotentialVars
175
=
new
Set
<
const
DiscreteVariable
* >(*
potentialVars
[
currentNodeId
]);
176
itsPotentialVars
->
erase
(
this
->
nodeVarMap_
[
currentNodeId
]);
177
NodeId
sonId
=
this
->
nodeSonsMap_
[
currentNodeId
][
moda
];
178
if
(
_staleTable_
[
sonId
]) {
179
filo
.
push_back
(
sonId
);
180
potentialVars
.
insert
(
sonId
,
itsPotentialVars
);
181
}
182
}
183
}
184
}
185
186
for
(
HashTableIteratorSafe
<
NodeId
,
Set
<
const
DiscreteVariable
* >* >
nodeIter
187
=
potentialVars
.
beginSafe
();
188
nodeIter
!=
potentialVars
.
endSafe
();
189
++
nodeIter
)
190
delete
nodeIter
.
val
();
191
}
192
193
194
// ############################################################################
195
/**
196
* inserts a new node in internal graohs
197
* @param nDB : the associated database
198
* @param boundVar : the associated variable
199
* @return the newly created node's id
200
*/
201
// ############################################################################
202
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
203
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
insertNode_
(
204
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
205
const
DiscreteVariable
*
boundVar
) {
206
NodeId
n
=
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertNode_
(
nDB
,
boundVar
);
207
_staleTable_
.
insert
(
n
,
true
);
208
return
n
;
209
}
210
211
212
// ############################################################################
213
/**
214
* Changes the associated variable of a node
215
* @param chgedNodeId : the node to change
216
* @param desiredVar : its new associated variable
217
*/
218
// ############################################################################
219
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
220
void
ITI
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
NodeId
currentNodeId
,
221
const
DiscreteVariable
*
desiredVar
) {
222
if
(
this
->
nodeVarMap_
[
currentNodeId
] !=
desiredVar
) {
223
_staleTable_
[
currentNodeId
] =
true
;
224
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
currentNodeId
,
225
desiredVar
);
226
}
227
}
228
229
230
// ############################################################################
231
/**
232
* Removes a node from the internal graph
233
* @param removedNodeId : the node to remove
234
*/
235
// ############################################################################
236
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
237
void
ITI
<
AttributeSelection
,
isScalar
>::
removeNode_
(
NodeId
currentNodeId
) {
238
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
removeNode_
(
currentNodeId
);
239
_staleTable_
.
erase
(
currentNodeId
);
240
}
241
242
243
// ============================================================================
244
/// @name Function Graph Updating methods
245
// ============================================================================
246
247
// ############################################################################
248
/// Updates target to currently learned graph structure
249
// ############################################################################
250
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
251
void
ITI
<
AttributeSelection
,
isScalar
>::
updateFunctionGraph
() {
252
this
->
target_
->
clear
();
253
this
->
target_
->
manager
()->
setRootNode
(
this
->
_insertNodeInFunctionGraph_
(
this
->
root_
));
254
}
255
256
257
// ############################################################################
258
/**
259
* Inserts an internal node in the target
260
* @param the source node in internal graph
261
* @return the mathcing node id in the target
262
*/
263
// ############################################################################
264
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
265
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
_insertNodeInFunctionGraph_
(
NodeId
currentNodeId
) {
266
if
(
this
->
nodeVarMap_
[
currentNodeId
] ==
this
->
value_
) {
267
NodeId
nody
=
_insertTerminalNode_
(
currentNodeId
);
268
return
nody
;
269
}
270
271
if
(!
this
->
target_
->
variablesSequence
().
exists
(
this
->
nodeVarMap_
[
currentNodeId
])) {
272
this
->
target_
->
add
(*(
this
->
nodeVarMap_
[
currentNodeId
]));
273
}
274
275
NodeId
nody
=
this
->
target_
->
manager
()->
addInternalNode
(
this
->
nodeVarMap_
[
currentNodeId
]);
276
for
(
Idx
moda
= 0;
moda
<
this
->
nodeVarMap_
[
currentNodeId
]->
domainSize
(); ++
moda
) {
277
NodeId
son
=
this
->
_insertNodeInFunctionGraph_
(
this
->
nodeSonsMap_
[
currentNodeId
][
moda
]);
278
this
->
target_
->
manager
()->
setSon
(
nody
,
moda
,
son
);
279
}
280
281
return
nody
;
282
}
283
284
285
// ############################################################################
286
/**
287
* Insert a terminal node in the target.
288
* This function is called if we're learning a real value function.
289
* Inserts then a single value in target.
290
* @param the source node in the learned graph
291
* @return the matching node in the target
292
*/
293
// ############################################################################
294
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
295
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
_insertTerminalNode_
(
NodeId
currentNodeId
,
296
Int2Type
<
false
>) {
297
if
(!
this
->
target_
->
variablesSequence
().
exists
(
this
->
value_
))
298
this
->
target_
->
add
(*(
this
->
value_
));
299
300
Size
tot
=
this
->
nodeId2Database_
[
currentNodeId
]->
nbObservation
();
301
if
(
tot
==
Size
(0))
return
this
->
target_
->
manager
()->
addTerminalNode
(0.0);
302
303
NodeId
*
sonsMap
304
=
static_cast
<
NodeId
* >(
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
this
->
value_
->
domainSize
()));
305
for
(
Idx
modality
= 0;
modality
<
this
->
value_
->
domainSize
(); ++
modality
) {
306
double
newVal
= 0.0;
307
newVal
= (
double
)
this
->
nodeId2Database_
[
currentNodeId
]->
effectif
(
modality
) / (
double
)
tot
;
308
sonsMap
[
modality
] =
this
->
target_
->
manager
()->
addTerminalNode
(
newVal
);
309
}
310
NodeId
nody
=
this
->
target_
->
manager
()->
addInternalNode
(
this
->
value_
,
sonsMap
);
311
return
nody
;
312
}
313
314
315
// ############################################################################
316
/**
317
* Insert a terminal node in the target.
318
* This function is called if we're learning the behaviour of a variable.
319
* Inserts then this variable and the relevant value beneath into target.
320
* @param the source node in the learned graph
321
* @return the matching node in the target
322
*/
323
// ############################################################################
324
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
325
NodeId
ITI
<
AttributeSelection
,
isScalar
>::
_insertTerminalNode_
(
NodeId
currentNodeId
,
326
Int2Type
<
true
>) {
327
double
value
= 0.0;
328
for
(
auto
valIter
=
this
->
nodeId2Database_
[
currentNodeId
]->
cbeginValues
();
329
valIter
!=
this
->
nodeId2Database_
[
currentNodeId
]->
cendValues
();
330
++
valIter
) {
331
value
+= (
double
)
valIter
.
key
() *
valIter
.
val
();
332
}
333
if
(
this
->
nodeId2Database_
[
currentNodeId
]->
nbObservation
())
334
value
/= (
double
)
this
->
nodeId2Database_
[
currentNodeId
]->
nbObservation
();
335
NodeId
nody
=
this
->
target_
->
manager
()->
addTerminalNode
(
value
);
336
return
nody
;
337
}
338
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643