aGrUM  0.16.0
graphChangesSelector4DiGraph_tpl.h
Go to the documentation of this file.
1 
29 #ifndef DOXYGEN_SHOULD_SKIP_THIS
30 
31 # include <limits>
32 
33 namespace gum {
34 
35  namespace learning {
36 
38  template < typename STRUCTURAL_CONSTRAINT,
39  typename GRAPH_CHANGES_GENERATOR,
40  template < typename >
41  class ALLOC >
42  INLINE GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
43  GRAPH_CHANGES_GENERATOR,
44  ALLOC >::
45  GraphChangesSelector4DiGraph(Score< ALLOC >& score,
46  STRUCTURAL_CONSTRAINT& constraint,
47  GRAPH_CHANGES_GENERATOR& changes_generator) :
48  __score(score.clone()),
49  __constraint(&constraint), __changes_generator(&changes_generator) {
50  __parents.resize(32);
51  GUM_CONSTRUCTOR(GraphChangesSelector4DiGraph);
52  }
53 
55  template < typename STRUCTURAL_CONSTRAINT,
56  typename GRAPH_CHANGES_GENERATOR,
57  template < typename >
58  class ALLOC >
59  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
60  GRAPH_CHANGES_GENERATOR,
61  ALLOC >::
63  const GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
64  GRAPH_CHANGES_GENERATOR,
65  ALLOC >& from) :
66  __score(from.__score != nullptr ? from.__score->clone() : nullptr),
75  // for debugging
76  GUM_CONS_CPY(GraphChangesSelector4DiGraph);
77  }
78 
80  template < typename STRUCTURAL_CONSTRAINT,
81  typename GRAPH_CHANGES_GENERATOR,
82  template < typename >
83  class ALLOC >
84  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
85  GRAPH_CHANGES_GENERATOR,
86  ALLOC >::
88  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
89  GRAPH_CHANGES_GENERATOR,
90  ALLOC >&& from) :
91  __score(from.__score),
92  __constraint(std::move(from.__constraint)),
94  __changes(std::move(from.__changes)),
95  __change_scores(std::move(from.__change_scores)),
97  __node_queue(std::move(from.__node_queue)),
100  __parents(std::move(from.__parents)),
101  __queues_valid(std::move(from.__queues_valid)),
103  from.__score = nullptr;
104  // for debugging
105  GUM_CONS_MOV(GraphChangesSelector4DiGraph);
106  }
107 
109  template < typename STRUCTURAL_CONSTRAINT,
110  typename GRAPH_CHANGES_GENERATOR,
111  template < typename >
112  class ALLOC >
114 
115  STRUCTURAL_CONSTRAINT,
116  GRAPH_CHANGES_GENERATOR,
118  if (__score != nullptr) {
119  ALLOC< Score< ALLOC > > allocator(__score->getAllocator());
120  allocator.destroy(__score);
121  allocator.deallocate(__score, 1);
122  }
123  GUM_DESTRUCTOR(GraphChangesSelector4DiGraph);
124  }
125 
127  template < typename STRUCTURAL_CONSTRAINT,
128  typename GRAPH_CHANGES_GENERATOR,
129  template < typename >
130  class ALLOC >
131  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
132  GRAPH_CHANGES_GENERATOR,
133  ALLOC >&
134  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
135  GRAPH_CHANGES_GENERATOR,
136  ALLOC >::
137  operator=(const GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
138  GRAPH_CHANGES_GENERATOR,
139  ALLOC >& from) {
140  if (this != &from) {
141  // remove the old score
142  if (__score != nullptr) {
143  ALLOC< Score< ALLOC > > allocator(__score->getAllocator());
144  allocator.destroy(__score);
145  allocator.deallocate(__score, 1);
146  __score = nullptr;
147  }
148 
149  if (from.__score != nullptr) __score = from.__score->clone();
150  __constraint = from.__constraint;
151  __changes_generator = from.__changes_generator;
152  __changes = from.__changes;
153  __change_scores = from.__change_scores;
154  __change_queue_per_node = from.__change_queue_per_node;
155  __node_queue = from.__node_queue;
156  __illegal_changes = from.__illegal_changes;
157  __node_current_scores = from.__node_current_scores;
158  __parents = from.__parents;
159  __queues_valid = from.__queues_valid;
160  __queues_to_update = from.__queues_to_update;
161  }
162 
163  return *this;
164  }
165 
167  template < typename STRUCTURAL_CONSTRAINT,
168  typename GRAPH_CHANGES_GENERATOR,
169  template < typename >
170  class ALLOC >
171  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
172  GRAPH_CHANGES_GENERATOR,
173  ALLOC >&
174  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
175  GRAPH_CHANGES_GENERATOR,
176  ALLOC >::
177  operator=(GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
178  GRAPH_CHANGES_GENERATOR,
179  ALLOC >&& from) {
180  if (this != &from) {
181  __score = from.__score;
182  from.__score = nullptr;
183 
184  __constraint = std::move(from.__constraint);
185  __changes_generator = std::move(from.__changes_generator);
186  __changes = std::move(from.__changes);
187  __change_scores = std::move(from.__change_scores);
188  __change_queue_per_node = std::move(from.__change_queue_per_node);
189  __node_queue = std::move(from.__node_queue);
190  __illegal_changes = std::move(from.__illegal_changes);
191  __node_current_scores = std::move(from.__node_current_scores);
192  __parents = std::move(from.__parents);
193  __queues_valid = std::move(from.__queues_valid);
194  __queues_to_update = std::move(from.__queues_to_update);
195  }
196 
197  return *this;
198  }
199 
200 
202  template < typename STRUCTURAL_CONSTRAINT,
203  typename GRAPH_CHANGES_GENERATOR,
204  template < typename >
205  class ALLOC >
206  INLINE bool
207  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
208  GRAPH_CHANGES_GENERATOR,
209  ALLOC >::isChangeValid(const GraphChange&
210  change) const {
211  return __constraint->checkModification(change);
212  }
213 
214 
216  template < typename STRUCTURAL_CONSTRAINT,
217  typename GRAPH_CHANGES_GENERATOR,
218  template < typename >
219  class ALLOC >
220  INLINE bool
221  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
222  GRAPH_CHANGES_GENERATOR,
223  ALLOC >::__isChangeValid(const std::size_t
224  index) const {
225  return isChangeValid(__changes[index]);
226  }
227 
228 
230  template < typename STRUCTURAL_CONSTRAINT,
231  typename GRAPH_CHANGES_GENERATOR,
232  template < typename >
233  class ALLOC >
234  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
235  GRAPH_CHANGES_GENERATOR,
236  ALLOC >::setGraph(DiGraph& graph) {
237  // fill the DAG with all the missing nodes
238  const DatabaseTable< ALLOC >& database = __score->database();
239  const auto& nodeId2Columns = __score->nodeId2Columns();
240 
241  if (nodeId2Columns.empty()) {
242  const NodeId nb_nodes = NodeId(database.nbVariables());
243  for (NodeId i = NodeId(0); i < nb_nodes; ++i) {
244  if (!graph.existsNode(i)) { graph.addNodeWithId(i); }
245  }
246  } else {
247  for (auto iter = nodeId2Columns.cbegin(); iter != nodeId2Columns.cend();
248  ++iter) {
249  const NodeId id = iter.first();
250  if (!graph.existsNode(id)) { graph.addNodeWithId(id); }
251  }
252  }
253 
254 
255  // remove the node that do belong neither to the database
256  // nor to nodeId2Columns
257  if (nodeId2Columns.empty()) {
258  const NodeId nb_nodes = NodeId(database.nbVariables());
259  for (auto node : graph) {
260  if (node >= nb_nodes) { graph.eraseNode(node); }
261  }
262  } else {
263  for (auto node : graph) {
264  if (!nodeId2Columns.existsFirst(node)) { graph.eraseNode(node); }
265  }
266  }
267 
268 
269  // __constraint is the constraint used by the selector to restrict the set
270  // of applicable changes. However, the generator may have a different set
271  // of constraints (e.g., a constraintSliceOrder needs be tested only by the
272  // generator because the changes returned by the generator will always
273  // statisfy this constraint, hence the selector needs not test this
274  // constraint). Therefore, if the selector and generator have different
275  // constraints, both should use method setGraph() to initialize
276  // themselves.
277  __constraint->setGraph(graph);
278  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
279  &(__changes_generator->constraint()))
280  != __constraint) {
281  __changes_generator->constraint().setGraph(graph);
282  }
283 
284  __changes_generator->setGraph(graph);
285 
286 
287  // save the set of parents of each node (this will speed-up the
288  // computations of the scores)
289  const std::size_t nb_nodes = graph.size();
290  {
291  const std::vector< NodeId, ALLOC< NodeId > > empty_pars;
292  __parents.clear();
293  __parents.resize(nb_nodes);
294  for (const auto node : graph) {
295  auto& node_parents = __parents.insert(node, empty_pars).second;
296  const NodeSet& dag_parents = graph.parents(node);
297  if (!dag_parents.empty()) {
298  node_parents.resize(dag_parents.size());
299  std::size_t j = std::size_t(0);
300  for (const auto par : dag_parents) {
301  node_parents[j] = par;
302  ++j;
303  }
304  }
305  }
306  }
307 
308  // assign a score to each node given its parents in the current graph
309  __node_current_scores.clear();
310  __node_current_scores.resize(nb_nodes);
311  for (const auto node : graph) {
312  __node_current_scores.insert(node, __score->score(node, __parents[node]));
313  }
314 
315  // compute all the possible changes
316  __changes.clear();
317  __changes.resize(nb_nodes);
318  for (const auto& change : *__changes_generator) {
319  __changes << change;
320  }
321  __changes_generator->notifyGetCompleted();
322 
323  // determine the changes that are illegal and prepare the computation of
324  // the scores of all the legal changes
326 
327  // set the __change_scores and __change_queue_per_node for legal changes
328  __change_scores.clear();
329  __change_scores.resize(
330  __changes.size(),
331  std::pair< double, double >(std::numeric_limits< double >::min(),
332  std::numeric_limits< double >::min()));
333  __change_queue_per_node.clear();
334  __change_queue_per_node.resize(nb_nodes);
335  {
336  const PriorityQueue< std::size_t, double, std::greater< double > >
337  empty_prio;
338  for (const auto node : graph) {
339  __change_queue_per_node.insert(node, empty_prio);
340  }
341  }
342 
343  for (std::size_t i = std::size_t(0); i < __changes.size(); ++i) {
344  if (!__isChangeValid(i)) {
346  } else {
347  const GraphChange& change = __changes[i];
348 
349  switch (change.type()) {
351  auto& parents = __parents[change.node2()];
352  parents.push_back(change.node1());
353  const double delta = __score->score(change.node2(), parents)
354  - __node_current_scores[change.node2()];
355  parents.pop_back();
356 
357  __change_scores[i].second = delta;
358  __change_queue_per_node[change.node2()].insert(i, delta);
359  } break;
360 
362  auto& parents = __parents[change.node2()];
363  for (auto& par : parents) {
364  if (par == change.node1()) {
365  par = *(parents.rbegin());
366  parents.pop_back();
367  break;
368  }
369  }
370  const double delta = __score->score(change.node2(), parents)
371  - __node_current_scores[change.node2()];
372  parents.push_back(change.node1());
373 
374  __change_scores[i].second = delta;
375  __change_queue_per_node[change.node2()].insert(i, delta);
376  } break;
377 
379  // remove arc ( node1 -> node2 )
380  auto& parents2 = __parents[change.node2()];
381  for (auto& par : parents2) {
382  if (par == change.node1()) {
383  par = *(parents2.rbegin());
384  parents2.pop_back();
385  break;
386  }
387  }
388 
389  const double delta2 = __score->score(change.node2(), parents2)
390  - __node_current_scores[change.node2()];
391  parents2.push_back(change.node1());
392 
393  // add arc ( node2 -> node1 )
394  auto& parents1 = __parents[change.node1()];
395  parents1.push_back(change.node2());
396  const double delta1 = __score->score(change.node1(), parents1)
397  - __node_current_scores[change.node1()];
398  parents1.pop_back();
399 
400  __change_scores[i].first = delta1;
401  __change_scores[i].second = delta2;
402 
403  const double delta = delta1 + delta2;
404  __change_queue_per_node[change.node1()].insert(i, delta);
405  __change_queue_per_node[change.node2()].insert(i, delta);
406 
407  } break;
408 
409  default: {
410  GUM_ERROR(NotImplementedYet,
411  "Method setGraph of GraphChangesSelector4DiGraph "
412  << "does not handle yet graph change of type "
413  << change.type());
414  }
415  }
416  }
417  }
418 
419  // update the global queue
421  for (const auto node : graph) {
422  __node_queue.insert(node,
424  ? std::numeric_limits< double >::min()
425  : __change_queue_per_node[node].topPriority());
426  }
427  __queues_valid = true;
429  }
430 
431 
433  template < typename STRUCTURAL_CONSTRAINT,
434  typename GRAPH_CHANGES_GENERATOR,
435  template < typename >
436  class ALLOC >
437  void
438  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
439  GRAPH_CHANGES_GENERATOR,
440  ALLOC >::__invalidateChange(const std::size_t
441  change_index) {
442  const GraphChange& change = __changes[change_index];
443  if (change.type() == GraphChangeType::ARC_REVERSAL) {
444  // remove the tail change from its priority queue
445  PriorityQueue< std::size_t, double, std::greater< double > >& queue1 =
446  __change_queue_per_node[change.node1()];
447  queue1.erase(change_index);
448 
449  // recompute the top priority for the changes of the head
450  const double new_priority = queue1.empty()
451  ? std::numeric_limits< double >::min()
452  : queue1.topPriority();
453  __node_queue.setPriority(change.node1(), new_priority);
454  }
455 
456  // remove the head change from its priority queue
457  PriorityQueue< std::size_t, double, std::greater< double > >& queue2 =
458  __change_queue_per_node[change.node2()];
459  queue2.erase(change_index);
460 
461  // recompute the top priority for the changes of the head
462  const double new_priority = queue2.empty()
463  ? std::numeric_limits< double >::min()
464  : queue2.topPriority();
465  __node_queue.setPriority(change.node2(), new_priority);
466 
467  // put the change into the illegal set
468  __illegal_changes.insert(change_index);
469  }
470 
471 
473  template < typename STRUCTURAL_CONSTRAINT,
474  typename GRAPH_CHANGES_GENERATOR,
475  template < typename >
476  class ALLOC >
477  bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
478  GRAPH_CHANGES_GENERATOR,
479  ALLOC >::empty() {
480  // put into the illegal change set all the top elements of the different
481  // queues that are not valid anymore
482  if (!__queues_valid) {
483  for (auto& queue_pair : __change_queue_per_node) {
484  auto& queue = queue_pair.second;
485  while (!queue.empty() && !__isChangeValid(queue.top())) {
486  __invalidateChange(queue.top());
487  }
488  }
489  __queues_valid = true;
490  }
491 
492  return __node_queue.topPriority() == std::numeric_limits< double >::min();
493  }
494 
495 
498  template < typename STRUCTURAL_CONSTRAINT,
499  typename GRAPH_CHANGES_GENERATOR,
500  template < typename >
501  class ALLOC >
502  bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
503  GRAPH_CHANGES_GENERATOR,
504  ALLOC >::empty(const NodeId node) {
505  // put into the illegal change set all the top elements of the different
506  // queues that are not valid anymore
507  if (!__queues_valid) {
508  for (auto& queue_pair : __change_queue_per_node) {
509  auto& queue = queue_pair.second;
510  while (!queue.empty() && !__isChangeValid(queue.top())) {
511  __invalidateChange(queue.top());
512  }
513  }
514  __queues_valid = true;
515  }
516 
517  return __change_queue_per_node[node].empty();
518  }
519 
520 
522  template < typename STRUCTURAL_CONSTRAINT,
523  typename GRAPH_CHANGES_GENERATOR,
524  template < typename >
525  class ALLOC >
526  INLINE const GraphChange&
527  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
528  GRAPH_CHANGES_GENERATOR,
529  ALLOC >::bestChange() {
530  if (!empty())
531  return __changes[__change_queue_per_node[__node_queue.top()].top()];
532  else
533  GUM_ERROR(NotFound, "there exists no graph change applicable");
534  }
535 
536 
538  template < typename STRUCTURAL_CONSTRAINT,
539  typename GRAPH_CHANGES_GENERATOR,
540  template < typename >
541  class ALLOC >
542  INLINE const GraphChange&
543  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
544  GRAPH_CHANGES_GENERATOR,
545  ALLOC >::bestChange(const NodeId node) {
546  if (!empty(node))
547  return __changes[__change_queue_per_node[node].top()];
548  else
549  GUM_ERROR(NotFound, "there exists no graph change applicable");
550  }
551 
552 
554  template < typename STRUCTURAL_CONSTRAINT,
555  typename GRAPH_CHANGES_GENERATOR,
556  template < typename >
557  class ALLOC >
558  INLINE double GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
559  GRAPH_CHANGES_GENERATOR,
560  ALLOC >::bestScore() {
561  if (!empty())
562  return __change_queue_per_node[__node_queue.top()].topPriority();
563  else
564  GUM_ERROR(NotFound, "there exists no graph change applicable");
565  }
566 
567 
569  template < typename STRUCTURAL_CONSTRAINT,
570  typename GRAPH_CHANGES_GENERATOR,
571  template < typename >
572  class ALLOC >
573  INLINE double
574  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
575  GRAPH_CHANGES_GENERATOR,
576  ALLOC >::bestScore(const NodeId node) {
577  if (!empty(node))
578  return __change_queue_per_node[node].topPriority();
579  else
580  GUM_ERROR(NotFound, "there exists no graph change applicable");
581  }
582 
583 
585  template < typename STRUCTURAL_CONSTRAINT,
586  typename GRAPH_CHANGES_GENERATOR,
587  template < typename >
588  class ALLOC >
590  STRUCTURAL_CONSTRAINT,
591  GRAPH_CHANGES_GENERATOR,
592  ALLOC >::__illegal2LegalChanges(Set< std::size_t >& changes_to_recompute) {
593  for (auto iter = __illegal_changes.beginSafe();
594  iter != __illegal_changes.endSafe();
595  ++iter) {
596  if (__isChangeValid(*iter)) {
597  const GraphChange& change = __changes[*iter];
598  if (change.type() == GraphChangeType::ARC_REVERSAL) {
599  __change_queue_per_node[change.node1()].insert(
600  *iter, std::numeric_limits< double >::min());
601  }
602  __change_queue_per_node[change.node2()].insert(
603  *iter, std::numeric_limits< double >::min());
604 
605  changes_to_recompute.insert(*iter);
606  __illegal_changes.erase(iter);
607  }
608  }
609  }
610 
611 
613  template < typename STRUCTURAL_CONSTRAINT,
614  typename GRAPH_CHANGES_GENERATOR,
615  template < typename >
616  class ALLOC >
617  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
618  GRAPH_CHANGES_GENERATOR,
619  ALLOC >::
621  const NodeId target_node) {
622  const HashTable< std::size_t, Size >& changes =
623  __change_queue_per_node[target_node].allValues();
624  for (auto iter = changes.cbeginSafe(); iter != changes.cendSafe(); ++iter) {
625  if (!changes_to_recompute.exists(iter.key())) {
626  if (__isChangeValid(iter.key())) {
627  changes_to_recompute.insert(iter.key());
628  } else {
629  __invalidateChange(iter.key());
630  }
631  }
632  }
633  }
634 
635 
637  template < typename STRUCTURAL_CONSTRAINT,
638  typename GRAPH_CHANGES_GENERATOR,
639  template < typename >
640  class ALLOC >
642  STRUCTURAL_CONSTRAINT,
643  GRAPH_CHANGES_GENERATOR,
644  ALLOC >::__updateScores(const Set< std::size_t >& changes_to_recompute) {
645  Set< NodeId > modified_nodes(changes_to_recompute.size());
646 
647  for (const auto change_index : changes_to_recompute) {
648  const GraphChange& change = __changes[change_index];
649 
650  switch (change.type()) {
652  // add the arc
653  auto& parents = __parents[change.node2()];
654  parents.push_back(change.node1());
655  const double delta = __score->score(change.node2(), parents)
656  - __node_current_scores[change.node2()];
657  parents.pop_back();
658 
659  // update the score
660  __change_scores[change_index].second = delta;
661 
662  // update the head queue
663  __change_queue_per_node[change.node2()].setPriority(change_index,
664  delta);
665  // indicate which queue was modified
666  modified_nodes.insert(change.node2());
667  } break;
668 
670  // remove the arc
671  auto& parents = __parents[change.node2()];
672  for (auto& par : parents) {
673  if (par == change.node1()) {
674  par = *(parents.rbegin());
675  parents.pop_back();
676  break;
677  }
678  }
679  const double delta = __score->score(change.node2(), parents)
680  - __node_current_scores[change.node2()];
681  parents.push_back(change.node1());
682 
683  // update the score
684  __change_scores[change_index].second = delta;
685 
686  // update the head queue
687  __change_queue_per_node[change.node2()].setPriority(change_index,
688  delta);
689  // indicate which queue was modified
690  modified_nodes.insert(change.node2());
691  } break;
692 
694  // remove arc ( node1 -> node2 )
695  auto& parents2 = __parents[change.node2()];
696  for (auto& par : parents2) {
697  if (par == change.node1()) {
698  par = *(parents2.rbegin());
699  parents2.pop_back();
700  break;
701  }
702  }
703 
704  const double delta2 = __score->score(change.node2(), parents2)
705  - __node_current_scores[change.node2()];
706  parents2.push_back(change.node1());
707 
708  // add arc ( node2 -> node1 )
709  auto& parents1 = __parents[change.node1()];
710  parents1.push_back(change.node2());
711  const double delta1 = __score->score(change.node1(), parents1)
712  - __node_current_scores[change.node1()];
713  parents1.pop_back();
714 
715  // update the scores
716  __change_scores[change_index].first = delta1;
717  __change_scores[change_index].second = delta2;
718 
719  // update the queues
720  const double delta = delta1 + delta2;
721  __change_queue_per_node[change.node1()].setPriority(change_index,
722  delta);
723  __change_queue_per_node[change.node2()].setPriority(change_index,
724  delta);
725 
726  // indicate which queues were modified
727  modified_nodes.insert(change.node1());
728  modified_nodes.insert(change.node2());
729  } break;
730 
731  default: {
732  GUM_ERROR(NotImplementedYet,
733  "Method __updateScores of GraphChangesSelector4DiGraph "
734  << "does not handle yet graph change of type "
735  << change.type());
736  }
737  }
738  }
739 
740  // update the node queue
741  for (const auto node : modified_nodes) {
743  __change_queue_per_node[node].empty()
744  ? std::numeric_limits< double >::min()
745  : __change_queue_per_node[node].topPriority());
746  }
747  }
748 
749 
751  template < typename STRUCTURAL_CONSTRAINT,
752  typename GRAPH_CHANGES_GENERATOR,
753  template < typename >
754  class ALLOC >
755  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
756  GRAPH_CHANGES_GENERATOR,
757  ALLOC >::__getNewChanges() {
758  // ask the graph change generator for all its available changes
759  for (const auto& change : *__changes_generator) {
760  // check that the change does not already exist
761  if (!__changes.exists(change)) {
762  // add the new change. To make the addition simple, we put the new
763  // change into the illegal changes set. Afterwards, the applyChange
764  // function will put the legal changes again into the queues
766  __changes << change;
767  __change_scores.push_back(
768  std::pair< double, double >(std::numeric_limits< double >::min(),
769  std::numeric_limits< double >::min()));
770  }
771  }
772 
773  // indicate to the generator that we have finished retrieving its changes
774  __changes_generator->notifyGetCompleted();
775  }
776 
777 
779  template < typename STRUCTURAL_CONSTRAINT,
780  typename GRAPH_CHANGES_GENERATOR,
781  template < typename >
782  class ALLOC >
783  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
784  GRAPH_CHANGES_GENERATOR,
785  ALLOC >::applyChange(const GraphChange&
786  change) {
787  // first, we get the index of the change
788  const std::size_t change_index = __changes.pos(change);
789 
790  // perform the change
791  Set< std::size_t > changes_to_recompute;
792  switch (change.type()) {
794  // update the current score
795  __node_current_scores[change.node2()] +=
796  __change_scores[change_index].second;
797  __parents[change.node2()].push_back(change.node1());
798 
799  // inform the constraint that the graph has been modified
800  __constraint->modifyGraph(static_cast< const ArcAddition& >(change));
801  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
802  &(__changes_generator->constraint()))
803  != __constraint) {
804  __changes_generator->constraint().modifyGraph(
805  static_cast< const ArcAddition& >(change));
806  }
807 
808  // get new possible changes from the graph change generator
809  // warning: put the next 3 lines before calling __illegal2LegalChanges
810  __changes_generator->modifyGraph(
811  static_cast< const ArcAddition& >(change));
812  __getNewChanges();
813 
814  // check whether some illegal changes can be put into the valid queues
815  __illegal2LegalChanges(changes_to_recompute);
816  __invalidateChange(change_index);
817  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node2());
818  __updateScores(changes_to_recompute);
819  } break;
820 
822  // update the current score
823  __node_current_scores[change.node2()] +=
824  __change_scores[change_index].second;
825  auto& parents = __parents[change.node2()];
826  for (auto& par : parents) {
827  if (par == change.node1()) {
828  par = *(parents.rbegin());
829  parents.pop_back();
830  break;
831  }
832  }
833 
834  // inform the constraint that the graph has been modified
835  __constraint->modifyGraph(static_cast< const ArcDeletion& >(change));
836  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
837  &(__changes_generator->constraint()))
838  != __constraint) {
839  __changes_generator->constraint().modifyGraph(
840  static_cast< const ArcDeletion& >(change));
841  }
842 
843  // get new possible changes from the graph change generator
844  // warning: put the next 3 lines before calling __illegal2LegalChanges
845  __changes_generator->modifyGraph(
846  static_cast< const ArcDeletion& >(change));
847  __getNewChanges();
848 
849  // check whether some illegal changes can be put into the valid queues
850  __illegal2LegalChanges(changes_to_recompute);
851  __invalidateChange(change_index);
852  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node2());
853  __updateScores(changes_to_recompute);
854  } break;
855 
857  // update the current score
858  __node_current_scores[change.node1()] +=
859  __change_scores[change_index].first;
860  __node_current_scores[change.node2()] +=
861  __change_scores[change_index].second;
862  __parents[change.node1()].push_back(change.node2());
863  auto& parents = __parents[change.node2()];
864  for (auto& par : parents) {
865  if (par == change.node1()) {
866  par = *(parents.rbegin());
867  parents.pop_back();
868  break;
869  }
870  }
871 
872  // inform the constraint that the graph has been modified
873  __constraint->modifyGraph(static_cast< const ArcReversal& >(change));
874  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
875  &(__changes_generator->constraint()))
876  != __constraint) {
877  __changes_generator->constraint().modifyGraph(
878  static_cast< const ArcReversal& >(change));
879  }
880 
881  // get new possible changes from the graph change generator
882  // warning: put the next 3 lines before calling __illegal2LegalChanges
883  __changes_generator->modifyGraph(
884  static_cast< const ArcReversal& >(change));
885  __getNewChanges();
886 
887  // check whether some illegal changes can be put into the valid queues
888  __illegal2LegalChanges(changes_to_recompute);
889  __invalidateChange(change_index);
890  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node1());
891  __findLegalChangesNeedingUpdate(changes_to_recompute, change.node2());
892  __updateScores(changes_to_recompute);
893  } break;
894 
895  default:
896  GUM_ERROR(NotImplementedYet,
897  "Method applyChange of GraphChangesSelector4DiGraph "
898  << "does not handle yet graph change of type "
899  << change.type());
900  }
901 
902  __queues_valid = false;
903  }
904 
905 
907  template < typename STRUCTURAL_CONSTRAINT,
908  typename GRAPH_CHANGES_GENERATOR,
909  template < typename >
910  class ALLOC >
912  STRUCTURAL_CONSTRAINT,
913  GRAPH_CHANGES_GENERATOR,
914  ALLOC >::applyChangeWithoutScoreUpdate(const GraphChange& change) {
915  // first, we get the index of the change
916  const std::size_t change_index = __changes.pos(change);
917 
918  // perform the change
919  switch (change.type()) {
921  // update the current score
922  __node_current_scores[change.node2()] +=
923  __change_scores[change_index].second;
924  __parents[change.node2()].push_back(change.node1());
925 
926  // inform the constraint that the graph has been modified
927  __constraint->modifyGraph(static_cast< const ArcAddition& >(change));
928  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
929  &(__changes_generator->constraint()))
930  != __constraint) {
931  __changes_generator->constraint().modifyGraph(
932  static_cast< const ArcAddition& >(change));
933  }
934 
935  // get new possible changes from the graph change generator
936  // warning: put the next 3 lines before calling __illegal2LegalChanges
937  __changes_generator->modifyGraph(
938  static_cast< const ArcAddition& >(change));
939  __getNewChanges();
940 
941  // indicate that we have just applied the change
942  __invalidateChange(change_index);
943 
944  // indicate that the queue to which the change belongs needs be
945  // updated
946  __queues_to_update.insert(change.node2());
947  } break;
948 
950  // update the current score
951  __node_current_scores[change.node2()] +=
952  __change_scores[change_index].second;
953  auto& parents = __parents[change.node2()];
954  for (auto& par : parents) {
955  if (par == change.node1()) {
956  par = *(parents.rbegin());
957  parents.pop_back();
958  break;
959  }
960  }
961 
962  // inform the constraint that the graph has been modified
963  __constraint->modifyGraph(static_cast< const ArcDeletion& >(change));
964  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
965  &(__changes_generator->constraint()))
966  != __constraint) {
967  __changes_generator->constraint().modifyGraph(
968  static_cast< const ArcDeletion& >(change));
969  }
970 
971  // get new possible changes from the graph change generator
972  // warning: put the next 3 lines before calling __illegal2LegalChanges
973  __changes_generator->modifyGraph(
974  static_cast< const ArcDeletion& >(change));
975  __getNewChanges();
976 
977  // indicate that we have just applied the change
978  __invalidateChange(change_index);
979 
980  // indicate that the queue to which the change belongs needs be
981  // updated
982  __queues_to_update.insert(change.node2());
983  } break;
984 
986  // update the current score
987  __node_current_scores[change.node1()] +=
988  __change_scores[change_index].first;
989  __node_current_scores[change.node2()] +=
990  __change_scores[change_index].second;
991  __parents[change.node1()].push_back(change.node2());
992  auto& parents = __parents[change.node2()];
993  for (auto& par : parents) {
994  if (par == change.node1()) {
995  par = *(parents.rbegin());
996  parents.pop_back();
997  break;
998  }
999  }
1000 
1001  // inform the constraint that the graph has been modified
1002  __constraint->modifyGraph(static_cast< const ArcReversal& >(change));
1003  if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(
1004  &(__changes_generator->constraint()))
1005  != __constraint) {
1006  __changes_generator->constraint().modifyGraph(
1007  static_cast< const ArcReversal& >(change));
1008  }
1009 
1010  // get new possible changes from the graph change generator
1011  // warning: put the next 3 lines before calling __illegal2LegalChanges
1012  __changes_generator->modifyGraph(
1013  static_cast< const ArcReversal& >(change));
1014  __getNewChanges();
1015 
1016  // indicate that we have just applied the change
1017  __invalidateChange(change_index);
1018 
1019  // indicate that the queue to which the change belongs needs be
1020  // updated
1021  __queues_to_update.insert(change.node1());
1022  __queues_to_update.insert(change.node2());
1023  } break;
1024 
1025  default:
1026  GUM_ERROR(NotImplementedYet,
1027  "Method applyChangeWithoutScoreUpdate of "
1028  << "GraphChangesSelector4DiGraph "
1029  << "does not handle yet graph change of type "
1030  << change.type());
1031  }
1032  }
1033 
1034 
1036  template < typename STRUCTURAL_CONSTRAINT,
1037  typename GRAPH_CHANGES_GENERATOR,
1038  template < typename >
1039  class ALLOC >
1040  void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1041  GRAPH_CHANGES_GENERATOR,
1043  // determine which changes in the illegal set are now legal
1044  Set< std::size_t > new_legal_changes;
1045  for (auto iter = __illegal_changes.beginSafe();
1046  iter != __illegal_changes.endSafe();
1047  ++iter) {
1048  if (__isChangeValid(*iter)) {
1049  new_legal_changes.insert(*iter);
1050  __illegal_changes.erase(iter);
1051  }
1052  }
1053 
1054  // update the scores that need be updated
1055  Set< std::size_t > changes_to_recompute;
1056  for (const auto& node : __queues_to_update) {
1057  __findLegalChangesNeedingUpdate(changes_to_recompute, node);
1058  }
1059  __queues_to_update.clear();
1060 
1061  // put the previously illegal changes that are now legal into their queues
1062  for (const auto change_index : new_legal_changes) {
1063  const GraphChange& change = __changes[change_index];
1064  if (change.type() == GraphChangeType::ARC_REVERSAL) {
1065  __change_queue_per_node[change.node1()].insert(
1066  change_index, std::numeric_limits< double >::min());
1067  }
1068  __change_queue_per_node[change.node2()].insert(
1069  change_index, std::numeric_limits< double >::min());
1070 
1071  changes_to_recompute.insert(change_index);
1072  }
1073 
1074  // compute the scores that we need
1075  __updateScores(changes_to_recompute);
1076 
1077  __queues_valid = false;
1078  }
1079 
1080 
1082  template < typename STRUCTURAL_CONSTRAINT,
1083  typename GRAPH_CHANGES_GENERATOR,
1084  template < typename >
1085  class ALLOC >
1086  std::vector< std::pair< NodeId, double > >
1087  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1088  GRAPH_CHANGES_GENERATOR,
1089  ALLOC >::nodesSortedByBestScore() const {
1090  std::vector< std::pair< NodeId, double > > result(__node_queue.size());
1091  for (std::size_t i = std::size_t(0); i < __node_queue.size(); ++i) {
1092  result[i].first = __node_queue[i];
1093  result[i].second = __node_queue.priorityByPos(i);
1094  }
1095 
1096  std::sort(result.begin(),
1097  result.end(),
1098  [](const std::pair< NodeId, double >& a,
1099  const std::pair< NodeId, double >& b) -> bool {
1100  return a.second > b.second;
1101  });
1102 
1103  return result;
1104  }
1105 
1106 
1108  template < typename STRUCTURAL_CONSTRAINT,
1109  typename GRAPH_CHANGES_GENERATOR,
1110  template < typename >
1111  class ALLOC >
1112  std::vector< std::pair< NodeId, double > >
1113  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1114  GRAPH_CHANGES_GENERATOR,
1115  ALLOC >::nodesUnsortedWithScore() const {
1116  std::vector< std::pair< NodeId, double > > result(__node_queue.size());
1117  for (std::size_t i = std::size_t(0); i < __node_queue.size(); ++i) {
1118  result[i].first = __node_queue[i];
1119  result[i].second = __node_queue.priorityByPos(i);
1120  }
1121 
1122  return result;
1123  }
1124 
1125 
1127  template < typename STRUCTURAL_CONSTRAINT,
1128  typename GRAPH_CHANGES_GENERATOR,
1129  template < typename >
1130  class ALLOC >
1131  INLINE typename GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1132  GRAPH_CHANGES_GENERATOR,
1133  ALLOC >::GeneratorType&
1134  GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
1135  GRAPH_CHANGES_GENERATOR,
1136  ALLOC >::graphChangeGenerator() const
1137  noexcept {
1138  return *__changes_generator;
1139  }
1140 
1141 
1142  } /* namespace learning */
1143 
1144 } /* namespace gum */
1145 
1146 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
NodeProperty< double > __node_current_scores
the current score of each node
bool empty()
indicates whether the selector still contains graph changes
GeneratorType & graphChangeGenerator() const noexcept
returns the generator used by the selector
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
void __updateScores(const Set< std::size_t > &changes_to_recompute)
perform the necessary updates of the scores
STRUCTURAL_CONSTRAINT * __constraint
the set of constraints used to determine valid changes
const GraphChange & bestChange()
returns the best graph change to examine
STL namespace.
const Val & top() const
returns the element at the top of the priority queue
void applyChange(const GraphChange &change)
indicate to the selector that a change has been applied
Set< std::size_t > __illegal_changes
the set of changes known to be currently illegal (due to the constraints)
void erase(const Key &k)
Erases an element from the set.
Definition: set_tpl.h:656
const Priority & priorityByPos(Size index) const
Returns the priority of the value passed in argument.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
PriorityQueue< NodeId, double, std::greater< double > > __node_queue
a global priority queue indicating for each node its best score
std::vector< std::pair< double, double > > __change_scores
the scores for the head and tail of all the changes
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition: set_tpl.h:502
void setPriority(const Val &elt, const Priority &new_priority)
Modifies the priority of each instance of a given element.
void applyChangeWithoutScoreUpdate(const GraphChange &change)
indicate to the selector that one of serveral changes has been applied
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:607
bool __isChangeValid(const std::size_t index) const
indicates whether a given change is valid or not
void resize(Size new_capacity)
Changes the size of the underlying hash table containing the set.
Definition: set_tpl.h:552
NodeProperty< std::vector< NodeId, ALLOC< NodeId > > > __parents
the set of parents of each node (speeds-up score computations)
Size insert(const Val &val, const Priority &priority)
Inserts a new (a copy) element in the priority queue.
Set< NodeId > __queues_to_update
the set of queues to update when applying several changes
double bestScore()
return the score of the best graph change
bool __queues_valid
indicates whether we need to recompute whether the queue is empty or not
GraphChangesSelector4DiGraph(Score< ALLOC > &score, STRUCTURAL_CONSTRAINT &constraint, GRAPH_CHANGES_GENERATOR &changes_generator)
default constructor
void updateScoresAfterAppliedChanges()
recompute all the scores after the application of several changes
std::vector< std::pair< NodeId, double > > nodesUnsortedWithScore() const
returns the set of queues top priorities
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition: set_tpl.h:488
Sequence< GraphChange > __changes
a sequence containing all the possible changes
std::vector< std::pair< NodeId, double > > nodesSortedByBestScore() const
returns the set of queues sorted by decreasing top priority
void clear()
Removes all the elements from the queue.
void __invalidateChange(const std::size_t change_index)
put a change into the illegal set
void __findLegalChangesNeedingUpdate(Set< std::size_t > &changes_to_recompute, const NodeId target_node)
finds the changes that are affected by a given node modification
void __illegal2LegalChanges(Set< std::size_t > &changes_to_recompute)
remove the now legal changes from the illegal set
void setGraph(DiGraph &graph)
sets the graph from which scores are computed
GRAPH_CHANGES_GENERATOR GeneratorType
the type of the generator
void clear()
Removes all the elements, if any, from the set.
Definition: set_tpl.h:375
GRAPH_CHANGES_GENERATOR * __changes_generator
the generator that returns the set of possible changes
void __getNewChanges()
get from the graph change generator a new set of changes
Size size() const noexcept
Returns the number of elements in the set.
Definition: set_tpl.h:701
bool isChangeValid(const GraphChange &change) const
indicates whether a given change is valid or not
const Priority & topPriority() const
Returns the priority of the top element.
Size NodeId
Type for node ids.
Definition: graphElements.h:98
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:613
NodeProperty< PriorityQueue< std::size_t, double, std::greater< double > > > __change_queue_per_node
for each node, a priority queue sorting GraphChanges by decreasing score
Size size() const noexcept
Returns the number of elements in the priority queue.
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55