aGrUM  0.16.0
dSeparation_tpl.h
Go to the documentation of this file.
1 
31 namespace gum {
32 
33 
34  // update a set of potentials, keeping only those d-connected with
35  // query variables given evidence
36  template < typename GUM_SCALAR, template < typename > class TABLE >
38  const IBayesNet< GUM_SCALAR >& bn,
39  const NodeSet& query,
40  const NodeSet& hardEvidence,
41  const NodeSet& softEvidence,
42  Set< const TABLE< GUM_SCALAR >* >& potentials) {
43  const DAG& dag = bn.dag();
44 
45  // mark the set of ancestors of the evidence
46  NodeSet ev_ancestors(dag.size());
47  {
48  List< NodeId > anc_to_visit;
49  for (const auto node : hardEvidence)
50  anc_to_visit.insert(node);
51  for (const auto node : softEvidence)
52  anc_to_visit.insert(node);
53  while (!anc_to_visit.empty()) {
54  const NodeId node = anc_to_visit.front();
55  anc_to_visit.popFront();
56 
57  if (!ev_ancestors.exists(node)) {
58  ev_ancestors.insert(node);
59  for (const auto par : dag.parents(node)) {
60  anc_to_visit.insert(par);
61  }
62  }
63  }
64  }
65 
66  // create the marks indicating that we have visited a node
67  NodeSet visited_from_child(dag.size());
68  NodeSet visited_from_parent(dag.size());
69 
73  for (const auto pot : potentials) {
74  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
75  for (const auto var : vars) {
76  const NodeId id = bn.nodeId(*var);
77  if (!node2potentials.exists(id)) {
78  node2potentials.insert(id, Set< const TABLE< GUM_SCALAR >* >());
79  }
80  node2potentials[id].insert(pot);
81  }
82  }
83 
84  // indicate that we will send the ball to all the query nodes (as children):
85  // in list nodes_to_visit, the first element is the next node to send the
86  // ball to and the Boolean indicates whether we shall reach it from one of
87  // its children (true) or from one parent (false)
88  List< std::pair< NodeId, bool > > nodes_to_visit;
89  for (const auto node : query) {
90  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
91  }
92 
93  // perform the bouncing ball until there is no node in the graph to send
94  // the ball to
95  while (!nodes_to_visit.empty() && !node2potentials.empty()) {
96  // get the next node to visit
97  const NodeId node = nodes_to_visit.front().first;
98  const bool direction = nodes_to_visit.front().second;
99  nodes_to_visit.popFront();
100 
101  // check if the node has not already been visited in the same direction
102  bool already_visited;
103  if (direction) {
104  already_visited = visited_from_child.exists(node);
105  if (!already_visited) { visited_from_child.insert(node); }
106  } else {
107  already_visited = visited_from_parent.exists(node);
108  if (!already_visited) { visited_from_parent.insert(node); }
109  }
110 
111  // if the node belongs to the query, update __node2potentials: remove all
112  // the potentials containing the node
113  if (node2potentials.exists(node)) {
114  auto& pot_set = node2potentials[node];
115  for (const auto pot : pot_set) {
116  const auto& vars = pot->variablesSequence();
117  for (const auto var : vars) {
118  const NodeId id = bn.nodeId(*var);
119  if (id != node) {
120  node2potentials[id].erase(pot);
121  if (node2potentials[id].empty()) { node2potentials.erase(id); }
122  }
123  }
124  }
125  node2potentials.erase(node);
126 
127  // if __node2potentials is empty, no need to go on: all the potentials
128  // are d-connected to the query
129  if (node2potentials.empty()) return;
130  }
131 
132  // if this is the first time we meet the node, then visit it
133  if (!already_visited) {
134  // mark the node as reachable if this is not a hard evidence
135  const bool is_hard_evidence = hardEvidence.exists(node);
136 
137  // bounce the ball toward the neighbors
138  if (direction && !is_hard_evidence) { // visit from a child
139  // visit the parents
140  for (const auto par : dag.parents(node)) {
141  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
142  }
143 
144  // visit the children
145  for (const auto chi : dag.children(node)) {
146  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
147  }
148  } else { // visit from a parent
149  if (!hardEvidence.exists(node)) {
150  // visit the children
151  for (const auto chi : dag.children(node)) {
152  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
153  }
154  }
155  if (ev_ancestors.exists(node)) {
156  // visit the parents
157  for (const auto par : dag.parents(node)) {
158  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
159  }
160  }
161  }
162  }
163  }
164 
165  // here, all the potentials that belong to __node2potentials are d-separated
166  // from the query
167  for (const auto elt : node2potentials) {
168  for (const auto pot : elt.second) {
169  potentials.erase(pot);
170  }
171  }
172  }
173 
174 
175 } /* namespace gum */
bool empty() const noexcept
Returns a boolean indicating whether the chained list is empty.
Definition: list_tpl.h:1970
Size size() const
alias for sizeNodes
The generic class for storing (ordered) sequences of objects.
Definition: sequence.h:1022
void relevantPotentials(const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
update a set of potentials, keeping only those d-connected with query variables given evidence ...
Generic doubly linked lists.
Definition: list.h:372
Class representing the minimal interface for Bayesian Network.
Definition: IBayesNet.h:62
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
void popFront()
Removes the first element of a List, if any.
Definition: list_tpl.h:1964
The class for generic Hash Tables.
Definition: hashTable.h:679
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:607
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
Val & insert(const Val &val)
Inserts a new element at the end of the chained list (alias of pushBack).
Definition: list_tpl.h:1619
Val & front() const
Returns a reference to first element of a list, if any.
Definition: list_tpl.h:1831
const NodeSet & children(const NodeId id) const
returns the set of nodes with arc outgoing from a given node
virtual NodeId nodeId(const DiscreteVariable &var) const =0
Return id node from discrete var pointer.
Base class for dag.
Definition: DAG.h:102
const DAG & dag() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:63
Size NodeId
Type for node ids.
Definition: graphElements.h:98