aGrUM  0.14.2
graphChangesSelector4DiGraph_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Christophe GONZALES and Pierre-Henri WUILLEMIN *
3  * {prenom.nom}@lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it wil be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
26 #ifndef DOXYGEN_SHOULD_SKIP_THIS
27 
28 # include <limits>
29 
30 namespace gum {
31 
32  namespace learning {
33 
35  template < typename STRUCTURAL_CONSTRAINT,
36  typename GRAPH_CHANGES_GENERATOR,
37  template < typename >
38  class ALLOC >
39  INLINE GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
40  GRAPH_CHANGES_GENERATOR,
41  ALLOC >::
42  GraphChangesSelector4DiGraph(Score< ALLOC >& score,
43  STRUCTURAL_CONSTRAINT& constraint,
44  GRAPH_CHANGES_GENERATOR& changes_generator) :
45  __score(score.clone()),
46  __constraint(&constraint), __changes_generator(&changes_generator) {
47  __parents.resize(32);
48  GUM_CONSTRUCTOR(GraphChangesSelector4DiGraph);
49  }
50 
52  template < typename STRUCTURAL_CONSTRAINT,
53  typename GRAPH_CHANGES_GENERATOR,
54  template < typename >
55  class ALLOC >
56  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
57  GRAPH_CHANGES_GENERATOR,
58  ALLOC >::
60  const GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
61  GRAPH_CHANGES_GENERATOR,
62  ALLOC >& from) :
63  __score(from.__score != nullptr ? from.__score->clone() : nullptr),
72  // for debugging
73  GUM_CONS_CPY(GraphChangesSelector4DiGraph);
74  }
75 
77  template < typename STRUCTURAL_CONSTRAINT,
78  typename GRAPH_CHANGES_GENERATOR,
79  template < typename >
80  class ALLOC >
81  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
82  GRAPH_CHANGES_GENERATOR,
83  ALLOC >::
85  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
86  GRAPH_CHANGES_GENERATOR,
87  ALLOC >&& from) :
88  __score(from.__score),
89  __constraint(std::move(from.__constraint)),
91  __changes(std::move(from.__changes)),
92  __change_scores(std::move(from.__change_scores)),
94  __node_queue(std::move(from.__node_queue)),
97  __parents(std::move(from.__parents)),
98  __queues_valid(std::move(from.__queues_valid)),
100  from.__score = nullptr;
101  // for debugging
102  GUM_CONS_MOV(GraphChangesSelector4DiGraph);
103  }
104 
106  template < typename STRUCTURAL_CONSTRAINT,
107  typename GRAPH_CHANGES_GENERATOR,
108  template < typename >
109  class ALLOC >
111 
112  STRUCTURAL_CONSTRAINT,
113  GRAPH_CHANGES_GENERATOR,
115  if (__score != nullptr) {
116  ALLOC< Score< ALLOC > > allocator(__score->getAllocator());
117  allocator.destroy(__score);
118  allocator.deallocate(__score, 1);
119  }
120  GUM_DESTRUCTOR(GraphChangesSelector4DiGraph);
121  }
122 
124  template < typename STRUCTURAL_CONSTRAINT,
125  typename GRAPH_CHANGES_GENERATOR,
126  template < typename >
127  class ALLOC >
128  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
129  GRAPH_CHANGES_GENERATOR,
130  ALLOC >&
131  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
132  GRAPH_CHANGES_GENERATOR,
133  ALLOC >::
134  operator=(const GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
135  GRAPH_CHANGES_GENERATOR,
136  ALLOC >& from) {
137  if (this != &from) {
138  // remove the old score
139  if (__score != nullptr) {
140  ALLOC< Score< ALLOC > > allocator(__score->getAllocator());
141  allocator.destroy(__score);
142  allocator.deallocate(__score, 1);
143  __score = nullptr;
144  }
145 
146  if (from.__score != nullptr) __score = from.__score->clone();
147  __constraint = from.__constraint;
148  __changes_generator = from.__changes_generator;
149  __changes = from.__changes;
150  __change_scores = from.__change_scores;
151  __change_queue_per_node = from.__change_queue_per_node;
152  __node_queue = from.__node_queue;
153  __illegal_changes = from.__illegal_changes;
154  __node_current_scores = from.__node_current_scores;
155  __parents = from.__parents;
156  __queues_valid = from.__queues_valid;
157  __queues_to_update = from.__queues_to_update;
158  }
159 
160  return *this;
161  }
162 
164  template < typename STRUCTURAL_CONSTRAINT,
165  typename GRAPH_CHANGES_GENERATOR,
166  template < typename >
167  class ALLOC >
168  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
169  GRAPH_CHANGES_GENERATOR,
170  ALLOC >&
171  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
172  GRAPH_CHANGES_GENERATOR,
173  ALLOC >::
174  operator=(GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
175  GRAPH_CHANGES_GENERATOR,
176  ALLOC >&& from) {
177  if (this != &from) {
178  __score = from.__score;
179  from.__score = nullptr;
180 
181  __constraint = std::move(from.__constraint);
182  __changes_generator = std::move(from.__changes_generator);
183  __changes = std::move(from.__changes);
184  __change_scores = std::move(from.__change_scores);
185  __change_queue_per_node = std::move(from.__change_queue_per_node);
186  __node_queue = std::move(from.__node_queue);
187  __illegal_changes = std::move(from.__illegal_changes);
188  __node_current_scores = std::move(from.__node_current_scores);
189  __parents = std::move(from.__parents);
190  __queues_valid = std::move(from.__queues_valid);
191  __queues_to_update = std::move(from.__queues_to_update);
192  }
193 
194  return *this;
195  }
196 
197 
199  template < typename STRUCTURAL_CONSTRAINT,
200  typename GRAPH_CHANGES_GENERATOR,
201  template < typename >
202  class ALLOC >
203  INLINE bool
204  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
205  GRAPH_CHANGES_GENERATOR,
206  ALLOC >::isChangeValid(const GraphChange&
207  change) const {
208  return __constraint->checkModification(change);
209  }
210 
211 
213  template < typename STRUCTURAL_CONSTRAINT,
214  typename GRAPH_CHANGES_GENERATOR,
215  template < typename >
216  class ALLOC >
217  INLINE bool
218  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
219  GRAPH_CHANGES_GENERATOR,
220  ALLOC >::__isChangeValid(const std::size_t
221  index) const {
222  return isChangeValid(__changes[index]);
223  }
224 
225 
227  template < typename STRUCTURAL_CONSTRAINT,
228  typename GRAPH_CHANGES_GENERATOR,
229  template < typename >
230  class ALLOC >
231  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
232  GRAPH_CHANGES_GENERATOR,
233  ALLOC >::setGraph(DiGraph& graph) {
234  // fill the DAG with all the missing nodes
235  const DatabaseTable< ALLOC >& database = __score->database();
236  const auto& nodeId2Columns = __score->nodeId2Columns();
237 
238  if (nodeId2Columns.empty()) {
239  const NodeId nb_nodes = NodeId(database.nbVariables());
240  for (NodeId i = NodeId(0); i < nb_nodes; ++i) {
241  if (!graph.existsNode(i)) { graph.addNodeWithId(i); }
242  }
243  } else {
244  for (auto iter = nodeId2Columns.cbegin(); iter != nodeId2Columns.cend();
245  ++iter) {
246  const NodeId id = iter.first();
247  if (!graph.existsNode(id)) { graph.addNodeWithId(id); }
248  }
249  }
250 
251 
252  // remove the node that do belong neither to the database
253  // nor to nodeId2Columns
254  if (nodeId2Columns.empty()) {
255  const NodeId nb_nodes = NodeId(database.nbVariables());
256  for (auto node : graph) {
257  if (node >= nb_nodes) { graph.eraseNode(node); }
258  }
259  } else {
260  for (auto node : graph) {
261  if (!nodeId2Columns.existsFirst(node)) { graph.eraseNode(node); }
262  }
263  }
264 
265 
266  // __constraint is the constraint used by the selector to restrict the set
267  // of applicable changes. However, the generator may have a different set
268  // of constraints (e.g., a constraintSliceOrder needs be tested only by the
269  // generator because the changes returned by the generator will always
270  // statisfy this constraint, hence the selector needs not test this
271  // constraint). Therefore, if the selector and generator have different
272  // constraints, both should use method setGraph() to initialize
273  // themselves.
274  __constraint->setGraph(graph);
275  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
276  &(__changes_generator->constraint()))
277  != __constraint) {
278  __changes_generator->constraint().setGraph(graph);
279  }
280 
281  __changes_generator->setGraph(graph);
282 
283 
284  // save the set of parents of each node (this will speed-up the
285  // computations of the scores)
286  const std::size_t nb_nodes = graph.size();
287  {
288  const std::vector< NodeId, ALLOC< NodeId > > empty_pars;
289  __parents.clear();
290  __parents.resize(nb_nodes);
291  for (const auto node : graph) {
292  auto& node_parents = __parents.insert(node, empty_pars).second;
293  const NodeSet& dag_parents = graph.parents(node);
294  if (!dag_parents.empty()) {
295  node_parents.resize(dag_parents.size());
296  std::size_t j = std::size_t(0);
297  for (const auto par : dag_parents) {
298  node_parents[j] = par;
299  ++j;
300  }
301  }
302  }
303  }
304 
305  // assign a score to each node given its parents in the current graph
306  __node_current_scores.clear();
307  __node_current_scores.resize(nb_nodes);
308  for (const auto node : graph) {
309  __node_current_scores.insert(node, __score->score(node, __parents[node]));
310  }
311 
312  // compute all the possible changes
313  __changes.clear();
314  __changes.resize(nb_nodes);
315  for (const auto& change : *__changes_generator) {
316  __changes << change;
317  }
318  __changes_generator->notifyGetCompleted();
319 
320  // determine the changes that are illegal and prepare the computation of
321  // the scores of all the legal changes
323 
324  // set the __change_scores and __change_queue_per_node for legal changes
325  __change_scores.clear();
326  __change_scores.resize(
327  __changes.size(),
328  std::pair< double, double >(std::numeric_limits< double >::min(),
329  std::numeric_limits< double >::min()));
330  __change_queue_per_node.clear();
331  __change_queue_per_node.resize(nb_nodes);
332  {
333  const PriorityQueue< std::size_t, double, std::greater< double > >
334  empty_prio;
335  for (const auto node : graph) {
336  __change_queue_per_node.insert(node, empty_prio);
337  }
338  }
339 
340  for (std::size_t i = std::size_t(0); i < __changes.size(); ++i) {
341  if (!__isChangeValid(i)) {
343  } else {
344  const GraphChange& change = __changes[i];
345 
346  switch (change.type()) {
348  auto& parents = __parents[change.node2()];
349  parents.push_back(change.node1());
350  const double delta = __score->score(change.node2(), parents)
351  - __node_current_scores[change.node2()];
352  parents.pop_back();
353 
354  __change_scores[i].second = delta;
355  __change_queue_per_node[change.node2()].insert(i, delta);
356  } break;
357 
359  auto& parents = __parents[change.node2()];
360  for (auto& par : parents) {
361  if (par == change.node1()) {
362  par = *(parents.rbegin());
363  parents.pop_back();
364  break;
365  }
366  }
367  const double delta = __score->score(change.node2(), parents)
368  - __node_current_scores[change.node2()];
369  parents.push_back(change.node1());
370 
371  __change_scores[i].second = delta;
372  __change_queue_per_node[change.node2()].insert(i, delta);
373  } break;
374 
376  // remove arc ( node1 -> node2 )
377  auto& parents2 = __parents[change.node2()];
378  for (auto& par : parents2) {
379  if (par == change.node1()) {
380  par = *(parents2.rbegin());
381  parents2.pop_back();
382  break;
383  }
384  }
385 
386  const double delta2 = __score->score(change.node2(), parents2)
387  - __node_current_scores[change.node2()];
388  parents2.push_back(change.node1());
389 
390  // add arc ( node2 -> node1 )
391  auto& parents1 = __parents[change.node1()];
392  parents1.push_back(change.node2());
393  const double delta1 = __score->score(change.node1(), parents1)
394  - __node_current_scores[change.node1()];
395  parents1.pop_back();
396 
397  __change_scores[i].first = delta1;
398  __change_scores[i].second = delta2;
399 
400  const double delta = delta1 + delta2;
401  __change_queue_per_node[change.node1()].insert(i, delta);
402  __change_queue_per_node[change.node2()].insert(i, delta);
403 
404  } break;
405 
406  default: {
407  GUM_ERROR(NotImplementedYet,
408  "Method setGraph of GraphChangesSelector4DiGraph "
409  << "does not handle yet graph change of type "
410  << change.type());
411  }
412  }
413  }
414  }
415 
416  // update the global queue
418  for (const auto node : graph) {
419  __node_queue.insert(node,
421  ? std::numeric_limits< double >::min()
422  : __change_queue_per_node[node].topPriority());
423  }
424  __queues_valid = true;
426  }
427 
428 
430  template < typename STRUCTURAL_CONSTRAINT,
431  typename GRAPH_CHANGES_GENERATOR,
432  template < typename >
433  class ALLOC >
434  void
435  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
436  GRAPH_CHANGES_GENERATOR,
437  ALLOC >::__invalidateChange(const std::size_t
438  change_index) {
439  const GraphChange& change = __changes[change_index];
440  if (change.type() == GraphChangeType::ARC_REVERSAL) {
441  // remove the tail change from its priority queue
442  PriorityQueue< std::size_t, double, std::greater< double > >& queue1 =
443  __change_queue_per_node[change.node1()];
444  queue1.erase(change_index);
445 
446  // recompute the top priority for the changes of the head
447  const double new_priority = queue1.empty()
448  ? std::numeric_limits< double >::min()
449  : queue1.topPriority();
450  __node_queue.setPriority(change.node1(), new_priority);
451  }
452 
453  // remove the head change from its priority queue
454  PriorityQueue< std::size_t, double, std::greater< double > >& queue2 =
455  __change_queue_per_node[change.node2()];
456  queue2.erase(change_index);
457 
458  // recompute the top priority for the changes of the head
459  const double new_priority = queue2.empty()
460  ? std::numeric_limits< double >::min()
461  : queue2.topPriority();
462  __node_queue.setPriority(change.node2(), new_priority);
463 
464  // put the change into the illegal set
465  __illegal_changes.insert(change_index);
466  }
467 
468 
470  template < typename STRUCTURAL_CONSTRAINT,
471  typename GRAPH_CHANGES_GENERATOR,
472  template < typename >
473  class ALLOC >
474  bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
475  GRAPH_CHANGES_GENERATOR,
476  ALLOC >::empty() {
477  // put into the illegal change set all the top elements of the different
478  // queues that are not valid anymore
479  if (!__queues_valid) {
480  for (auto& queue_pair : __change_queue_per_node) {
481  auto& queue = queue_pair.second;
482  while (!queue.empty() && !__isChangeValid(queue.top())) {
483  __invalidateChange(queue.top());
484  }
485  }
486  __queues_valid = true;
487  }
488 
489  return __node_queue.topPriority() == std::numeric_limits< double >::min();
490  }
491 
492 
495  template < typename STRUCTURAL_CONSTRAINT,
496  typename GRAPH_CHANGES_GENERATOR,
497  template < typename >
498  class ALLOC >
499  bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
500  GRAPH_CHANGES_GENERATOR,
501  ALLOC >::empty(const NodeId node) {
502  // put into the illegal change set all the top elements of the different
503  // queues that are not valid anymore
504  if (!__queues_valid) {
505  for (auto& queue_pair : __change_queue_per_node) {
506  auto& queue = queue_pair.second;
507  while (!queue.empty() && !__isChangeValid(queue.top())) {
508  __invalidateChange(queue.top());
509  }
510  }
511  __queues_valid = true;
512  }
513 
514  return __change_queue_per_node[node].empty();
515  }
516 
517 
519  template < typename STRUCTURAL_CONSTRAINT,
520  typename GRAPH_CHANGES_GENERATOR,
521  template < typename >
522  class ALLOC >
523  INLINE const GraphChange&
524  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
525  GRAPH_CHANGES_GENERATOR,
526  ALLOC >::bestChange() {
527  if (!empty())
528  return __changes[__change_queue_per_node[__node_queue.top()].top()];
529  else
530  GUM_ERROR(NotFound, "there exists no graph change applicable");
531  }
532 
533 
535  template < typename STRUCTURAL_CONSTRAINT,
536  typename GRAPH_CHANGES_GENERATOR,
537  template < typename >
538  class ALLOC >
539  INLINE const GraphChange&
540  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
541  GRAPH_CHANGES_GENERATOR,
542  ALLOC >::bestChange(const NodeId node) {
543  if (!empty(node))
544  return __changes[__change_queue_per_node[node].top()];
545  else
546  GUM_ERROR(NotFound, "there exists no graph change applicable");
547  }
548 
549 
551  template < typename STRUCTURAL_CONSTRAINT,
552  typename GRAPH_CHANGES_GENERATOR,
553  template < typename >
554  class ALLOC >
555  INLINE double GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
556  GRAPH_CHANGES_GENERATOR,
557  ALLOC >::bestScore() {
558  if (!empty())
559  return __change_queue_per_node[__node_queue.top()].topPriority();
560  else
561  GUM_ERROR(NotFound, "there exists no graph change applicable");
562  }
563 
564 
566  template < typename STRUCTURAL_CONSTRAINT,
567  typename GRAPH_CHANGES_GENERATOR,
568  template < typename >
569  class ALLOC >
570  INLINE double
571  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
572  GRAPH_CHANGES_GENERATOR,
573  ALLOC >::bestScore(const NodeId node) {
574  if (!empty(node))
575  return __change_queue_per_node[node].topPriority();
576  else
577  GUM_ERROR(NotFound, "there exists no graph change applicable");
578  }
579 
580 
582  template < typename STRUCTURAL_CONSTRAINT,
583  typename GRAPH_CHANGES_GENERATOR,
584  template < typename >
585  class ALLOC >
587  STRUCTURAL_CONSTRAINT,
588  GRAPH_CHANGES_GENERATOR,
589  ALLOC >::__illegal2LegalChanges(Set< std::size_t >& changes_to_recompute) {
590  for (auto iter = __illegal_changes.beginSafe();
591  iter != __illegal_changes.endSafe();
592  ++iter) {
593  if (__isChangeValid(*iter)) {
594  const GraphChange& change = __changes[*iter];
595  if (change.type() == GraphChangeType::ARC_REVERSAL) {
596  __change_queue_per_node[change.node1()].insert(
597  *iter, std::numeric_limits< double >::min());
598  }
599  __change_queue_per_node[change.node2()].insert(
600  *iter, std::numeric_limits< double >::min());
601 
602  changes_to_recompute.insert(*iter);
603  __illegal_changes.erase(iter);
604  }
605  }
606  }
607 
608 
610  template < typename STRUCTURAL_CONSTRAINT,
611  typename GRAPH_CHANGES_GENERATOR,
612  template < typename >
613  class ALLOC >
614  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
615  GRAPH_CHANGES_GENERATOR,
616  ALLOC >::
618  const NodeId target_node) {
619  const HashTable< std::size_t, Size >& changes =
620  __change_queue_per_node[target_node].allValues();
621  for (auto iter = changes.cbeginSafe(); iter != changes.cendSafe(); ++iter) {
622  if (!changes_to_recompute.exists(iter.key())) {
623  if (__isChangeValid(iter.key())) {
624  changes_to_recompute.insert(iter.key());
625  } else {
626  __invalidateChange(iter.key());
627  }
628  }
629  }
630  }
631 
632 
634  template < typename STRUCTURAL_CONSTRAINT,
635  typename GRAPH_CHANGES_GENERATOR,
636  template < typename >
637  class ALLOC >
639  STRUCTURAL_CONSTRAINT,
640  GRAPH_CHANGES_GENERATOR,
641  ALLOC >::__updateScores(const Set< std::size_t >& changes_to_recompute) {
642  Set< NodeId > modified_nodes(changes_to_recompute.size());
643 
644  for (const auto change_index : changes_to_recompute) {
645  const GraphChange& change = __changes[change_index];
646 
647  switch (change.type()) {
649  // add the arc
650  auto& parents = __parents[change.node2()];
651  parents.push_back(change.node1());
652  const double delta = __score->score(change.node2(), parents)
653  - __node_current_scores[change.node2()];
654  parents.pop_back();
655 
656  // update the score
657  __change_scores[change_index].second = delta;
658 
659  // update the head queue
660  __change_queue_per_node[change.node2()].setPriority(change_index,
661  delta);
662  // indicate which queue was modified
663  modified_nodes.insert(change.node2());
664  } break;
665 
667  // remove the arc
668  auto& parents = __parents[change.node2()];
669  for (auto& par : parents) {
670  if (par == change.node1()) {
671  par = *(parents.rbegin());
672  parents.pop_back();
673  break;
674  }
675  }
676  const double delta = __score->score(change.node2(), parents)
677  - __node_current_scores[change.node2()];
678  parents.push_back(change.node1());
679 
680  // update the score
681  __change_scores[change_index].second = delta;
682 
683  // update the head queue
684  __change_queue_per_node[change.node2()].setPriority(change_index,
685  delta);
686  // indicate which queue was modified
687  modified_nodes.insert(change.node2());
688  } break;
689 
691  // remove arc ( node1 -> node2 )
692  auto& parents2 = __parents[change.node2()];
693  for (auto& par : parents2) {
694  if (par == change.node1()) {
695  par = *(parents2.rbegin());
696  parents2.pop_back();
697  break;
698  }
699  }
700 
701  const double delta2 = __score->score(change.node2(), parents2)
702  - __node_current_scores[change.node2()];
703  parents2.push_back(change.node1());
704 
705  // add arc ( node2 -> node1 )
706  auto& parents1 = __parents[change.node1()];
707  parents1.push_back(change.node2());
708  const double delta1 = __score->score(change.node1(), parents1)
709  - __node_current_scores[change.node1()];
710  parents1.pop_back();
711 
712  // update the scores
713  __change_scores[change_index].first = delta1;
714  __change_scores[change_index].second = delta2;
715 
716  // update the queues
717  const double delta = delta1 + delta2;
718  __change_queue_per_node[change.node1()].setPriority(change_index,
719  delta);
720  __change_queue_per_node[change.node2()].setPriority(change_index,
721  delta);
722 
723  // indicate which queues were modified
724  modified_nodes.insert(change.node1());
725  modified_nodes.insert(change.node2());
726  } break;
727 
728  default: {
729  GUM_ERROR(NotImplementedYet,
730  "Method __updateScores of GraphChangesSelector4DiGraph "
731  << "does not handle yet graph change of type "
732  << change.type());
733  }
734  }
735  }
736 
737  // update the node queue
738  for (const auto node : modified_nodes) {
740  __change_queue_per_node[node].empty()
741  ? std::numeric_limits< double >::min()
742  : __change_queue_per_node[node].topPriority());
743  }
744  }
745 
746 
748  template < typename STRUCTURAL_CONSTRAINT,
749  typename GRAPH_CHANGES_GENERATOR,
750  template < typename >
751  class ALLOC >
752  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
753  GRAPH_CHANGES_GENERATOR,
754  ALLOC >::__getNewChanges() {
755  // ask the graph change generator for all its available changes
756  for (const auto& change : *__changes_generator) {
757  // check that the change does not already exist
758  if (!__changes.exists(change)) {
759  // add the new change. To make the addition simple, we put the new
760  // change into the illegal changes set. Afterwards, the applyChange
761  // function will put the legal changes again into the queues
763  __changes << change;
764  __change_scores.push_back(
765  std::pair< double, double >(std::numeric_limits< double >::min(),
766  std::numeric_limits< double >::min()));
767  }
768  }
769 
770  // indicate to the generator that we have finished retrieving its changes
771  __changes_generator->notifyGetCompleted();
772  }
773 
774 
776  template < typename STRUCTURAL_CONSTRAINT,
777  typename GRAPH_CHANGES_GENERATOR,
778  template < typename >
779  class ALLOC >
780  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
781  GRAPH_CHANGES_GENERATOR,
782  ALLOC >::applyChange(const GraphChange&
783  change) {
784  // first, we get the index of the change
785  const std::size_t change_index = __changes.pos(change);
786 
787  // perform the change
788  Set< std::size_t > changes_to_recompute;
789  switch (change.type()) {
791  // update the current score
792  __node_current_scores[change.node2()] +=
793  __change_scores[change_index].second;
794  __parents[change.node2()].push_back(change.node1());
795 
796  // inform the constraint that the graph has been modified
797  __constraint->modifyGraph(static_cast< const ArcAddition& >(change));
798  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
799  &(__changes_generator->constraint()))
800  != __constraint) {
801  __changes_generator->constraint().modifyGraph(
802  static_cast< const ArcAddition& >(change));
803  }
804 
805  // get new possible changes from the graph change generator
806  // warning: put the next 3 lines before calling __illegal2LegalChanges
807  __changes_generator->modifyGraph(
808  static_cast< const ArcAddition& >(change));
809  __getNewChanges();
810 
811  // check whether some illegal changes can be put into the valid queues
812  __illegal2LegalChanges(changes_to_recompute);
813  __invalidateChange(change_index);
814  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node2());
815  __updateScores(changes_to_recompute);
816  } break;
817 
819  // update the current score
820  __node_current_scores[change.node2()] +=
821  __change_scores[change_index].second;
822  auto& parents = __parents[change.node2()];
823  for (auto& par : parents) {
824  if (par == change.node1()) {
825  par = *(parents.rbegin());
826  parents.pop_back();
827  break;
828  }
829  }
830 
831  // inform the constraint that the graph has been modified
832  __constraint->modifyGraph(static_cast< const ArcDeletion& >(change));
833  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
834  &(__changes_generator->constraint()))
835  != __constraint) {
836  __changes_generator->constraint().modifyGraph(
837  static_cast< const ArcDeletion& >(change));
838  }
839 
840  // get new possible changes from the graph change generator
841  // warning: put the next 3 lines before calling __illegal2LegalChanges
842  __changes_generator->modifyGraph(
843  static_cast< const ArcDeletion& >(change));
844  __getNewChanges();
845 
846  // check whether some illegal changes can be put into the valid queues
847  __illegal2LegalChanges(changes_to_recompute);
848  __invalidateChange(change_index);
849  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node2());
850  __updateScores(changes_to_recompute);
851  } break;
852 
854  // update the current score
855  __node_current_scores[change.node1()] +=
856  __change_scores[change_index].first;
857  __node_current_scores[change.node2()] +=
858  __change_scores[change_index].second;
859  __parents[change.node1()].push_back(change.node2());
860  auto& parents = __parents[change.node2()];
861  for (auto& par : parents) {
862  if (par == change.node1()) {
863  par = *(parents.rbegin());
864  parents.pop_back();
865  break;
866  }
867  }
868 
869  // inform the constraint that the graph has been modified
870  __constraint->modifyGraph(static_cast< const ArcReversal& >(change));
871  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
872  &(__changes_generator->constraint()))
873  != __constraint) {
874  __changes_generator->constraint().modifyGraph(
875  static_cast< const ArcReversal& >(change));
876  }
877 
878  // get new possible changes from the graph change generator
879  // warning: put the next 3 lines before calling __illegal2LegalChanges
880  __changes_generator->modifyGraph(
881  static_cast< const ArcReversal& >(change));
882  __getNewChanges();
883 
884  // check whether some illegal changes can be put into the valid queues
885  __illegal2LegalChanges(changes_to_recompute);
886  __invalidateChange(change_index);
887  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node1());
888  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node2());
889  __updateScores(changes_to_recompute);
890  } break;
891 
892  default:
893  GUM_ERROR(NotImplementedYet,
894  "Method applyChange of GraphChangesSelector4DiGraph "
895  << "does not handle yet graph change of type "
896  << change.type());
897  }
898 
899  __queues_valid = false;
900  }
901 
902 
904  template < typename STRUCTURAL_CONSTRAINT,
905  typename GRAPH_CHANGES_GENERATOR,
906  template < typename >
907  class ALLOC >
909  STRUCTURAL_CONSTRAINT,
910  GRAPH_CHANGES_GENERATOR,
911  ALLOC >::applyChangeWithoutScoreUpdate(const GraphChange& change) {
912  // first, we get the index of the change
913  const std::size_t change_index = __changes.pos(change);
914 
915  // perform the change
916  switch (change.type()) {
918  // update the current score
919  __node_current_scores[change.node2()] +=
920  __change_scores[change_index].second;
921  __parents[change.node2()].push_back(change.node1());
922 
923  // inform the constraint that the graph has been modified
924  __constraint->modifyGraph(static_cast< const ArcAddition& >(change));
925  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
926  &(__changes_generator->constraint()))
927  != __constraint) {
928  __changes_generator->constraint().modifyGraph(
929  static_cast< const ArcAddition& >(change));
930  }
931 
932  // get new possible changes from the graph change generator
933  // warning: put the next 3 lines before calling __illegal2LegalChanges
934  __changes_generator->modifyGraph(
935  static_cast< const ArcAddition& >(change));
936  __getNewChanges();
937 
938  // indicate that we have just applied the change
939  __invalidateChange(change_index);
940 
941  // indicate that the queue to which the change belongs needs be
942  // updated
943  __queues_to_update.insert(change.node2());
944  } break;
945 
947  // update the current score
948  __node_current_scores[change.node2()] +=
949  __change_scores[change_index].second;
950  auto& parents = __parents[change.node2()];
951  for (auto& par : parents) {
952  if (par == change.node1()) {
953  par = *(parents.rbegin());
954  parents.pop_back();
955  break;
956  }
957  }
958 
959  // inform the constraint that the graph has been modified
960  __constraint->modifyGraph(static_cast< const ArcDeletion& >(change));
961  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
962  &(__changes_generator->constraint()))
963  != __constraint) {
964  __changes_generator->constraint().modifyGraph(
965  static_cast< const ArcDeletion& >(change));
966  }
967 
968  // get new possible changes from the graph change generator
969  // warning: put the next 3 lines before calling __illegal2LegalChanges
970  __changes_generator->modifyGraph(
971  static_cast< const ArcDeletion& >(change));
972  __getNewChanges();
973 
974  // indicate that we have just applied the change
975  __invalidateChange(change_index);
976 
977  // indicate that the queue to which the change belongs needs be
978  // updated
979  __queues_to_update.insert(change.node2());
980  } break;
981 
983  // update the current score
984  __node_current_scores[change.node1()] +=
985  __change_scores[change_index].first;
986  __node_current_scores[change.node2()] +=
987  __change_scores[change_index].second;
988  __parents[change.node1()].push_back(change.node2());
989  auto& parents = __parents[change.node2()];
990  for (auto& par : parents) {
991  if (par == change.node1()) {
992  par = *(parents.rbegin());
993  parents.pop_back();
994  break;
995  }
996  }
997 
998  // inform the constraint that the graph has been modified
999  __constraint->modifyGraph(static_cast< const ArcReversal& >(change));
1000  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
1001  &(__changes_generator->constraint()))
1002  != __constraint) {
1003  __changes_generator->constraint().modifyGraph(
1004  static_cast< const ArcReversal& >(change));
1005  }
1006 
1007  // get new possible changes from the graph change generator
1008  // warning: put the next 3 lines before calling __illegal2LegalChanges
1009  __changes_generator->modifyGraph(
1010  static_cast< const ArcReversal& >(change));
1011  __getNewChanges();
1012 
1013  // indicate that we have just applied the change
1014  __invalidateChange(change_index);
1015 
1016  // indicate that the queue to which the change belongs needs be
1017  // updated
1018  __queues_to_update.insert(change.node1());
1019  __queues_to_update.insert(change.node2());
1020  } break;
1021 
1022  default:
1023  GUM_ERROR(NotImplementedYet,
1024  "Method applyChangeWithoutScoreUpdate of "
1025  << "GraphChangesSelector4DiGraph "
1026  << "does not handle yet graph change of type "
1027  << change.type());
1028  }
1029  }
1030 
1031 
1033  template < typename STRUCTURAL_CONSTRAINT,
1034  typename GRAPH_CHANGES_GENERATOR,
1035  template < typename >
1036  class ALLOC >
1037  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1038  GRAPH_CHANGES_GENERATOR,
1040  // determine which changes in the illegal set are now legal
1041  Set< std::size_t > new_legal_changes;
1042  for (auto iter = __illegal_changes.beginSafe();
1043  iter != __illegal_changes.endSafe();
1044  ++iter) {
1045  if (__isChangeValid(*iter)) {
1046  new_legal_changes.insert(*iter);
1047  __illegal_changes.erase(iter);
1048  }
1049  }
1050 
1051  // update the scores that need be updated
1052  Set< std::size_t > changes_to_recompute;
1053  for (const auto& node : __queues_to_update) {
1054  __findLegalChangesNeedingUpdate(changes_to_recompute, node);
1055  }
1056  __queues_to_update.clear();
1057 
1058  // put the previously illegal changes that are now legal into their queues
1059  for (const auto change_index : new_legal_changes) {
1060  const GraphChange& change = __changes[change_index];
1061  if (change.type() == GraphChangeType::ARC_REVERSAL) {
1062  __change_queue_per_node[change.node1()].insert(
1063  change_index, std::numeric_limits< double >::min());
1064  }
1065  __change_queue_per_node[change.node2()].insert(
1066  change_index, std::numeric_limits< double >::min());
1067 
1068  changes_to_recompute.insert(change_index);
1069  }
1070 
1071  // compute the scores that we need
1072  __updateScores(changes_to_recompute);
1073 
1074  __queues_valid = false;
1075  }
1076 
1077 
1079  template < typename STRUCTURAL_CONSTRAINT,
1080  typename GRAPH_CHANGES_GENERATOR,
1081  template < typename >
1082  class ALLOC >
1083  std::vector< std::pair< NodeId, double > >
1084  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1085  GRAPH_CHANGES_GENERATOR,
1086  ALLOC >::nodesSortedByBestScore() const {
1087  std::vector< std::pair< NodeId, double > > result(__node_queue.size());
1088  for (std::size_t i = std::size_t(0); i < __node_queue.size(); ++i) {
1089  result[i].first = __node_queue[i];
1090  result[i].second = __node_queue.priorityByPos(i);
1091  }
1092 
1093  std::sort(result.begin(),
1094  result.end(),
1095  [](const std::pair< NodeId, double >& a,
1096  const std::pair< NodeId, double >& b) -> bool {
1097  return a.second > b.second;
1098  });
1099 
1100  return result;
1101  }
1102 
1103 
1105  template < typename STRUCTURAL_CONSTRAINT,
1106  typename GRAPH_CHANGES_GENERATOR,
1107  template < typename >
1108  class ALLOC >
1109  std::vector< std::pair< NodeId, double > >
1110  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1111  GRAPH_CHANGES_GENERATOR,
1112  ALLOC >::nodesUnsortedWithScore() const {
1113  std::vector< std::pair< NodeId, double > > result(__node_queue.size());
1114  for (std::size_t i = std::size_t(0); i < __node_queue.size(); ++i) {
1115  result[i].first = __node_queue[i];
1116  result[i].second = __node_queue.priorityByPos(i);
1117  }
1118 
1119  return result;
1120  }
1121 
1122 
1124  template < typename STRUCTURAL_CONSTRAINT,
1125  typename GRAPH_CHANGES_GENERATOR,
1126  template < typename >
1127  class ALLOC >
1128  INLINE typename GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1129  GRAPH_CHANGES_GENERATOR,
1130  ALLOC >::GeneratorType&
1131  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1132  GRAPH_CHANGES_GENERATOR,
1133  ALLOC >::graphChangeGenerator() const
1134  noexcept {
1135  return *__changes_generator;
1136  }
1137 
1138 
1139  } /* namespace learning */
1140 
1141 } /* namespace gum */
1142 
1143 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
NodeProperty< double > __node_current_scores
the current score of each node
bool empty()
indicates whether the selector still contains graph changes
GeneratorType & graphChangeGenerator() const noexcept
returns the generator used by the selector
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
void __updateScores(const Set< std::size_t > &changes_to_recompute)
perform the necessary updates of the scores
STRUCTURAL_CONSTRAINT * __constraint
the set of constraints used to determine valid changes
const GraphChange & bestChange()
returns the best graph change to examine
STL namespace.
const Val & top() const
returns the element at the top of the priority queue
void applyChange(const GraphChange &change)
indicate to the selector that a change has been applied
Set< std::size_t > __illegal_changes
the set of changes known to be currently illegal (due to the constraints)
void erase(const Key &k)
Erases an element from the set.
Definition: set_tpl.h:653
const Priority & priorityByPos(Size index) const
Returns the priority of the value passed in argument.
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
PriorityQueue< NodeId, double, std::greater< double > > __node_queue
a global priority queue indicating for each node its best score
std::vector< std::pair< double, double > > __change_scores
the scores for the head and tail of all the changes
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition: set_tpl.h:499
void setPriority(const Val &elt, const Priority &new_priority)
Modifies the priority of each instance of a given element.
void applyChangeWithoutScoreUpdate(const GraphChange &change)
indicate to the selector that one of serveral changes has been applied
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:604
bool __isChangeValid(const std::size_t index) const
indicates whether a given change is valid or not
void resize(Size new_capacity)
Changes the size of the underlying hash table containing the set.
Definition: set_tpl.h:549
NodeProperty< std::vector< NodeId, ALLOC< NodeId > > > __parents
the set of parents of each node (speeds-up score computations)
Size insert(const Val &val, const Priority &priority)
Inserts a new (a copy) element in the priority queue.
Set< NodeId > __queues_to_update
the set of queues to update when applying several changes
double bestScore()
return the score of the best graph change
bool __queues_valid
indicates whether we need to recompute whether the queue is empty or not
GraphChangesSelector4DiGraph(Score< ALLOC > &score, STRUCTURAL_CONSTRAINT &constraint, GRAPH_CHANGES_GENERATOR &changes_generator)
default constructor
void updateScoresAfterAppliedChanges()
recompute all the scores after the application of several changes
std::vector< std::pair< NodeId, double > > nodesUnsortedWithScore() const
returns the set of queues top priorities
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition: set_tpl.h:485
Sequence< GraphChange > __changes
a sequence containing all the possible changes
std::vector< std::pair< NodeId, double > > nodesSortedByBestScore() const
returns the set of queues sorted by decreasing top priority
void clear()
Removes all the elements from the queue.
void __invalidateChange(const std::size_t change_index)
put a change into the illegal set
void __findLegalChangesNeedingUpdate(Set< std::size_t > &changes_to_recompute, const NodeId target_node)
finds the changes that are affected by a given node modification
void __illegal2LegalChanges(Set< std::size_t > &changes_to_recompute)
remove the now legal changes from the illegal set
void setGraph(DiGraph &graph)
sets the graph from which scores are computed
GRAPH_CHANGES_GENERATOR GeneratorType
the type of the generator
void clear()
Removes all the elements, if any, from the set.
Definition: set_tpl.h:372
GRAPH_CHANGES_GENERATOR * __changes_generator
the generator that returns the set of possible changes
void __getNewChanges()
get from the graph change generator a new set of changes
Size size() const noexcept
Returns the number of elements in the set.
Definition: set_tpl.h:698
bool isChangeValid(const GraphChange &change) const
indicates whether a given change is valid or not
const Priority & topPriority() const
Returns the priority of the top element.
Size NodeId
Type for node ids.
Definition: graphElements.h:97
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:610
NodeProperty< PriorityQueue< std::size_t, double, std::greater< double > > > __change_queue_per_node
for each node, a priority queue sorting GraphChanges by decreasing score
Size size() const noexcept
Returns the number of elements in the priority queue.
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52