aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
incrementalGraphLearner_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27
* GONZALES(@AMU)
28
*/
29
// =======================================================
30
#
include
<
queue
>
31
// =======================================================
32
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
33
#
include
<
agrum
/
tools
/
core
/
multiPriorityQueue
.
h
>
34
#
include
<
agrum
/
tools
/
core
/
types
.
h
>
35
// =======================================================
36
#
include
<
agrum
/
FMDP
/
learning
/
core
/
chiSquare
.
h
>
37
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
incrementalGraphLearner
.
h
>
38
// =======================================================
39
#
include
<
agrum
/
tools
/
variables
/
discreteVariable
.
h
>
40
// =======================================================
41
42
namespace
gum
{
43
44
// ============================================================================
45
/// @name Constructor & destructor.
46
// ============================================================================
47
48
// ############################################################################
49
/**
50
* Default constructor
51
* @param target : the output diagram usable by the outside
52
* @param attributesSet : set of variables from which we try to describe the
53
* learned function
54
* @param learnVariable : if we tried to learn a the behaviour of a variable
55
* given variable given another set of variables, this is the one. If we are
56
* learning a function of real value, this is just a computationnal trick
57
* (and is to be deprecated)
58
*/
59
// ############################################################################
60
template
< TESTNAME AttributeSelection,
bool
isScalar >
61
IncrementalGraphLearner< AttributeSelection, isScalar >::IncrementalGraphLearner(
62
MultiDimFunctionGraph
<
double
>*
target
,
63
Set
<
const
DiscreteVariable
* >
varList
,
64
const
DiscreteVariable
*
value
) :
65
target_
(
target
),
66
setOfVars_
(
varList
),
value_
(
value
) {
67
GUM_CONSTRUCTOR
(
IncrementalGraphLearner
);
68
69
for
(
auto
varIter
=
setOfVars_
.
cbeginSafe
();
varIter
!=
setOfVars_
.
cendSafe
();
70
++
varIter
)
71
var2Node_
.
insert
(*
varIter
,
new
LinkedList
<
NodeId
>());
72
var2Node_
.
insert
(
value_
,
new
LinkedList
<
NodeId
>());
73
74
model_
.
addNode
();
75
this
->
root_
=
insertLeafNode_
(
76
new
NodeDatabase
<
AttributeSelection
,
isScalar
>(&
setOfVars_
,
value_
),
77
value_
,
78
new
Set
<
const
Observation
* >());
79
}
80
81
82
// ############################################################################
83
/// Default destructor
84
// ############################################################################
85
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
86
IncrementalGraphLearner
<
AttributeSelection
,
87
isScalar
>::~
IncrementalGraphLearner
() {
88
for
(
auto
nodeIter
=
nodeId2Database_
.
beginSafe
();
89
nodeIter
!=
nodeId2Database_
.
endSafe
();
90
++
nodeIter
)
91
delete
nodeIter
.
val
();
92
93
for
(
auto
nodeIter
=
nodeSonsMap_
.
beginSafe
();
94
nodeIter
!=
nodeSonsMap_
.
endSafe
();
95
++
nodeIter
)
96
SOA_DEALLOCATE
(
nodeIter
.
val
(),
97
sizeof
(
NodeId
) *
nodeVarMap_
[
nodeIter
.
key
()]->
domainSize
());
98
99
for
(
auto
varIter
=
var2Node_
.
beginSafe
();
varIter
!=
var2Node_
.
endSafe
();
100
++
varIter
)
101
delete
varIter
.
val
();
102
103
for
(
auto
nodeIter
=
leafDatabase_
.
beginSafe
();
104
nodeIter
!=
leafDatabase_
.
endSafe
();
105
++
nodeIter
)
106
delete
nodeIter
.
val
();
107
108
clearValue__
();
109
110
GUM_DESTRUCTOR
(
IncrementalGraphLearner
);
111
}
112
113
114
// ============================================================================
115
/// @name New Observation insertion methods
116
// ============================================================================
117
118
// ############################################################################
119
/**
120
* Inserts a new observation
121
* @param the new observation to learn
122
*/
123
// ############################################################################
124
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
125
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
addObservation
(
126
const
Observation
*
newObs
) {
127
assumeValue__
(
newObs
);
128
129
// The we go across the tree
130
NodeId
currentNodeId
=
root_
;
131
132
while
(
nodeSonsMap_
.
exists
(
currentNodeId
)) {
133
// On each encountered node, we update the database
134
updateNodeWithObservation_
(
newObs
,
currentNodeId
);
135
136
// The we select the next to go throught
137
currentNodeId
138
=
nodeSonsMap_
[
currentNodeId
]
139
[
branchObs__
(
newObs
,
nodeVarMap_
[
currentNodeId
])];
140
}
141
142
// On final insertion into the leave we reach
143
updateNodeWithObservation_
(
newObs
,
currentNodeId
);
144
leafDatabase_
[
currentNodeId
]->
insert
(
newObs
);
145
}
146
147
148
// ============================================================================
149
/// @name New Observation insertion methods
150
// ============================================================================
151
152
// ############################################################################
153
/// If a new modality appears to exists for given variable,
154
/// call this method to turn every associated node to this variable into leaf.
155
/// Graph has then indeed to be revised
156
// ############################################################################
157
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
158
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
updateVar
(
159
const
DiscreteVariable
*
var
) {
160
Link
<
NodeId
>*
nodIter
=
var2Node_
[
var
]->
list
();
161
Link
<
NodeId
>*
nni
=
nullptr
;
162
while
(
nodIter
) {
163
nni
=
nodIter
->
nextLink
();
164
convertNode2Leaf_
(
nodIter
->
element
());
165
nodIter
=
nni
;
166
}
167
}
168
169
170
// ############################################################################
171
/**
172
* From the given sets of node, selects randomly one and installs it
173
* on given node. Chechks of course if node's current variable is not in that
174
* set first.
175
* @param nody : the node we update
176
* @param bestVar : the set of interessting vars to be installed here
177
*/
178
// ############################################################################
179
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
180
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
updateNode_
(
181
NodeId
updatedNode
,
182
Set
<
const
DiscreteVariable
* >&
varsOfInterest
) {
183
// If this node has no interesting variable, we turn it into a leaf
184
if
(
varsOfInterest
.
empty
()) {
185
convertNode2Leaf_
(
updatedNode
);
186
return
;
187
}
188
189
// If this node has already one of the best variable intalled as test, we
190
// move on
191
if
(
nodeVarMap_
.
exists
(
updatedNode
)
192
&&
varsOfInterest
.
exists
(
nodeVarMap_
[
updatedNode
])) {
193
return
;
194
}
195
196
// In any other case we have to install variable as best test
197
Idx
randy
= (
Idx
)(
std
::
rand
() /
RAND_MAX
) *
varsOfInterest
.
size
(),
basc
= 0;
198
SetConstIteratorSafe
<
const
DiscreteVariable
* >
varIter
;
199
for
(
varIter
=
varsOfInterest
.
cbeginSafe
(),
basc
= 0;
200
varIter
!=
varsOfInterest
.
cendSafe
() &&
basc
<
randy
;
201
++
varIter
,
basc
++)
202
;
203
204
transpose_
(
updatedNode
, *
varIter
);
205
}
206
207
208
// ############################################################################
209
/// Turns the given node into a leaf if not already so
210
// ############################################################################
211
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
212
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
convertNode2Leaf_
(
213
NodeId
currentNodeId
) {
214
if
(
nodeVarMap_
[
currentNodeId
] !=
value_
) {
215
leafDatabase_
.
insert
(
currentNodeId
,
new
Set
<
const
Observation
* >());
216
217
// Resolving potential sons issue
218
for
(
Idx
modality
= 0;
modality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
();
219
++
modality
) {
220
NodeId
sonId
=
nodeSonsMap_
[
currentNodeId
][
modality
];
221
convertNode2Leaf_
(
sonId
);
222
(*
leafDatabase_
[
currentNodeId
])
223
= (*
leafDatabase_
[
currentNodeId
]) + *(
leafDatabase_
[
sonId
]);
224
removeNode_
(
sonId
);
225
}
226
227
SOA_DEALLOCATE
(
nodeSonsMap_
[
currentNodeId
],
228
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
());
229
nodeSonsMap_
.
erase
(
currentNodeId
);
230
231
chgNodeBoundVar_
(
currentNodeId
,
value_
);
232
}
233
}
234
235
236
// ############################################################################
237
/// Installs given variable to the given node, ensuring that the variable
238
/// is not present in its subtree
239
// ############################################################################
240
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
241
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
transpose_
(
242
NodeId
currentNodeId
,
243
const
DiscreteVariable
*
desiredVar
) {
244
// **************************************************************************************
245
// Si le noeud courant contient déjà la variable qu'on souhaite lui amener
246
// Il n'y a rien à faire
247
if
(
nodeVarMap_
[
currentNodeId
] ==
desiredVar
) {
return
; }
248
249
// **************************************************************************************
250
// Si le noeud courant est terminal,
251
// Il faut artificiellement insérer un noeud liant à la variable
252
if
(
nodeVarMap_
[
currentNodeId
] ==
value_
) {
253
// We turned this leaf into an internal node.
254
// This mean that we'll need to install children leaves for each value of
255
// desiredVar
256
257
// First We must prepare these new leaves NodeDatabases and Sets<const
258
// Observation*>
259
NodeDatabase
<
AttributeSelection
,
isScalar
>**
dbMap
260
=
static_cast
<
NodeDatabase
<
AttributeSelection
,
isScalar
>** >(
261
SOA_ALLOCATE
(
sizeof
(
NodeDatabase
<
AttributeSelection
,
isScalar
>*)
262
*
desiredVar
->
domainSize
()));
263
Set
<
const
Observation
* >**
obsetMap
264
=
static_cast
<
Set
<
const
Observation
* >** >(
SOA_ALLOCATE
(
265
sizeof
(
Set
<
const
Observation
* >*) *
desiredVar
->
domainSize
()));
266
for
(
Idx
modality
= 0;
modality
<
desiredVar
->
domainSize
(); ++
modality
) {
267
dbMap
[
modality
]
268
=
new
NodeDatabase
<
AttributeSelection
,
isScalar
>(&
setOfVars_
,
value_
);
269
obsetMap
[
modality
] =
new
Set
<
const
Observation
* >();
270
}
271
for
(
SetIteratorSafe
<
const
Observation
* >
obsIter
272
=
leafDatabase_
[
currentNodeId
]->
beginSafe
();
273
leafDatabase_
[
currentNodeId
]->
endSafe
() !=
obsIter
;
274
++
obsIter
) {
275
dbMap
[
branchObs__
(*
obsIter
,
desiredVar
)]->
addObservation
(*
obsIter
);
276
obsetMap
[
branchObs__
(*
obsIter
,
desiredVar
)]->
insert
(*
obsIter
);
277
}
278
279
// Then we can install each new leaves (and put in place the sonsMap)
280
NodeId
*
sonsMap
=
static_cast
<
NodeId
* >(
281
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
desiredVar
->
domainSize
()));
282
for
(
Idx
modality
= 0;
modality
<
desiredVar
->
domainSize
(); ++
modality
)
283
sonsMap
[
modality
]
284
=
insertLeafNode_
(
dbMap
[
modality
],
value_
,
obsetMap
[
modality
]);
285
286
// Some necessary clean up
287
SOA_DEALLOCATE
(
dbMap
,
288
sizeof
(
NodeDatabase
<
AttributeSelection
,
isScalar
>*)
289
*
desiredVar
->
domainSize
());
290
SOA_DEALLOCATE
(
obsetMap
,
291
sizeof
(
Set
<
const
Observation
* >*)
292
*
desiredVar
->
domainSize
());
293
294
// And finally we can turn the node into an internal node associated to
295
// desiredVar
296
chgNodeBoundVar_
(
currentNodeId
,
desiredVar
);
297
nodeSonsMap_
.
insert
(
currentNodeId
,
sonsMap
);
298
299
return
;
300
}
301
302
// *************************************************************************************
303
// Remains the general case where currentNodeId is an internal node.
304
305
// First we ensure that children node use desiredVar as variable
306
for
(
Idx
modality
= 0;
modality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
();
307
++
modality
)
308
transpose_
(
nodeSonsMap_
[
currentNodeId
][
modality
],
desiredVar
);
309
310
// Sequence<NodeDatabase<AttributeSelection, isScalar>*>
311
// sonsNodeDatabase =
312
// nodeId2Database_[currentNodeId]->splitOnVar(desiredVar);
313
NodeId
*
sonsMap
=
static_cast
<
NodeId
* >(
314
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
desiredVar
->
domainSize
()));
315
316
// Then we create the new mapping
317
for
(
Idx
desiredVarModality
= 0;
desiredVarModality
<
desiredVar
->
domainSize
();
318
++
desiredVarModality
) {
319
NodeId
*
grandSonsMap
=
static_cast
<
NodeId
* >(
320
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
()));
321
NodeDatabase
<
AttributeSelection
,
isScalar
>*
sonDB
322
=
new
NodeDatabase
<
AttributeSelection
,
isScalar
>(&
setOfVars_
,
value_
);
323
for
(
Idx
currentVarModality
= 0;
324
currentVarModality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
();
325
++
currentVarModality
) {
326
grandSonsMap
[
currentVarModality
]
327
=
nodeSonsMap_
[
nodeSonsMap_
[
currentNodeId
][
currentVarModality
]]
328
[
desiredVarModality
];
329
sonDB
->
operator
+=((*
nodeId2Database_
[
grandSonsMap
[
currentVarModality
]]));
330
}
331
332
sonsMap
[
desiredVarModality
]
333
=
insertInternalNode_
(
sonDB
,
nodeVarMap_
[
currentNodeId
],
grandSonsMap
);
334
}
335
336
// Finally we clean the old remaining nodes
337
for
(
Idx
currentVarModality
= 0;
338
currentVarModality
<
nodeVarMap_
[
currentNodeId
]->
domainSize
();
339
++
currentVarModality
) {
340
removeNode_
(
nodeSonsMap_
[
currentNodeId
][
currentVarModality
]);
341
}
342
343
// We suppress the old sons map and remap to the new one
344
SOA_DEALLOCATE
(
nodeSonsMap_
[
currentNodeId
],
345
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
());
346
nodeSonsMap_
[
currentNodeId
] =
sonsMap
;
347
348
chgNodeBoundVar_
(
currentNodeId
,
desiredVar
);
349
}
350
351
352
// ############################################################################
353
/**
354
* inserts a new node in internal graph
355
* @param nDB : the associated database
356
* @param boundVar : the associated variable
357
* @return the newly created node's id
358
*/
359
// ############################################################################
360
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
361
NodeId
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertNode_
(
362
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
363
const
DiscreteVariable
*
boundVar
) {
364
NodeId
newNodeId
=
model_
.
addNode
();
365
nodeVarMap_
.
insert
(
newNodeId
,
boundVar
);
366
nodeId2Database_
.
insert
(
newNodeId
,
nDB
);
367
var2Node_
[
boundVar
]->
addLink
(
newNodeId
);
368
369
needUpdate_
=
true
;
370
371
return
newNodeId
;
372
}
373
374
375
// ############################################################################
376
/**
377
* inserts a new internal node in internal graph
378
* @param nDB : the associated database
379
* @param boundVar : the associated variable
380
* @param sonsMap : a table giving node's sons node
381
* @return the newly created node's id
382
*/
383
// ############################################################################
384
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
385
NodeId
386
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertInternalNode_
(
387
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
388
const
DiscreteVariable
*
boundVar
,
389
NodeId
*
sonsMap
) {
390
NodeId
newNodeId
=
this
->
insertNode_
(
nDB
,
boundVar
);
391
nodeSonsMap_
.
insert
(
newNodeId
,
sonsMap
);
392
return
newNodeId
;
393
}
394
395
396
// ############################################################################
397
/**
398
* inserts a new leaf node in internal graohs
399
* @param nDB : the associated database
400
* @param boundVar : the associated variable
401
* @param obsSet : the set of observation this leaf retains
402
* @return the newly created node's id
403
*/
404
// ############################################################################
405
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
406
NodeId
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
insertLeafNode_
(
407
NodeDatabase
<
AttributeSelection
,
isScalar
>*
nDB
,
408
const
DiscreteVariable
*
boundVar
,
409
Set
<
const
Observation
* >*
obsSet
) {
410
NodeId
newNodeId
=
this
->
insertNode_
(
nDB
,
boundVar
);
411
leafDatabase_
.
insert
(
newNodeId
,
obsSet
);
412
return
newNodeId
;
413
}
414
415
416
// ############################################################################
417
/**
418
* Changes the associated variable of a node
419
* @param chgedNodeId : the node to change
420
* @param desiredVar : its new associated variable
421
*/
422
// ############################################################################
423
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
424
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
chgNodeBoundVar_
(
425
NodeId
currentNodeId
,
426
const
DiscreteVariable
*
desiredVar
) {
427
if
(
nodeVarMap_
[
currentNodeId
] ==
desiredVar
)
return
;
428
429
var2Node_
[
nodeVarMap_
[
currentNodeId
]]->
searchAndRemoveLink
(
currentNodeId
);
430
var2Node_
[
desiredVar
]->
addLink
(
currentNodeId
);
431
nodeVarMap_
[
currentNodeId
] =
desiredVar
;
432
433
if
(
nodeVarMap_
[
currentNodeId
] !=
value_
434
&&
leafDatabase_
.
exists
(
currentNodeId
)) {
435
delete
leafDatabase_
[
currentNodeId
];
436
leafDatabase_
.
erase
(
currentNodeId
);
437
}
438
439
if
(
nodeVarMap_
[
currentNodeId
] ==
value_
440
&& !
leafDatabase_
.
exists
(
currentNodeId
)) {
441
leafDatabase_
.
insert
(
currentNodeId
,
new
Set
<
const
Observation
* >());
442
}
443
444
needUpdate_
=
true
;
445
}
446
447
448
// ############################################################################
449
/**
450
* Removes a node from the internal graph
451
* @param removedNodeId : the node to remove
452
*/
453
// ############################################################################
454
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
455
void
IncrementalGraphLearner
<
AttributeSelection
,
isScalar
>::
removeNode_
(
456
NodeId
currentNodeId
) {
457
// Retriat de l'id
458
model_
.
eraseNode
(
currentNodeId
);
459
460
// Retrait du vecteur fils
461
if
(
nodeSonsMap_
.
exists
(
currentNodeId
)) {
462
SOA_DEALLOCATE
(
nodeSonsMap_
[
currentNodeId
],
463
sizeof
(
NodeId
) *
nodeVarMap_
[
currentNodeId
]->
domainSize
());
464
nodeSonsMap_
.
erase
(
currentNodeId
);
465
}
466
467
if
(
leafDatabase_
.
exists
(
currentNodeId
)) {
468
delete
leafDatabase_
[
currentNodeId
];
469
leafDatabase_
.
erase
(
currentNodeId
);
470
}
471
472
// Retrait de la variable
473
var2Node_
[
nodeVarMap_
[
currentNodeId
]]->
searchAndRemoveLink
(
currentNodeId
);
474
nodeVarMap_
.
erase
(
currentNodeId
);
475
476
// Retrait du NodeDatabase
477
delete
nodeId2Database_
[
currentNodeId
];
478
nodeId2Database_
.
erase
(
currentNodeId
);
479
480
needUpdate_
=
true
;
481
}
482
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669