44 #define RECAST(x) reinterpret_cast< const MultiDimFunctionGraph< GUM_SCALAR >* >(x) 60 template <
typename GUM_SCALAR >
63 GUM_SCALAR discountFactor,
66 _discountFactor(discountFactor),
67 _operator(opi), _verbose(verbose) {
70 __threshold = epsilon;
72 _optimalPolicy =
nullptr;
78 template <
typename GUM_SCALAR >
82 if (_vFunction) {
delete _vFunction; }
84 if (_optimalPolicy)
delete _optimalPolicy;
101 template <
typename GUM_SCALAR >
105 if (!_optimalPolicy || _optimalPolicy->root() == 0)
106 return "NO OPTIMAL POLICY CALCULATED YET";
112 std::stringstream output;
113 std::stringstream terminalStream;
114 std::stringstream nonTerminalStream;
115 std::stringstream arcstream;
118 output << std::endl <<
"digraph \" OPTIMAL POLICY \" {" << std::endl;
121 terminalStream <<
"node [shape = box];" << std::endl;
122 nonTerminalStream <<
"node [shape = ellipse];" << std::endl;
125 std::string tab =
"\t";
131 std::queue< NodeId > fifo;
134 fifo.push(_optimalPolicy->root());
135 visited << _optimalPolicy->root();
140 while (!fifo.empty()) {
142 NodeId currentNodeId = fifo.front();
146 if (_optimalPolicy->isTerminalNode(currentNodeId)) {
148 ActionSet ase = _optimalPolicy->nodeValue(currentNodeId);
151 terminalStream << tab << currentNodeId <<
";" << tab << currentNodeId
152 <<
" [label=\"" << currentNodeId <<
" - ";
158 terminalStream << _fmdp->actionName(*valIter) <<
" ";
161 terminalStream <<
"\"];" << std::endl;
168 const InternalNode* currentNode = _optimalPolicy->node(currentNodeId);
171 nonTerminalStream << tab << currentNodeId <<
";" << tab << currentNodeId
172 <<
" [label=\"" << currentNodeId <<
" - " 173 << currentNode->
nodeVar()->
name() <<
"\"];" << std::endl;
177 for (
Idx sonIter = 0; sonIter < currentNode->
nbSons(); ++sonIter) {
178 if (!visited.
exists(currentNode->
son(sonIter))) {
179 fifo.push(currentNode->
son(sonIter));
180 visited << currentNode->
son(sonIter);
182 if (!sonMap.
exists(currentNode->
son(sonIter)))
184 sonMap[currentNode->
son(sonIter)]->addLink(sonIter);
190 arcstream << tab << currentNodeId <<
" -> " << sonIter.
key()
195 if (modaIter->
nextLink()) arcstream <<
", ";
198 arcstream <<
"\",color=\"#00ff00\"];" << std::endl;
199 delete sonIter.val();
205 output << terminalStream.str() << std::endl
206 << nonTerminalStream.str() << std::endl
207 << arcstream.str() << std::endl
225 template <
typename GUM_SCALAR >
230 __threshold *= (1 - _discountFactor) / (2 * _discountFactor);
233 for (
auto varIter = _fmdp->beginVariables(); varIter != _fmdp->endVariables();
238 _vFunction = _operator->getFunctionInstance();
239 _optimalPolicy = _operator->getAggregatorInstance();
247 template <
typename GUM_SCALAR >
250 this->_initVFunction();
258 GUM_SCALAR gap = __threshold + 1;
259 while ((gap > __threshold) && (nbIte < nbStep)) {
267 _operator->subtract(newVFunction, _vFunction);
271 if (gap < fabs(deltaV->
value())) gap = fabs(deltaV->
value());
275 std::cout <<
" ------------------- Fin itération n° " << nbIte << std::endl
276 <<
" Gap : " << gap <<
" - " << __threshold << std::endl;
281 _vFunction = newVFunction;
294 template <
typename GUM_SCALAR >
296 _vFunction->copy(*(
RECAST(_fmdp->reward())));
311 template <
typename GUM_SCALAR >
317 _operator->getFunctionInstance();
322 std::vector< MultiDimFunctionGraph< GUM_SCALAR >* > qActionsSet;
323 for (
auto actionIter = _fmdp->beginActions();
324 actionIter != _fmdp->endActions();
327 this->_evalQaction(newVFunction, *actionIter);
328 qActionsSet.push_back(qAction);
335 newVFunction = this->_maximiseQactions(qActionsSet);
339 newVFunction = this->_addReward(newVFunction);
348 template <
typename GUM_SCALAR >
357 return _operator->regress(Vold, actionId, this->_fmdp, this->_elVarSeq);
364 template <
typename GUM_SCALAR >
369 qActionsSet.pop_back();
371 while (!qActionsSet.empty()) {
373 qActionsSet.pop_back();
374 newVFunction = _operator->maximize(newVFunction, qAction);
384 template <
typename GUM_SCALAR >
389 qActionsSet.pop_back();
391 while (!qActionsSet.empty()) {
393 qActionsSet.pop_back();
394 newVFunction = _operator->minimize(newVFunction, qAction);
404 template <
typename GUM_SCALAR >
410 _operator->getFunctionInstance();
416 newVFunction = _operator->
add(newVFunction,
RECAST(_fmdp->reward(actionId)));
433 template <
typename GUM_SCALAR >
438 _operator->getFunctionInstance();
441 std::vector< MultiDimFunctionGraph< ArgMaxSet< GUM_SCALAR, Idx >,
446 for (
auto actionIter = _fmdp->beginActions();
447 actionIter != _fmdp->endActions();
450 this->_evalQaction(newVFunction, *actionIter);
452 qAction = this->_addReward(qAction);
454 argMaxQActionsSet.push_back(_makeArgMax(qAction, *actionIter));
463 argMaxVFunction = _argmaximiseQactions(argMaxQActionsSet);
468 _extractOptimalPolicy(argMaxVFunction);
477 template <
typename GUM_SCALAR >
482 amcpy = _operator->getArgMaxFunctionInstance();
489 amcpy->
add(**varIter);
493 __recurArgMaxCopy(qAction->
root(), actionId, qAction, amcpy, src2dest));
503 template <
typename GUM_SCALAR >
511 if (visitedNodes.
exists(currentNodeId))
return visitedNodes[currentNodeId];
516 nody = argMaxCpy->manager()->addTerminalNode(leaf);
522 sonsMap[moda] = __recurArgMaxCopy(
523 currentNode->
son(moda), actionId, src, argMaxCpy, visitedNodes);
525 argMaxCpy->manager()->addInternalNode(currentNode->
nodeVar(), sonsMap);
527 visitedNodes.
insert(currentNodeId, nody);
535 template <
typename GUM_SCALAR >
542 newVFunction = qActionsSet.back();
543 qActionsSet.pop_back();
545 while (!qActionsSet.empty()) {
547 qAction = qActionsSet.back();
548 qActionsSet.pop_back();
549 newVFunction = _operator->argmaximize(newVFunction, qAction);
560 template <
typename GUM_SCALAR >
564 argMaxOptimalValueFunction) {
565 _optimalPolicy->clear();
569 argMaxOptimalValueFunction->variablesSequence().beginSafe();
570 varIter != argMaxOptimalValueFunction->variablesSequence().endSafe();
572 _optimalPolicy->add(**varIter);
575 _optimalPolicy->manager()->setRootNode(__recurExtractOptPol(
576 argMaxOptimalValueFunction->root(), argMaxOptimalValueFunction, src2dest));
578 delete argMaxOptimalValueFunction;
585 template <
typename GUM_SCALAR >
591 if (visitedNodes.
exists(currentNodeId))
return visitedNodes[currentNodeId];
594 if (argMaxOptVFunc->isTerminalNode(currentNodeId)) {
596 __transferActionIds(argMaxOptVFunc->nodeValue(currentNodeId), leaf);
597 nody = _optimalPolicy->manager()->addTerminalNode(leaf);
599 const InternalNode* currentNode = argMaxOptVFunc->node(currentNodeId);
603 sonsMap[moda] = __recurExtractOptPol(
604 currentNode->
son(moda), argMaxOptVFunc, visitedNodes);
605 nody = _optimalPolicy->manager()->addInternalNode(currentNode->
nodeVar(),
608 visitedNodes.
insert(currentNodeId, nody);
615 template <
typename GUM_SCALAR >
SequenceIteratorSafe< GUM_SCALAR_SEQ > endSafe() const
Iterator end.
<agrum/FMDP/planning/structuredPlaner.h>
void nextValue() const
Increments the constant safe iterator.
bool isTerminalNode(const NodeId &node) const
Indicates if given node is terminal or not.
void setRootNode(const NodeId &root)
Sets root node of decision diagram.
void beginValues() const
Initializes the constant safe iterator on terminal nodes.
A class to store the optimal actions.
void copyAndMultiplyByScalar(const MultiDimFunctionGraph< GUM_SCALAR, TerminalNodePolicy > &src, GUM_SCALAR gamma)
Copies src diagrams and multiply every value by the given scalar.
const iterator_safe & endSafe() noexcept
Returns the safe iterator pointing to the end of the hashtable.
const InternalNode * node(NodeId n) const
Returns internalNode structure associated to that nodeId.
const DiscreteVariable * nodeVar() const
Returns the node variable.
Idx nbSons() const
Returns the number of sons.
#define RECAST(x)
For shorter line and hence more comprehensive code purposes only.
void copyAndReassign(const MultiDimFunctionGraph< GUM_SCALAR, TerminalNodePolicy > &src, const Bijection< const DiscreteVariable *, const DiscreteVariable * > &reassign)
Copies src diagrams structure into this diagrams.
<agrum/FMDP/SDyna/IOperatorStrategy.h>
NodeId son(Idx modality) const
Returns the son at a given index.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
const Key & key(const Key &key) const
Returns a reference on a given key.
This class is used to implement factored decision process.
gum is the global namespace for all aGrUM entities
This files contains several function objects that are not (yet) defined in the STL.
The class for generic Hash Tables.
Class to handle efficiently argMaxSet.
SequenceIteratorSafe< Idx > endSafe() const
Iterator end.
Headers of the StructuredPlaner planer class.
const T & element() const
Returns the element stored in this link.
virtual Size domainSize() const =0
StructuredPlaner(IOperatorStrategy< GUM_SCALAR > *opi, GUM_SCALAR discountFactor, GUM_SCALAR epsilon, bool verbose)
Default constructor.
virtual void add(const DiscreteVariable &v)
Adds a new var to the variables of the multidimensional matrix.
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Header of the Potential class.
const GUM_SCALAR & nodeValue(NodeId n) const
Returns value associated to given node.
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
Structure used to represent a node internal structure.
bool hasValue() const
Indicates if constant safe iterator has reach end of terminal nodes list.
Link of a chain list allocated using the SmallObjectAllocator.
const DiscreteVariable * main2prime(const DiscreteVariable *mainVar) const
Returns the primed variable associate to the given main variable.
Header files of gum::Instantiation.
const NodeId & root() const
Returns the id of the root node from the diagram.
SequenceIteratorSafe< GUM_SCALAR_SEQ > beginSafe() const
Iterator beginning.
Headers of MultiDimFunctionGraph.
SequenceIteratorSafe< Idx > beginSafe() const
Iterator beginning.
Implementation of a Terminal Node Policy that maps nodeid to a set of value.
virtual const Sequence< const DiscreteVariable *> & variablesSequence() const override
Returns a const ref to the sequence of DiscreteVariable*.
Chain list allocated using the SmallObjectAllocator.
iterator_safe beginSafe()
Returns the safe iterator pointing to the beginning of the hashtable.
Size Idx
Type for indexes.
MultiDimFunctionGraphManager< GUM_SCALAR, TerminalNodePolicy > * manager()
Returns a const reference to the manager of this diagram.
const GUM_SCALAR & value() const
Returns the value of the current terminal nodes pointed by the constant safe iterator.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
const std::string & name() const
returns the name of the variable
const Link< T > * nextLink() const
Returns next link.
Size NodeId
Type for node ids.