aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
searchStrategy_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Inline implementation of the SearchStrategy class.
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 #include <agrum/PRM/gspan/searchStrategy.h>
29 
30 namespace gum {
31  namespace prm {
32  namespace gspan {
33 
34  template < typename GUM_SCALAR >
35  double SearchStrategy< GUM_SCALAR >::computeCost_(const Pattern& p) {
36  double cost = 0;
37  const Sequence< PRMInstance< GUM_SCALAR >* >& seq
38  = *(this->tree_->data(p).iso_map.begin().val());
40 
41  for (const auto inst: seq) {
42  for (const auto input: inst->type().slotChains())
43  for (const auto inst2: inst->getInstances(input->id()))
44  if ((!seq.exists(inst2))
45  && (!input_set.exists(
46  &(inst2->get(input->lastElt().safeName()))))) {
47  cost += std::log(input->type().variable().domainSize());
49  }
50 
51  for (auto vec = inst->beginInvRef(); vec != inst->endInvRef(); ++vec)
52  for (const auto inverse: *vec.val())
53  if (!seq.exists(inverse.first)) {
54  cost += std::log(
55  inst->get(vec.key()).type().variable().domainSize());
56  break;
57  }
58  }
59 
60  return cost;
61  }
62 
63  template < typename GUM_SCALAR >
64  void StrictSearch< GUM_SCALAR >::buildPatternGraph__(
65  typename StrictSearch< GUM_SCALAR >::PData& data,
66  Set< Potential< GUM_SCALAR >* >& pool,
67  const Sequence< PRMInstance< GUM_SCALAR >* >& match) {
68  for (const auto inst: match) {
69  for (const auto& elt: *inst) {
70  // Adding the node
75  pool.insert(
76  const_cast< Potential< GUM_SCALAR >* >(&(elt.second->cpf())));
77  }
78  }
79 
80  // Second we add edges and nodes to inners or outputs
81  for (const auto inst: match)
82  for (const auto& elt: *inst) {
84  bool found
85  = false; // If this is set at true, then node is an outer node
86 
87  // Children existing in the instance type's DAG
88  for (const auto chld:
91  node,
93  }
94 
95  // Parents existing in the instance type's DAG
96  for (const auto par:
98  switch (inst->type().get(par).elt_type()) {
102  node,
104  break;
105  }
106 
108  for (const auto inst2: inst->getInstances(par))
109  if (match.exists(inst2))
111  node,
113  str__(inst2,
114  static_cast< const PRMSlotChain< GUM_SCALAR >& >(
115  inst->type().get(par)))));
116 
117  break;
118  }
119 
120  default: { /* Do nothing */
121  }
122  }
123  }
124 
125  // Referring PRMAttribute<GUM_SCALAR>
126  if (inst->hasRefAttr(elt.second->id())) {
127  const std::vector<
129  = inst->getRefAttr(elt.second->id());
130 
131  for (auto pair = ref_attr.begin(); pair != ref_attr.end(); ++pair) {
132  if (match.exists(pair->first)) {
133  NodeId id = pair->first->type().get(pair->second).id();
134 
135  for (const auto child:
138  node,
140  str__(pair->first, pair->first->get(child))));
141  } else {
142  found = true;
143  }
144  }
145  }
146 
147  if (found)
149  else
151  }
152  }
153 
154  template < typename GUM_SCALAR >
156  typename StrictSearch< GUM_SCALAR >::PData& data,
157  Set< Potential< GUM_SCALAR >* >& pool) {
159 
161 
163 
165  const std::vector< NodeId >& elim_order = t.eliminationOrder();
166  Size max(0), max_count(1);
167  Set< Potential< GUM_SCALAR >* > trash;
168  Potential< GUM_SCALAR >* pot = 0;
169 
170  for (size_t idx = 0; idx < data.inners.size(); ++idx) {
171  pot = new Potential< GUM_SCALAR >(new MultiDimSparse< GUM_SCALAR >(0));
173  trash.insert(pot);
175 
176  for (const auto p: pool)
177  if (p->contains(*(data.vars.second(elim_order[idx])))) {
178  for (auto var = p->variablesSequence().begin();
179  var != p->variablesSequence().end();
180  ++var) {
181  try {
182  pot->add(**var);
183  } catch (DuplicateElement&) {}
184  }
185 
186  toRemove.insert(p);
187  }
188 
189  if (pot->domainSize() > max) {
190  max = pot->domainSize();
191  max_count = 1;
192  } else if (pot->domainSize() == max) {
193  ++max_count;
194  }
195 
196  for (const auto p: toRemove)
197  pool.erase(p);
198 
200  }
201 
202  for (const auto pot: trash)
203  delete pot;
204 
205  return std::make_pair(max, max_count);
206  }
207 
208  // The SearchStrategy class
209  template < typename GUM_SCALAR >
212  }
213 
214  template < typename GUM_SCALAR >
216  const SearchStrategy< GUM_SCALAR >& from) :
217  tree_(from.tree_) {
219  }
220 
221  template < typename GUM_SCALAR >
224  }
225 
226  template < typename GUM_SCALAR >
228  const SearchStrategy< GUM_SCALAR >& from) {
229  this->tree_ = from.tree_;
230  return *this;
231  }
232 
233  template < typename GUM_SCALAR >
234  INLINE void
236  this->tree_ = tree;
237  }
238 
239  // FrequenceSearch
240 
241  // The FrequenceSearch class
242  template < typename GUM_SCALAR >
246  }
247 
248  template < typename GUM_SCALAR >
250  const FrequenceSearch< GUM_SCALAR >& from) :
252  freq__(from.freq__) {
254  }
255 
256  template < typename GUM_SCALAR >
259  }
260 
261  template < typename GUM_SCALAR >
264  const FrequenceSearch< GUM_SCALAR >& from) {
265  freq__ = from.freq__;
266  return *this;
267  }
268 
269  template < typename GUM_SCALAR >
271  return this->tree_->frequency(*r) >= freq__;
272  }
273 
274  template < typename GUM_SCALAR >
276  const Pattern* parent,
277  const Pattern* child,
278  const EdgeGrowth< GUM_SCALAR >& growh) {
279  return this->tree_->frequency(*child) >= freq__;
280  }
281 
282  template < typename GUM_SCALAR >
284  gspan::Pattern* j) {
285  // We want a descending order
286  return this->tree_->frequency(*i) > this->tree_->frequency(*j);
287  }
288 
289  template < typename GUM_SCALAR >
291  LabelData* j) {
292  return (this->tree_->graph().size(i) > this->tree_->graph().size(j));
293  }
294 
295  // StrictSearch
296 
297  // The StrictSearch class
298  template < typename GUM_SCALAR >
300  SearchStrategy< GUM_SCALAR >(), freq__(freq), dot__(".") {
302  }
303 
304  template < typename GUM_SCALAR >
306  const StrictSearch< GUM_SCALAR >& from) :
308  freq__(from.freq__) {
310  }
311 
312  template < typename GUM_SCALAR >
315  }
316 
317  template < typename GUM_SCALAR >
319  const StrictSearch< GUM_SCALAR >& from) {
320  freq__ = from.freq__;
321  return *this;
322  }
323 
324  template < typename GUM_SCALAR >
326  return (this->tree_->frequency(*r) >= freq__);
327  }
328 
329  template < typename GUM_SCALAR >
331  const Pattern* parent,
332  const Pattern* child,
333  const EdgeGrowth< GUM_SCALAR >& growth) {
334  return inner_cost__(child)
335  + this->tree_->frequency(*child) * outer_cost__(child)
336  < this->tree_->frequency(*child) * outer_cost__(parent);
337  }
338 
339  template < typename GUM_SCALAR >
341  gspan::Pattern* j) {
342  return inner_cost__(i) + this->tree_->frequency(*i) * outer_cost__(i)
343  < inner_cost__(j) + this->tree_->frequency(*j) * outer_cost__(j);
344  }
345 
346  template < typename GUM_SCALAR >
348  LabelData* j) {
349  return i->tree_width * this->tree_->graph().size(i)
350  < j->tree_width * this->tree_->graph().size(j);
351  }
352 
353  template < typename GUM_SCALAR >
355  try {
356  return map__[p].first;
357  } catch (NotFound&) {
359  return map__[p].first;
360  }
361  }
362 
363  template < typename GUM_SCALAR >
365  try {
366  return map__[p].second;
367  } catch (NotFound&) {
369  return map__[p].second;
370  }
371  }
372 
373  template < typename GUM_SCALAR >
375  const PRMInstance< GUM_SCALAR >* i,
376  const PRMAttribute< GUM_SCALAR >* a) const {
377  return i->name() + dot__ + a->safeName();
378  }
379 
380  template < typename GUM_SCALAR >
382  const PRMInstance< GUM_SCALAR >* i,
383  const PRMAttribute< GUM_SCALAR >& a) const {
384  return i->name() + dot__ + a.safeName();
385  }
386 
387  template < typename GUM_SCALAR >
389  const PRMInstance< GUM_SCALAR >* i,
390  const PRMSlotChain< GUM_SCALAR >& a) const {
391  return i->name() + dot__ + a.lastElt().safeName();
392  }
393 
394  template < typename GUM_SCALAR >
396  typename StrictSearch< GUM_SCALAR >::PData data;
397  Set< Potential< GUM_SCALAR >* > pool;
399  pool,
400  *(this->tree_->data(*p).iso_map.begin().val()));
402  double outer = this->computeCost_(*p);
404  }
405 
406  // TreeWidthSearch
407 
408  template < typename GUM_SCALAR >
412  }
413 
414  template < typename GUM_SCALAR >
416  const TreeWidthSearch< GUM_SCALAR >& from) :
419  }
420 
421  template < typename GUM_SCALAR >
424  }
425 
426  template < typename GUM_SCALAR >
429  const TreeWidthSearch< GUM_SCALAR >& from) {
430  return *this;
431  }
432 
433  template < typename GUM_SCALAR >
435  try {
436  return map__[&p];
437  } catch (NotFound&) {
438  map__.insert(&p, this->computeCost_(p));
439  return map__[&p];
440  }
441  }
442 
443  template < typename GUM_SCALAR >
445  Size tree_width = 0;
446 
447  for (const auto n: r->nodes())
449 
450  return tree_width >= cost(*r);
451  }
452 
453  template < typename GUM_SCALAR >
455  const Pattern* parent,
456  const Pattern* child,
457  const EdgeGrowth< GUM_SCALAR >& growth) {
458  return cost(*parent) >= cost(*child);
459  }
460 
461  template < typename GUM_SCALAR >
463  gspan::Pattern* j) {
464  return cost(*i) < cost(*j);
465  }
466 
467  template < typename GUM_SCALAR >
469  LabelData* j) {
470  return i->tree_width < j->tree_width;
471  }
472 
473  } /* namespace gspan */
474  } /* namespace prm */
475 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)
INLINE std::ostream & operator<<(std::ostream &out, const EdgeData< GUM_SCALAR > &data)
Print a EdgeData<GUM_SCALAR> in out.