35 template <
typename GUM_SCALAR >
42 __o3_prm(&o3_prm), __solver(&solver), __errors(&errors) {
46 template <
typename GUM_SCALAR >
57 template <
typename GUM_SCALAR >
69 template <
typename GUM_SCALAR >
74 template <
typename GUM_SCALAR >
77 if (
this == &src) {
return *
this; }
90 template <
typename GUM_SCALAR >
93 if (
this == &src) {
return *
this; }
94 __prm = std::move(src.__prm);
101 __dag = std::move(src.__dag);
106 template <
typename GUM_SCALAR >
117 for (
auto& i : c->interfaces()) {
118 if (
__solver->resolveInterface(i)) { implements.
insert(i.label()); }
122 if (
__solver->resolveClass(c->superLabel())) {
124 c->name().label(), c->superLabel().label(), &implements,
true);
131 template <
typename GUM_SCALAR >
135 for (
auto id = topo_order.rbegin();
id != topo_order.rend(); --id) {
140 template <
typename GUM_SCALAR >
145 template <
typename GUM_SCALAR >
152 __classMap.insert(c->name().label(), c.get());
156 O3PRM_CLASS_DUPLICATE(c->name(), *
__errors);
164 template <
typename GUM_SCALAR >
167 if (c->superLabel().label() !=
"") {
168 if (!
__solver->resolveClass(c->superLabel())) {
return false; }
170 auto head =
__nameMap[c->superLabel().label()];
171 auto tail =
__nameMap[c->name().label()];
177 O3PRM_CLASS_CYLIC_INHERITANCE(c->name(), c->superLabel(), *
__errors);
186 template <
typename GUM_SCALAR >
190 __prm->getClass(c->name().label()).initializeInheritance();
199 template <
typename GUM_SCALAR >
204 attr_map.insert(a->name().label(), a.get());
210 agg_map.insert(agg.name().label(), &agg);
215 ref_map.insert(ref.name().label(), &ref);
220 if (
__solver->resolveInterface(i)) {
230 template <
typename GUM_SCALAR >
237 const auto& real_i =
__prm->getInterface(i.
label());
239 auto counter = (
Size)0;
240 for (
const auto& a : real_i.attributes()) {
241 if (attr_map.
exists(a->name())) {
245 O3PRM_CLASS_ATTR_IMPLEMENTATION(
251 if (agg_map.
exists(a->name())) {
256 O3PRM_CLASS_AGG_IMPLEMENTATION(
263 if (counter != real_i.attributes().size()) {
269 for (
const auto& r : real_i.referenceSlots()) {
270 if (ref_map.
exists(r->name())) {
275 O3PRM_CLASS_REF_IMPLEMENTATION(
284 template <
typename GUM_SCALAR >
288 if (!
__solver->resolveType(o3_type)) {
return false; }
290 return __prm->type(o3_type.
label()).isSubTypeOf(type);
293 template <
typename GUM_SCALAR >
296 if (!
__solver->resolveSlotType(o3_type)) {
return false; }
299 return __prm->getInterface(o3_type.
label()).isSubTypeOf(type);
301 return __prm->getClass(o3_type.
label()).isSubTypeOf(type);
305 template <
typename GUM_SCALAR >
320 template <
typename GUM_SCALAR >
326 factory.
addParameter(
"int", p.name().label(), p.value().value());
331 factory.
addParameter(
"real", p.name().label(), p.value().value());
340 template <
typename GUM_SCALAR >
349 template <
typename GUM_SCALAR >
359 ref.type().label(), ref.name().label(), ref.isArray());
366 template <
typename GUM_SCALAR >
370 if (!
__solver->resolveSlotType(ref.
type())) {
return false; }
376 const auto& elt = real_c.get(ref.
name().
label());
391 if (slot_type->name() == real_ref->slotType().name()) {
395 }
else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
411 if ((&ref_type) == (&real_c)) {
417 if (ref_type.isSubTypeOf(real_c)) {
426 template <
typename GUM_SCALAR >
435 template <
typename GUM_SCALAR >
444 template <
typename GUM_SCALAR >
451 factory.
startAttribute(attr->type().label(), attr->name().label());
459 template <
typename GUM_SCALAR >
463 if (!
__solver->resolveType(attr.
type())) {
return false; }
469 if (!super.exists(attr.
name().
label())) {
return true; }
471 const auto& super_type = super.get(attr.
name().
label()).type();
474 if (!type.isSubTypeOf(super_type)) {
482 template <
typename GUM_SCALAR >
497 for (
auto a : super.attributes()) {
498 to_complete.
insert(a->safeName());
501 for (
auto a : super.aggregates()) {
502 to_complete.insert(a->safeName());
507 .
get(a->name().label())
513 .
get(a.name().label())
517 for (
auto a : to_complete) {
526 template <
typename GUM_SCALAR >
540 template <
typename GUM_SCALAR >
548 for (
const auto& parent : agg.parents()) {
557 template <
typename GUM_SCALAR >
562 if (t ==
nullptr) {
return false; }
570 template <
typename GUM_SCALAR >
578 for (
const auto& parent : attr->parents()) {
582 auto raw =
dynamic_cast< const O3RawCPT*
>(attr.get());
585 auto values = std::vector< std::string >();
586 for (
const auto& val : raw->values()) {
587 values.push_back(val.formula().formula());
592 auto rule_cpt =
dynamic_cast< const O3RuleCPT*
>(attr.get());
594 for (
const auto& rule : rule_cpt->rules()) {
595 auto labels = std::vector< std::string >();
596 auto values = std::vector< std::string >();
598 for (
const auto& lbl : rule.first) {
599 labels.push_back(lbl.label());
602 for (
const auto& form : rule.second) {
603 values.push_back(form.formula().formula());
615 template <
typename GUM_SCALAR >
620 for (
auto& prnt : attr.
parents()) {
625 auto raw =
dynamic_cast< O3RawCPT*
>(&attr);
628 auto rule =
dynamic_cast< O3RuleCPT*
>(&attr);
634 template <
typename GUM_SCALAR >
637 if (prnt.
label().find(
'.') == std::string::npos) {
645 template <
typename GUM_SCALAR >
649 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *
__errors);
653 const auto& elt = c.
get(prnt.
label());
657 O3PRM_CLASS_ILLEGAL_PARENT(prnt, *
__errors);
664 template <
typename GUM_SCALAR >
671 template <
typename GUM_SCALAR >
675 if (rule.first.size() != attr.
parents().size()) {
676 O3PRM_CLASS_ILLEGAL_RULE_SIZE(
683 template <
typename GUM_SCALAR >
689 for (std::size_t i = 0; i < attr.
parents().size(); ++i) {
690 auto label = rule.first[i];
695 if (label.label() !=
"*" 696 && std::find(real_labels.begin(), real_labels.end(), label.label())
697 == real_labels.end()) {
698 O3PRM_CLASS_ILLEGAL_RULE_LABEL(rule, label, prnt, *
__errors);
705 return errors ==
false;
708 template <
typename GUM_SCALAR >
713 for (
auto& f : rule.second) {
714 f.formula().variables().clear();
715 for (
const auto& values : scope) {
716 f.formula().variables().insert(values.first, values.second->value());
722 template <
typename GUM_SCALAR >
729 GUM_SCALAR sum = 0.0;
730 for (
const auto& f : rule.second) {
732 auto value = GUM_SCALAR(f.formula().result());
734 if (value < 0.0 || 1.0 < value) {
745 if (std::abs(sum - 1.0) > 1e-3) {
746 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
749 }
else if (std::abs(sum - 1.0f) > 1e-6) {
750 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
753 return errors ==
false;
756 template <
typename GUM_SCALAR >
759 const auto& scope = c.
scope();
761 for (
auto& rule : attr.
rules()) {
773 return errors ==
false;
776 template <
typename GUM_SCALAR >
781 auto domainSize = type->domainSize();
782 for (
auto& prnt : attr.
parents()) {
784 domainSize *= c.
get(prnt.label()).type()->domainSize();
793 if (domainSize != attr.
values().size()) {
794 O3PRM_CLASS_ILLEGAL_CPT_SIZE(c.
name(),
803 const auto& scope = c.
scope();
804 for (
auto& f : attr.
values()) {
805 f.formula().variables().clear();
807 for (
const auto& values : scope) {
808 f.formula().variables().insert(values.first, values.second->value());
813 Size parent_size = domainSize / type->domainSize();
814 auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
816 for (std::size_t i = 0; i < attr.
values().size(); ++i) {
818 auto idx = i % parent_size;
819 auto val = (GUM_SCALAR)attr.
values()[i].formula().result();
822 if (val < 0.0 || 1.0 < val) {
823 O3PRM_CLASS_ILLEGAL_CPT_VALUE(
828 O3PRM_CLASS_ILLEGAL_CPT_VALUE(
834 for (
auto f : values) {
835 if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
836 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
839 }
else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
840 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
847 template <
typename GUM_SCALAR >
852 auto s = chain.
label();
854 std::vector< std::string > v;
858 for (
size_t i = 0; i < v.size(); ++i) {
863 auto elt = &(current->get(link));
865 if (i == v.size() - 1) {
874 current = &(ref->slotType());
886 template <
typename GUM_SCALAR >
890 const std::string& s) {
892 O3PRM_CLASS_LINK_NOT_FOUND(chain, s, *
__errors);
898 template <
typename GUM_SCALAR >
905 auto params = std::vector< std::string >();
906 for (
auto& p : agg.parameters()) {
907 params.push_back(p.label());
911 agg.aggregateType().label(),
912 agg.variableType().label(),
921 template <
typename GUM_SCALAR >
932 template <
typename GUM_SCALAR >
937 auto t = (
const PRMType*)
nullptr;
939 for (
const auto& prnt : agg.
parents()) {
942 if (elt ==
nullptr) {
943 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *
__errors);
952 O3PRM_CLASS_WRONG_PARENT(prnt, *
__errors);
956 }
else if ((*t) != elt->type()) {
958 O3PRM_CLASS_WRONG_PARENT_TYPE(
959 prnt, t->name(), elt->type().name(), *
__errors);
967 template <
typename GUM_SCALAR >
976 && !agg_type.isSubTypeOf(super.get(agg.
name().
label()).type())) {
977 O3PRM_CLASS_ILLEGAL_OVERLOAD(
986 template <
typename GUM_SCALAR >
1013 if (!ok) {
return false; }
1032 template <
typename GUM_SCALAR >
1037 O3PRM_CLASS_AGG_PARAMETERS(
1045 template <
typename GUM_SCALAR >
1048 const auto& param = agg.
parameters().front();
1058 O3PRM_CLASS_AGG_PARAMETER_NOT_FOUND(agg.
name(), param, *
__errors);
std::pair< O3LabelList, O3FormulaList > O3Rule
bool __checkAggregateForDeclaration(O3Class &o3class, O3Aggregate &agg)
virtual O3LabelList & parents()
PRMParameter is a member of a Class in a PRM.
void completeAggregates()
bool __checkLabelsValues(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
virtual O3RuleList & rules()
DiscreteVariable & variable()
Return a reference on the DiscreteVariable contained in this.
void endAggregator()
Finishes an aggregate declaration.
bool __checkAttributeForCompletion(const O3Class &o3_c, O3Attribute &attr)
void __addParamsToForms(const HashTable< std::string, const PRMParameter< GUM_SCALAR > * > &scope, O3RuleCPT::O3Rule &rule)
const std::string & name() const
Returns the name of this object.
bool __checkAttributeForDeclaration(O3Class &o3_c, O3Attribute &attr)
HashTable< std::string, O3Attribute *> AttrMap
O3ReferenceSlotList & referenceSlots()
void buildReferenceSlots()
bool __checkRawCPT(const PRMClass< GUM_SCALAR > &c, O3RawCPT &attr)
The O3Aggregate is part of the AST of the O3PRM language.
Headers for the O3ClassFactory class.
virtual void continueClass(const std::string &c) override
Continue the declaration of a class.
O3AggregateList & aggregates()
bool __checkReferenceSlot(O3Class &c, O3ReferenceSlot &ref)
HashTable< std::string, O3Aggregate *> AggMap
void __setO3ClassCreationOrder()
Abstract class representing an element of PRM class.
virtual void endAttribute() override
Tells the factory that we finished declaring an attribute.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
virtual void addReferenceSlot(const std::string &type, const std::string &name, bool isArray) override
Tells the factory that we started declaring a slot.
This class is used contain and manipulate gum::ParseError.
The O3Label is part of the AST of the O3PRM language.
bool __checkParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
The O3Attribute is part of the AST of the O3PRM language.
void __addReferenceSlots(O3Class &c)
virtual bool exists(const std::string &name) const
Returns true if a member with the given name exists in this PRMClassElementContainer or in the PRMCla...
virtual O3FormulaList & values()
void __completeAggregates(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
gum is the global namespace for all aGrUM entities
O3LabelList & parameters()
O3ParameterList & parameters()
bool __checkLocalParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
virtual NodeId addNode()
insert a new node and return its id
A PRMReferenceSlot represent a relation between two PRMClassElementContainer.
virtual void startClass(const std::string &c, const std::string &ext="", const Set< std::string > *implements=nullptr, bool delayInheritance=false) override
Tells the factory that we start a class declaration.
The class for generic Hash Tables.
void buildImplementations()
bool __checkRuleCPT(const PRMClass< GUM_SCALAR > &c, O3RuleCPT &attr)
void continueAggregator(const std::string &name)
Conitnues an aggregator declaration.
bool __checkImplementation(O3Class &c)
virtual void startAttribute(const std::string &type, const std::string &name, bool scalar_atttr=false) override
Tells the factory that we start an attribute declaration.
bool __checkRemoteParent(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &prnt)
HashTable< std::string, gum::NodeId > __nameMap
void __completeAttribute(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
The O3RuleCPT is part of the AST of the O3PRM language.
virtual Size domainSize() const =0
void setRawCPFByColumns(const std::vector< GUM_SCALAR > &array)
Gives the factory the CPF in its raw form.
Resolves names for the different O3PRM factories.
bool __checkAggTypeLegality(O3Class &o3class, O3Aggregate &agg)
bool __checkAggParameters(O3Class &o3class, O3Aggregate &agg, const PRMType *t)
The O3Class is part of the AST of the O3PRM language.
O3ClassFactory(PRM< GUM_SCALAR > &prm, O3PRM &o3_prm, O3NameSolver< GUM_SCALAR > &solver, ErrorsContainer &errors)
bool __checkParameterValue(O3Aggregate &agg, const gum::prm::PRMType &t)
Factory which builds a PRM<GUM_SCALAR>.
HashTable< std::string, const PRMParameter< GUM_SCALAR > *> scope() const
Returns all the parameters in the scope of this class.
The O3ReferenceSlot is part of the AST of the O3PRM language.
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
virtual void setCPFByRule(const std::vector< std::string > &labels, const std::vector< GUM_SCALAR > &values)
Fills the CPF using a rule.
HashTable< NodeId, O3Class *> __nodeMap
PRMClassElement< GUM_SCALAR > & get(NodeId id)
See gum::prm::PRMClassElementContainer<GUM_SCALAR>::get(NodeId).
bool __checkAggregateForCompletion(O3Class &o3class, O3Aggregate &agg)
HashTable< std::string, O3Class *> __classMap
Base class for all aGrUM's exceptions.
HashTable< std::string, O3ReferenceSlot *> RefMap
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
void __declareAggregates(O3Class &c)
This is a decoration of the DiscreteVariable class.
bool __checkLabelsNumber(const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
O3AttributeList & attributes()
virtual void endClass(bool checkImplementations=true) override
Tells the factory that we finished a class declaration.
Builds gum::prm::Class from gum::prm::o3prm::O3Class.
const Sequence< NodeId > & topologicalOrder(bool clear=true) const
The topological order stays the same as long as no variable or arcs are added or erased src the topol...
bool __checkRuleCPTSumsTo1(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
The O3RawCPT is part of the AST of the O3PRM language.
std::vector< O3Class *> __o3Classes
bool __checkSlotChainLink(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain, const std::string &s)
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
<agrum/PRM/classElementContainer.h>
O3Label & aggregateType()
ErrorsContainer * __errors
A PRMClass is an object of a PRM representing a fragment of a Bayesian Network which can be instantia...
O3NameSolver< GUM_SCALAR > * __solver
O3ClassFactory< GUM_SCALAR > & operator=(const O3ClassFactory< GUM_SCALAR > &src)
O3LabelList & interfaces()
bool __checkAndAddArcsToDag()
void decomposePath(const std::string &path, std::vector< std::string > &v)
Decompose a string in a vector of strings using "." as separators.
void __addParameters(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
const PRMClassElement< GUM_SCALAR > * __resolveSlotChain(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain)
virtual void continueAttribute(const std::string &name) override
Continues the declaration of an attribute.
void completeAttributes()
void addParameter(const std::string &type, const std::string &name, double value) override
Add a parameter to the current class with a default value.
bool __checkParametersNumber(O3Aggregate &agg, Size n)
The O3PRM is part of the AST of the O3PRM language.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
void __declareAttribute(O3Class &c)
const PRMType * __checkAggParents(O3Class &o3class, O3Aggregate &agg)
void insert(const Key &k)
Inserts a new element into the set.
void startAggregator(const std::string &name, const std::string &agg_type, const std::string &rv_type, const std::vector< std::string > ¶ms)
Start an aggregator declaration.
bool __checkAndAddNodesToDag()
virtual void addParent(const std::string &name) override
Tells the factory that we add a parent to the current declared attribute.
#define GUM_ERROR(type, msg)
PRM< GUM_SCALAR > * __prm