aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
searchStrategy_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Inline implementation of the SearchStrategy class.
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 #include <agrum/PRM/gspan/searchStrategy.h>
29 
30 namespace gum {
31  namespace prm {
32  namespace gspan {
33 
34  template < typename GUM_SCALAR >
35  double SearchStrategy< GUM_SCALAR >::computeCost_(const Pattern& p) {
36  double cost = 0;
37  const Sequence< PRMInstance< GUM_SCALAR >* >& seq
38  = *(this->tree_->data(p).iso_map.begin().val());
40 
41  for (const auto inst: seq) {
42  for (const auto input: inst->type().slotChains())
43  for (const auto inst2: inst->getInstances(input->id()))
44  if ((!seq.exists(inst2))
45  && (!input_set.exists(&(inst2->get(input->lastElt().safeName()))))) {
46  cost += std::log(input->type().variable().domainSize());
48  }
49 
50  for (auto vec = inst->beginInvRef(); vec != inst->endInvRef(); ++vec)
51  for (const auto inverse: *vec.val())
52  if (!seq.exists(inverse.first)) {
53  cost += std::log(inst->get(vec.key()).type().variable().domainSize());
54  break;
55  }
56  }
57 
58  return cost;
59  }
60 
61  template < typename GUM_SCALAR >
62  void StrictSearch< GUM_SCALAR >::_buildPatternGraph_(
63  typename StrictSearch< GUM_SCALAR >::PData& data,
64  Set< Potential< GUM_SCALAR >* >& pool,
65  const Sequence< PRMInstance< GUM_SCALAR >* >& match) {
66  for (const auto inst: match) {
67  for (const auto& elt: *inst) {
68  // Adding the node
73  pool.insert(const_cast< Potential< GUM_SCALAR >* >(&(elt.second->cpf())));
74  }
75  }
76 
77  // Second we add edges and nodes to inners or outputs
78  for (const auto inst: match)
79  for (const auto& elt: *inst) {
81  bool found = false; // If this is set at true, then node is an outer node
82 
83  // Children existing in the instance type's DAG
84  for (const auto chld: inst->type().containerDag().children(elt.second->id())) {
86  }
87 
88  // Parents existing in the instance type's DAG
89  for (const auto par: inst->type().containerDag().parents(elt.second->id())) {
90  switch (inst->type().get(par).elt_type()) {
94  break;
95  }
96 
98  for (const auto inst2: inst->getInstances(par))
99  if (match.exists(inst2))
102  _str_(inst2,
103  static_cast< const PRMSlotChain< GUM_SCALAR >& >(
104  inst->type().get(par)))));
105 
106  break;
107  }
108 
109  default: { /* Do nothing */
110  }
111  }
112  }
113 
114  // Referring PRMAttribute<GUM_SCALAR>
115  if (inst->hasRefAttr(elt.second->id())) {
116  const std::vector< std::pair< PRMInstance< GUM_SCALAR >*, std::string > >& ref_attr
117  = inst->getRefAttr(elt.second->id());
118 
119  for (auto pair = ref_attr.begin(); pair != ref_attr.end(); ++pair) {
120  if (match.exists(pair->first)) {
121  NodeId id = pair->first->type().get(pair->second).id();
122 
123  for (const auto child: pair->first->type().containerDag().children(id))
125  node,
127  } else {
128  found = true;
129  }
130  }
131  }
132 
133  if (found)
135  else
137  }
138  }
139 
140  template < typename GUM_SCALAR >
142  typename StrictSearch< GUM_SCALAR >::PData& data,
143  Set< Potential< GUM_SCALAR >* >& pool) {
145 
147 
149 
151  const std::vector< NodeId >& elim_order = t.eliminationOrder();
152  Size max(0), max_count(1);
153  Set< Potential< GUM_SCALAR >* > trash;
154  Potential< GUM_SCALAR >* pot = 0;
155 
156  for (size_t idx = 0; idx < data.inners.size(); ++idx) {
157  pot = new Potential< GUM_SCALAR >(new MultiDimSparse< GUM_SCALAR >(0));
159  trash.insert(pot);
161 
162  for (const auto p: pool)
163  if (p->contains(*(data.vars.second(elim_order[idx])))) {
164  for (auto var = p->variablesSequence().begin(); var != p->variablesSequence().end();
165  ++var) {
166  try {
167  pot->add(**var);
168  } catch (DuplicateElement&) {}
169  }
170 
171  toRemove.insert(p);
172  }
173 
174  if (pot->domainSize() > max) {
175  max = pot->domainSize();
176  max_count = 1;
177  } else if (pot->domainSize() == max) {
178  ++max_count;
179  }
180 
181  for (const auto p: toRemove)
182  pool.erase(p);
183 
185  }
186 
187  for (const auto pot: trash)
188  delete pot;
189 
190  return std::make_pair(max, max_count);
191  }
192 
193  // The SearchStrategy class
194  template < typename GUM_SCALAR >
197  }
198 
199  template < typename GUM_SCALAR >
200  INLINE
202  tree_(from.tree_) {
204  }
205 
206  template < typename GUM_SCALAR >
209  }
210 
211  template < typename GUM_SCALAR >
214  this->tree_ = from.tree_;
215  return *this;
216  }
217 
218  template < typename GUM_SCALAR >
220  this->tree_ = tree;
221  }
222 
223  // FrequenceSearch
224 
225  // The FrequenceSearch class
226  template < typename GUM_SCALAR >
230  }
231 
232  template < typename GUM_SCALAR >
233  INLINE
236  _freq_(from._freq_) {
238  }
239 
240  template < typename GUM_SCALAR >
243  }
244 
245  template < typename GUM_SCALAR >
248  _freq_ = from._freq_;
249  return *this;
250  }
251 
252  template < typename GUM_SCALAR >
254  return this->tree_->frequency(*r) >= _freq_;
255  }
256 
257  template < typename GUM_SCALAR >
258  INLINE bool
260  const Pattern* child,
261  const EdgeGrowth< GUM_SCALAR >& growh) {
262  return this->tree_->frequency(*child) >= _freq_;
263  }
264 
265  template < typename GUM_SCALAR >
267  // We want a descending order
268  return this->tree_->frequency(*i) > this->tree_->frequency(*j);
269  }
270 
271  template < typename GUM_SCALAR >
273  return (this->tree_->graph().size(i) > this->tree_->graph().size(j));
274  }
275 
276  // StrictSearch
277 
278  // The StrictSearch class
279  template < typename GUM_SCALAR >
281  SearchStrategy< GUM_SCALAR >(), _freq_(freq), _dot_(".") {
283  }
284 
285  template < typename GUM_SCALAR >
289  }
290 
291  template < typename GUM_SCALAR >
294  }
295 
296  template < typename GUM_SCALAR >
299  _freq_ = from._freq_;
300  return *this;
301  }
302 
303  template < typename GUM_SCALAR >
305  return (this->tree_->frequency(*r) >= _freq_);
306  }
307 
308  template < typename GUM_SCALAR >
309  INLINE bool
311  const Pattern* child,
312  const EdgeGrowth< GUM_SCALAR >& growth) {
314  < this->tree_->frequency(*child) * _outer_cost_(parent);
315  }
316 
317  template < typename GUM_SCALAR >
319  return _inner_cost_(i) + this->tree_->frequency(*i) * _outer_cost_(i)
320  < _inner_cost_(j) + this->tree_->frequency(*j) * _outer_cost_(j);
321  }
322 
323  template < typename GUM_SCALAR >
325  return i->tree_width * this->tree_->graph().size(i)
326  < j->tree_width * this->tree_->graph().size(j);
327  }
328 
329  template < typename GUM_SCALAR >
331  try {
332  return _map_[p].first;
333  } catch (NotFound&) {
335  return _map_[p].first;
336  }
337  }
338 
339  template < typename GUM_SCALAR >
341  try {
342  return _map_[p].second;
343  } catch (NotFound&) {
345  return _map_[p].second;
346  }
347  }
348 
349  template < typename GUM_SCALAR >
350  INLINE std::string
352  const PRMAttribute< GUM_SCALAR >* a) const {
353  return i->name() + _dot_ + a->safeName();
354  }
355 
356  template < typename GUM_SCALAR >
357  INLINE std::string
359  const PRMAttribute< GUM_SCALAR >& a) const {
360  return i->name() + _dot_ + a.safeName();
361  }
362 
363  template < typename GUM_SCALAR >
364  INLINE std::string
366  const PRMSlotChain< GUM_SCALAR >& a) const {
367  return i->name() + _dot_ + a.lastElt().safeName();
368  }
369 
370  template < typename GUM_SCALAR >
372  typename StrictSearch< GUM_SCALAR >::PData data;
373  Set< Potential< GUM_SCALAR >* > pool;
374  _buildPatternGraph_(data, pool, *(this->tree_->data(*p).iso_map.begin().val()));
376  double outer = this->computeCost_(*p);
378  }
379 
380  // TreeWidthSearch
381 
382  template < typename GUM_SCALAR >
385  }
386 
387  template < typename GUM_SCALAR >
388  INLINE
392  }
393 
394  template < typename GUM_SCALAR >
397  }
398 
399  template < typename GUM_SCALAR >
402  return *this;
403  }
404 
405  template < typename GUM_SCALAR >
407  try {
408  return _map_[&p];
409  } catch (NotFound&) {
410  _map_.insert(&p, this->computeCost_(p));
411  return _map_[&p];
412  }
413  }
414 
415  template < typename GUM_SCALAR >
417  Size tree_width = 0;
418 
419  for (const auto n: r->nodes())
421 
422  return tree_width >= cost(*r);
423  }
424 
425  template < typename GUM_SCALAR >
426  INLINE bool
428  const Pattern* child,
429  const EdgeGrowth< GUM_SCALAR >& growth) {
430  return cost(*parent) >= cost(*child);
431  }
432 
433  template < typename GUM_SCALAR >
435  return cost(*i) < cost(*j);
436  }
437 
438  template < typename GUM_SCALAR >
440  return i->tree_width < j->tree_width;
441  }
442 
443  } /* namespace gspan */
444  } /* namespace prm */
445 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)
INLINE std::ostream & operator<<(std::ostream &out, const EdgeData< GUM_SCALAR > &data)
Print a EdgeData<GUM_SCALAR> in out.