30 #include <agrum/tools/core/math/formula.h> 32 #include <agrum/PRM/elements/PRMScalarAttribute.h> 33 #include <agrum/PRM/elements/PRMType.h> 36 #include <agrum/PRM/elements/PRMFormAttribute.h> 41 template <
typename GUM_SCALAR >
42 PRMFormAttribute< GUM_SCALAR >::PRMFormAttribute(
const PRMClass< GUM_SCALAR >& c,
43 const std::string& name,
45 MultiDimImplementation< std::string >* impl) :
46 PRMAttribute< GUM_SCALAR >(name),
47 _type_(
new PRMType(type)), _cpf_(0), _formulas_(impl), _class_(&c) {
48 GUM_CONSTRUCTOR(PRMFormAttribute);
49 _formulas_->add(_type_->variable());
50 this->safeName_ = PRMObject::LEFT_CAST() + _type_->name() + PRMObject::RIGHT_CAST() + name;
53 template <
typename GUM_SCALAR >
54 PRMFormAttribute< GUM_SCALAR >::~PRMFormAttribute() {
55 GUM_DESTRUCTOR(PRMFormAttribute);
61 template <
typename GUM_SCALAR >
62 PRMAttribute< GUM_SCALAR >*
63 PRMFormAttribute< GUM_SCALAR >::newFactory(
const PRMClass< GUM_SCALAR >& c)
const {
65 =
static_cast< MultiDimImplementation< std::string >* >(
this->_formulas_->newFactory());
66 return new PRMFormAttribute< GUM_SCALAR >(c,
this->name(),
this->type(), impl);
69 template <
typename GUM_SCALAR >
70 PRMAttribute< GUM_SCALAR >* PRMFormAttribute< GUM_SCALAR >::copy(
71 Bijection<
const DiscreteVariable*,
const DiscreteVariable* > bij)
const {
72 auto copy =
new PRMFormAttribute< GUM_SCALAR >(*_class_,
this->name(),
this->type());
73 for (
auto var: _formulas_->variablesSequence()) {
74 if (var != &(_type_->variable())) { copy->_formulas_->add(*var); }
77 Instantiation inst(*(copy->_formulas_)), jnst(*_formulas_);
78 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
79 copy->_formulas_->set(inst, _formulas_->get(jnst));
82 GUM_ASSERT(copy->_formulas_->contains(copy->_type_->variable()));
86 template <
typename GUM_SCALAR >
87 void PRMFormAttribute< GUM_SCALAR >::copyCpf(
88 const Bijection<
const DiscreteVariable*,
const DiscreteVariable* >& bij,
89 const PRMAttribute< GUM_SCALAR >& source) {
91 _formulas_ =
new MultiDimArray< std::string >();
93 for (
const auto& var: source.cpf().variablesSequence()) {
94 _formulas_->add(*(bij.second(var)));
97 if (
dynamic_cast<
const PRMFormAttribute< GUM_SCALAR >* >(&source)) {
98 const auto& src =
static_cast<
const PRMFormAttribute< GUM_SCALAR >& >(source);
100 Instantiation inst(_formulas_), jnst(src._formulas_);
102 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
103 inst.inc(), jnst.inc()) {
104 _formulas_->set(inst, src._formulas_->get(jnst));
107 GUM_ASSERT(inst.end() && jnst.end());
110 Instantiation inst(_formulas_), jnst(source.cpf());
112 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
113 inst.inc(), jnst.inc()) {
114 auto val = std::to_string(source.cpf().get(jnst));
115 _formulas_->set(inst, val);
118 GUM_ASSERT(inst.end() && jnst.end());
126 GUM_ASSERT(_formulas_->contains(_type_->variable()));
127 GUM_ASSERT(!_formulas_->contains(source.type().variable()));
130 template <
typename GUM_SCALAR >
131 typename PRMClassElement< GUM_SCALAR >::ClassElementType
132 PRMFormAttribute< GUM_SCALAR >::elt_type()
const {
133 return this->prm_attribute;
136 template <
typename GUM_SCALAR >
137 PRMType& PRMFormAttribute< GUM_SCALAR >::type() {
141 template <
typename GUM_SCALAR >
142 const PRMType& PRMFormAttribute< GUM_SCALAR >::type()
const {
146 template <
typename GUM_SCALAR >
147 const Potential< GUM_SCALAR >& PRMFormAttribute< GUM_SCALAR >::cpf()
const {
148 if (_cpf_ == 0) { _fillCpf_(); }
152 template <
typename GUM_SCALAR >
153 void PRMFormAttribute< GUM_SCALAR >::addParent(
const PRMClassElement< GUM_SCALAR >& elt) {
159 _formulas_->add(elt.type().variable());
160 }
catch (DuplicateElement&) {
161 GUM_ERROR(DuplicateElement, elt.name() <<
" as parent of " <<
this->name())
162 }
catch (OperationNotAllowed&) {
163 GUM_ERROR(OperationNotAllowed,
164 elt.name() <<
" of wrong type as parent of " <<
this->name();)
167 GUM_ASSERT(_formulas_->contains(_type_->variable()));
170 template <
typename GUM_SCALAR >
171 void PRMFormAttribute< GUM_SCALAR >::addChild(
const PRMClassElement< GUM_SCALAR >& elt) {}
173 template <
typename GUM_SCALAR >
174 PRMAttribute< GUM_SCALAR >* PRMFormAttribute< GUM_SCALAR >::getCastDescendant()
const {
175 PRMScalarAttribute< GUM_SCALAR >* cast = 0;
178 cast =
new PRMScalarAttribute< GUM_SCALAR >(
this->name(), type().superType());
179 }
catch (NotFound&) {
180 GUM_ERROR(OperationNotAllowed,
"this ScalarAttribute can not have cast descendant")
183 cast->addParent(*
this);
185 const DiscreteVariable& my_var = type().variable();
186 DiscreteVariable& cast_var = cast->type().variable();
187 Instantiation inst(cast->cpf());
189 for (inst.setFirst(); !inst.end(); inst.inc()) {
190 if (type().label_map()[inst.val(my_var)] == inst.val(cast_var)) {
191 cast->cpf().set(inst, 1);
193 cast->cpf().set(inst, 0);
197 GUM_ASSERT(_formulas_->contains(_type_->variable()));
201 template <
typename GUM_SCALAR >
202 void PRMFormAttribute< GUM_SCALAR >::setAsCastDescendant(PRMAttribute< GUM_SCALAR >* cast) {
204 type().setSuper(cast->type());
205 }
catch (OperationNotAllowed&) {
206 GUM_ERROR(OperationNotAllowed,
"this ScalarAttribute can not have cast descendant")
207 }
catch (TypeError&) {
208 std::stringstream msg;
209 msg << type().name() <<
" is not a subtype of " << cast->type().name();
210 GUM_ERROR(TypeError, msg.str())
213 cast->becomeCastDescendant(type());
216 template <
typename GUM_SCALAR >
217 void PRMFormAttribute< GUM_SCALAR >::becomeCastDescendant(PRMType& subtype) {
220 _formulas_ =
new MultiDimArray< std::string >();
221 _formulas_->add(type().variable());
222 _formulas_->add(subtype.variable());
224 Instantiation inst(_formulas_);
226 for (inst.setFirst(); !inst.end(); inst.inc()) {
227 auto my_pos = inst.pos(subtype.variable());
228 if (subtype.label_map()[my_pos] == inst.pos(type().variable())) {
229 _formulas_->set(inst,
"1");
231 _formulas_->set(inst,
"0");
241 template <
typename GUM_SCALAR >
242 PRMFormAttribute< GUM_SCALAR >::PRMFormAttribute(
const PRMFormAttribute& source) :
243 PRMAttribute< GUM_SCALAR >(source.name()) {
244 GUM_CONS_CPY(PRMFormAttribute);
245 GUM_ERROR(OperationNotAllowed,
"Cannot copy FormAttribute")
248 template <
typename GUM_SCALAR >
249 PRMFormAttribute< GUM_SCALAR >&
250 PRMFormAttribute< GUM_SCALAR >::operator=(
const PRMFormAttribute< GUM_SCALAR >& source) {
251 GUM_ERROR(OperationNotAllowed,
"Cannot copy FormAttribute")
254 template <
typename GUM_SCALAR >
255 void PRMFormAttribute< GUM_SCALAR >::_fillCpf_()
const {
257 if (_cpf_) {
delete _cpf_; }
259 _cpf_ =
new Potential< GUM_SCALAR >();
261 for (
auto var: _formulas_->variablesSequence()) {
265 auto params = _class_->scope();
267 Instantiation inst(_formulas_);
268 Instantiation jnst(_cpf_);
270 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
271 inst.inc(), jnst.inc()) {
273 auto val = _formulas_->get(inst);
274 if (val ==
"") { val =
"0.0"; }
278 for (
auto item: params) {
279 f.variables().insert(item.first, item.second->value());
282 _cpf_->set(jnst, (GUM_SCALAR)f.result());
285 GUM_ASSERT(inst.end() && jnst.end());
287 }
catch (Exception&) { GUM_ERROR(NotFound,
"undefined value in cpt") }
288 GUM_ASSERT(_formulas_->contains(_type_->variable()))
291 template <
typename GUM_SCALAR >
292 MultiDimImplementation< std::string >& PRMFormAttribute< GUM_SCALAR >::formulas() {
300 template <
typename GUM_SCALAR >
301 const MultiDimImplementation< std::string >& PRMFormAttribute< GUM_SCALAR >::formulas()
const {
305 template <
typename GUM_SCALAR >
306 void PRMFormAttribute< GUM_SCALAR >::swap(
const PRMType& old_type,
const PRMType& new_type) {
307 if (&(old_type) == _type_) {
308 GUM_ERROR(OperationNotAllowed,
"Cannot replace attribute own type")
310 if (old_type->domainSize() != new_type->domainSize()) {
311 GUM_ERROR(OperationNotAllowed,
"Cannot replace types with difference domain size")
313 if (!_formulas_->contains(old_type.variable())) {
314 GUM_ERROR(NotFound,
"could not find variable " + old_type.name())
317 auto old = _formulas_;
319 _formulas_ =
new MultiDimArray< std::string >();
321 for (
auto var: old->variablesSequence()) {
322 if (var != &(old_type.variable())) {
323 _formulas_->add(*var);
325 _formulas_->add(new_type.variable());
329 Instantiation inst(_formulas_), jnst(old);
331 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
332 _formulas_->set(inst, old->get(jnst));
342 GUM_ASSERT(inst.end() && jnst.end());
343 GUM_ASSERT(_formulas_->contains(_type_->variable()));
344 GUM_ASSERT(!_formulas_->contains(new_type.variable()));
345 GUM_ASSERT(_formulas_->contains(new_type.variable()));
348 template <
typename GUM_SCALAR >
349 PRMType* PRMFormAttribute< GUM_SCALAR >::type_() {
353 template <
typename GUM_SCALAR >
354 void PRMFormAttribute< GUM_SCALAR >::type_(PRMType* t) {
355 if (_type_->variable().domainSize() != t->variable().domainSize()) {
356 GUM_ERROR(OperationNotAllowed,
"Cannot replace types with difference domain size")
358 auto old = _formulas_;
360 _formulas_ =
new MultiDimArray< std::string >();
362 for (
auto var: old->variablesSequence()) {
363 if (var != &(_type_->variable())) {
364 _formulas_->add(*var);
366 _formulas_->add(t->variable());
370 Instantiation inst(_formulas_), jnst(old);
372 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
373 _formulas_->set(inst, old->get(jnst));
385 GUM_ASSERT(_formulas_->contains(_type_->variable()));
386 GUM_ASSERT(inst.end() && jnst.end());