aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
BayesBall_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation of the BayesBall class.
25  */
26 
27 namespace gum {
28 
29 
30  // update a set of potentials, keeping only those d-connected with
31  // query variables
32  template < typename GUM_SCALAR, template < typename > class TABLE >
33  void
35  const NodeSet& query,
36  const NodeSet& hardEvidence,
37  const NodeSet& softEvidence,
38  Set< const TABLE< GUM_SCALAR >* >& potentials) {
39  const DAG& dag = bn.dag();
40 
41  // create the marks (top = first and bottom = second)
42  NodeProperty< std::pair< bool, bool > > marks;
43  marks.resize(dag.size());
44  const std::pair< bool, bool > empty_mark(false, false);
45 
46  /// for relevant potentials: indicate which tables contain a variable
47  /// (nodeId)
49  for (const auto pot: potentials) {
50  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
51  for (const auto var: vars) {
52  const NodeId id = bn.nodeId(*var);
53  if (!node2potentials.exists(id)) {
54  node2potentials.insert(id, Set< const TABLE< GUM_SCALAR >* >());
55  }
57  }
58  }
59 
60  // indicate that we will send the ball to all the query nodes (as children):
61  // in list nodes_to_visit, the first element is the next node to send the
62  // ball to and the Boolean indicates whether we shall reach it from one of
63  // its children (true) or from one parent (false)
64  List< std::pair< NodeId, bool > > nodes_to_visit;
65  for (const auto node: query) {
66  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
67  }
68 
69  // perform the bouncing ball until node2potentials__ becomes empty (which
70  // means that we have reached all the potentials and, therefore, those
71  // are d-connected to query) or until there is no node in the graph to send
72  // the ball to
73  while (!nodes_to_visit.empty() && !node2potentials.empty()) {
74  // get the next node to visit
76 
77  // if the marks of the node do not exist, create them
79 
80  // if the node belongs to the query, update node2potentials__: remove all
81  // the potentials containing the node
83  auto& pot_set = node2potentials[node];
84  for (const auto pot: pot_set) {
85  const auto& vars = pot->variablesSequence();
86  for (const auto var: vars) {
87  const NodeId id = bn.nodeId(*var);
88  if (id != node) {
91  }
92  }
93  }
95 
96  // if node2potentials__ is empty, no need to go on: all the potentials
97  // are d-connected to the query
98  if (node2potentials.empty()) return;
99  }
100 
101 
102  // bounce the ball toward the neighbors
103  if (nodes_to_visit.front().second) { // visit from a child
105 
106  if (hardEvidence.exists(node)) { continue; }
107 
108  if (!marks[node].first) {
109  marks[node].first = true; // top marked
110  for (const auto par: dag.parents(node)) {
111  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
112  }
113  }
114 
115  if (!marks[node].second) {
116  marks[node].second = true; // bottom marked
117  for (const auto chi: dag.children(node)) {
118  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
119  }
120  }
121  } else { // visit from a parent
123 
126 
127  if (is_evidence && !marks[node].first) {
128  marks[node].first = true;
129 
130  for (const auto par: dag.parents(node)) {
131  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
132  }
133  }
134 
135  if (!is_hard_evidence && !marks[node].second) {
136  marks[node].second = true;
137 
138  for (const auto chi: dag.children(node)) {
139  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
140  }
141  }
142  }
143  }
144 
145 
146  // here, all the potentials that belong to node2potentials__ are d-separated
147  // from the query
148  for (const auto elt: node2potentials) {
149  for (const auto pot: elt.second) {
151  }
152  }
153  }
154 
155 
156 } /* namespace gum */
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669