aGrUM  0.14.2
searchStrategy_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
27 
28 namespace gum {
29  namespace prm {
30  namespace gspan {
31 
32  template < typename GUM_SCALAR >
34  double cost = 0;
36  *(this->_tree->data(p).iso_map.begin().val());
38 
39  for (const auto inst : seq) {
40  for (const auto input : inst->type().slotChains())
41  for (const auto inst2 : inst->getInstances(input->id()))
42  if ((!seq.exists(inst2))
43  && (!input_set.exists(
44  &(inst2->get(input->lastElt().safeName()))))) {
45  cost += std::log(input->type().variable().domainSize());
46  input_set.insert(&(inst2->get(input->lastElt().safeName())));
47  }
48 
49  for (auto vec = inst->beginInvRef(); vec != inst->endInvRef(); ++vec)
50  for (const auto inverse : *vec.val())
51  if (!seq.exists(inverse.first)) {
52  cost +=
53  std::log(inst->get(vec.key()).type().variable().domainSize());
54  break;
55  }
56  }
57 
58  return cost;
59  }
60 
61  template < typename GUM_SCALAR >
63  typename StrictSearch< GUM_SCALAR >::PData& data,
64  Set< Potential< GUM_SCALAR >* >& pool,
65  const Sequence< PRMInstance< GUM_SCALAR >* >& match) {
66  for (const auto inst : match) {
67  for (const auto& elt : *inst) {
68  // Adding the node
69  NodeId id = data.graph.addNode();
70  data.node2attr.insert(id, __str(inst, elt.second));
71  data.mod.insert(id, elt.second->type()->domainSize());
72  data.vars.insert(id, &elt.second->type().variable());
73  pool.insert(
74  const_cast< Potential< GUM_SCALAR >* >(&(elt.second->cpf())));
75  }
76  }
77 
78  // Second we add edges and nodes to inners or outputs
79  for (const auto inst : match)
80  for (const auto& elt : *inst) {
81  NodeId node = data.node2attr.first(__str(inst, elt.second));
82  bool found =
83  false; // If this is set at true, then node is an outer node
84 
85  // Children existing in the instance type's DAG
86  for (const auto chld :
87  inst->type().containerDag().children(elt.second->id())) {
88  data.graph.addEdge(
89  node, data.node2attr.first(__str(inst, inst->get(chld))));
90  }
91 
92  // Parents existing in the instance type's DAG
93  for (const auto par :
94  inst->type().containerDag().parents(elt.second->id())) {
95  switch (inst->type().get(par).elt_type()) {
98  data.graph.addEdge(
99  node, data.node2attr.first(__str(inst, inst->get(par))));
100  break;
101  }
102 
104  for (const auto inst2 : inst->getInstances(par))
105  if (match.exists(inst2))
106  data.graph.addEdge(
107  node,
108  data.node2attr.first(
109  __str(inst2,
110  static_cast< const PRMSlotChain< GUM_SCALAR >& >(
111  inst->type().get(par)))));
112 
113  break;
114  }
115 
116  default: { /* Do nothing */
117  }
118  }
119  }
120 
121  // Referring PRMAttribute<GUM_SCALAR>
122  if (inst->hasRefAttr(elt.second->id())) {
123  const std::vector<
124  std::pair< PRMInstance< GUM_SCALAR >*, std::string > >& ref_attr =
125  inst->getRefAttr(elt.second->id());
126 
127  for (auto pair = ref_attr.begin(); pair != ref_attr.end(); ++pair) {
128  if (match.exists(pair->first)) {
129  NodeId id = pair->first->type().get(pair->second).id();
130 
131  for (const auto child :
132  pair->first->type().containerDag().children(id))
133  data.graph.addEdge(node,
134  data.node2attr.first(__str(
135  pair->first, pair->first->get(child))));
136  } else {
137  found = true;
138  }
139  }
140  }
141 
142  if (found)
143  data.outputs.insert(node);
144  else
145  data.inners.insert(node);
146  }
147  }
148 
149  template < typename GUM_SCALAR >
151  typename StrictSearch< GUM_SCALAR >::PData& data,
152  Set< Potential< GUM_SCALAR >* >& pool) {
153  List< NodeSet > partial_order;
154 
155  if (data.inners.size()) partial_order.insert(data.inners);
156 
157  if (data.outputs.size()) partial_order.insert(data.outputs);
158 
159  PartialOrderedTriangulation t(&(data.graph), &(data.mod), &partial_order);
160  const std::vector< NodeId >& elim_order = t.eliminationOrder();
161  Size max(0), max_count(1);
163  Potential< GUM_SCALAR >* pot = 0;
164 
165  for (size_t idx = 0; idx < data.inners.size(); ++idx) {
167  pot->add(*(data.vars.second(elim_order[idx])));
168  trash.insert(pot);
169  Set< Potential< GUM_SCALAR >* > toRemove;
170 
171  for (const auto p : pool)
172  if (p->contains(*(data.vars.second(elim_order[idx])))) {
173  for (auto var = p->variablesSequence().begin();
174  var != p->variablesSequence().end();
175  ++var) {
176  try {
177  pot->add(**var);
178  } catch (DuplicateElement&) {}
179  }
180 
181  toRemove.insert(p);
182  }
183 
184  if (pot->domainSize() > max) {
185  max = pot->domainSize();
186  max_count = 1;
187  } else if (pot->domainSize() == max) {
188  ++max_count;
189  }
190 
191  for (const auto p : toRemove)
192  pool.erase(p);
193 
194  pot->erase(*(data.vars.second(elim_order[idx])));
195  }
196 
197  for (const auto pot : trash)
198  delete pot;
199 
200  return std::make_pair(max, max_count);
201  }
202 
203  // The SearchStrategy class
204  template < typename GUM_SCALAR >
206  GUM_CONSTRUCTOR(SearchStrategy);
207  }
208 
209  template < typename GUM_SCALAR >
211  const SearchStrategy< GUM_SCALAR >& from) :
212  _tree(from._tree) {
213  GUM_CONS_CPY(SearchStrategy);
214  }
215 
216  template < typename GUM_SCALAR >
218  GUM_DESTRUCTOR(SearchStrategy);
219  }
220 
221  template < typename GUM_SCALAR >
224  this->_tree = from._tree;
225  return *this;
226  }
227 
228  template < typename GUM_SCALAR >
229  INLINE void
231  this->_tree = tree;
232  }
233 
234  // FrequenceSearch
235 
236  // The FrequenceSearch class
237  template < typename GUM_SCALAR >
239  SearchStrategy< GUM_SCALAR >(), __freq(freq) {
240  GUM_CONSTRUCTOR(FrequenceSearch);
241  }
242 
243  template < typename GUM_SCALAR >
245  const FrequenceSearch< GUM_SCALAR >& from) :
246  SearchStrategy< GUM_SCALAR >(from),
247  __freq(from.__freq) {
248  GUM_CONS_CPY(FrequenceSearch);
249  }
250 
251  template < typename GUM_SCALAR >
253  GUM_DESTRUCTOR(FrequenceSearch);
254  }
255 
256  template < typename GUM_SCALAR >
259  __freq = from.__freq;
260  return *this;
261  }
262 
263  template < typename GUM_SCALAR >
265  return this->_tree->frequency(*r) >= __freq;
266  }
267 
268  template < typename GUM_SCALAR >
270  const Pattern* parent,
271  const Pattern* child,
272  const EdgeGrowth< GUM_SCALAR >& growh) {
273  return this->_tree->frequency(*child) >= __freq;
274  }
275 
276  template < typename GUM_SCALAR >
278  gspan::Pattern* j) {
279  // We want a descending order
280  return this->_tree->frequency(*i) > this->_tree->frequency(*j);
281  }
282 
283  template < typename GUM_SCALAR >
285  LabelData* j) {
286  return (this->_tree->graph().size(i) > this->_tree->graph().size(j));
287  }
288 
289  // StrictSearch
290 
291  // The StrictSearch class
292  template < typename GUM_SCALAR >
294  SearchStrategy< GUM_SCALAR >(), __freq(freq), __dot(".") {
295  GUM_CONSTRUCTOR(StrictSearch);
296  }
297 
298  template < typename GUM_SCALAR >
300  const StrictSearch< GUM_SCALAR >& from) :
301  SearchStrategy< GUM_SCALAR >(from),
302  __freq(from.__freq) {
303  GUM_CONS_CPY(StrictSearch);
304  }
305 
306  template < typename GUM_SCALAR >
308  GUM_DESTRUCTOR(StrictSearch);
309  }
310 
311  template < typename GUM_SCALAR >
314  __freq = from.__freq;
315  return *this;
316  }
317 
318  template < typename GUM_SCALAR >
320  return (this->_tree->frequency(*r) >= __freq);
321  }
322 
323  template < typename GUM_SCALAR >
325  const Pattern* parent,
326  const Pattern* child,
327  const EdgeGrowth< GUM_SCALAR >& growth) {
328  return __inner_cost(child)
329  + this->_tree->frequency(*child) * __outer_cost(child)
330  < this->_tree->frequency(*child) * __outer_cost(parent);
331  }
332 
333  template < typename GUM_SCALAR >
335  gspan::Pattern* j) {
336  return __inner_cost(i) + this->_tree->frequency(*i) * __outer_cost(i)
337  < __inner_cost(j) + this->_tree->frequency(*j) * __outer_cost(j);
338  }
339 
340  template < typename GUM_SCALAR >
342  LabelData* j) {
343  return i->tree_width * this->_tree->graph().size(i)
344  < j->tree_width * this->_tree->graph().size(j);
345  }
346 
347  template < typename GUM_SCALAR >
349  try {
350  return __map[p].first;
351  } catch (NotFound&) {
352  __compute_costs(p);
353  return __map[p].first;
354  }
355  }
356 
357  template < typename GUM_SCALAR >
359  try {
360  return __map[p].second;
361  } catch (NotFound&) {
362  __compute_costs(p);
363  return __map[p].second;
364  }
365  }
366 
367  template < typename GUM_SCALAR >
369  const PRMInstance< GUM_SCALAR >* i,
370  const PRMAttribute< GUM_SCALAR >* a) const {
371  return i->name() + __dot + a->safeName();
372  }
373 
374  template < typename GUM_SCALAR >
376  const PRMInstance< GUM_SCALAR >* i,
377  const PRMAttribute< GUM_SCALAR >& a) const {
378  return i->name() + __dot + a.safeName();
379  }
380 
381  template < typename GUM_SCALAR >
383  const PRMInstance< GUM_SCALAR >* i,
384  const PRMSlotChain< GUM_SCALAR >& a) const {
385  return i->name() + __dot + a.lastElt().safeName();
386  }
387 
388  template < typename GUM_SCALAR >
390  typename StrictSearch< GUM_SCALAR >::PData data;
393  data, pool, *(this->_tree->data(*p).iso_map.begin().val()));
394  double inner = std::log(__elimination_cost(data, pool).first);
395  double outer = this->_computeCost(*p);
396  __map.insert(p, std::make_pair(inner, outer));
397  }
398 
399  // TreeWidthSearch
400 
401  template < typename GUM_SCALAR >
403  SearchStrategy< GUM_SCALAR >() {
404  GUM_CONSTRUCTOR(TreeWidthSearch);
405  }
406 
407  template < typename GUM_SCALAR >
409  const TreeWidthSearch< GUM_SCALAR >& from) :
410  SearchStrategy< GUM_SCALAR >(from) {
411  GUM_CONS_CPY(TreeWidthSearch);
412  }
413 
414  template < typename GUM_SCALAR >
416  GUM_DESTRUCTOR(TreeWidthSearch);
417  }
418 
419  template < typename GUM_SCALAR >
422  return *this;
423  }
424 
425  template < typename GUM_SCALAR >
427  try {
428  return __map[&p];
429  } catch (NotFound&) {
430  __map.insert(&p, this->_computeCost(p));
431  return __map[&p];
432  }
433  }
434 
435  template < typename GUM_SCALAR >
437  Size tree_width = 0;
438 
439  for (const auto n : r->nodes())
440  tree_width += r->label(n).tree_width;
441 
442  return tree_width >= cost(*r);
443  }
444 
445  template < typename GUM_SCALAR >
447  const Pattern* parent,
448  const Pattern* child,
449  const EdgeGrowth< GUM_SCALAR >& growth) {
450  return cost(*parent) >= cost(*child);
451  }
452 
453  template < typename GUM_SCALAR >
455  gspan::Pattern* j) {
456  return cost(*i) < cost(*j);
457  }
458 
459  template < typename GUM_SCALAR >
461  LabelData* j) {
462  return i->tree_width < j->tree_width;
463  }
464 
465  } /* namespace gspan */
466  } /* namespace prm */
467 } /* namespace gum */
void setTree(DFSTree< GUM_SCALAR > *tree)
void insert(const T1 &first, const T2 &second)
Inserts a new association in the gum::Bijection.
aGrUM&#39;s Potential is a multi-dimensional array with tensor operators.
Definition: potential.h:57
void __buildPatternGraph(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Potential< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
This class is used to define an edge growth of a pattern in this DFSTree.
Definition: edgeGrowth.h:60
NodeSet inners
Returns the set of inner nodes.
virtual Size domainSize() const final
Returns the product of the variables domain size.
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
virtual bool operator()(LabelData *i, LabelData *j)
const T1 & first(const T2 &second) const
Returns the first value of a pair given its second value.
Inner class to handle data about labels in this interface graph.
const std::string & name() const
Returns the name of this object.
Definition: PRMObject_inl.h:32
virtual bool accept_root(const Pattern *r)
An PRMInstance is a Bayesian Network fragment defined by a Class and used in a PRMSystem.
Definition: PRMInstance.h:60
virtual ~TreeWidthSearch()
Destructor.
double __outer_cost(const Pattern *p)
StrictSearch & operator=(const StrictSearch &from)
Copy operator.
The generic class for storing (ordered) sequences of objects.
Definition: sequence.h:1019
virtual void addEdge(const NodeId first, const NodeId second)
insert a new edge into the undirected graph
Definition: undiGraph_inl.h:32
virtual void erase(const DiscreteVariable &var) final
Removes a var from the variables of the multidimensional matrix.
virtual ~FrequenceSearch()
Destructor.
Abstract class representing an element of PRM class.
UndiGraph graph
A yet to be triangulated undigraph.
std::pair< Size, Size > __elimination_cost(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Potential< GUM_SCALAR > * > &pool)
Bijection< NodeId, const DiscreteVariable *> vars
Bijection between graph&#39;s nodes and their corresponding DiscreteVariable, for inference purpose...
Generic doubly linked lists.
Definition: list.h:369
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
virtual NodeId addNode()
insert a new node and return its id
LabelData & label(NodeId node)
Returns the LabelData assigned to node.
Definition: pattern_inl.h:47
DFSTree< GUM_SCALAR > * _tree
virtual bool operator()(LabelData *i, LabelData *j)
FrequenceSearch & operator=(const FrequenceSearch &from)
Copy operator.
Representation of a setA Set is a structure that contains arbitrary elements.
Definition: set.h:162
StrictSearch(Size freq=2)
Default constructor.
const std::string & safeName() const
Returns the safe name of this PRMClassElement, if any.
A DFSTree is used by gspan to sort lexicographically patterns discovered in an interface graph...
Definition: DFSTree.h:68
A growth is accepted if and only if the new growth has a tree width less large or equal than its fath...
std::string __str(const PRMInstance< GUM_SCALAR > *i, const PRMAttribute< GUM_SCALAR > *a) const
virtual ~SearchStrategy()
Destructor.
virtual bool accept_root(const Pattern *r)
HashTable< const Pattern *, double > __map
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
class for graph triangulations for which we enforce a given partial ordering on the nodes elimination...
NodeSet outputs
Returns the set of outputs nodes given all the matches of pattern.
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
virtual bool accept_root(const Pattern *r)
const std::vector< NodeId > & eliminationOrder()
returns an elimination ordering compatible with the triangulated graph
Bijection< NodeId, std::string > node2attr
A bijection to easily keep track between graph and attributes, its of the form instance_name DOT attr...
TreeWidthSearch()
Default constructor.
bool exists(const Key &k) const
Check the existence of k in the sequence.
Definition: sequence_tpl.h:399
NodeProperty< Size > mod
The pattern&#39;s variables modalities.
Headers of the SearchStrategy class and child.
FrequenceSearch(Size freq)
Default constructor.
A PRMSlotChain represents a sequence of gum::prm::PRMClassElement<GUM_SCALAR> where the n-1 first gum...
Definition: PRMObject.h:218
const NodeGraphPart & nodes() const
Definition: pattern_inl.h:163
Private structure to represent data about a pattern.
HashTable< const Pattern *, std::pair< double, double > > __map
This is class is an implementation of a strict strategy for the GSpan algorithm.
Val & insert(const Val &val)
Inserts a new element at the end of the chained list (alias of pushBack).
Definition: list_tpl.h:1616
virtual void add(const DiscreteVariable &v) final
Adds a new var to the variables of the multidimensional matrix.
double __inner_cost(const Pattern *p)
This is class is an implementation of a simple serach strategy for the gspan algorithm: it accept a g...
SearchStrategy()
Default constructor.
double _computeCost(const Pattern &p)
void __compute_costs(const Pattern *p)
Size tree_width
The size in terms of tree width of the given label.
PRMAttribute is a member of a Class in a PRM.
Definition: PRMAttribute.h:58
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:45
PRMClassElement< GUM_SCALAR > & lastElt()
Returns the last element of the slot chain, typically this is an gum::PRMAttribute or a gum::PRMAggre...
Size size() const noexcept
Returns the number of elements in the set.
Definition: set_tpl.h:698
TreeWidthSearch & operator=(const TreeWidthSearch &from)
Copy operator.
virtual ~StrictSearch()
Destructor.
This contains all the information we want for a node in a DFSTree.
Definition: pattern.h:70
This is an abstract class used to tune search strategies in the gspan algorithm.
Definition: DFSTree.h:58
Size NodeId
Type for node ids.
Definition: graphElements.h:97
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:610
Multidimensional matrix stored as a sparse array in memory.
SearchStrategy< GUM_SCALAR > & operator=(const SearchStrategy< GUM_SCALAR > &from)
Copy operator.
virtual bool operator()(LabelData *i, LabelData *j)
void insert(const Key &k)
Insert an element at the end of the sequence.
Definition: sequence_tpl.h:405