47 #define RECAST(x) reinterpret_cast< const MultiDimFunctionGraph< GUM_SCALAR >* >(x) 63 template <
typename GUM_SCALAR >
66 GUM_SCALAR discountFactor,
69 _discountFactor(discountFactor),
70 _operator(opi), _verbose(verbose) {
73 __threshold = epsilon;
75 _optimalPolicy =
nullptr;
81 template <
typename GUM_SCALAR >
85 if (_vFunction) {
delete _vFunction; }
87 if (_optimalPolicy)
delete _optimalPolicy;
104 template <
typename GUM_SCALAR >
108 if (!_optimalPolicy || _optimalPolicy->root() == 0)
109 return "NO OPTIMAL POLICY CALCULATED YET";
115 std::stringstream output;
116 std::stringstream terminalStream;
117 std::stringstream nonTerminalStream;
118 std::stringstream arcstream;
121 output << std::endl <<
"digraph \" OPTIMAL POLICY \" {" << std::endl;
124 terminalStream <<
"node [shape = box];" << std::endl;
125 nonTerminalStream <<
"node [shape = ellipse];" << std::endl;
128 std::string tab =
"\t";
134 std::queue< NodeId > fifo;
137 fifo.push(_optimalPolicy->root());
138 visited << _optimalPolicy->root();
143 while (!fifo.empty()) {
145 NodeId currentNodeId = fifo.front();
149 if (_optimalPolicy->isTerminalNode(currentNodeId)) {
151 ActionSet ase = _optimalPolicy->nodeValue(currentNodeId);
154 terminalStream << tab << currentNodeId <<
";" << tab << currentNodeId
155 <<
" [label=\"" << currentNodeId <<
" - ";
161 terminalStream << _fmdp->actionName(*valIter) <<
" ";
164 terminalStream <<
"\"];" << std::endl;
171 const InternalNode* currentNode = _optimalPolicy->node(currentNodeId);
174 nonTerminalStream << tab << currentNodeId <<
";" << tab << currentNodeId
175 <<
" [label=\"" << currentNodeId <<
" - " 176 << currentNode->
nodeVar()->
name() <<
"\"];" << std::endl;
180 for (
Idx sonIter = 0; sonIter < currentNode->
nbSons(); ++sonIter) {
181 if (!visited.
exists(currentNode->
son(sonIter))) {
182 fifo.push(currentNode->
son(sonIter));
183 visited << currentNode->
son(sonIter);
185 if (!sonMap.
exists(currentNode->
son(sonIter)))
187 sonMap[currentNode->
son(sonIter)]->addLink(sonIter);
193 arcstream << tab << currentNodeId <<
" -> " << sonIter.
key()
198 if (modaIter->
nextLink()) arcstream <<
", ";
201 arcstream <<
"\",color=\"#00ff00\"];" << std::endl;
202 delete sonIter.val();
208 output << terminalStream.str() << std::endl
209 << nonTerminalStream.str() << std::endl
210 << arcstream.str() << std::endl
228 template <
typename GUM_SCALAR >
233 __threshold *= (1 - _discountFactor) / (2 * _discountFactor);
236 for (
auto varIter = _fmdp->beginVariables(); varIter != _fmdp->endVariables();
241 _vFunction = _operator->getFunctionInstance();
242 _optimalPolicy = _operator->getAggregatorInstance();
250 template <
typename GUM_SCALAR >
253 this->_initVFunction();
261 GUM_SCALAR gap = __threshold + 1;
262 while ((gap > __threshold) && (nbIte < nbStep)) {
270 _operator->subtract(newVFunction, _vFunction);
274 if (gap < fabs(deltaV->
value())) gap = fabs(deltaV->
value());
278 std::cout <<
" ------------------- Fin itération n° " << nbIte << std::endl
279 <<
" Gap : " << gap <<
" - " << __threshold << std::endl;
284 _vFunction = newVFunction;
297 template <
typename GUM_SCALAR >
299 _vFunction->copy(*(
RECAST(_fmdp->reward())));
314 template <
typename GUM_SCALAR >
320 _operator->getFunctionInstance();
325 std::vector< MultiDimFunctionGraph< GUM_SCALAR >* > qActionsSet;
326 for (
auto actionIter = _fmdp->beginActions();
327 actionIter != _fmdp->endActions();
330 this->_evalQaction(newVFunction, *actionIter);
331 qActionsSet.push_back(qAction);
338 newVFunction = this->_maximiseQactions(qActionsSet);
342 newVFunction = this->_addReward(newVFunction);
351 template <
typename GUM_SCALAR >
360 return _operator->regress(Vold, actionId, this->_fmdp, this->_elVarSeq);
367 template <
typename GUM_SCALAR >
372 qActionsSet.pop_back();
374 while (!qActionsSet.empty()) {
376 qActionsSet.pop_back();
377 newVFunction = _operator->maximize(newVFunction, qAction);
387 template <
typename GUM_SCALAR >
392 qActionsSet.pop_back();
394 while (!qActionsSet.empty()) {
396 qActionsSet.pop_back();
397 newVFunction = _operator->minimize(newVFunction, qAction);
407 template <
typename GUM_SCALAR >
413 _operator->getFunctionInstance();
419 newVFunction = _operator->
add(newVFunction,
RECAST(_fmdp->reward(actionId)));
436 template <
typename GUM_SCALAR >
441 _operator->getFunctionInstance();
444 std::vector< MultiDimFunctionGraph< ArgMaxSet< GUM_SCALAR, Idx >,
449 for (
auto actionIter = _fmdp->beginActions();
450 actionIter != _fmdp->endActions();
453 this->_evalQaction(newVFunction, *actionIter);
455 qAction = this->_addReward(qAction);
457 argMaxQActionsSet.push_back(_makeArgMax(qAction, *actionIter));
466 argMaxVFunction = _argmaximiseQactions(argMaxQActionsSet);
471 _extractOptimalPolicy(argMaxVFunction);
480 template <
typename GUM_SCALAR >
485 amcpy = _operator->getArgMaxFunctionInstance();
492 amcpy->
add(**varIter);
496 __recurArgMaxCopy(qAction->
root(), actionId, qAction, amcpy, src2dest));
506 template <
typename GUM_SCALAR >
514 if (visitedNodes.
exists(currentNodeId))
return visitedNodes[currentNodeId];
519 nody = argMaxCpy->manager()->addTerminalNode(leaf);
525 sonsMap[moda] = __recurArgMaxCopy(
526 currentNode->
son(moda), actionId, src, argMaxCpy, visitedNodes);
528 argMaxCpy->manager()->addInternalNode(currentNode->
nodeVar(), sonsMap);
530 visitedNodes.
insert(currentNodeId, nody);
538 template <
typename GUM_SCALAR >
545 newVFunction = qActionsSet.back();
546 qActionsSet.pop_back();
548 while (!qActionsSet.empty()) {
550 qAction = qActionsSet.back();
551 qActionsSet.pop_back();
552 newVFunction = _operator->argmaximize(newVFunction, qAction);
563 template <
typename GUM_SCALAR >
567 argMaxOptimalValueFunction) {
568 _optimalPolicy->clear();
572 argMaxOptimalValueFunction->variablesSequence().beginSafe();
573 varIter != argMaxOptimalValueFunction->variablesSequence().endSafe();
575 _optimalPolicy->add(**varIter);
578 _optimalPolicy->manager()->setRootNode(__recurExtractOptPol(
579 argMaxOptimalValueFunction->root(), argMaxOptimalValueFunction, src2dest));
581 delete argMaxOptimalValueFunction;
588 template <
typename GUM_SCALAR >
594 if (visitedNodes.
exists(currentNodeId))
return visitedNodes[currentNodeId];
597 if (argMaxOptVFunc->isTerminalNode(currentNodeId)) {
599 __transferActionIds(argMaxOptVFunc->nodeValue(currentNodeId), leaf);
600 nody = _optimalPolicy->manager()->addTerminalNode(leaf);
602 const InternalNode* currentNode = argMaxOptVFunc->node(currentNodeId);
606 sonsMap[moda] = __recurExtractOptPol(
607 currentNode->
son(moda), argMaxOptVFunc, visitedNodes);
608 nody = _optimalPolicy->manager()->addInternalNode(currentNode->
nodeVar(),
611 visitedNodes.
insert(currentNodeId, nody);
618 template <
typename GUM_SCALAR >
SequenceIteratorSafe< GUM_SCALAR_SEQ > endSafe() const
Iterator end.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
<agrum/FMDP/planning/structuredPlaner.h>
void nextValue() const
Increments the constant safe iterator.
bool isTerminalNode(const NodeId &node) const
Indicates if given node is terminal or not.
void setRootNode(const NodeId &root)
Sets root node of decision diagram.
void beginValues() const
Initializes the constant safe iterator on terminal nodes.
A class to store the optimal actions.
void copyAndMultiplyByScalar(const MultiDimFunctionGraph< GUM_SCALAR, TerminalNodePolicy > &src, GUM_SCALAR gamma)
Copies src diagrams and multiply every value by the given scalar.
const iterator_safe & endSafe() noexcept
Returns the safe iterator pointing to the end of the hashtable.
const InternalNode * node(NodeId n) const
Returns internalNode structure associated to that nodeId.
const DiscreteVariable * nodeVar() const
Returns the node variable.
Idx nbSons() const
Returns the number of sons.
#define RECAST(x)
For shorter line and hence more comprehensive code purposes only.
void copyAndReassign(const MultiDimFunctionGraph< GUM_SCALAR, TerminalNodePolicy > &src, const Bijection< const DiscreteVariable *, const DiscreteVariable * > &reassign)
Copies src diagrams structure into this diagrams.
<agrum/FMDP/SDyna/IOperatorStrategy.h>
NodeId son(Idx modality) const
Returns the son at a given index.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
const Key & key(const Key &key) const
Returns a reference on a given key.
This class is used to implement factored decision process.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
The class for generic Hash Tables.
Class to handle efficiently argMaxSet.
SequenceIteratorSafe< Idx > endSafe() const
Iterator end.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const T & element() const
Returns the element stored in this link.
virtual Size domainSize() const =0
StructuredPlaner(IOperatorStrategy< GUM_SCALAR > *opi, GUM_SCALAR discountFactor, GUM_SCALAR epsilon, bool verbose)
Default constructor.
virtual void add(const DiscreteVariable &v)
Adds a new var to the variables of the multidimensional matrix.
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const GUM_SCALAR & nodeValue(NodeId n) const
Returns value associated to given node.
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
Structure used to represent a node internal structure.
bool hasValue() const
Indicates if constant safe iterator has reach end of terminal nodes list.
Link of a chain list allocated using the SmallObjectAllocator.
const DiscreteVariable * main2prime(const DiscreteVariable *mainVar) const
Returns the primed variable associate to the given main variable.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const NodeId & root() const
Returns the id of the root node from the diagram.
SequenceIteratorSafe< GUM_SCALAR_SEQ > beginSafe() const
Iterator beginning.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
SequenceIteratorSafe< Idx > beginSafe() const
Iterator beginning.
Implementation of a Terminal Node Policy that maps nodeid to a set of value.
virtual const Sequence< const DiscreteVariable *> & variablesSequence() const override
Returns a const ref to the sequence of DiscreteVariable*.
Chain list allocated using the SmallObjectAllocator.
iterator_safe beginSafe()
Returns the safe iterator pointing to the beginning of the hashtable.
Size Idx
Type for indexes.
MultiDimFunctionGraphManager< GUM_SCALAR, TerminalNodePolicy > * manager()
Returns a const reference to the manager of this diagram.
const GUM_SCALAR & value() const
Returns the value of the current terminal nodes pointed by the constant safe iterator.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
const std::string & name() const
returns the name of the variable
const Link< T > * nextLink() const
Returns next link.
Size NodeId
Type for node ids.