28 #include <agrum/PRM/PRMFactory.h> 33 #include <agrum/tools/core/math/formula.h> 35 #include <agrum/tools/variables/discretizedVariable.h> 36 #include <agrum/tools/variables/rangeVariable.h> 38 #include <agrum/PRM/elements/PRMFormAttribute.h> 39 #include <agrum/PRM/elements/PRMFuncAttribute.h> 45 template <
typename GUM_SCALAR >
46 INLINE
void PRMFactory< GUM_SCALAR >::startClass(
const std::string& name,
47 const std::string& extends,
48 const Set< std::string >* implements,
49 bool delayInheritance) {
50 std::string real_name = _addPrefix_(name);
51 if (_prm_->_classMap_.exists(real_name) || _prm_->_interfaceMap_.exists(real_name)) {
52 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
54 PRMClass< GUM_SCALAR >* c =
nullptr;
55 PRMClass< GUM_SCALAR >* mother =
nullptr;
56 Set< PRMInterface< GUM_SCALAR >* > impl;
58 if (implements != 0) {
59 for (
const auto& imp: *implements) {
60 impl.insert(_retrieveInterface_(imp));
64 if (extends !=
"") { mother = _retrieveClass_(extends); }
66 if ((extends ==
"") && impl.empty()) {
67 c =
new PRMClass< GUM_SCALAR >(real_name);
68 }
else if ((extends !=
"") && impl.empty()) {
69 c =
new PRMClass< GUM_SCALAR >(real_name, *mother, delayInheritance);
70 }
else if ((extends ==
"") && (!impl.empty())) {
71 c =
new PRMClass< GUM_SCALAR >(real_name, impl, delayInheritance);
72 }
else if ((extends !=
"") && (!impl.empty())) {
73 c =
new PRMClass< GUM_SCALAR >(real_name, *mother, impl, delayInheritance);
76 _prm_->_classMap_.insert(c->name(), c);
77 _prm_->_classes_.insert(c);
81 template <
typename GUM_SCALAR >
82 INLINE
void PRMFactory< GUM_SCALAR >::continueClass(
const std::string& name) {
83 std::string real_name = _addPrefix_(name);
84 if (!(_prm_->_classMap_.exists(real_name))) {
85 std::stringstream msg;
86 msg <<
"'" << real_name <<
"' not found";
87 GUM_ERROR(NotFound, msg.str())
89 _stack_.push_back(&(_prm_->getClass(real_name)));
92 template <
typename GUM_SCALAR >
93 INLINE
void PRMFactory< GUM_SCALAR >::endClass(
bool checkImplementations) {
94 PRMClass< GUM_SCALAR >* c
95 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
97 if (checkImplementations) { _checkInterfaceImplementation_(c); }
102 template <
typename GUM_SCALAR >
104 PRMFactory< GUM_SCALAR >::_checkInterfaceImplementation_(PRMClass< GUM_SCALAR >* c) {
106 for (
const auto& i: c->implements()) {
108 for (
const auto& node: i->containerDag().nodes()) {
109 std::string name = i->get(node).name();
111 switch (i->get(node).elt_type()) {
112 case PRMClassElement< GUM_SCALAR >::prm_aggregate:
113 case PRMClassElement< GUM_SCALAR >::prm_attribute: {
114 if ((c->get(name).elt_type() == PRMClassElement< GUM_SCALAR >::prm_attribute)
115 || (c->get(name).elt_type()
116 == PRMClassElement< GUM_SCALAR >::prm_aggregate)) {
117 if (!c->get(name).type().isSubTypeOf(i->get(name).type())) {
118 std::stringstream msg;
119 msg <<
"class " << c->name() <<
" does not respect interface ";
120 GUM_ERROR(PRMTypeError, msg.str() + i->name())
123 std::stringstream msg;
124 msg <<
"class " << c->name() <<
" does not respect interface ";
125 GUM_ERROR(PRMTypeError, msg.str() + i->name())
131 case PRMClassElement< GUM_SCALAR >::prm_refslot: {
132 if (c->get(name).elt_type() == PRMClassElement< GUM_SCALAR >::prm_refslot) {
133 const PRMReferenceSlot< GUM_SCALAR >& ref_i
134 =
static_cast<
const PRMReferenceSlot< GUM_SCALAR >& >(i->get(name));
135 const PRMReferenceSlot< GUM_SCALAR >& ref_this
136 =
static_cast<
const PRMReferenceSlot< GUM_SCALAR >& >(c->get(name));
138 if (!ref_this.slotType().isSubTypeOf(ref_i.slotType())) {
139 std::stringstream msg;
140 msg <<
"class " << c->name() <<
" does not respect interface ";
141 GUM_ERROR(PRMTypeError, msg.str() + i->name())
144 std::stringstream msg;
145 msg <<
"class " << c->name() <<
" does not respect interface ";
146 GUM_ERROR(PRMTypeError, msg.str() + i->name())
152 case PRMClassElement< GUM_SCALAR >::prm_slotchain: {
158 std::string msg =
"unexpected ClassElement<GUM_SCALAR> in interface ";
159 GUM_ERROR(FatalError, msg + i->name())
163 }
catch (NotFound&) {
164 std::stringstream msg;
165 msg <<
"class " << c->name() <<
" does not respect interface ";
166 GUM_ERROR(PRMTypeError, msg.str() + i->name())
169 }
catch (NotFound&) {
175 template <
typename GUM_SCALAR >
176 INLINE
void PRMFactory< GUM_SCALAR >::startInterface(
const std::string& name,
177 const std::string& extends,
178 bool delayInheritance) {
179 std::string real_name = _addPrefix_(name);
180 if (_prm_->_classMap_.exists(real_name) || _prm_->_interfaceMap_.exists(real_name)) {
181 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
183 PRMInterface< GUM_SCALAR >* i =
nullptr;
184 PRMInterface< GUM_SCALAR >* super =
nullptr;
186 if (extends !=
"") { super = _retrieveInterface_(extends); }
188 if (super !=
nullptr) {
189 i =
new PRMInterface< GUM_SCALAR >(real_name, *super, delayInheritance);
191 i =
new PRMInterface< GUM_SCALAR >(real_name);
194 _prm_->_interfaceMap_.insert(i->name(), i);
195 _prm_->_interfaces_.insert(i);
196 _stack_.push_back(i);
199 template <
typename GUM_SCALAR >
200 INLINE
void PRMFactory< GUM_SCALAR >::continueInterface(
const std::string& name) {
201 std::string real_name = _addPrefix_(name);
202 if (!_prm_->_interfaceMap_.exists(real_name)) {
203 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' not found.")
206 PRMInterface< GUM_SCALAR >* i = _retrieveInterface_(real_name);
207 _stack_.push_back(i);
210 template <
typename GUM_SCALAR >
211 INLINE
void PRMFactory< GUM_SCALAR >::addAttribute(PRMAttribute< GUM_SCALAR >* attr) {
212 PRMClass< GUM_SCALAR >* c
213 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
216 const Sequence<
const DiscreteVariable* >& vars = attr->cpf().variablesSequence();
218 for (
const auto& node: c->containerDag().nodes()) {
220 if (vars.exists(&(c->get(node).type().variable()))) {
223 if (&(attr->type().variable()) != &(c->get(node).type().variable())) {
224 c->addArc(c->get(node).safeName(), attr->safeName());
227 }
catch (OperationNotAllowed&) {}
230 if (count != attr->cpf().variablesSequence().size()) {
231 GUM_ERROR(NotFound,
"unable to found all parents of this attribute")
235 template <
typename GUM_SCALAR >
236 INLINE
void PRMFactory< GUM_SCALAR >::_addParent_(PRMClassElementContainer< GUM_SCALAR >* c,
237 PRMAttribute< GUM_SCALAR >* a,
238 const std::string& name) {
240 PRMClassElement< GUM_SCALAR >& elt = c->get(name);
242 switch (elt.elt_type()) {
243 case PRMClassElement< GUM_SCALAR >::prm_refslot: {
244 GUM_ERROR(OperationNotAllowed,
245 "can not add a reference slot as a parent of an attribute")
249 case PRMClassElement< GUM_SCALAR >::prm_slotchain: {
250 if (
static_cast< PRMSlotChain< GUM_SCALAR >& >(elt).isMultiple()) {
251 GUM_ERROR(OperationNotAllowed,
"can not add a multiple slot chain to an attribute")
254 c->addArc(name, a->name());
259 case PRMClassElement< GUM_SCALAR >::prm_attribute:
260 case PRMClassElement< GUM_SCALAR >::prm_aggregate: {
261 c->addArc(name, a->name());
266 GUM_ERROR(FatalError,
"unknown ClassElement<GUM_SCALAR>")
269 }
catch (NotFound&) {
271 PRMSlotChain< GUM_SCALAR >* sc = _buildSlotChain_(c, name);
274 std::string msg =
"found no ClassElement<GUM_SCALAR> with the given name ";
275 GUM_ERROR(NotFound, msg + name)
276 }
else if (!sc->isMultiple()) {
278 c->addArc(sc->name(), a->name());
281 GUM_ERROR(OperationNotAllowed,
282 "Impossible to add a multiple reference slot as" 283 " direct parent of an PRMAttribute<GUM_SCALAR>.");
289 template <
typename GUM_SCALAR >
290 INLINE
void PRMFactory< GUM_SCALAR >::addParent(
const std::string& name) {
291 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(2);
294 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
295 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
296 _addParent_(c, a, name);
297 }
catch (FactoryInvalidState&) {
298 auto agg =
static_cast< PRMAggregate< GUM_SCALAR >* >(
299 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_aggregate));
300 _addParent_(
static_cast< PRMClass< GUM_SCALAR >* >(c), agg, name);
304 template <
typename GUM_SCALAR >
305 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByFloatLines(
const std::vector<
float >& array) {
306 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
307 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
308 _checkStack_(2, PRMObject::prm_type::CLASS);
310 if (a->cpf().domainSize() != array.size()) GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
312 std::vector< GUM_SCALAR > array2(array.begin(), array.end());
313 a->cpf().fillWith(array2);
316 template <
typename GUM_SCALAR >
317 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByLines(
const std::vector< GUM_SCALAR >& array) {
318 auto elt = _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute);
319 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(elt);
320 _checkStack_(2, PRMObject::prm_type::CLASS);
322 if (a->cpf().domainSize() != array.size()) {
323 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
326 a->cpf().fillWith(array);
329 template <
typename GUM_SCALAR >
331 PRMFactory< GUM_SCALAR >::setRawCPFByFloatColumns(
const std::vector<
float >& array) {
332 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
333 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
335 if (a->cpf().domainSize() != array.size()) {
336 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
339 std::vector< GUM_SCALAR > array2(array.begin(), array.end());
340 setRawCPFByColumns(array2);
343 template <
typename GUM_SCALAR >
345 PRMFactory< GUM_SCALAR >::setRawCPFByColumns(
const std::vector< GUM_SCALAR >& array) {
346 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
347 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
349 if (a->cpf().domainSize() != array.size()) {
350 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
353 if (a->cpf().nbrDim() == 1) {
354 setRawCPFByLines(array);
357 Instantiation inst(a->cpf());
359 for (
auto idx = inst.variablesSequence().rbegin(); idx != inst.variablesSequence().rend();
365 auto idx = (std::size_t)0;
366 while ((!jnst.end()) && idx < array.size()) {
368 a->cpf().set(inst, array[idx]);
375 template <
typename GUM_SCALAR >
377 PRMFactory< GUM_SCALAR >::setCPFByFloatRule(
const std::vector< std::string >& parents,
378 const std::vector<
float >& values) {
379 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
380 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
382 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
383 GUM_ERROR(OperationNotAllowed,
"wrong number of parents")
386 if (values.size() != a->type().variable().domainSize()) {
387 GUM_ERROR(OperationNotAllowed,
"wrong number of values")
390 std::vector< GUM_SCALAR > values2(values.begin(), values.end());
391 setCPFByRule(parents, values2);
394 template <
typename GUM_SCALAR >
395 INLINE
void PRMFactory< GUM_SCALAR >::setCPFByRule(
const std::vector< std::string >& parents,
396 const std::vector< GUM_SCALAR >& values) {
397 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
398 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
400 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
401 GUM_ERROR(OperationNotAllowed,
"wrong number of parents")
404 if (values.size() != a->type().variable().domainSize()) {
405 GUM_ERROR(OperationNotAllowed,
"wrong number of values")
408 if (
dynamic_cast< PRMFormAttribute< GUM_SCALAR >* >(a)) {
409 auto form =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(a);
412 Instantiation jnst, knst;
413 const DiscreteVariable* var = 0;
417 for (Idx i = 0; i < parents.size(); ++i) {
418 var = form->formulas().variablesSequence().atPos(1 + i);
420 if (parents[i] ==
"*") {
427 for (Size j = 0; j < var->domainSize(); ++j) {
428 if (var->label(j) == parents[i]) {
429 jnst.chgVal(*var, j);
436 std::string msg =
"could not find label ";
437 GUM_ERROR(NotFound, msg + parents[i])
442 Instantiation inst(form->formulas());
445 for (Size i = 0; i < form->type()->domainSize(); ++i) {
446 inst.chgVal(form->type().variable(), i);
448 for (inst.setFirstIn(knst); !inst.end(); inst.incIn(knst)) {
449 form->formulas().set(inst, std::to_string(values[i]));
454 GUM_ERROR(OperationNotAllowed,
"invalide attribute type")
458 template <
typename GUM_SCALAR >
459 INLINE
void PRMFactory< GUM_SCALAR >::setCPFByRule(
const std::vector< std::string >& parents,
460 const std::vector< std::string >& values) {
461 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
462 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
464 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
465 GUM_ERROR(OperationNotAllowed,
"wrong number of parents")
468 if (values.size() != a->type().variable().domainSize()) {
469 GUM_ERROR(OperationNotAllowed,
"wrong number of values")
472 if (
dynamic_cast< PRMFormAttribute< GUM_SCALAR >* >(a)) {
473 auto form =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(a);
476 Instantiation jnst, knst;
477 const DiscreteVariable* var = 0;
481 for (Idx i = 0; i < parents.size(); ++i) {
482 var = form->formulas().variablesSequence().atPos(1 + i);
484 if (parents[i] ==
"*") {
491 for (Size j = 0; j < var->domainSize(); ++j) {
492 if (var->label(j) == parents[i]) {
493 jnst.chgVal(*var, j);
500 std::string msg =
"could not find label ";
501 GUM_ERROR(NotFound, msg + parents[i])
506 Instantiation inst(form->formulas());
509 for (Size i = 0; i < form->type()->domainSize(); ++i) {
510 inst.chgVal(form->type().variable(), i);
512 for (inst.setFirstIn(knst); !inst.end(); inst.incIn(knst)) {
513 form->formulas().set(inst, values[i]);
518 GUM_ERROR(OperationNotAllowed,
"invalide attribute type")
522 template <
typename GUM_SCALAR >
523 INLINE
void PRMFactory< GUM_SCALAR >::addParameter(
const std::string& type,
524 const std::string& name,
526 auto c =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
528 PRMParameter< GUM_SCALAR >* p =
nullptr;
530 p =
new PRMParameter< GUM_SCALAR >(name,
531 PRMParameter< GUM_SCALAR >::ParameterType::INT,
533 }
else if (type ==
"real") {
534 p =
new PRMParameter< GUM_SCALAR >(name,
535 PRMParameter< GUM_SCALAR >::ParameterType::REAL,
541 }
catch (DuplicateElement&) { c->overload(p); }
544 template <
typename GUM_SCALAR >
546 PRMFactory< GUM_SCALAR >::startAggregator(
const std::string& name,
547 const std::string& agg_type,
548 const std::string& rv_type,
549 const std::vector< std::string >& params) {
550 PRMClass< GUM_SCALAR >* c
551 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
553 auto agg =
new PRMAggregate< GUM_SCALAR >(name,
554 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
555 *_retrieveType_(rv_type));
559 }
catch (DuplicateElement&) { c->overload(agg); }
561 switch (agg->agg_type()) {
562 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT:
563 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
564 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
565 if (params.size() != 1) {
566 GUM_ERROR(OperationNotAllowed,
"aggregate requires a parameter")
568 agg->setLabel(params.front());
575 _stack_.push_back(agg);
578 template <
typename GUM_SCALAR >
579 INLINE
void PRMFactory< GUM_SCALAR >::continueAggregator(
const std::string& name) {
580 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(1);
582 if (!c->exists(name)) GUM_ERROR(NotFound,
"Element " << name <<
"not found")
584 auto& agg = c->get(name);
585 if (!PRMClassElement< GUM_SCALAR >::isAggregate(agg))
586 GUM_ERROR(OperationNotAllowed,
"Element " << name <<
" not an aggregate")
588 _stack_.push_back(&agg);
591 template <
typename GUM_SCALAR >
592 INLINE
void PRMFactory< GUM_SCALAR >::_addParent_(PRMClass< GUM_SCALAR >* c,
593 PRMAggregate< GUM_SCALAR >* agg,
594 const std::string& name) {
595 auto chains = std::vector< std::string >{name};
596 auto inputs = std::vector< PRMClassElement< GUM_SCALAR >* >();
597 _retrieveInputs_(c, chains, inputs);
599 switch (agg->agg_type()) {
600 case PRMAggregate< GUM_SCALAR >::AggregateType::OR:
601 case PRMAggregate< GUM_SCALAR >::AggregateType::AND: {
602 if (inputs.front()->type() != *(_retrieveType_(
"boolean"))) {
603 GUM_ERROR(TypeError,
"expected booleans")
609 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT:
610 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
611 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
612 if (!agg->hasLabel()) {
613 auto param = agg->labelValue();
616 while (label_idx < inputs.front()->type()->domainSize()) {
617 if (inputs.front()->type()->label(label_idx) == param) {
break; }
622 if (label_idx == inputs.front()->type()->domainSize()) {
623 GUM_ERROR(NotFound,
"could not find label")
626 agg->setLabel(label_idx);
632 case PRMAggregate< GUM_SCALAR >::AggregateType::SUM:
633 case PRMAggregate< GUM_SCALAR >::AggregateType::MEDIAN:
634 case PRMAggregate< GUM_SCALAR >::AggregateType::AMPLITUDE:
635 case PRMAggregate< GUM_SCALAR >::AggregateType::MIN:
636 case PRMAggregate< GUM_SCALAR >::AggregateType::MAX: {
641 GUM_ERROR(FatalError,
"Unknown aggregator.")
645 c->addArc(inputs.front()->safeName(), agg->safeName());
648 template <
typename GUM_SCALAR >
649 INLINE
void PRMFactory< GUM_SCALAR >::endAggregator() {
650 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_aggregate);
654 template <
typename GUM_SCALAR >
655 INLINE
void PRMFactory< GUM_SCALAR >::addAggregator(
const std::string& name,
656 const std::string& agg_type,
657 const std::vector< std::string >& chains,
658 const std::vector< std::string >& params,
660 PRMClass< GUM_SCALAR >* c
661 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
664 if (chains.size() == 0) {
665 GUM_ERROR(OperationNotAllowed,
"a PRMAggregate<GUM_SCALAR> requires at least one parent")
669 std::vector< PRMClassElement< GUM_SCALAR >* > inputs;
674 bool hasSC = _retrieveInputs_(c, chains, inputs);
679 if (inputs.size() > 1) {
680 for (
auto iter = inputs.begin() + 1; iter != inputs.end(); ++iter) {
681 if ((**(iter - 1)).type() != (**iter).type()) {
682 GUM_ERROR(TypeError,
"found different types")
688 PRMAggregate< GUM_SCALAR >* agg =
nullptr;
690 switch (PRMAggregate< GUM_SCALAR >::str2enum(agg_type)) {
691 case PRMAggregate< GUM_SCALAR >::AggregateType::OR:
692 case PRMAggregate< GUM_SCALAR >::AggregateType::AND: {
693 if (inputs.front()->type() != *(_retrieveType_(
"boolean"))) {
694 GUM_ERROR(TypeError,
"expected booleans")
696 if (params.size() != 0) { GUM_ERROR(OperationNotAllowed,
"invalid number of paramaters") }
698 agg =
new PRMAggregate< GUM_SCALAR >(name,
699 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
700 inputs.front()->type());
705 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
706 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
707 if (params.size() != 1) { GUM_ERROR(OperationNotAllowed,
"invalid number of parameters") }
711 while (label_idx < inputs.front()->type()->domainSize()) {
712 if (inputs.front()->type()->label(label_idx) == params.front()) {
break; }
717 if (label_idx == inputs.front()->type()->domainSize()) {
718 GUM_ERROR(NotFound,
"could not find label")
722 agg =
new PRMAggregate< GUM_SCALAR >(name,
723 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
724 *(_retrieveType_(
"boolean")),
731 case PRMAggregate< GUM_SCALAR >::AggregateType::SUM:
732 case PRMAggregate< GUM_SCALAR >::AggregateType::MEDIAN:
733 case PRMAggregate< GUM_SCALAR >::AggregateType::AMPLITUDE:
734 case PRMAggregate< GUM_SCALAR >::AggregateType::MIN:
735 case PRMAggregate< GUM_SCALAR >::AggregateType::MAX: {
736 if (params.size() != 0) { GUM_ERROR(OperationNotAllowed,
"invalid number of parameters") }
738 auto output_type = _retrieveType_(type);
741 agg =
new PRMAggregate< GUM_SCALAR >(name,
742 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
748 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT: {
749 if (params.size() != 1) { GUM_ERROR(OperationNotAllowed,
"invalid number of parameters") }
753 while (label_idx < inputs.front()->type()->domainSize()) {
754 if (inputs.front()->type()->label(label_idx) == params.front()) {
break; }
759 if (label_idx == inputs.front()->type()->domainSize()) {
760 GUM_ERROR(NotFound,
"could not find label")
763 auto output_type = _retrieveType_(type);
766 agg =
new PRMAggregate< GUM_SCALAR >(name,
767 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
775 GUM_ERROR(FatalError,
"Unknown aggregator.")
779 std::string safe_name = agg->safeName();
785 }
catch (DuplicateElement&) { c->overload(agg); }
789 =
new PRMScalarAttribute< GUM_SCALAR >(agg->name(), agg->type(), agg->buildImpl());
793 }
catch (DuplicateElement&) { c->overload(attr); }
797 }
catch (DuplicateElement&) {
802 for (
const auto& elt: inputs) {
803 c->addArc(elt->safeName(), safe_name);
807 template <
typename GUM_SCALAR >
808 INLINE
void PRMFactory< GUM_SCALAR >::addReferenceSlot(
const std::string& type,
809 const std::string& name,
811 PRMClassElementContainer< GUM_SCALAR >* owner = _checkStackContainter_(1);
812 PRMClassElementContainer< GUM_SCALAR >* slotType = 0;
815 slotType = _retrieveClass_(type);
816 }
catch (NotFound&) {
818 slotType = _retrieveInterface_(type);
819 }
catch (NotFound&) { GUM_ERROR(NotFound,
"unknown ReferenceSlot<GUM_SCALAR> slot type") }
822 PRMReferenceSlot< GUM_SCALAR >* ref
823 =
new PRMReferenceSlot< GUM_SCALAR >(name, *slotType, isArray);
827 }
catch (DuplicateElement&) { owner->overload(ref); }
830 template <
typename GUM_SCALAR >
831 INLINE
void PRMFactory< GUM_SCALAR >::addArray(
const std::string& type,
832 const std::string& name,
834 PRMSystem< GUM_SCALAR >* model
835 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
836 PRMClass< GUM_SCALAR >* c = _retrieveClass_(type);
837 PRMInstance< GUM_SCALAR >* inst = 0;
840 model->addArray(name, *c);
842 for (Size i = 0; i < size; ++i) {
843 std::stringstream elt_name;
844 elt_name << name <<
"[" << i <<
"]";
845 inst =
new PRMInstance< GUM_SCALAR >(elt_name.str(), *c);
846 model->add(name, inst);
848 }
catch (PRMTypeError&) {
851 }
catch (NotFound&) {
857 template <
typename GUM_SCALAR >
858 INLINE
void PRMFactory< GUM_SCALAR >::incArray(
const std::string& l_i,
const std::string& r_i) {
859 PRMSystem< GUM_SCALAR >* model
860 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
862 if (model->isArray(l_i)) {
863 if (model->isInstance(r_i)) {
864 model->add(l_i, model->get(r_i));
866 GUM_ERROR(NotFound,
"right value is not an instance")
869 GUM_ERROR(NotFound,
"left value is no an array")
873 template <
typename GUM_SCALAR >
874 INLINE
void PRMFactory< GUM_SCALAR >::setReferenceSlot(
const std::string& l_i,
875 const std::string& l_ref,
876 const std::string& r_i) {
878 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
879 std::vector< PRMInstance< GUM_SCALAR >* > lefts;
880 std::vector< PRMInstance< GUM_SCALAR >* > rights;
882 if (model->isInstance(l_i)) {
883 lefts.push_back(&(model->get(l_i)));
884 }
else if (model->isArray(l_i)) {
885 for (
const auto& elt: model->getArray(l_i))
886 lefts.push_back(elt);
888 GUM_ERROR(NotFound,
"left value does not name an instance or an array")
891 if (model->isInstance(r_i)) {
892 rights.push_back(&(model->get(r_i)));
893 }
else if (model->isArray(r_i)) {
894 for (
const auto& elt: model->getArray(r_i))
895 rights.push_back(elt);
897 GUM_ERROR(NotFound,
"left value does not name an instance or an array")
900 for (
const auto l: lefts) {
901 for (
const auto r: rights) {
902 auto& elt = l->type().get(l_ref);
903 if (PRMClassElement< GUM_SCALAR >::isReferenceSlot(elt)) {
904 l->add(elt.id(), *r);
907 GUM_ERROR(NotFound,
"unfound reference slot")
913 template <
typename GUM_SCALAR >
914 INLINE PRMSlotChain< GUM_SCALAR >*
915 PRMFactory< GUM_SCALAR >::_buildSlotChain_(PRMClassElementContainer< GUM_SCALAR >* start,
916 const std::string& name) {
917 std::vector< std::string > v;
918 decomposePath(name, v);
919 PRMClassElementContainer< GUM_SCALAR >* current = start;
920 PRMReferenceSlot< GUM_SCALAR >* ref =
nullptr;
921 Sequence< PRMClassElement< GUM_SCALAR >* > elts;
923 for (size_t i = 0; i < v.size(); ++i) {
925 switch (current->get(v[i]).elt_type()) {
926 case PRMClassElement< GUM_SCALAR >::prm_refslot:
927 ref = &(
static_cast< PRMReferenceSlot< GUM_SCALAR >& >(current->get(v[i])));
929 current = &( (ref->slotType()));
932 case PRMClassElement< GUM_SCALAR >::prm_aggregate:
933 case PRMClassElement< GUM_SCALAR >::prm_attribute:
935 if (i == v.size() - 1) {
936 elts.insert(&(current->get(v[i])));
946 }
catch (NotFound&) {
return nullptr; }
949 GUM_ASSERT(v.size() == elts.size());
951 current->setOutputNode(*(elts.back()),
true);
953 return new PRMSlotChain< GUM_SCALAR >(name, elts);
956 template <
typename GUM_SCALAR >
957 INLINE
bool PRMFactory< GUM_SCALAR >::_retrieveInputs_(
958 PRMClass< GUM_SCALAR >* c,
959 const std::vector< std::string >& chains,
960 std::vector< PRMClassElement< GUM_SCALAR >* >& inputs) {
963 for (size_t i = 0; i < chains.size(); ++i) {
965 inputs.push_back(&(c->get(chains[i])));
966 retVal = retVal || PRMClassElement< GUM_SCALAR >::isSlotChain(*(inputs.back()));
967 }
catch (NotFound&) {
968 inputs.push_back(_buildSlotChain_(c, chains[i]));
972 c->add(inputs.back());
974 GUM_ERROR(NotFound,
"unknown slot chain")
979 PRMType* t = _retrieveCommonType_(inputs);
981 std::vector< std::pair< PRMClassElement< GUM_SCALAR >*, PRMClassElement< GUM_SCALAR >* > >
984 for (
const auto& elt: inputs) {
985 if ((*elt).type() != (*t)) {
986 if (PRMClassElement< GUM_SCALAR >::isSlotChain(*elt)) {
987 PRMSlotChain< GUM_SCALAR >* sc =
static_cast< PRMSlotChain< GUM_SCALAR >* >(elt);
988 std::stringstream name;
990 for (Size idx = 0; idx < sc->chain().size() - 1; ++idx) {
991 name << sc->chain().atPos(idx)->name() <<
".";
994 name <<
".(" << t->name() <<
")" << sc->lastElt().name();
997 toAdd.push_back(std::make_pair(elt, &(c->get(name.str()))));
998 }
catch (NotFound&) {
999 toAdd.push_back(std::make_pair(elt, _buildSlotChain_(c, name.str())));
1002 std::stringstream name;
1003 name <<
"(" << t->name() <<
")" << elt->name();
1004 toAdd.push_back(std::make_pair(elt, &(c->get(name.str()))));
1012 template <
typename GUM_SCALAR >
1013 INLINE PRMType* PRMFactory< GUM_SCALAR >::_retrieveCommonType_(
1014 const std::vector< PRMClassElement< GUM_SCALAR >* >& elts) {
1015 const PRMType* current =
nullptr;
1016 HashTable< std::string, Size > counters;
1019 for (
const auto& elt: elts) {
1021 current = &((*elt).type());
1023 while (current != 0) {
1025 if (counters.exists(current->name())) {
1026 ++(counters[current->name()]);
1028 counters.insert(current->name(), 1);
1032 if (current->isSubType()) {
1033 current = &(current->superType());
1038 }
catch (OperationNotAllowed&) {
1039 GUM_ERROR(WrongClassElement,
"found a ClassElement<GUM_SCALAR> without a type")
1048 int current_depth = 0;
1050 for (
const auto& elt: counters) {
1051 if ((elt.second) == elts.size()) {
1052 current_depth = _typeDepth_(_retrieveType_(elt.first));
1054 if (current_depth > max_depth) {
1055 max_depth = current_depth;
1056 current = _retrieveType_(elt.first);
1061 if (current) {
return const_cast< PRMType* >(current); }
1063 GUM_ERROR(NotFound,
"could not find a common type")
1066 template <
typename GUM_SCALAR >
1068 PRMFactory< GUM_SCALAR >::addNoisyOrCompound(
const std::string& name,
1069 const std::vector< std::string >& chains,
1070 const std::vector<
float >& numbers,
1072 const std::vector< std::string >& labels) {
1073 if (currentType() != PRMObject::prm_type::CLASS) {
1074 GUM_ERROR(gum::FactoryInvalidState,
"invalid state to add a noisy-or")
1077 PRMClass< GUM_SCALAR >* c =
dynamic_cast< gum::prm::PRMClass< GUM_SCALAR >* >(getCurrent());
1079 std::vector< PRMClassElement< GUM_SCALAR >* > parents;
1081 for (
const auto& elt: chains)
1082 parents.push_back(&(c->get(elt)));
1084 PRMType* common_type = _retrieveCommonType_(parents);
1086 for (size_t idx = 0; idx < parents.size(); ++idx) {
1087 if (parents[idx]->type() != (*common_type)) {
1088 PRMClassElement< GUM_SCALAR >* parent = parents[idx];
1091 std::string safe_name = parent->cast(*common_type);
1093 if (!c->exists(safe_name)) {
1094 if (PRMClassElement< GUM_SCALAR >::isSlotChain(*parent)) {
1095 parents[idx] = _buildSlotChain_(c, safe_name);
1096 c->add(parents[idx]);
1098 GUM_ERROR(NotFound,
"unable to find parent")
1101 parents[idx] = &(c->get(safe_name));
1106 if (numbers.size() == 1) {
1107 auto impl =
new gum::MultiDimNoisyORCompound< GUM_SCALAR >(leak, numbers.front());
1108 auto attr =
new PRMScalarAttribute< GUM_SCALAR >(name, retrieveType(
"boolean"), impl);
1110 }
else if (numbers.size() == parents.size()) {
1111 gum::MultiDimNoisyORCompound< GUM_SCALAR >* noisy
1112 =
new gum::MultiDimNoisyORCompound< GUM_SCALAR >(leak);
1113 gum::prm::PRMFuncAttribute< GUM_SCALAR >* attr
1114 =
new gum::prm::PRMFuncAttribute< GUM_SCALAR >(name, retrieveType(
"boolean"), noisy);
1116 for (size_t idx = 0; idx < numbers.size(); ++idx) {
1117 noisy->causalWeight(parents[idx]->type().variable(), numbers[idx]);
1122 GUM_ERROR(OperationNotAllowed,
"invalid parameters for a noisy or")
1125 if (!labels.empty()) {
1126 GUM_ERROR(OperationNotAllowed,
"labels definitions not handle for noisy-or")
1130 template <
typename GUM_SCALAR >
1131 INLINE PRMType* PRMFactory< GUM_SCALAR >::_retrieveType_(
const std::string& name)
const {
1132 PRMType* type =
nullptr;
1133 std::string full_name;
1136 if (_prm_->_typeMap_.exists(name)) {
1137 type = _prm_->_typeMap_[name];
1142 std::string prefixed = _addPrefix_(name);
1143 if (_prm_->_typeMap_.exists(prefixed)) {
1145 type = _prm_->_typeMap_[prefixed];
1146 full_name = prefixed;
1147 }
else if (full_name != prefixed) {
1148 GUM_ERROR(DuplicateElement,
"Type name '" << name <<
"' is ambiguous: specify full name.")
1153 std::string relatif_ns = currentPackage();
1154 size_t last_dot = relatif_ns.find_last_of(
'.');
1155 if (last_dot != std::string::npos) {
1156 relatif_ns = relatif_ns.substr(0, last_dot) +
'.' + name;
1157 if (_prm_->_typeMap_.exists(relatif_ns)) {
1159 type = _prm_->_typeMap_[relatif_ns];
1160 full_name = relatif_ns;
1161 }
else if (full_name != relatif_ns) {
1162 GUM_ERROR(DuplicateElement,
1163 "Type name '" << name <<
"' is ambiguous: specify full name.");
1170 if (!_namespaces_.empty()) {
1171 auto ns_list = _namespaces_.back();
1172 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1173 std::string ns = (*ns_list)[i];
1174 std::string ns_name = ns +
"." + name;
1175 if (_prm_->_typeMap_.exists(ns_name)) {
1177 type = _prm_->_typeMap_[ns_name];
1178 full_name = ns_name;
1179 }
else if (full_name != ns_name) {
1180 GUM_ERROR(DuplicateElement,
1181 "Type name '" << name <<
"' is ambiguous: specify full name.");
1187 if (type == 0) { GUM_ERROR(NotFound,
"Type '" << name <<
"' not found, check imports.") }
1192 template <
typename GUM_SCALAR >
1193 PRMClass< GUM_SCALAR >*
1194 PRMFactory< GUM_SCALAR >::_retrieveClass_(
const std::string& name)
const {
1195 PRMClass< GUM_SCALAR >* a_class =
nullptr;
1196 std::string full_name;
1199 if (_prm_->_classMap_.exists(name)) {
1200 a_class = _prm_->_classMap_[name];
1205 std::string prefixed = _addPrefix_(name);
1206 if (_prm_->_classMap_.exists(prefixed)) {
1207 if (a_class ==
nullptr) {
1208 a_class = _prm_->_classMap_[prefixed];
1209 full_name = prefixed;
1210 }
else if (full_name != prefixed) {
1211 GUM_ERROR(DuplicateElement,
1212 "Class name '" << name <<
"' is ambiguous: specify full name.");
1217 if (!_namespaces_.empty()) {
1218 auto ns_list = _namespaces_.back();
1219 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1220 std::string ns = (*ns_list)[i];
1221 std::string ns_name = ns +
"." + name;
1222 if (_prm_->_classMap_.exists(ns_name)) {
1224 a_class = _prm_->_classMap_[ns_name];
1225 full_name = ns_name;
1226 }
else if (full_name != ns_name) {
1227 GUM_ERROR(DuplicateElement,
1228 "Class name '" << name <<
"' is ambiguous: specify full name.");
1234 if (a_class == 0) { GUM_ERROR(NotFound,
"Class '" << name <<
"' not found, check imports.") }
1239 template <
typename GUM_SCALAR >
1240 PRMInterface< GUM_SCALAR >*
1241 PRMFactory< GUM_SCALAR >::_retrieveInterface_(
const std::string& name)
const {
1242 PRMInterface< GUM_SCALAR >* interface =
nullptr;
1243 std::string full_name;
1246 if (_prm_->_interfaceMap_.exists(name)) {
1247 interface = _prm_->_interfaceMap_[name];
1252 std::string prefixed = _addPrefix_(name);
1253 if (_prm_->_interfaceMap_.exists(prefixed)) {
1254 if (interface ==
nullptr) {
1255 interface = _prm_->_interfaceMap_[prefixed];
1256 full_name = prefixed;
1257 }
else if (full_name != prefixed) {
1258 GUM_ERROR(DuplicateElement,
1259 "Interface name '" << name <<
"' is ambiguous: specify full name.");
1264 if (!_namespaces_.empty()) {
1265 auto ns_list = _namespaces_.back();
1267 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1268 std::string ns = (*ns_list)[i];
1269 std::string ns_name = ns +
"." + name;
1271 if (_prm_->_interfaceMap_.exists(ns_name)) {
1272 if (interface ==
nullptr) {
1273 interface = _prm_->_interfaceMap_[ns_name];
1274 full_name = ns_name;
1275 }
else if (full_name != ns_name) {
1276 GUM_ERROR(DuplicateElement,
1277 "Interface name '" << name <<
"' is ambiguous: specify full name.");
1283 if (interface ==
nullptr) {
1284 GUM_ERROR(NotFound,
"Interface '" << name <<
"' not found, check imports.")
1290 template <
typename GUM_SCALAR >
1291 INLINE PRMFactory< GUM_SCALAR >::PRMFactory() {
1292 GUM_CONSTRUCTOR(PRMFactory);
1293 _prm_ =
new PRM< GUM_SCALAR >();
1296 template <
typename GUM_SCALAR >
1297 INLINE PRMFactory< GUM_SCALAR >::PRMFactory(PRM< GUM_SCALAR >* prm) :
1298 IPRMFactory(), _prm_(prm) {
1299 GUM_CONSTRUCTOR(PRMFactory);
1302 template <
typename GUM_SCALAR >
1303 INLINE PRMFactory< GUM_SCALAR >::~PRMFactory() {
1304 GUM_DESTRUCTOR(PRMFactory);
1305 while (!_namespaces_.empty()) {
1306 auto ns = _namespaces_.back();
1307 _namespaces_.pop_back();
1312 template <
typename GUM_SCALAR >
1313 INLINE PRM< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::prm()
const {
1317 template <
typename GUM_SCALAR >
1318 INLINE PRMObject::prm_type PRMFactory< GUM_SCALAR >::currentType()
const {
1319 if (_stack_.size() == 0) { GUM_ERROR(NotFound,
"no object being built") }
1321 return _stack_.back()->obj_type();
1324 template <
typename GUM_SCALAR >
1325 INLINE PRMObject* PRMFactory< GUM_SCALAR >::getCurrent() {
1326 if (_stack_.size() == 0) { GUM_ERROR(NotFound,
"no object being built") }
1328 return _stack_.back();
1331 template <
typename GUM_SCALAR >
1332 INLINE
const PRMObject* PRMFactory< GUM_SCALAR >::getCurrent()
const {
1333 if (_stack_.size() == 0) { GUM_ERROR(NotFound,
"no object being built") }
1335 return _stack_.back();
1338 template <
typename GUM_SCALAR >
1339 INLINE PRMObject* PRMFactory< GUM_SCALAR >::closeCurrent() {
1340 if (_stack_.size() > 0) {
1341 PRMObject* obj = _stack_.back();
1349 template <
typename GUM_SCALAR >
1350 INLINE std::string PRMFactory< GUM_SCALAR >::currentPackage()
const {
1351 return (_packages_.empty()) ?
"" : _packages_.back();
1354 template <
typename GUM_SCALAR >
1355 INLINE
void PRMFactory< GUM_SCALAR >::startDiscreteType(
const std::string& name,
1356 std::string super) {
1357 std::string real_name = _addPrefix_(name);
1358 if (_prm_->_typeMap_.exists(real_name)) {
1359 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
1362 auto t =
new PRMType(LabelizedVariable(real_name,
"", 0));
1363 _stack_.push_back(t);
1365 auto t =
new PRMType(LabelizedVariable(real_name,
"", 0));
1366 t->_superType_ = _retrieveType_(super);
1367 t->_label_map_ =
new std::vector< Idx >();
1368 _stack_.push_back(t);
1372 template <
typename GUM_SCALAR >
1373 INLINE
void PRMFactory< GUM_SCALAR >::addLabel(
const std::string& l, std::string extends) {
1374 if (extends ==
"") {
1375 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1376 LabelizedVariable* var =
dynamic_cast< LabelizedVariable* >(t->_var_);
1379 GUM_ERROR(FatalError,
"the current type's variable is not a LabelizedVariable.")
1380 }
else if (t->_superType_) {
1381 GUM_ERROR(OperationNotAllowed,
"current type is a subtype.")
1386 }
catch (DuplicateElement&) {
1387 GUM_ERROR(DuplicateElement,
"a label '" << l <<
"' already exists")
1390 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1391 LabelizedVariable* var =
dynamic_cast< LabelizedVariable* >(t->_var_);
1394 GUM_ERROR(FatalError,
"the current type's variable is not a LabelizedVariable.")
1395 }
else if (!t->_superType_) {
1396 GUM_ERROR(OperationNotAllowed,
"current type is not a subtype.")
1401 for (Idx i = 0; i < t->_superType_->_var_->domainSize(); ++i) {
1402 if (t->_superType_->_var_->label(i) == extends) {
1405 }
catch (DuplicateElement&) {
1406 GUM_ERROR(DuplicateElement,
"a label '" << l <<
"' already exists")
1409 t->_label_map_->push_back(i);
1416 if (!found) { GUM_ERROR(NotFound,
"inexistent label in super type.") }
1420 template <
typename GUM_SCALAR >
1421 INLINE
void PRMFactory< GUM_SCALAR >::endDiscreteType() {
1422 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1424 if (!t->_isValid_()) {
1425 GUM_ERROR(OperationNotAllowed,
"current type is not a valid subtype")
1426 }
else if (t->variable().domainSize() < 2) {
1427 GUM_ERROR(OperationNotAllowed,
"current type is not a valid discrete type")
1430 _prm_->_typeMap_.insert(t->name(), t);
1432 _prm_->_types_.insert(t);
1436 template <
typename GUM_SCALAR >
1437 INLINE
void PRMFactory< GUM_SCALAR >::startDiscretizedType(
const std::string& name) {
1438 std::string real_name = _addPrefix_(name);
1439 if (_prm_->_typeMap_.exists(real_name)) {
1440 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
1442 auto var = DiscretizedVariable<
double >(real_name,
"");
1443 auto t =
new PRMType(var);
1444 _stack_.push_back(t);
1447 template <
typename GUM_SCALAR >
1448 INLINE
void PRMFactory< GUM_SCALAR >::addTick(
double tick) {
1449 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1450 DiscretizedVariable<
double >* var =
dynamic_cast< DiscretizedVariable<
double >* >(t->_var_);
1452 if (!var) { GUM_ERROR(FatalError,
"the current type's variable is not a LabelizedVariable.") }
1456 }
catch (DefaultInLabel&) {
1457 GUM_ERROR(OperationNotAllowed,
"tick already in used for this variable")
1461 template <
typename GUM_SCALAR >
1462 INLINE
void PRMFactory< GUM_SCALAR >::endDiscretizedType() {
1463 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1465 if (t->variable().domainSize() < 2) {
1466 GUM_ERROR(OperationNotAllowed,
"current type is not a valid discrete type")
1469 _prm_->_typeMap_.insert(t->name(), t);
1471 _prm_->_types_.insert(t);
1475 template <
typename GUM_SCALAR >
1477 PRMFactory< GUM_SCALAR >::addRangeType(
const std::string& name,
long minVal,
long maxVal) {
1478 std::string real_name = _addPrefix_(name);
1479 if (_prm_->_typeMap_.exists(real_name)) {
1480 std::stringstream msg;
1481 msg <<
"\"" << real_name <<
"' is already used.";
1482 GUM_ERROR(DuplicateElement, msg.str())
1485 auto var = RangeVariable(real_name,
"", minVal, maxVal);
1486 auto t =
new PRMType(var);
1488 if (t->variable().domainSize() < 2) {
1489 GUM_ERROR(OperationNotAllowed,
"current type is not a valid discrete type")
1492 _prm_->_typeMap_.insert(t->name(), t);
1493 _prm_->_types_.insert(t);
1496 template <
typename GUM_SCALAR >
1497 INLINE
void PRMFactory< GUM_SCALAR >::endInterface() {
1498 _checkStack_(1, PRMObject::prm_type::PRM_INTERFACE);
1502 template <
typename GUM_SCALAR >
1503 INLINE
void PRMFactory< GUM_SCALAR >::addAttribute(
const std::string& type,
1504 const std::string& name) {
1505 _checkStack_(1, PRMObject::prm_type::PRM_INTERFACE);
1506 startAttribute(type, name);
1510 template <
typename GUM_SCALAR >
1511 INLINE
void PRMFactory< GUM_SCALAR >::startAttribute(
const std::string& type,
1512 const std::string& name,
1514 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(1);
1515 PRMAttribute< GUM_SCALAR >* a =
nullptr;
1517 if (PRMObject::isClass(*c) && (!scalar_attr)) {
1518 a =
new PRMFormAttribute< GUM_SCALAR >(
static_cast< PRMClass< GUM_SCALAR >& >(*c),
1520 *_retrieveType_(type));
1523 a =
new PRMScalarAttribute< GUM_SCALAR >(name, *_retrieveType_(type));
1526 std::string dot =
".";
1531 }
catch (DuplicateElement&) { c->overload(a); }
1532 }
catch (Exception&) {
1533 if (a !=
nullptr && (!c->exists(a->id()))) {
delete a; }
1536 _stack_.push_back(a);
1539 template <
typename GUM_SCALAR >
1540 INLINE
void PRMFactory< GUM_SCALAR >::continueAttribute(
const std::string& name) {
1541 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(1);
1543 if (!c->exists(name)) GUM_ERROR(NotFound,
"Attribute " << name <<
"not found")
1545 auto& a = c->get(name);
1547 if (!PRMClassElement< GUM_SCALAR >::isAttribute(a))
1548 GUM_ERROR(OperationNotAllowed,
"Element " << name <<
" not an attribute")
1550 _stack_.push_back(&a);
1553 template <
typename GUM_SCALAR >
1554 INLINE
void PRMFactory< GUM_SCALAR >::endAttribute() {
1555 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute);
1559 template <
typename GUM_SCALAR >
1560 INLINE
void PRMFactory< GUM_SCALAR >::startSystem(
const std::string& name) {
1561 if (_prm_->_systemMap_.exists(name)) {
1562 GUM_ERROR(DuplicateElement,
"'" << name <<
"' is already used.")
1564 PRMSystem< GUM_SCALAR >* model =
new PRMSystem< GUM_SCALAR >(_addPrefix_(name));
1565 _stack_.push_back(model);
1566 _prm_->_systemMap_.insert(model->name(), model);
1567 _prm_->_systems_.insert(model);
1570 template <
typename GUM_SCALAR >
1571 INLINE
void PRMFactory< GUM_SCALAR >::endSystem() {
1573 PRMSystem< GUM_SCALAR >* model
1574 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
1576 model->instantiate();
1577 }
catch (Exception&) { GUM_ERROR(FatalError,
"could not create system") }
1580 template <
typename GUM_SCALAR >
1581 INLINE
void PRMFactory< GUM_SCALAR >::addInstance(
const std::string& type,
1582 const std::string& name) {
1583 auto c = _retrieveClass_(type);
1586 if (c->parameters().size() > 0) {
1587 HashTable< std::string,
double > params;
1588 addInstance(type, name, params);
1591 _addInstance_(c, name);
1595 template <
typename GUM_SCALAR >
1597 PRMFactory< GUM_SCALAR >::addInstance(
const std::string& type,
1598 const std::string& name,
1599 const HashTable< std::string,
double >& params) {
1600 auto c = _retrieveClass_(type);
1602 if (c->parameters().empty()) {
1603 if (params.empty()) {
1604 _addInstance_(c, name);
1606 GUM_ERROR(OperationNotAllowed,
"Class " + type +
" does not have parameters")
1610 auto my_params = params;
1612 for (
const auto& p: c->parameters()) {
1613 if (!my_params.exists(p->name())) { my_params.insert(p->name(), p->value()); }
1617 std::stringstream sBuff;
1618 sBuff << c->name() <<
"<";
1620 for (
const auto& p: my_params) {
1621 sBuff << p.first <<
"=" << p.second <<
",";
1625 std::string sub_c = sBuff.str().substr(0, sBuff.str().size() - 1) +
">";
1629 auto pck_cpy = _packages_;
1632 startClass(sub_c, c->name());
1635 for (
auto p: my_params) {
1636 auto type =
static_cast< PRMParameter< GUM_SCALAR >& >(c->get(p.first)).valueType();
1637 if (type == PRMParameter< GUM_SCALAR >::ParameterType::INT) {
1638 addParameter(
"int", p.first, p.second);
1641 addParameter(
"real", p.first, p.second);
1647 _packages_ = pck_cpy;
1649 }
catch (DuplicateElement&) {
1652 c = _retrieveClass_(sub_c);
1653 _addInstance_(c, name);
1657 template <
typename GUM_SCALAR >
1658 INLINE
void PRMFactory< GUM_SCALAR >::_addInstance_(PRMClass< GUM_SCALAR >* type,
1659 const std::string& name) {
1660 PRMInstance< GUM_SCALAR >* i =
nullptr;
1663 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
1664 i =
new PRMInstance< GUM_SCALAR >(name, *type);
1667 }
catch (OperationNotAllowed&) {
1668 if (i) {
delete i; }
1673 template <
typename GUM_SCALAR >
1674 INLINE std::string PRMFactory< GUM_SCALAR >::_addPrefix_(
const std::string& str)
const {
1675 if (!_packages_.empty()) {
1676 std::string full_name = _packages_.back();
1677 full_name.append(
".");
1678 full_name.append(str);
1685 template <
typename GUM_SCALAR >
1686 INLINE PRMObject* PRMFactory< GUM_SCALAR >::_checkStack_(Idx i, PRMObject::prm_type obj_type) {
1688 if (_stack_.size() - i > _stack_.size()) {
1689 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1692 PRMObject* obj = _stack_[_stack_.size() - i];
1694 if (obj->obj_type() != obj_type) {
1695 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1701 template <
typename GUM_SCALAR >
1702 INLINE PRMClassElementContainer< GUM_SCALAR >*
1703 PRMFactory< GUM_SCALAR >::_checkStackContainter_(Idx i) {
1705 if (_stack_.size() - i > _stack_.size()) {
1706 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1709 PRMObject* obj = _stack_[_stack_.size() - i];
1711 if ((obj->obj_type() == PRMObject::prm_type::CLASS)
1712 || (obj->obj_type() == PRMObject::prm_type::PRM_INTERFACE)) {
1713 return static_cast< PRMClassElementContainer< GUM_SCALAR >* >(obj);
1715 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1719 template <
typename GUM_SCALAR >
1720 INLINE PRMClassElement< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::_checkStack_(
1722 typename PRMClassElement< GUM_SCALAR >::ClassElementType elt_type) {
1724 if (_stack_.size() - i > _stack_.size()) {
1725 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1728 PRMClassElement< GUM_SCALAR >* obj
1729 =
dynamic_cast< PRMClassElement< GUM_SCALAR >* >(_stack_[_stack_.size() - i]);
1731 if (obj == 0) { GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls") }
1733 if (obj->elt_type() != elt_type) {
1734 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1740 template <
typename GUM_SCALAR >
1741 INLINE
int PRMFactory< GUM_SCALAR >::_typeDepth_(
const PRMType* t) {
1743 const PRMType* current = t;
1745 while (current->isSubType()) {
1747 current = &(current->superType());
1753 template <
typename GUM_SCALAR >
1754 INLINE
void PRMFactory< GUM_SCALAR >::pushPackage(
const std::string& name) {
1755 _packages_.push_back(name);
1756 _namespaces_.push_back(
new List< std::string >());
1759 template <
typename GUM_SCALAR >
1760 INLINE std::string PRMFactory< GUM_SCALAR >::popPackage() {
1761 std::string plop = currentPackage();
1763 if (!_packages_.empty()) {
1764 std::string s = _packages_.back();
1765 _packages_.pop_back();
1767 if (_namespaces_.size() > 0) {
1768 delete _namespaces_.back();
1769 _namespaces_.pop_back();
1777 template <
typename GUM_SCALAR >
1778 INLINE
void PRMFactory< GUM_SCALAR >::addImport(
const std::string& name) {
1779 if (name.size() == 0) { GUM_ERROR(OperationNotAllowed,
"illegal import name") }
1780 if (_namespaces_.empty()) { _namespaces_.push_back(
new List< std::string >()); }
1781 _namespaces_.back()->push_back(name);
1784 template <
typename GUM_SCALAR >
1785 INLINE
void PRMFactory< GUM_SCALAR >::setReferenceSlot(
const std::string& l_i,
1786 const std::string& r_i) {
1787 size_t pos = l_i.find_last_of(
'.');
1789 if (pos != std::string::npos) {
1790 std::string l_ref = l_i.substr(pos + 1, std::string::npos);
1791 setReferenceSlot(l_i.substr(0, pos), l_ref, r_i);
1793 GUM_ERROR(NotFound,
"left value does not name an instance or an array")
1797 template <
typename GUM_SCALAR >
1798 INLINE PRMClass< GUM_SCALAR >&
1799 PRMFactory< GUM_SCALAR >::retrieveClass(
const std::string& name) {
1800 return *_retrieveClass_(name);
1803 template <
typename GUM_SCALAR >
1804 INLINE PRMType& PRMFactory< GUM_SCALAR >::retrieveType(
const std::string& name) {
1805 return *_retrieveType_(name);
1808 template <
typename GUM_SCALAR >
1809 INLINE PRMType& PRMFactory< GUM_SCALAR >::retrieveCommonType(
1810 const std::vector< PRMClassElement< GUM_SCALAR >* >& elts) {
1811 return *(_retrieveCommonType_(elts));
1815 template <
typename GUM_SCALAR >
1816 INLINE
bool PRMFactory< GUM_SCALAR >::isClassOrInterface(
const std::string& type)
const {
1818 _retrieveClass_(type);
1821 }
catch (NotFound&) {
1822 }
catch (DuplicateElement&) {}
1825 _retrieveInterface_(type);
1828 }
catch (NotFound&) {
1829 }
catch (DuplicateElement&) {}
1834 template <
typename GUM_SCALAR >
1835 INLINE
bool PRMFactory< GUM_SCALAR >::isArrayInCurrentSystem(
const std::string& name)
const {
1836 const PRMSystem< GUM_SCALAR >* system
1837 =
static_cast<
const PRMSystem< GUM_SCALAR >* >(getCurrent());
1838 return (system && system->isArray(name));
1841 template <
typename GUM_SCALAR >
1843 PRMFactory< GUM_SCALAR >::setRawCPFByColumns(
const std::vector< std::string >& array) {
1844 _checkStack_(2, PRMObject::prm_type::CLASS);
1846 auto a =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(
1847 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
1849 if (a->formulas().domainSize() != array.size()) {
1850 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
1853 if (a->formulas().nbrDim() == 1) {
1854 setRawCPFByLines(array);
1857 Instantiation inst(a->formulas());
1859 for (
auto idx = inst.variablesSequence().rbegin(); idx != inst.variablesSequence().rend();
1865 auto idx = (std::size_t)0;
1866 while ((!jnst.end()) && idx < array.size()) {
1868 a->formulas().set(inst, array[idx]);
1878 template <
typename GUM_SCALAR >
1880 PRMFactory< GUM_SCALAR >::setRawCPFByLines(
const std::vector< std::string >& array) {
1881 _checkStack_(2, PRMObject::prm_type::CLASS);
1883 auto a =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(
1884 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
1886 if (a->formulas().domainSize() != array.size()) {
1887 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
1890 a->formulas().populate(array);