aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
incrementalGraphLearner_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27
* GONZALES(@AMU)
28
*/
29
// =======================================================
30
#
include
<
queue
>
31
// =======================================================
32
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
33
#
include
<
agrum
/
tools
/
core
/
multiPriorityQueue
.
h
>
34
#
include
<
agrum
/
tools
/
core
/
types
.
h
>
35
// =======================================================
36
#
include
<
agrum
/
FMDP
/
learning
/
core
/
chiSquare
.
h
>
37
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
incrementalGraphLearner
.
h
>
38
// =======================================================
39
#
include
<
agrum
/
tools
/
variables
/
discreteVariable
.
h
>
40
// =======================================================
41
42
namespace
gum
{
43
44
// ============================================================================
45
/// @name Constructor & destructor.
46
// ============================================================================
47
48
// ############################################################################
49
/**
50
* Default constructor
51
* @param target : the output diagram usable by the outside
52
* @param attributesSet : set of variables from which we try to describe the
53
* learned function
54
* @param learnVariable : if we tried to learn a the behaviour of a variable
55
* given variable given another set of variables, this is the one. If we are
56
* learning a function of real value, this is just a computationnal trick
57
* (and is to be deprecated)
58
*/
59
// ############################################################################
60
template
< TESTNAME AttributeSelection,
bool
isScalar >
61
IncrementalGraphLearner< AttributeSelection, isScalar >::IncrementalGraphLearner(
62
MultiDimFunctionGraph
<
double
>*
target
,
63
Set
<
const
DiscreteVariable
* >
varList
,
64
const
DiscreteVariable
*
value
) :
65
target_
(
target
),
66
setOfVars_
(
varList
),
value_
(
value
) {
67
GUM_CONSTRUCTOR
(
IncrementalGraphLearner
);
68
69
for
(
auto
varIter
=
setOfVars_
.
cbeginSafe
();
varIter
!=
setOfVars_
.
cendSafe
(); ++
varIter
)
70
var2Node_
.
insert
(*
varIter
,
new
LinkedList
<
NodeId
>());
71
var2Node_
.
insert
(
value_
,
new
LinkedList
<
NodeId
>());
72
73
model_
.
addNode
();
74
this
->
root_
75
=
insertLeafNode_
(
new
NodeDatabase
<
AttributeSelection
,
isScalar
>(&
setOfVars_
,
value_
),
76
value_
,
77
new
Set
<
const
Observation
* >());
78
}
79
80
81
// ############################################################################
82
/// Default destructor
83
// ############################################################################
84
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
85
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::~
IncrementalGraphLearner
() {
86
for
(
auto
nodeIter
=
nodeId2Database_
.
beginSafe
();
nodeIter
!=
nodeId2Database_
.
endSafe
();
87
++
nodeIter
)
88
delete
nodeIter
.
val
();
89
90
for
(
auto
nodeIter
=
nodeSonsMap_
.
beginSafe
();
nodeIter
!=
nodeSonsMap_
.
endSafe
(); ++
nodeIter
)
91
SOA_DEALLOCATE
(
nodeIter
.
val
(),
sizeof
(
NodeId
) *
nodeVarMap_
[
nodeIter
.
key
()]->
domainSize
());
92
93
for
(
auto
varIter
=
var2Node_
.
beginSafe
();
varIter
!=
var2Node_
.
endSafe
(); ++
varIter
)
94
delete
varIter
.
val
();
95
96
for
(
auto
nodeIter
=
leafDatabase_
.
beginSafe
();
nodeIter
!=
leafDatabase_
.
endSafe
(); ++
nodeIter
)
97
delete
nodeIter
.
val
();
98
99
_clearValue_
();
100
101
GUM_DESTRUCTOR
(
IncrementalGraphLearner
);
102
}
103
104
105
// ============================================================================
106
/// @name New Observation insertion methods
107
// ============================================================================
108
109
// ############################################################################
110
/**
111
* Inserts a new observation
112
* @param the new observation to learn
113
*/
114
// ############################################################################
115
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
116
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
addObservation
(
117
const
Observation
*
newObs
) {
118
_assumeValue_
(
newObs
);
119
120
// The we go across the tree
121
NodeId
currentNodeId
=
root_
;
122
123
while
(
nodeSonsMap_
.
exists
(
currentNodeId
)) {
124
// On each encountered node, we update the database
125
updateNodeWithObservation_
(
newObs
,
currentNodeId
);
126
127
// The we select the next to go throught
128
currentNodeId
=
nodeSonsMap_
[
currentNodeId
][
_branchObs_
(
newObs
,
nodeVarMap_
[
currentNodeId
])];
129
}
130
131
// On final insertion into the leave we reach
132
updateNodeWithObservation_
(
newObs
,
currentNodeId
);
133
leafDatabase_
[
currentNodeId
]->
insert
(
newObs
);
134
}
135
136
137
// ============================================================================
138
/// @name New Observation insertion methods
139
// ============================================================================
140
141
// ############################################################################
142
/// If a new modality appears to exists for given variable,
143
/// call this method to turn every associated node to this variable into leaf.
144
/// Graph has then indeed to be revised
145
// ############################################################################
146
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
147
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
updateVar
(
148
const
DiscreteVariable
*
var
) {
149
Link
<
NodeId
>*
nodIter
=
var2Node_
[
var
]->
list
();
150
Link
<
NodeId
>*
nni
=
nullptr
;
151
while
(
nodIter
) {
152
nni
=
nodIter
->
nextLink
();
153
convertNode2Leaf_
(
nodIter
->
element
());
154
nodIter
=
nni
;
155
}
156
}
157
158
159
// ############################################################################
160
/**
161
* From the given sets of node, selects randomly one and installs it
162
* on given node. Chechks of course if node's current variable is not in that
163
* set first.
164
* @param nody : the node we update
165
* @param bestVar : the set of interessting vars to be installed here
166
*/
167
// ############################################################################
168
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
169
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
updateNode_
(
170
NodeId
updatedNode
,
171
Set
<
const
DiscreteVariable
* >&
varsOfInterest
) {
172
// If this node has no interesting variable, we turn it into a leaf
173
if
(
varsOfInterest
.
empty
()) {
174
convertNode2Leaf_
(
updatedNode
);
175
return
;
176
}
177
178
// If this node has already one of the best variable intalled as test, we
179
// move on
180
if
(
nodeVarMap_
.
exists
(
updatedNode
) &&
varsOfInterest
.
exists
(
nodeVarMap_
[
updatedNode
])) {
181
return
;
182
}
183
184
// In any other case we have to install variable as best test
185
Idx
randy
= (
Idx
)(
std
::
rand
() /
RAND_MAX
) *
varsOfInterest
.
size
(),
basc
= 0;
186
SetConstIteratorSafe
<
const
DiscreteVariable
* >
varIter
;
187
for
(
varIter
=
varsOfInterest
.
cbeginSafe
(),
basc
= 0;
188
varIter
!=
varsOfInterest
.
cendSafe
() &&
basc
<
randy
;
189
++
varIter
,
basc
++)
190
;
191
192
transpose_
(
updatedNode
, *
varIter
);
193
}
194
195
196
// ############################################################################
197
/// Turns the given node into a leaf if not already so
198
// ############################################################################
199
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
200
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
convertNode2Leaf_
(
201
NodeId
currentNodeId
) {
202
if
(
nodeVarMap_
[
currentNodeId
] !=
value_
) {
203
leafDatabase_
.
insert
(
currentNodeId
,
new
Set
<
const
Observation
* >());
204
205
// Resolving potential sons issue
206
for
(
Idx
modality
= 0;
modality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
(); ++
modality
) {
207
NodeId
sonId
=
nodeSonsMap_
[
currentNodeId
][
modality
];
208
convertNode2Leaf_
(
sonId
);
209
(*
leafDatabase_
[
currentNodeId
]) = (*
leafDatabase_
[
currentNodeId
]) + *(
leafDatabase_
[
sonId
]);
210
removeNode_
(
sonId
);
211
}
212
213
SOA_DEALLOCATE
(
nodeSonsMap_
[
currentNodeId
],
214
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
());
215
nodeSonsMap_
.
erase
(
currentNodeId
);
216
217
chgNodeBoundVar_
(
currentNodeId
,
value_
);
218
}
219
}
220
221
222
// ############################################################################
223
/// Installs given variable to the given node, ensuring that the variable
224
/// is not present in its subtree
225
// ############################################################################
226
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
227
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
transpose_
(
228
NodeId
currentNodeId
,
229
const
DiscreteVariable
*
desiredVar
) {
230
// **************************************************************************************
231
// Si le noeud courant contient déjà la variable qu'on souhaite lui amener
232
// Il n'y a rien à faire
233
if
(
nodeVarMap_
[
currentNodeId
] ==
desiredVar
) {
return
; }
234
235
// **************************************************************************************
236
// Si le noeud courant est terminal,
237
// Il faut artificiellement insérer un noeud liant à la variable
238
if
(
nodeVarMap_
[
currentNodeId
] ==
value_
) {
239
// We turned this leaf into an internal node.
240
// This mean that we'll need to install children leaves for each value of
241
// desiredVar
242
243
// First We must prepare these new leaves NodeDatabases and Sets<const
244
// Observation*>
245
NodeDatabase
<
AttributeSelection
,
isScalar
>**
dbMap
246
=
static_cast
<
NodeDatabase
<
AttributeSelection
,
isScalar
>** >(
SOA_ALLOCATE
(
247
sizeof
(
NodeDatabase
<
AttributeSelection
,
isScalar
>*) *
desiredVar
->
domainSize
()));
248
Set
<
const
Observation
* >**
obsetMap
=
static_cast
<
Set
<
const
Observation
* >** >(
249
SOA_ALLOCATE
(
sizeof
(
Set
<
const
Observation
* >*) *
desiredVar
->
domainSize
()));
250
for
(
Idx
modality
= 0;
modality
<
desiredVar
->
domainSize
(); ++
modality
) {
251
dbMap
[
modality
] =
new
NodeDatabase
<
AttributeSelection
,
isScalar
>(&
setOfVars_
,
value_
);
252
obsetMap
[
modality
] =
new
Set
<
const
Observation
* >();
253
}
254
for
(
SetIteratorSafe
<
const
Observation
* >
obsIter
255
=
leafDatabase_
[
currentNodeId
]->
beginSafe
();
256
leafDatabase_
[
currentNodeId
]->
endSafe
() !=
obsIter
;
257
++
obsIter
) {
258
dbMap
[
_branchObs_
(*
obsIter
,
desiredVar
)]->
addObservation
(*
obsIter
);
259
obsetMap
[
_branchObs_
(*
obsIter
,
desiredVar
)]->
insert
(*
obsIter
);
260
}
261
262
// Then we can install each new leaves (and put in place the sonsMap)
263
NodeId
*
sonsMap
264
=
static_cast
<
NodeId
* >(
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
desiredVar
->
domainSize
()));
265
for
(
Idx
modality
= 0;
modality
<
desiredVar
->
domainSize
(); ++
modality
)
266
sonsMap
[
modality
] =
insertLeafNode_
(
dbMap
[
modality
],
value_
,
obsetMap
[
modality
]);
267
268
// Some necessary clean up
269
SOA_DEALLOCATE
(
dbMap
,
270
sizeof
(
NodeDatabase
<
AttributeSelection
,
isScalar
>*)
271
*
desiredVar
->
domainSize
());
272
SOA_DEALLOCATE
(
obsetMap
,
sizeof
(
Set
<
const
Observation
* >*) *
desiredVar
->
domainSize
());
273
274
// And finally we can turn the node into an internal node associated to
275
// desiredVar
276
chgNodeBoundVar_
(
currentNodeId
,
desiredVar
);
277
nodeSonsMap_
.
insert
(
currentNodeId
,
sonsMap
);
278
279
return
;
280
}
281
282
// *************************************************************************************
283
// Remains the general case where currentNodeId is an internal node.
284
285
// First we ensure that children node use desiredVar as variable
286
for
(
Idx
modality
= 0;
modality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
(); ++
modality
)
287
transpose_
(
nodeSonsMap_
[
currentNodeId
][
modality
],
desiredVar
);
288
289
// Sequence<NodeDatabase<AttributeSelection, isScalar>*>
290
// sonsNodeDatabase =
291
// nodeId2Database_[currentNodeId]->splitOnVar(desiredVar);
292
NodeId
*
sonsMap
293
=
static_cast
<
NodeId
* >(
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
desiredVar
->
domainSize
()));
294
295
// Then we create the new mapping
296
for
(
Idx
desiredVarModality
= 0;
desiredVarModality
<
desiredVar
->
domainSize
();
297
++
desiredVarModality
) {
298
NodeId
*
grandSonsMap
=
static_cast
<
NodeId
* >(
299
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
()));
300
NodeDatabase
<
AttributeSelection
,
isScalar
>*
sonDB
301
=
new
NodeDatabase
<
AttributeSelection
,
isScalar
>(&
setOfVars_
,
value_
);
302
for
(
Idx
currentVarModality
= 0;
303
currentVarModality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
();
304
++
currentVarModality
) {
305
grandSonsMap
[
currentVarModality
]
306
=
nodeSonsMap_
[
nodeSonsMap_
[
currentNodeId
][
currentVarModality
]][
desiredVarModality
];
307
sonDB
->
operator
+=((*
nodeId2Database_
[
grandSonsMap
[
currentVarModality
]]));
308
}
309
310
sonsMap
[
desiredVarModality
]
311
=
insertInternalNode_
(
sonDB
,
nodeVarMap_
[
currentNodeId
],
grandSonsMap
);
312
}
313
314
// Finally we clean the old remaining nodes
315
for
(
Idx
currentVarModality
= 0;
currentVarModality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
();
316
++
currentVarModality
) {
317
removeNode_
(
nodeSonsMap_
[
currentNodeId
][
currentVarModality
]);
318
}
319
320
// We suppress the old sons map and remap to the new one
321
SOA_DEALLOCATE
(
nodeSonsMap_
[
currentNodeId
],
322
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
());
323
nodeSonsMap_
[
currentNodeId
] =
sonsMap
;
324
325
chgNodeBoundVar_
(
currentNodeId
,
desiredVar
);
326
}
327
328
329
// ############################################################################
330
/**
331
* inserts a new node in internal graph
332
* @param nDB : the associated database
333
* @param boundVar : the associated variable
334
* @return the newly created node's id
335
*/
336
// ############################################################################
337
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
338
NodeId
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertNode_
(
339
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
340
const
DiscreteVariable
*
boundVar
) {
341
NodeId
newNodeId
=
model_
.
addNode
();
342
nodeVarMap_
.
insert
(
newNodeId
,
boundVar
);
343
nodeId2Database_
.
insert
(
newNodeId
,
nDB
);
344
var2Node_
[
boundVar
]->
addLink
(
newNodeId
);
345
346
needUpdate_
=
true
;
347
348
return
newNodeId
;
349
}
350
351
352
// ############################################################################
353
/**
354
* inserts a new internal node in internal graph
355
* @param nDB : the associated database
356
* @param boundVar : the associated variable
357
* @param sonsMap : a table giving node's sons node
358
* @return the newly created node's id
359
*/
360
// ############################################################################
361
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
362
NodeId
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertInternalNode_
(
363
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
364
const
DiscreteVariable
*
boundVar
,
365
NodeId
*
sonsMap
) {
366
NodeId
newNodeId
=
this
->
insertNode_
(
nDB
,
boundVar
);
367
nodeSonsMap_
.
insert
(
newNodeId
,
sonsMap
);
368
return
newNodeId
;
369
}
370
371
372
// ############################################################################
373
/**
374
* inserts a new leaf node in internal graohs
375
* @param nDB : the associated database
376
* @param boundVar : the associated variable
377
* @param obsSet : the set of observation this leaf retains
378
* @return the newly created node's id
379
*/
380
// ############################################################################
381
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
382
NodeId
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertLeafNode_
(
383
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
384
const
DiscreteVariable
*
boundVar
,
385
Set
<
const
Observation
* >*
obsSet
) {
386
NodeId
newNodeId
=
this
->
insertNode_
(
nDB
,
boundVar
);
387
leafDatabase_
.
insert
(
newNodeId
,
obsSet
);
388
return
newNodeId
;
389
}
390
391
392
// ############################################################################
393
/**
394
* Changes the associated variable of a node
395
* @param chgedNodeId : the node to change
396
* @param desiredVar : its new associated variable
397
*/
398
// ############################################################################
399
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
400
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
401
NodeId
currentNodeId
,
402
const
DiscreteVariable
*
desiredVar
) {
403
if
(
nodeVarMap_
[
currentNodeId
] ==
desiredVar
)
return
;
404
405
var2Node_
[
nodeVarMap_
[
currentNodeId
]]->
searchAndRemoveLink
(
currentNodeId
);
406
var2Node_
[
desiredVar
]->
addLink
(
currentNodeId
);
407
nodeVarMap_
[
currentNodeId
] =
desiredVar
;
408
409
if
(
nodeVarMap_
[
currentNodeId
] !=
value_
&&
leafDatabase_
.
exists
(
currentNodeId
)) {
410
delete
leafDatabase_
[
currentNodeId
];
411
leafDatabase_
.
erase
(
currentNodeId
);
412
}
413
414
if
(
nodeVarMap_
[
currentNodeId
] ==
value_
&& !
leafDatabase_
.
exists
(
currentNodeId
)) {
415
leafDatabase_
.
insert
(
currentNodeId
,
new
Set
<
const
Observation
* >());
416
}
417
418
needUpdate_
=
true
;
419
}
420
421
422
// ############################################################################
423
/**
424
* Removes a node from the internal graph
425
* @param removedNodeId : the node to remove
426
*/
427
// ############################################################################
428
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
429
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
removeNode_
(
NodeId
currentNodeId
) {
430
// Retriat de l'id
431
model_
.
eraseNode
(
currentNodeId
);
432
433
// Retrait du vecteur fils
434
if
(
nodeSonsMap_
.
exists
(
currentNodeId
)) {
435
SOA_DEALLOCATE
(
nodeSonsMap_
[
currentNodeId
],
436
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
());
437
nodeSonsMap_
.
erase
(
currentNodeId
);
438
}
439
440
if
(
leafDatabase_
.
exists
(
currentNodeId
)) {
441
delete
leafDatabase_
[
currentNodeId
];
442
leafDatabase_
.
erase
(
currentNodeId
);
443
}
444
445
// Retrait de la variable
446
var2Node_
[
nodeVarMap_
[
currentNodeId
]]->
searchAndRemoveLink
(
currentNodeId
);
447
nodeVarMap_
.
erase
(
currentNodeId
);
448
449
// Retrait du NodeDatabase
450
delete
nodeId2Database_
[
currentNodeId
];
451
nodeId2Database_
.
erase
(
currentNodeId
);
452
453
needUpdate_
=
true
;
454
}
455
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643