aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
structuredBayesBall_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Inline implementation of StructuredBayesBall.
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <agrum/PRM/inference/structuredBayesBall.h>
30 
31 namespace gum {
32  namespace prm {
33 
34  template < typename GUM_SCALAR >
37 
38  for (const auto& elt: reqMap__)
39  delete elt.second.first;
40  }
41 
42  template < typename GUM_SCALAR >
44  for (const auto& elt: reqMap__)
45  delete elt.second.first;
46 
47  keyMap__.clear();
48  reqMap__.clear();
49  }
50 
51  template < typename GUM_SCALAR >
53  const PRMInstance< GUM_SCALAR >* i,
54  NodeId n) {
55  try {
56  typename PRMInference< GUM_SCALAR >::Chain chain
57  = std::make_pair(i, &(i->get(n)));
58 
59  if (inf__->hasEvidence(chain)) {
60  const Potential< GUM_SCALAR >* e = inf__->evidence(i)[n];
62  Size count = 0;
63 
64  for (inst.setFirst(); !inst.end(); inst.inc()) {
65  if ((e->get(inst) == (GUM_SCALAR)1.0))
66  ++count;
67  else if (e->get(inst) != (GUM_SCALAR)0.0)
68  return false;
69  }
70 
71  return (count == 1);
72  }
73 
74  return false;
75  } catch (NotFound&) { return false; }
76  }
77 
78  template < typename GUM_SCALAR >
80  const PRMInstance< GUM_SCALAR >* i,
81  NodeId n) {
82  clean__();
83  /// Key = instance.PRMClassElement<GUM_DATA>
84  /// pair = <upper mark, lower mark>
86  fromChild__(i, n, marks);
88 
89  for (const auto& elt: marks)
90  delete elt.second;
91  }
92 
93  template < typename GUM_SCALAR >
95  const PRMInstance< GUM_SCALAR >* i,
96  NodeId n,
97  InstanceMap& marks) {
98  if (!marks.exists(i)) {
100  }
101 
102  if (!marks[i]->exists(n)) {
103  marks[i]->insert(n, std::pair< bool, bool >(false, false));
104  }
105 
106  // Sending message to parents
107  switch (i->type().get(n).elt_type()) {
109  if (!getMark__(marks, i, n).first) {
110  getMark__(marks, i, n).first = true;
111 
112  for (const auto inst: i->getInstances(n))
114  inst->get(getSC__(i, n).lastElt().safeName()).id(),
115  marks);
116  }
117 
118  if (!getMark__(marks, i, n).second) {
119  getMark__(marks, i, n).second = true;
120 
121  for (const auto chi: i->type().containerDag().children(n))
122  fromParent__(i, chi, marks);
123  }
124 
125  break;
126  }
127 
130  if (!getMark__(marks, i, n).first) {
131  getMark__(marks, i, n).first = true;
132 
133  if (!isHardEvidence__(i, n))
134  for (const auto par: i->type().containerDag().parents(n))
135  fromChild__(i, par, marks);
136  }
137 
138  if (!getMark__(marks, i, n).second) {
139  getMark__(marks, i, n).second = true;
140 
141  // In i.
142  for (const auto chi: i->type().containerDag().children(n))
143  fromParent__(i, chi, marks);
144 
145  // Out of i.
146  try {
147  const auto& refs = i->getRefAttr(n);
148 
149  for (auto iter = refs.begin(); iter != refs.end(); ++iter)
151  iter->first->type().get(iter->second).id(),
152  marks);
153  } catch (NotFound&) {
154  // Not an inverse sc
155  }
156  }
157 
158  break;
159  }
160 
161  default: {
162  // We shouldn't reach any other PRMClassElement<GUM_DATA> than
163  // PRMAttribute
164  // or
165  // PRMSlotChain<GUM_SCALAR>.
166  GUM_ERROR(FatalError, "This case is impossible.");
167  }
168  }
169  }
170 
171  template < typename GUM_SCALAR >
173  const PRMInstance< GUM_SCALAR >* i,
174  NodeId n,
175  InstanceMap& marks) {
176  if (!marks.exists(i)) {
178  }
179 
180  if (!marks[i]->exists(n)) {
181  marks[i]->insert(n, std::pair< bool, bool >(false, false));
182  }
183 
184  // Concerns only PRMAttribute (because of the hard evidence)
185  if ((isHardEvidence__(i, n)) && (!getMark__(marks, i, n).first)) {
186  getMark__(marks, i, n).first = true;
187 
188  for (const auto par: i->type().containerDag().parents(n))
189  fromChild__(i, par, marks);
190  } else if (!getMark__(marks, i, n).second) {
191  getMark__(marks, i, n).second = true;
192 
193  // In i.
194  for (const auto chi: i->type().containerDag().children(n))
195  fromParent__(i, chi, marks);
196 
197  // Out of i.
198  try {
199  for (auto iter = i->getRefAttr(n).begin();
200  iter != i->getRefAttr(n).end();
201  ++iter)
203  iter->first->type().get(iter->second).id(),
204  marks);
205  } catch (NotFound&) {
206  // Not an inverse sc
207  }
208  }
209  }
210 
211  template < typename GUM_SCALAR >
213  // First find for each instance it's requisite nodes
214  HashTable< const PRMInstance< GUM_SCALAR >*, Set< NodeId >* > req_map;
215 
216  for (const auto& elt: marks) {
217  Set< NodeId >* req_set = new Set< NodeId >();
218 
219  for (const auto& elt2: *elt.second)
221 
223  }
224 
225  // Remove all instances with 0 requisite nodes
226  Set< const PRMInstance< GUM_SCALAR >* > to_remove;
227 
228  for (const auto& elt: req_map)
229  if (elt.second->size() == 0) to_remove.insert(elt.first);
230 
231  for (const auto remo: to_remove) {
232  delete req_map[remo];
233  req_map.erase(remo);
234  }
235 
236  // Fill reqMap__ and keyMap__
237  for (const auto& elt: req_map) {
239 
240  if (reqMap__.exists(key)) {
242  elt.first,
243  std::pair< std::string, Set< NodeId >* >(key, reqMap__[key].first));
244  reqMap__[key].second += 1;
245  delete elt.second;
246  req_map[elt.first] = 0;
247  } else {
248  reqMap__.insert(key, std::pair< Set< NodeId >*, Size >(elt.second, 1));
250  elt.first,
251  std::pair< std::string, Set< NodeId >* >(key, elt.second));
252  }
253  }
254  }
255 
256  template < typename GUM_SCALAR >
258  const PRMInstance< GUM_SCALAR >* i,
259  Set< NodeId >& req_nodes) {
261  sBuff << i->type().name();
262 
263  for (const auto node: i->type().containerDag().nodes())
264  if (req_nodes.exists(node)) sBuff << "-" << node;
265 
266  return sBuff.str();
267  }
268 
269  template < typename GUM_SCALAR >
271  const PRMInference< GUM_SCALAR >& inference) :
272  inf__(&inference) {
274  }
275 
276  template < typename GUM_SCALAR >
279  inf__(0) {
281  GUM_ERROR(FatalError, "Not allowed.");
282  }
283 
284  template < typename GUM_SCALAR >
288  GUM_ERROR(FatalError, "Not allowed.");
289  }
290 
291  template < typename GUM_SCALAR >
293  const PRMInstance< GUM_SCALAR >* i) const {
294  return keyMap__[i].first;
295  }
296 
297  template < typename GUM_SCALAR >
299  const PRMInstance< GUM_SCALAR >& i) const {
300  return keyMap__[&i].first;
301  }
302 
303  template < typename GUM_SCALAR >
305  const PRMInstance< GUM_SCALAR >* i) const {
306  return *(keyMap__[i].second);
307  }
308 
309  template < typename GUM_SCALAR >
311  const PRMInstance< GUM_SCALAR >& i) const {
312  return *(keyMap__[&i].second);
313  }
314 
315  template < typename GUM_SCALAR >
317  const std::string& key) const {
318  return reqMap__[key].second;
319  }
320 
321  template < typename GUM_SCALAR >
323  return ((float)reqMap__.size()) / ((float)keyMap__.size());
324  }
325 
326  template < typename GUM_SCALAR >
328  const PRMInstance< GUM_SCALAR >* i) const {
329  return keyMap__.exists(i);
330  }
331 
332  template < typename GUM_SCALAR >
334  const PRMInstance< GUM_SCALAR >& i) const {
335  return keyMap__.exists(&i);
336  }
337 
338  template < typename GUM_SCALAR >
340  const PRMInstance< GUM_SCALAR >* i,
341  NodeId n) {
342  compute__(i, n);
343  }
344 
345  template < typename GUM_SCALAR >
347  const PRMInstance< GUM_SCALAR >& i,
348  NodeId n) {
349  compute__(&i, n);
350  }
351 
352  template < typename GUM_SCALAR >
353  INLINE const PRMSlotChain< GUM_SCALAR >&
355  const PRMInstance< GUM_SCALAR >* i,
356  NodeId n) {
357  return static_cast< const PRMSlotChain< GUM_SCALAR >& >(i->type().get(n));
358  }
359 
360  template < typename GUM_SCALAR >
363  const PRMInstance< GUM_SCALAR >* i,
364  NodeId n) {
365  return (*(marks[i]))[n];
366  }
367 
368  } /* namespace prm */
369 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)