aGrUM  0.16.0
searchStrategy_tpl.h
Go to the documentation of this file.
1 
30 
31 namespace gum {
32  namespace prm {
33  namespace gspan {
34 
35  template < typename GUM_SCALAR >
37  double cost = 0;
39  *(this->_tree->data(p).iso_map.begin().val());
41 
42  for (const auto inst : seq) {
43  for (const auto input : inst->type().slotChains())
44  for (const auto inst2 : inst->getInstances(input->id()))
45  if ((!seq.exists(inst2))
46  && (!input_set.exists(
47  &(inst2->get(input->lastElt().safeName()))))) {
48  cost += std::log(input->type().variable().domainSize());
49  input_set.insert(&(inst2->get(input->lastElt().safeName())));
50  }
51 
52  for (auto vec = inst->beginInvRef(); vec != inst->endInvRef(); ++vec)
53  for (const auto inverse : *vec.val())
54  if (!seq.exists(inverse.first)) {
55  cost +=
56  std::log(inst->get(vec.key()).type().variable().domainSize());
57  break;
58  }
59  }
60 
61  return cost;
62  }
63 
64  template < typename GUM_SCALAR >
66  typename StrictSearch< GUM_SCALAR >::PData& data,
67  Set< Potential< GUM_SCALAR >* >& pool,
68  const Sequence< PRMInstance< GUM_SCALAR >* >& match) {
69  for (const auto inst : match) {
70  for (const auto& elt : *inst) {
71  // Adding the node
72  NodeId id = data.graph.addNode();
73  data.node2attr.insert(id, __str(inst, elt.second));
74  data.mod.insert(id, elt.second->type()->domainSize());
75  data.vars.insert(id, &elt.second->type().variable());
76  pool.insert(
77  const_cast< Potential< GUM_SCALAR >* >(&(elt.second->cpf())));
78  }
79  }
80 
81  // Second we add edges and nodes to inners or outputs
82  for (const auto inst : match)
83  for (const auto& elt : *inst) {
84  NodeId node = data.node2attr.first(__str(inst, elt.second));
85  bool found =
86  false; // If this is set at true, then node is an outer node
87 
88  // Children existing in the instance type's DAG
89  for (const auto chld :
90  inst->type().containerDag().children(elt.second->id())) {
91  data.graph.addEdge(
92  node, data.node2attr.first(__str(inst, inst->get(chld))));
93  }
94 
95  // Parents existing in the instance type's DAG
96  for (const auto par :
97  inst->type().containerDag().parents(elt.second->id())) {
98  switch (inst->type().get(par).elt_type()) {
101  data.graph.addEdge(
102  node, data.node2attr.first(__str(inst, inst->get(par))));
103  break;
104  }
105 
107  for (const auto inst2 : inst->getInstances(par))
108  if (match.exists(inst2))
109  data.graph.addEdge(
110  node,
111  data.node2attr.first(
112  __str(inst2,
113  static_cast< const PRMSlotChain< GUM_SCALAR >& >(
114  inst->type().get(par)))));
115 
116  break;
117  }
118 
119  default: { /* Do nothing */
120  }
121  }
122  }
123 
124  // Referring PRMAttribute<GUM_SCALAR>
125  if (inst->hasRefAttr(elt.second->id())) {
126  const std::vector<
127  std::pair< PRMInstance< GUM_SCALAR >*, std::string > >& ref_attr =
128  inst->getRefAttr(elt.second->id());
129 
130  for (auto pair = ref_attr.begin(); pair != ref_attr.end(); ++pair) {
131  if (match.exists(pair->first)) {
132  NodeId id = pair->first->type().get(pair->second).id();
133 
134  for (const auto child :
135  pair->first->type().containerDag().children(id))
136  data.graph.addEdge(node,
137  data.node2attr.first(__str(
138  pair->first, pair->first->get(child))));
139  } else {
140  found = true;
141  }
142  }
143  }
144 
145  if (found)
146  data.outputs.insert(node);
147  else
148  data.inners.insert(node);
149  }
150  }
151 
152  template < typename GUM_SCALAR >
154  typename StrictSearch< GUM_SCALAR >::PData& data,
155  Set< Potential< GUM_SCALAR >* >& pool) {
156  List< NodeSet > partial_order;
157 
158  if (data.inners.size()) partial_order.insert(data.inners);
159 
160  if (data.outputs.size()) partial_order.insert(data.outputs);
161 
162  PartialOrderedTriangulation t(&(data.graph), &(data.mod), &partial_order);
163  const std::vector< NodeId >& elim_order = t.eliminationOrder();
164  Size max(0), max_count(1);
166  Potential< GUM_SCALAR >* pot = 0;
167 
168  for (size_t idx = 0; idx < data.inners.size(); ++idx) {
170  pot->add(*(data.vars.second(elim_order[idx])));
171  trash.insert(pot);
172  Set< Potential< GUM_SCALAR >* > toRemove;
173 
174  for (const auto p : pool)
175  if (p->contains(*(data.vars.second(elim_order[idx])))) {
176  for (auto var = p->variablesSequence().begin();
177  var != p->variablesSequence().end();
178  ++var) {
179  try {
180  pot->add(**var);
181  } catch (DuplicateElement&) {}
182  }
183 
184  toRemove.insert(p);
185  }
186 
187  if (pot->domainSize() > max) {
188  max = pot->domainSize();
189  max_count = 1;
190  } else if (pot->domainSize() == max) {
191  ++max_count;
192  }
193 
194  for (const auto p : toRemove)
195  pool.erase(p);
196 
197  pot->erase(*(data.vars.second(elim_order[idx])));
198  }
199 
200  for (const auto pot : trash)
201  delete pot;
202 
203  return std::make_pair(max, max_count);
204  }
205 
206  // The SearchStrategy class
207  template < typename GUM_SCALAR >
209  GUM_CONSTRUCTOR(SearchStrategy);
210  }
211 
212  template < typename GUM_SCALAR >
214  const SearchStrategy< GUM_SCALAR >& from) :
215  _tree(from._tree) {
216  GUM_CONS_CPY(SearchStrategy);
217  }
218 
219  template < typename GUM_SCALAR >
221  GUM_DESTRUCTOR(SearchStrategy);
222  }
223 
224  template < typename GUM_SCALAR >
227  this->_tree = from._tree;
228  return *this;
229  }
230 
231  template < typename GUM_SCALAR >
232  INLINE void
234  this->_tree = tree;
235  }
236 
237  // FrequenceSearch
238 
239  // The FrequenceSearch class
240  template < typename GUM_SCALAR >
242  SearchStrategy< GUM_SCALAR >(), __freq(freq) {
243  GUM_CONSTRUCTOR(FrequenceSearch);
244  }
245 
246  template < typename GUM_SCALAR >
248  const FrequenceSearch< GUM_SCALAR >& from) :
249  SearchStrategy< GUM_SCALAR >(from),
250  __freq(from.__freq) {
251  GUM_CONS_CPY(FrequenceSearch);
252  }
253 
254  template < typename GUM_SCALAR >
256  GUM_DESTRUCTOR(FrequenceSearch);
257  }
258 
259  template < typename GUM_SCALAR >
262  __freq = from.__freq;
263  return *this;
264  }
265 
266  template < typename GUM_SCALAR >
268  return this->_tree->frequency(*r) >= __freq;
269  }
270 
271  template < typename GUM_SCALAR >
273  const Pattern* parent,
274  const Pattern* child,
275  const EdgeGrowth< GUM_SCALAR >& growh) {
276  return this->_tree->frequency(*child) >= __freq;
277  }
278 
279  template < typename GUM_SCALAR >
281  gspan::Pattern* j) {
282  // We want a descending order
283  return this->_tree->frequency(*i) > this->_tree->frequency(*j);
284  }
285 
286  template < typename GUM_SCALAR >
288  LabelData* j) {
289  return (this->_tree->graph().size(i) > this->_tree->graph().size(j));
290  }
291 
292  // StrictSearch
293 
294  // The StrictSearch class
295  template < typename GUM_SCALAR >
297  SearchStrategy< GUM_SCALAR >(), __freq(freq), __dot(".") {
298  GUM_CONSTRUCTOR(StrictSearch);
299  }
300 
301  template < typename GUM_SCALAR >
303  const StrictSearch< GUM_SCALAR >& from) :
304  SearchStrategy< GUM_SCALAR >(from),
305  __freq(from.__freq) {
306  GUM_CONS_CPY(StrictSearch);
307  }
308 
309  template < typename GUM_SCALAR >
311  GUM_DESTRUCTOR(StrictSearch);
312  }
313 
314  template < typename GUM_SCALAR >
317  __freq = from.__freq;
318  return *this;
319  }
320 
321  template < typename GUM_SCALAR >
323  return (this->_tree->frequency(*r) >= __freq);
324  }
325 
326  template < typename GUM_SCALAR >
328  const Pattern* parent,
329  const Pattern* child,
330  const EdgeGrowth< GUM_SCALAR >& growth) {
331  return __inner_cost(child)
332  + this->_tree->frequency(*child) * __outer_cost(child)
333  < this->_tree->frequency(*child) * __outer_cost(parent);
334  }
335 
336  template < typename GUM_SCALAR >
338  gspan::Pattern* j) {
339  return __inner_cost(i) + this->_tree->frequency(*i) * __outer_cost(i)
340  < __inner_cost(j) + this->_tree->frequency(*j) * __outer_cost(j);
341  }
342 
343  template < typename GUM_SCALAR >
345  LabelData* j) {
346  return i->tree_width * this->_tree->graph().size(i)
347  < j->tree_width * this->_tree->graph().size(j);
348  }
349 
350  template < typename GUM_SCALAR >
352  try {
353  return __map[p].first;
354  } catch (NotFound&) {
355  __compute_costs(p);
356  return __map[p].first;
357  }
358  }
359 
360  template < typename GUM_SCALAR >
362  try {
363  return __map[p].second;
364  } catch (NotFound&) {
365  __compute_costs(p);
366  return __map[p].second;
367  }
368  }
369 
370  template < typename GUM_SCALAR >
372  const PRMInstance< GUM_SCALAR >* i,
373  const PRMAttribute< GUM_SCALAR >* a) const {
374  return i->name() + __dot + a->safeName();
375  }
376 
377  template < typename GUM_SCALAR >
379  const PRMInstance< GUM_SCALAR >* i,
380  const PRMAttribute< GUM_SCALAR >& a) const {
381  return i->name() + __dot + a.safeName();
382  }
383 
384  template < typename GUM_SCALAR >
386  const PRMInstance< GUM_SCALAR >* i,
387  const PRMSlotChain< GUM_SCALAR >& a) const {
388  return i->name() + __dot + a.lastElt().safeName();
389  }
390 
391  template < typename GUM_SCALAR >
393  typename StrictSearch< GUM_SCALAR >::PData data;
396  data, pool, *(this->_tree->data(*p).iso_map.begin().val()));
397  double inner = std::log(__elimination_cost(data, pool).first);
398  double outer = this->_computeCost(*p);
399  __map.insert(p, std::make_pair(inner, outer));
400  }
401 
402  // TreeWidthSearch
403 
404  template < typename GUM_SCALAR >
406  SearchStrategy< GUM_SCALAR >() {
407  GUM_CONSTRUCTOR(TreeWidthSearch);
408  }
409 
410  template < typename GUM_SCALAR >
412  const TreeWidthSearch< GUM_SCALAR >& from) :
413  SearchStrategy< GUM_SCALAR >(from) {
414  GUM_CONS_CPY(TreeWidthSearch);
415  }
416 
417  template < typename GUM_SCALAR >
419  GUM_DESTRUCTOR(TreeWidthSearch);
420  }
421 
422  template < typename GUM_SCALAR >
425  return *this;
426  }
427 
428  template < typename GUM_SCALAR >
430  try {
431  return __map[&p];
432  } catch (NotFound&) {
433  __map.insert(&p, this->_computeCost(p));
434  return __map[&p];
435  }
436  }
437 
438  template < typename GUM_SCALAR >
440  Size tree_width = 0;
441 
442  for (const auto n : r->nodes())
443  tree_width += r->label(n).tree_width;
444 
445  return tree_width >= cost(*r);
446  }
447 
448  template < typename GUM_SCALAR >
450  const Pattern* parent,
451  const Pattern* child,
452  const EdgeGrowth< GUM_SCALAR >& growth) {
453  return cost(*parent) >= cost(*child);
454  }
455 
456  template < typename GUM_SCALAR >
458  gspan::Pattern* j) {
459  return cost(*i) < cost(*j);
460  }
461 
462  template < typename GUM_SCALAR >
464  LabelData* j) {
465  return i->tree_width < j->tree_width;
466  }
467 
468  } /* namespace gspan */
469  } /* namespace prm */
470 } /* namespace gum */
void setTree(DFSTree< GUM_SCALAR > *tree)
void insert(const T1 &first, const T2 &second)
Inserts a new association in the gum::Bijection.
aGrUM&#39;s Potential is a multi-dimensional array with tensor operators.
Definition: potential.h:60
void __buildPatternGraph(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Potential< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
This class is used to define an edge growth of a pattern in this DFSTree.
Definition: edgeGrowth.h:63
NodeSet inners
Returns the set of inner nodes.
virtual Size domainSize() const final
Returns the product of the variables domain size.
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
virtual bool operator()(LabelData *i, LabelData *j)
const T1 & first(const T2 &second) const
Returns the first value of a pair given its second value.
Inner class to handle data about labels in this interface graph.
const std::string & name() const
Returns the name of this object.
Definition: PRMObject_inl.h:35
virtual bool accept_root(const Pattern *r)
An PRMInstance is a Bayesian Network fragment defined by a Class and used in a PRMSystem.
Definition: PRMInstance.h:63
virtual ~TreeWidthSearch()
Destructor.
double __outer_cost(const Pattern *p)
StrictSearch & operator=(const StrictSearch &from)
Copy operator.
The generic class for storing (ordered) sequences of objects.
Definition: sequence.h:1022
virtual void addEdge(const NodeId first, const NodeId second)
insert a new edge into the undirected graph
Definition: undiGraph_inl.h:35
virtual void erase(const DiscreteVariable &var) final
Removes a var from the variables of the multidimensional matrix.
virtual ~FrequenceSearch()
Destructor.
Abstract class representing an element of PRM class.
UndiGraph graph
A yet to be triangulated undigraph.
std::pair< Size, Size > __elimination_cost(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Potential< GUM_SCALAR > * > &pool)
Bijection< NodeId, const DiscreteVariable *> vars
Bijection between graph&#39;s nodes and their corresponding DiscreteVariable, for inference purpose...
Generic doubly linked lists.
Definition: list.h:372
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
virtual NodeId addNode()
insert a new node and return its id
LabelData & label(NodeId node)
Returns the LabelData assigned to node.
Definition: pattern_inl.h:50
DFSTree< GUM_SCALAR > * _tree
virtual bool operator()(LabelData *i, LabelData *j)
FrequenceSearch & operator=(const FrequenceSearch &from)
Copy operator.
Representation of a setA Set is a structure that contains arbitrary elements.
Definition: set.h:165
StrictSearch(Size freq=2)
Default constructor.
const std::string & safeName() const
Returns the safe name of this PRMClassElement, if any.
A DFSTree is used by gspan to sort lexicographically patterns discovered in an interface graph...
Definition: DFSTree.h:71
A growth is accepted if and only if the new growth has a tree width less large or equal than its fath...
std::string __str(const PRMInstance< GUM_SCALAR > *i, const PRMAttribute< GUM_SCALAR > *a) const
virtual ~SearchStrategy()
Destructor.
virtual bool accept_root(const Pattern *r)
HashTable< const Pattern *, double > __map
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
class for graph triangulations for which we enforce a given partial ordering on the nodes elimination...
NodeSet outputs
Returns the set of outputs nodes given all the matches of pattern.
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
virtual bool accept_root(const Pattern *r)
const std::vector< NodeId > & eliminationOrder()
returns an elimination ordering compatible with the triangulated graph
Bijection< NodeId, std::string > node2attr
A bijection to easily keep track between graph and attributes, its of the form instance_name DOT attr...
TreeWidthSearch()
Default constructor.
bool exists(const Key &k) const
Check the existence of k in the sequence.
Definition: sequence_tpl.h:402
NodeProperty< Size > mod
The pattern&#39;s variables modalities.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
FrequenceSearch(Size freq)
Default constructor.
A PRMSlotChain represents a sequence of gum::prm::PRMClassElement<GUM_SCALAR> where the n-1 first gum...
Definition: PRMObject.h:221
const NodeGraphPart & nodes() const
Definition: pattern_inl.h:166
Private structure to represent data about a pattern.
HashTable< const Pattern *, std::pair< double, double > > __map
This is class is an implementation of a strict strategy for the GSpan algorithm.
Val & insert(const Val &val)
Inserts a new element at the end of the chained list (alias of pushBack).
Definition: list_tpl.h:1619
virtual void add(const DiscreteVariable &v) final
Adds a new var to the variables of the multidimensional matrix.
double __inner_cost(const Pattern *p)
This is class is an implementation of a simple serach strategy for the gspan algorithm: it accept a g...
SearchStrategy()
Default constructor.
double _computeCost(const Pattern &p)
void __compute_costs(const Pattern *p)
Size tree_width
The size in terms of tree width of the given label.
PRMAttribute is a member of a Class in a PRM.
Definition: PRMAttribute.h:61
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition: types.h:48
PRMClassElement< GUM_SCALAR > & lastElt()
Returns the last element of the slot chain, typically this is an gum::PRMAttribute or a gum::PRMAggre...
Size size() const noexcept
Returns the number of elements in the set.
Definition: set_tpl.h:701
TreeWidthSearch & operator=(const TreeWidthSearch &from)
Copy operator.
virtual ~StrictSearch()
Destructor.
This contains all the information we want for a node in a DFSTree.
Definition: pattern.h:73
This is an abstract class used to tune search strategies in the gspan algorithm.
Definition: DFSTree.h:61
Size NodeId
Type for node ids.
Definition: graphElements.h:98
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:613
Multidimensional matrix stored as a sparse array in memory.
SearchStrategy< GUM_SCALAR > & operator=(const SearchStrategy< GUM_SCALAR > &from)
Copy operator.
virtual bool operator()(LabelData *i, LabelData *j)
void insert(const Key &k)
Insert an element at the end of the sequence.
Definition: sequence_tpl.h:408