aGrUM  0.16.0
O3ClassFactory_tpl.h
Go to the documentation of this file.
1 
32 
33 namespace gum {
34  namespace prm {
35  namespace o3prm {
36 
37  template < typename GUM_SCALAR >
39  PRM< GUM_SCALAR >& prm,
40  O3PRM& o3_prm,
42  ErrorsContainer& errors) :
43  __prm(&prm),
44  __o3_prm(&o3_prm), __solver(&solver), __errors(&errors) {
45  GUM_CONSTRUCTOR(O3ClassFactory);
46  }
47 
48  template < typename GUM_SCALAR >
50  const O3ClassFactory< GUM_SCALAR >& src) :
51  __prm(src.__prm),
54  __nodeMap(src.__nodeMap), __dag(src.__dag),
56  GUM_CONS_CPY(O3ClassFactory);
57  }
58 
59  template < typename GUM_SCALAR >
62  __prm(std::move(src.__prm)),
63  __o3_prm(std::move(src.__o3_prm)), __solver(std::move(src.__solver)),
64  __errors(std::move(src.__errors)), __nameMap(std::move(src.__nameMap)),
65  __classMap(std::move(src.__classMap)),
66  __nodeMap(std::move(src.__nodeMap)), __dag(std::move(src.__dag)),
67  __o3Classes(std::move(src.__o3Classes)) {
68  GUM_CONS_MOV(O3ClassFactory);
69  }
70 
71  template < typename GUM_SCALAR >
73  GUM_DESTRUCTOR(O3ClassFactory);
74  }
75 
76  template < typename GUM_SCALAR >
79  if (this == &src) { return *this; }
80  __prm = src.__prm;
81  __o3_prm = src.__o3_prm;
82  __solver = src.__solver;
83  __errors = src.__errors;
84  __nameMap = src.__nameMap;
85  __classMap = src.__classMap;
86  __nodeMap = src.__nodeMap;
87  __dag = src.__dag;
89  return *this;
90  }
91 
92  template < typename GUM_SCALAR >
95  if (this == &src) { return *this; }
96  __prm = std::move(src.__prm);
97  __o3_prm = std::move(src.__o3_prm);
98  __solver = std::move(src.__solver);
99  __errors = std::move(src.__errors);
100  __nameMap = std::move(src.__nameMap);
101  __classMap = std::move(src.__classMap);
102  __nodeMap = std::move(src.__nodeMap);
103  __dag = std::move(src.__dag);
104  __o3Classes = std::move(src.__o3Classes);
105  return *this;
106  }
107 
108  template < typename GUM_SCALAR >
111 
112  // Class with a super class must be declared after
113  if (__checkO3Classes()) {
115 
116  for (auto c : __o3Classes) {
117  // Soving interfaces
118  auto implements = Set< std::string >();
119  for (auto& i : c->interfaces()) {
120  if (__solver->resolveInterface(i)) { implements.insert(i.label()); }
121  }
122 
123  // Adding the class
124  if (__solver->resolveClass(c->superLabel())) {
125  factory.startClass(
126  c->name().label(), c->superLabel().label(), &implements, true);
127  factory.endClass(false);
128  }
129  }
130  }
131  }
132 
133  template < typename GUM_SCALAR >
135  auto topo_order = __dag.topologicalOrder();
136 
137  for (auto id = topo_order.rbegin(); id != topo_order.rend(); --id) {
138  __o3Classes.push_back(__nodeMap[*id]);
139  }
140  }
141 
142  template < typename GUM_SCALAR >
145  }
146 
147  template < typename GUM_SCALAR >
149  for (auto& c : __o3_prm->classes()) {
150  auto id = __dag.addNode();
151 
152  try {
153  __nameMap.insert(c->name().label(), id);
154  __classMap.insert(c->name().label(), c.get());
155  __nodeMap.insert(id, c.get());
156 
157  } catch (DuplicateElement&) {
158  O3PRM_CLASS_DUPLICATE(c->name(), *__errors);
159  return false;
160  }
161  }
162 
163  return true;
164  }
165 
166  template < typename GUM_SCALAR >
168  for (auto& c : __o3_prm->classes()) {
169  if (c->superLabel().label() != "") {
170  if (!__solver->resolveClass(c->superLabel())) { return false; }
171 
172  auto head = __nameMap[c->superLabel().label()];
173  auto tail = __nameMap[c->name().label()];
174 
175  try {
176  __dag.addArc(tail, head);
177  } catch (InvalidDirectedCycle&) {
178  // Cyclic inheritance
179  O3PRM_CLASS_CYLIC_INHERITANCE(c->name(), c->superLabel(), *__errors);
180  return false;
181  }
182  }
183  }
184 
185  return true;
186  }
187 
188  template < typename GUM_SCALAR >
190  for (auto& c : __o3_prm->classes()) {
191  if (__checkImplementation(*c)) {
192  __prm->getClass(c->name().label()).initializeInheritance();
193  }
194  }
195  }
196 
200 
201  template < typename GUM_SCALAR >
203  // Saving attributes names for fast lookup
204  auto attr_map = AttrMap();
205  for (auto& a : c.attributes()) {
206  attr_map.insert(a->name().label(), a.get());
207  }
208 
209  // Saving aggregates names for fast lookup
210  auto agg_map = AggMap();
211  for (auto& agg : c.aggregates()) {
212  agg_map.insert(agg.name().label(), &agg);
213  }
214 
215  auto ref_map = RefMap();
216  for (auto& ref : c.referenceSlots()) {
217  ref_map.insert(ref.name().label(), &ref);
218  }
219 
220  // Cheking interface implementation
221  for (auto& i : c.interfaces()) {
222  if (__solver->resolveInterface(i)) {
223  if (!__checkImplementation(c, i, attr_map, agg_map, ref_map)) {
224  return false;
225  }
226  }
227  }
228 
229  return true;
230  }
231 
232  template < typename GUM_SCALAR >
233  INLINE bool
235  O3Label& i,
236  AttrMap& attr_map,
237  AggMap& agg_map,
238  RefMap& ref_map) {
239  const auto& real_i = __prm->getInterface(i.label());
240 
241  auto counter = (Size)0;
242  for (const auto& a : real_i.attributes()) {
243  if (attr_map.exists(a->name())) {
244  ++counter;
245 
246  if (!__checkImplementation(attr_map[a->name()]->type(), a->type())) {
247  O3PRM_CLASS_ATTR_IMPLEMENTATION(
248  c.name(), i, attr_map[a->name()]->name(), *__errors);
249  return false;
250  }
251  }
252 
253  if (agg_map.exists(a->name())) {
254  ++counter;
255 
256  if (!__checkImplementation(agg_map[a->name()]->variableType(),
257  a->type())) {
258  O3PRM_CLASS_AGG_IMPLEMENTATION(
259  c.name(), i, agg_map[a->name()]->name(), *__errors);
260  return false;
261  }
262  }
263  }
264 
265  if (counter != real_i.attributes().size()) {
266  O3PRM_CLASS_MISSING_ATTRIBUTES(c.name(), i, *__errors);
267  return false;
268  }
269 
270  counter = 0;
271  for (const auto& r : real_i.referenceSlots()) {
272  if (ref_map.exists(r->name())) {
273  ++counter;
274 
275  if (!__checkImplementation(ref_map[r->name()]->type(),
276  r->slotType())) {
277  O3PRM_CLASS_REF_IMPLEMENTATION(
278  c.name(), i, ref_map[r->name()]->name(), *__errors);
279  return false;
280  }
281  }
282  }
283  return true;
284  }
285 
286  template < typename GUM_SCALAR >
287  INLINE bool
289  const PRMType& type) {
290  if (!__solver->resolveType(o3_type)) { return false; }
291 
292  return __prm->type(o3_type.label()).isSubTypeOf(type);
293  }
294 
295  template < typename GUM_SCALAR >
297  O3Label& o3_type, const PRMClassElementContainer< GUM_SCALAR >& type) {
298  if (!__solver->resolveSlotType(o3_type)) { return false; }
299 
300  if (__prm->isInterface(o3_type.label())) {
301  return __prm->getInterface(o3_type.label()).isSubTypeOf(type);
302  } else {
303  return __prm->getClass(o3_type.label()).isSubTypeOf(type);
304  }
305  }
306 
307  template < typename GUM_SCALAR >
310  // Class with a super class must be declared after
311  for (auto c : __o3Classes) {
312  __prm->getClass(c->name().label()).inheritParameters();
313 
314  factory.continueClass(c->name().label());
315 
316  __addParameters(factory, *c);
317 
318  factory.endClass(false);
319  }
320  }
321 
322  template < typename GUM_SCALAR >
324  PRMFactory< GUM_SCALAR >& factory, O3Class& c) {
325  for (auto& p : c.parameters()) {
326  switch (p.type()) {
328  factory.addParameter("int", p.name().label(), p.value().value());
329  break;
330  }
331 
333  factory.addParameter("real", p.name().label(), p.value().value());
334  break;
335  }
336 
337  default: {
338  GUM_ERROR(FatalError, "unknown O3Parameter type");
339  }
340  }
341  }
342  }
343 
344  template < typename GUM_SCALAR >
346  // Class with a super class must be declared after
347  for (auto c : __o3Classes) {
348  __prm->getClass(c->name().label()).inheritReferenceSlots();
350  }
351  }
352 
353  template < typename GUM_SCALAR >
356 
357  factory.continueClass(c.name().label());
358 
359  // References
360  for (auto& ref : c.referenceSlots()) {
361  if (__checkReferenceSlot(c, ref)) {
362  factory.addReferenceSlot(
363  ref.type().label(), ref.name().label(), ref.isArray());
364  }
365  }
366 
367  factory.endClass(false);
368  }
369 
370  template < typename GUM_SCALAR >
371  INLINE bool
373  O3ReferenceSlot& ref) {
374  if (!__solver->resolveSlotType(ref.type())) { return false; }
375 
376  const auto& real_c = __prm->getClass(c.name().label());
377 
378  // Check for dupplicates
379  if (real_c.exists(ref.name().label())) {
380  const auto& elt = real_c.get(ref.name().label());
381 
383  auto slot_type = (PRMClassElementContainer< GUM_SCALAR >*)nullptr;
384 
385  if (__prm->isInterface(ref.type().label())) {
386  slot_type = &(__prm->getInterface(ref.type().label()));
387 
388  } else {
389  slot_type = &(__prm->getClass(ref.type().label()));
390  }
391 
392  auto real_ref =
393  static_cast< const PRMReferenceSlot< GUM_SCALAR >* >(&elt);
394 
395  if (slot_type->name() == real_ref->slotType().name()) {
396  O3PRM_CLASS_DUPLICATE_REFERENCE(ref.name(), *__errors);
397  return false;
398 
399  } else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
400  O3PRM_CLASS_ILLEGAL_OVERLOAD(ref.name(), c.name(), *__errors);
401  return false;
402  }
403 
404  } else {
405  O3PRM_CLASS_DUPLICATE_REFERENCE(ref.name(), *__errors);
406  return false;
407  }
408  }
409 
410  // If class we need to check for illegal references
411  if (__prm->isClass(ref.type().label())) {
412  const auto& ref_type = __prm->getClass(ref.type().label());
413 
414  // No recursive reference
415  if ((&ref_type) == (&real_c)) {
416  O3PRM_CLASS_SELF_REFERENCE(c.name(), ref.name(), *__errors);
417  return false;
418  }
419 
420  // No reference to subclasses
421  if (ref_type.isSubTypeOf(real_c)) {
422  O3PRM_CLASS_ILLEGAL_SUB_REFERENCE(c.name(), ref.type(), *__errors);
423  return false;
424  }
425  }
426 
427  return true;
428  }
429 
430  template < typename GUM_SCALAR >
432  // Class with a super class must be declared after
433  for (auto c : __o3Classes) {
434  __prm->getClass(c->name().label()).inheritAttributes();
435  __declareAttribute(*c);
436  }
437  }
438 
439  template < typename GUM_SCALAR >
441  // Class with a super class must be declared after
442  for (auto c : __o3Classes) {
443  __prm->getClass(c->name().label()).inheritAggregates();
445  }
446  }
447 
448  template < typename GUM_SCALAR >
451  factory.continueClass(c.name().label());
452 
453  for (auto& attr : c.attributes()) {
454  if (__checkAttributeForDeclaration(c, *attr)) {
455  factory.startAttribute(attr->type().label(), attr->name().label());
456  factory.endAttribute();
457  }
458  }
459 
460  factory.endClass(false);
461  }
462 
463  template < typename GUM_SCALAR >
465  O3Class& c, O3Attribute& attr) {
466  // Check type
467  if (!__solver->resolveType(attr.type())) { return false; }
468 
469  // Checking type legality if overload
470  if (c.superLabel().label() != "") {
471  const auto& super = __prm->getClass(c.superLabel().label());
472 
473  if (!super.exists(attr.name().label())) { return true; }
474 
475  const auto& super_type = super.get(attr.name().label()).type();
476  const auto& type = __prm->type(attr.type().label());
477 
478  if (!type.isSubTypeOf(super_type)) {
479  O3PRM_CLASS_ILLEGAL_OVERLOAD(attr.name(), c.superLabel(), *__errors);
480  return false;
481  }
482  }
483  return true;
484  }
485 
486  template < typename GUM_SCALAR >
489 
490  // Class with a super class must be declared in order
491  for (auto c : __o3Classes) {
492  __prm->getClass(c->name().label()).inheritSlotChains();
493  factory.continueClass(c->name().label());
494 
495  __completeAttribute(factory, *c);
496 
497  if (c->superLabel().label() != "") {
498  auto& super = __prm->getClass(c->superLabel().label());
499  auto to_complete = Set< std::string >();
500 
501  for (auto a : super.attributes()) {
502  to_complete.insert(a->safeName());
503  }
504 
505  for (auto a : super.aggregates()) {
506  to_complete.insert(a->safeName());
507  }
508 
509  for (auto& a : c->attributes()) {
510  to_complete.erase(__prm->getClass(c->name().label())
511  .get(a->name().label())
512  .safeName());
513  }
514 
515  for (auto& a : c->aggregates()) {
516  to_complete.erase(__prm->getClass(c->name().label())
517  .get(a.name().label())
518  .safeName());
519  }
520 
521  for (auto a : to_complete) {
522  __prm->getClass(c->name().label()).completeInheritance(a);
523  }
524  }
525 
526  factory.endClass(true);
527  }
528  }
529 
530  template < typename GUM_SCALAR >
533 
534  // Class with a super class must be declared in order
535  for (auto c : __o3Classes) {
536  factory.continueClass(c->name().label());
537 
538  __completeAggregates(factory, *c);
539 
540  factory.endClass(false);
541  }
542  }
543 
544  template < typename GUM_SCALAR >
546  PRMFactory< GUM_SCALAR >& factory, O3Class& c) {
547  // Attributes
548  for (auto& agg : c.aggregates()) {
549  if (__checkAggregateForCompletion(c, agg)) {
550  factory.continueAggregator(agg.name().label());
551 
552  for (const auto& parent : agg.parents()) {
553  factory.addParent(parent.label());
554  }
555 
556  factory.endAggregator();
557  }
558  }
559  }
560 
561  template < typename GUM_SCALAR >
563  O3Class& c, O3Aggregate& agg) {
564  // Checking parents
565  auto t = __checkAggParents(c, agg);
566  if (t == nullptr) { return false; }
567 
568  // Checking parameters numbers
569  if (!__checkAggParameters(c, agg, t)) { return false; }
570 
571  return true;
572  }
573 
574  template < typename GUM_SCALAR >
576  PRMFactory< GUM_SCALAR >& factory, O3Class& c) {
577  // Attributes
578  for (auto& attr : c.attributes()) {
579  if (__checkAttributeForCompletion(c, *attr)) {
580  factory.continueAttribute(attr->name().label());
581 
582  for (const auto& parent : attr->parents()) {
583  factory.addParent(parent.label());
584  }
585 
586  auto raw = dynamic_cast< const O3RawCPT* >(attr.get());
587 
588  if (raw) {
589  auto values = std::vector< std::string >();
590  for (const auto& val : raw->values()) {
591  values.push_back(val.formula().formula());
592  }
593  factory.setRawCPFByColumns(values);
594  }
595 
596  auto rule_cpt = dynamic_cast< const O3RuleCPT* >(attr.get());
597  if (rule_cpt) {
598  for (const auto& rule : rule_cpt->rules()) {
599  auto labels = std::vector< std::string >();
600  auto values = std::vector< std::string >();
601 
602  for (const auto& lbl : rule.first) {
603  labels.push_back(lbl.label());
604  }
605 
606  for (const auto& form : rule.second) {
607  values.push_back(form.formula().formula());
608  }
609 
610  factory.setCPFByRule(labels, values);
611  }
612  }
613 
614  factory.endAttribute();
615  }
616  }
617  }
618 
619  template < typename GUM_SCALAR >
621  const O3Class& o3_c, O3Attribute& attr) {
622  // Check for parents existence
623  const auto& c = __prm->getClass(o3_c.name().label());
624  for (auto& prnt : attr.parents()) {
625  if (!__checkParent(c, prnt)) { return false; }
626  }
627 
628  // Check that CPT sums to 1
629  auto raw = dynamic_cast< O3RawCPT* >(&attr);
630  if (raw) { return __checkRawCPT(c, *raw); }
631 
632  auto rule = dynamic_cast< O3RuleCPT* >(&attr);
633  if (rule) { return __checkRuleCPT(c, *rule); }
634 
635  return true;
636  }
637 
638  template < typename GUM_SCALAR >
640  const PRMClass< GUM_SCALAR >& c, const O3Label& prnt) {
641  if (prnt.label().find('.') == std::string::npos) {
642  return __checkLocalParent(c, prnt);
643 
644  } else {
645  return __checkRemoteParent(c, prnt);
646  }
647  }
648 
649  template < typename GUM_SCALAR >
651  const PRMClass< GUM_SCALAR >& c, const O3Label& prnt) {
652  if (!c.exists(prnt.label())) {
653  O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *__errors);
654  return false;
655  }
656 
657  const auto& elt = c.get(prnt.label());
661  O3PRM_CLASS_ILLEGAL_PARENT(prnt, *__errors);
662  return false;
663  }
664 
665  return true;
666  }
667 
668  template < typename GUM_SCALAR >
670  const PRMClassElementContainer< GUM_SCALAR >& c, const O3Label& prnt) {
671  if (__resolveSlotChain(c, prnt) == nullptr) { return false; }
672  return true;
673  }
674 
675  template < typename GUM_SCALAR >
677  const O3RuleCPT& attr, const O3RuleCPT::O3Rule& rule) {
678  // Check that the number of labels is correct
679  if (rule.first.size() != attr.parents().size()) {
680  O3PRM_CLASS_ILLEGAL_RULE_SIZE(
681  rule, rule.first.size(), attr.parents().size(), *__errors);
682  return false;
683  }
684  return true;
685  }
686 
687  template < typename GUM_SCALAR >
689  const PRMClass< GUM_SCALAR >& c,
690  const O3RuleCPT& attr,
691  const O3RuleCPT::O3Rule& rule) {
692  bool errors = false;
693  for (std::size_t i = 0; i < attr.parents().size(); ++i) {
694  auto label = rule.first[i];
695  auto prnt = attr.parents()[i];
696  try {
697  auto real_labels = __resolveSlotChain(c, prnt)->type()->labels();
698  // c.get(prnt.label()).type()->labels();
699  if (label.label() != "*"
700  && std::find(real_labels.begin(), real_labels.end(), label.label())
701  == real_labels.end()) {
702  O3PRM_CLASS_ILLEGAL_RULE_LABEL(rule, label, prnt, *__errors);
703  errors = true;
704  }
705  } catch (Exception&) {
706  // parent does not exists and is already reported
707  }
708  }
709  return errors == false;
710  }
711 
712  template < typename GUM_SCALAR >
714  const HashTable< std::string, const PRMParameter< GUM_SCALAR >* >& scope,
715  O3RuleCPT::O3Rule& rule) {
716  // Add parameters to formulas
717  for (auto& f : rule.second) {
718  f.formula().variables().clear();
719  for (const auto& values : scope) {
720  f.formula().variables().insert(values.first, values.second->value());
721  }
722  }
723  }
724 
725 
726  template < typename GUM_SCALAR >
728  const PRMClass< GUM_SCALAR >& c,
729  const O3RuleCPT& attr,
730  const O3RuleCPT::O3Rule& rule) {
731  bool errors = false;
732  // Check that formulas are valid and sums to 1
733  GUM_SCALAR sum = 0.0;
734  for (const auto& f : rule.second) {
735  try {
736  auto value = GUM_SCALAR(f.formula().result());
737  sum += value;
738  if (value < 0.0 || 1.0 < value) {
739  O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), f, *__errors);
740  errors = true;
741  }
742  } catch (OperationNotAllowed&) {
743  O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), f, *__errors);
744  errors = true;
745  }
746  }
747 
748  // Check that CPT sums to 1
749  if (std::abs(sum - 1.0) > 1e-3) {
750  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
751  c.name(), attr.name(), float(sum), *__errors);
752  errors = true;
753  } else if (std::abs(sum - 1.0f) > 1e-6) {
754  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
755  c.name(), attr.name(), float(sum), *__errors);
756  }
757  return errors == false;
758  }
759 
760  template < typename GUM_SCALAR >
762  const PRMClass< GUM_SCALAR >& c, O3RuleCPT& attr) {
763  const auto& scope = c.scope();
764  bool errors = false;
765  for (auto& rule : attr.rules()) {
766  try {
767  if (!__checkLabelsNumber(attr, rule)) { errors = true; }
768  if (!__checkLabelsValues(c, attr, rule)) { errors = true; }
769  __addParamsToForms(scope, rule);
770  if (!__checkRuleCPTSumsTo1(c, attr, rule)) { errors = true; }
771  } catch (Exception& e) {
772  GUM_SHOWERROR(e);
773  errors = true;
774  }
775  }
776 
777  return errors == false;
778  }
779 
780  template < typename GUM_SCALAR >
782  const PRMClass< GUM_SCALAR >& c, O3RawCPT& attr) {
783  const auto& type = __prm->type(attr.type().label());
784 
785  auto domainSize = type->domainSize();
786  for (auto& prnt : attr.parents()) {
787  try {
788  domainSize *= c.get(prnt.label()).type()->domainSize();
789  } catch (NotFound&) {
790  // If we are here, all parents have been check so __resolveSlotChain
791  // will not raise an error and not return a nullptr
792  domainSize *= __resolveSlotChain(c, prnt)->type()->domainSize();
793  }
794  }
795 
796  // Check for CPT size
797  if (domainSize != attr.values().size()) {
798  O3PRM_CLASS_ILLEGAL_CPT_SIZE(c.name(),
799  attr.name(),
800  Size(attr.values().size()),
801  domainSize,
802  *__errors);
803  return false;
804  }
805 
806  // Add parameters to formulas
807  const auto& scope = c.scope();
808  for (auto& f : attr.values()) {
809  f.formula().variables().clear();
810 
811  for (const auto& values : scope) {
812  f.formula().variables().insert(values.first, values.second->value());
813  }
814  }
815 
816  // Check that CPT sums to 1
817  Size parent_size = domainSize / type->domainSize();
818  auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
819 
820  for (std::size_t i = 0; i < attr.values().size(); ++i) {
821  try {
822  auto idx = i % parent_size;
823  auto val = (GUM_SCALAR)attr.values()[i].formula().result();
824  values[idx] += val;
825 
826  if (val < 0.0 || 1.0 < val) {
827  O3PRM_CLASS_ILLEGAL_CPT_VALUE(
828  c.name(), attr.name(), attr.values()[i], *__errors);
829  return false;
830  }
831  } catch (Exception&) {
832  O3PRM_CLASS_ILLEGAL_CPT_VALUE(
833  c.name(), attr.name(), attr.values()[i], *__errors);
834  return false;
835  }
836  }
837 
838  for (auto f : values) {
839  if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
840  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
841  c.name(), attr.name(), float(f), *__errors);
842  return false;
843  } else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
844  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
845  c.name(), attr.name(), float(f), *__errors);
846  }
847  }
848  return true;
849  }
850 
851  template < typename GUM_SCALAR >
852  INLINE const PRMClassElement< GUM_SCALAR >*
855  const O3Label& chain) {
856  auto s = chain.label();
857  auto current = &c;
858  std::vector< std::string > v;
859 
860  decomposePath(chain.label(), v);
861 
862  for (size_t i = 0; i < v.size(); ++i) {
863  auto link = v[i];
864 
865  if (!__checkSlotChainLink(*current, chain, link)) { return nullptr; }
866 
867  auto elt = &(current->get(link));
868 
869  if (i == v.size() - 1) {
870  // last link, should be an attribute or aggregate
871  return elt;
872 
873  } else {
874  // should be a reference slot
875 
876  auto ref = dynamic_cast< const PRMReferenceSlot< GUM_SCALAR >* >(elt);
877  if (ref) {
878  current = &(ref->slotType());
879  } else {
880  return nullptr; // failsafe to prevent infinite loop
881  }
882  }
883  }
884 
885  // Encountered only reference slots
886 
887  return nullptr;
888  }
889 
890  template < typename GUM_SCALAR >
893  const O3Label& chain,
894  const std::string& s) {
895  if (!c.exists(s)) {
896  O3PRM_CLASS_LINK_NOT_FOUND(chain, s, *__errors);
897  return false;
898  }
899  return true;
900  }
901 
902  template < typename GUM_SCALAR >
905  factory.continueClass(c.name().label());
906 
907  for (auto& agg : c.aggregates()) {
908  if (__checkAggregateForDeclaration(c, agg)) {
909  auto params = std::vector< std::string >();
910  for (auto& p : agg.parameters()) {
911  params.push_back(p.label());
912  }
913 
914  factory.startAggregator(agg.name().label(),
915  agg.aggregateType().label(),
916  agg.variableType().label(),
917  params);
918  factory.endAggregator();
919  }
920  }
921 
922  factory.endClass(false);
923  }
924 
925  template < typename GUM_SCALAR >
927  O3Class& o3class, O3Aggregate& agg) {
928  if (!__solver->resolveType(agg.variableType())) { return false; }
929 
930  // Checking type legality if overload
931  if (!__checkAggTypeLegality(o3class, agg)) { return false; }
932 
933  return true;
934  }
935 
936  template < typename GUM_SCALAR >
937  INLINE const PRMType*
939  O3Aggregate& agg) {
940  const auto& c = __prm->getClass(o3class.name().label());
941  auto t = (const PRMType*)nullptr;
942 
943  for (const auto& prnt : agg.parents()) {
944  auto elt = __resolveSlotChain(c, prnt);
945 
946  if (elt == nullptr) {
947  O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *__errors);
948  return nullptr;
949 
950  } else {
951  if (t == nullptr) {
952  try {
953  t = &(elt->type());
954 
955  } catch (OperationNotAllowed&) {
956  O3PRM_CLASS_WRONG_PARENT(prnt, *__errors);
957  return nullptr;
958  }
959 
960  } else if ((*t) != elt->type()) {
961  // Wront type in chain
962  O3PRM_CLASS_WRONG_PARENT_TYPE(
963  prnt, t->name(), elt->type().name(), *__errors);
964  return nullptr;
965  }
966  }
967  }
968  return t;
969  }
970 
971  template < typename GUM_SCALAR >
972  INLINE bool
974  O3Aggregate& agg) {
975  if (__prm->isClass(o3class.superLabel().label())) {
976  const auto& super = __prm->getClass(o3class.superLabel().label());
977  const auto& agg_type = __prm->type(agg.variableType().label());
978 
979  if (super.exists(agg.name().label())
980  && !agg_type.isSubTypeOf(super.get(agg.name().label()).type())) {
981  O3PRM_CLASS_ILLEGAL_OVERLOAD(
982  agg.name(), o3class.superLabel(), *__errors);
983  return false;
984  }
985  }
986 
987  return true;
988  }
989 
990  template < typename GUM_SCALAR >
992  O3Class& o3class, O3Aggregate& agg, const PRMType* t) {
993  bool ok = false;
994 
996  agg.aggregateType().label())) {
1003  ok = __checkParametersNumber(agg, 0);
1004  break;
1005  }
1006 
1010  ok = __checkParametersNumber(agg, 1);
1011  break;
1012  }
1013 
1014  default: {
1015  GUM_ERROR(FatalError, "unknown aggregate type");
1016  }
1017  }
1018 
1019  if (!ok) { return false; }
1020 
1021  // Checking parameters type
1023  agg.aggregateType().label())) {
1027  ok = __checkParameterValue(agg, *t);
1028  break;
1029  }
1030 
1031  default: { /* Nothing to do */
1032  }
1033  }
1034 
1035  return ok;
1036  }
1037 
1038  template < typename GUM_SCALAR >
1039  INLINE bool
1041  Size n) {
1042  if (agg.parameters().size() != n) {
1043  O3PRM_CLASS_AGG_PARAMETERS(
1044  agg.name(), Size(n), Size(agg.parameters().size()), *__errors);
1045  return false;
1046  }
1047 
1048  return true;
1049  }
1050 
1051  template < typename GUM_SCALAR >
1053  O3Aggregate& agg, const gum::prm::PRMType& t) {
1054  const auto& param = agg.parameters().front();
1055  bool found = false;
1056  for (Size idx = 0; idx < t.variable().domainSize(); ++idx) {
1057  if (t.variable().label(idx) == param.label()) {
1058  found = true;
1059  break;
1060  }
1061  }
1062 
1063  if (!found) {
1064  O3PRM_CLASS_AGG_PARAMETER_NOT_FOUND(agg.name(), param, *__errors);
1065  return false;
1066  }
1067 
1068  return true;
1069  }
1070 
1071  } // namespace o3prm
1072  } // namespace prm
1073 } // namespace gum
O3LabelList & parents()
Definition: O3prm.cpp:1068
O3ClassList & classes()
Definition: O3prm.cpp:503
std::pair< O3LabelList, O3FormulaList > O3Rule
Definition: O3prm.h:546
bool __checkAggregateForDeclaration(O3Class &o3class, O3Aggregate &agg)
virtual O3LabelList & parents()
Definition: O3prm.cpp:668
PRMParameter is a member of a Class in a PRM.
Definition: PRMParameter.h:52
bool __checkLabelsValues(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
virtual O3RuleList & rules()
Definition: O3prm.cpp:766
DiscreteVariable & variable()
Return a reference on the DiscreteVariable contained in this.
Definition: PRMType_inl.h:45
void endAggregator()
Finishes an aggregate declaration.
bool __checkAttributeForCompletion(const O3Class &o3_c, O3Attribute &attr)
void __addParamsToForms(const HashTable< std::string, const PRMParameter< GUM_SCALAR > * > &scope, O3RuleCPT::O3Rule &rule)
const std::string & name() const
Returns the name of this object.
Definition: PRMObject_inl.h:35
bool __checkAttributeForDeclaration(O3Class &o3_c, O3Attribute &attr)
HashTable< std::string, O3Attribute *> AttrMap
O3ReferenceSlotList & referenceSlots()
Definition: O3prm.cpp:887
#define GUM_SHOWERROR(e)
Definition: exceptions.h:61
bool __checkRawCPT(const PRMClass< GUM_SCALAR > &c, O3RawCPT &attr)
The O3Aggregate is part of the AST of the O3PRM language.
Definition: O3prm.h:577
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
virtual void continueClass(const std::string &c) override
Continue the declaration of a class.
O3AggregateList & aggregates()
Definition: O3prm.cpp:895
STL namespace.
bool __checkReferenceSlot(O3Class &c, O3ReferenceSlot &ref)
HashTable< std::string, O3Aggregate *> AggMap
Abstract class representing an element of PRM class.
virtual void endAttribute() override
Tells the factory that we finished declaring an attribute.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
virtual void addReferenceSlot(const std::string &type, const std::string &name, bool isArray) override
Tells the factory that we started declaring a slot.
This class is used contain and manipulate gum::ParseError.
The O3Label is part of the AST of the O3PRM language.
Definition: O3prm.h:173
bool __checkParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
The O3Attribute is part of the AST of the O3PRM language.
Definition: O3prm.h:471
virtual bool exists(const std::string &name) const
Returns true if a member with the given name exists in this PRMClassElementContainer or in the PRMCla...
virtual O3FormulaList & values()
Definition: O3prm.cpp:715
void __completeAggregates(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
O3LabelList & parameters()
Definition: O3prm.cpp:1074
O3ParameterList & parameters()
Definition: O3prm.cpp:882
bool __checkLocalParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
virtual NodeId addNode()
insert a new node and return its id
A PRMReferenceSlot represent a relation between two PRMClassElementContainer.
Definition: PRMObject.h:223
virtual void startClass(const std::string &c, const std::string &ext="", const Set< std::string > *implements=nullptr, bool delayInheritance=false) override
Tells the factory that we start a class declaration.
The class for generic Hash Tables.
Definition: hashTable.h:679
bool __checkRuleCPT(const PRMClass< GUM_SCALAR > &c, O3RuleCPT &attr)
void continueAggregator(const std::string &name)
Conitnues an aggregator declaration.
std::string & label()
Definition: O3prm.cpp:242
virtual void startAttribute(const std::string &type, const std::string &name, bool scalar_atttr=false) override
Tells the factory that we start an attribute declaration.
bool __checkRemoteParent(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &prnt)
HashTable< std::string, gum::NodeId > __nameMap
void __completeAttribute(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
The O3RuleCPT is part of the AST of the O3PRM language.
Definition: O3prm.h:542
virtual Size domainSize() const =0
void setRawCPFByColumns(const std::vector< GUM_SCALAR > &array)
Gives the factory the CPF in its raw form.
Resolves names for the different O3PRM factories.
Definition: O3NameSolver.h:57
bool __checkAggTypeLegality(O3Class &o3class, O3Aggregate &agg)
bool __checkAggParameters(O3Class &o3class, O3Aggregate &agg, const PRMType *t)
The O3Class is part of the AST of the O3PRM language.
Definition: O3prm.h:619
O3ClassFactory(PRM< GUM_SCALAR > &prm, O3PRM &o3_prm, O3NameSolver< GUM_SCALAR > &solver, ErrorsContainer &errors)
O3Label & superLabel()
Definition: O3prm.cpp:872
bool __checkParameterValue(O3Aggregate &agg, const gum::prm::PRMType &t)
Factory which builds a PRM<GUM_SCALAR>.
Definition: PRMType.h:50
HashTable< std::string, const PRMParameter< GUM_SCALAR > *> scope() const
Returns all the parameters in the scope of this class.
The O3ReferenceSlot is part of the AST of the O3PRM language.
Definition: O3prm.h:438
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
virtual O3Label & type()
Definition: O3prm.cpp:662
virtual void setCPFByRule(const std::vector< std::string > &labels, const std::vector< GUM_SCALAR > &values)
Fills the CPF using a rule.
HashTable< NodeId, O3Class *> __nodeMap
PRMClassElement< GUM_SCALAR > & get(NodeId id)
See gum::prm::PRMClassElementContainer<GUM_SCALAR>::get(NodeId).
bool __checkAggregateForCompletion(O3Class &o3class, O3Aggregate &agg)
HashTable< std::string, O3Class *> __classMap
Base class for all aGrUM&#39;s exceptions.
Definition: exceptions.h:106
HashTable< std::string, O3ReferenceSlot *> RefMap
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:43
This is a decoration of the DiscreteVariable class.
Definition: PRMType.h:63
bool __checkLabelsNumber(const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
O3AttributeList & attributes()
Definition: O3prm.cpp:889
virtual void endClass(bool checkImplementations=true) override
Tells the factory that we finished a class declaration.
Builds gum::prm::Class from gum::prm::o3prm::O3Class.
const Sequence< NodeId > & topologicalOrder(bool clear=true) const
The topological order stays the same as long as no variable or arcs are added or erased src the topol...
Definition: diGraph.cpp:91
bool __checkRuleCPTSumsTo1(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
The O3RawCPT is part of the AST of the O3PRM language.
Definition: O3prm.h:510
std::vector< O3Class *> __o3Classes
bool __checkSlotChainLink(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain, const std::string &s)
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
Definition: PRM.h:66
<agrum/PRM/classElementContainer.h>
A PRMClass is an object of a PRM representing a fragment of a Bayesian Network which can be instantia...
Definition: PRMClass.h:66
O3NameSolver< GUM_SCALAR > * __solver
O3ClassFactory< GUM_SCALAR > & operator=(const O3ClassFactory< GUM_SCALAR > &src)
O3LabelList & interfaces()
Definition: O3prm.cpp:877
void decomposePath(const std::string &path, std::vector< std::string > &v)
Decompose a string in a vector of strings using "." as separators.
Definition: utils_prm.cpp:29
void __addParameters(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
const PRMClassElement< GUM_SCALAR > * __resolveSlotChain(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain)
virtual void continueAttribute(const std::string &name) override
Continues the declaration of an attribute.
void addParameter(const std::string &type, const std::string &name, double value) override
Add a parameter to the current class with a default value.
bool __checkParametersNumber(O3Aggregate &agg, Size n)
The O3PRM is part of the AST of the O3PRM language.
Definition: O3prm.h:892
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:48
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
virtual O3Label & name()
Definition: O3prm.cpp:665
const PRMType * __checkAggParents(O3Class &o3class, O3Aggregate &agg)
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:613
void startAggregator(const std::string &name, const std::string &agg_type, const std::string &rv_type, const std::vector< std::string > &params)
Start an aggregator declaration.
virtual void addParent(const std::string &name) override
Tells the factory that we add a parent to the current declared attribute.
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55