31 #include <agrum/PRM/utils_prm.h> 33 #include <agrum/PRM/elements/PRMAttribute.h> 34 #include <agrum/PRM/elements/PRMType.h> 39 template <
typename GUM_SCALAR >
40 PRMScalarAttribute< GUM_SCALAR >::PRMScalarAttribute(
41 const std::string& name,
43 MultiDimImplementation< GUM_SCALAR >* impl) :
44 PRMAttribute< GUM_SCALAR >(name),
45 _type_(
new PRMType(type)), _cpf_(
new Potential< GUM_SCALAR >(impl)) {
46 GUM_CONSTRUCTOR(PRMScalarAttribute);
47 _cpf_->add(_type_->variable());
49 this->safeName_ = PRMObject::LEFT_CAST() + _type_->name() + PRMObject::RIGHT_CAST() + name;
52 template <
typename GUM_SCALAR >
53 PRMScalarAttribute< GUM_SCALAR >::PRMScalarAttribute(
54 const PRMScalarAttribute< GUM_SCALAR >& source) :
55 PRMAttribute< GUM_SCALAR >(source),
57 GUM_CONS_CPY(PRMScalarAttribute);
58 GUM_ERROR(FatalError,
"Illegal call to the copy constructor of gum::ScalarAttribute")
61 template <
typename GUM_SCALAR >
62 PRMScalarAttribute< GUM_SCALAR >::~PRMScalarAttribute() {
63 GUM_DESTRUCTOR(PRMScalarAttribute);
68 template <
typename GUM_SCALAR >
69 PRMAttribute< GUM_SCALAR >*
70 PRMScalarAttribute< GUM_SCALAR >::newFactory(
const PRMClass< GUM_SCALAR >& c)
const {
71 auto impl =
static_cast< MultiDimImplementation< GUM_SCALAR >* >(
72 this->cpf().content()->newFactory());
73 return new PRMScalarAttribute< GUM_SCALAR >(
this->name(),
this->type(), impl);
76 template <
typename GUM_SCALAR >
77 PRMAttribute< GUM_SCALAR >* PRMScalarAttribute< GUM_SCALAR >::copy(
78 Bijection<
const DiscreteVariable*,
const DiscreteVariable* > bij)
const {
79 auto copy =
new PRMScalarAttribute< GUM_SCALAR >(
this->name(),
this->type());
81 if (!bij.existsFirst(&(type().variable()))) {
82 bij.insert(&(type().variable()), &(copy->type().variable()));
86 copy->_cpf_ = copyPotential(bij, cpf());
91 template <
typename GUM_SCALAR >
92 void PRMScalarAttribute< GUM_SCALAR >::copyCpf(
93 const Bijection<
const DiscreteVariable*,
const DiscreteVariable* >& bij,
94 const PRMAttribute< GUM_SCALAR >& source) {
96 _cpf_ =
new Potential< GUM_SCALAR >();
98 for (
auto var: source.cpf().variablesSequence()) {
99 _cpf_->add(*(bij.second(var)));
102 Instantiation inst(*_cpf_), jnst(source.cpf());
104 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
105 _cpf_->set(inst, source.cpf().get(jnst));
108 GUM_ASSERT(inst.end() && jnst.end());
109 GUM_ASSERT(_cpf_->contains(_type_->variable()));
110 GUM_ASSERT(!_cpf_->contains(source.type().variable()));
113 template <
typename GUM_SCALAR >
114 PRMScalarAttribute< GUM_SCALAR >&
115 PRMScalarAttribute< GUM_SCALAR >::operator=(
const PRMScalarAttribute< GUM_SCALAR >& from) {
116 GUM_ERROR(FatalError,
"Illegal call to the copy operator of gum::ScalarAttribute")
119 template <
typename GUM_SCALAR >
120 INLINE
typename PRMClassElement< GUM_SCALAR >::ClassElementType
121 PRMScalarAttribute< GUM_SCALAR >::elt_type()
const {
122 return this->prm_attribute;
125 template <
typename GUM_SCALAR >
126 INLINE PRMType& PRMScalarAttribute< GUM_SCALAR >::type() {
130 template <
typename GUM_SCALAR >
131 INLINE
const PRMType& PRMScalarAttribute< GUM_SCALAR >::type()
const {
135 template <
typename GUM_SCALAR >
136 INLINE
const Potential< GUM_SCALAR >& PRMScalarAttribute< GUM_SCALAR >::cpf()
const {
140 template <
typename GUM_SCALAR >
142 PRMScalarAttribute< GUM_SCALAR >::addParent(
const PRMClassElement< GUM_SCALAR >& elt) {
144 _cpf_->add(elt.type().variable());
145 }
catch (DuplicateElement&) {
146 GUM_ERROR(DuplicateElement, elt.name() <<
" as parent of " <<
this->name())
147 }
catch (OperationNotAllowed&) {
148 GUM_ERROR(OperationNotAllowed, elt.name() <<
" of wrong type as parent of " <<
this->name())
153 template <
typename GUM_SCALAR >
155 PRMScalarAttribute< GUM_SCALAR >::addChild(
const PRMClassElement< GUM_SCALAR >& elt) {}
157 template <
typename GUM_SCALAR >
158 PRMAttribute< GUM_SCALAR >* PRMScalarAttribute< GUM_SCALAR >::getCastDescendant()
const {
159 PRMScalarAttribute< GUM_SCALAR >* cast = 0;
162 cast =
new PRMScalarAttribute< GUM_SCALAR >(
this->name(), type().superType());
163 }
catch (NotFound&) {
164 GUM_ERROR(OperationNotAllowed,
"this ScalarAttribute can not have cast descendant")
167 cast->addParent(*
this);
168 const DiscreteVariable& my_var = type().variable();
169 DiscreteVariable& cast_var = cast->type().variable();
170 Instantiation inst(cast->cpf());
172 for (inst.setFirst(); !inst.end(); inst.inc()) {
173 if (type().label_map()[inst.val(my_var)] == inst.val(cast_var)) {
174 cast->cpf().set(inst, 1);
176 cast->cpf().set(inst, 0);
183 template <
typename GUM_SCALAR >
184 void PRMScalarAttribute< GUM_SCALAR >::setAsCastDescendant(PRMAttribute< GUM_SCALAR >* cast) {
186 type().setSuper(cast->type());
187 }
catch (OperationNotAllowed&) {
188 GUM_ERROR(OperationNotAllowed,
"this ScalarAttribute can not have cast descendant")
189 }
catch (TypeError&) {
190 std::stringstream msg;
191 msg << type().name() <<
" is not a subtype of " << cast->type().name();
192 GUM_ERROR(TypeError, msg.str())
194 cast->becomeCastDescendant(type());
197 template <
typename GUM_SCALAR >
198 void PRMScalarAttribute< GUM_SCALAR >::becomeCastDescendant(PRMType& subtype) {
200 _cpf_ =
new Potential< GUM_SCALAR >();
201 _cpf_->add(type().variable());
202 _cpf_->add(subtype.variable());
204 Instantiation inst(*_cpf_);
206 for (inst.setFirst(); !inst.end(); inst.inc()) {
207 auto my_pos = inst.pos(subtype.variable());
208 if (subtype.label_map()[my_pos] == inst.pos(type().variable())) {
216 template <
typename GUM_SCALAR >
217 void PRMScalarAttribute< GUM_SCALAR >::swap(
const PRMType& old_type,
const PRMType& new_type) {
218 if (&(old_type) == _type_) {
219 GUM_ERROR(OperationNotAllowed,
"Cannot replace attribute own type")
221 if (old_type->domainSize() != new_type->domainSize()) {
222 GUM_ERROR(OperationNotAllowed,
"Cannot replace types with difference domain size")
224 if (!_cpf_->contains(old_type.variable())) {
225 GUM_ERROR(NotFound,
"could not find variable " + old_type.name())
230 _cpf_ =
new Potential< GUM_SCALAR >();
232 for (
auto var: old->variablesSequence()) {
233 if (var != &(old_type.variable())) {
236 _cpf_->add(new_type.variable());
240 Instantiation inst(_cpf_), jnst(old);
242 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
243 _cpf_->set(inst, old->get(jnst));
248 GUM_ASSERT(inst.end() && jnst.end());
249 GUM_ASSERT(_cpf_->contains(_type_->variable()));
250 GUM_ASSERT(_cpf_->contains(new_type.variable()));
251 GUM_ASSERT(!_cpf_->contains(old_type.variable()));
254 template <
typename GUM_SCALAR >
255 PRMType* PRMScalarAttribute< GUM_SCALAR >::type_() {
259 template <
typename GUM_SCALAR >
260 void PRMScalarAttribute< GUM_SCALAR >::type_(PRMType* t) {
261 if (_type_->variable().domainSize() != t->variable().domainSize()) {
262 GUM_ERROR(OperationNotAllowed,
"Cannot replace types with difference domain size")
266 _cpf_ =
new Potential< GUM_SCALAR >();
268 for (
auto var: old->variablesSequence()) {
269 if (var != &(_type_->variable())) {
272 _cpf_->add(t->variable());
276 Instantiation inst(_cpf_), jnst(old);
278 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end()); inst.inc(), jnst.inc()) {
279 _cpf_->set(inst, old->get(jnst));
286 GUM_ASSERT(_cpf_->contains(_type_->variable()));
287 GUM_ASSERT(inst.end() && jnst.end());