aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
O3ClassFactory_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation for the O3ClassFactory class.
25  *
26  * @author Christophe GONZALES(@AMU) and Pierre-Henri WUILLEMIN(@LIP6)
27  * @author Lionel TORTI
28  */
29 
30 #include <agrum/PRM/o3prm/O3ClassFactory.h>
31 
32 namespace gum {
33  namespace prm {
34  namespace o3prm {
35 
36  template < typename GUM_SCALAR >
37  INLINE O3ClassFactory< GUM_SCALAR >::O3ClassFactory(PRM< GUM_SCALAR >& prm,
38  O3PRM& o3_prm,
41  _prm_(&prm),
44  }
45 
46  template < typename GUM_SCALAR >
52  }
53 
54  template < typename GUM_SCALAR >
62  }
63 
64  template < typename GUM_SCALAR >
67  }
68 
69  template < typename GUM_SCALAR >
72  if (this == &src) { return *this; }
73  _prm_ = src._prm_;
80  _dag_ = src._dag_;
82  return *this;
83  }
84 
85  template < typename GUM_SCALAR >
88  if (this == &src) { return *this; }
89  _prm_ = std::move(src._prm_);
96  _dag_ = std::move(src._dag_);
98  return *this;
99  }
100 
101  template < typename GUM_SCALAR >
104 
105  // Class with a super class must be declared after
106  if (_checkO3Classes_()) {
108 
109  for (auto c: _o3Classes_) {
110  // Soving interfaces
111  auto implements = Set< std::string >();
112  for (auto& i: c->interfaces()) {
114  }
115 
116  // Adding the class
117  if (_solver_->resolveClass(c->superLabel())) {
118  factory.startClass(c->name().label(), c->superLabel().label(), &implements, true);
119  factory.endClass(false);
120  }
121  }
122  }
123  }
124 
125  template < typename GUM_SCALAR >
128 
129  for (auto id = topo_order.rbegin(); id != topo_order.rend(); --id) {
131  }
132  }
133 
134  template < typename GUM_SCALAR >
137  }
138 
139  template < typename GUM_SCALAR >
141  for (auto& c: _o3_prm_->classes()) {
142  auto id = _dag_.addNode();
143 
144  try {
145  _nameMap_.insert(c->name().label(), id);
146  _classMap_.insert(c->name().label(), c.get());
147  _nodeMap_.insert(id, c.get());
148 
149  } catch (DuplicateElement&) {
151  return false;
152  }
153  }
154 
155  return true;
156  }
157 
158  template < typename GUM_SCALAR >
160  for (auto& c: _o3_prm_->classes()) {
161  if (c->superLabel().label() != "") {
162  if (!_solver_->resolveClass(c->superLabel())) { return false; }
163 
164  auto head = _nameMap_[c->superLabel().label()];
165  auto tail = _nameMap_[c->name().label()];
166 
167  try {
168  _dag_.addArc(tail, head);
169  } catch (InvalidDirectedCycle&) {
170  // Cyclic inheritance
172  return false;
173  }
174  }
175  }
176 
177  return true;
178  }
179 
180  template < typename GUM_SCALAR >
182  for (auto& c: _o3_prm_->classes()) {
183  if (_checkImplementation_(*c)) {
185  }
186  }
187  }
188 
192 
193  template < typename GUM_SCALAR >
195  // Saving attributes names for fast lookup
196  auto attr_map = AttrMap();
197  for (auto& a: c.attributes()) {
198  attr_map.insert(a->name().label(), a.get());
199  }
200 
201  // Saving aggregates names for fast lookup
202  auto agg_map = AggMap();
203  for (auto& agg: c.aggregates()) {
204  agg_map.insert(agg.name().label(), &agg);
205  }
206 
207  auto ref_map = RefMap();
208  for (auto& ref: c.referenceSlots()) {
209  ref_map.insert(ref.name().label(), &ref);
210  }
211 
212  // Cheking interface implementation
213  for (auto& i: c.interfaces()) {
214  if (_solver_->resolveInterface(i)) {
215  if (!_checkImplementation_(c, i, attr_map, agg_map, ref_map)) { return false; }
216  }
217  }
218 
219  return true;
220  }
221 
222  template < typename GUM_SCALAR >
224  O3Label& i,
225  AttrMap& attr_map,
226  AggMap& agg_map,
227  RefMap& ref_map) {
228  const auto& real_i = _prm_->getInterface(i.label());
229 
230  auto counter = (Size)0;
231  for (const auto& a: real_i.attributes()) {
232  if (attr_map.exists(a->name())) {
233  ++counter;
234 
235  if (!_checkImplementation_(attr_map[a->name()]->type(), a->type())) {
237  return false;
238  }
239  }
240 
241  if (agg_map.exists(a->name())) {
242  ++counter;
243 
246  return false;
247  }
248  }
249  }
250 
251  if (counter != real_i.attributes().size()) {
253  return false;
254  }
255 
256  counter = 0;
257  for (const auto& r: real_i.referenceSlots()) {
258  if (ref_map.exists(r->name())) {
259  ++counter;
260 
261  if (!_checkImplementation_(ref_map[r->name()]->type(), r->slotType())) {
263  return false;
264  }
265  }
266  }
267  return true;
268  }
269 
270  template < typename GUM_SCALAR >
272  const PRMType& type) {
273  if (!_solver_->resolveType(o3_type)) { return false; }
274 
275  return _prm_->type(o3_type.label()).isSubTypeOf(type);
276  }
277 
278  template < typename GUM_SCALAR >
280  O3Label& o3_type,
282  if (!_solver_->resolveSlotType(o3_type)) { return false; }
283 
284  if (_prm_->isInterface(o3_type.label())) {
286  } else {
288  }
289  }
290 
291  template < typename GUM_SCALAR >
294  // Class with a super class must be declared after
295  for (auto c: _o3Classes_) {
297 
299 
301 
302  factory.endClass(false);
303  }
304  }
305 
306  template < typename GUM_SCALAR >
308  O3Class& c) {
309  for (auto& p: c.parameters()) {
310  switch (p.type()) {
311  case O3Parameter::PRMType::INT: {
312  factory.addParameter("int", p.name().label(), p.value().value());
313  break;
314  }
315 
316  case O3Parameter::PRMType::FLOAT: {
317  factory.addParameter("real", p.name().label(), p.value().value());
318  break;
319  }
320 
321  default: {
322  GUM_ERROR(FatalError, "unknown O3Parameter type")
323  }
324  }
325  }
326  }
327 
328  template < typename GUM_SCALAR >
330  // Class with a super class must be declared after
331  for (auto c: _o3Classes_) {
334  }
335  }
336 
337  template < typename GUM_SCALAR >
340 
342 
343  // References
344  for (auto& ref: c.referenceSlots()) {
345  if (_checkReferenceSlot_(c, ref)) {
347  }
348  }
349 
350  factory.endClass(false);
351  }
352 
353  template < typename GUM_SCALAR >
355  O3ReferenceSlot& ref) {
356  if (!_solver_->resolveSlotType(ref.type())) { return false; }
357 
358  const auto& real_c = _prm_->getClass(c.name().label());
359 
360  // Check for dupplicates
361  if (real_c.exists(ref.name().label())) {
362  const auto& elt = real_c.get(ref.name().label());
363 
365  auto slot_type = (PRMClassElementContainer< GUM_SCALAR >*)nullptr;
366 
367  if (_prm_->isInterface(ref.type().label())) {
369 
370  } else {
371  slot_type = &(_prm_->getClass(ref.type().label()));
372  }
373 
374  auto real_ref = static_cast< const PRMReferenceSlot< GUM_SCALAR >* >(&elt);
375 
376  if (slot_type->name() == real_ref->slotType().name()) {
378  return false;
379 
380  } else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
382  return false;
383  }
384 
385  } else {
387  return false;
388  }
389  }
390 
391  // If class we need to check for illegal references
392  if (_prm_->isClass(ref.type().label())) {
393  const auto& ref_type = _prm_->getClass(ref.type().label());
394 
395  // No recursive reference
396  if ((&ref_type) == (&real_c)) {
398  return false;
399  }
400 
401  // No reference to subclasses
402  if (ref_type.isSubTypeOf(real_c)) {
404  return false;
405  }
406  }
407 
408  return true;
409  }
410 
411  template < typename GUM_SCALAR >
413  // Class with a super class must be declared after
414  for (auto c: _o3Classes_) {
417  }
418  }
419 
420  template < typename GUM_SCALAR >
422  // Class with a super class must be declared after
423  for (auto c: _o3Classes_) {
426  }
427  }
428 
429  template < typename GUM_SCALAR >
433 
434  for (auto& attr: c.attributes()) {
438  }
439  }
440 
441  factory.endClass(false);
442  }
443 
444  template < typename GUM_SCALAR >
446  O3Attribute& attr) {
447  // Check type
448  if (!_solver_->resolveType(attr.type())) { return false; }
449 
450  // Checking type legality if overload
451  if (c.superLabel().label() != "") {
452  const auto& super = _prm_->getClass(c.superLabel().label());
453 
454  if (!super.exists(attr.name().label())) { return true; }
455 
456  const auto& super_type = super.get(attr.name().label()).type();
457  const auto& type = _prm_->type(attr.type().label());
458 
459  if (!type.isSubTypeOf(super_type)) {
461  return false;
462  }
463  }
464  return true;
465  }
466 
467  template < typename GUM_SCALAR >
470 
471  // Class with a super class must be declared in order
472  for (auto c: _o3Classes_) {
475 
477 
478  if (c->superLabel().label() != "") {
479  auto& super = _prm_->getClass(c->superLabel().label());
480  auto to_complete = Set< std::string >();
481 
482  for (auto a: super.attributes()) {
484  }
485 
486  for (auto a: super.aggregates()) {
488  }
489 
490  for (auto& a: c->attributes()) {
492  _prm_->getClass(c->name().label()).get(a->name().label()).safeName());
493  }
494 
495  for (auto& a: c->aggregates()) {
497  _prm_->getClass(c->name().label()).get(a.name().label()).safeName());
498  }
499 
500  for (auto a: to_complete) {
502  }
503  }
504 
505  factory.endClass(true);
506  }
507  }
508 
509  template < typename GUM_SCALAR >
512 
513  // Class with a super class must be declared in order
514  for (auto c: _o3Classes_) {
516 
518 
519  factory.endClass(false);
520  }
521  }
522 
523  template < typename GUM_SCALAR >
524  INLINE void
526  O3Class& c) {
527  // Attributes
528  for (auto& agg: c.aggregates()) {
531 
532  for (const auto& parent: agg.parents()) {
534  }
535 
537  }
538  }
539  }
540 
541  template < typename GUM_SCALAR >
543  O3Aggregate& agg) {
544  // Checking parents
545  auto t = _checkAggParents_(c, agg);
546  if (t == nullptr) { return false; }
547 
548  // Checking parameters numbers
549  if (!_checkAggParameters_(c, agg, t)) { return false; }
550 
551  return true;
552  }
553 
554  template < typename GUM_SCALAR >
555  INLINE void
557  O3Class& c) {
558  // Attributes
559  for (auto& attr: c.attributes()) {
562 
563  for (const auto& parent: attr->parents()) {
565  }
566 
567  auto raw = dynamic_cast< const O3RawCPT* >(attr.get());
568 
569  if (raw) {
570  auto values = std::vector< std::string >();
571  for (const auto& val: raw->values()) {
573  }
575  }
576 
577  auto rule_cpt = dynamic_cast< const O3RuleCPT* >(attr.get());
578  if (rule_cpt) {
579  for (const auto& rule: rule_cpt->rules()) {
580  auto labels = std::vector< std::string >();
581  auto values = std::vector< std::string >();
582 
583  for (const auto& lbl: rule.first) {
585  }
586 
587  for (const auto& form: rule.second) {
589  }
590 
592  }
593  }
594 
596  }
597  }
598  }
599 
600  template < typename GUM_SCALAR >
602  O3Attribute& attr) {
603  // Check for parents existence
604  const auto& c = _prm_->getClass(o3_c.name().label());
605  for (auto& prnt: attr.parents()) {
606  if (!_checkParent_(c, prnt)) { return false; }
607  }
608 
609  // Check that CPT sums to 1
610  auto raw = dynamic_cast< O3RawCPT* >(&attr);
611  if (raw) { return _checkRawCPT_(c, *raw); }
612 
613  auto rule = dynamic_cast< O3RuleCPT* >(&attr);
614  if (rule) { return _checkRuleCPT_(c, *rule); }
615 
616  return true;
617  }
618 
619  template < typename GUM_SCALAR >
621  const O3Label& prnt) {
622  if (prnt.label().find('.') == std::string::npos) {
623  return _checkLocalParent_(c, prnt);
624 
625  } else {
626  return _checkRemoteParent_(c, prnt);
627  }
628  }
629 
630  template < typename GUM_SCALAR >
632  const O3Label& prnt) {
633  if (!c.exists(prnt.label())) {
635  return false;
636  }
637 
638  const auto& elt = c.get(prnt.label());
643  return false;
644  }
645 
646  return true;
647  }
648 
649  template < typename GUM_SCALAR >
652  const O3Label& prnt) {
653  if (_resolveSlotChain_(c, prnt) == nullptr) { return false; }
654  return true;
655  }
656 
657  template < typename GUM_SCALAR >
659  const O3RuleCPT::O3Rule& rule) {
660  // Check that the number of labels is correct
661  if (rule.first.size() != attr.parents().size()) {
663  return false;
664  }
665  return true;
666  }
667 
668  template < typename GUM_SCALAR >
670  const O3RuleCPT& attr,
671  const O3RuleCPT::O3Rule& rule) {
672  bool errors = false;
673  for (std::size_t i = 0; i < attr.parents().size(); ++i) {
674  auto label = rule.first[i];
675  auto prnt = attr.parents()[i];
676  try {
678  // c.get(prnt.label()).type()->labels();
679  if (label.label() != "*"
681  == real_labels.end()) {
683  errors = true;
684  }
685  } catch (Exception&) {
686  // parent does not exists and is already reported
687  }
688  }
689  return errors == false;
690  }
691 
692  template < typename GUM_SCALAR >
694  const HashTable< std::string, const PRMParameter< GUM_SCALAR >* >& scope,
695  O3RuleCPT::O3Rule& rule) {
696  // Add parameters to formulas
697  for (auto& f: rule.second) {
698  f.formula().variables().clear();
699  for (const auto& values: scope) {
701  }
702  }
703  }
704 
705 
706  template < typename GUM_SCALAR >
707  INLINE bool
709  const O3RuleCPT& attr,
710  const O3RuleCPT::O3Rule& rule) {
711  bool errors = false;
712  // Check that formulas are valid and sums to 1
713  GUM_SCALAR sum = 0.0;
714  for (const auto& f: rule.second) {
715  try {
716  auto value = GUM_SCALAR(f.formula().result());
717  sum += value;
718  if (value < 0.0 || 1.0 < value) {
720  errors = true;
721  }
722  } catch (OperationNotAllowed&) {
724  errors = true;
725  }
726  }
727 
728  // Check that CPT sums to 1
729  if (std::abs(sum - 1.0) > 1e-3) {
731  errors = true;
732  } else if (std::abs(sum - 1.0f) > 1e-6) {
734  }
735  return errors == false;
736  }
737 
738  template < typename GUM_SCALAR >
740  O3RuleCPT& attr) {
741  const auto& scope = c.scope();
742  bool errors = false;
743  for (auto& rule: attr.rules()) {
744  try {
745  if (!_checkLabelsNumber_(attr, rule)) { errors = true; }
746  if (!_checkLabelsValues_(c, attr, rule)) { errors = true; }
748  if (!_checkRuleCPTSumsTo1_(c, attr, rule)) { errors = true; }
749  } catch (Exception& e) {
750  GUM_SHOWERROR(e);
751  errors = true;
752  }
753  }
754 
755  return errors == false;
756  }
757 
758  template < typename GUM_SCALAR >
760  O3RawCPT& attr) {
761  const auto& type = _prm_->type(attr.type().label());
762 
763  auto domainSize = type->domainSize();
764  for (auto& prnt: attr.parents()) {
765  try {
766  domainSize *= c.get(prnt.label()).type()->domainSize();
767  } catch (NotFound&) {
768  // If we are here, all parents have been check so _resolveSlotChain_
769  // will not raise an error and not return a nullptr
771  }
772  }
773 
774  // Check for CPT size
775  if (domainSize != attr.values().size()) {
777  attr.name(),
778  Size(attr.values().size()),
779  domainSize,
780  *_errors_);
781  return false;
782  }
783 
784  // Add parameters to formulas
785  const auto& scope = c.scope();
786  for (auto& f: attr.values()) {
787  f.formula().variables().clear();
788 
789  for (const auto& values: scope) {
791  }
792  }
793 
794  // Check that CPT sums to 1
796  auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
797 
798  for (std::size_t i = 0; i < attr.values().size(); ++i) {
799  try {
800  auto idx = i % parent_size;
801  auto val = (GUM_SCALAR)attr.values()[i].formula().result();
802  values[idx] += val;
803 
804  if (val < 0.0 || 1.0 < val) {
806  return false;
807  }
808  } catch (Exception&) {
810  return false;
811  }
812  }
813 
814  for (auto f: values) {
815  if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
817  return false;
818  } else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
820  }
821  }
822  return true;
823  }
824 
825  template < typename GUM_SCALAR >
828  const O3Label& chain) {
829  auto s = chain.label();
830  auto current = &c;
831  std::vector< std::string > v;
832 
834 
835  for (size_t i = 0; i < v.size(); ++i) {
836  auto link = v[i];
837 
838  if (!_checkSlotChainLink_(*current, chain, link)) { return nullptr; }
839 
840  auto elt = &(current->get(link));
841 
842  if (i == v.size() - 1) {
843  // last link, should be an attribute or aggregate
844  return elt;
845 
846  } else {
847  // should be a reference slot
848 
849  auto ref = dynamic_cast< const PRMReferenceSlot< GUM_SCALAR >* >(elt);
850  if (ref) {
851  current = &(ref->slotType());
852  } else {
853  return nullptr; // failsafe to prevent infinite loop
854  }
855  }
856  }
857 
858  // Encountered only reference slots
859 
860  return nullptr;
861  }
862 
863  template < typename GUM_SCALAR >
866  const O3Label& chain,
867  const std::string& s) {
868  if (!c.exists(s)) {
870  return false;
871  }
872  return true;
873  }
874 
875  template < typename GUM_SCALAR >
879 
880  for (auto& agg: c.aggregates()) {
882  auto params = std::vector< std::string >();
883  for (auto& p: agg.parameters()) {
884  params.push_back(p.label());
885  }
886 
888  agg.aggregateType().label(),
889  agg.variableType().label(),
890  params);
892  }
893  }
894 
895  factory.endClass(false);
896  }
897 
898  template < typename GUM_SCALAR >
900  O3Aggregate& agg) {
901  if (!_solver_->resolveType(agg.variableType())) { return false; }
902 
903  // Checking type legality if overload
904  if (!_checkAggTypeLegality_(o3class, agg)) { return false; }
905 
906  return true;
907  }
908 
909  template < typename GUM_SCALAR >
911  O3Aggregate& agg) {
912  const auto& c = _prm_->getClass(o3class.name().label());
913  auto t = (const PRMType*)nullptr;
914 
915  for (const auto& prnt: agg.parents()) {
916  auto elt = _resolveSlotChain_(c, prnt);
917 
918  if (elt == nullptr) {
920  return nullptr;
921 
922  } else {
923  if (t == nullptr) {
924  try {
925  t = &(elt->type());
926 
927  } catch (OperationNotAllowed&) {
929  return nullptr;
930  }
931 
932  } else if ((*t) != elt->type()) {
933  // Wront type in chain
935  return nullptr;
936  }
937  }
938  }
939  return t;
940  }
941 
942  template < typename GUM_SCALAR >
944  O3Aggregate& agg) {
945  if (_prm_->isClass(o3class.superLabel().label())) {
946  const auto& super = _prm_->getClass(o3class.superLabel().label());
947  const auto& agg_type = _prm_->type(agg.variableType().label());
948 
949  if (super.exists(agg.name().label())
950  && !agg_type.isSubTypeOf(super.get(agg.name().label()).type())) {
952  return false;
953  }
954  }
955 
956  return true;
957  }
958 
959  template < typename GUM_SCALAR >
961  O3Aggregate& agg,
962  const PRMType* t) {
963  bool ok = false;
964 
974  break;
975  }
976 
981  break;
982  }
983 
984  default: {
985  GUM_ERROR(FatalError, "unknown aggregate type")
986  }
987  }
988 
989  if (!ok) { return false; }
990 
991  // Checking parameters type
997  break;
998  }
999 
1000  default: { /* Nothing to do */
1001  }
1002  }
1003 
1004  return ok;
1005  }
1006 
1007  template < typename GUM_SCALAR >
1009  if (agg.parameters().size() != n) {
1011  return false;
1012  }
1013 
1014  return true;
1015  }
1016 
1017  template < typename GUM_SCALAR >
1019  const gum::prm::PRMType& t) {
1020  const auto& param = agg.parameters().front();
1021  bool found = false;
1022  for (Size idx = 0; idx < t.variable().domainSize(); ++idx) {
1023  if (t.variable().label(idx) == param.label()) {
1024  found = true;
1025  break;
1026  }
1027  }
1028 
1029  if (!found) {
1031  return false;
1032  }
1033 
1034  return true;
1035  }
1036 
1037  } // namespace o3prm
1038  } // namespace prm
1039 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)