31 #include <agrum/PRM/utils_prm.h> 33 #include <agrum/PRM/elements/PRMAttribute.h> 34 #include <agrum/PRM/elements/PRMType.h> 39 template <
typename GUM_SCALAR >
40 PRMScalarAttribute< GUM_SCALAR >::PRMScalarAttribute(
41 const std::string& name,
43 MultiDimImplementation< GUM_SCALAR >* impl) :
44 PRMAttribute< GUM_SCALAR >(name),
45 type__(
new PRMType(type)), cpf__(
new Potential< GUM_SCALAR >(impl)) {
46 GUM_CONSTRUCTOR(PRMScalarAttribute);
47 cpf__->add(type__->variable());
49 this->safeName_ = PRMObject::LEFT_CAST() + type__->name()
50 + PRMObject::RIGHT_CAST() + name;
53 template <
typename GUM_SCALAR >
54 PRMScalarAttribute< GUM_SCALAR >::PRMScalarAttribute(
55 const PRMScalarAttribute< GUM_SCALAR >& source) :
56 PRMAttribute< GUM_SCALAR >(source),
58 GUM_CONS_CPY(PRMScalarAttribute);
60 "Illegal call to the copy constructor of gum::ScalarAttribute");
63 template <
typename GUM_SCALAR >
64 PRMScalarAttribute< GUM_SCALAR >::~PRMScalarAttribute() {
65 GUM_DESTRUCTOR(PRMScalarAttribute);
70 template <
typename GUM_SCALAR >
71 PRMAttribute< GUM_SCALAR >* PRMScalarAttribute< GUM_SCALAR >::newFactory(
72 const PRMClass< GUM_SCALAR >& c)
const {
73 auto impl =
static_cast< MultiDimImplementation< GUM_SCALAR >* >(
74 this->cpf().content()->newFactory());
75 return new PRMScalarAttribute< GUM_SCALAR >(
this->name(),
80 template <
typename GUM_SCALAR >
81 PRMAttribute< GUM_SCALAR >* PRMScalarAttribute< GUM_SCALAR >::copy(
82 Bijection<
const DiscreteVariable*,
const DiscreteVariable* > bij)
const {
83 auto copy =
new PRMScalarAttribute< GUM_SCALAR >(
this->name(),
this->type());
85 if (!bij.existsFirst(&(type().variable()))) {
86 bij.insert(&(type().variable()), &(copy->type().variable()));
90 copy->cpf__ = copyPotential(bij, cpf());
95 template <
typename GUM_SCALAR >
96 void PRMScalarAttribute< GUM_SCALAR >::copyCpf(
97 const Bijection<
const DiscreteVariable*,
const DiscreteVariable* >& bij,
98 const PRMAttribute< GUM_SCALAR >& source) {
100 cpf__ =
new Potential< GUM_SCALAR >();
102 for (
auto var: source.cpf().variablesSequence()) {
103 cpf__->add(*(bij.second(var)));
106 Instantiation inst(*cpf__), jnst(source.cpf());
108 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
109 inst.inc(), jnst.inc()) {
110 cpf__->set(inst, source.cpf().get(jnst));
113 GUM_ASSERT(inst.end() && jnst.end());
114 GUM_ASSERT(cpf__->contains(type__->variable()));
115 GUM_ASSERT(!cpf__->contains(source.type().variable()));
118 template <
typename GUM_SCALAR >
119 PRMScalarAttribute< GUM_SCALAR >& PRMScalarAttribute< GUM_SCALAR >::operator=(
120 const PRMScalarAttribute< GUM_SCALAR >& from) {
121 GUM_ERROR(FatalError,
122 "Illegal call to the copy operator of gum::ScalarAttribute");
125 template <
typename GUM_SCALAR >
126 INLINE
typename PRMClassElement< GUM_SCALAR >::ClassElementType
127 PRMScalarAttribute< GUM_SCALAR >::elt_type()
const {
128 return this->prm_attribute;
131 template <
typename GUM_SCALAR >
132 INLINE PRMType& PRMScalarAttribute< GUM_SCALAR >::type() {
136 template <
typename GUM_SCALAR >
137 INLINE
const PRMType& PRMScalarAttribute< GUM_SCALAR >::type()
const {
141 template <
typename GUM_SCALAR >
142 INLINE
const Potential< GUM_SCALAR >&
143 PRMScalarAttribute< GUM_SCALAR >::cpf()
const {
147 template <
typename GUM_SCALAR >
148 INLINE
void PRMScalarAttribute< GUM_SCALAR >::addParent(
149 const PRMClassElement< GUM_SCALAR >& elt) {
151 cpf__->add(elt.type().variable());
152 }
catch (DuplicateElement&) {
153 GUM_ERROR(DuplicateElement,
154 elt.name() <<
" as parent of " <<
this->name());
155 }
catch (OperationNotAllowed&) {
156 GUM_ERROR(OperationNotAllowed,
157 elt.name() <<
" of wrong type as parent of " <<
this->name());
162 template <
typename GUM_SCALAR >
163 INLINE
void PRMScalarAttribute< GUM_SCALAR >::addChild(
164 const PRMClassElement< GUM_SCALAR >& elt) {}
166 template <
typename GUM_SCALAR >
167 PRMAttribute< GUM_SCALAR >*
168 PRMScalarAttribute< GUM_SCALAR >::getCastDescendant()
const {
169 PRMScalarAttribute< GUM_SCALAR >* cast = 0;
172 cast =
new PRMScalarAttribute< GUM_SCALAR >(
this->name(),
174 }
catch (NotFound&) {
175 GUM_ERROR(OperationNotAllowed,
176 "this ScalarAttribute can not have cast descendant");
179 cast->addParent(*
this);
180 const DiscreteVariable& my_var = type().variable();
181 DiscreteVariable& cast_var = cast->type().variable();
182 Instantiation inst(cast->cpf());
184 for (inst.setFirst(); !inst.end(); inst.inc()) {
185 if (type().label_map()[inst.val(my_var)] == inst.val(cast_var)) {
186 cast->cpf().set(inst, 1);
188 cast->cpf().set(inst, 0);
195 template <
typename GUM_SCALAR >
196 void PRMScalarAttribute< GUM_SCALAR >::setAsCastDescendant(
197 PRMAttribute< GUM_SCALAR >* cast) {
199 type().setSuper(cast->type());
200 }
catch (OperationNotAllowed&) {
201 GUM_ERROR(OperationNotAllowed,
202 "this ScalarAttribute can not have cast descendant");
203 }
catch (TypeError&) {
204 std::stringstream msg;
205 msg << type().name() <<
" is not a subtype of " << cast->type().name();
206 GUM_ERROR(TypeError, msg.str());
208 cast->becomeCastDescendant(type());
211 template <
typename GUM_SCALAR >
212 void PRMScalarAttribute< GUM_SCALAR >::becomeCastDescendant(PRMType& subtype) {
214 cpf__ =
new Potential< GUM_SCALAR >();
215 cpf__->add(type().variable());
216 cpf__->add(subtype.variable());
218 Instantiation inst(*cpf__);
220 for (inst.setFirst(); !inst.end(); inst.inc()) {
221 auto my_pos = inst.pos(subtype.variable());
222 if (subtype.label_map()[my_pos] == inst.pos(type().variable())) {
230 template <
typename GUM_SCALAR >
231 void PRMScalarAttribute< GUM_SCALAR >::swap(
const PRMType& old_type,
232 const PRMType& new_type) {
233 if (&(old_type) == type__) {
234 GUM_ERROR(OperationNotAllowed,
"Cannot replace attribute own type");
236 if (old_type->domainSize() != new_type->domainSize()) {
237 GUM_ERROR(OperationNotAllowed,
238 "Cannot replace types with difference domain size");
240 if (!cpf__->contains(old_type.variable())) {
241 GUM_ERROR(NotFound,
"could not find variable " + old_type.name());
246 cpf__ =
new Potential< GUM_SCALAR >();
248 for (
auto var: old->variablesSequence()) {
249 if (var != &(old_type.variable())) {
252 cpf__->add(new_type.variable());
256 Instantiation inst(cpf__), jnst(old);
258 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
259 inst.inc(), jnst.inc()) {
260 cpf__->set(inst, old->get(jnst));
265 GUM_ASSERT(inst.end() && jnst.end());
266 GUM_ASSERT(cpf__->contains(type__->variable()));
267 GUM_ASSERT(cpf__->contains(new_type.variable()));
268 GUM_ASSERT(!cpf__->contains(old_type.variable()));
271 template <
typename GUM_SCALAR >
272 PRMType* PRMScalarAttribute< GUM_SCALAR >::type_() {
276 template <
typename GUM_SCALAR >
277 void PRMScalarAttribute< GUM_SCALAR >::type_(PRMType* t) {
278 if (type__->variable().domainSize() != t->variable().domainSize()) {
279 GUM_ERROR(OperationNotAllowed,
280 "Cannot replace types with difference domain size");
284 cpf__ =
new Potential< GUM_SCALAR >();
286 for (
auto var: old->variablesSequence()) {
287 if (var != &(type__->variable())) {
290 cpf__->add(t->variable());
294 Instantiation inst(cpf__), jnst(old);
296 for (inst.setFirst(), jnst.setFirst(); !(inst.end() || jnst.end());
297 inst.inc(), jnst.inc()) {
298 cpf__->set(inst, old->get(jnst));
305 GUM_ASSERT(cpf__->contains(type__->variable()));
306 GUM_ASSERT(inst.end() && jnst.end());