aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
structuredPlaner_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template implementation of FMDP/planning/StructuredPlaner.h classes.
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6) and Jean-Christophe MAGNAN and Christophe
27
* GONZALES(@AMU)
28
*/
29
30
// =========================================================================
31
#
include
<
queue
>
32
#
include
<
vector
>
33
//#include <algorithm>
34
//#include <utility>
35
// =========================================================================
36
#
include
<
agrum
/
tools
/
core
/
math
/
math_utils
.
h
>
37
#
include
<
agrum
/
tools
/
core
/
functors
.
h
>
38
// =========================================================================
39
#
include
<
agrum
/
tools
/
multidim
/
implementations
/
multiDimFunctionGraph
.
h
>
40
#
include
<
agrum
/
tools
/
multidim
/
instantiation
.
h
>
41
#
include
<
agrum
/
tools
/
multidim
/
potential
.
h
>
42
// =========================================================================
43
#
include
<
agrum
/
FMDP
/
planning
/
structuredPlaner
.
h
>
44
// =========================================================================
45
46
/// For shorter line and hence more comprehensive code purposes only
47
#
define
RECAST
(
x
)
reinterpret_cast
<
const
MultiDimFunctionGraph
<
GUM_SCALAR
>
*
>
(
x
)
48
49
namespace
gum
{
50
51
52
/* **************************************************************************************************
53
* **/
54
/* ** **/
55
/* ** Constructors / Destructors **/
56
/* ** **/
57
/* **************************************************************************************************
58
* **/
59
60
// ===========================================================================
61
// Default constructor
62
// ===========================================================================
63
template
<
typename
GUM_SCALAR >
64
INLINE StructuredPlaner<
GUM_SCALAR
>::
StructuredPlaner
(
65
IOperatorStrategy
<
GUM_SCALAR
>*
opi
,
66
GUM_SCALAR
discountFactor
,
67
GUM_SCALAR
epsilon
,
68
bool
verbose
) :
69
discountFactor_
(
discountFactor
),
70
operator_
(
opi
),
verbose_
(
verbose
) {
71
GUM_CONSTRUCTOR
(
StructuredPlaner
);
72
73
threshold__
=
epsilon
;
74
vFunction_
=
nullptr
;
75
optimalPolicy_
=
nullptr
;
76
}
77
78
// ===========================================================================
79
// Default destructor
80
// ===========================================================================
81
template
<
typename
GUM_SCALAR
>
82
INLINE
StructuredPlaner
<
GUM_SCALAR
>::~
StructuredPlaner
() {
83
GUM_DESTRUCTOR
(
StructuredPlaner
);
84
85
if
(
vFunction_
) {
delete
vFunction_
; }
86
87
if
(
optimalPolicy_
)
delete
optimalPolicy_
;
88
89
delete
operator_
;
90
}
91
92
93
/* **************************************************************************************************
94
* **/
95
/* ** **/
96
/* ** Datastructure access methods **/
97
/* ** **/
98
/* **************************************************************************************************
99
* **/
100
101
// ===========================================================================
102
// Initializes data structure needed for making the planning
103
// ===========================================================================
104
template
<
typename
GUM_SCALAR
>
105
std
::
string
StructuredPlaner
<
GUM_SCALAR
>::
optimalPolicy2String
() {
106
// ************************************************************************
107
// Discarding the case where no \pi* have been computed
108
if
(!
optimalPolicy_
||
optimalPolicy_
->
root
() == 0)
109
return
"NO OPTIMAL POLICY CALCULATED YET"
;
110
111
// ************************************************************************
112
// Initialisation
113
114
// Declaration of the needed string stream
115
std
::
stringstream
output
;
116
std
::
stringstream
terminalStream
;
117
std
::
stringstream
nonTerminalStream
;
118
std
::
stringstream
arcstream
;
119
120
// First line for the toDot
121
output
<<
std
::
endl
<<
"digraph \" OPTIMAL POLICY \" {"
<<
std
::
endl
;
122
123
// Form line for the internal node stream en the terminal node stream
124
terminalStream
<<
"node [shape = box];"
<<
std
::
endl
;
125
nonTerminalStream
<<
"node [shape = ellipse];"
<<
std
::
endl
;
126
127
// For somme clarity in the final string
128
std
::
string
tab
=
"\t"
;
129
130
// To know if we already checked a node or not
131
Set
<
NodeId
>
visited
;
132
133
// FIFO of nodes to visit
134
std
::
queue
<
NodeId
>
fifo
;
135
136
// Loading the FIFO
137
fifo
.
push
(
optimalPolicy_
->
root
());
138
visited
<<
optimalPolicy_
->
root
();
139
140
141
// ************************************************************************
142
// Main loop
143
while
(!
fifo
.
empty
()) {
144
// Node to visit
145
NodeId
currentNodeId
=
fifo
.
front
();
146
fifo
.
pop
();
147
148
// Checking if it is terminal
149
if
(
optimalPolicy_
->
isTerminalNode
(
currentNodeId
)) {
150
// Get back the associated ActionSet
151
ActionSet
ase
=
optimalPolicy_
->
nodeValue
(
currentNodeId
);
152
153
// Creating a line for this node
154
terminalStream
<<
tab
<<
currentNodeId
<<
";"
<<
tab
<<
currentNodeId
155
<<
" [label=\""
<<
currentNodeId
<<
" - "
;
156
157
// Enumerating and adding to the line the associated optimal actions
158
for
(
SequenceIteratorSafe
<
Idx
>
valIter
=
ase
.
beginSafe
();
159
valIter
!=
ase
.
endSafe
();
160
++
valIter
)
161
terminalStream
<<
fmdp_
->
actionName
(*
valIter
) <<
" "
;
162
163
// Terminating line
164
terminalStream
<<
"\"];"
<<
std
::
endl
;
165
continue
;
166
}
167
168
// Either wise
169
{
170
// Geting back the associated internal node
171
const
InternalNode
*
currentNode
=
optimalPolicy_
->
node
(
currentNodeId
);
172
173
// Creating a line in internalnode stream for this node
174
nonTerminalStream
<<
tab
<<
currentNodeId
<<
";"
<<
tab
<<
currentNodeId
175
<<
" [label=\""
<<
currentNodeId
<<
" - "
176
<<
currentNode
->
nodeVar
()->
name
() <<
"\"];"
<<
std
::
endl
;
177
178
// Going through the sons and agregating them according the the sons Ids
179
HashTable
<
NodeId
,
LinkedList
<
Idx
>* >
sonMap
;
180
for
(
Idx
sonIter
= 0;
sonIter
<
currentNode
->
nbSons
(); ++
sonIter
) {
181
if
(!
visited
.
exists
(
currentNode
->
son
(
sonIter
))) {
182
fifo
.
push
(
currentNode
->
son
(
sonIter
));
183
visited
<<
currentNode
->
son
(
sonIter
);
184
}
185
if
(!
sonMap
.
exists
(
currentNode
->
son
(
sonIter
)))
186
sonMap
.
insert
(
currentNode
->
son
(
sonIter
),
new
LinkedList
<
Idx
>());
187
sonMap
[
currentNode
->
son
(
sonIter
)]->
addLink
(
sonIter
);
188
}
189
190
// Adding to the arc stram
191
for
(
auto
sonIter
=
sonMap
.
beginSafe
();
sonIter
!=
sonMap
.
endSafe
();
192
++
sonIter
) {
193
arcstream
<<
tab
<<
currentNodeId
<<
" -> "
<<
sonIter
.
key
()
194
<<
" [label=\" "
;
195
Link
<
Idx
>*
modaIter
=
sonIter
.
val
()->
list
();
196
while
(
modaIter
) {
197
arcstream
<<
currentNode
->
nodeVar
()->
label
(
modaIter
->
element
());
198
if
(
modaIter
->
nextLink
())
arcstream
<<
", "
;
199
modaIter
=
modaIter
->
nextLink
();
200
}
201
arcstream
<<
"\",color=\"#00ff00\"];"
<<
std
::
endl
;
202
delete
sonIter
.
val
();
203
}
204
}
205
}
206
207
// Terminating
208
output
<<
terminalStream
.
str
() <<
std
::
endl
209
<<
nonTerminalStream
.
str
() <<
std
::
endl
210
<<
arcstream
.
str
() <<
std
::
endl
211
<<
"}"
<<
std
::
endl
;
212
213
return
output
.
str
();
214
}
215
216
217
/* **************************************************************************************************
218
* **/
219
/* ** **/
220
/* ** Planning Methods **/
221
/* ** **/
222
/* **************************************************************************************************
223
* **/
224
225
// ===========================================================================
226
// Initializes data structure needed for making the planning
227
// ===========================================================================
228
template
<
typename
GUM_SCALAR
>
229
void
StructuredPlaner
<
GUM_SCALAR
>::
initialize
(
const
FMDP
<
GUM_SCALAR
>*
fmdp
) {
230
fmdp_
=
fmdp
;
231
232
// Determination of the threshold value
233
threshold__
*= (1 -
discountFactor_
) / (2 *
discountFactor_
);
234
235
// Establishement of sequence of variable elemination
236
for
(
auto
varIter
=
fmdp_
->
beginVariables
();
varIter
!=
fmdp_
->
endVariables
();
237
++
varIter
)
238
elVarSeq_
<<
fmdp_
->
main2prime
(*
varIter
);
239
240
// Initialisation of the value function
241
vFunction_
=
operator_
->
getFunctionInstance
();
242
optimalPolicy_
=
operator_
->
getAggregatorInstance
();
243
firstTime__
=
true
;
244
}
245
246
247
// ===========================================================================
248
// Performs a value iteration
249
// ===========================================================================
250
template
<
typename
GUM_SCALAR
>
251
void
StructuredPlaner
<
GUM_SCALAR
>::
makePlanning
(
Idx
nbStep
) {
252
if
(
firstTime__
) {
253
this
->
initVFunction_
();
254
firstTime__
=
false
;
255
}
256
257
// *****************************************************************************************
258
// Main loop
259
// *****************************************************************************************
260
Idx
nbIte
= 0;
261
GUM_SCALAR
gap
=
threshold__
+ 1;
262
while
((
gap
>
threshold__
) && (
nbIte
<
nbStep
)) {
263
++
nbIte
;
264
265
MultiDimFunctionGraph
<
GUM_SCALAR
>*
newVFunction
=
this
->
valueIteration_
();
266
267
// *****************************************************************************************
268
// Then we compare new value function and the old one
269
MultiDimFunctionGraph
<
GUM_SCALAR
>*
deltaV
270
=
operator_
->
subtract
(
newVFunction
,
vFunction_
);
271
gap
= 0;
272
273
for
(
deltaV
->
beginValues
();
deltaV
->
hasValue
();
deltaV
->
nextValue
())
274
if
(
gap
<
fabs
(
deltaV
->
value
()))
gap
=
fabs
(
deltaV
->
value
());
275
delete
deltaV
;
276
277
if
(
verbose_
)
278
std
::
cout
<<
" ------------------- Fin itération n° "
<<
nbIte
<<
std
::
endl
279
<<
" Gap : "
<<
gap
<<
" - "
<<
threshold__
<<
std
::
endl
;
280
281
// *****************************************************************************************
282
// And eventually we update pointers for next loop
283
delete
vFunction_
;
284
vFunction_
=
newVFunction
;
285
}
286
287
// *****************************************************************************************
288
// Policy matching value function research
289
// *****************************************************************************************
290
this
->
evalPolicy_
();
291
}
292
293
294
// ===========================================================================
295
// Performs a single step of value iteration
296
// ===========================================================================
297
template
<
typename
GUM_SCALAR
>
298
void
StructuredPlaner
<
GUM_SCALAR
>::
initVFunction_
() {
299
vFunction_
->
copy
(*(
RECAST
(
fmdp_
->
reward
())));
300
}
301
302
/* **************************************************************************************************
303
* **/
304
/* ** **/
305
/* ** Value Iteration Methods **/
306
/* ** **/
307
/* **************************************************************************************************
308
* **/
309
310
311
// ===========================================================================
312
// Performs a single step of value iteration
313
// ===========================================================================
314
template
<
typename
GUM_SCALAR
>
315
MultiDimFunctionGraph
<
GUM_SCALAR
>*
316
StructuredPlaner
<
GUM_SCALAR
>::
valueIteration_
() {
317
// *****************************************************************************************
318
// Loop reset
319
MultiDimFunctionGraph
<
GUM_SCALAR
>*
newVFunction
320
=
operator_
->
getFunctionInstance
();
321
newVFunction
->
copyAndReassign
(*
vFunction_
,
fmdp_
->
mapMainPrime
());
322
323
// *****************************************************************************************
324
// For each action
325
std
::
vector
<
MultiDimFunctionGraph
<
GUM_SCALAR
>* >
qActionsSet
;
326
for
(
auto
actionIter
=
fmdp_
->
beginActions
();
327
actionIter
!=
fmdp_
->
endActions
();
328
++
actionIter
) {
329
MultiDimFunctionGraph
<
GUM_SCALAR
>*
qAction
330
=
this
->
evalQaction_
(
newVFunction
, *
actionIter
);
331
qActionsSet
.
push_back
(
qAction
);
332
}
333
delete
newVFunction
;
334
335
// *****************************************************************************************
336
// Next to evaluate main value function, we take maximise over all action
337
// value, ...
338
newVFunction
=
this
->
maximiseQactions_
(
qActionsSet
);
339
340
// *******************************************************************************************
341
// Next, we eval the new function value
342
newVFunction
=
this
->
addReward_
(
newVFunction
);
343
344
return
newVFunction
;
345
}
346
347
348
// ===========================================================================
349
// Evals the q function for current fmdp action
350
// ===========================================================================
351
template
<
typename
GUM_SCALAR
>
352
MultiDimFunctionGraph
<
GUM_SCALAR
>*
353
StructuredPlaner
<
GUM_SCALAR
>::
evalQaction_
(
354
const
MultiDimFunctionGraph
<
GUM_SCALAR
>*
Vold
,
355
Idx
actionId
) {
356
// ******************************************************************************
357
// Initialisation :
358
// Creating a copy of last Vfunction to deduce from the new Qaction
359
// And finding the first var to eleminate (the one at the end)
360
361
return
operator_
->
regress
(
Vold
,
actionId
,
this
->
fmdp_
,
this
->
elVarSeq_
);
362
}
363
364
365
// ===========================================================================
366
// Maximise the AAction to iobtain the vFunction
367
// ===========================================================================
368
template
<
typename
GUM_SCALAR
>
369
MultiDimFunctionGraph
<
GUM_SCALAR
>*
370
StructuredPlaner
<
GUM_SCALAR
>::
maximiseQactions_
(
371
std
::
vector
<
MultiDimFunctionGraph
<
GUM_SCALAR
>* >&
qActionsSet
) {
372
MultiDimFunctionGraph
<
GUM_SCALAR
>*
newVFunction
=
qActionsSet
.
back
();
373
qActionsSet
.
pop_back
();
374
375
while
(!
qActionsSet
.
empty
()) {
376
MultiDimFunctionGraph
<
GUM_SCALAR
>*
qAction
=
qActionsSet
.
back
();
377
qActionsSet
.
pop_back
();
378
newVFunction
=
operator_
->
maximize
(
newVFunction
,
qAction
);
379
}
380
381
return
newVFunction
;
382
}
383
384
385
// ===========================================================================
386
// Maximise the AAction to iobtain the vFunction
387
// ===========================================================================
388
template
<
typename
GUM_SCALAR
>
389
MultiDimFunctionGraph
<
GUM_SCALAR
>*
390
StructuredPlaner
<
GUM_SCALAR
>::
minimiseFunctions_
(
391
std
::
vector
<
MultiDimFunctionGraph
<
GUM_SCALAR
>* >&
qActionsSet
) {
392
MultiDimFunctionGraph
<
GUM_SCALAR
>*
newVFunction
=
qActionsSet
.
back
();
393
qActionsSet
.
pop_back
();
394
395
while
(!
qActionsSet
.
empty
()) {
396
MultiDimFunctionGraph
<
GUM_SCALAR
>*
qAction
=
qActionsSet
.
back
();
397
qActionsSet
.
pop_back
();
398
newVFunction
=
operator_
->
minimize
(
newVFunction
,
qAction
);
399
}
400
401
return
newVFunction
;
402
}
403
404
405
// ===========================================================================
406
// Updates the value function by multiplying by discount and adding reward
407
// ===========================================================================
408
template
<
typename
GUM_SCALAR
>
409
MultiDimFunctionGraph
<
GUM_SCALAR
>*
StructuredPlaner
<
GUM_SCALAR
>::
addReward_
(
410
MultiDimFunctionGraph
<
GUM_SCALAR
>*
Vold
,
411
Idx
actionId
) {
412
// *****************************************************************************************
413
// ... we multiply the result by the discount factor, ...
414
MultiDimFunctionGraph
<
GUM_SCALAR
>*
newVFunction
415
=
operator_
->
getFunctionInstance
();
416
newVFunction
->
copyAndMultiplyByScalar
(*
Vold
,
this
->
discountFactor_
);
417
delete
Vold
;
418
419
// *****************************************************************************************
420
// ... and finally add reward
421
newVFunction
=
operator_
->
add
(
newVFunction
,
RECAST
(
fmdp_
->
reward
(
actionId
)));
422
423
return
newVFunction
;
424
}
425
426
427
/* **************************************************************************************************
428
* **/
429
/* ** **/
430
/* ** Optimal Policy Evaluation Methods **/
431
/* ** **/
432
/* **************************************************************************************************
433
* **/
434
435
// ===========================================================================
436
// Evals the policy corresponding to the given value function
437
// ===========================================================================
438
template
<
typename
GUM_SCALAR
>
439
void
StructuredPlaner
<
GUM_SCALAR
>::
evalPolicy_
() {
440
// *****************************************************************************************
441
// Loop reset
442
MultiDimFunctionGraph
<
GUM_SCALAR
>*
newVFunction
443
=
operator_
->
getFunctionInstance
();
444
newVFunction
->
copyAndReassign
(*
vFunction_
,
fmdp_
->
mapMainPrime
());
445
446
std
::
vector
<
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
447
SetTerminalNodePolicy
>* >
448
argMaxQActionsSet
;
449
// *****************************************************************************************
450
// For each action
451
for
(
auto
actionIter
=
fmdp_
->
beginActions
();
452
actionIter
!=
fmdp_
->
endActions
();
453
++
actionIter
) {
454
MultiDimFunctionGraph
<
GUM_SCALAR
>*
qAction
455
=
this
->
evalQaction_
(
newVFunction
, *
actionIter
);
456
457
qAction
=
this
->
addReward_
(
qAction
);
458
459
argMaxQActionsSet
.
push_back
(
makeArgMax_
(
qAction
, *
actionIter
));
460
}
461
delete
newVFunction
;
462
463
464
// *****************************************************************************************
465
// Next to evaluate main value function, we take maximise over all action
466
// value, ...
467
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
SetTerminalNodePolicy
>*
468
argMaxVFunction
469
=
argmaximiseQactions_
(
argMaxQActionsSet
);
470
471
// *****************************************************************************************
472
// Next to evaluate main value function, we take maximise over all action
473
// value, ...
474
extractOptimalPolicy_
(
argMaxVFunction
);
475
}
476
477
478
// ===========================================================================
479
// Creates a copy of given in parameter decision Graph and replaces leaves
480
// of that Graph by a pair containing value of the leaf and action to which
481
// is bind this Graph (given in parameter).
482
// ===========================================================================
483
template
<
typename
GUM_SCALAR
>
484
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
SetTerminalNodePolicy
>*
485
StructuredPlaner
<
GUM_SCALAR
>::
makeArgMax_
(
486
const
MultiDimFunctionGraph
<
GUM_SCALAR
>*
qAction
,
487
Idx
actionId
) {
488
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
SetTerminalNodePolicy
>*
489
amcpy
490
=
operator_
->
getArgMaxFunctionInstance
();
491
492
// Insertion des nouvelles variables
493
for
(
SequenceIteratorSafe
<
const
DiscreteVariable
* >
varIter
494
=
qAction
->
variablesSequence
().
beginSafe
();
495
varIter
!=
qAction
->
variablesSequence
().
endSafe
();
496
++
varIter
)
497
amcpy
->
add
(**
varIter
);
498
499
HashTable
<
NodeId
,
NodeId
>
src2dest
;
500
amcpy
->
manager
()->
setRootNode
(
501
recurArgMaxCopy__
(
qAction
->
root
(),
actionId
,
qAction
,
amcpy
,
src2dest
));
502
503
delete
qAction
;
504
return
amcpy
;
505
}
506
507
508
// ==========================================================================
509
// Recursion part for the createArgMaxCopy
510
// ==========================================================================
511
template
<
typename
GUM_SCALAR
>
512
NodeId
StructuredPlaner
<
GUM_SCALAR
>::
recurArgMaxCopy__
(
513
NodeId
currentNodeId
,
514
Idx
actionId
,
515
const
MultiDimFunctionGraph
<
GUM_SCALAR
>*
src
,
516
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
SetTerminalNodePolicy
>*
517
argMaxCpy
,
518
HashTable
<
NodeId
,
NodeId
>&
visitedNodes
) {
519
if
(
visitedNodes
.
exists
(
currentNodeId
))
return
visitedNodes
[
currentNodeId
];
520
521
NodeId
nody
;
522
if
(
src
->
isTerminalNode
(
currentNodeId
)) {
523
ArgMaxSet
<
GUM_SCALAR
,
Idx
>
leaf
(
src
->
nodeValue
(
currentNodeId
),
actionId
);
524
nody
=
argMaxCpy
->
manager
()->
addTerminalNode
(
leaf
);
525
}
else
{
526
const
InternalNode
*
currentNode
=
src
->
node
(
currentNodeId
);
527
NodeId
*
sonsMap
=
static_cast
<
NodeId
* >(
528
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
currentNode
->
nodeVar
()->
domainSize
()));
529
for
(
Idx
moda
= 0;
moda
<
currentNode
->
nodeVar
()->
domainSize
(); ++
moda
)
530
sonsMap
[
moda
] =
recurArgMaxCopy__
(
currentNode
->
son
(
moda
),
531
actionId
,
532
src
,
533
argMaxCpy
,
534
visitedNodes
);
535
nody
536
=
argMaxCpy
->
manager
()->
addInternalNode
(
currentNode
->
nodeVar
(),
sonsMap
);
537
}
538
visitedNodes
.
insert
(
currentNodeId
,
nody
);
539
return
nody
;
540
}
541
542
543
// ===========================================================================
544
// Performs argmax_a Q(s,a)
545
// ===========================================================================
546
template
<
typename
GUM_SCALAR
>
547
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
SetTerminalNodePolicy
>*
548
StructuredPlaner
<
GUM_SCALAR
>::
argmaximiseQactions_
(
549
std
::
vector
<
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
550
SetTerminalNodePolicy
>* >&
551
qActionsSet
) {
552
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
SetTerminalNodePolicy
>*
553
newVFunction
554
=
qActionsSet
.
back
();
555
qActionsSet
.
pop_back
();
556
557
while
(!
qActionsSet
.
empty
()) {
558
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
SetTerminalNodePolicy
>*
559
qAction
560
=
qActionsSet
.
back
();
561
qActionsSet
.
pop_back
();
562
newVFunction
=
operator_
->
argmaximize
(
newVFunction
,
qAction
);
563
}
564
565
return
newVFunction
;
566
}
567
568
// ===========================================================================
569
// Creates a copy of given in parameter decision Graph and replaces leaves
570
// of that Graph by a pair containing value of the leaf and action to which
571
// is bind this Graph (given in parameter).
572
// ===========================================================================
573
template
<
typename
GUM_SCALAR
>
574
void
StructuredPlaner
<
GUM_SCALAR
>::
extractOptimalPolicy_
(
575
const
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
576
SetTerminalNodePolicy
>*
577
argMaxOptimalValueFunction
) {
578
optimalPolicy_
->
clear
();
579
580
// Insertion des nouvelles variables
581
for
(
SequenceIteratorSafe
<
const
DiscreteVariable
* >
varIter
582
=
argMaxOptimalValueFunction
->
variablesSequence
().
beginSafe
();
583
varIter
!=
argMaxOptimalValueFunction
->
variablesSequence
().
endSafe
();
584
++
varIter
)
585
optimalPolicy_
->
add
(**
varIter
);
586
587
HashTable
<
NodeId
,
NodeId
>
src2dest
;
588
optimalPolicy_
->
manager
()->
setRootNode
(
589
recurExtractOptPol__
(
argMaxOptimalValueFunction
->
root
(),
590
argMaxOptimalValueFunction
,
591
src2dest
));
592
593
delete
argMaxOptimalValueFunction
;
594
}
595
596
597
// ==========================================================================
598
// Recursion part for the createArgMaxCopy
599
// ==========================================================================
600
template
<
typename
GUM_SCALAR
>
601
NodeId
StructuredPlaner
<
GUM_SCALAR
>::
recurExtractOptPol__
(
602
NodeId
currentNodeId
,
603
const
MultiDimFunctionGraph
<
ArgMaxSet
<
GUM_SCALAR
,
Idx
>,
604
SetTerminalNodePolicy
>*
argMaxOptVFunc
,
605
HashTable
<
NodeId
,
NodeId
>&
visitedNodes
) {
606
if
(
visitedNodes
.
exists
(
currentNodeId
))
return
visitedNodes
[
currentNodeId
];
607
608
NodeId
nody
;
609
if
(
argMaxOptVFunc
->
isTerminalNode
(
currentNodeId
)) {
610
ActionSet
leaf
;
611
transferActionIds__
(
argMaxOptVFunc
->
nodeValue
(
currentNodeId
),
leaf
);
612
nody
=
optimalPolicy_
->
manager
()->
addTerminalNode
(
leaf
);
613
}
else
{
614
const
InternalNode
*
currentNode
=
argMaxOptVFunc
->
node
(
currentNodeId
);
615
NodeId
*
sonsMap
=
static_cast
<
NodeId
* >(
616
SOA_ALLOCATE
(
sizeof
(
NodeId
) *
currentNode
->
nodeVar
()->
domainSize
()));
617
for
(
Idx
moda
= 0;
moda
<
currentNode
->
nodeVar
()->
domainSize
(); ++
moda
)
618
sonsMap
[
moda
] =
recurExtractOptPol__
(
currentNode
->
son
(
moda
),
619
argMaxOptVFunc
,
620
visitedNodes
);
621
nody
=
optimalPolicy_
->
manager
()->
addInternalNode
(
currentNode
->
nodeVar
(),
622
sonsMap
);
623
}
624
visitedNodes
.
insert
(
currentNodeId
,
nody
);
625
return
nody
;
626
}
627
628
// ==========================================================================
629
// Extract from an ArgMaxSet the associated ActionSet
630
// ==========================================================================
631
template
<
typename
GUM_SCALAR
>
632
void
StructuredPlaner
<
GUM_SCALAR
>::
transferActionIds__
(
633
const
ArgMaxSet
<
GUM_SCALAR
,
Idx
>&
src
,
634
ActionSet
&
dest
) {
635
for
(
auto
idi
=
src
.
beginSafe
();
idi
!=
src
.
endSafe
(); ++
idi
)
636
dest
+= *
idi
;
637
}
638
639
640
}
// end of namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
RECAST
#define RECAST(x)
Definition:
fmdp_tpl.h:36