28 #include <agrum/PRM/PRMFactory.h> 33 #include <agrum/tools/core/math/formula.h> 35 #include <agrum/tools/variables/discretizedVariable.h> 36 #include <agrum/tools/variables/integerVariable.h> 37 #include <agrum/tools/variables/rangeVariable.h> 39 #include <agrum/PRM/elements/PRMFormAttribute.h> 40 #include <agrum/PRM/elements/PRMFuncAttribute.h> 46 template <
typename GUM_SCALAR >
47 INLINE
void PRMFactory< GUM_SCALAR >::startClass(
const std::string& name,
48 const std::string& extends,
49 const Set< std::string >* implements,
50 bool delayInheritance) {
51 std::string real_name = _addPrefix_(name);
52 if (_prm_->_classMap_.exists(real_name) || _prm_->_interfaceMap_.exists(real_name)) {
53 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
55 PRMClass< GUM_SCALAR >* c =
nullptr;
56 PRMClass< GUM_SCALAR >* mother =
nullptr;
57 Set< PRMInterface< GUM_SCALAR >* > impl;
59 if (implements != 0) {
60 for (
const auto& imp: *implements) {
61 impl.insert(_retrieveInterface_(imp));
65 if (extends !=
"") { mother = _retrieveClass_(extends); }
67 if ((extends ==
"") && impl.empty()) {
68 c =
new PRMClass< GUM_SCALAR >(real_name);
69 }
else if ((extends !=
"") && impl.empty()) {
70 c =
new PRMClass< GUM_SCALAR >(real_name, *mother, delayInheritance);
71 }
else if ((extends ==
"") && (!impl.empty())) {
72 c =
new PRMClass< GUM_SCALAR >(real_name, impl, delayInheritance);
73 }
else if ((extends !=
"") && (!impl.empty())) {
74 c =
new PRMClass< GUM_SCALAR >(real_name, *mother, impl, delayInheritance);
77 _prm_->_classMap_.insert(c->name(), c);
78 _prm_->_classes_.insert(c);
82 template <
typename GUM_SCALAR >
83 INLINE
void PRMFactory< GUM_SCALAR >::continueClass(
const std::string& name) {
84 std::string real_name = _addPrefix_(name);
85 if (!(_prm_->_classMap_.exists(real_name))) {
86 std::stringstream msg;
87 msg <<
"'" << real_name <<
"' not found";
88 GUM_ERROR(NotFound, msg.str())
90 _stack_.push_back(&(_prm_->getClass(real_name)));
93 template <
typename GUM_SCALAR >
94 INLINE
void PRMFactory< GUM_SCALAR >::endClass(
bool checkImplementations) {
95 PRMClass< GUM_SCALAR >* c
96 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
98 if (checkImplementations) { _checkInterfaceImplementation_(c); }
103 template <
typename GUM_SCALAR >
105 PRMFactory< GUM_SCALAR >::_checkInterfaceImplementation_(PRMClass< GUM_SCALAR >* c) {
107 for (
const auto& i: c->implements()) {
109 for (
const auto& node: i->containerDag().nodes()) {
110 std::string name = i->get(node).name();
112 switch (i->get(node).elt_type()) {
113 case PRMClassElement< GUM_SCALAR >::prm_aggregate:
114 case PRMClassElement< GUM_SCALAR >::prm_attribute: {
115 if ((c->get(name).elt_type() == PRMClassElement< GUM_SCALAR >::prm_attribute)
116 || (c->get(name).elt_type()
117 == PRMClassElement< GUM_SCALAR >::prm_aggregate)) {
118 if (!c->get(name).type().isSubTypeOf(i->get(name).type())) {
119 std::stringstream msg;
120 msg <<
"class " << c->name() <<
" does not respect interface ";
121 GUM_ERROR(PRMTypeError, msg.str() + i->name())
124 std::stringstream msg;
125 msg <<
"class " << c->name() <<
" does not respect interface ";
126 GUM_ERROR(PRMTypeError, msg.str() + i->name())
132 case PRMClassElement< GUM_SCALAR >::prm_refslot: {
133 if (c->get(name).elt_type() == PRMClassElement< GUM_SCALAR >::prm_refslot) {
134 const PRMReferenceSlot< GUM_SCALAR >& ref_i
135 =
static_cast<
const PRMReferenceSlot< GUM_SCALAR >& >(i->get(name));
136 const PRMReferenceSlot< GUM_SCALAR >& ref_this
137 =
static_cast<
const PRMReferenceSlot< GUM_SCALAR >& >(c->get(name));
139 if (!ref_this.slotType().isSubTypeOf(ref_i.slotType())) {
140 std::stringstream msg;
141 msg <<
"class " << c->name() <<
" does not respect interface ";
142 GUM_ERROR(PRMTypeError, msg.str() + i->name())
145 std::stringstream msg;
146 msg <<
"class " << c->name() <<
" does not respect interface ";
147 GUM_ERROR(PRMTypeError, msg.str() + i->name())
153 case PRMClassElement< GUM_SCALAR >::prm_slotchain: {
159 std::string msg =
"unexpected ClassElement<GUM_SCALAR> in interface ";
160 GUM_ERROR(FatalError, msg + i->name())
164 }
catch (NotFound&) {
165 std::stringstream msg;
166 msg <<
"class " << c->name() <<
" does not respect interface ";
167 GUM_ERROR(PRMTypeError, msg.str() + i->name())
170 }
catch (NotFound&) {
176 template <
typename GUM_SCALAR >
177 INLINE
void PRMFactory< GUM_SCALAR >::startInterface(
const std::string& name,
178 const std::string& extends,
179 bool delayInheritance) {
180 std::string real_name = _addPrefix_(name);
181 if (_prm_->_classMap_.exists(real_name) || _prm_->_interfaceMap_.exists(real_name)) {
182 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
184 PRMInterface< GUM_SCALAR >* i =
nullptr;
185 PRMInterface< GUM_SCALAR >* super =
nullptr;
187 if (extends !=
"") { super = _retrieveInterface_(extends); }
189 if (super !=
nullptr) {
190 i =
new PRMInterface< GUM_SCALAR >(real_name, *super, delayInheritance);
192 i =
new PRMInterface< GUM_SCALAR >(real_name);
195 _prm_->_interfaceMap_.insert(i->name(), i);
196 _prm_->_interfaces_.insert(i);
197 _stack_.push_back(i);
200 template <
typename GUM_SCALAR >
201 INLINE
void PRMFactory< GUM_SCALAR >::continueInterface(
const std::string& name) {
202 std::string real_name = _addPrefix_(name);
203 if (!_prm_->_interfaceMap_.exists(real_name)) {
204 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' not found.")
207 PRMInterface< GUM_SCALAR >* i = _retrieveInterface_(real_name);
208 _stack_.push_back(i);
211 template <
typename GUM_SCALAR >
212 INLINE
void PRMFactory< GUM_SCALAR >::addAttribute(PRMAttribute< GUM_SCALAR >* attr) {
213 PRMClass< GUM_SCALAR >* c
214 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
217 const Sequence<
const DiscreteVariable* >& vars = attr->cpf().variablesSequence();
219 for (
const auto& node: c->containerDag().nodes()) {
221 if (vars.exists(&(c->get(node).type().variable()))) {
224 if (&(attr->type().variable()) != &(c->get(node).type().variable())) {
225 c->addArc(c->get(node).safeName(), attr->safeName());
228 }
catch (OperationNotAllowed&) {}
231 if (count != attr->cpf().variablesSequence().size()) {
232 GUM_ERROR(NotFound,
"unable to found all parents of this attribute")
236 template <
typename GUM_SCALAR >
237 INLINE
void PRMFactory< GUM_SCALAR >::_addParent_(PRMClassElementContainer< GUM_SCALAR >* c,
238 PRMAttribute< GUM_SCALAR >* a,
239 const std::string& name) {
241 PRMClassElement< GUM_SCALAR >& elt = c->get(name);
243 switch (elt.elt_type()) {
244 case PRMClassElement< GUM_SCALAR >::prm_refslot: {
245 GUM_ERROR(OperationNotAllowed,
246 "can not add a reference slot as a parent of an attribute")
250 case PRMClassElement< GUM_SCALAR >::prm_slotchain: {
251 if (
static_cast< PRMSlotChain< GUM_SCALAR >& >(elt).isMultiple()) {
252 GUM_ERROR(OperationNotAllowed,
"can not add a multiple slot chain to an attribute")
255 c->addArc(name, a->name());
260 case PRMClassElement< GUM_SCALAR >::prm_attribute:
261 case PRMClassElement< GUM_SCALAR >::prm_aggregate: {
262 c->addArc(name, a->name());
267 GUM_ERROR(FatalError,
"unknown ClassElement<GUM_SCALAR>")
270 }
catch (NotFound&) {
272 PRMSlotChain< GUM_SCALAR >* sc = _buildSlotChain_(c, name);
275 std::string msg =
"found no ClassElement<GUM_SCALAR> with the given name ";
276 GUM_ERROR(NotFound, msg + name)
277 }
else if (!sc->isMultiple()) {
279 c->addArc(sc->name(), a->name());
282 GUM_ERROR(OperationNotAllowed,
283 "Impossible to add a multiple reference slot as" 284 " direct parent of an PRMAttribute<GUM_SCALAR>.");
290 template <
typename GUM_SCALAR >
291 INLINE
void PRMFactory< GUM_SCALAR >::addParent(
const std::string& name) {
292 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(2);
295 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
296 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
297 _addParent_(c, a, name);
298 }
catch (FactoryInvalidState&) {
299 auto agg =
static_cast< PRMAggregate< GUM_SCALAR >* >(
300 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_aggregate));
301 _addParent_(
static_cast< PRMClass< GUM_SCALAR >* >(c), agg, name);
305 template <
typename GUM_SCALAR >
306 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByFloatLines(
const std::vector<
float >& array) {
307 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
308 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
309 _checkStack_(2, PRMObject::prm_type::CLASS);
311 if (a->cpf().domainSize() != array.size()) GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
313 std::vector< GUM_SCALAR > array2(array.begin(), array.end());
314 a->cpf().fillWith(array2);
317 template <
typename GUM_SCALAR >
318 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByLines(
const std::vector< GUM_SCALAR >& array) {
319 auto elt = _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute);
320 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(elt);
321 _checkStack_(2, PRMObject::prm_type::CLASS);
323 if (a->cpf().domainSize() != array.size()) {
324 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
327 a->cpf().fillWith(array);
330 template <
typename GUM_SCALAR >
332 PRMFactory< GUM_SCALAR >::setRawCPFByFloatColumns(
const std::vector<
float >& array) {
333 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
334 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
336 if (a->cpf().domainSize() != array.size()) {
337 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
340 std::vector< GUM_SCALAR > array2(array.begin(), array.end());
341 setRawCPFByColumns(array2);
344 template <
typename GUM_SCALAR >
346 PRMFactory< GUM_SCALAR >::setRawCPFByColumns(
const std::vector< GUM_SCALAR >& array) {
347 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
348 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
350 if (a->cpf().domainSize() != array.size()) {
351 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
354 if (a->cpf().nbrDim() == 1) {
355 setRawCPFByLines(array);
358 Instantiation inst(a->cpf());
360 for (
auto idx = inst.variablesSequence().rbegin(); idx != inst.variablesSequence().rend();
366 auto idx = (std::size_t)0;
367 while ((!jnst.end()) && idx < array.size()) {
369 a->cpf().set(inst, array[idx]);
376 template <
typename GUM_SCALAR >
378 PRMFactory< GUM_SCALAR >::setCPFByFloatRule(
const std::vector< std::string >& parents,
379 const std::vector<
float >& values) {
380 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
381 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
383 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
384 GUM_ERROR(OperationNotAllowed,
"wrong number of parents")
387 if (values.size() != a->type().variable().domainSize()) {
388 GUM_ERROR(OperationNotAllowed,
"wrong number of values")
391 std::vector< GUM_SCALAR > values2(values.begin(), values.end());
392 setCPFByRule(parents, values2);
395 template <
typename GUM_SCALAR >
396 INLINE
void PRMFactory< GUM_SCALAR >::setCPFByRule(
const std::vector< std::string >& parents,
397 const std::vector< GUM_SCALAR >& values) {
398 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
399 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
401 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
402 GUM_ERROR(OperationNotAllowed,
"wrong number of parents")
405 if (values.size() != a->type().variable().domainSize()) {
406 GUM_ERROR(OperationNotAllowed,
"wrong number of values")
409 if (
dynamic_cast< PRMFormAttribute< GUM_SCALAR >* >(a)) {
410 auto form =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(a);
413 Instantiation jnst, knst;
414 const DiscreteVariable* var = 0;
418 for (Idx i = 0; i < parents.size(); ++i) {
419 var = form->formulas().variablesSequence().atPos(1 + i);
421 if (parents[i] ==
"*") {
428 for (Size j = 0; j < var->domainSize(); ++j) {
429 if (var->label(j) == parents[i]) {
430 jnst.chgVal(*var, j);
437 std::string msg =
"could not find label ";
438 GUM_ERROR(NotFound, msg + parents[i])
443 Instantiation inst(form->formulas());
446 for (Size i = 0; i < form->type()->domainSize(); ++i) {
447 inst.chgVal(form->type().variable(), i);
449 for (inst.setFirstIn(knst); !inst.end(); inst.incIn(knst)) {
450 form->formulas().set(inst, std::to_string(values[i]));
455 GUM_ERROR(OperationNotAllowed,
"invalide attribute type")
459 template <
typename GUM_SCALAR >
460 INLINE
void PRMFactory< GUM_SCALAR >::setCPFByRule(
const std::vector< std::string >& parents,
461 const std::vector< std::string >& values) {
462 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
463 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
465 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
466 GUM_ERROR(OperationNotAllowed,
"wrong number of parents")
469 if (values.size() != a->type().variable().domainSize()) {
470 GUM_ERROR(OperationNotAllowed,
"wrong number of values")
473 if (
dynamic_cast< PRMFormAttribute< GUM_SCALAR >* >(a)) {
474 auto form =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(a);
477 Instantiation jnst, knst;
478 const DiscreteVariable* var = 0;
482 for (Idx i = 0; i < parents.size(); ++i) {
483 var = form->formulas().variablesSequence().atPos(1 + i);
485 if (parents[i] ==
"*") {
492 for (Size j = 0; j < var->domainSize(); ++j) {
493 if (var->label(j) == parents[i]) {
494 jnst.chgVal(*var, j);
501 std::string msg =
"could not find label ";
502 GUM_ERROR(NotFound, msg + parents[i])
507 Instantiation inst(form->formulas());
510 for (Size i = 0; i < form->type()->domainSize(); ++i) {
511 inst.chgVal(form->type().variable(), i);
513 for (inst.setFirstIn(knst); !inst.end(); inst.incIn(knst)) {
514 form->formulas().set(inst, values[i]);
519 GUM_ERROR(OperationNotAllowed,
"invalide attribute type")
523 template <
typename GUM_SCALAR >
524 INLINE
void PRMFactory< GUM_SCALAR >::addParameter(
const std::string& type,
525 const std::string& name,
527 auto c =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
529 PRMParameter< GUM_SCALAR >* p =
nullptr;
531 p =
new PRMParameter< GUM_SCALAR >(name,
532 PRMParameter< GUM_SCALAR >::ParameterType::INT,
534 }
else if (type ==
"real") {
535 p =
new PRMParameter< GUM_SCALAR >(name,
536 PRMParameter< GUM_SCALAR >::ParameterType::REAL,
542 }
catch (DuplicateElement&) { c->overload(p); }
545 template <
typename GUM_SCALAR >
547 PRMFactory< GUM_SCALAR >::startAggregator(
const std::string& name,
548 const std::string& agg_type,
549 const std::string& rv_type,
550 const std::vector< std::string >& params) {
551 PRMClass< GUM_SCALAR >* c
552 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
554 auto agg =
new PRMAggregate< GUM_SCALAR >(name,
555 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
556 *_retrieveType_(rv_type));
560 }
catch (DuplicateElement&) { c->overload(agg); }
562 switch (agg->agg_type()) {
563 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT:
564 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
565 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
566 if (params.size() != 1) {
567 GUM_ERROR(OperationNotAllowed,
"aggregate requires a parameter")
569 agg->setLabel(params.front());
576 _stack_.push_back(agg);
579 template <
typename GUM_SCALAR >
580 INLINE
void PRMFactory< GUM_SCALAR >::continueAggregator(
const std::string& name) {
581 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(1);
583 if (!c->exists(name)) GUM_ERROR(NotFound,
"Element " << name <<
"not found")
585 auto& agg = c->get(name);
586 if (!PRMClassElement< GUM_SCALAR >::isAggregate(agg))
587 GUM_ERROR(OperationNotAllowed,
"Element " << name <<
" not an aggregate")
589 _stack_.push_back(&agg);
592 template <
typename GUM_SCALAR >
593 INLINE
void PRMFactory< GUM_SCALAR >::_addParent_(PRMClass< GUM_SCALAR >* c,
594 PRMAggregate< GUM_SCALAR >* agg,
595 const std::string& name) {
596 auto chains = std::vector< std::string >{name};
597 auto inputs = std::vector< PRMClassElement< GUM_SCALAR >* >();
598 _retrieveInputs_(c, chains, inputs);
600 switch (agg->agg_type()) {
601 case PRMAggregate< GUM_SCALAR >::AggregateType::OR:
602 case PRMAggregate< GUM_SCALAR >::AggregateType::AND: {
603 if (inputs.front()->type() != *(_retrieveType_(
"boolean"))) {
604 GUM_ERROR(TypeError,
"expected booleans")
610 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT:
611 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
612 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
613 if (!agg->hasLabel()) {
614 auto param = agg->labelValue();
617 while (label_idx < inputs.front()->type()->domainSize()) {
618 if (inputs.front()->type()->label(label_idx) == param) {
break; }
623 if (label_idx == inputs.front()->type()->domainSize()) {
624 GUM_ERROR(NotFound,
"could not find label")
627 agg->setLabel(label_idx);
633 case PRMAggregate< GUM_SCALAR >::AggregateType::SUM:
634 case PRMAggregate< GUM_SCALAR >::AggregateType::MEDIAN:
635 case PRMAggregate< GUM_SCALAR >::AggregateType::AMPLITUDE:
636 case PRMAggregate< GUM_SCALAR >::AggregateType::MIN:
637 case PRMAggregate< GUM_SCALAR >::AggregateType::MAX: {
642 GUM_ERROR(FatalError,
"Unknown aggregator.")
646 c->addArc(inputs.front()->safeName(), agg->safeName());
649 template <
typename GUM_SCALAR >
650 INLINE
void PRMFactory< GUM_SCALAR >::endAggregator() {
651 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_aggregate);
655 template <
typename GUM_SCALAR >
656 INLINE
void PRMFactory< GUM_SCALAR >::addAggregator(
const std::string& name,
657 const std::string& agg_type,
658 const std::vector< std::string >& chains,
659 const std::vector< std::string >& params,
661 PRMClass< GUM_SCALAR >* c
662 =
static_cast< PRMClass< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::CLASS));
665 if (chains.size() == 0) {
666 GUM_ERROR(OperationNotAllowed,
"a PRMAggregate<GUM_SCALAR> requires at least one parent")
670 std::vector< PRMClassElement< GUM_SCALAR >* > inputs;
675 bool hasSC = _retrieveInputs_(c, chains, inputs);
680 if (inputs.size() > 1) {
681 for (
auto iter = inputs.begin() + 1; iter != inputs.end(); ++iter) {
682 if ((**(iter - 1)).type() != (**iter).type()) {
683 GUM_ERROR(TypeError,
"found different types")
689 PRMAggregate< GUM_SCALAR >* agg =
nullptr;
691 switch (PRMAggregate< GUM_SCALAR >::str2enum(agg_type)) {
692 case PRMAggregate< GUM_SCALAR >::AggregateType::OR:
693 case PRMAggregate< GUM_SCALAR >::AggregateType::AND: {
694 if (inputs.front()->type() != *(_retrieveType_(
"boolean"))) {
695 GUM_ERROR(TypeError,
"expected booleans")
697 if (params.size() != 0) { GUM_ERROR(OperationNotAllowed,
"invalid number of paramaters") }
699 agg =
new PRMAggregate< GUM_SCALAR >(name,
700 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
701 inputs.front()->type());
706 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
707 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
708 if (params.size() != 1) { GUM_ERROR(OperationNotAllowed,
"invalid number of parameters") }
712 while (label_idx < inputs.front()->type()->domainSize()) {
713 if (inputs.front()->type()->label(label_idx) == params.front()) {
break; }
718 if (label_idx == inputs.front()->type()->domainSize()) {
719 GUM_ERROR(NotFound,
"could not find label")
723 agg =
new PRMAggregate< GUM_SCALAR >(name,
724 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
725 *(_retrieveType_(
"boolean")),
732 case PRMAggregate< GUM_SCALAR >::AggregateType::SUM:
733 case PRMAggregate< GUM_SCALAR >::AggregateType::MEDIAN:
734 case PRMAggregate< GUM_SCALAR >::AggregateType::AMPLITUDE:
735 case PRMAggregate< GUM_SCALAR >::AggregateType::MIN:
736 case PRMAggregate< GUM_SCALAR >::AggregateType::MAX: {
737 if (params.size() != 0) { GUM_ERROR(OperationNotAllowed,
"invalid number of parameters") }
739 auto output_type = _retrieveType_(type);
742 agg =
new PRMAggregate< GUM_SCALAR >(name,
743 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
749 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT: {
750 if (params.size() != 1) { GUM_ERROR(OperationNotAllowed,
"invalid number of parameters") }
754 while (label_idx < inputs.front()->type()->domainSize()) {
755 if (inputs.front()->type()->label(label_idx) == params.front()) {
break; }
760 if (label_idx == inputs.front()->type()->domainSize()) {
761 GUM_ERROR(NotFound,
"could not find label")
764 auto output_type = _retrieveType_(type);
767 agg =
new PRMAggregate< GUM_SCALAR >(name,
768 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
776 GUM_ERROR(FatalError,
"Unknown aggregator.")
780 std::string safe_name = agg->safeName();
786 }
catch (DuplicateElement&) { c->overload(agg); }
790 =
new PRMScalarAttribute< GUM_SCALAR >(agg->name(), agg->type(), agg->buildImpl());
794 }
catch (DuplicateElement&) { c->overload(attr); }
798 }
catch (DuplicateElement&) {
803 for (
const auto& elt: inputs) {
804 c->addArc(elt->safeName(), safe_name);
808 template <
typename GUM_SCALAR >
809 INLINE
void PRMFactory< GUM_SCALAR >::addReferenceSlot(
const std::string& type,
810 const std::string& name,
812 PRMClassElementContainer< GUM_SCALAR >* owner = _checkStackContainter_(1);
813 PRMClassElementContainer< GUM_SCALAR >* slotType = 0;
816 slotType = _retrieveClass_(type);
817 }
catch (NotFound&) {
819 slotType = _retrieveInterface_(type);
820 }
catch (NotFound&) { GUM_ERROR(NotFound,
"unknown ReferenceSlot<GUM_SCALAR> slot type") }
823 PRMReferenceSlot< GUM_SCALAR >* ref
824 =
new PRMReferenceSlot< GUM_SCALAR >(name, *slotType, isArray);
828 }
catch (DuplicateElement&) { owner->overload(ref); }
831 template <
typename GUM_SCALAR >
832 INLINE
void PRMFactory< GUM_SCALAR >::addArray(
const std::string& type,
833 const std::string& name,
835 PRMSystem< GUM_SCALAR >* model
836 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
837 PRMClass< GUM_SCALAR >* c = _retrieveClass_(type);
838 PRMInstance< GUM_SCALAR >* inst = 0;
841 model->addArray(name, *c);
843 for (Size i = 0; i < size; ++i) {
844 std::stringstream elt_name;
845 elt_name << name <<
"[" << i <<
"]";
846 inst =
new PRMInstance< GUM_SCALAR >(elt_name.str(), *c);
847 model->add(name, inst);
849 }
catch (PRMTypeError&) {
852 }
catch (NotFound&) {
858 template <
typename GUM_SCALAR >
859 INLINE
void PRMFactory< GUM_SCALAR >::incArray(
const std::string& l_i,
const std::string& r_i) {
860 PRMSystem< GUM_SCALAR >* model
861 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
863 if (model->isArray(l_i)) {
864 if (model->isInstance(r_i)) {
865 model->add(l_i, model->get(r_i));
867 GUM_ERROR(NotFound,
"right value is not an instance")
870 GUM_ERROR(NotFound,
"left value is no an array")
874 template <
typename GUM_SCALAR >
875 INLINE
void PRMFactory< GUM_SCALAR >::setReferenceSlot(
const std::string& l_i,
876 const std::string& l_ref,
877 const std::string& r_i) {
879 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
880 std::vector< PRMInstance< GUM_SCALAR >* > lefts;
881 std::vector< PRMInstance< GUM_SCALAR >* > rights;
883 if (model->isInstance(l_i)) {
884 lefts.push_back(&(model->get(l_i)));
885 }
else if (model->isArray(l_i)) {
886 for (
const auto& elt: model->getArray(l_i))
887 lefts.push_back(elt);
889 GUM_ERROR(NotFound,
"left value does not name an instance or an array")
892 if (model->isInstance(r_i)) {
893 rights.push_back(&(model->get(r_i)));
894 }
else if (model->isArray(r_i)) {
895 for (
const auto& elt: model->getArray(r_i))
896 rights.push_back(elt);
898 GUM_ERROR(NotFound,
"left value does not name an instance or an array")
901 for (
const auto l: lefts) {
902 for (
const auto r: rights) {
903 auto& elt = l->type().get(l_ref);
904 if (PRMClassElement< GUM_SCALAR >::isReferenceSlot(elt)) {
905 l->add(elt.id(), *r);
908 GUM_ERROR(NotFound,
"unfound reference slot")
914 template <
typename GUM_SCALAR >
915 INLINE PRMSlotChain< GUM_SCALAR >*
916 PRMFactory< GUM_SCALAR >::_buildSlotChain_(PRMClassElementContainer< GUM_SCALAR >* start,
917 const std::string& name) {
918 std::vector< std::string > v;
919 decomposePath(name, v);
920 PRMClassElementContainer< GUM_SCALAR >* current = start;
921 PRMReferenceSlot< GUM_SCALAR >* ref =
nullptr;
922 Sequence< PRMClassElement< GUM_SCALAR >* > elts;
924 for (size_t i = 0; i < v.size(); ++i) {
926 switch (current->get(v[i]).elt_type()) {
927 case PRMClassElement< GUM_SCALAR >::prm_refslot:
928 ref = &(
static_cast< PRMReferenceSlot< GUM_SCALAR >& >(current->get(v[i])));
930 current = &( (ref->slotType()));
933 case PRMClassElement< GUM_SCALAR >::prm_aggregate:
934 case PRMClassElement< GUM_SCALAR >::prm_attribute:
936 if (i == v.size() - 1) {
937 elts.insert(&(current->get(v[i])));
947 }
catch (NotFound&) {
return nullptr; }
950 GUM_ASSERT(v.size() == elts.size());
952 current->setOutputNode(*(elts.back()),
true);
954 return new PRMSlotChain< GUM_SCALAR >(name, elts);
957 template <
typename GUM_SCALAR >
958 INLINE
bool PRMFactory< GUM_SCALAR >::_retrieveInputs_(
959 PRMClass< GUM_SCALAR >* c,
960 const std::vector< std::string >& chains,
961 std::vector< PRMClassElement< GUM_SCALAR >* >& inputs) {
964 for (size_t i = 0; i < chains.size(); ++i) {
966 inputs.push_back(&(c->get(chains[i])));
967 retVal = retVal || PRMClassElement< GUM_SCALAR >::isSlotChain(*(inputs.back()));
968 }
catch (NotFound&) {
969 inputs.push_back(_buildSlotChain_(c, chains[i]));
973 c->add(inputs.back());
975 GUM_ERROR(NotFound,
"unknown slot chain")
980 PRMType* t = _retrieveCommonType_(inputs);
982 std::vector< std::pair< PRMClassElement< GUM_SCALAR >*, PRMClassElement< GUM_SCALAR >* > >
985 for (
const auto& elt: inputs) {
986 if ((*elt).type() != (*t)) {
987 if (PRMClassElement< GUM_SCALAR >::isSlotChain(*elt)) {
988 PRMSlotChain< GUM_SCALAR >* sc =
static_cast< PRMSlotChain< GUM_SCALAR >* >(elt);
989 std::stringstream name;
991 for (Size idx = 0; idx < sc->chain().size() - 1; ++idx) {
992 name << sc->chain().atPos(idx)->name() <<
".";
995 name <<
".(" << t->name() <<
")" << sc->lastElt().name();
998 toAdd.push_back(std::make_pair(elt, &(c->get(name.str()))));
999 }
catch (NotFound&) {
1000 toAdd.push_back(std::make_pair(elt, _buildSlotChain_(c, name.str())));
1003 std::stringstream name;
1004 name <<
"(" << t->name() <<
")" << elt->name();
1005 toAdd.push_back(std::make_pair(elt, &(c->get(name.str()))));
1013 template <
typename GUM_SCALAR >
1014 INLINE PRMType* PRMFactory< GUM_SCALAR >::_retrieveCommonType_(
1015 const std::vector< PRMClassElement< GUM_SCALAR >* >& elts) {
1016 const PRMType* current =
nullptr;
1017 HashTable< std::string, Size > counters;
1020 for (
const auto& elt: elts) {
1022 current = &((*elt).type());
1024 while (current != 0) {
1026 if (counters.exists(current->name())) {
1027 ++(counters[current->name()]);
1029 counters.insert(current->name(), 1);
1033 if (current->isSubType()) {
1034 current = &(current->superType());
1039 }
catch (OperationNotAllowed&) {
1040 GUM_ERROR(WrongClassElement,
"found a ClassElement<GUM_SCALAR> without a type")
1049 int current_depth = 0;
1051 for (
const auto& elt: counters) {
1052 if ((elt.second) == elts.size()) {
1053 current_depth = _typeDepth_(_retrieveType_(elt.first));
1055 if (current_depth > max_depth) {
1056 max_depth = current_depth;
1057 current = _retrieveType_(elt.first);
1062 if (current) {
return const_cast< PRMType* >(current); }
1064 GUM_ERROR(NotFound,
"could not find a common type")
1067 template <
typename GUM_SCALAR >
1069 PRMFactory< GUM_SCALAR >::addNoisyOrCompound(
const std::string& name,
1070 const std::vector< std::string >& chains,
1071 const std::vector<
float >& numbers,
1073 const std::vector< std::string >& labels) {
1074 if (currentType() != PRMObject::prm_type::CLASS) {
1075 GUM_ERROR(gum::FactoryInvalidState,
"invalid state to add a noisy-or")
1078 PRMClass< GUM_SCALAR >* c =
dynamic_cast< gum::prm::PRMClass< GUM_SCALAR >* >(getCurrent());
1080 std::vector< PRMClassElement< GUM_SCALAR >* > parents;
1082 for (
const auto& elt: chains)
1083 parents.push_back(&(c->get(elt)));
1085 PRMType* common_type = _retrieveCommonType_(parents);
1087 for (size_t idx = 0; idx < parents.size(); ++idx) {
1088 if (parents[idx]->type() != (*common_type)) {
1089 PRMClassElement< GUM_SCALAR >* parent = parents[idx];
1092 std::string safe_name = parent->cast(*common_type);
1094 if (!c->exists(safe_name)) {
1095 if (PRMClassElement< GUM_SCALAR >::isSlotChain(*parent)) {
1096 parents[idx] = _buildSlotChain_(c, safe_name);
1097 c->add(parents[idx]);
1099 GUM_ERROR(NotFound,
"unable to find parent")
1102 parents[idx] = &(c->get(safe_name));
1107 if (numbers.size() == 1) {
1108 auto impl =
new gum::MultiDimNoisyORCompound< GUM_SCALAR >(leak, numbers.front());
1109 auto attr =
new PRMScalarAttribute< GUM_SCALAR >(name, retrieveType(
"boolean"), impl);
1111 }
else if (numbers.size() == parents.size()) {
1112 gum::MultiDimNoisyORCompound< GUM_SCALAR >* noisy
1113 =
new gum::MultiDimNoisyORCompound< GUM_SCALAR >(leak);
1114 gum::prm::PRMFuncAttribute< GUM_SCALAR >* attr
1115 =
new gum::prm::PRMFuncAttribute< GUM_SCALAR >(name, retrieveType(
"boolean"), noisy);
1117 for (size_t idx = 0; idx < numbers.size(); ++idx) {
1118 noisy->causalWeight(parents[idx]->type().variable(), numbers[idx]);
1123 GUM_ERROR(OperationNotAllowed,
"invalid parameters for a noisy or")
1126 if (!labels.empty()) {
1127 GUM_ERROR(OperationNotAllowed,
"labels definitions not handle for noisy-or")
1131 template <
typename GUM_SCALAR >
1132 INLINE PRMType* PRMFactory< GUM_SCALAR >::_retrieveType_(
const std::string& name)
const {
1133 PRMType* type =
nullptr;
1134 std::string full_name;
1137 if (_prm_->_typeMap_.exists(name)) {
1138 type = _prm_->_typeMap_[name];
1143 std::string prefixed = _addPrefix_(name);
1144 if (_prm_->_typeMap_.exists(prefixed)) {
1146 type = _prm_->_typeMap_[prefixed];
1147 full_name = prefixed;
1148 }
else if (full_name != prefixed) {
1149 GUM_ERROR(DuplicateElement,
"Type name '" << name <<
"' is ambiguous: specify full name.")
1154 std::string relatif_ns = currentPackage();
1155 size_t last_dot = relatif_ns.find_last_of(
'.');
1156 if (last_dot != std::string::npos) {
1157 relatif_ns = relatif_ns.substr(0, last_dot) +
'.' + name;
1158 if (_prm_->_typeMap_.exists(relatif_ns)) {
1160 type = _prm_->_typeMap_[relatif_ns];
1161 full_name = relatif_ns;
1162 }
else if (full_name != relatif_ns) {
1163 GUM_ERROR(DuplicateElement,
1164 "Type name '" << name <<
"' is ambiguous: specify full name.");
1171 if (!_namespaces_.empty()) {
1172 auto ns_list = _namespaces_.back();
1173 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1174 std::string ns = (*ns_list)[i];
1175 std::string ns_name = ns +
"." + name;
1176 if (_prm_->_typeMap_.exists(ns_name)) {
1178 type = _prm_->_typeMap_[ns_name];
1179 full_name = ns_name;
1180 }
else if (full_name != ns_name) {
1181 GUM_ERROR(DuplicateElement,
1182 "Type name '" << name <<
"' is ambiguous: specify full name.");
1188 if (type == 0) { GUM_ERROR(NotFound,
"Type '" << name <<
"' not found, check imports.") }
1193 template <
typename GUM_SCALAR >
1194 PRMClass< GUM_SCALAR >*
1195 PRMFactory< GUM_SCALAR >::_retrieveClass_(
const std::string& name)
const {
1196 PRMClass< GUM_SCALAR >* a_class =
nullptr;
1197 std::string full_name;
1200 if (_prm_->_classMap_.exists(name)) {
1201 a_class = _prm_->_classMap_[name];
1206 std::string prefixed = _addPrefix_(name);
1207 if (_prm_->_classMap_.exists(prefixed)) {
1208 if (a_class ==
nullptr) {
1209 a_class = _prm_->_classMap_[prefixed];
1210 full_name = prefixed;
1211 }
else if (full_name != prefixed) {
1212 GUM_ERROR(DuplicateElement,
1213 "Class name '" << name <<
"' is ambiguous: specify full name.");
1218 if (!_namespaces_.empty()) {
1219 auto ns_list = _namespaces_.back();
1220 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1221 std::string ns = (*ns_list)[i];
1222 std::string ns_name = ns +
"." + name;
1223 if (_prm_->_classMap_.exists(ns_name)) {
1225 a_class = _prm_->_classMap_[ns_name];
1226 full_name = ns_name;
1227 }
else if (full_name != ns_name) {
1228 GUM_ERROR(DuplicateElement,
1229 "Class name '" << name <<
"' is ambiguous: specify full name.");
1235 if (a_class == 0) { GUM_ERROR(NotFound,
"Class '" << name <<
"' not found, check imports.") }
1240 template <
typename GUM_SCALAR >
1241 PRMInterface< GUM_SCALAR >*
1242 PRMFactory< GUM_SCALAR >::_retrieveInterface_(
const std::string& name)
const {
1243 PRMInterface< GUM_SCALAR >* interface =
nullptr;
1244 std::string full_name;
1247 if (_prm_->_interfaceMap_.exists(name)) {
1248 interface = _prm_->_interfaceMap_[name];
1253 std::string prefixed = _addPrefix_(name);
1254 if (_prm_->_interfaceMap_.exists(prefixed)) {
1255 if (interface ==
nullptr) {
1256 interface = _prm_->_interfaceMap_[prefixed];
1257 full_name = prefixed;
1258 }
else if (full_name != prefixed) {
1259 GUM_ERROR(DuplicateElement,
1260 "Interface name '" << name <<
"' is ambiguous: specify full name.");
1265 if (!_namespaces_.empty()) {
1266 auto ns_list = _namespaces_.back();
1268 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1269 std::string ns = (*ns_list)[i];
1270 std::string ns_name = ns +
"." + name;
1272 if (_prm_->_interfaceMap_.exists(ns_name)) {
1273 if (interface ==
nullptr) {
1274 interface = _prm_->_interfaceMap_[ns_name];
1275 full_name = ns_name;
1276 }
else if (full_name != ns_name) {
1277 GUM_ERROR(DuplicateElement,
1278 "Interface name '" << name <<
"' is ambiguous: specify full name.");
1284 if (interface ==
nullptr) {
1285 GUM_ERROR(NotFound,
"Interface '" << name <<
"' not found, check imports.")
1291 template <
typename GUM_SCALAR >
1292 INLINE PRMFactory< GUM_SCALAR >::PRMFactory() {
1293 GUM_CONSTRUCTOR(PRMFactory);
1294 _prm_ =
new PRM< GUM_SCALAR >();
1297 template <
typename GUM_SCALAR >
1298 INLINE PRMFactory< GUM_SCALAR >::PRMFactory(PRM< GUM_SCALAR >* prm) :
1299 IPRMFactory(), _prm_(prm) {
1300 GUM_CONSTRUCTOR(PRMFactory);
1303 template <
typename GUM_SCALAR >
1304 INLINE PRMFactory< GUM_SCALAR >::~PRMFactory() {
1305 GUM_DESTRUCTOR(PRMFactory);
1306 while (!_namespaces_.empty()) {
1307 auto ns = _namespaces_.back();
1308 _namespaces_.pop_back();
1313 template <
typename GUM_SCALAR >
1314 INLINE PRM< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::prm()
const {
1318 template <
typename GUM_SCALAR >
1319 INLINE PRMObject::prm_type PRMFactory< GUM_SCALAR >::currentType()
const {
1320 if (_stack_.size() == 0) { GUM_ERROR(NotFound,
"no object being built") }
1322 return _stack_.back()->obj_type();
1325 template <
typename GUM_SCALAR >
1326 INLINE PRMObject* PRMFactory< GUM_SCALAR >::getCurrent() {
1327 if (_stack_.size() == 0) { GUM_ERROR(NotFound,
"no object being built") }
1329 return _stack_.back();
1332 template <
typename GUM_SCALAR >
1333 INLINE
const PRMObject* PRMFactory< GUM_SCALAR >::getCurrent()
const {
1334 if (_stack_.size() == 0) { GUM_ERROR(NotFound,
"no object being built") }
1336 return _stack_.back();
1339 template <
typename GUM_SCALAR >
1340 INLINE PRMObject* PRMFactory< GUM_SCALAR >::closeCurrent() {
1341 if (_stack_.size() > 0) {
1342 PRMObject* obj = _stack_.back();
1350 template <
typename GUM_SCALAR >
1351 INLINE std::string PRMFactory< GUM_SCALAR >::currentPackage()
const {
1352 return (_packages_.empty()) ?
"" : _packages_.back();
1355 template <
typename GUM_SCALAR >
1356 INLINE
void PRMFactory< GUM_SCALAR >::startDiscreteType(
const std::string& name,
1357 std::string super) {
1358 std::string real_name = _addPrefix_(name);
1359 if (_prm_->_typeMap_.exists(real_name)) {
1360 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
1363 auto t =
new PRMType(LabelizedVariable(real_name,
"", 0));
1364 _stack_.push_back(t);
1366 auto t =
new PRMType(LabelizedVariable(real_name,
"", 0));
1367 t->_superType_ = _retrieveType_(super);
1368 t->_label_map_ =
new std::vector< Idx >();
1369 _stack_.push_back(t);
1373 template <
typename GUM_SCALAR >
1374 INLINE
void PRMFactory< GUM_SCALAR >::addLabel(
const std::string& l, std::string extends) {
1375 if (extends ==
"") {
1376 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1377 LabelizedVariable* var =
dynamic_cast< LabelizedVariable* >(t->_var_);
1380 GUM_ERROR(FatalError,
"the current type's variable is not a LabelizedVariable.")
1381 }
else if (t->_superType_) {
1382 GUM_ERROR(OperationNotAllowed,
"current type is a subtype.")
1387 }
catch (DuplicateElement&) {
1388 GUM_ERROR(DuplicateElement,
"a label '" << l <<
"' already exists")
1391 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1392 LabelizedVariable* var =
dynamic_cast< LabelizedVariable* >(t->_var_);
1395 GUM_ERROR(FatalError,
"the current type's variable is not a LabelizedVariable.")
1396 }
else if (!t->_superType_) {
1397 GUM_ERROR(OperationNotAllowed,
"current type is not a subtype.")
1402 for (Idx i = 0; i < t->_superType_->_var_->domainSize(); ++i) {
1403 if (t->_superType_->_var_->label(i) == extends) {
1406 }
catch (DuplicateElement&) {
1407 GUM_ERROR(DuplicateElement,
"a label '" << l <<
"' already exists")
1410 t->_label_map_->push_back(i);
1417 if (!found) { GUM_ERROR(NotFound,
"inexistent label in super type.") }
1421 template <
typename GUM_SCALAR >
1422 INLINE
void PRMFactory< GUM_SCALAR >::endDiscreteType() {
1423 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1425 if (!t->_isValid_()) {
1426 GUM_ERROR(OperationNotAllowed,
"current type is not a valid subtype")
1427 }
else if (t->variable().domainSize() < 2) {
1428 GUM_ERROR(OperationNotAllowed,
"current type is not a valid discrete type")
1431 _prm_->_typeMap_.insert(t->name(), t);
1433 _prm_->_types_.insert(t);
1437 template <
typename GUM_SCALAR >
1438 INLINE
void PRMFactory< GUM_SCALAR >::startDiscretizedType(
const std::string& name) {
1439 std::string real_name = _addPrefix_(name);
1440 if (_prm_->_typeMap_.exists(real_name)) {
1441 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.")
1443 auto var = DiscretizedVariable<
double >(real_name,
"");
1444 auto t =
new PRMType(var);
1445 _stack_.push_back(t);
1448 template <
typename GUM_SCALAR >
1449 INLINE
void PRMFactory< GUM_SCALAR >::addTick(
double tick) {
1450 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1451 DiscretizedVariable<
double >* var =
dynamic_cast< DiscretizedVariable<
double >* >(t->_var_);
1453 if (!var) { GUM_ERROR(FatalError,
"the current type's variable is not a LabelizedVariable.") }
1457 }
catch (DefaultInLabel&) {
1458 GUM_ERROR(OperationNotAllowed,
"tick already in used for this variable")
1462 template <
typename GUM_SCALAR >
1463 INLINE
void PRMFactory< GUM_SCALAR >::endDiscretizedType() {
1464 PRMType* t =
static_cast< PRMType* >(_checkStack_(1, PRMObject::prm_type::TYPE));
1466 if (t->variable().domainSize() < 2) {
1467 GUM_ERROR(OperationNotAllowed,
"current type is not a valid discrete type")
1470 _prm_->_typeMap_.insert(t->name(), t);
1472 _prm_->_types_.insert(t);
1476 template <
typename GUM_SCALAR >
1478 PRMFactory< GUM_SCALAR >::addRangeType(
const std::string& name,
long minVal,
long maxVal) {
1479 std::string real_name = _addPrefix_(name);
1480 if (_prm_->_typeMap_.exists(real_name)) {
1481 std::stringstream msg;
1482 msg <<
"\"" << real_name <<
"' is already used.";
1483 GUM_ERROR(DuplicateElement, msg.str())
1486 auto var = RangeVariable(real_name,
"", minVal, maxVal);
1487 auto t =
new PRMType(var);
1489 if (t->variable().domainSize() < 2) {
1490 GUM_ERROR(OperationNotAllowed,
"current type is not a valid discrete type")
1493 _prm_->_typeMap_.insert(t->name(), t);
1494 _prm_->_types_.insert(t);
1497 template <
typename GUM_SCALAR >
1498 INLINE
void PRMFactory< GUM_SCALAR >::endInterface() {
1499 _checkStack_(1, PRMObject::prm_type::PRM_INTERFACE);
1503 template <
typename GUM_SCALAR >
1504 INLINE
void PRMFactory< GUM_SCALAR >::addAttribute(
const std::string& type,
1505 const std::string& name) {
1506 _checkStack_(1, PRMObject::prm_type::PRM_INTERFACE);
1507 startAttribute(type, name);
1511 template <
typename GUM_SCALAR >
1512 INLINE
void PRMFactory< GUM_SCALAR >::startAttribute(
const std::string& type,
1513 const std::string& name,
1515 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(1);
1516 PRMAttribute< GUM_SCALAR >* a =
nullptr;
1518 if (PRMObject::isClass(*c) && (!scalar_attr)) {
1519 a =
new PRMFormAttribute< GUM_SCALAR >(
static_cast< PRMClass< GUM_SCALAR >& >(*c),
1521 *_retrieveType_(type));
1524 a =
new PRMScalarAttribute< GUM_SCALAR >(name, *_retrieveType_(type));
1527 std::string dot =
".";
1532 }
catch (DuplicateElement&) { c->overload(a); }
1533 }
catch (Exception&) {
1534 if (a !=
nullptr && (!c->exists(a->id()))) {
delete a; }
1537 _stack_.push_back(a);
1540 template <
typename GUM_SCALAR >
1541 INLINE
void PRMFactory< GUM_SCALAR >::continueAttribute(
const std::string& name) {
1542 PRMClassElementContainer< GUM_SCALAR >* c = _checkStackContainter_(1);
1544 if (!c->exists(name)) GUM_ERROR(NotFound,
"Attribute " << name <<
"not found")
1546 auto& a = c->get(name);
1548 if (!PRMClassElement< GUM_SCALAR >::isAttribute(a))
1549 GUM_ERROR(OperationNotAllowed,
"Element " << name <<
" not an attribute")
1551 _stack_.push_back(&a);
1554 template <
typename GUM_SCALAR >
1555 INLINE
void PRMFactory< GUM_SCALAR >::endAttribute() {
1556 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute);
1560 template <
typename GUM_SCALAR >
1561 INLINE
void PRMFactory< GUM_SCALAR >::startSystem(
const std::string& name) {
1562 if (_prm_->_systemMap_.exists(name)) {
1563 GUM_ERROR(DuplicateElement,
"'" << name <<
"' is already used.")
1565 PRMSystem< GUM_SCALAR >* model =
new PRMSystem< GUM_SCALAR >(_addPrefix_(name));
1566 _stack_.push_back(model);
1567 _prm_->_systemMap_.insert(model->name(), model);
1568 _prm_->_systems_.insert(model);
1571 template <
typename GUM_SCALAR >
1572 INLINE
void PRMFactory< GUM_SCALAR >::endSystem() {
1574 PRMSystem< GUM_SCALAR >* model
1575 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
1577 model->instantiate();
1578 }
catch (Exception&) { GUM_ERROR(FatalError,
"could not create system") }
1581 template <
typename GUM_SCALAR >
1582 INLINE
void PRMFactory< GUM_SCALAR >::addInstance(
const std::string& type,
1583 const std::string& name) {
1584 auto c = _retrieveClass_(type);
1587 if (c->parameters().size() > 0) {
1588 HashTable< std::string,
double > params;
1589 addInstance(type, name, params);
1592 _addInstance_(c, name);
1596 template <
typename GUM_SCALAR >
1598 PRMFactory< GUM_SCALAR >::addInstance(
const std::string& type,
1599 const std::string& name,
1600 const HashTable< std::string,
double >& params) {
1601 auto c = _retrieveClass_(type);
1603 if (c->parameters().empty()) {
1604 if (params.empty()) {
1605 _addInstance_(c, name);
1607 GUM_ERROR(OperationNotAllowed,
"Class " + type +
" does not have parameters")
1611 auto my_params = params;
1613 for (
const auto& p: c->parameters()) {
1614 if (!my_params.exists(p->name())) { my_params.insert(p->name(), p->value()); }
1618 std::stringstream sBuff;
1619 sBuff << c->name() <<
"<";
1621 for (
const auto& p: my_params) {
1622 sBuff << p.first <<
"=" << p.second <<
",";
1626 std::string sub_c = sBuff.str().substr(0, sBuff.str().size() - 1) +
">";
1630 auto pck_cpy = _packages_;
1633 startClass(sub_c, c->name());
1636 for (
auto p: my_params) {
1637 auto type =
static_cast< PRMParameter< GUM_SCALAR >& >(c->get(p.first)).valueType();
1638 if (type == PRMParameter< GUM_SCALAR >::ParameterType::INT) {
1639 addParameter(
"int", p.first, p.second);
1642 addParameter(
"real", p.first, p.second);
1648 _packages_ = pck_cpy;
1650 }
catch (DuplicateElement&) {
1653 c = _retrieveClass_(sub_c);
1654 _addInstance_(c, name);
1658 template <
typename GUM_SCALAR >
1659 INLINE
void PRMFactory< GUM_SCALAR >::_addInstance_(PRMClass< GUM_SCALAR >* type,
1660 const std::string& name) {
1661 PRMInstance< GUM_SCALAR >* i =
nullptr;
1664 =
static_cast< PRMSystem< GUM_SCALAR >* >(_checkStack_(1, PRMObject::prm_type::SYSTEM));
1665 i =
new PRMInstance< GUM_SCALAR >(name, *type);
1668 }
catch (OperationNotAllowed&) {
1669 if (i) {
delete i; }
1674 template <
typename GUM_SCALAR >
1675 INLINE std::string PRMFactory< GUM_SCALAR >::_addPrefix_(
const std::string& str)
const {
1676 if (!_packages_.empty()) {
1677 std::string full_name = _packages_.back();
1678 full_name.append(
".");
1679 full_name.append(str);
1686 template <
typename GUM_SCALAR >
1687 INLINE PRMObject* PRMFactory< GUM_SCALAR >::_checkStack_(Idx i, PRMObject::prm_type obj_type) {
1689 if (_stack_.size() - i > _stack_.size()) {
1690 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1693 PRMObject* obj = _stack_[_stack_.size() - i];
1695 if (obj->obj_type() != obj_type) {
1696 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1702 template <
typename GUM_SCALAR >
1703 INLINE PRMClassElementContainer< GUM_SCALAR >*
1704 PRMFactory< GUM_SCALAR >::_checkStackContainter_(Idx i) {
1706 if (_stack_.size() - i > _stack_.size()) {
1707 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1710 PRMObject* obj = _stack_[_stack_.size() - i];
1712 if ((obj->obj_type() == PRMObject::prm_type::CLASS)
1713 || (obj->obj_type() == PRMObject::prm_type::PRM_INTERFACE)) {
1714 return static_cast< PRMClassElementContainer< GUM_SCALAR >* >(obj);
1716 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1720 template <
typename GUM_SCALAR >
1721 INLINE PRMClassElement< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::_checkStack_(
1723 typename PRMClassElement< GUM_SCALAR >::ClassElementType elt_type) {
1725 if (_stack_.size() - i > _stack_.size()) {
1726 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1729 PRMClassElement< GUM_SCALAR >* obj
1730 =
dynamic_cast< PRMClassElement< GUM_SCALAR >* >(_stack_[_stack_.size() - i]);
1732 if (obj == 0) { GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls") }
1734 if (obj->elt_type() != elt_type) {
1735 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls")
1741 template <
typename GUM_SCALAR >
1742 INLINE
int PRMFactory< GUM_SCALAR >::_typeDepth_(
const PRMType* t) {
1744 const PRMType* current = t;
1746 while (current->isSubType()) {
1748 current = &(current->superType());
1754 template <
typename GUM_SCALAR >
1755 INLINE
void PRMFactory< GUM_SCALAR >::pushPackage(
const std::string& name) {
1756 _packages_.push_back(name);
1757 _namespaces_.push_back(
new List< std::string >());
1760 template <
typename GUM_SCALAR >
1761 INLINE std::string PRMFactory< GUM_SCALAR >::popPackage() {
1762 std::string plop = currentPackage();
1764 if (!_packages_.empty()) {
1765 std::string s = _packages_.back();
1766 _packages_.pop_back();
1768 if (_namespaces_.size() > 0) {
1769 delete _namespaces_.back();
1770 _namespaces_.pop_back();
1778 template <
typename GUM_SCALAR >
1779 INLINE
void PRMFactory< GUM_SCALAR >::addImport(
const std::string& name) {
1780 if (name.size() == 0) { GUM_ERROR(OperationNotAllowed,
"illegal import name") }
1781 if (_namespaces_.empty()) { _namespaces_.push_back(
new List< std::string >()); }
1782 _namespaces_.back()->push_back(name);
1785 template <
typename GUM_SCALAR >
1786 INLINE
void PRMFactory< GUM_SCALAR >::setReferenceSlot(
const std::string& l_i,
1787 const std::string& r_i) {
1788 size_t pos = l_i.find_last_of(
'.');
1790 if (pos != std::string::npos) {
1791 std::string l_ref = l_i.substr(pos + 1, std::string::npos);
1792 setReferenceSlot(l_i.substr(0, pos), l_ref, r_i);
1794 GUM_ERROR(NotFound,
"left value does not name an instance or an array")
1798 template <
typename GUM_SCALAR >
1799 INLINE PRMClass< GUM_SCALAR >&
1800 PRMFactory< GUM_SCALAR >::retrieveClass(
const std::string& name) {
1801 return *_retrieveClass_(name);
1804 template <
typename GUM_SCALAR >
1805 INLINE PRMType& PRMFactory< GUM_SCALAR >::retrieveType(
const std::string& name) {
1806 return *_retrieveType_(name);
1809 template <
typename GUM_SCALAR >
1810 INLINE PRMType& PRMFactory< GUM_SCALAR >::retrieveCommonType(
1811 const std::vector< PRMClassElement< GUM_SCALAR >* >& elts) {
1812 return *(_retrieveCommonType_(elts));
1816 template <
typename GUM_SCALAR >
1817 INLINE
bool PRMFactory< GUM_SCALAR >::isClassOrInterface(
const std::string& type)
const {
1819 _retrieveClass_(type);
1822 }
catch (NotFound&) {
1823 }
catch (DuplicateElement&) {}
1826 _retrieveInterface_(type);
1829 }
catch (NotFound&) {
1830 }
catch (DuplicateElement&) {}
1835 template <
typename GUM_SCALAR >
1836 INLINE
bool PRMFactory< GUM_SCALAR >::isArrayInCurrentSystem(
const std::string& name)
const {
1837 const PRMSystem< GUM_SCALAR >* system
1838 =
static_cast<
const PRMSystem< GUM_SCALAR >* >(getCurrent());
1839 return (system && system->isArray(name));
1842 template <
typename GUM_SCALAR >
1844 PRMFactory< GUM_SCALAR >::setRawCPFByColumns(
const std::vector< std::string >& array) {
1845 _checkStack_(2, PRMObject::prm_type::CLASS);
1847 auto a =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(
1848 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
1850 if (a->formulas().domainSize() != array.size()) {
1851 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
1854 if (a->formulas().nbrDim() == 1) {
1855 setRawCPFByLines(array);
1858 Instantiation inst(a->formulas());
1860 for (
auto idx = inst.variablesSequence().rbegin(); idx != inst.variablesSequence().rend();
1866 auto idx = (std::size_t)0;
1867 while ((!jnst.end()) && idx < array.size()) {
1869 a->formulas().set(inst, array[idx]);
1879 template <
typename GUM_SCALAR >
1881 PRMFactory< GUM_SCALAR >::setRawCPFByLines(
const std::vector< std::string >& array) {
1882 _checkStack_(2, PRMObject::prm_type::CLASS);
1884 auto a =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(
1885 _checkStack_(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
1887 if (a->formulas().domainSize() != array.size()) {
1888 GUM_ERROR(OperationNotAllowed,
"illegal CPF size")
1891 a->formulas().populate(array);