aGrUM  0.20.3
a C++ library for (probabilistic) graphical models

Implementation of Shachter's Bayes Balls algorithm. More...

#include <agrum/BN/inference/BayesBall.h>

Static Public Member Functions

Accessors / Modifiers
static void requisiteNodes (const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
 Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence. More...
 
template<typename GUM_SCALAR , template< typename > class TABLE>
static void relevantPotentials (const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
 update a set of potentials, keeping only those d-connected with query variables given evidence More...
 

Detailed Description

Implementation of Shachter's Bayes Balls algorithm.

Definition at line 51 of file BayesBall.h.

Constructor & Destructor Documentation

◆ BayesBall()

INLINE gum::BayesBall::BayesBall ( )
private

Default constructor.

Definition at line 33 of file BayesBall_inl.h.

References gum::Set< Key, Alloc >::emplace().

33  {
34  GUM_CONSTRUCTOR(BayesBall);
35  ;
36  }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:33
+ Here is the call graph for this function:

◆ ~BayesBall()

INLINE gum::BayesBall::~BayesBall ( )
private

Destructor.

Definition at line 39 of file BayesBall_inl.h.

References gum::Set< Key, Alloc >::emplace().

39  {
40  GUM_DESTRUCTOR(BayesBall);
41  ;
42  }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:33
+ Here is the call graph for this function:

Member Function Documentation

◆ relevantPotentials()

template<typename GUM_SCALAR , template< typename > class TABLE>
void gum::BayesBall::relevantPotentials ( const IBayesNet< GUM_SCALAR > &  bn,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
Set< const TABLE< GUM_SCALAR > * > &  potentials 
)
static

update a set of potentials, keeping only those d-connected with query variables given evidence

Definition at line 33 of file BayesBall_tpl.h.

References gum::Set< Key, Alloc >::emplace().

37  {
38  const DAG& dag = bn.dag();
39 
40  // create the marks (top = first and bottom = second)
41  NodeProperty< std::pair< bool, bool > > marks;
42  marks.resize(dag.size());
43  const std::pair< bool, bool > empty_mark(false, false);
44 
47  HashTable< NodeId, Set< const TABLE< GUM_SCALAR >* > > node2potentials;
48  for (const auto pot: potentials) {
49  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
50  for (const auto var: vars) {
51  const NodeId id = bn.nodeId(*var);
52  if (!node2potentials.exists(id)) {
53  node2potentials.insert(id, Set< const TABLE< GUM_SCALAR >* >());
54  }
55  node2potentials[id].insert(pot);
56  }
57  }
58 
59  // indicate that we will send the ball to all the query nodes (as children):
60  // in list nodes_to_visit, the first element is the next node to send the
61  // ball to and the Boolean indicates whether we shall reach it from one of
62  // its children (true) or from one parent (false)
63  List< std::pair< NodeId, bool > > nodes_to_visit;
64  for (const auto node: query) {
65  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
66  }
67 
68  // perform the bouncing ball until _node2potentials_ becomes empty (which
69  // means that we have reached all the potentials and, therefore, those
70  // are d-connected to query) or until there is no node in the graph to send
71  // the ball to
72  while (!nodes_to_visit.empty() && !node2potentials.empty()) {
73  // get the next node to visit
74  NodeId node = nodes_to_visit.front().first;
75 
76  // if the marks of the node do not exist, create them
77  if (!marks.exists(node)) marks.insert(node, empty_mark);
78 
79  // if the node belongs to the query, update _node2potentials_: remove all
80  // the potentials containing the node
81  if (node2potentials.exists(node)) {
82  auto& pot_set = node2potentials[node];
83  for (const auto pot: pot_set) {
84  const auto& vars = pot->variablesSequence();
85  for (const auto var: vars) {
86  const NodeId id = bn.nodeId(*var);
87  if (id != node) {
88  node2potentials[id].erase(pot);
89  if (node2potentials[id].empty()) { node2potentials.erase(id); }
90  }
91  }
92  }
93  node2potentials.erase(node);
94 
95  // if _node2potentials_ is empty, no need to go on: all the potentials
96  // are d-connected to the query
97  if (node2potentials.empty()) return;
98  }
99 
100 
101  // bounce the ball toward the neighbors
102  if (nodes_to_visit.front().second) { // visit from a child
103  nodes_to_visit.popFront();
104 
105  if (hardEvidence.exists(node)) { continue; }
106 
107  if (!marks[node].first) {
108  marks[node].first = true; // top marked
109  for (const auto par: dag.parents(node)) {
110  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
111  }
112  }
113 
114  if (!marks[node].second) {
115  marks[node].second = true; // bottom marked
116  for (const auto chi: dag.children(node)) {
117  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
118  }
119  }
120  } else { // visit from a parent
121  nodes_to_visit.popFront();
122 
123  const bool is_hard_evidence = hardEvidence.exists(node);
124  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
125 
126  if (is_evidence && !marks[node].first) {
127  marks[node].first = true;
128 
129  for (const auto par: dag.parents(node)) {
130  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
131  }
132  }
133 
134  if (!is_hard_evidence && !marks[node].second) {
135  marks[node].second = true;
136 
137  for (const auto chi: dag.children(node)) {
138  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
139  }
140  }
141  }
142  }
143 
144 
145  // here, all the potentials that belong to _node2potentials_ are d-separated
146  // from the query
147  for (const auto elt: node2potentials) {
148  for (const auto pot: elt.second) {
149  potentials.erase(pot);
150  }
151  }
152  }
Size NodeId
Type for node ids.
Definition: graphElements.h:97
+ Here is the call graph for this function:

◆ requisiteNodes()

void gum::BayesBall::requisiteNodes ( const DAG dag,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
NodeSet requisite 
)
static

Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.

Requisite nodes are those that are d-connected to at least one of the query nodes given a set of hard and soft evidence

Definition at line 35 of file BayesBall.cpp.

References gum::Set< Key, Alloc >::emplace().

39  {
40  // for the moment, no node is requisite
41  requisite.clear();
42 
43  // create the marks (top = first and bottom = second )
44  NodeProperty< std::pair< bool, bool > > marks(dag.size());
45  const std::pair< bool, bool > empty_mark(false, false);
46 
47  // indicate that we will send the ball to all the query nodes (as children):
48  // in list nodes_to_visit, the first element is the next node to send the
49  // ball to and the Boolean indicates whether we shall reach it from one of
50  // its children (true) or from one parent (false)
51  List< std::pair< NodeId, bool > > nodes_to_visit;
52  for (const auto node: query) {
53  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
54  }
55 
56  // perform the bouncing ball until there is no node in the graph to send
57  // the ball to
58  while (!nodes_to_visit.empty()) {
59  // get the next node to visit
60  NodeId node = nodes_to_visit.front().first;
61 
62  // if the marks of the node do not exist, create them
63  if (!marks.exists(node)) marks.insert(node, empty_mark);
64 
65  // bounce the ball toward the neighbors
66  if (nodes_to_visit.front().second) { // visit from a child
67  nodes_to_visit.popFront();
68  requisite.insert(node);
69 
70  if (hardEvidence.exists(node)) { continue; }
71 
72  if (!marks[node].first) {
73  marks[node].first = true; // top marked
74  for (const auto par: dag.parents(node)) {
75  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
76  }
77  }
78 
79  if (!marks[node].second) {
80  marks[node].second = true; // bottom marked
81  for (const auto chi: dag.children(node)) {
82  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
83  }
84  }
85  } else { // visit from a parent
86  nodes_to_visit.popFront();
87 
88  const bool is_hard_evidence = hardEvidence.exists(node);
89  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
90 
91  if (is_evidence && !marks[node].first) {
92  marks[node].first = true;
93  requisite.insert(node);
94 
95  for (const auto par: dag.parents(node)) {
96  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
97  }
98  }
99 
100  if (!is_hard_evidence && !marks[node].second) {
101  marks[node].second = true;
102 
103  for (const auto chi: dag.children(node)) {
104  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
105  }
106  }
107  }
108  }
109  }
Size NodeId
Type for node ids.
Definition: graphElements.h:97
+ Here is the call graph for this function:

The documentation for this class was generated from the following files: