28 #include <agrum/PRM/PRMFactory.h> 33 #include <agrum/tools/core/math/formula.h> 35 #include <agrum/tools/variables/discretizedVariable.h> 36 #include <agrum/tools/variables/rangeVariable.h> 38 #include <agrum/PRM/elements/PRMFormAttribute.h> 39 #include <agrum/PRM/elements/PRMFuncAttribute.h> 45 template <
typename GUM_SCALAR >
47 PRMFactory< GUM_SCALAR >::startClass(
const std::string& name,
48 const std::string& extends,
49 const Set< std::string >* implements,
50 bool delayInheritance) {
51 std::string real_name = addPrefix__(name);
52 if (prm__->classMap__.exists(real_name)
53 || prm__->interfaceMap__.exists(real_name)) {
54 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.");
56 PRMClass< GUM_SCALAR >* c =
nullptr;
57 PRMClass< GUM_SCALAR >* mother =
nullptr;
58 Set< PRMInterface< GUM_SCALAR >* > impl;
60 if (implements != 0) {
61 for (
const auto& imp: *implements) {
62 impl.insert(retrieveInterface__(imp));
66 if (extends !=
"") { mother = retrieveClass__(extends); }
68 if ((extends ==
"") && impl.empty()) {
69 c =
new PRMClass< GUM_SCALAR >(real_name);
70 }
else if ((extends !=
"") && impl.empty()) {
71 c =
new PRMClass< GUM_SCALAR >(real_name, *mother, delayInheritance);
72 }
else if ((extends ==
"") && (!impl.empty())) {
73 c =
new PRMClass< GUM_SCALAR >(real_name, impl, delayInheritance);
74 }
else if ((extends !=
"") && (!impl.empty())) {
75 c =
new PRMClass< GUM_SCALAR >(real_name, *mother, impl, delayInheritance);
78 prm__->classMap__.insert(c->name(), c);
79 prm__->classes__.insert(c);
83 template <
typename GUM_SCALAR >
84 INLINE
void PRMFactory< GUM_SCALAR >::continueClass(
const std::string& name) {
85 std::string real_name = addPrefix__(name);
86 if (!(prm__->classMap__.exists(real_name))) {
87 std::stringstream msg;
88 msg <<
"'" << real_name <<
"' not found";
89 GUM_ERROR(NotFound, msg.str());
91 stack__.push_back(&(prm__->getClass(real_name)));
94 template <
typename GUM_SCALAR >
95 INLINE
void PRMFactory< GUM_SCALAR >::endClass(
bool checkImplementations) {
96 PRMClass< GUM_SCALAR >* c =
static_cast< PRMClass< GUM_SCALAR >* >(
97 checkStack__(1, PRMObject::prm_type::CLASS));
99 if (checkImplementations) { checkInterfaceImplementation__(c); }
104 template <
typename GUM_SCALAR >
105 INLINE
void PRMFactory< GUM_SCALAR >::checkInterfaceImplementation__(
106 PRMClass< GUM_SCALAR >* c) {
108 for (
const auto& i: c->implements()) {
110 for (
const auto& node: i->containerDag().nodes()) {
111 std::string name = i->get(node).name();
113 switch (i->get(node).elt_type()) {
114 case PRMClassElement< GUM_SCALAR >::prm_aggregate:
115 case PRMClassElement< GUM_SCALAR >::prm_attribute: {
116 if ((c->get(name).elt_type()
117 == PRMClassElement< GUM_SCALAR >::prm_attribute)
118 || (c->get(name).elt_type()
119 == PRMClassElement< GUM_SCALAR >::prm_aggregate)) {
120 if (!c->get(name).type().isSubTypeOf(i->get(name).type())) {
121 std::stringstream msg;
122 msg <<
"class " << c->name()
123 <<
" does not respect interface ";
124 GUM_ERROR(PRMTypeError, msg.str() + i->name());
127 std::stringstream msg;
128 msg <<
"class " << c->name() <<
" does not respect interface ";
129 GUM_ERROR(PRMTypeError, msg.str() + i->name());
135 case PRMClassElement< GUM_SCALAR >::prm_refslot: {
136 if (c->get(name).elt_type()
137 == PRMClassElement< GUM_SCALAR >::prm_refslot) {
138 const PRMReferenceSlot< GUM_SCALAR >& ref_i
139 =
static_cast<
const PRMReferenceSlot< GUM_SCALAR >& >(
141 const PRMReferenceSlot< GUM_SCALAR >& ref_this
142 =
static_cast<
const PRMReferenceSlot< GUM_SCALAR >& >(
145 if (!ref_this.slotType().isSubTypeOf(ref_i.slotType())) {
146 std::stringstream msg;
147 msg <<
"class " << c->name()
148 <<
" does not respect interface ";
149 GUM_ERROR(PRMTypeError, msg.str() + i->name());
152 std::stringstream msg;
153 msg <<
"class " << c->name() <<
" does not respect interface ";
154 GUM_ERROR(PRMTypeError, msg.str() + i->name());
160 case PRMClassElement< GUM_SCALAR >::prm_slotchain: {
167 =
"unexpected ClassElement<GUM_SCALAR> in interface ";
168 GUM_ERROR(FatalError, msg + i->name());
172 }
catch (NotFound&) {
173 std::stringstream msg;
174 msg <<
"class " << c->name() <<
" does not respect interface ";
175 GUM_ERROR(PRMTypeError, msg.str() + i->name());
178 }
catch (NotFound&) {
184 template <
typename GUM_SCALAR >
186 PRMFactory< GUM_SCALAR >::startInterface(
const std::string& name,
187 const std::string& extends,
188 bool delayInheritance) {
189 std::string real_name = addPrefix__(name);
190 if (prm__->classMap__.exists(real_name)
191 || prm__->interfaceMap__.exists(real_name)) {
192 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.");
194 PRMInterface< GUM_SCALAR >* i =
nullptr;
195 PRMInterface< GUM_SCALAR >* super =
nullptr;
197 if (extends !=
"") { super = retrieveInterface__(extends); }
199 if (super !=
nullptr) {
200 i =
new PRMInterface< GUM_SCALAR >(real_name, *super, delayInheritance);
202 i =
new PRMInterface< GUM_SCALAR >(real_name);
205 prm__->interfaceMap__.insert(i->name(), i);
206 prm__->interfaces__.insert(i);
207 stack__.push_back(i);
210 template <
typename GUM_SCALAR >
212 PRMFactory< GUM_SCALAR >::continueInterface(
const std::string& name) {
213 std::string real_name = addPrefix__(name);
214 if (!prm__->interfaceMap__.exists(real_name)) {
215 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' not found.");
218 PRMInterface< GUM_SCALAR >* i = retrieveInterface__(real_name);
219 stack__.push_back(i);
222 template <
typename GUM_SCALAR >
224 PRMFactory< GUM_SCALAR >::addAttribute(PRMAttribute< GUM_SCALAR >* attr) {
225 PRMClass< GUM_SCALAR >* c =
static_cast< PRMClass< GUM_SCALAR >* >(
226 checkStack__(1, PRMObject::prm_type::CLASS));
229 const Sequence<
const DiscreteVariable* >& vars
230 = attr->cpf().variablesSequence();
232 for (
const auto& node: c->containerDag().nodes()) {
234 if (vars.exists(&(c->get(node).type().variable()))) {
237 if (&(attr->type().variable()) != &(c->get(node).type().variable())) {
238 c->addArc(c->get(node).safeName(), attr->safeName());
241 }
catch (OperationNotAllowed&) {}
244 if (count != attr->cpf().variablesSequence().size()) {
245 GUM_ERROR(NotFound,
"unable to found all parents of this attribute");
249 template <
typename GUM_SCALAR >
250 INLINE
void PRMFactory< GUM_SCALAR >::addParent__(
251 PRMClassElementContainer< GUM_SCALAR >* c,
252 PRMAttribute< GUM_SCALAR >* a,
253 const std::string& name) {
255 PRMClassElement< GUM_SCALAR >& elt = c->get(name);
257 switch (elt.elt_type()) {
258 case PRMClassElement< GUM_SCALAR >::prm_refslot: {
259 GUM_ERROR(OperationNotAllowed,
260 "can not add a reference slot as a parent of an attribute");
264 case PRMClassElement< GUM_SCALAR >::prm_slotchain: {
265 if (
static_cast< PRMSlotChain< GUM_SCALAR >& >(elt).isMultiple()) {
266 GUM_ERROR(OperationNotAllowed,
267 "can not add a multiple slot chain to an attribute");
270 c->addArc(name, a->name());
275 case PRMClassElement< GUM_SCALAR >::prm_attribute:
276 case PRMClassElement< GUM_SCALAR >::prm_aggregate: {
277 c->addArc(name, a->name());
282 GUM_ERROR(FatalError,
"unknown ClassElement<GUM_SCALAR>");
285 }
catch (NotFound&) {
287 PRMSlotChain< GUM_SCALAR >* sc = buildSlotChain__(c, name);
291 =
"found no ClassElement<GUM_SCALAR> with the given name ";
292 GUM_ERROR(NotFound, msg + name);
293 }
else if (!sc->isMultiple()) {
295 c->addArc(sc->name(), a->name());
298 GUM_ERROR(OperationNotAllowed,
299 "Impossible to add a multiple reference slot as" 300 " direct parent of an PRMAttribute<GUM_SCALAR>.");
306 template <
typename GUM_SCALAR >
307 INLINE
void PRMFactory< GUM_SCALAR >::addParent(
const std::string& name) {
308 PRMClassElementContainer< GUM_SCALAR >* c = checkStackContainter__(2);
311 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
312 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
313 addParent__(c, a, name);
314 }
catch (FactoryInvalidState&) {
315 auto agg =
static_cast< PRMAggregate< GUM_SCALAR >* >(
316 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_aggregate));
317 addParent__(
static_cast< PRMClass< GUM_SCALAR >* >(c), agg, name);
321 template <
typename GUM_SCALAR >
322 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByFloatLines(
323 const std::vector<
float >& array) {
324 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
325 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
326 checkStack__(2, PRMObject::prm_type::CLASS);
328 if (a->cpf().domainSize() != array.size())
329 GUM_ERROR(OperationNotAllowed,
"illegal CPF size");
331 std::vector< GUM_SCALAR > array2(array.begin(), array.end());
332 a->cpf().fillWith(array2);
335 template <
typename GUM_SCALAR >
336 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByLines(
337 const std::vector< GUM_SCALAR >& array) {
338 auto elt = checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute);
339 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(elt);
340 checkStack__(2, PRMObject::prm_type::CLASS);
342 if (a->cpf().domainSize() != array.size()) {
343 GUM_ERROR(OperationNotAllowed,
"illegal CPF size");
346 a->cpf().fillWith(array);
349 template <
typename GUM_SCALAR >
350 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByFloatColumns(
351 const std::vector<
float >& array) {
352 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
353 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
355 if (a->cpf().domainSize() != array.size()) {
356 GUM_ERROR(OperationNotAllowed,
"illegal CPF size");
359 std::vector< GUM_SCALAR > array2(array.begin(), array.end());
360 setRawCPFByColumns(array2);
363 template <
typename GUM_SCALAR >
364 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByColumns(
365 const std::vector< GUM_SCALAR >& array) {
366 PRMAttribute< GUM_SCALAR >* a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
367 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
369 if (a->cpf().domainSize() != array.size()) {
370 GUM_ERROR(OperationNotAllowed,
"illegal CPF size");
373 if (a->cpf().nbrDim() == 1) {
374 setRawCPFByLines(array);
377 Instantiation inst(a->cpf());
379 for (
auto idx = inst.variablesSequence().rbegin();
380 idx != inst.variablesSequence().rend();
386 auto idx = (std::size_t)0;
387 while ((!jnst.end()) && idx < array.size()) {
389 a->cpf().set(inst, array[idx]);
396 template <
typename GUM_SCALAR >
397 INLINE
void PRMFactory< GUM_SCALAR >::setCPFByFloatRule(
398 const std::vector< std::string >& parents,
399 const std::vector<
float >& values) {
400 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
401 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
403 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
404 GUM_ERROR(OperationNotAllowed,
"wrong number of parents");
407 if (values.size() != a->type().variable().domainSize()) {
408 GUM_ERROR(OperationNotAllowed,
"wrong number of values");
411 std::vector< GUM_SCALAR > values2(values.begin(), values.end());
412 setCPFByRule(parents, values2);
415 template <
typename GUM_SCALAR >
416 INLINE
void PRMFactory< GUM_SCALAR >::setCPFByRule(
417 const std::vector< std::string >& parents,
418 const std::vector< GUM_SCALAR >& values) {
419 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
420 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
422 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
423 GUM_ERROR(OperationNotAllowed,
"wrong number of parents");
426 if (values.size() != a->type().variable().domainSize()) {
427 GUM_ERROR(OperationNotAllowed,
"wrong number of values");
430 if (
dynamic_cast< PRMFormAttribute< GUM_SCALAR >* >(a)) {
431 auto form =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(a);
434 Instantiation jnst, knst;
435 const DiscreteVariable* var = 0;
439 for (Idx i = 0; i < parents.size(); ++i) {
440 var = form->formulas().variablesSequence().atPos(1 + i);
442 if (parents[i] ==
"*") {
449 for (Size j = 0; j < var->domainSize(); ++j) {
450 if (var->label(j) == parents[i]) {
451 jnst.chgVal(*var, j);
458 std::string msg =
"could not find label ";
459 GUM_ERROR(NotFound, msg + parents[i]);
464 Instantiation inst(form->formulas());
467 for (Size i = 0; i < form->type()->domainSize(); ++i) {
468 inst.chgVal(form->type().variable(), i);
470 for (inst.setFirstIn(knst); !inst.end(); inst.incIn(knst)) {
471 form->formulas().set(inst, std::to_string(values[i]));
476 GUM_ERROR(OperationNotAllowed,
"invalide attribute type");
480 template <
typename GUM_SCALAR >
481 INLINE
void PRMFactory< GUM_SCALAR >::setCPFByRule(
482 const std::vector< std::string >& parents,
483 const std::vector< std::string >& values) {
484 auto a =
static_cast< PRMAttribute< GUM_SCALAR >* >(
485 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
487 if ((parents.size() + 1) != a->cpf().variablesSequence().size()) {
488 GUM_ERROR(OperationNotAllowed,
"wrong number of parents");
491 if (values.size() != a->type().variable().domainSize()) {
492 GUM_ERROR(OperationNotAllowed,
"wrong number of values");
495 if (
dynamic_cast< PRMFormAttribute< GUM_SCALAR >* >(a)) {
496 auto form =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(a);
499 Instantiation jnst, knst;
500 const DiscreteVariable* var = 0;
504 for (Idx i = 0; i < parents.size(); ++i) {
505 var = form->formulas().variablesSequence().atPos(1 + i);
507 if (parents[i] ==
"*") {
514 for (Size j = 0; j < var->domainSize(); ++j) {
515 if (var->label(j) == parents[i]) {
516 jnst.chgVal(*var, j);
523 std::string msg =
"could not find label ";
524 GUM_ERROR(NotFound, msg + parents[i]);
529 Instantiation inst(form->formulas());
532 for (Size i = 0; i < form->type()->domainSize(); ++i) {
533 inst.chgVal(form->type().variable(), i);
535 for (inst.setFirstIn(knst); !inst.end(); inst.incIn(knst)) {
536 form->formulas().set(inst, values[i]);
541 GUM_ERROR(OperationNotAllowed,
"invalide attribute type");
545 template <
typename GUM_SCALAR >
546 INLINE
void PRMFactory< GUM_SCALAR >::addParameter(
const std::string& type,
547 const std::string& name,
549 auto c =
static_cast< PRMClass< GUM_SCALAR >* >(
550 checkStack__(1, PRMObject::prm_type::CLASS));
552 PRMParameter< GUM_SCALAR >* p =
nullptr;
554 p =
new PRMParameter< GUM_SCALAR >(
556 PRMParameter< GUM_SCALAR >::ParameterType::INT,
558 }
else if (type ==
"real") {
559 p =
new PRMParameter< GUM_SCALAR >(
561 PRMParameter< GUM_SCALAR >::ParameterType::REAL,
567 }
catch (DuplicateElement&) { c->overload(p); }
570 template <
typename GUM_SCALAR >
571 INLINE
void PRMFactory< GUM_SCALAR >::startAggregator(
572 const std::string& name,
573 const std::string& agg_type,
574 const std::string& rv_type,
575 const std::vector< std::string >& params) {
576 PRMClass< GUM_SCALAR >* c =
static_cast< PRMClass< GUM_SCALAR >* >(
577 checkStack__(1, PRMObject::prm_type::CLASS));
579 auto agg =
new PRMAggregate< GUM_SCALAR >(
581 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
582 *retrieveType__(rv_type));
586 }
catch (DuplicateElement&) { c->overload(agg); }
588 switch (agg->agg_type()) {
589 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT:
590 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
591 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
592 if (params.size() != 1) {
593 GUM_ERROR(OperationNotAllowed,
"aggregate requires a parameter");
595 agg->setLabel(params.front());
602 stack__.push_back(agg);
605 template <
typename GUM_SCALAR >
607 PRMFactory< GUM_SCALAR >::continueAggregator(
const std::string& name) {
608 PRMClassElementContainer< GUM_SCALAR >* c = checkStackContainter__(1);
609 if (!c->exists(name)) { GUM_ERROR(NotFound, name <<
"not found"); }
610 auto& agg = c->get(name);
611 if (!PRMClassElement< GUM_SCALAR >::isAggregate(agg)) {
612 GUM_ERROR(OperationNotAllowed, name <<
" not an aggregate");
614 stack__.push_back(&agg);
617 template <
typename GUM_SCALAR >
619 PRMFactory< GUM_SCALAR >::addParent__(PRMClass< GUM_SCALAR >* c,
620 PRMAggregate< GUM_SCALAR >* agg,
621 const std::string& name) {
622 auto chains = std::vector< std::string >{name};
623 auto inputs = std::vector< PRMClassElement< GUM_SCALAR >* >();
624 retrieveInputs__(c, chains, inputs);
626 switch (agg->agg_type()) {
627 case PRMAggregate< GUM_SCALAR >::AggregateType::OR:
628 case PRMAggregate< GUM_SCALAR >::AggregateType::AND: {
629 if (inputs.front()->type() != *(retrieveType__(
"boolean"))) {
630 GUM_ERROR(TypeError,
"expected booleans");
636 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT:
637 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
638 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
639 if (!agg->hasLabel()) {
640 auto param = agg->labelValue();
643 while (label_idx < inputs.front()->type()->domainSize()) {
644 if (inputs.front()->type()->label(label_idx) == param) {
break; }
649 if (label_idx == inputs.front()->type()->domainSize()) {
650 GUM_ERROR(NotFound,
"could not find label");
653 agg->setLabel(label_idx);
659 case PRMAggregate< GUM_SCALAR >::AggregateType::SUM:
660 case PRMAggregate< GUM_SCALAR >::AggregateType::MEDIAN:
661 case PRMAggregate< GUM_SCALAR >::AggregateType::AMPLITUDE:
662 case PRMAggregate< GUM_SCALAR >::AggregateType::MIN:
663 case PRMAggregate< GUM_SCALAR >::AggregateType::MAX: {
668 GUM_ERROR(FatalError,
"Unknown aggregator.");
672 c->addArc(inputs.front()->safeName(), agg->safeName());
675 template <
typename GUM_SCALAR >
676 INLINE
void PRMFactory< GUM_SCALAR >::endAggregator() {
677 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_aggregate);
681 template <
typename GUM_SCALAR >
682 INLINE
void PRMFactory< GUM_SCALAR >::addAggregator(
683 const std::string& name,
684 const std::string& agg_type,
685 const std::vector< std::string >& chains,
686 const std::vector< std::string >& params,
688 PRMClass< GUM_SCALAR >* c =
static_cast< PRMClass< GUM_SCALAR >* >(
689 checkStack__(1, PRMObject::prm_type::CLASS));
692 if (chains.size() == 0) {
693 GUM_ERROR(OperationNotAllowed,
694 "a PRMAggregate<GUM_SCALAR> requires at least one parent");
698 std::vector< PRMClassElement< GUM_SCALAR >* > inputs;
703 bool hasSC = retrieveInputs__(c, chains, inputs);
708 if (inputs.size() > 1) {
709 for (
auto iter = inputs.begin() + 1; iter != inputs.end(); ++iter) {
710 if ((**(iter - 1)).type() != (**iter).type()) {
711 GUM_ERROR(TypeError,
"found different types");
717 PRMAggregate< GUM_SCALAR >* agg =
nullptr;
719 switch (PRMAggregate< GUM_SCALAR >::str2enum(agg_type)) {
720 case PRMAggregate< GUM_SCALAR >::AggregateType::OR:
721 case PRMAggregate< GUM_SCALAR >::AggregateType::AND: {
722 if (inputs.front()->type() != *(retrieveType__(
"boolean"))) {
723 GUM_ERROR(TypeError,
"expected booleans");
725 if (params.size() != 0) {
726 GUM_ERROR(OperationNotAllowed,
"invalid number of paramaters");
729 agg =
new PRMAggregate< GUM_SCALAR >(
731 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
732 inputs.front()->type());
737 case PRMAggregate< GUM_SCALAR >::AggregateType::EXISTS:
738 case PRMAggregate< GUM_SCALAR >::AggregateType::FORALL: {
739 if (params.size() != 1) {
740 GUM_ERROR(OperationNotAllowed,
"invalid number of parameters");
745 while (label_idx < inputs.front()->type()->domainSize()) {
746 if (inputs.front()->type()->label(label_idx) == params.front()) {
753 if (label_idx == inputs.front()->type()->domainSize()) {
754 GUM_ERROR(NotFound,
"could not find label");
758 agg =
new PRMAggregate< GUM_SCALAR >(
760 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
761 *(retrieveType__(
"boolean")),
768 case PRMAggregate< GUM_SCALAR >::AggregateType::SUM:
769 case PRMAggregate< GUM_SCALAR >::AggregateType::MEDIAN:
770 case PRMAggregate< GUM_SCALAR >::AggregateType::AMPLITUDE:
771 case PRMAggregate< GUM_SCALAR >::AggregateType::MIN:
772 case PRMAggregate< GUM_SCALAR >::AggregateType::MAX: {
773 if (params.size() != 0) {
774 GUM_ERROR(OperationNotAllowed,
"invalid number of parameters");
777 auto output_type = retrieveType__(type);
780 agg =
new PRMAggregate< GUM_SCALAR >(
782 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
788 case PRMAggregate< GUM_SCALAR >::AggregateType::COUNT: {
789 if (params.size() != 1) {
790 GUM_ERROR(OperationNotAllowed,
"invalid number of parameters");
795 while (label_idx < inputs.front()->type()->domainSize()) {
796 if (inputs.front()->type()->label(label_idx) == params.front()) {
803 if (label_idx == inputs.front()->type()->domainSize()) {
804 GUM_ERROR(NotFound,
"could not find label");
807 auto output_type = retrieveType__(type);
810 agg =
new PRMAggregate< GUM_SCALAR >(
812 PRMAggregate< GUM_SCALAR >::str2enum(agg_type),
820 GUM_ERROR(FatalError,
"Unknown aggregator.");
824 std::string safe_name = agg->safeName();
830 }
catch (DuplicateElement&) { c->overload(agg); }
833 auto attr =
new PRMScalarAttribute< GUM_SCALAR >(agg->name(),
839 }
catch (DuplicateElement&) { c->overload(attr); }
843 }
catch (DuplicateElement&) {
848 for (
const auto& elt: inputs) {
849 c->addArc(elt->safeName(), safe_name);
853 template <
typename GUM_SCALAR >
854 INLINE
void PRMFactory< GUM_SCALAR >::addReferenceSlot(
const std::string& type,
855 const std::string& name,
857 PRMClassElementContainer< GUM_SCALAR >* owner = checkStackContainter__(1);
858 PRMClassElementContainer< GUM_SCALAR >* slotType = 0;
861 slotType = retrieveClass__(type);
862 }
catch (NotFound&) {
864 slotType = retrieveInterface__(type);
865 }
catch (NotFound&) {
866 GUM_ERROR(NotFound,
"unknown ReferenceSlot<GUM_SCALAR> slot type");
870 PRMReferenceSlot< GUM_SCALAR >* ref
871 =
new PRMReferenceSlot< GUM_SCALAR >(name, *slotType, isArray);
875 }
catch (DuplicateElement&) { owner->overload(ref); }
878 template <
typename GUM_SCALAR >
879 INLINE
void PRMFactory< GUM_SCALAR >::addArray(
const std::string& type,
880 const std::string& name,
882 PRMSystem< GUM_SCALAR >* model =
static_cast< PRMSystem< GUM_SCALAR >* >(
883 checkStack__(1, PRMObject::prm_type::SYSTEM));
884 PRMClass< GUM_SCALAR >* c = retrieveClass__(type);
885 PRMInstance< GUM_SCALAR >* inst = 0;
888 model->addArray(name, *c);
890 for (Size i = 0; i < size; ++i) {
891 std::stringstream elt_name;
892 elt_name << name <<
"[" << i <<
"]";
893 inst =
new PRMInstance< GUM_SCALAR >(elt_name.str(), *c);
894 model->add(name, inst);
896 }
catch (PRMTypeError&) {
899 }
catch (NotFound&) {
905 template <
typename GUM_SCALAR >
906 INLINE
void PRMFactory< GUM_SCALAR >::incArray(
const std::string& l_i,
907 const std::string& r_i) {
908 PRMSystem< GUM_SCALAR >* model =
static_cast< PRMSystem< GUM_SCALAR >* >(
909 checkStack__(1, PRMObject::prm_type::SYSTEM));
911 if (model->isArray(l_i)) {
912 if (model->isInstance(r_i)) {
913 model->add(l_i, model->get(r_i));
915 GUM_ERROR(NotFound,
"right value is not an instance");
918 GUM_ERROR(NotFound,
"left value is no an array");
922 template <
typename GUM_SCALAR >
924 PRMFactory< GUM_SCALAR >::setReferenceSlot(
const std::string& l_i,
925 const std::string& l_ref,
926 const std::string& r_i) {
927 auto model =
static_cast< PRMSystem< GUM_SCALAR >* >(
928 checkStack__(1, PRMObject::prm_type::SYSTEM));
929 std::vector< PRMInstance< GUM_SCALAR >* > lefts;
930 std::vector< PRMInstance< GUM_SCALAR >* > rights;
932 if (model->isInstance(l_i)) {
933 lefts.push_back(&(model->get(l_i)));
934 }
else if (model->isArray(l_i)) {
935 for (
const auto& elt: model->getArray(l_i))
936 lefts.push_back(elt);
938 GUM_ERROR(NotFound,
"left value does not name an instance or an array");
941 if (model->isInstance(r_i)) {
942 rights.push_back(&(model->get(r_i)));
943 }
else if (model->isArray(r_i)) {
944 for (
const auto& elt: model->getArray(r_i))
945 rights.push_back(elt);
947 GUM_ERROR(NotFound,
"left value does not name an instance or an array");
950 for (
const auto l: lefts) {
951 for (
const auto r: rights) {
952 auto& elt = l->type().get(l_ref);
953 if (PRMClassElement< GUM_SCALAR >::isReferenceSlot(elt)) {
954 l->add(elt.id(), *r);
957 GUM_ERROR(NotFound,
"unfound reference slot");
963 template <
typename GUM_SCALAR >
964 INLINE PRMSlotChain< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::buildSlotChain__(
965 PRMClassElementContainer< GUM_SCALAR >* start,
966 const std::string& name) {
967 std::vector< std::string > v;
968 decomposePath(name, v);
969 PRMClassElementContainer< GUM_SCALAR >* current = start;
970 PRMReferenceSlot< GUM_SCALAR >* ref =
nullptr;
971 Sequence< PRMClassElement< GUM_SCALAR >* > elts;
973 for (size_t i = 0; i < v.size(); ++i) {
975 switch (current->get(v[i]).elt_type()) {
976 case PRMClassElement< GUM_SCALAR >::prm_refslot:
977 ref = &(
static_cast< PRMReferenceSlot< GUM_SCALAR >& >(
978 current->get(v[i])));
984 case PRMClassElement< GUM_SCALAR >::prm_aggregate:
985 case PRMClassElement< GUM_SCALAR >::prm_attribute:
987 if (i == v.size() - 1) {
988 elts.insert(&(current->get(v[i])));
998 }
catch (NotFound&) {
return nullptr; }
1001 GUM_ASSERT(v.size() == elts.size());
1003 current->setOutputNode(*(elts.back()),
true);
1005 return new PRMSlotChain< GUM_SCALAR >(name, elts);
1008 template <
typename GUM_SCALAR >
1009 INLINE
bool PRMFactory< GUM_SCALAR >::retrieveInputs__(
1010 PRMClass< GUM_SCALAR >* c,
1011 const std::vector< std::string >& chains,
1012 std::vector< PRMClassElement< GUM_SCALAR >* >& inputs) {
1013 bool retVal =
false;
1015 for (size_t i = 0; i < chains.size(); ++i) {
1017 inputs.push_back(&(c->get(chains[i])));
1019 || PRMClassElement< GUM_SCALAR >::isSlotChain(*(inputs.back()));
1020 }
catch (NotFound&) {
1021 inputs.push_back(buildSlotChain__(c, chains[i]));
1024 if (inputs.back()) {
1025 c->add(inputs.back());
1027 GUM_ERROR(NotFound,
"unknown slot chain");
1032 PRMType* t = retrieveCommonType__(inputs);
1034 std::vector< std::pair< PRMClassElement< GUM_SCALAR >*,
1035 PRMClassElement< GUM_SCALAR >* > >
1038 for (
const auto& elt: inputs) {
1039 if ((*elt).type() != (*t)) {
1040 if (PRMClassElement< GUM_SCALAR >::isSlotChain(*elt)) {
1041 PRMSlotChain< GUM_SCALAR >* sc
1042 =
static_cast< PRMSlotChain< GUM_SCALAR >* >(elt);
1043 std::stringstream name;
1045 for (Size idx = 0; idx < sc->chain().size() - 1; ++idx) {
1046 name << sc->chain().atPos(idx)->name() <<
".";
1049 name <<
".(" << t->name() <<
")" << sc->lastElt().name();
1052 toAdd.push_back(std::make_pair(elt, &(c->get(name.str()))));
1053 }
catch (NotFound&) {
1055 std::make_pair(elt, buildSlotChain__(c, name.str())));
1058 std::stringstream name;
1059 name <<
"(" << t->name() <<
")" << elt->name();
1060 toAdd.push_back(std::make_pair(elt, &(c->get(name.str()))));
1068 template <
typename GUM_SCALAR >
1069 INLINE PRMType* PRMFactory< GUM_SCALAR >::retrieveCommonType__(
1070 const std::vector< PRMClassElement< GUM_SCALAR >* >& elts) {
1071 const PRMType* current =
nullptr;
1072 HashTable< std::string, Size > counters;
1075 for (
const auto& elt: elts) {
1077 current = &((*elt).type());
1079 while (current != 0) {
1081 if (counters.exists(current->name())) {
1082 ++(counters[current->name()]);
1084 counters.insert(current->name(), 1);
1088 if (current->isSubType()) {
1089 current = &(current->superType());
1094 }
catch (OperationNotAllowed&) {
1095 GUM_ERROR(WrongClassElement,
1096 "found a ClassElement<GUM_SCALAR> without a type");
1105 int current_depth = 0;
1107 for (
const auto& elt: counters) {
1108 if ((elt.second) == elts.size()) {
1109 current_depth = typeDepth__(retrieveType__(elt.first));
1111 if (current_depth > max_depth) {
1112 max_depth = current_depth;
1113 current = retrieveType__(elt.first);
1118 if (current) {
return const_cast< PRMType* >(current); }
1120 GUM_ERROR(NotFound,
"could not find a common type");
1123 template <
typename GUM_SCALAR >
1124 INLINE
void PRMFactory< GUM_SCALAR >::addNoisyOrCompound(
1125 const std::string& name,
1126 const std::vector< std::string >& chains,
1127 const std::vector<
float >& numbers,
1129 const std::vector< std::string >& labels) {
1130 if (currentType() != PRMObject::prm_type::CLASS) {
1131 GUM_ERROR(gum::FactoryInvalidState,
"invalid state to add a noisy-or");
1134 PRMClass< GUM_SCALAR >* c
1135 =
dynamic_cast< gum::prm::PRMClass< GUM_SCALAR >* >(getCurrent());
1137 std::vector< PRMClassElement< GUM_SCALAR >* > parents;
1139 for (
const auto& elt: chains)
1140 parents.push_back(&(c->get(elt)));
1142 PRMType* common_type = retrieveCommonType__(parents);
1144 for (size_t idx = 0; idx < parents.size(); ++idx) {
1145 if (parents[idx]->type() != (*common_type)) {
1146 PRMClassElement< GUM_SCALAR >* parent = parents[idx];
1149 std::string safe_name = parent->cast(*common_type);
1151 if (!c->exists(safe_name)) {
1152 if (PRMClassElement< GUM_SCALAR >::isSlotChain(*parent)) {
1153 parents[idx] = buildSlotChain__(c, safe_name);
1154 c->add(parents[idx]);
1156 GUM_ERROR(NotFound,
"unable to find parent");
1159 parents[idx] = &(c->get(safe_name));
1164 if (numbers.size() == 1) {
1166 =
new gum::MultiDimNoisyORCompound< GUM_SCALAR >(leak, numbers.front());
1167 auto attr =
new PRMScalarAttribute< GUM_SCALAR >(name,
1168 retrieveType(
"boolean"),
1171 }
else if (numbers.size() == parents.size()) {
1172 gum::MultiDimNoisyORCompound< GUM_SCALAR >* noisy
1173 =
new gum::MultiDimNoisyORCompound< GUM_SCALAR >(leak);
1174 gum::prm::PRMFuncAttribute< GUM_SCALAR >* attr
1175 =
new gum::prm::PRMFuncAttribute< GUM_SCALAR >(name,
1176 retrieveType(
"boolean"),
1179 for (size_t idx = 0; idx < numbers.size(); ++idx) {
1180 noisy->causalWeight(parents[idx]->type().variable(), numbers[idx]);
1185 GUM_ERROR(OperationNotAllowed,
"invalid parameters for a noisy or");
1188 if (!labels.empty()) {
1189 GUM_ERROR(OperationNotAllowed,
1190 "labels definitions not handle for noisy-or");
1194 template <
typename GUM_SCALAR >
1196 PRMFactory< GUM_SCALAR >::retrieveType__(
const std::string& name)
const {
1197 PRMType* type =
nullptr;
1198 std::string full_name;
1201 if (prm__->typeMap__.exists(name)) {
1202 type = prm__->typeMap__[name];
1207 std::string prefixed = addPrefix__(name);
1208 if (prm__->typeMap__.exists(prefixed)) {
1210 type = prm__->typeMap__[prefixed];
1211 full_name = prefixed;
1212 }
else if (full_name != prefixed) {
1213 GUM_ERROR(DuplicateElement,
1214 "Type name '" << name <<
"' is ambiguous: specify full name.");
1219 std::string relatif_ns = currentPackage();
1220 size_t last_dot = relatif_ns.find_last_of(
'.');
1221 if (last_dot != std::string::npos) {
1222 relatif_ns = relatif_ns.substr(0, last_dot) +
'.' + name;
1223 if (prm__->typeMap__.exists(relatif_ns)) {
1225 type = prm__->typeMap__[relatif_ns];
1226 full_name = relatif_ns;
1227 }
else if (full_name != relatif_ns) {
1228 GUM_ERROR(DuplicateElement,
1229 "Type name '" << name
1230 <<
"' is ambiguous: specify full name.");
1237 if (!namespaces__.empty()) {
1238 auto ns_list = namespaces__.back();
1239 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1240 std::string ns = (*ns_list)[i];
1241 std::string ns_name = ns +
"." + name;
1242 if (prm__->typeMap__.exists(ns_name)) {
1244 type = prm__->typeMap__[ns_name];
1245 full_name = ns_name;
1246 }
else if (full_name != ns_name) {
1247 GUM_ERROR(DuplicateElement,
1248 "Type name '" << name
1249 <<
"' is ambiguous: specify full name.");
1256 GUM_ERROR(NotFound,
"Type '" << name <<
"' not found, check imports.");
1262 template <
typename GUM_SCALAR >
1263 PRMClass< GUM_SCALAR >*
1264 PRMFactory< GUM_SCALAR >::retrieveClass__(
const std::string& name)
const {
1265 PRMClass< GUM_SCALAR >* a_class =
nullptr;
1266 std::string full_name;
1269 if (prm__->classMap__.exists(name)) {
1270 a_class = prm__->classMap__[name];
1275 std::string prefixed = addPrefix__(name);
1276 if (prm__->classMap__.exists(prefixed)) {
1277 if (a_class ==
nullptr) {
1278 a_class = prm__->classMap__[prefixed];
1279 full_name = prefixed;
1280 }
else if (full_name != prefixed) {
1281 GUM_ERROR(DuplicateElement,
1282 "Class name '" << name
1283 <<
"' is ambiguous: specify full name.");
1288 if (!namespaces__.empty()) {
1289 auto ns_list = namespaces__.back();
1290 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1291 std::string ns = (*ns_list)[i];
1292 std::string ns_name = ns +
"." + name;
1293 if (prm__->classMap__.exists(ns_name)) {
1295 a_class = prm__->classMap__[ns_name];
1296 full_name = ns_name;
1297 }
else if (full_name != ns_name) {
1298 GUM_ERROR(DuplicateElement,
1299 "Class name '" << name
1300 <<
"' is ambiguous: specify full name.");
1307 GUM_ERROR(NotFound,
"Class '" << name <<
"' not found, check imports.");
1313 template <
typename GUM_SCALAR >
1314 PRMInterface< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::retrieveInterface__(
1315 const std::string& name)
const {
1316 PRMInterface< GUM_SCALAR >* interface =
nullptr;
1317 std::string full_name;
1320 if (prm__->interfaceMap__.exists(name)) {
1321 interface = prm__->interfaceMap__[name];
1326 std::string prefixed = addPrefix__(name);
1327 if (prm__->interfaceMap__.exists(prefixed)) {
1328 if (interface ==
nullptr) {
1329 interface = prm__->interfaceMap__[prefixed];
1330 full_name = prefixed;
1331 }
else if (full_name != prefixed) {
1332 GUM_ERROR(DuplicateElement,
1333 "Interface name '" << name
1334 <<
"' is ambiguous: specify full name.");
1339 if (!namespaces__.empty()) {
1340 auto ns_list = namespaces__.back();
1342 for (gum::Size i = 0; i < ns_list->size(); ++i) {
1343 std::string ns = (*ns_list)[i];
1344 std::string ns_name = ns +
"." + name;
1346 if (prm__->interfaceMap__.exists(ns_name)) {
1347 if (interface ==
nullptr) {
1348 interface = prm__->interfaceMap__[ns_name];
1349 full_name = ns_name;
1350 }
else if (full_name != ns_name) {
1351 GUM_ERROR(DuplicateElement,
1353 << name <<
"' is ambiguous: specify full name.");
1359 if (interface ==
nullptr) {
1361 "Interface '" << name <<
"' not found, check imports.");
1367 template <
typename GUM_SCALAR >
1368 INLINE PRMFactory< GUM_SCALAR >::PRMFactory() {
1369 GUM_CONSTRUCTOR(PRMFactory);
1370 prm__ =
new PRM< GUM_SCALAR >();
1373 template <
typename GUM_SCALAR >
1374 INLINE PRMFactory< GUM_SCALAR >::PRMFactory(PRM< GUM_SCALAR >* prm) :
1375 IPRMFactory(), prm__(prm) {
1376 GUM_CONSTRUCTOR(PRMFactory);
1379 template <
typename GUM_SCALAR >
1380 INLINE PRMFactory< GUM_SCALAR >::~PRMFactory() {
1381 GUM_DESTRUCTOR(PRMFactory);
1382 while (!namespaces__.empty()) {
1383 auto ns = namespaces__.back();
1384 namespaces__.pop_back();
1389 template <
typename GUM_SCALAR >
1390 INLINE PRM< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::prm()
const {
1394 template <
typename GUM_SCALAR >
1395 INLINE PRMObject::prm_type PRMFactory< GUM_SCALAR >::currentType()
const {
1396 if (stack__.size() == 0) { GUM_ERROR(NotFound,
"no object being built"); }
1398 return stack__.back()->obj_type();
1401 template <
typename GUM_SCALAR >
1402 INLINE PRMObject* PRMFactory< GUM_SCALAR >::getCurrent() {
1403 if (stack__.size() == 0) { GUM_ERROR(NotFound,
"no object being built"); }
1405 return stack__.back();
1408 template <
typename GUM_SCALAR >
1409 INLINE
const PRMObject* PRMFactory< GUM_SCALAR >::getCurrent()
const {
1410 if (stack__.size() == 0) { GUM_ERROR(NotFound,
"no object being built"); }
1412 return stack__.back();
1415 template <
typename GUM_SCALAR >
1416 INLINE PRMObject* PRMFactory< GUM_SCALAR >::closeCurrent() {
1417 if (stack__.size() > 0) {
1418 PRMObject* obj = stack__.back();
1426 template <
typename GUM_SCALAR >
1427 INLINE std::string PRMFactory< GUM_SCALAR >::currentPackage()
const {
1428 return (packages__.empty()) ?
"" : packages__.back();
1431 template <
typename GUM_SCALAR >
1433 PRMFactory< GUM_SCALAR >::startDiscreteType(
const std::string& name,
1434 std::string super) {
1435 std::string real_name = addPrefix__(name);
1436 if (prm__->typeMap__.exists(real_name)) {
1437 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.");
1440 auto t =
new PRMType(LabelizedVariable(real_name,
"", 0));
1441 stack__.push_back(t);
1443 auto t =
new PRMType(LabelizedVariable(real_name,
"", 0));
1444 t->superType__ = retrieveType__(super);
1445 t->label_map__ =
new std::vector< Idx >();
1446 stack__.push_back(t);
1450 template <
typename GUM_SCALAR >
1451 INLINE
void PRMFactory< GUM_SCALAR >::addLabel(
const std::string& l,
1452 std::string extends) {
1453 if (extends ==
"") {
1455 =
static_cast< PRMType* >(checkStack__(1, PRMObject::prm_type::TYPE));
1456 LabelizedVariable* var =
dynamic_cast< LabelizedVariable* >(t->var__);
1459 GUM_ERROR(FatalError,
1460 "the current type's variable is not a LabelizedVariable.");
1461 }
else if (t->superType__) {
1462 GUM_ERROR(OperationNotAllowed,
"current type is a subtype.");
1467 }
catch (DuplicateElement&) {
1468 GUM_ERROR(DuplicateElement,
"a label '" << l <<
"' already exists");
1472 =
static_cast< PRMType* >(checkStack__(1, PRMObject::prm_type::TYPE));
1473 LabelizedVariable* var =
dynamic_cast< LabelizedVariable* >(t->var__);
1476 GUM_ERROR(FatalError,
1477 "the current type's variable is not a LabelizedVariable.");
1478 }
else if (!t->superType__) {
1479 GUM_ERROR(OperationNotAllowed,
"current type is not a subtype.");
1484 for (Idx i = 0; i < t->superType__->var__->domainSize(); ++i) {
1485 if (t->superType__->var__->label(i) == extends) {
1488 }
catch (DuplicateElement&) {
1489 GUM_ERROR(DuplicateElement,
"a label '" << l <<
"' already exists");
1492 t->label_map__->push_back(i);
1499 if (!found) { GUM_ERROR(NotFound,
"inexistent label in super type."); }
1503 template <
typename GUM_SCALAR >
1504 INLINE
void PRMFactory< GUM_SCALAR >::endDiscreteType() {
1506 =
static_cast< PRMType* >(checkStack__(1, PRMObject::prm_type::TYPE));
1508 if (!t->isValid__()) {
1509 GUM_ERROR(OperationNotAllowed,
"current type is not a valid subtype");
1510 }
else if (t->variable().domainSize() < 2) {
1511 GUM_ERROR(OperationNotAllowed,
1512 "current type is not a valid discrete type");
1515 prm__->typeMap__.insert(t->name(), t);
1517 prm__->types__.insert(t);
1521 template <
typename GUM_SCALAR >
1523 PRMFactory< GUM_SCALAR >::startDiscretizedType(
const std::string& name) {
1524 std::string real_name = addPrefix__(name);
1525 if (prm__->typeMap__.exists(real_name)) {
1526 GUM_ERROR(DuplicateElement,
"'" << real_name <<
"' is already used.");
1528 auto var = DiscretizedVariable<
double >(real_name,
"");
1529 auto t =
new PRMType(var);
1530 stack__.push_back(t);
1533 template <
typename GUM_SCALAR >
1534 INLINE
void PRMFactory< GUM_SCALAR >::addTick(
double tick) {
1536 =
static_cast< PRMType* >(checkStack__(1, PRMObject::prm_type::TYPE));
1537 DiscretizedVariable<
double >* var
1538 =
dynamic_cast< DiscretizedVariable<
double >* >(t->var__);
1541 GUM_ERROR(FatalError,
1542 "the current type's variable is not a LabelizedVariable.");
1547 }
catch (DefaultInLabel&) {
1548 GUM_ERROR(OperationNotAllowed,
"tick already in used for this variable");
1552 template <
typename GUM_SCALAR >
1553 INLINE
void PRMFactory< GUM_SCALAR >::endDiscretizedType() {
1555 =
static_cast< PRMType* >(checkStack__(1, PRMObject::prm_type::TYPE));
1557 if (t->variable().domainSize() < 2) {
1558 GUM_ERROR(OperationNotAllowed,
1559 "current type is not a valid discrete type");
1562 prm__->typeMap__.insert(t->name(), t);
1564 prm__->types__.insert(t);
1568 template <
typename GUM_SCALAR >
1569 INLINE
void PRMFactory< GUM_SCALAR >::addRangeType(
const std::string& name,
1572 std::string real_name = addPrefix__(name);
1573 if (prm__->typeMap__.exists(real_name)) {
1574 std::stringstream msg;
1575 msg <<
"\"" << real_name <<
"' is already used.";
1576 GUM_ERROR(DuplicateElement, msg.str());
1579 auto var = RangeVariable(real_name,
"", minVal, maxVal);
1580 auto t =
new PRMType(var);
1582 if (t->variable().domainSize() < 2) {
1583 GUM_ERROR(OperationNotAllowed,
1584 "current type is not a valid discrete type");
1587 prm__->typeMap__.insert(t->name(), t);
1588 prm__->types__.insert(t);
1591 template <
typename GUM_SCALAR >
1592 INLINE
void PRMFactory< GUM_SCALAR >::endInterface() {
1593 checkStack__(1, PRMObject::prm_type::PRM_INTERFACE);
1597 template <
typename GUM_SCALAR >
1598 INLINE
void PRMFactory< GUM_SCALAR >::addAttribute(
const std::string& type,
1599 const std::string& name) {
1600 checkStack__(1, PRMObject::prm_type::PRM_INTERFACE);
1601 startAttribute(type, name);
1605 template <
typename GUM_SCALAR >
1606 INLINE
void PRMFactory< GUM_SCALAR >::startAttribute(
const std::string& type,
1607 const std::string& name,
1609 PRMClassElementContainer< GUM_SCALAR >* c = checkStackContainter__(1);
1610 PRMAttribute< GUM_SCALAR >* a =
nullptr;
1612 if (PRMObject::isClass(*c) && (!scalar_attr)) {
1613 a =
new PRMFormAttribute< GUM_SCALAR >(
1614 static_cast< PRMClass< GUM_SCALAR >& >(*c),
1616 *retrieveType__(type));
1619 a =
new PRMScalarAttribute< GUM_SCALAR >(name, *retrieveType__(type));
1622 std::string dot =
".";
1627 }
catch (DuplicateElement&) { c->overload(a); }
1628 }
catch (Exception&) {
1629 if (a !=
nullptr && (!c->exists(a->id()))) {
delete a; }
1632 stack__.push_back(a);
1635 template <
typename GUM_SCALAR >
1637 PRMFactory< GUM_SCALAR >::continueAttribute(
const std::string& name) {
1638 PRMClassElementContainer< GUM_SCALAR >* c = checkStackContainter__(1);
1639 if (!c->exists(name)) { GUM_ERROR(NotFound, name <<
"not found"); }
1640 auto& a = c->get(name);
1641 if (!PRMClassElement< GUM_SCALAR >::isAttribute(a)) {
1642 GUM_ERROR(OperationNotAllowed, name <<
" not an attribute");
1644 stack__.push_back(&a);
1647 template <
typename GUM_SCALAR >
1648 INLINE
void PRMFactory< GUM_SCALAR >::endAttribute() {
1649 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute);
1653 template <
typename GUM_SCALAR >
1654 INLINE
void PRMFactory< GUM_SCALAR >::startSystem(
const std::string& name) {
1655 if (prm__->systemMap__.exists(name)) {
1656 GUM_ERROR(DuplicateElement,
"'" << name <<
"' is already used.");
1658 PRMSystem< GUM_SCALAR >* model
1659 =
new PRMSystem< GUM_SCALAR >(addPrefix__(name));
1660 stack__.push_back(model);
1661 prm__->systemMap__.insert(model->name(), model);
1662 prm__->systems__.insert(model);
1665 template <
typename GUM_SCALAR >
1666 INLINE
void PRMFactory< GUM_SCALAR >::endSystem() {
1668 PRMSystem< GUM_SCALAR >* model =
static_cast< PRMSystem< GUM_SCALAR >* >(
1669 checkStack__(1, PRMObject::prm_type::SYSTEM));
1671 model->instantiate();
1672 }
catch (Exception&) { GUM_ERROR(FatalError,
"could not create system"); }
1675 template <
typename GUM_SCALAR >
1676 INLINE
void PRMFactory< GUM_SCALAR >::addInstance(
const std::string& type,
1677 const std::string& name) {
1678 auto c = retrieveClass__(type);
1681 if (c->parameters().size() > 0) {
1682 HashTable< std::string,
double > params;
1683 addInstance(type, name, params);
1686 addInstance__(c, name);
1690 template <
typename GUM_SCALAR >
1691 INLINE
void PRMFactory< GUM_SCALAR >::addInstance(
1692 const std::string& type,
1693 const std::string& name,
1694 const HashTable< std::string,
double >& params) {
1695 auto c = retrieveClass__(type);
1697 if (c->parameters().empty()) {
1698 if (params.empty()) {
1699 addInstance__(c, name);
1701 GUM_ERROR(OperationNotAllowed,
1702 "Class " + type +
" does not have parameters");
1706 auto my_params = params;
1708 for (
const auto& p: c->parameters()) {
1709 if (!my_params.exists(p->name())) {
1710 my_params.insert(p->name(), p->value());
1715 std::stringstream sBuff;
1716 sBuff << c->name() <<
"<";
1718 for (
const auto& p: my_params) {
1719 sBuff << p.first <<
"=" << p.second <<
",";
1723 std::string sub_c = sBuff.str().substr(0, sBuff.str().size() - 1) +
">";
1727 auto pck_cpy = packages__;
1730 startClass(sub_c, c->name());
1733 for (
auto p: my_params) {
1734 auto type =
static_cast< PRMParameter< GUM_SCALAR >& >(c->get(p.first))
1736 if (type == PRMParameter< GUM_SCALAR >::ParameterType::INT) {
1737 addParameter(
"int", p.first, p.second);
1740 addParameter(
"real", p.first, p.second);
1746 packages__ = pck_cpy;
1748 }
catch (DuplicateElement&) {
1751 c = retrieveClass__(sub_c);
1752 addInstance__(c, name);
1756 template <
typename GUM_SCALAR >
1758 PRMFactory< GUM_SCALAR >::addInstance__(PRMClass< GUM_SCALAR >* type,
1759 const std::string& name) {
1760 PRMInstance< GUM_SCALAR >* i =
nullptr;
1762 auto s =
static_cast< PRMSystem< GUM_SCALAR >* >(
1763 checkStack__(1, PRMObject::prm_type::SYSTEM));
1764 i =
new PRMInstance< GUM_SCALAR >(name, *type);
1767 }
catch (OperationNotAllowed&) {
1768 if (i) {
delete i; }
1773 template <
typename GUM_SCALAR >
1775 PRMFactory< GUM_SCALAR >::addPrefix__(
const std::string& str)
const {
1776 if (!packages__.empty()) {
1777 std::string full_name = packages__.back();
1778 full_name.append(
".");
1779 full_name.append(str);
1786 template <
typename GUM_SCALAR >
1788 PRMFactory< GUM_SCALAR >::checkStack__(Idx i,
1789 PRMObject::prm_type obj_type) {
1791 if (stack__.size() - i > stack__.size()) {
1792 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls");
1795 PRMObject* obj = stack__[stack__.size() - i];
1797 if (obj->obj_type() != obj_type) {
1798 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls");
1804 template <
typename GUM_SCALAR >
1805 INLINE PRMClassElementContainer< GUM_SCALAR >*
1806 PRMFactory< GUM_SCALAR >::checkStackContainter__(Idx i) {
1808 if (stack__.size() - i > stack__.size()) {
1809 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls");
1812 PRMObject* obj = stack__[stack__.size() - i];
1814 if ((obj->obj_type() == PRMObject::prm_type::CLASS)
1815 || (obj->obj_type() == PRMObject::prm_type::PRM_INTERFACE)) {
1816 return static_cast< PRMClassElementContainer< GUM_SCALAR >* >(obj);
1818 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls");
1822 template <
typename GUM_SCALAR >
1823 INLINE PRMClassElement< GUM_SCALAR >* PRMFactory< GUM_SCALAR >::checkStack__(
1825 typename PRMClassElement< GUM_SCALAR >::ClassElementType elt_type) {
1827 if (stack__.size() - i > stack__.size()) {
1828 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls");
1831 PRMClassElement< GUM_SCALAR >* obj
1832 =
dynamic_cast< PRMClassElement< GUM_SCALAR >* >(
1833 stack__[stack__.size() - i]);
1836 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls");
1839 if (obj->elt_type() != elt_type) {
1840 GUM_ERROR(FactoryInvalidState,
"illegal sequence of calls");
1846 template <
typename GUM_SCALAR >
1847 INLINE
int PRMFactory< GUM_SCALAR >::typeDepth__(
const PRMType* t) {
1849 const PRMType* current = t;
1851 while (current->isSubType()) {
1853 current = &(current->superType());
1859 template <
typename GUM_SCALAR >
1860 INLINE
void PRMFactory< GUM_SCALAR >::pushPackage(
const std::string& name) {
1861 packages__.push_back(name);
1862 namespaces__.push_back(
new List< std::string >());
1865 template <
typename GUM_SCALAR >
1866 INLINE std::string PRMFactory< GUM_SCALAR >::popPackage() {
1867 std::string plop = currentPackage();
1869 if (!packages__.empty()) {
1870 std::string s = packages__.back();
1871 packages__.pop_back();
1873 if (namespaces__.size() > 0) {
1874 delete namespaces__.back();
1875 namespaces__.pop_back();
1883 template <
typename GUM_SCALAR >
1884 INLINE
void PRMFactory< GUM_SCALAR >::addImport(
const std::string& name) {
1885 if (name.size() == 0) {
1886 GUM_ERROR(OperationNotAllowed,
"illegal import name");
1888 if (namespaces__.empty()) {
1889 namespaces__.push_back(
new List< std::string >());
1891 namespaces__.back()->push_back(name);
1894 template <
typename GUM_SCALAR >
1896 PRMFactory< GUM_SCALAR >::setReferenceSlot(
const std::string& l_i,
1897 const std::string& r_i) {
1898 size_t pos = l_i.find_last_of(
'.');
1900 if (pos != std::string::npos) {
1901 std::string l_ref = l_i.substr(pos + 1, std::string::npos);
1902 setReferenceSlot(l_i.substr(0, pos), l_ref, r_i);
1904 GUM_ERROR(NotFound,
"left value does not name an instance or an array");
1908 template <
typename GUM_SCALAR >
1909 INLINE PRMClass< GUM_SCALAR >&
1910 PRMFactory< GUM_SCALAR >::retrieveClass(
const std::string& name) {
1911 return *retrieveClass__(name);
1914 template <
typename GUM_SCALAR >
1916 PRMFactory< GUM_SCALAR >::retrieveType(
const std::string& name) {
1917 return *retrieveType__(name);
1920 template <
typename GUM_SCALAR >
1921 INLINE PRMType& PRMFactory< GUM_SCALAR >::retrieveCommonType(
1922 const std::vector< PRMClassElement< GUM_SCALAR >* >& elts) {
1923 return *(retrieveCommonType__(elts));
1927 template <
typename GUM_SCALAR >
1928 INLINE
bool PRMFactory< GUM_SCALAR >::isClassOrInterface(
1929 const std::string& type)
const {
1931 retrieveClass__(type);
1934 }
catch (NotFound&) {
1935 }
catch (DuplicateElement&) {}
1938 retrieveInterface__(type);
1941 }
catch (NotFound&) {
1942 }
catch (DuplicateElement&) {}
1947 template <
typename GUM_SCALAR >
1948 INLINE
bool PRMFactory< GUM_SCALAR >::isArrayInCurrentSystem(
1949 const std::string& name)
const {
1950 const PRMSystem< GUM_SCALAR >* system
1951 =
static_cast<
const PRMSystem< GUM_SCALAR >* >(getCurrent());
1952 return (system && system->isArray(name));
1955 template <
typename GUM_SCALAR >
1956 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByColumns(
1957 const std::vector< std::string >& array) {
1958 checkStack__(2, PRMObject::prm_type::CLASS);
1960 auto a =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(
1961 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
1963 if (a->formulas().domainSize() != array.size()) {
1964 GUM_ERROR(OperationNotAllowed,
"illegal CPF size");
1967 if (a->formulas().nbrDim() == 1) {
1968 setRawCPFByLines(array);
1971 Instantiation inst(a->formulas());
1973 for (
auto idx = inst.variablesSequence().rbegin();
1974 idx != inst.variablesSequence().rend();
1980 auto idx = (std::size_t)0;
1981 while ((!jnst.end()) && idx < array.size()) {
1983 a->formulas().set(inst, array[idx]);
1993 template <
typename GUM_SCALAR >
1994 INLINE
void PRMFactory< GUM_SCALAR >::setRawCPFByLines(
1995 const std::vector< std::string >& array) {
1996 checkStack__(2, PRMObject::prm_type::CLASS);
1998 auto a =
static_cast< PRMFormAttribute< GUM_SCALAR >* >(
1999 checkStack__(1, PRMClassElement< GUM_SCALAR >::prm_attribute));
2001 if (a->formulas().domainSize() != array.size()) {
2002 GUM_ERROR(OperationNotAllowed,
"illegal CPF size");
2005 a->formulas().populate(array);