37 template <
typename GUM_SCALAR >
44 __o3_prm(&o3_prm), __solver(&solver), __errors(&errors) {
48 template <
typename GUM_SCALAR >
59 template <
typename GUM_SCALAR >
71 template <
typename GUM_SCALAR >
76 template <
typename GUM_SCALAR >
79 if (
this == &src) {
return *
this; }
92 template <
typename GUM_SCALAR >
95 if (
this == &src) {
return *
this; }
96 __prm = std::move(src.__prm);
103 __dag = std::move(src.__dag);
108 template <
typename GUM_SCALAR >
119 for (
auto& i : c->interfaces()) {
120 if (
__solver->resolveInterface(i)) { implements.
insert(i.label()); }
124 if (
__solver->resolveClass(c->superLabel())) {
126 c->name().label(), c->superLabel().label(), &implements,
true);
133 template <
typename GUM_SCALAR >
137 for (
auto id = topo_order.rbegin();
id != topo_order.rend(); --id) {
142 template <
typename GUM_SCALAR >
147 template <
typename GUM_SCALAR >
154 __classMap.insert(c->name().label(), c.get());
158 O3PRM_CLASS_DUPLICATE(c->name(), *
__errors);
166 template <
typename GUM_SCALAR >
169 if (c->superLabel().label() !=
"") {
170 if (!
__solver->resolveClass(c->superLabel())) {
return false; }
172 auto head =
__nameMap[c->superLabel().label()];
173 auto tail =
__nameMap[c->name().label()];
179 O3PRM_CLASS_CYLIC_INHERITANCE(c->name(), c->superLabel(), *
__errors);
188 template <
typename GUM_SCALAR >
192 __prm->getClass(c->name().label()).initializeInheritance();
201 template <
typename GUM_SCALAR >
206 attr_map.insert(a->name().label(), a.get());
212 agg_map.insert(agg.name().label(), &agg);
217 ref_map.insert(ref.name().label(), &ref);
222 if (
__solver->resolveInterface(i)) {
232 template <
typename GUM_SCALAR >
239 const auto& real_i =
__prm->getInterface(i.
label());
241 auto counter = (
Size)0;
242 for (
const auto& a : real_i.attributes()) {
243 if (attr_map.
exists(a->name())) {
247 O3PRM_CLASS_ATTR_IMPLEMENTATION(
253 if (agg_map.
exists(a->name())) {
258 O3PRM_CLASS_AGG_IMPLEMENTATION(
265 if (counter != real_i.attributes().size()) {
271 for (
const auto& r : real_i.referenceSlots()) {
272 if (ref_map.
exists(r->name())) {
277 O3PRM_CLASS_REF_IMPLEMENTATION(
286 template <
typename GUM_SCALAR >
290 if (!
__solver->resolveType(o3_type)) {
return false; }
292 return __prm->type(o3_type.
label()).isSubTypeOf(type);
295 template <
typename GUM_SCALAR >
298 if (!
__solver->resolveSlotType(o3_type)) {
return false; }
301 return __prm->getInterface(o3_type.
label()).isSubTypeOf(type);
303 return __prm->getClass(o3_type.
label()).isSubTypeOf(type);
307 template <
typename GUM_SCALAR >
322 template <
typename GUM_SCALAR >
328 factory.
addParameter(
"int", p.name().label(), p.value().value());
333 factory.
addParameter(
"real", p.name().label(), p.value().value());
344 template <
typename GUM_SCALAR >
353 template <
typename GUM_SCALAR >
363 ref.type().label(), ref.name().label(), ref.isArray());
370 template <
typename GUM_SCALAR >
374 if (!
__solver->resolveSlotType(ref.
type())) {
return false; }
380 const auto& elt = real_c.get(ref.
name().
label());
395 if (slot_type->name() == real_ref->slotType().name()) {
399 }
else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
415 if ((&ref_type) == (&real_c)) {
421 if (ref_type.isSubTypeOf(real_c)) {
430 template <
typename GUM_SCALAR >
439 template <
typename GUM_SCALAR >
448 template <
typename GUM_SCALAR >
455 factory.
startAttribute(attr->type().label(), attr->name().label());
463 template <
typename GUM_SCALAR >
467 if (!
__solver->resolveType(attr.
type())) {
return false; }
473 if (!super.exists(attr.
name().
label())) {
return true; }
475 const auto& super_type = super.get(attr.
name().
label()).type();
478 if (!type.isSubTypeOf(super_type)) {
486 template <
typename GUM_SCALAR >
501 for (
auto a : super.attributes()) {
502 to_complete.
insert(a->safeName());
505 for (
auto a : super.aggregates()) {
506 to_complete.insert(a->safeName());
511 .
get(a->name().label())
517 .
get(a.name().label())
521 for (
auto a : to_complete) {
530 template <
typename GUM_SCALAR >
544 template <
typename GUM_SCALAR >
552 for (
const auto& parent : agg.parents()) {
561 template <
typename GUM_SCALAR >
566 if (t ==
nullptr) {
return false; }
574 template <
typename GUM_SCALAR >
582 for (
const auto& parent : attr->parents()) {
586 auto raw =
dynamic_cast< const O3RawCPT*
>(attr.get());
589 auto values = std::vector< std::string >();
590 for (
const auto& val : raw->values()) {
591 values.push_back(val.formula().formula());
596 auto rule_cpt =
dynamic_cast< const O3RuleCPT*
>(attr.get());
598 for (
const auto& rule : rule_cpt->rules()) {
599 auto labels = std::vector< std::string >();
600 auto values = std::vector< std::string >();
602 for (
const auto& lbl : rule.first) {
603 labels.push_back(lbl.label());
606 for (
const auto& form : rule.second) {
607 values.push_back(form.formula().formula());
619 template <
typename GUM_SCALAR >
624 for (
auto& prnt : attr.
parents()) {
629 auto raw =
dynamic_cast< O3RawCPT*
>(&attr);
632 auto rule =
dynamic_cast< O3RuleCPT*
>(&attr);
638 template <
typename GUM_SCALAR >
641 if (prnt.
label().find(
'.') == std::string::npos) {
649 template <
typename GUM_SCALAR >
653 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *
__errors);
657 const auto& elt = c.
get(prnt.
label());
661 O3PRM_CLASS_ILLEGAL_PARENT(prnt, *
__errors);
668 template <
typename GUM_SCALAR >
675 template <
typename GUM_SCALAR >
679 if (rule.first.size() != attr.
parents().size()) {
680 O3PRM_CLASS_ILLEGAL_RULE_SIZE(
687 template <
typename GUM_SCALAR >
693 for (std::size_t i = 0; i < attr.
parents().size(); ++i) {
694 auto label = rule.first[i];
699 if (label.label() !=
"*" 700 && std::find(real_labels.begin(), real_labels.end(), label.label())
701 == real_labels.end()) {
702 O3PRM_CLASS_ILLEGAL_RULE_LABEL(rule, label, prnt, *
__errors);
709 return errors ==
false;
712 template <
typename GUM_SCALAR >
717 for (
auto& f : rule.second) {
718 f.formula().variables().clear();
719 for (
const auto& values : scope) {
720 f.formula().variables().insert(values.first, values.second->value());
726 template <
typename GUM_SCALAR >
733 GUM_SCALAR sum = 0.0;
734 for (
const auto& f : rule.second) {
736 auto value = GUM_SCALAR(f.formula().result());
738 if (value < 0.0 || 1.0 < value) {
749 if (std::abs(sum - 1.0) > 1e-3) {
750 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
753 }
else if (std::abs(sum - 1.0f) > 1e-6) {
754 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
757 return errors ==
false;
760 template <
typename GUM_SCALAR >
763 const auto& scope = c.
scope();
765 for (
auto& rule : attr.
rules()) {
777 return errors ==
false;
780 template <
typename GUM_SCALAR >
785 auto domainSize = type->domainSize();
786 for (
auto& prnt : attr.
parents()) {
788 domainSize *= c.
get(prnt.label()).type()->domainSize();
797 if (domainSize != attr.
values().size()) {
798 O3PRM_CLASS_ILLEGAL_CPT_SIZE(c.
name(),
807 const auto& scope = c.
scope();
808 for (
auto& f : attr.
values()) {
809 f.formula().variables().clear();
811 for (
const auto& values : scope) {
812 f.formula().variables().insert(values.first, values.second->value());
817 Size parent_size = domainSize / type->domainSize();
818 auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
820 for (std::size_t i = 0; i < attr.
values().size(); ++i) {
822 auto idx = i % parent_size;
823 auto val = (GUM_SCALAR)attr.
values()[i].formula().result();
826 if (val < 0.0 || 1.0 < val) {
827 O3PRM_CLASS_ILLEGAL_CPT_VALUE(
832 O3PRM_CLASS_ILLEGAL_CPT_VALUE(
838 for (
auto f : values) {
839 if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
840 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(
843 }
else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
844 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(
851 template <
typename GUM_SCALAR >
856 auto s = chain.
label();
858 std::vector< std::string > v;
862 for (
size_t i = 0; i < v.size(); ++i) {
867 auto elt = &(current->get(link));
869 if (i == v.size() - 1) {
878 current = &(ref->slotType());
890 template <
typename GUM_SCALAR >
894 const std::string& s) {
896 O3PRM_CLASS_LINK_NOT_FOUND(chain, s, *
__errors);
902 template <
typename GUM_SCALAR >
909 auto params = std::vector< std::string >();
910 for (
auto& p : agg.parameters()) {
911 params.push_back(p.label());
915 agg.aggregateType().label(),
916 agg.variableType().label(),
925 template <
typename GUM_SCALAR >
936 template <
typename GUM_SCALAR >
941 auto t = (
const PRMType*)
nullptr;
943 for (
const auto& prnt : agg.
parents()) {
946 if (elt ==
nullptr) {
947 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *
__errors);
956 O3PRM_CLASS_WRONG_PARENT(prnt, *
__errors);
960 }
else if ((*t) != elt->type()) {
962 O3PRM_CLASS_WRONG_PARENT_TYPE(
963 prnt, t->name(), elt->type().name(), *
__errors);
971 template <
typename GUM_SCALAR >
980 && !agg_type.isSubTypeOf(super.get(agg.
name().
label()).type())) {
981 O3PRM_CLASS_ILLEGAL_OVERLOAD(
990 template <
typename GUM_SCALAR >
1019 if (!ok) {
return false; }
1038 template <
typename GUM_SCALAR >
1043 O3PRM_CLASS_AGG_PARAMETERS(
1051 template <
typename GUM_SCALAR >
1054 const auto& param = agg.
parameters().front();
1064 O3PRM_CLASS_AGG_PARAMETER_NOT_FOUND(agg.
name(), param, *
__errors);
std::pair< O3LabelList, O3FormulaList > O3Rule
bool __checkAggregateForDeclaration(O3Class &o3class, O3Aggregate &agg)
virtual O3LabelList & parents()
PRMParameter is a member of a Class in a PRM.
void completeAggregates()
bool __checkLabelsValues(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
virtual O3RuleList & rules()
DiscreteVariable & variable()
Return a reference on the DiscreteVariable contained in this.
void endAggregator()
Finishes an aggregate declaration.
bool __checkAttributeForCompletion(const O3Class &o3_c, O3Attribute &attr)
void __addParamsToForms(const HashTable< std::string, const PRMParameter< GUM_SCALAR > * > &scope, O3RuleCPT::O3Rule &rule)
const std::string & name() const
Returns the name of this object.
bool __checkAttributeForDeclaration(O3Class &o3_c, O3Attribute &attr)
HashTable< std::string, O3Attribute *> AttrMap
O3ReferenceSlotList & referenceSlots()
void buildReferenceSlots()
bool __checkRawCPT(const PRMClass< GUM_SCALAR > &c, O3RawCPT &attr)
The O3Aggregate is part of the AST of the O3PRM language.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
virtual void continueClass(const std::string &c) override
Continue the declaration of a class.
O3AggregateList & aggregates()
bool __checkReferenceSlot(O3Class &c, O3ReferenceSlot &ref)
HashTable< std::string, O3Aggregate *> AggMap
void __setO3ClassCreationOrder()
Abstract class representing an element of PRM class.
virtual void endAttribute() override
Tells the factory that we finished declaring an attribute.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
virtual void addReferenceSlot(const std::string &type, const std::string &name, bool isArray) override
Tells the factory that we started declaring a slot.
This class is used contain and manipulate gum::ParseError.
The O3Label is part of the AST of the O3PRM language.
bool __checkParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
The O3Attribute is part of the AST of the O3PRM language.
void __addReferenceSlots(O3Class &c)
virtual bool exists(const std::string &name) const
Returns true if a member with the given name exists in this PRMClassElementContainer or in the PRMCla...
virtual O3FormulaList & values()
void __completeAggregates(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
O3LabelList & parameters()
O3ParameterList & parameters()
bool __checkLocalParent(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
virtual NodeId addNode()
insert a new node and return its id
A PRMReferenceSlot represent a relation between two PRMClassElementContainer.
virtual void startClass(const std::string &c, const std::string &ext="", const Set< std::string > *implements=nullptr, bool delayInheritance=false) override
Tells the factory that we start a class declaration.
The class for generic Hash Tables.
void buildImplementations()
bool __checkRuleCPT(const PRMClass< GUM_SCALAR > &c, O3RuleCPT &attr)
void continueAggregator(const std::string &name)
Conitnues an aggregator declaration.
bool __checkImplementation(O3Class &c)
virtual void startAttribute(const std::string &type, const std::string &name, bool scalar_atttr=false) override
Tells the factory that we start an attribute declaration.
bool __checkRemoteParent(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &prnt)
HashTable< std::string, gum::NodeId > __nameMap
void __completeAttribute(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
The O3RuleCPT is part of the AST of the O3PRM language.
virtual Size domainSize() const =0
void setRawCPFByColumns(const std::vector< GUM_SCALAR > &array)
Gives the factory the CPF in its raw form.
Resolves names for the different O3PRM factories.
bool __checkAggTypeLegality(O3Class &o3class, O3Aggregate &agg)
bool __checkAggParameters(O3Class &o3class, O3Aggregate &agg, const PRMType *t)
The O3Class is part of the AST of the O3PRM language.
O3ClassFactory(PRM< GUM_SCALAR > &prm, O3PRM &o3_prm, O3NameSolver< GUM_SCALAR > &solver, ErrorsContainer &errors)
bool __checkParameterValue(O3Aggregate &agg, const gum::prm::PRMType &t)
Factory which builds a PRM<GUM_SCALAR>.
HashTable< std::string, const PRMParameter< GUM_SCALAR > *> scope() const
Returns all the parameters in the scope of this class.
The O3ReferenceSlot is part of the AST of the O3PRM language.
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
virtual void setCPFByRule(const std::vector< std::string > &labels, const std::vector< GUM_SCALAR > &values)
Fills the CPF using a rule.
HashTable< NodeId, O3Class *> __nodeMap
PRMClassElement< GUM_SCALAR > & get(NodeId id)
See gum::prm::PRMClassElementContainer<GUM_SCALAR>::get(NodeId).
bool __checkAggregateForCompletion(O3Class &o3class, O3Aggregate &agg)
HashTable< std::string, O3Class *> __classMap
Base class for all aGrUM's exceptions.
HashTable< std::string, O3ReferenceSlot *> RefMap
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
void __declareAggregates(O3Class &c)
This is a decoration of the DiscreteVariable class.
bool __checkLabelsNumber(const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
O3AttributeList & attributes()
virtual void endClass(bool checkImplementations=true) override
Tells the factory that we finished a class declaration.
Builds gum::prm::Class from gum::prm::o3prm::O3Class.
const Sequence< NodeId > & topologicalOrder(bool clear=true) const
The topological order stays the same as long as no variable or arcs are added or erased src the topol...
bool __checkRuleCPTSumsTo1(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
The O3RawCPT is part of the AST of the O3PRM language.
std::vector< O3Class *> __o3Classes
bool __checkSlotChainLink(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain, const std::string &s)
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
<agrum/PRM/classElementContainer.h>
O3Label & aggregateType()
ErrorsContainer * __errors
A PRMClass is an object of a PRM representing a fragment of a Bayesian Network which can be instantia...
O3NameSolver< GUM_SCALAR > * __solver
O3ClassFactory< GUM_SCALAR > & operator=(const O3ClassFactory< GUM_SCALAR > &src)
O3LabelList & interfaces()
bool __checkAndAddArcsToDag()
void decomposePath(const std::string &path, std::vector< std::string > &v)
Decompose a string in a vector of strings using "." as separators.
void __addParameters(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
const PRMClassElement< GUM_SCALAR > * __resolveSlotChain(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain)
virtual void continueAttribute(const std::string &name) override
Continues the declaration of an attribute.
void completeAttributes()
void addParameter(const std::string &type, const std::string &name, double value) override
Add a parameter to the current class with a default value.
bool __checkParametersNumber(O3Aggregate &agg, Size n)
The O3PRM is part of the AST of the O3PRM language.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
void __declareAttribute(O3Class &c)
const PRMType * __checkAggParents(O3Class &o3class, O3Aggregate &agg)
void insert(const Key &k)
Inserts a new element into the set.
void startAggregator(const std::string &name, const std::string &agg_type, const std::string &rv_type, const std::vector< std::string > ¶ms)
Start an aggregator declaration.
bool __checkAndAddNodesToDag()
virtual void addParent(const std::string &name) override
Tells the factory that we add a parent to the current declared attribute.
#define GUM_ERROR(type, msg)
PRM< GUM_SCALAR > * __prm