aGrUM  0.14.2
O3ClassFactory_tpl.h
Go to the documentation of this file.
1 /**************************************************************************
2  * Copyright (C) 2005 by Pierre-Henri WUILLEMIN et Christophe GONZALES *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
20 
30 
31 namespace gum {
32  namespace prm {
33  namespace o3prm {
34 
35  template < typename GUM_SCALAR >
37  PRM< GUM_SCALAR >& prm,
38  O3PRM& o3_prm,
40  ErrorsContainer& errors) :
41  __prm(&prm),
42  __o3_prm(&o3_prm), __solver(&solver), __errors(&errors) {
43  GUM_CONSTRUCTOR(O3ClassFactory);
44  }
45 
46  template < typename GUM_SCALAR >
48  const O3ClassFactory< GUM_SCALAR >& src) :
49  __prm(src.__prm),
52  __nodeMap(src.__nodeMap), __dag(src.__dag),
54  GUM_CONS_CPY(O3ClassFactory);
55  }
56 
57  template < typename GUM_SCALAR >
60  __prm(std::move(src.__prm)),
61  __o3_prm(std::move(src.__o3_prm)), __solver(std::move(src.__solver)),
62  __errors(std::move(src.__errors)), __nameMap(std::move(src.__nameMap)),
63  __classMap(std::move(src.__classMap)),
64  __nodeMap(std::move(src.__nodeMap)), __dag(std::move(src.__dag)),
65  __o3Classes(std::move(src.__o3Classes)) {
66  GUM_CONS_MOV(O3ClassFactory);
67  }
68 
69  template < typename GUM_SCALAR >
71  GUM_DESTRUCTOR(O3ClassFactory);
72  }
73 
74  template < typename GUM_SCALAR >
77  if (this == &src) { return *this; }
78  __prm = src.__prm;
79  __o3_prm = src.__o3_prm;
80  __solver = src.__solver;
81  __errors = src.__errors;
82  __nameMap = src.__nameMap;
83  __classMap = src.__classMap;
84  __nodeMap = src.__nodeMap;
85  __dag = src.__dag;
87  return *this;
88  }
89 
90  template < typename GUM_SCALAR >
93  if (this == &src) { return *this; }
94  __prm = std::move(src.__prm);
95  __o3_prm = std::move(src.__o3_prm);
96  __solver = std::move(src.__solver);
97  __errors = std::move(src.__errors);
98  __nameMap = std::move(src.__nameMap);
99  __classMap = std::move(src.__classMap);
100  __nodeMap = std::move(src.__nodeMap);
101  __dag = std::move(src.__dag);
102  __o3Classes = std::move(src.__o3Classes);
103  return *this;
104  }
105 
106  template < typename GUM_SCALAR >
109 
110  // Class with a super class must be declared after
111  if (__checkO3Classes()) {
113 
114  for (auto c : __o3Classes) {
115  // Soving interfaces
116  auto implements = Set< std::string >();
117  for (auto& i : c->interfaces()) {
118  if (__solver->resolveInterface(i)) { implements.insert(i.label()); }
119  }
120 
121  // Adding the class
122  if (__solver->resolveClass(c->superLabel())) {
123  factory.startClass(
124  c->name().label(), c->superLabel().label(), &implements, true);
125  factory.endClass(false);
126  }
127  }
128  }
129  }
130 
131  template < typename GUM_SCALAR >
133  auto topo_order = __dag.topologicalOrder();
134 
135  for (auto id = topo_order.rbegin(); id != topo_order.rend(); --id) {
136  __o3Classes.push_back(__nodeMap[*id]);
137  }
138  }
139 
140  template < typename GUM_SCALAR >
143  }
144 
145  template < typename GUM_SCALAR >
147  for (auto& c : __o3_prm->classes()) {
148  auto id = __dag.addNode();
149 
150  try {
151  __nameMap.insert(c->name().label(), id);
152  __classMap.insert(c->name().label(), c.get());
153  __nodeMap.insert(id, c.get());
154 
155  } catch (DuplicateElement&) {
156  O3PRM_CLASS_DUPLICATE(c->name(), *__errors);
157  return false;
158  }
159  }
160 
161  return true;
162  }
163 
164  template < typename GUM_SCALAR >
166  for (auto& c : __o3_prm->classes()) {
167  if (c->superLabel().label() != "") {
168  if (!__solver->resolveClass(c->superLabel())) { return false; }
169 
170  auto head = __nameMap[c->superLabel().label()];
171  auto tail = __nameMap[c->name().label()];
172 
173  try {
174  __dag.addArc(tail, head);
175  } catch (InvalidDirectedCycle&) {
176  // Cyclic inheritance
177  O3PRM_CLASS_CYLIC_INHERITANCE(c->name(), c->superLabel(), *__errors);
178  return false;
179  }
180  }
181  }
182 
183  return true;
184  }
185 
186  template < typename GUM_SCALAR >
188  for (auto& c : __o3_prm->classes()) {
189  if (__checkImplementation(*c)) {
190  __prm->getClass(c->name().label()).initializeInheritance();
191  }
192  }
193  }
194 
198 
199  template < typename GUM_SCALAR >
201  // Saving attributes names for fast lookup
202  auto attr_map = AttrMap();
203  for (auto& a : c.attributes()) {
204  attr_map.insert(a->name().label(), a.get());
205  }
206 
207  // Saving aggregates names for fast lookup
208  auto agg_map = AggMap();
209  for (auto& agg : c.aggregates()) {
210  agg_map.insert(agg.name().label(), &agg);
211  }
212 
213  auto ref_map = RefMap();
214  for (auto& ref : c.referenceSlots()) {
215  ref_map.insert(ref.name().label(), &ref);
216  }
217 
218  // Cheking interface implementation
219  for (auto& i : c.interfaces()) {
220  if (__solver->resolveInterface(i)) {
221  if (!__checkImplementation(c, i, attr_map, agg_map, ref_map)) {
222  return false;
223  }
224  }
225  }
226 
227  return true;
228  }
229 
230  template < typename GUM_SCALAR >
231  INLINE bool
233  O3Label& i,
234  AttrMap& attr_map,
235  AggMap& agg_map,
236  RefMap& ref_map) {
237  const auto& real_i = __prm->getInterface(i.label());
238 
239  auto counter = (Size)0;
240  for (const auto& a : real_i.attributes()) {
241  if (attr_map.exists(a->name())) {
242  ++counter;
243 
244  if (!__checkImplementation(attr_map[a->name()]->type(), a->type())) {
245  O3PRM_CLASS_ATTR_IMPLEMENTATION(
246  c.name(), i, attr_map[a->name()]->name(), *__errors);
247  return false;
248  }
249  }
250 
251  if (agg_map.exists(a->name())) {
252  ++counter;
253 
254  if (!__checkImplementation(agg_map[a->name()]->variableType(),
255  a->type())) {
256  O3PRM_CLASS_AGG_IMPLEMENTATION(
257  c.name(), i, agg_map[a->name()]->name(), *__errors);
258  return false;
259  }
260  }
261  }
262 
263  if (counter != real_i.attributes().size()) {
264  O3PRM_CLASS_MISSING_ATTRIBUTES(c.name(), i, *__errors);
265  return false;
266  }
267 
268  counter = 0;
269  for (const auto& r : real_i.referenceSlots()) {
270  if (ref_map.exists(r->name())) {
271  ++counter;
272 
273  if (!__checkImplementation(ref_map[r->name()]->type(),
274  r->slotType())) {
275  O3PRM_CLASS_REF_IMPLEMENTATION(
276  c.name(), i, ref_map[r->name()]->name(), *__errors);
277  return false;
278  }
279  }
280  }
281  return true;
282  }
283 
284  template < typename GUM_SCALAR >
285  INLINE bool
287  const PRMType& type) {
288  if (!__solver->resolveType(o3_type)) { return false; }
289 
290  return __prm->type(o3_type.label()).isSubTypeOf(type);
291  }
292 
293  template < typename GUM_SCALAR >
295  O3Label& o3_type, const PRMClassElementContainer< GUM_SCALAR >& type) {
296  if (!__solver->resolveSlotType(o3_type)) { return false; }
297 
298  if (__prm->isInterface(o3_type.label())) {
299  return __prm->getInterface(o3_type.label()).isSubTypeOf(type);
300  } else {
301  return __prm->getClass(o3_type.label()).isSubTypeOf(type);
302  }
303  }
304 
305  template < typename GUM_SCALAR >
308  // Class with a super class must be declared after
309  for (auto c : __o3Classes) {
310  __prm->getClass(c->name().label()).inheritParameters();
311 
312  factory.continueClass(c->name().label());
313 
314  __addParameters(factory, *c);
315 
316  factory.endClass(false);
317  }
318  }
319 
320  template < typename GUM_SCALAR >
322  PRMFactory< GUM_SCALAR >& factory, O3Class& c) {
323  for (auto& p : c.parameters()) {
324  switch (p.type()) {
326  factory.addParameter("int", p.name().label(), p.value().value());
327  break;
328  }
329 
331  factory.addParameter("real", p.name().label(), p.value().value());
332  break;
333  }
334 
335  default: { GUM_ERROR(FatalError, "unknown O3Parameter type"); }
336  }
337  }
338  }
339 
340  template < typename GUM_SCALAR >
342  // Class with a super class must be declared after
343  for (auto c : __o3Classes) {
344  __prm->getClass(c->name().label()).inheritReferenceSlots();
346  }
347  }
348 
349  template < typename GUM_SCALAR >
352 
353  factory.continueClass(c.name().label());
354 
355  // References
356  for (auto& ref : c.referenceSlots()) {
357  if (__checkReferenceSlot(c, ref)) {
358  factory.addReferenceSlot(
359  ref.type().label(), ref.name().label(), ref.isArray());
360  }
361  }
362 
363  factory.endClass(false);
364  }
365 
366  template < typename GUM_SCALAR >
367  INLINE bool
369  O3ReferenceSlot& ref) {
370  if (!__solver->resolveSlotType(ref.type())) { return false; }
371 
372  const auto& real_c = __prm->getClass(c.name().label());
373 
374  // Check for dupplicates
375  if (real_c.exists(ref.name().label())) {
376  const auto& elt = real_c.get(ref.name().label());
377 
379  auto slot_type = (PRMClassElementContainer< GUM_SCALAR >*)nullptr;
380 
381  if (__prm->isInterface(ref.type().label())) {
382  slot_type = &(__prm->getInterface(ref.type().label()));
383 
384  } else {
385  slot_type = &(__prm->getClass(ref.type().label()));
386  }
387 
388  auto real_ref =
389  static_cast< const PRMReferenceSlot< GUM_SCALAR >* >(&elt);
390 
391  if (slot_type->name() == real_ref->slotType().name()) {
392  O3PRM_CLASS_DUPLICATE_REFERENCE(ref.name(), *__errors);
393  return false;
394 
395  } else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
396  O3PRM_CLASS_ILLEGAL_OVERLOAD(ref.name(), c.name(), *__errors);
397  return false;
398  }
399 
400  } else {
401  O3PRM_CLASS_DUPLICATE_REFERENCE(ref.name(), *__errors);
402  return false;
403  }
404  }
405 
406  // If class we need to check for illegal references
407  if (__prm->isClass(ref.type().label())) {
408  const auto& ref_type = __prm->getClass(ref.type().label());
409 
410  // No recursive reference
411  if ((&ref_type) == (&real_c)) {
412  O3PRM_CLASS_SELF_REFERENCE(c.name(), ref.name(), *__errors);
413  return false;
414  }
415 
416  // No reference to subclasses
417  if (ref_type.isSubTypeOf(real_c)) {
418  O3PRM_CLASS_ILLEGAL_SUB_REFERENCE(c.name(), ref.type(), *__errors);
419  return false;
420  }
421  }
422 
423  return true;
424  }
425 
426  template < typename GUM_SCALAR >
428  // Class with a super class must be declared after
429  for (auto c : __o3Classes) {
430  __prm->getClass(c->name().label()).inheritAttributes();
431  __declareAttribute(*c);
432  }
433  }
434 
435  template < typename GUM_SCALAR >
437  // Class with a super class must be declared after
438  for (auto c : __o3Classes) {
439  __prm->getClass(c->name().label()).inheritAggregates();
441  }
442  }
443 
444  template < typename GUM_SCALAR >
447  factory.continueClass(c.name().label());
448 
449  for (auto& attr : c.attributes()) {
450  if (__checkAttributeForDeclaration(c, *attr)) {
451  factory.startAttribute(attr->type().label(), attr->name().label());
452  factory.endAttribute();
453  }
454  }
455 
456  factory.endClass(false);
457  }
458 
459  template < typename GUM_SCALAR >
461  O3Class& c, O3Attribute& attr) {
462  // Check type
463  if (!__solver->resolveType(attr.type())) { return false; }
464 
465  // Checking type legality if overload
466  if (c.superLabel().label() != "") {
467  const auto& super = __prm->getClass(c.superLabel().label());
468 
469  if (!super.exists(attr.name().label())) { return true; }
470 
471  const auto& super_type = super.get(attr.name().label()).type();
472  const auto& type = __prm->type(attr.type().label());
473 
474  if (!type.isSubTypeOf(super_type)) {
475  O3PRM_CLASS_ILLEGAL_OVERLOAD(attr.name(), c.superLabel(), *__errors);
476  return false;
477  }
478  }
479  return true;
480  }
481 
482  template < typename GUM_SCALAR >
485 
486  // Class with a super class must be declared in order
487  for (auto c : __o3Classes) {
488  __prm->getClass(c->name().label()).inheritSlotChains();
489  factory.continueClass(c->name().label());
490 
491  __completeAttribute(factory, *c);
492 
493  if (c->superLabel().label() != "") {
494  auto& super = __prm->getClass(c->superLabel().label());
495  auto to_complete = Set< std::string >();
496 
497  for (auto a : super.attributes()) {
498  to_complete.insert(a->safeName());
499  }
500 
501  for (auto a : super.aggregates()) {
502  to_complete.insert(a->safeName());
503  }
504 
505  for (auto& a : c->attributes()) {
506  to_complete.erase(__prm->getClass(c->name().label())
507  .get(a->name().label())
508  .safeName());
509  }
510 
511  for (auto& a : c->aggregates()) {
512  to_complete.erase(__prm->getClass(c->name().label())
513  .get(a.name().label())
514  .safeName());
515  }
516 
517  for (auto a : to_complete) {
518  __prm->getClass(c->name().label()).completeInheritance(a);
519  }
520  }
521 
522  factory.endClass(true);
523  }
524  }
525 
526  template < typename GUM_SCALAR >
529 
530  // Class with a super class must be declared in order
531  for (auto c : __o3Classes) {
532  factory.continueClass(c->name().label());
533 
534  __completeAggregates(factory, *c);
535 
536  factory.endClass(false);
537  }
538  }
539 
540  template < typename GUM_SCALAR >
542  PRMFactory< GUM_SCALAR >& factory, O3Class& c) {
543  // Attributes
544  for (auto& agg : c.aggregates()) {
545  if (__checkAggregateForCompletion(c, agg)) {
546  factory.continueAggregator(agg.name().label());
547 
548  for (const auto& parent : agg.parents()) {
549  factory.addParent(parent.label());
550  }
551 
552  factory.endAggregator();
553  }
554  }
555  }
556 
557  template < typename GUM_SCALAR >
559  O3Class& c, O3Aggregate& agg) {
560  // Checking parents
561  auto t = __checkAggParents(c, agg);
562  if (t == nullptr) { return false; }
563 
564  // Checking parameters numbers
565  if (!__checkAggParameters(c, agg, t)) { return false; }
566 
567  return true;
568  }
569 
570  template < typename GUM_SCALAR >
572  PRMFactory< GUM_SCALAR >& factory, O3Class& c) {
573  // Attributes
574  for (auto& attr : c.attributes()) {
575  if (__checkAttributeForCompletion(c, *attr)) {
576  factory.continueAttribute(attr->name().label());
577 
578  for (const auto& parent : attr->parents()) {
579  factory.addParent(parent.label());
580  }
581 
582  auto raw = dynamic_cast< const O3RawCPT* >(attr.get());
583 
584  if (raw) {
585  auto values = std::vector< std::string >();
586  for (const auto& val : raw->values()) {
587  values.push_back(val.formula().formula());
588  }
589  factory.setRawCPFByColumns(values);
590  }
591 
592  auto rule_cpt = dynamic_cast< const O3RuleCPT* >(attr.get());
593  if (rule_cpt) {
594  for (const auto& rule : rule_cpt->rules()) {
595  auto labels = std::vector< std::string >();
596  auto values = std::vector< std::string >();
597 
598  for (const auto& lbl : rule.first) {
599  labels.push_back(lbl.label());
600  }
601 
602  for (const auto& form : rule.second) {
603  values.push_back(form.formula().formula());
604  }
605 
606  factory.setCPFByRule(labels, values);
607  }
608  }
609 
610  factory.endAttribute();
611  }
612  }
613  }
614 
615  template < typename GUM_SCALAR >
617  const O3Class& o3_c, O3Attribute& attr) {
618  // Check for parents existence
619  const auto& c = __prm->getClass(o3_c.name().label());
620  for (auto& prnt : attr.parents()) {
621  if (!__checkParent(c, prnt)) { return false; }
622  }
623 
624  // Check that CPT sums to 1
625  auto raw = dynamic_cast< O3RawCPT* >(&attr);
626  if (raw) { return __checkRawCPT(c, *raw); }
627 
628  auto rule = dynamic_cast< O3RuleCPT* >(&attr);
629  if (rule) { return __checkRuleCPT(c, *rule); }
630 
631  return true;
632  }
633 
634  template < typename GUM_SCALAR >
636  const PRMClass< GUM_SCALAR >& c, const O3Label& prnt) {
637  if (prnt.label().find('.') == std::string::npos) {
638  return __checkLocalParent(c, prnt);
639 
640  } else {
641  return __checkRemoteParent(c, prnt);
642  }
643  }
644 
645  template < typename GUM_SCALAR >
647  const PRMClass< GUM_SCALAR >& c, const O3Label& prnt) {
648  if (!c.exists(prnt.label())) {
649  O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *__errors);
650  return false;
651  }
652 
653  const auto& elt = c.get(prnt.label());
657  O3PRM_CLASS_ILLEGAL_PARENT(prnt, *__errors);
658  return false;
659  }
660 
661  return true;
662  }
663 
664  template < typename GUM_SCALAR >
666  const PRMClassElementContainer< GUM_SCALAR >& c, const O3Label& prnt) {
667  if (__resolveSlotChain(c, prnt) == nullptr) { return false; }
668  return true;
669  }
670 
671  template < typename GUM_SCALAR >
673  const O3RuleCPT& attr, const O3RuleCPT::O3Rule& rule) {
674  // Check that the number of labels is correct
675  if (rule.first.size() != attr.parents().size()) {
676  O3PRM_CLASS_ILLEGAL_RULE_SIZE(
677  rule, rule.first.size(), attr.parents().size(), *__errors);
678  return false;
679  }
680  return true;
681  }
682 
683  template < typename GUM_SCALAR >
685  const PRMClass< GUM_SCALAR >& c,
686  const O3RuleCPT& attr,
687  const O3RuleCPT::O3Rule& rule) {
688  bool errors = false;
689  for (std::size_t i = 0; i < attr.parents().size(); ++i) {
690  auto label = rule.first[i];
691  auto prnt = attr.parents()[i];
692  try {
693  auto real_labels = __resolveSlotChain(c, prnt)->type()->labels();
694  // c.get(prnt.label()).type()->labels();
695  if (label.label() != "*"
696  && std::find(real_labels.begin(), real_labels.end(), label.label())
697  == real_labels.end()) {
698  O3PRM_CLASS_ILLEGAL_RULE_LABEL(rule, label, prnt, *__errors);
699  errors = true;
700  }
701  } catch (Exception&) {
702  // parent does not exists and is already reported
703  }
704  }
705  return errors == false;
706  }
707 
708  template < typename GUM_SCALAR >
710  const HashTable< std::string, const PRMParameter< GUM_SCALAR >* >& scope,
711  O3RuleCPT::O3Rule& rule) {
712  // Add parameters to formulas
713  for (auto& f : rule.second) {
714  f.formula().variables().clear();
715  for (const auto& values : scope) {
716  f.formula().variables().insert(values.first, values.second->value());
717  }
718  }
719  }
720 
721 
722  template < typename GUM_SCALAR >
724  const PRMClass< GUM_SCALAR >& c,
725  const O3RuleCPT& attr,
726  const O3RuleCPT::O3Rule& rule) {
727  bool errors = false;
728  // Check that formulas are valid and sums to 1
729  GUM_SCALAR sum = 0.0;
730  for (const auto& f : rule.second) {
731  try {
732  auto value = GUM_SCALAR(f.formula().result());
733  sum += value;
734  if (value < 0.0 || 1.0 < value) {
735  O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), f, *__errors);
736  errors = true;
737  }
738  } catch (OperationNotAllowed&) {
739  O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), f, *__errors);
740  errors = true;
741  }
742  }
743 
744  // Check that CPT sums to 1
745  if (std::abs(sum - 1.0) > 1e-3) {
746  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
747  c.name(), attr.name(), float(sum), *__errors);
748  errors = true;
749  } else if (std::abs(sum - 1.0f) > 1e-6) {
750  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
751  c.name(), attr.name(), float(sum), *__errors);
752  }
753  return errors == false;
754  }
755 
756  template < typename GUM_SCALAR >
758  const PRMClass< GUM_SCALAR >& c, O3RuleCPT& attr) {
759  const auto& scope = c.scope();
760  bool errors = false;
761  for (auto& rule : attr.rules()) {
762  try {
763  if (!__checkLabelsNumber(attr, rule)) { errors = true; }
764  if (!__checkLabelsValues(c, attr, rule)) { errors = true; }
765  __addParamsToForms(scope, rule);
766  if (!__checkRuleCPTSumsTo1(c, attr, rule)) { errors = true; }
767  } catch (Exception& e) {
768  GUM_SHOWERROR(e);
769  errors = true;
770  }
771  }
772 
773  return errors == false;
774  }
775 
776  template < typename GUM_SCALAR >
778  const PRMClass< GUM_SCALAR >& c, O3RawCPT& attr) {
779  const auto& type = __prm->type(attr.type().label());
780 
781  auto domainSize = type->domainSize();
782  for (auto& prnt : attr.parents()) {
783  try {
784  domainSize *= c.get(prnt.label()).type()->domainSize();
785  } catch (NotFound&) {
786  // If we are here, all parents have been check so __resolveSlotChain
787  // will not raise an error and not return a nullptr
788  domainSize *= __resolveSlotChain(c, prnt)->type()->domainSize();
789  }
790  }
791 
792  // Check for CPT size
793  if (domainSize != attr.values().size()) {
794  O3PRM_CLASS_ILLEGAL_CPT_SIZE(c.name(),
795  attr.name(),
796  Size(attr.values().size()),
797  domainSize,
798  *__errors);
799  return false;
800  }
801 
802  // Add parameters to formulas
803  const auto& scope = c.scope();
804  for (auto& f : attr.values()) {
805  f.formula().variables().clear();
806 
807  for (const auto& values : scope) {
808  f.formula().variables().insert(values.first, values.second->value());
809  }
810  }
811 
812  // Check that CPT sums to 1
813  Size parent_size = domainSize / type->domainSize();
814  auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
815 
816  for (std::size_t i = 0; i < attr.values().size(); ++i) {
817  try {
818  auto idx = i % parent_size;
819  auto val = (GUM_SCALAR)attr.values()[i].formula().result();
820  values[idx] += val;
821 
822  if (val < 0.0 || 1.0 < val) {
823  O3PRM_CLASS_ILLEGAL_CPT_VALUE(
824  c.name(), attr.name(), attr.values()[i], *__errors);
825  return false;
826  }
827  } catch (Exception&) {
828  O3PRM_CLASS_ILLEGAL_CPT_VALUE(
829  c.name(), attr.name(), attr.values()[i], *__errors);
830  return false;
831  }
832  }
833 
834  for (auto f : values) {
835  if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
836  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
837  c.name(), attr.name(), float(f), *__errors);
838  return false;
839  } else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
840  O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
841  c.name(), attr.name(), float(f), *__errors);
842  }
843  }
844  return true;
845  }
846 
847  template < typename GUM_SCALAR >
848  INLINE const PRMClassElement< GUM_SCALAR >*
851  const O3Label& chain) {
852  auto s = chain.label();
853  auto current = &c;
854  std::vector< std::string > v;
855 
856  decomposePath(chain.label(), v);
857 
858  for (size_t i = 0; i < v.size(); ++i) {
859  auto link = v[i];
860 
861  if (!__checkSlotChainLink(*current, chain, link)) { return nullptr; }
862 
863  auto elt = &(current->get(link));
864 
865  if (i == v.size() - 1) {
866  // last link, should be an attribute or aggregate
867  return elt;
868 
869  } else {
870  // should be a reference slot
871 
872  auto ref = dynamic_cast< const PRMReferenceSlot< GUM_SCALAR >* >(elt);
873  if (ref) {
874  current = &(ref->slotType());
875  } else {
876  return nullptr; // failsafe to prevent infinite loop
877  }
878  }
879  }
880 
881  // Encountered only reference slots
882 
883  return nullptr;
884  }
885 
886  template < typename GUM_SCALAR >
889  const O3Label& chain,
890  const std::string& s) {
891  if (!c.exists(s)) {
892  O3PRM_CLASS_LINK_NOT_FOUND(chain, s, *__errors);
893  return false;
894  }
895  return true;
896  }
897 
898  template < typename GUM_SCALAR >
901  factory.continueClass(c.name().label());
902 
903  for (auto& agg : c.aggregates()) {
904  if (__checkAggregateForDeclaration(c, agg)) {
905  auto params = std::vector< std::string >();
906  for (auto& p : agg.parameters()) {
907  params.push_back(p.label());
908  }
909 
910  factory.startAggregator(agg.name().label(),
911  agg.aggregateType().label(),
912  agg.variableType().label(),
913  params);
914  factory.endAggregator();
915  }
916  }
917 
918  factory.endClass(false);
919  }
920 
921  template < typename GUM_SCALAR >
923  O3Class& o3class, O3Aggregate& agg) {
924  if (!__solver->resolveType(agg.variableType())) { return false; }
925 
926  // Checking type legality if overload
927  if (!__checkAggTypeLegality(o3class, agg)) { return false; }
928 
929  return true;
930  }
931 
932  template < typename GUM_SCALAR >
933  INLINE const PRMType*
935  O3Aggregate& agg) {
936  const auto& c = __prm->getClass(o3class.name().label());
937  auto t = (const PRMType*)nullptr;
938 
939  for (const auto& prnt : agg.parents()) {
940  auto elt = __resolveSlotChain(c, prnt);
941 
942  if (elt == nullptr) {
943  O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *__errors);
944  return nullptr;
945 
946  } else {
947  if (t == nullptr) {
948  try {
949  t = &(elt->type());
950 
951  } catch (OperationNotAllowed&) {
952  O3PRM_CLASS_WRONG_PARENT(prnt, *__errors);
953  return nullptr;
954  }
955 
956  } else if ((*t) != elt->type()) {
957  // Wront type in chain
958  O3PRM_CLASS_WRONG_PARENT_TYPE(
959  prnt, t->name(), elt->type().name(), *__errors);
960  return nullptr;
961  }
962  }
963  }
964  return t;
965  }
966 
967  template < typename GUM_SCALAR >
968  INLINE bool
970  O3Aggregate& agg) {
971  if (__prm->isClass(o3class.superLabel().label())) {
972  const auto& super = __prm->getClass(o3class.superLabel().label());
973  const auto& agg_type = __prm->type(agg.variableType().label());
974 
975  if (super.exists(agg.name().label())
976  && !agg_type.isSubTypeOf(super.get(agg.name().label()).type())) {
977  O3PRM_CLASS_ILLEGAL_OVERLOAD(
978  agg.name(), o3class.superLabel(), *__errors);
979  return false;
980  }
981  }
982 
983  return true;
984  }
985 
986  template < typename GUM_SCALAR >
988  O3Class& o3class, O3Aggregate& agg, const PRMType* t) {
989  bool ok = false;
990 
992  agg.aggregateType().label())) {
999  ok = __checkParametersNumber(agg, 0);
1000  break;
1001  }
1002 
1006  ok = __checkParametersNumber(agg, 1);
1007  break;
1008  }
1009 
1010  default: { GUM_ERROR(FatalError, "unknown aggregate type"); }
1011  }
1012 
1013  if (!ok) { return false; }
1014 
1015  // Checking parameters type
1017  agg.aggregateType().label())) {
1021  ok = __checkParameterValue(agg, *t);
1022  break;
1023  }
1024 
1025  default: { /* Nothing to do */
1026  }
1027  }
1028 
1029  return ok;
1030  }
1031 
1032  template < typename GUM_SCALAR >
1033  INLINE bool
1035  Size n) {
1036  if (agg.parameters().size() != n) {
1037  O3PRM_CLASS_AGG_PARAMETERS(
1038  agg.name(), Size(n), Size(agg.parameters().size()), *__errors);
1039  return false;
1040  }
1041 
1042  return true;
1043  }
1044 
1045  template < typename GUM_SCALAR >
1047  O3Aggregate& agg, const gum::prm::PRMType& t) {
1048  const auto& param = agg.parameters().front();
1049  bool found = false;
1050  for (Size idx = 0; idx < t.variable().domainSize(); ++idx) {
1051  if (t.variable().label(idx) == param.label()) {
1052  found = true;
1053  break;
1054  }
1055  }
1056 
1057  if (!found) {
1058  O3PRM_CLASS_AGG_PARAMETER_NOT_FOUND(agg.name(), param, *__errors);
1059  return false;
1060  }
1061 
1062  return true;
1063  }
1064 
1065  } // namespace o3prm
1066  } // namespace prm
1067 } // namespace gum
O3LabelList & parents()
Definition: O3prm.cpp:1066
O3ClassList & classes()
Definition: O3prm.cpp:501
std::pair< O3LabelList, O3FormulaList > O3Rule
Definition: O3prm.h:544
bool __checkAggregateForDeclaration(O3Class &o3class, O3Aggregate &agg)
virtual O3LabelList & parents()
Definition: O3prm.cpp:666
PRMParameter is a member of a Class in a PRM.
Definition: PRMParameter.h:49
bool __checkLabelsValues(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
virtual O3RuleList & rules()
Definition: O3prm.cpp:764
DiscreteVariable & variable()
Return a reference on the DiscreteVariable contained in this.
Definition: PRMType_inl.h:42
void endAggregator()
Finishes an aggregate declaration.
bool __checkAttributeForCompletion(const O3Class &o3_c, O3Attribute &attr)
void __addParamsToForms(const HashTable< std::string, const PRMParameter< GUM_SCALAR > * > &scope, O3RuleCPT::O3Rule &rule)
const std::string & name() const
Returns the name of this object.
Definition: PRMObject_inl.h:32
bool __checkAttributeForDeclaration(O3Class &o3_c, O3Attribute &attr)
HashTable< std::string, O3Attribute *> AttrMap
O3ReferenceSlotList & referenceSlots()
Definition: O3prm.cpp:885
#define GUM_SHOWERROR(e)
Definition: exceptions.h:58
bool __checkRawCPT(const PRMClass< GUM_SCALAR > &c, O3RawCPT &attr)
The O3Aggregate is part of the AST of the O3PRM language.
Definition: O3prm.h:575
Headers for the O3ClassFactory class.
virtual void continueClass(const std::string &c) override
Continue the declaration of a class.
O3AggregateList & aggregates()
Definition: O3prm.cpp:893
STL namespace.
bool __checkReferenceSlot(O3Class &c, O3ReferenceSlot &ref)
HashTable< std::string, O3Aggregate *> AggMap
Abstract class representing an element of PRM class.
virtual void endAttribute() override
Tells the factory that we finished declaring an attribute.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
virtual void addReferenceSlot(const std::string &type, const std::string &name, bool isArray) override
Tells the factory that we started declaring a slot.
This class is used contain and manipulate gum::ParseError.
The O3Label is part of the AST of the O3PRM language.
Definition: O3prm.h:171
bool __checkParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
The O3Attribute is part of the AST of the O3PRM language.
Definition: O3prm.h:469
virtual bool exists(const std::string &name) const
Returns true if a member with the given name exists in this PRMClassElementContainer or in the PRMCla...
virtual O3FormulaList & values()
Definition: O3prm.cpp:713
void __completeAggregates(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
O3LabelList & parameters()
Definition: O3prm.cpp:1072
O3ParameterList & parameters()
Definition: O3prm.cpp:880
bool __checkLocalParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
virtual NodeId addNode()
insert a new node and return its id
A PRMReferenceSlot represent a relation between two PRMClassElementContainer.
Definition: PRMObject.h:220
virtual void startClass(const std::string &c, const std::string &ext="", const Set< std::string > *implements=nullptr, bool delayInheritance=false) override
Tells the factory that we start a class declaration.
The class for generic Hash Tables.
Definition: hashTable.h:676
bool __checkRuleCPT(const PRMClass< GUM_SCALAR > &c, O3RuleCPT &attr)
void continueAggregator(const std::string &name)
Conitnues an aggregator declaration.
std::string & label()
Definition: O3prm.cpp:240
virtual void startAttribute(const std::string &type, const std::string &name, bool scalar_atttr=false) override
Tells the factory that we start an attribute declaration.
bool __checkRemoteParent(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &prnt)
HashTable< std::string, gum::NodeId > __nameMap
void __completeAttribute(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
The O3RuleCPT is part of the AST of the O3PRM language.
Definition: O3prm.h:540
virtual Size domainSize() const =0
void setRawCPFByColumns(const std::vector< GUM_SCALAR > &array)
Gives the factory the CPF in its raw form.
Resolves names for the different O3PRM factories.
Definition: O3NameSolver.h:55
bool __checkAggTypeLegality(O3Class &o3class, O3Aggregate &agg)
bool __checkAggParameters(O3Class &o3class, O3Aggregate &agg, const PRMType *t)
The O3Class is part of the AST of the O3PRM language.
Definition: O3prm.h:617
O3ClassFactory(PRM< GUM_SCALAR > &prm, O3PRM &o3_prm, O3NameSolver< GUM_SCALAR > &solver, ErrorsContainer &errors)
O3Label & superLabel()
Definition: O3prm.cpp:870
bool __checkParameterValue(O3Aggregate &agg, const gum::prm::PRMType &t)
Factory which builds a PRM<GUM_SCALAR>.
Definition: PRMType.h:47
HashTable< std::string, const PRMParameter< GUM_SCALAR > *> scope() const
Returns all the parameters in the scope of this class.
The O3ReferenceSlot is part of the AST of the O3PRM language.
Definition: O3prm.h:436
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
virtual O3Label & type()
Definition: O3prm.cpp:660
virtual void setCPFByRule(const std::vector< std::string > &labels, const std::vector< GUM_SCALAR > &values)
Fills the CPF using a rule.
HashTable< NodeId, O3Class *> __nodeMap
PRMClassElement< GUM_SCALAR > & get(NodeId id)
See gum::prm::PRMClassElementContainer<GUM_SCALAR>::get(NodeId).
bool __checkAggregateForCompletion(O3Class &o3class, O3Aggregate &agg)
HashTable< std::string, O3Class *> __classMap
Base class for all aGrUM&#39;s exceptions.
Definition: exceptions.h:103
HashTable< std::string, O3ReferenceSlot *> RefMap
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:40
This is a decoration of the DiscreteVariable class.
Definition: PRMType.h:60
bool __checkLabelsNumber(const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
O3AttributeList & attributes()
Definition: O3prm.cpp:887
virtual void endClass(bool checkImplementations=true) override
Tells the factory that we finished a class declaration.
Builds gum::prm::Class from gum::prm::o3prm::O3Class.
const Sequence< NodeId > & topologicalOrder(bool clear=true) const
The topological order stays the same as long as no variable or arcs are added or erased src the topol...
Definition: diGraph.cpp:88
bool __checkRuleCPTSumsTo1(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
The O3RawCPT is part of the AST of the O3PRM language.
Definition: O3prm.h:508
std::vector< O3Class *> __o3Classes
bool __checkSlotChainLink(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain, const std::string &s)
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
Definition: PRM.h:63
<agrum/PRM/classElementContainer.h>
A PRMClass is an object of a PRM representing a fragment of a Bayesian Network which can be instantia...
Definition: PRMClass.h:63
O3NameSolver< GUM_SCALAR > * __solver
O3ClassFactory< GUM_SCALAR > & operator=(const O3ClassFactory< GUM_SCALAR > &src)
O3LabelList & interfaces()
Definition: O3prm.cpp:875
void decomposePath(const std::string &path, std::vector< std::string > &v)
Decompose a string in a vector of strings using "." as separators.
Definition: utils_prm.cpp:27
void __addParameters(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
const PRMClassElement< GUM_SCALAR > * __resolveSlotChain(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain)
virtual void continueAttribute(const std::string &name) override
Continues the declaration of an attribute.
void addParameter(const std::string &type, const std::string &name, double value) override
Add a parameter to the current class with a default value.
bool __checkParametersNumber(O3Aggregate &agg, Size n)
The O3PRM is part of the AST of the O3PRM language.
Definition: O3prm.h:890
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
virtual O3Label & name()
Definition: O3prm.cpp:663
const PRMType * __checkAggParents(O3Class &o3class, O3Aggregate &agg)
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:610
void startAggregator(const std::string &name, const std::string &agg_type, const std::string &rv_type, const std::vector< std::string > &params)
Start an aggregator declaration.
virtual void addParent(const std::string &name) override
Tells the factory that we add a parent to the current declared attribute.
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52