aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
PRMFormAttribute_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Inline implementation of gum::PRMFormAttribute
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 #include <iostream>
29 
30 #include <agrum/tools/core/math/formula.h>
31 
32 #include <agrum/PRM/elements/PRMScalarAttribute.h>
33 #include <agrum/PRM/elements/PRMType.h>
34 
35 // to ease IDE parser
36 #include <agrum/PRM/elements/PRMFormAttribute.h>
37 
38 namespace gum {
39  namespace prm {
40 
41  template < typename GUM_SCALAR >
42  PRMFormAttribute< GUM_SCALAR >::PRMFormAttribute(const PRMClass< GUM_SCALAR >& c,
43  const std::string& name,
44  const PRMType& type,
45  MultiDimImplementation< std::string >* impl) :
46  PRMAttribute< GUM_SCALAR >(name),
47  _type_(new PRMType(type)), _cpf_(0), _formulas_(impl), _class_(&c) {
48  GUM_CONSTRUCTOR(PRMFormAttribute);
49  _formulas_->add(_type_->variable());
50  this->safeName_ = PRMObject::LEFT_CAST() + _type_->name() + PRMObject::RIGHT_CAST() + name;
51  }
52 
53  template < typename GUM_SCALAR >
54  PRMFormAttribute< GUM_SCALAR >::~PRMFormAttribute() {
55  GUM_DESTRUCTOR(PRMFormAttribute);
56  delete _type_;
57  delete _cpf_;
58  delete _formulas_;
59  }
60 
61  template < typename GUM_SCALAR >
62  PRMAttribute< GUM_SCALAR >*
63  PRMFormAttribute< GUM_SCALAR >::newFactory(const PRMClass< GUM_SCALAR >& c) const {
64  auto impl
65  = static_cast< MultiDimImplementation< std::string >* >(this->_formulas_->newFactory());
66  return new PRMFormAttribute< GUM_SCALAR >(c, this->name(), this->type(), impl);
67  }
68 
69  template < typename GUM_SCALAR >
70  PRMAttribute< GUM_SCALAR >* PRMFormAttribute< GUM_SCALAR >::copy(
71  Bijection< const DiscreteVariable*, const DiscreteVariable* > bij) const {
72  auto copy = new PRMFormAttribute< GUM_SCALAR >(*_class_, this->name(), this->type());
73  for (auto var: _formulas_->variablesSequence()) {
74  if (var != &(_type_->variable())) { copy->_formulas_->add(*var); }
75  }
76 
77  Instantiation inst(*(copy->_formulas_)), jnst(*_formulas_);
78  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
79  copy->_formulas_->set(inst, _formulas_->get(jnst));
80  }
81 
82  GUM_ASSERT(copy->_formulas_->contains(copy->_type_->variable()));
83  return copy;
84  }
85 
86  template < typename GUM_SCALAR >
87  void PRMFormAttribute< GUM_SCALAR >::copyCpf(
88  const Bijection< const DiscreteVariable*, const DiscreteVariable* >& bij,
89  const PRMAttribute< GUM_SCALAR >& source) {
90  delete _formulas_;
91  _formulas_ = new MultiDimArray< std::string >();
92 
93  for (const auto& var: source.cpf().variablesSequence()) {
94  _formulas_->add(*(bij.second(var)));
95  }
96 
97  if (dynamic_cast< const PRMFormAttribute< GUM_SCALAR >* >(&source)) {
98  const auto& src = static_cast< const PRMFormAttribute< GUM_SCALAR >& >(source);
99 
100  Instantiation inst(_formulas_), jnst(src._formulas_);
101 
102  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
103  inst.inc(), jnst.inc()) {
104  _formulas_->set(inst, src._formulas_->get(jnst));
105  }
106 
107  GUM_ASSERT(inst.end() && jnst.end());
108 
109  } else {
110  Instantiation inst(_formulas_), jnst(source.cpf());
111 
112  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
113  inst.inc(), jnst.inc()) {
114  auto val = std::to_string(source.cpf().get(jnst));
115  _formulas_->set(inst, val);
116  }
117 
118  GUM_ASSERT(inst.end() && jnst.end());
119  }
120 
121  if (_cpf_) {
122  delete _cpf_;
123  _cpf_ = 0;
124  }
125 
126  GUM_ASSERT(_formulas_->contains(_type_->variable()));
127  GUM_ASSERT(!_formulas_->contains(source.type().variable()));
128  }
129 
130  template < typename GUM_SCALAR >
131  typename PRMClassElement< GUM_SCALAR >::ClassElementType
132  PRMFormAttribute< GUM_SCALAR >::elt_type() const {
133  return this->prm_attribute;
134  }
135 
136  template < typename GUM_SCALAR >
137  PRMType& PRMFormAttribute< GUM_SCALAR >::type() {
138  return *_type_;
139  }
140 
141  template < typename GUM_SCALAR >
142  const PRMType& PRMFormAttribute< GUM_SCALAR >::type() const {
143  return *_type_;
144  }
145 
146  template < typename GUM_SCALAR >
147  const Potential< GUM_SCALAR >& PRMFormAttribute< GUM_SCALAR >::cpf() const {
148  if (_cpf_ == 0) { _fillCpf_(); }
149  return *_cpf_;
150  }
151 
152  template < typename GUM_SCALAR >
153  void PRMFormAttribute< GUM_SCALAR >::addParent(const PRMClassElement< GUM_SCALAR >& elt) {
154  try {
155  if (_cpf_) {
156  delete _cpf_;
157  _cpf_ = 0;
158  }
159  _formulas_->add(elt.type().variable());
160  } catch (DuplicateElement&) {
161  GUM_ERROR(DuplicateElement, elt.name() << " as parent of " << this->name())
162  } catch (OperationNotAllowed&) {
163  GUM_ERROR(OperationNotAllowed,
164  elt.name() << " of wrong type as parent of " << this->name();)
165  }
166 
167  GUM_ASSERT(_formulas_->contains(_type_->variable()));
168  }
169 
170  template < typename GUM_SCALAR >
171  void PRMFormAttribute< GUM_SCALAR >::addChild(const PRMClassElement< GUM_SCALAR >& elt) {}
172 
173  template < typename GUM_SCALAR >
174  PRMAttribute< GUM_SCALAR >* PRMFormAttribute< GUM_SCALAR >::getCastDescendant() const {
175  PRMScalarAttribute< GUM_SCALAR >* cast = 0;
176 
177  try {
178  cast = new PRMScalarAttribute< GUM_SCALAR >(this->name(), type().superType());
179  } catch (NotFound&) {
180  GUM_ERROR(OperationNotAllowed, "this ScalarAttribute can not have cast descendant")
181  }
182 
183  cast->addParent(*this);
184 
185  const DiscreteVariable& my_var = type().variable();
186  DiscreteVariable& cast_var = cast->type().variable();
187  Instantiation inst(cast->cpf());
188 
189  for (inst.setFirst(); !inst.end(); inst.inc()) {
190  if (type().label_map()[inst.val(my_var)] == inst.val(cast_var)) {
191  cast->cpf().set(inst, 1);
192  } else {
193  cast->cpf().set(inst, 0);
194  }
195  }
196 
197  GUM_ASSERT(_formulas_->contains(_type_->variable()));
198  return cast;
199  }
200 
201  template < typename GUM_SCALAR >
202  void PRMFormAttribute< GUM_SCALAR >::setAsCastDescendant(PRMAttribute< GUM_SCALAR >* cast) {
203  try {
204  type().setSuper(cast->type());
205  } catch (OperationNotAllowed&) {
206  GUM_ERROR(OperationNotAllowed, "this ScalarAttribute can not have cast descendant")
207  } catch (TypeError&) {
208  std::stringstream msg;
209  msg << type().name() << " is not a subtype of " << cast->type().name();
210  GUM_ERROR(TypeError, msg.str())
211  }
212 
213  cast->becomeCastDescendant(type());
214  }
215 
216  template < typename GUM_SCALAR >
217  void PRMFormAttribute< GUM_SCALAR >::becomeCastDescendant(PRMType& subtype) {
218  delete _formulas_;
219 
220  _formulas_ = new MultiDimArray< std::string >();
221  _formulas_->add(type().variable());
222  _formulas_->add(subtype.variable());
223 
224  Instantiation inst(_formulas_);
225 
226  for (inst.setFirst(); !inst.end(); inst.inc()) {
227  auto my_pos = inst.pos(subtype.variable());
228  if (subtype.label_map()[my_pos] == inst.pos(type().variable())) {
229  _formulas_->set(inst, "1");
230  } else {
231  _formulas_->set(inst, "0");
232  }
233  }
234 
235  if (_cpf_) {
236  delete _cpf_;
237  _cpf_ = nullptr;
238  }
239  }
240 
241  template < typename GUM_SCALAR >
242  PRMFormAttribute< GUM_SCALAR >::PRMFormAttribute(const PRMFormAttribute& source) :
243  PRMAttribute< GUM_SCALAR >(source.name()) {
244  GUM_CONS_CPY(PRMFormAttribute);
245  GUM_ERROR(OperationNotAllowed, "Cannot copy FormAttribute")
246  }
247 
248  template < typename GUM_SCALAR >
249  PRMFormAttribute< GUM_SCALAR >&
250  PRMFormAttribute< GUM_SCALAR >::operator=(const PRMFormAttribute< GUM_SCALAR >& source) {
251  GUM_ERROR(OperationNotAllowed, "Cannot copy FormAttribute")
252  }
253 
254  template < typename GUM_SCALAR >
255  void PRMFormAttribute< GUM_SCALAR >::_fillCpf_() const {
256  try {
257  if (_cpf_) { delete _cpf_; }
258 
259  _cpf_ = new Potential< GUM_SCALAR >();
260 
261  for (auto var: _formulas_->variablesSequence()) {
262  _cpf_->add(*var);
263  }
264 
265  auto params = _class_->scope();
266 
267  Instantiation inst(_formulas_);
268  Instantiation jnst(_cpf_);
269 
270  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
271  inst.inc(), jnst.inc()) {
272  // With CPT defined using rules, empty values can appear
273  auto val = _formulas_->get(inst);
274  if (val == "") { val = "0.0"; }
275 
276  Formula f(val);
277 
278  for (auto item: params) {
279  f.variables().insert(item.first, item.second->value());
280  }
281 
282  _cpf_->set(jnst, (GUM_SCALAR)f.result());
283  }
284 
285  GUM_ASSERT(inst.end() && jnst.end());
286 
287  } catch (Exception&) { GUM_ERROR(NotFound, "undefined value in cpt") }
288  GUM_ASSERT(_formulas_->contains(_type_->variable()))
289  }
290 
291  template < typename GUM_SCALAR >
292  MultiDimImplementation< std::string >& PRMFormAttribute< GUM_SCALAR >::formulas() {
293  if (_cpf_) {
294  delete _cpf_;
295  _cpf_ = 0;
296  }
297  return *_formulas_;
298  }
299 
300  template < typename GUM_SCALAR >
301  const MultiDimImplementation< std::string >& PRMFormAttribute< GUM_SCALAR >::formulas() const {
302  return *_formulas_;
303  }
304 
305  template < typename GUM_SCALAR >
306  void PRMFormAttribute< GUM_SCALAR >::swap(const PRMType& old_type, const PRMType& new_type) {
307  if (&(old_type) == _type_) {
308  GUM_ERROR(OperationNotAllowed, "Cannot replace attribute own type")
309  }
310  if (old_type->domainSize() != new_type->domainSize()) {
311  GUM_ERROR(OperationNotAllowed, "Cannot replace types with difference domain size")
312  }
313  if (!_formulas_->contains(old_type.variable())) {
314  GUM_ERROR(NotFound, "could not find variable " + old_type.name())
315  }
316 
317  auto old = _formulas_;
318 
319  _formulas_ = new MultiDimArray< std::string >();
320 
321  for (auto var: old->variablesSequence()) {
322  if (var != &(old_type.variable())) {
323  _formulas_->add(*var);
324  } else {
325  _formulas_->add(new_type.variable());
326  }
327  }
328 
329  Instantiation inst(_formulas_), jnst(old);
330 
331  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
332  _formulas_->set(inst, old->get(jnst));
333  }
334 
335  delete old;
336 
337  if (_cpf_) {
338  delete _cpf_;
339  _cpf_ = 0;
340  }
341 
342  GUM_ASSERT(inst.end() && jnst.end());
343  GUM_ASSERT(_formulas_->contains(_type_->variable()));
344  GUM_ASSERT(!_formulas_->contains(new_type.variable()));
345  GUM_ASSERT(_formulas_->contains(new_type.variable()));
346  }
347 
348  template < typename GUM_SCALAR >
349  PRMType* PRMFormAttribute< GUM_SCALAR >::type_() {
350  return _type_;
351  }
352 
353  template < typename GUM_SCALAR >
354  void PRMFormAttribute< GUM_SCALAR >::type_(PRMType* t) {
355  if (_type_->variable().domainSize() != t->variable().domainSize()) {
356  GUM_ERROR(OperationNotAllowed, "Cannot replace types with difference domain size")
357  }
358  auto old = _formulas_;
359 
360  _formulas_ = new MultiDimArray< std::string >();
361 
362  for (auto var: old->variablesSequence()) {
363  if (var != &(_type_->variable())) {
364  _formulas_->add(*var);
365  } else {
366  _formulas_->add(t->variable());
367  }
368  }
369 
370  Instantiation inst(_formulas_), jnst(old);
371 
372  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
373  _formulas_->set(inst, old->get(jnst));
374  }
375 
376  delete old;
377 
378  _type_ = t;
379 
380  if (_cpf_) {
381  delete _cpf_;
382  _cpf_ = 0;
383  }
384 
385  GUM_ASSERT(_formulas_->contains(_type_->variable()));
386  GUM_ASSERT(inst.end() && jnst.end());
387  }
388 
389  } /* namespace prm */
390 } /* namespace gum */