aGrUM  0.14.2
gum::BayesBall Class Reference

Implementation of Shachter's Bayes Balls algorithm. More...

#include <agrum/BN/inference/BayesBall.h>

Static Public Member Functions

Accessors / Modifiers
static void requisiteNodes (const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
 Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence. More...
 
template<typename GUM_SCALAR , template< typename > class TABLE>
static void relevantPotentials (const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
 update a set of potentials, keeping only those d-connected with query variables given evidence More...
 

Detailed Description

Implementation of Shachter's Bayes Balls algorithm.

Definition at line 48 of file BayesBall.h.

Constructor & Destructor Documentation

◆ BayesBall()

INLINE gum::BayesBall::BayesBall ( )
private

Default constructor.

Definition at line 31 of file BayesBall_inl.h.

31 { GUM_CONSTRUCTOR(BayesBall); }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:31

◆ ~BayesBall()

INLINE gum::BayesBall::~BayesBall ( )
private

Destructor.

Definition at line 34 of file BayesBall_inl.h.

34 { GUM_DESTRUCTOR(BayesBall); }
BayesBall()
Default constructor.
Definition: BayesBall_inl.h:31

Member Function Documentation

◆ relevantPotentials()

template<typename GUM_SCALAR , template< typename > class TABLE>
void gum::BayesBall::relevantPotentials ( const IBayesNet< GUM_SCALAR > &  bn,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
Set< const TABLE< GUM_SCALAR > * > &  potentials 
)
static

update a set of potentials, keeping only those d-connected with query variables given evidence

Definition at line 32 of file BayesBall_tpl.h.

References gum::ArcGraphPart::children(), gum::DAGmodel::dag(), gum::List< Val, Alloc >::empty(), gum::HashTable< Key, Val, Alloc >::empty(), gum::HashTable< Key, Val, Alloc >::erase(), gum::Set< Key, Alloc >::exists(), gum::HashTable< Key, Val, Alloc >::exists(), gum::List< Val, Alloc >::front(), gum::List< Val, Alloc >::insert(), gum::HashTable< Key, Val, Alloc >::insert(), gum::IBayesNet< GUM_SCALAR >::nodeId(), gum::ArcGraphPart::parents(), gum::List< Val, Alloc >::popFront(), gum::HashTable< Key, Val, Alloc >::resize(), and gum::NodeGraphPart::size().

36  {
37  const DAG& dag = bn.dag();
38 
39  // create the marks (top = first and bottom = second)
40  NodeProperty< std::pair< bool, bool > > marks;
41  marks.resize(dag.size());
42  const std::pair< bool, bool > empty_mark(false, false);
43 
46  HashTable< NodeId, Set< const TABLE< GUM_SCALAR >* > > node2potentials;
47  for (const auto pot : potentials) {
48  const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
49  for (const auto var : vars) {
50  const NodeId id = bn.nodeId(*var);
51  if (!node2potentials.exists(id)) {
52  node2potentials.insert(id, Set< const TABLE< GUM_SCALAR >* >());
53  }
54  node2potentials[id].insert(pot);
55  }
56  }
57 
58  // indicate that we will send the ball to all the query nodes (as children):
59  // in list nodes_to_visit, the first element is the next node to send the
60  // ball to and the Boolean indicates whether we shall reach it from one of
61  // its children (true) or from one parent (false)
62  List< std::pair< NodeId, bool > > nodes_to_visit;
63  for (const auto node : query) {
64  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
65  }
66 
67  // perform the bouncing ball until __node2potentials becomes empty (which
68  // means that we have reached all the potentials and, therefore, those
69  // are d-connected to query) or until there is no node in the graph to send
70  // the ball to
71  while (!nodes_to_visit.empty() && !node2potentials.empty()) {
72  // get the next node to visit
73  NodeId node = nodes_to_visit.front().first;
74 
75  // if the marks of the node do not exist, create them
76  if (!marks.exists(node)) marks.insert(node, empty_mark);
77 
78  // if the node belongs to the query, update __node2potentials: remove all
79  // the potentials containing the node
80  if (node2potentials.exists(node)) {
81  auto& pot_set = node2potentials[node];
82  for (const auto pot : pot_set) {
83  const auto& vars = pot->variablesSequence();
84  for (const auto var : vars) {
85  const NodeId id = bn.nodeId(*var);
86  if (id != node) {
87  node2potentials[id].erase(pot);
88  if (node2potentials[id].empty()) { node2potentials.erase(id); }
89  }
90  }
91  }
92  node2potentials.erase(node);
93 
94  // if __node2potentials is empty, no need to go on: all the potentials
95  // are d-connected to the query
96  if (node2potentials.empty()) return;
97  }
98 
99 
100  // bounce the ball toward the neighbors
101  if (nodes_to_visit.front().second) { // visit from a child
102  nodes_to_visit.popFront();
103 
104  if (hardEvidence.exists(node)) { continue; }
105 
106  if (!marks[node].first) {
107  marks[node].first = true; // top marked
108  for (const auto par : dag.parents(node)) {
109  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
110  }
111  }
112 
113  if (!marks[node].second) {
114  marks[node].second = true; // bottom marked
115  for (const auto chi : dag.children(node)) {
116  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
117  }
118  }
119  } else { // visit from a parent
120  nodes_to_visit.popFront();
121 
122  const bool is_hard_evidence = hardEvidence.exists(node);
123  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
124 
125  if (is_evidence && !marks[node].first) {
126  marks[node].first = true;
127 
128  for (const auto par : dag.parents(node)) {
129  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
130  }
131  }
132 
133  if (!is_hard_evidence && !marks[node].second) {
134  marks[node].second = true;
135 
136  for (const auto chi : dag.children(node)) {
137  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
138  }
139  }
140  }
141  }
142 
143 
144  // here, all the potentials that belong to __node2potentials are d-separated
145  // from the query
146  for (const auto elt : node2potentials) {
147  for (const auto pot : elt.second) {
148  potentials.erase(pot);
149  }
150  }
151  }
Size NodeId
Type for node ids.
Definition: graphElements.h:97
+ Here is the call graph for this function:

◆ requisiteNodes()

void gum::BayesBall::requisiteNodes ( const DAG dag,
const NodeSet query,
const NodeSet hardEvidence,
const NodeSet softEvidence,
NodeSet requisite 
)
static

Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.

Requisite nodes are those that are d-connected to at least one of the query nodes given a set of hard and soft evidence

Definition at line 33 of file BayesBall.cpp.

References gum::ArcGraphPart::children(), gum::Set< Key, Alloc >::clear(), gum::List< Val, Alloc >::empty(), gum::Set< Key, Alloc >::exists(), gum::List< Val, Alloc >::front(), gum::Set< Key, Alloc >::insert(), gum::List< Val, Alloc >::insert(), gum::ArcGraphPart::parents(), gum::List< Val, Alloc >::popFront(), and gum::NodeGraphPart::size().

37  {
38  // for the moment, no node is requisite
39  requisite.clear();
40 
41  // create the marks (top = first and bottom = second )
42  NodeProperty< std::pair< bool, bool > > marks(dag.size());
43  const std::pair< bool, bool > empty_mark(false, false);
44 
45  // indicate that we will send the ball to all the query nodes (as children):
46  // in list nodes_to_visit, the first element is the next node to send the
47  // ball to and the Boolean indicates whether we shall reach it from one of
48  // its children (true) or from one parent (false)
49  List< std::pair< NodeId, bool > > nodes_to_visit;
50  for (const auto node : query) {
51  nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
52  }
53 
54  // perform the bouncing ball until there is no node in the graph to send
55  // the ball to
56  while (!nodes_to_visit.empty()) {
57  // get the next node to visit
58  NodeId node = nodes_to_visit.front().first;
59 
60  // if the marks of the node do not exist, create them
61  if (!marks.exists(node)) marks.insert(node, empty_mark);
62 
63  // bounce the ball toward the neighbors
64  if (nodes_to_visit.front().second) { // visit from a child
65  nodes_to_visit.popFront();
66  requisite.insert(node);
67 
68  if (hardEvidence.exists(node)) { continue; }
69 
70  if (!marks[node].first) {
71  marks[node].first = true; // top marked
72  for (const auto par : dag.parents(node)) {
73  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
74  }
75  }
76 
77  if (!marks[node].second) {
78  marks[node].second = true; // bottom marked
79  for (const auto chi : dag.children(node)) {
80  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
81  }
82  }
83  } else { // visit from a parent
84  nodes_to_visit.popFront();
85 
86  const bool is_hard_evidence = hardEvidence.exists(node);
87  const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
88 
89  if (is_evidence && !marks[node].first) {
90  marks[node].first = true;
91  requisite.insert(node);
92 
93  for (const auto par : dag.parents(node)) {
94  nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
95  }
96  }
97 
98  if (!is_hard_evidence && !marks[node].second) {
99  marks[node].second = true;
100 
101  for (const auto chi : dag.children(node)) {
102  nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
103  }
104  }
105  }
106  }
107  }
Size NodeId
Type for node ids.
Definition: graphElements.h:97
+ Here is the call graph for this function:

The documentation for this class was generated from the following files: