aGrUM  0.20.2
a C++ library for (probabilistic) graphical models

Implementation of Shachter's Bayes Balls algorithm. More...

#include <agrum/BN/inference/BayesBall.h>

Static Public Member Functions

Accessors / Modifiers
static void requisiteNodes (const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
 Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence. More...
 
template<typename GUM_SCALAR , template< typename > class TABLE>
static void relevantPotentials (const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
 update a set of potentials, keeping only those d-connected with query variables given evidence More...
 

Detailed Description

Implementation of Shachter's Bayes Balls algorithm.

Definition at line 51 of file BayesBall.h.

Constructor & Destructor Documentation

◆ BayesBall()

INLINE gum::BayesBall::BayesBall ( )
private

Default constructor.

Definition at line 33 of file BayesBall_inl.h.

References gum::Set< Key, Alloc >::emplace().

33 { GUM_CONSTRUCTOR(BayesBall); }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:33
+ Here is the call graph for this function:

◆ ~BayesBall()

INLINE gum::BayesBall::~BayesBall ( )
private

Destructor.

Definition at line 36 of file BayesBall_inl.h.

References gum::Set< Key, Alloc >::emplace().

36 { GUM_DESTRUCTOR(BayesBall); }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:33
+ Here is the call graph for this function:

Member Function Documentation

◆ relevantPotentials()

template<typename GUM_SCALAR , template< typename > class TABLE>
void gum::BayesBall::relevantPotentials ( const IBayesNet< GUM_SCALAR > &  bn,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
Set< const TABLE< GUM_SCALAR > * > &  potentials 
)
static

update a set of potentials, keeping only those d-connected with query variables given evidence

Definition at line 34 of file BayesBall_tpl.h.

References gum::Set< Key, Alloc >::emplace().

38  {
39  const DAG& dag = bn.dag();
40 
41  // create the marks (top = first and bottom = second)
42  NodeProperty< std::pair< bool, bool > > marks;
43  marks.resize(dag.size());
44  const std::pair< bool, bool > empty_mark(false, false);
45 
48  HashTable< NodeId, Set< const TABLE< GUM_SCALAR >* > > node2potentials;
49  for (const auto pot: potentials) {
50  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
51  for (const auto var: vars) {
52  const NodeId id = bn.nodeId(*var);
53  if (!node2potentials.exists(id)) {
54  node2potentials.insert(id, Set< const TABLE< GUM_SCALAR >* >());
55  }
56  node2potentials[id].insert(pot);
57  }
58  }
59 
60  // indicate that we will send the ball to all the query nodes (as children):
61  // in list nodes_to_visit, the first element is the next node to send the
62  // ball to and the Boolean indicates whether we shall reach it from one of
63  // its children (true) or from one parent (false)
64  List< std::pair< NodeId, bool > > nodes_to_visit;
65  for (const auto node: query) {
66  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
67  }
68 
69  // perform the bouncing ball until node2potentials__ becomes empty (which
70  // means that we have reached all the potentials and, therefore, those
71  // are d-connected to query) or until there is no node in the graph to send
72  // the ball to
73  while (!nodes_to_visit.empty() && !node2potentials.empty()) {
74  // get the next node to visit
75  NodeId node = nodes_to_visit.front().first;
76 
77  // if the marks of the node do not exist, create them
78  if (!marks.exists(node)) marks.insert(node, empty_mark);
79 
80  // if the node belongs to the query, update node2potentials__: remove all
81  // the potentials containing the node
82  if (node2potentials.exists(node)) {
83  auto& pot_set = node2potentials[node];
84  for (const auto pot: pot_set) {
85  const auto& vars = pot->variablesSequence();
86  for (const auto var: vars) {
87  const NodeId id = bn.nodeId(*var);
88  if (id != node) {
89  node2potentials[id].erase(pot);
90  if (node2potentials[id].empty()) { node2potentials.erase(id); }
91  }
92  }
93  }
94  node2potentials.erase(node);
95 
96  // if node2potentials__ is empty, no need to go on: all the potentials
97  // are d-connected to the query
98  if (node2potentials.empty()) return;
99  }
100 
101 
102  // bounce the ball toward the neighbors
103  if (nodes_to_visit.front().second) { // visit from a child
104  nodes_to_visit.popFront();
105 
106  if (hardEvidence.exists(node)) { continue; }
107 
108  if (!marks[node].first) {
109  marks[node].first = true; // top marked
110  for (const auto par: dag.parents(node)) {
111  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
112  }
113  }
114 
115  if (!marks[node].second) {
116  marks[node].second = true; // bottom marked
117  for (const auto chi: dag.children(node)) {
118  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
119  }
120  }
121  } else { // visit from a parent
122  nodes_to_visit.popFront();
123 
124  const bool is_hard_evidence = hardEvidence.exists(node);
125  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
126 
127  if (is_evidence && !marks[node].first) {
128  marks[node].first = true;
129 
130  for (const auto par: dag.parents(node)) {
131  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
132  }
133  }
134 
135  if (!is_hard_evidence && !marks[node].second) {
136  marks[node].second = true;
137 
138  for (const auto chi: dag.children(node)) {
139  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
140  }
141  }
142  }
143  }
144 
145 
146  // here, all the potentials that belong to node2potentials__ are d-separated
147  // from the query
148  for (const auto elt: node2potentials) {
149  for (const auto pot: elt.second) {
150  potentials.erase(pot);
151  }
152  }
153  }
Size NodeId
Type for node ids.
Definition: graphElements.h:97
+ Here is the call graph for this function:

◆ requisiteNodes()

void gum::BayesBall::requisiteNodes ( const DAG dag,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
NodeSet requisite 
)
static

Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.

Requisite nodes are those that are d-connected to at least one of the query nodes given a set of hard and soft evidence

Definition at line 35 of file BayesBall.cpp.

References gum::Set< Key, Alloc >::emplace().

39  {
40  // for the moment, no node is requisite
41  requisite.clear();
42 
43  // create the marks (top = first and bottom = second )
44  NodeProperty< std::pair< bool, bool > > marks(dag.size());
45  const std::pair< bool, bool > empty_mark(false, false);
46 
47  // indicate that we will send the ball to all the query nodes (as children):
48  // in list nodes_to_visit, the first element is the next node to send the
49  // ball to and the Boolean indicates whether we shall reach it from one of
50  // its children (true) or from one parent (false)
51  List< std::pair< NodeId, bool > > nodes_to_visit;
52  for (const auto node: query) {
53  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
54  }
55 
56  // perform the bouncing ball until there is no node in the graph to send
57  // the ball to
58  while (!nodes_to_visit.empty()) {
59  // get the next node to visit
60  NodeId node = nodes_to_visit.front().first;
61 
62  // if the marks of the node do not exist, create them
63  if (!marks.exists(node)) marks.insert(node, empty_mark);
64 
65  // bounce the ball toward the neighbors
66  if (nodes_to_visit.front().second) { // visit from a child
67  nodes_to_visit.popFront();
68  requisite.insert(node);
69 
70  if (hardEvidence.exists(node)) { continue; }
71 
72  if (!marks[node].first) {
73  marks[node].first = true; // top marked
74  for (const auto par: dag.parents(node)) {
75  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
76  }
77  }
78 
79  if (!marks[node].second) {
80  marks[node].second = true; // bottom marked
81  for (const auto chi: dag.children(node)) {
82  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
83  }
84  }
85  } else { // visit from a parent
86  nodes_to_visit.popFront();
87 
88  const bool is_hard_evidence = hardEvidence.exists(node);
89  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
90 
91  if (is_evidence && !marks[node].first) {
92  marks[node].first = true;
93  requisite.insert(node);
94 
95  for (const auto par: dag.parents(node)) {
96  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
97  }
98  }
99 
100  if (!is_hard_evidence && !marks[node].second) {
101  marks[node].second = true;
102 
103  for (const auto chi: dag.children(node)) {
104  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
105  }
106  }
107  }
108  }
109  }
Size NodeId
Type for node ids.
Definition: graphElements.h:97
+ Here is the call graph for this function:

The documentation for this class was generated from the following files: