aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
O3ClassFactory_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation for the O3ClassFactory class.
25  *
26  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27  * @author Lionel TORTI
28  */
29 
30 #include <agrum/PRM/o3prm/O3ClassFactory.h>
31 
32 namespace gum {
33  namespace prm {
34  namespace o3prm {
35 
36  template < typename GUM_SCALAR >
37  INLINE O3ClassFactory< GUM_SCALAR >::O3ClassFactory(
38  PRM< GUM_SCALAR >& prm,
39  O3PRM& o3_prm,
42  prm__(&prm),
45  }
46 
47  template < typename GUM_SCALAR >
49  const O3ClassFactory< GUM_SCALAR >& src) :
50  prm__(src.prm__),
56  }
57 
58  template < typename GUM_SCALAR >
61  prm__(std::move(src.prm__)),
68  }
69 
70  template < typename GUM_SCALAR >
73  }
74 
75  template < typename GUM_SCALAR >
77  const O3ClassFactory< GUM_SCALAR >& src) {
78  if (this == &src) { return *this; }
79  prm__ = src.prm__;
86  dag__ = src.dag__;
88  return *this;
89  }
90 
91  template < typename GUM_SCALAR >
94  if (this == &src) { return *this; }
95  prm__ = std::move(src.prm__);
102  dag__ = std::move(src.dag__);
104  return *this;
105  }
106 
107  template < typename GUM_SCALAR >
110 
111  // Class with a super class must be declared after
112  if (checkO3Classes__()) {
114 
115  for (auto c: o3Classes__) {
116  // Soving interfaces
117  auto implements = Set< std::string >();
118  for (auto& i: c->interfaces()) {
120  }
121 
122  // Adding the class
123  if (solver__->resolveClass(c->superLabel())) {
125  c->superLabel().label(),
126  &implements,
127  true);
128  factory.endClass(false);
129  }
130  }
131  }
132  }
133 
134  template < typename GUM_SCALAR >
137 
138  for (auto id = topo_order.rbegin(); id != topo_order.rend(); --id) {
140  }
141  }
142 
143  template < typename GUM_SCALAR >
146  }
147 
148  template < typename GUM_SCALAR >
150  for (auto& c: o3_prm__->classes()) {
151  auto id = dag__.addNode();
152 
153  try {
154  nameMap__.insert(c->name().label(), id);
155  classMap__.insert(c->name().label(), c.get());
156  nodeMap__.insert(id, c.get());
157 
158  } catch (DuplicateElement&) {
160  return false;
161  }
162  }
163 
164  return true;
165  }
166 
167  template < typename GUM_SCALAR >
169  for (auto& c: o3_prm__->classes()) {
170  if (c->superLabel().label() != "") {
171  if (!solver__->resolveClass(c->superLabel())) { return false; }
172 
173  auto head = nameMap__[c->superLabel().label()];
174  auto tail = nameMap__[c->name().label()];
175 
176  try {
177  dag__.addArc(tail, head);
178  } catch (InvalidDirectedCycle&) {
179  // Cyclic inheritance
181  return false;
182  }
183  }
184  }
185 
186  return true;
187  }
188 
189  template < typename GUM_SCALAR >
191  for (auto& c: o3_prm__->classes()) {
192  if (checkImplementation__(*c)) {
194  }
195  }
196  }
197 
201 
202  template < typename GUM_SCALAR >
204  // Saving attributes names for fast lookup
205  auto attr_map = AttrMap();
206  for (auto& a: c.attributes()) {
207  attr_map.insert(a->name().label(), a.get());
208  }
209 
210  // Saving aggregates names for fast lookup
211  auto agg_map = AggMap();
212  for (auto& agg: c.aggregates()) {
213  agg_map.insert(agg.name().label(), &agg);
214  }
215 
216  auto ref_map = RefMap();
217  for (auto& ref: c.referenceSlots()) {
218  ref_map.insert(ref.name().label(), &ref);
219  }
220 
221  // Cheking interface implementation
222  for (auto& i: c.interfaces()) {
223  if (solver__->resolveInterface(i)) {
225  return false;
226  }
227  }
228  }
229 
230  return true;
231  }
232 
233  template < typename GUM_SCALAR >
234  INLINE bool
236  O3Label& i,
237  AttrMap& attr_map,
238  AggMap& agg_map,
239  RefMap& ref_map) {
240  const auto& real_i = prm__->getInterface(i.label());
241 
242  auto counter = (Size)0;
243  for (const auto& a: real_i.attributes()) {
244  if (attr_map.exists(a->name())) {
245  ++counter;
246 
247  if (!checkImplementation__(attr_map[a->name()]->type(), a->type())) {
249  i,
250  attr_map[a->name()]->name(),
251  *errors__);
252  return false;
253  }
254  }
255 
256  if (agg_map.exists(a->name())) {
257  ++counter;
258 
260  a->type())) {
262  i,
263  agg_map[a->name()]->name(),
264  *errors__);
265  return false;
266  }
267  }
268  }
269 
270  if (counter != real_i.attributes().size()) {
272  return false;
273  }
274 
275  counter = 0;
276  for (const auto& r: real_i.referenceSlots()) {
277  if (ref_map.exists(r->name())) {
278  ++counter;
279 
281  r->slotType())) {
283  i,
284  ref_map[r->name()]->name(),
285  *errors__);
286  return false;
287  }
288  }
289  }
290  return true;
291  }
292 
293  template < typename GUM_SCALAR >
294  INLINE bool
296  const PRMType& type) {
297  if (!solver__->resolveType(o3_type)) { return false; }
298 
299  return prm__->type(o3_type.label()).isSubTypeOf(type);
300  }
301 
302  template < typename GUM_SCALAR >
304  O3Label& o3_type,
306  if (!solver__->resolveSlotType(o3_type)) { return false; }
307 
308  if (prm__->isInterface(o3_type.label())) {
310  } else {
312  }
313  }
314 
315  template < typename GUM_SCALAR >
318  // Class with a super class must be declared after
319  for (auto c: o3Classes__) {
321 
323 
325 
326  factory.endClass(false);
327  }
328  }
329 
330  template < typename GUM_SCALAR >
333  O3Class& c) {
334  for (auto& p: c.parameters()) {
335  switch (p.type()) {
336  case O3Parameter::PRMType::INT: {
337  factory.addParameter("int", p.name().label(), p.value().value());
338  break;
339  }
340 
341  case O3Parameter::PRMType::FLOAT: {
342  factory.addParameter("real", p.name().label(), p.value().value());
343  break;
344  }
345 
346  default: {
347  GUM_ERROR(FatalError, "unknown O3Parameter type");
348  }
349  }
350  }
351  }
352 
353  template < typename GUM_SCALAR >
355  // Class with a super class must be declared after
356  for (auto c: o3Classes__) {
359  }
360  }
361 
362  template < typename GUM_SCALAR >
365 
367 
368  // References
369  for (auto& ref: c.referenceSlots()) {
370  if (checkReferenceSlot__(c, ref)) {
372  ref.name().label(),
373  ref.isArray());
374  }
375  }
376 
377  factory.endClass(false);
378  }
379 
380  template < typename GUM_SCALAR >
381  INLINE bool
383  O3ReferenceSlot& ref) {
384  if (!solver__->resolveSlotType(ref.type())) { return false; }
385 
386  const auto& real_c = prm__->getClass(c.name().label());
387 
388  // Check for dupplicates
389  if (real_c.exists(ref.name().label())) {
390  const auto& elt = real_c.get(ref.name().label());
391 
393  auto slot_type = (PRMClassElementContainer< GUM_SCALAR >*)nullptr;
394 
395  if (prm__->isInterface(ref.type().label())) {
397 
398  } else {
399  slot_type = &(prm__->getClass(ref.type().label()));
400  }
401 
402  auto real_ref
403  = static_cast< const PRMReferenceSlot< GUM_SCALAR >* >(&elt);
404 
405  if (slot_type->name() == real_ref->slotType().name()) {
407  return false;
408 
409  } else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
411  return false;
412  }
413 
414  } else {
416  return false;
417  }
418  }
419 
420  // If class we need to check for illegal references
421  if (prm__->isClass(ref.type().label())) {
422  const auto& ref_type = prm__->getClass(ref.type().label());
423 
424  // No recursive reference
425  if ((&ref_type) == (&real_c)) {
427  return false;
428  }
429 
430  // No reference to subclasses
431  if (ref_type.isSubTypeOf(real_c)) {
433  return false;
434  }
435  }
436 
437  return true;
438  }
439 
440  template < typename GUM_SCALAR >
442  // Class with a super class must be declared after
443  for (auto c: o3Classes__) {
446  }
447  }
448 
449  template < typename GUM_SCALAR >
451  // Class with a super class must be declared after
452  for (auto c: o3Classes__) {
455  }
456  }
457 
458  template < typename GUM_SCALAR >
462 
463  for (auto& attr: c.attributes()) {
467  }
468  }
469 
470  factory.endClass(false);
471  }
472 
473  template < typename GUM_SCALAR >
475  O3Class& c,
476  O3Attribute& attr) {
477  // Check type
478  if (!solver__->resolveType(attr.type())) { return false; }
479 
480  // Checking type legality if overload
481  if (c.superLabel().label() != "") {
482  const auto& super = prm__->getClass(c.superLabel().label());
483 
484  if (!super.exists(attr.name().label())) { return true; }
485 
486  const auto& super_type = super.get(attr.name().label()).type();
487  const auto& type = prm__->type(attr.type().label());
488 
489  if (!type.isSubTypeOf(super_type)) {
491  return false;
492  }
493  }
494  return true;
495  }
496 
497  template < typename GUM_SCALAR >
500 
501  // Class with a super class must be declared in order
502  for (auto c: o3Classes__) {
505 
507 
508  if (c->superLabel().label() != "") {
509  auto& super = prm__->getClass(c->superLabel().label());
510  auto to_complete = Set< std::string >();
511 
512  for (auto a: super.attributes()) {
514  }
515 
516  for (auto a: super.aggregates()) {
518  }
519 
520  for (auto& a: c->attributes()) {
522  .get(a->name().label())
523  .safeName());
524  }
525 
526  for (auto& a: c->aggregates()) {
528  .get(a.name().label())
529  .safeName());
530  }
531 
532  for (auto a: to_complete) {
534  }
535  }
536 
537  factory.endClass(true);
538  }
539  }
540 
541  template < typename GUM_SCALAR >
544 
545  // Class with a super class must be declared in order
546  for (auto c: o3Classes__) {
548 
550 
551  factory.endClass(false);
552  }
553  }
554 
555  template < typename GUM_SCALAR >
558  O3Class& c) {
559  // Attributes
560  for (auto& agg: c.aggregates()) {
563 
564  for (const auto& parent: agg.parents()) {
566  }
567 
569  }
570  }
571  }
572 
573  template < typename GUM_SCALAR >
575  O3Class& c,
576  O3Aggregate& agg) {
577  // Checking parents
578  auto t = checkAggParents__(c, agg);
579  if (t == nullptr) { return false; }
580 
581  // Checking parameters numbers
582  if (!checkAggParameters__(c, agg, t)) { return false; }
583 
584  return true;
585  }
586 
587  template < typename GUM_SCALAR >
590  O3Class& c) {
591  // Attributes
592  for (auto& attr: c.attributes()) {
595 
596  for (const auto& parent: attr->parents()) {
598  }
599 
600  auto raw = dynamic_cast< const O3RawCPT* >(attr.get());
601 
602  if (raw) {
603  auto values = std::vector< std::string >();
604  for (const auto& val: raw->values()) {
606  }
608  }
609 
610  auto rule_cpt = dynamic_cast< const O3RuleCPT* >(attr.get());
611  if (rule_cpt) {
612  for (const auto& rule: rule_cpt->rules()) {
613  auto labels = std::vector< std::string >();
614  auto values = std::vector< std::string >();
615 
616  for (const auto& lbl: rule.first) {
618  }
619 
620  for (const auto& form: rule.second) {
622  }
623 
625  }
626  }
627 
629  }
630  }
631  }
632 
633  template < typename GUM_SCALAR >
635  const O3Class& o3_c,
636  O3Attribute& attr) {
637  // Check for parents existence
638  const auto& c = prm__->getClass(o3_c.name().label());
639  for (auto& prnt: attr.parents()) {
640  if (!checkParent__(c, prnt)) { return false; }
641  }
642 
643  // Check that CPT sums to 1
644  auto raw = dynamic_cast< O3RawCPT* >(&attr);
645  if (raw) { return checkRawCPT__(c, *raw); }
646 
647  auto rule = dynamic_cast< O3RuleCPT* >(&attr);
648  if (rule) { return checkRuleCPT__(c, *rule); }
649 
650  return true;
651  }
652 
653  template < typename GUM_SCALAR >
655  const PRMClass< GUM_SCALAR >& c,
656  const O3Label& prnt) {
657  if (prnt.label().find('.') == std::string::npos) {
658  return checkLocalParent__(c, prnt);
659 
660  } else {
661  return checkRemoteParent__(c, prnt);
662  }
663  }
664 
665  template < typename GUM_SCALAR >
667  const PRMClass< GUM_SCALAR >& c,
668  const O3Label& prnt) {
669  if (!c.exists(prnt.label())) {
671  return false;
672  }
673 
674  const auto& elt = c.get(prnt.label());
679  return false;
680  }
681 
682  return true;
683  }
684 
685  template < typename GUM_SCALAR >
688  const O3Label& prnt) {
689  if (resolveSlotChain__(c, prnt) == nullptr) { return false; }
690  return true;
691  }
692 
693  template < typename GUM_SCALAR >
695  const O3RuleCPT& attr,
696  const O3RuleCPT::O3Rule& rule) {
697  // Check that the number of labels is correct
698  if (rule.first.size() != attr.parents().size()) {
700  rule.first.size(),
701  attr.parents().size(),
702  *errors__);
703  return false;
704  }
705  return true;
706  }
707 
708  template < typename GUM_SCALAR >
710  const PRMClass< GUM_SCALAR >& c,
711  const O3RuleCPT& attr,
712  const O3RuleCPT::O3Rule& rule) {
713  bool errors = false;
714  for (std::size_t i = 0; i < attr.parents().size(); ++i) {
715  auto label = rule.first[i];
716  auto prnt = attr.parents()[i];
717  try {
719  // c.get(prnt.label()).type()->labels();
720  if (label.label() != "*"
722  == real_labels.end()) {
724  errors = true;
725  }
726  } catch (Exception&) {
727  // parent does not exists and is already reported
728  }
729  }
730  return errors == false;
731  }
732 
733  template < typename GUM_SCALAR >
735  const HashTable< std::string, const PRMParameter< GUM_SCALAR >* >& scope,
736  O3RuleCPT::O3Rule& rule) {
737  // Add parameters to formulas
738  for (auto& f: rule.second) {
739  f.formula().variables().clear();
740  for (const auto& values: scope) {
742  }
743  }
744  }
745 
746 
747  template < typename GUM_SCALAR >
749  const PRMClass< GUM_SCALAR >& c,
750  const O3RuleCPT& attr,
751  const O3RuleCPT::O3Rule& rule) {
752  bool errors = false;
753  // Check that formulas are valid and sums to 1
754  GUM_SCALAR sum = 0.0;
755  for (const auto& f: rule.second) {
756  try {
757  auto value = GUM_SCALAR(f.formula().result());
758  sum += value;
759  if (value < 0.0 || 1.0 < value) {
761  errors = true;
762  }
763  } catch (OperationNotAllowed&) {
765  errors = true;
766  }
767  }
768 
769  // Check that CPT sums to 1
770  if (std::abs(sum - 1.0) > 1e-3) {
772  attr.name(),
773  float(sum),
774  *errors__);
775  errors = true;
776  } else if (std::abs(sum - 1.0f) > 1e-6) {
778  attr.name(),
779  float(sum),
780  *errors__);
781  }
782  return errors == false;
783  }
784 
785  template < typename GUM_SCALAR >
787  const PRMClass< GUM_SCALAR >& c,
788  O3RuleCPT& attr) {
789  const auto& scope = c.scope();
790  bool errors = false;
791  for (auto& rule: attr.rules()) {
792  try {
793  if (!checkLabelsNumber__(attr, rule)) { errors = true; }
794  if (!checkLabelsValues__(c, attr, rule)) { errors = true; }
796  if (!checkRuleCPTSumsTo1__(c, attr, rule)) { errors = true; }
797  } catch (Exception& e) {
798  GUM_SHOWERROR(e);
799  errors = true;
800  }
801  }
802 
803  return errors == false;
804  }
805 
806  template < typename GUM_SCALAR >
808  const PRMClass< GUM_SCALAR >& c,
809  O3RawCPT& attr) {
810  const auto& type = prm__->type(attr.type().label());
811 
812  auto domainSize = type->domainSize();
813  for (auto& prnt: attr.parents()) {
814  try {
815  domainSize *= c.get(prnt.label()).type()->domainSize();
816  } catch (NotFound&) {
817  // If we are here, all parents have been check so resolveSlotChain__
818  // will not raise an error and not return a nullptr
820  }
821  }
822 
823  // Check for CPT size
824  if (domainSize != attr.values().size()) {
826  attr.name(),
827  Size(attr.values().size()),
828  domainSize,
829  *errors__);
830  return false;
831  }
832 
833  // Add parameters to formulas
834  const auto& scope = c.scope();
835  for (auto& f: attr.values()) {
836  f.formula().variables().clear();
837 
838  for (const auto& values: scope) {
840  }
841  }
842 
843  // Check that CPT sums to 1
845  auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
846 
847  for (std::size_t i = 0; i < attr.values().size(); ++i) {
848  try {
849  auto idx = i % parent_size;
850  auto val = (GUM_SCALAR)attr.values()[i].formula().result();
851  values[idx] += val;
852 
853  if (val < 0.0 || 1.0 < val) {
855  attr.name(),
856  attr.values()[i],
857  *errors__);
858  return false;
859  }
860  } catch (Exception&) {
862  attr.name(),
863  attr.values()[i],
864  *errors__);
865  return false;
866  }
867  }
868 
869  for (auto f: values) {
870  if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
872  attr.name(),
873  float(f),
874  *errors__);
875  return false;
876  } else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
878  attr.name(),
879  float(f),
880  *errors__);
881  }
882  }
883  return true;
884  }
885 
886  template < typename GUM_SCALAR >
890  const O3Label& chain) {
891  auto s = chain.label();
892  auto current = &c;
893  std::vector< std::string > v;
894 
896 
897  for (size_t i = 0; i < v.size(); ++i) {
898  auto link = v[i];
899 
900  if (!checkSlotChainLink__(*current, chain, link)) { return nullptr; }
901 
902  auto elt = &(current->get(link));
903 
904  if (i == v.size() - 1) {
905  // last link, should be an attribute or aggregate
906  return elt;
907 
908  } else {
909  // should be a reference slot
910 
911  auto ref = dynamic_cast< const PRMReferenceSlot< GUM_SCALAR >* >(elt);
912  if (ref) {
913  current = &(ref->slotType());
914  } else {
915  return nullptr; // failsafe to prevent infinite loop
916  }
917  }
918  }
919 
920  // Encountered only reference slots
921 
922  return nullptr;
923  }
924 
925  template < typename GUM_SCALAR >
928  const O3Label& chain,
929  const std::string& s) {
930  if (!c.exists(s)) {
932  return false;
933  }
934  return true;
935  }
936 
937  template < typename GUM_SCALAR >
941 
942  for (auto& agg: c.aggregates()) {
944  auto params = std::vector< std::string >();
945  for (auto& p: agg.parameters()) {
946  params.push_back(p.label());
947  }
948 
950  agg.aggregateType().label(),
951  agg.variableType().label(),
952  params);
954  }
955  }
956 
957  factory.endClass(false);
958  }
959 
960  template < typename GUM_SCALAR >
962  O3Class& o3class,
963  O3Aggregate& agg) {
964  if (!solver__->resolveType(agg.variableType())) { return false; }
965 
966  // Checking type legality if overload
967  if (!checkAggTypeLegality__(o3class, agg)) { return false; }
968 
969  return true;
970  }
971 
972  template < typename GUM_SCALAR >
973  INLINE const PRMType*
975  O3Aggregate& agg) {
976  const auto& c = prm__->getClass(o3class.name().label());
977  auto t = (const PRMType*)nullptr;
978 
979  for (const auto& prnt: agg.parents()) {
980  auto elt = resolveSlotChain__(c, prnt);
981 
982  if (elt == nullptr) {
984  return nullptr;
985 
986  } else {
987  if (t == nullptr) {
988  try {
989  t = &(elt->type());
990 
991  } catch (OperationNotAllowed&) {
993  return nullptr;
994  }
995 
996  } else if ((*t) != elt->type()) {
997  // Wront type in chain
999  t->name(),
1000  elt->type().name(),
1001  *errors__);
1002  return nullptr;
1003  }
1004  }
1005  }
1006  return t;
1007  }
1008 
1009  template < typename GUM_SCALAR >
1010  INLINE bool
1012  O3Aggregate& agg) {
1013  if (prm__->isClass(o3class.superLabel().label())) {
1014  const auto& super = prm__->getClass(o3class.superLabel().label());
1015  const auto& agg_type = prm__->type(agg.variableType().label());
1016 
1017  if (super.exists(agg.name().label())
1018  && !agg_type.isSubTypeOf(super.get(agg.name().label()).type())) {
1020  o3class.superLabel(),
1021  *errors__);
1022  return false;
1023  }
1024  }
1025 
1026  return true;
1027  }
1028 
1029  template < typename GUM_SCALAR >
1030  INLINE bool
1032  O3Aggregate& agg,
1033  const PRMType* t) {
1034  bool ok = false;
1035 
1036  switch (gum::prm::PRMAggregate< GUM_SCALAR >::str2enum(
1037  agg.aggregateType().label())) {
1046  break;
1047  }
1048 
1053  break;
1054  }
1055 
1056  default: {
1057  GUM_ERROR(FatalError, "unknown aggregate type");
1058  }
1059  }
1060 
1061  if (!ok) { return false; }
1062 
1063  // Checking parameters type
1064  switch (gum::prm::PRMAggregate< GUM_SCALAR >::str2enum(
1065  agg.aggregateType().label())) {
1070  break;
1071  }
1072 
1073  default: { /* Nothing to do */
1074  }
1075  }
1076 
1077  return ok;
1078  }
1079 
1080  template < typename GUM_SCALAR >
1081  INLINE bool
1083  Size n) {
1084  if (agg.parameters().size() != n) {
1086  Size(n),
1087  Size(agg.parameters().size()),
1088  *errors__);
1089  return false;
1090  }
1091 
1092  return true;
1093  }
1094 
1095  template < typename GUM_SCALAR >
1097  O3Aggregate& agg,
1098  const gum::prm::PRMType& t) {
1099  const auto& param = agg.parameters().front();
1100  bool found = false;
1101  for (Size idx = 0; idx < t.variable().domainSize(); ++idx) {
1102  if (t.variable().label(idx) == param.label()) {
1103  found = true;
1104  break;
1105  }
1106  }
1107 
1108  if (!found) {
1110  return false;
1111  }
1112 
1113  return true;
1114  }
1115 
1116  } // namespace o3prm
1117  } // namespace prm
1118 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)