aGrUM  0.16.0
gum::BayesBall Class Reference

Implementation of Shachter's Bayes Balls algorithm. More...

#include <agrum/BN/inference/BayesBall.h>

Static Public Member Functions

Accessors / Modifiers
static void requisiteNodes (const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
 Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence. More...
 
template<typename GUM_SCALAR , template< typename > class TABLE>
static void relevantPotentials (const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
 update a set of potentials, keeping only those d-connected with query variables given evidence More...
 

Detailed Description

Implementation of Shachter's Bayes Balls algorithm.

Definition at line 51 of file BayesBall.h.

Constructor & Destructor Documentation

◆ BayesBall()

INLINE gum::BayesBall::BayesBall ( )
private

Default constructor.

Definition at line 34 of file BayesBall_inl.h.

34 { GUM_CONSTRUCTOR(BayesBall); }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:34

◆ ~BayesBall()

INLINE gum::BayesBall::~BayesBall ( )
private

Destructor.

Definition at line 37 of file BayesBall_inl.h.

37 { GUM_DESTRUCTOR(BayesBall); }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:34

Member Function Documentation

◆ relevantPotentials()

template<typename GUM_SCALAR , template< typename > class TABLE>
void gum::BayesBall::relevantPotentials ( const IBayesNet< GUM_SCALAR > &  bn,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
Set< const TABLE< GUM_SCALAR > * > &  potentials 
)
static

update a set of potentials, keeping only those d-connected with query variables given evidence

Definition at line 35 of file BayesBall_tpl.h.

References gum::ArcGraphPart::children(), gum::DAGmodel::dag(), gum::List< Val, Alloc >::empty(), gum::HashTable< Key, Val, Alloc >::empty(), gum::HashTable< Key, Val, Alloc >::erase(), gum::Set< Key, Alloc >::exists(), gum::HashTable< Key, Val, Alloc >::exists(), gum::List< Val, Alloc >::front(), gum::List< Val, Alloc >::insert(), gum::HashTable< Key, Val, Alloc >::insert(), gum::IBayesNet< GUM_SCALAR >::nodeId(), gum::ArcGraphPart::parents(), gum::List< Val, Alloc >::popFront(), gum::HashTable< Key, Val, Alloc >::resize(), and gum::NodeGraphPart::size().

39  {
40  const DAG& dag = bn.dag();
41 
42  // create the marks (top = first and bottom = second)
43  NodeProperty< std::pair< bool, bool > > marks;
44  marks.resize(dag.size());
45  const std::pair< bool, bool > empty_mark(false, false);
46 
49  HashTable< NodeId, Set< const TABLE< GUM_SCALAR >* > > node2potentials;
50  for (const auto pot : potentials) {
51  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
52  for (const auto var : vars) {
53  const NodeId id = bn.nodeId(*var);
54  if (!node2potentials.exists(id)) {
55  node2potentials.insert(id, Set< const TABLE< GUM_SCALAR >* >());
56  }
57  node2potentials[id].insert(pot);
58  }
59  }
60 
61  // indicate that we will send the ball to all the query nodes (as children):
62  // in list nodes_to_visit, the first element is the next node to send the
63  // ball to and the Boolean indicates whether we shall reach it from one of
64  // its children (true) or from one parent (false)
65  List< std::pair< NodeId, bool > > nodes_to_visit;
66  for (const auto node : query) {
67  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
68  }
69 
70  // perform the bouncing ball until __node2potentials becomes empty (which
71  // means that we have reached all the potentials and, therefore, those
72  // are d-connected to query) or until there is no node in the graph to send
73  // the ball to
74  while (!nodes_to_visit.empty() && !node2potentials.empty()) {
75  // get the next node to visit
76  NodeId node = nodes_to_visit.front().first;
77 
78  // if the marks of the node do not exist, create them
79  if (!marks.exists(node)) marks.insert(node, empty_mark);
80 
81  // if the node belongs to the query, update __node2potentials: remove all
82  // the potentials containing the node
83  if (node2potentials.exists(node)) {
84  auto& pot_set = node2potentials[node];
85  for (const auto pot : pot_set) {
86  const auto& vars = pot->variablesSequence();
87  for (const auto var : vars) {
88  const NodeId id = bn.nodeId(*var);
89  if (id != node) {
90  node2potentials[id].erase(pot);
91  if (node2potentials[id].empty()) { node2potentials.erase(id); }
92  }
93  }
94  }
95  node2potentials.erase(node);
96 
97  // if __node2potentials is empty, no need to go on: all the potentials
98  // are d-connected to the query
99  if (node2potentials.empty()) return;
100  }
101 
102 
103  // bounce the ball toward the neighbors
104  if (nodes_to_visit.front().second) { // visit from a child
105  nodes_to_visit.popFront();
106 
107  if (hardEvidence.exists(node)) { continue; }
108 
109  if (!marks[node].first) {
110  marks[node].first = true; // top marked
111  for (const auto par : dag.parents(node)) {
112  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
113  }
114  }
115 
116  if (!marks[node].second) {
117  marks[node].second = true; // bottom marked
118  for (const auto chi : dag.children(node)) {
119  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
120  }
121  }
122  } else { // visit from a parent
123  nodes_to_visit.popFront();
124 
125  const bool is_hard_evidence = hardEvidence.exists(node);
126  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
127 
128  if (is_evidence && !marks[node].first) {
129  marks[node].first = true;
130 
131  for (const auto par : dag.parents(node)) {
132  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
133  }
134  }
135 
136  if (!is_hard_evidence && !marks[node].second) {
137  marks[node].second = true;
138 
139  for (const auto chi : dag.children(node)) {
140  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
141  }
142  }
143  }
144  }
145 
146 
147  // here, all the potentials that belong to __node2potentials are d-separated
148  // from the query
149  for (const auto elt : node2potentials) {
150  for (const auto pot : elt.second) {
151  potentials.erase(pot);
152  }
153  }
154  }
Size NodeId
Type for node ids.
Definition: graphElements.h:98
+ Here is the call graph for this function:

◆ requisiteNodes()

void gum::BayesBall::requisiteNodes ( const DAG dag,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
NodeSet requisite 
)
static

Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.

Requisite nodes are those that are d-connected to at least one of the query nodes given a set of hard and soft evidence

Definition at line 36 of file BayesBall.cpp.

References gum::ArcGraphPart::children(), gum::Set< Key, Alloc >::clear(), gum::List< Val, Alloc >::empty(), gum::Set< Key, Alloc >::exists(), gum::List< Val, Alloc >::front(), gum::Set< Key, Alloc >::insert(), gum::List< Val, Alloc >::insert(), gum::ArcGraphPart::parents(), gum::List< Val, Alloc >::popFront(), and gum::NodeGraphPart::size().

40  {
41  // for the moment, no node is requisite
42  requisite.clear();
43 
44  // create the marks (top = first and bottom = second )
45  NodeProperty< std::pair< bool, bool > > marks(dag.size());
46  const std::pair< bool, bool > empty_mark(false, false);
47 
48  // indicate that we will send the ball to all the query nodes (as children):
49  // in list nodes_to_visit, the first element is the next node to send the
50  // ball to and the Boolean indicates whether we shall reach it from one of
51  // its children (true) or from one parent (false)
52  List< std::pair< NodeId, bool > > nodes_to_visit;
53  for (const auto node : query) {
54  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
55  }
56 
57  // perform the bouncing ball until there is no node in the graph to send
58  // the ball to
59  while (!nodes_to_visit.empty()) {
60  // get the next node to visit
61  NodeId node = nodes_to_visit.front().first;
62 
63  // if the marks of the node do not exist, create them
64  if (!marks.exists(node)) marks.insert(node, empty_mark);
65 
66  // bounce the ball toward the neighbors
67  if (nodes_to_visit.front().second) { // visit from a child
68  nodes_to_visit.popFront();
69  requisite.insert(node);
70 
71  if (hardEvidence.exists(node)) { continue; }
72 
73  if (!marks[node].first) {
74  marks[node].first = true; // top marked
75  for (const auto par : dag.parents(node)) {
76  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
77  }
78  }
79 
80  if (!marks[node].second) {
81  marks[node].second = true; // bottom marked
82  for (const auto chi : dag.children(node)) {
83  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
84  }
85  }
86  } else { // visit from a parent
87  nodes_to_visit.popFront();
88 
89  const bool is_hard_evidence = hardEvidence.exists(node);
90  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
91 
92  if (is_evidence && !marks[node].first) {
93  marks[node].first = true;
94  requisite.insert(node);
95 
96  for (const auto par : dag.parents(node)) {
97  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
98  }
99  }
100 
101  if (!is_hard_evidence && !marks[node].second) {
102  marks[node].second = true;
103 
104  for (const auto chi : dag.children(node)) {
105  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
106  }
107  }
108  }
109  }
110  }
Size NodeId
Type for node ids.
Definition: graphElements.h:98
+ Here is the call graph for this function:

The documentation for this class was generated from the following files: