aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
PRMFormAttribute_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Inline implementation of gum::PRMFormAttribute
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 #include <iostream>
29 
30 #include <agrum/tools/core/math/formula.h>
31 
32 #include <agrum/PRM/elements/PRMScalarAttribute.h>
33 #include <agrum/PRM/elements/PRMType.h>
34 
35 // to ease IDE parser
36 #include <agrum/PRM/elements/PRMFormAttribute.h>
37 
38 namespace gum {
39  namespace prm {
40 
41  template < typename GUM_SCALAR >
42  PRMFormAttribute< GUM_SCALAR >::PRMFormAttribute(
43  const PRMClass< GUM_SCALAR >& c,
44  const std::string& name,
45  const PRMType& type,
46  MultiDimImplementation< std::string >* impl) :
47  PRMAttribute< GUM_SCALAR >(name),
48  type__(new PRMType(type)), cpf__(0), formulas__(impl), class__(&c) {
49  GUM_CONSTRUCTOR(PRMFormAttribute);
50  formulas__->add(type__->variable());
51  this->safeName_ = PRMObject::LEFT_CAST() + type__->name()
52  + PRMObject::RIGHT_CAST() + name;
53  }
54 
55  template < typename GUM_SCALAR >
56  PRMFormAttribute< GUM_SCALAR >::~PRMFormAttribute() {
57  GUM_DESTRUCTOR(PRMFormAttribute);
58  delete type__;
59  delete cpf__;
60  delete formulas__;
61  }
62 
63  template < typename GUM_SCALAR >
64  PRMAttribute< GUM_SCALAR >* PRMFormAttribute< GUM_SCALAR >::newFactory(
65  const PRMClass< GUM_SCALAR >& c) const {
66  auto impl = static_cast< MultiDimImplementation< std::string >* >(
67  this->formulas__->newFactory());
68  return new PRMFormAttribute< GUM_SCALAR >(c,
69  this->name(),
70  this->type(),
71  impl);
72  }
73 
74  template < typename GUM_SCALAR >
75  PRMAttribute< GUM_SCALAR >* PRMFormAttribute< GUM_SCALAR >::copy(
76  Bijection< const DiscreteVariable*, const DiscreteVariable* > bij) const {
77  auto copy = new PRMFormAttribute< GUM_SCALAR >(*class__,
78  this->name(),
79  this->type());
80  for (auto var: formulas__->variablesSequence()) {
81  if (var != &(type__->variable())) { copy->formulas__->add(*var); }
82  }
83 
84  Instantiation inst(*(copy->formulas__)), jnst(*formulas__);
85  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
86  inst.inc(), jnst.inc()) {
87  copy->formulas__->set(inst, formulas__->get(jnst));
88  }
89 
90  GUM_ASSERT(copy->formulas__->contains(copy->type__->variable()));
91  return copy;
92  }
93 
94  template < typename GUM_SCALAR >
95  void PRMFormAttribute< GUM_SCALAR >::copyCpf(
96  const Bijection< const DiscreteVariable*, const DiscreteVariable* >& bij,
97  const PRMAttribute< GUM_SCALAR >& source) {
98  delete formulas__;
99  formulas__ = new MultiDimArray< std::string >();
100 
101  for (const auto& var: source.cpf().variablesSequence()) {
102  formulas__->add(*(bij.second(var)));
103  }
104 
105  if (dynamic_cast< const PRMFormAttribute< GUM_SCALAR >* >(&source)) {
106  const auto& src
107  = static_cast< const PRMFormAttribute< GUM_SCALAR >& >(source);
108 
109  Instantiation inst(formulas__), jnst(src.formulas__);
110 
111  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
112  inst.inc(), jnst.inc()) {
113  formulas__->set(inst, src.formulas__->get(jnst));
114  }
115 
116  GUM_ASSERT(inst.end() && jnst.end());
117 
118  } else {
119  Instantiation inst(formulas__), jnst(source.cpf());
120 
121  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
122  inst.inc(), jnst.inc()) {
123  auto val = std::to_string(source.cpf().get(jnst));
124  formulas__->set(inst, val);
125  }
126 
127  GUM_ASSERT(inst.end() && jnst.end());
128  }
129 
130  if (cpf__) {
131  delete cpf__;
132  cpf__ = 0;
133  }
134 
135  GUM_ASSERT(formulas__->contains(type__->variable()));
136  GUM_ASSERT(!formulas__->contains(source.type().variable()));
137  }
138 
139  template < typename GUM_SCALAR >
140  typename PRMClassElement< GUM_SCALAR >::ClassElementType
141  PRMFormAttribute< GUM_SCALAR >::elt_type() const {
142  return this->prm_attribute;
143  }
144 
145  template < typename GUM_SCALAR >
146  PRMType& PRMFormAttribute< GUM_SCALAR >::type() {
147  return *type__;
148  }
149 
150  template < typename GUM_SCALAR >
151  const PRMType& PRMFormAttribute< GUM_SCALAR >::type() const {
152  return *type__;
153  }
154 
155  template < typename GUM_SCALAR >
156  const Potential< GUM_SCALAR >& PRMFormAttribute< GUM_SCALAR >::cpf() const {
157  if (cpf__ == 0) { fillCpf__(); }
158  return *cpf__;
159  }
160 
161  template < typename GUM_SCALAR >
162  void PRMFormAttribute< GUM_SCALAR >::addParent(
163  const PRMClassElement< GUM_SCALAR >& elt) {
164  try {
165  if (cpf__) {
166  delete cpf__;
167  cpf__ = 0;
168  }
169  formulas__->add(elt.type().variable());
170  } catch (DuplicateElement&) {
171  GUM_ERROR(DuplicateElement,
172  elt.name() << " as parent of " << this->name());
173  } catch (OperationNotAllowed&) {
174  GUM_ERROR(OperationNotAllowed,
175  elt.name() << " of wrong type as parent of " << this->name(););
176  }
177 
178  GUM_ASSERT(formulas__->contains(type__->variable()));
179  }
180 
181  template < typename GUM_SCALAR >
182  void PRMFormAttribute< GUM_SCALAR >::addChild(
183  const PRMClassElement< GUM_SCALAR >& elt) {}
184 
185  template < typename GUM_SCALAR >
186  PRMAttribute< GUM_SCALAR >*
187  PRMFormAttribute< GUM_SCALAR >::getCastDescendant() const {
188  PRMScalarAttribute< GUM_SCALAR >* cast = 0;
189 
190  try {
191  cast = new PRMScalarAttribute< GUM_SCALAR >(this->name(),
192  type().superType());
193  } catch (NotFound&) {
194  GUM_ERROR(OperationNotAllowed,
195  "this ScalarAttribute can not have cast descendant");
196  }
197 
198  cast->addParent(*this);
199 
200  const DiscreteVariable& my_var = type().variable();
201  DiscreteVariable& cast_var = cast->type().variable();
202  Instantiation inst(cast->cpf());
203 
204  for (inst.setFirst(); !inst.end(); inst.inc()) {
205  if (type().label_map()[inst.val(my_var)] == inst.val(cast_var)) {
206  cast->cpf().set(inst, 1);
207  } else {
208  cast->cpf().set(inst, 0);
209  }
210  }
211 
212  GUM_ASSERT(formulas__->contains(type__->variable()));
213  return cast;
214  }
215 
216  template < typename GUM_SCALAR >
217  void PRMFormAttribute< GUM_SCALAR >::setAsCastDescendant(
218  PRMAttribute< GUM_SCALAR >* cast) {
219  try {
220  type().setSuper(cast->type());
221  } catch (OperationNotAllowed&) {
222  GUM_ERROR(OperationNotAllowed,
223  "this ScalarAttribute can not have cast descendant");
224  } catch (TypeError&) {
225  std::stringstream msg;
226  msg << type().name() << " is not a subtype of " << cast->type().name();
227  GUM_ERROR(TypeError, msg.str());
228  }
229 
230  cast->becomeCastDescendant(type());
231  }
232 
233  template < typename GUM_SCALAR >
234  void PRMFormAttribute< GUM_SCALAR >::becomeCastDescendant(PRMType& subtype) {
235  delete formulas__;
236 
237  formulas__ = new MultiDimArray< std::string >();
238  formulas__->add(type().variable());
239  formulas__->add(subtype.variable());
240 
241  Instantiation inst(formulas__);
242 
243  for (inst.setFirst(); !inst.end(); inst.inc()) {
244  auto my_pos = inst.pos(subtype.variable());
245  if (subtype.label_map()[my_pos] == inst.pos(type().variable())) {
246  formulas__->set(inst, "1");
247  } else {
248  formulas__->set(inst, "0");
249  }
250  }
251 
252  if (cpf__) {
253  delete cpf__;
254  cpf__ = nullptr;
255  }
256  }
257 
258  template < typename GUM_SCALAR >
259  PRMFormAttribute< GUM_SCALAR >::PRMFormAttribute(
260  const PRMFormAttribute& source) :
261  PRMAttribute< GUM_SCALAR >(source.name()) {
262  GUM_CONS_CPY(PRMFormAttribute);
263  GUM_ERROR(OperationNotAllowed, "Cannot copy FormAttribute");
264  }
265 
266  template < typename GUM_SCALAR >
267  PRMFormAttribute< GUM_SCALAR >& PRMFormAttribute< GUM_SCALAR >::operator=(
268  const PRMFormAttribute< GUM_SCALAR >& source) {
269  GUM_ERROR(OperationNotAllowed, "Cannot copy FormAttribute");
270  }
271 
272  template < typename GUM_SCALAR >
273  void PRMFormAttribute< GUM_SCALAR >::fillCpf__() const {
274  try {
275  if (cpf__) { delete cpf__; }
276 
277  cpf__ = new Potential< GUM_SCALAR >();
278 
279  for (auto var: formulas__->variablesSequence()) {
280  cpf__->add(*var);
281  }
282 
283  auto params = class__->scope();
284 
285  Instantiation inst(formulas__);
286  Instantiation jnst(cpf__);
287 
288  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
289  inst.inc(), jnst.inc()) {
290  // With CPT defined using rules, empty values can appear
291  auto val = formulas__->get(inst);
292  if (val == "") { val = "0.0"; }
293 
294  Formula f(val);
295 
296  for (auto item: params) {
297  f.variables().insert(item.first, item.second->value());
298  }
299 
300  cpf__->set(jnst, (GUM_SCALAR)f.result());
301  }
302 
303  GUM_ASSERT(inst.end() && jnst.end());
304 
305  } catch (Exception&) { GUM_ERROR(NotFound, "undefined value in cpt"); }
306  GUM_ASSERT(formulas__->contains(type__->variable()));
307  }
308 
309  template < typename GUM_SCALAR >
310  MultiDimImplementation< std::string >&
311  PRMFormAttribute< GUM_SCALAR >::formulas() {
312  if (cpf__) {
313  delete cpf__;
314  cpf__ = 0;
315  }
316  return *formulas__;
317  }
318 
319  template < typename GUM_SCALAR >
320  const MultiDimImplementation< std::string >&
321  PRMFormAttribute< GUM_SCALAR >::formulas() const {
322  return *formulas__;
323  }
324 
325  template < typename GUM_SCALAR >
326  void PRMFormAttribute< GUM_SCALAR >::swap(const PRMType& old_type,
327  const PRMType& new_type) {
328  if (&(old_type) == type__) {
329  GUM_ERROR(OperationNotAllowed, "Cannot replace attribute own type");
330  }
331  if (old_type->domainSize() != new_type->domainSize()) {
332  GUM_ERROR(OperationNotAllowed,
333  "Cannot replace types with difference domain size");
334  }
335  if (!formulas__->contains(old_type.variable())) {
336  GUM_ERROR(NotFound, "could not find variable " + old_type.name());
337  }
338 
339  auto old = formulas__;
340 
341  formulas__ = new MultiDimArray< std::string >();
342 
343  for (auto var: old->variablesSequence()) {
344  if (var != &(old_type.variable())) {
345  formulas__->add(*var);
346  } else {
347  formulas__->add(new_type.variable());
348  }
349  }
350 
351  Instantiation inst(formulas__), jnst(old);
352 
353  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
354  inst.inc(), jnst.inc()) {
355  formulas__->set(inst, old->get(jnst));
356  }
357 
358  delete old;
359 
360  if (cpf__) {
361  delete cpf__;
362  cpf__ = 0;
363  }
364 
365  GUM_ASSERT(inst.end() && jnst.end());
366  GUM_ASSERT(formulas__->contains(type__->variable()));
367  GUM_ASSERT(!formulas__->contains(new_type.variable()));
368  GUM_ASSERT(formulas__->contains(new_type.variable()));
369  }
370 
371  template < typename GUM_SCALAR >
372  PRMType* PRMFormAttribute< GUM_SCALAR >::type_() {
373  return type__;
374  }
375 
376  template < typename GUM_SCALAR >
377  void PRMFormAttribute< GUM_SCALAR >::type_(PRMType* t) {
378  if (type__->variable().domainSize() != t->variable().domainSize()) {
379  GUM_ERROR(OperationNotAllowed,
380  "Cannot replace types with difference domain size");
381  }
382  auto old = formulas__;
383 
384  formulas__ = new MultiDimArray< std::string >();
385 
386  for (auto var: old->variablesSequence()) {
387  if (var != &(type__->variable())) {
388  formulas__->add(*var);
389  } else {
390  formulas__->add(t->variable());
391  }
392  }
393 
394  Instantiation inst(formulas__), jnst(old);
395 
396  for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
397  inst.inc(), jnst.inc()) {
398  formulas__->set(inst, old->get(jnst));
399  }
400 
401  delete old;
402 
403  type__ = t;
404 
405  if (cpf__) {
406  delete cpf__;
407  cpf__ = 0;
408  }
409 
410  GUM_ASSERT(formulas__->contains(type__->variable()));
411  GUM_ASSERT(inst.end() && jnst.end());
412  }
413 
414  } /* namespace prm */
415 } /* namespace gum */