aGrUM  0.17.2
a C++ library for (probabilistic) graphical models
BayesNetFragment_tpl.h
Go to the documentation of this file.
1 
29 #include <agrum/BN/BayesNet.h>
32 
33 namespace gum {
34  template < typename GUM_SCALAR >
36  const IBayesNet< GUM_SCALAR >& bn) :
37  DiGraphListener(&bn.dag()),
38  __bn(bn) {
39  GUM_CONSTRUCTOR(BayesNetFragment);
40  }
41 
42  template < typename GUM_SCALAR >
44  GUM_DESTRUCTOR(BayesNetFragment);
45 
46  for (auto node: nodes())
47  if (__localCPTs.exists(node)) _uninstallCPT(node);
48  }
49 
50  //============================================================
51  // signals to keep consistency with the referred BayesNet
52  template < typename GUM_SCALAR >
53  INLINE void BayesNetFragment< GUM_SCALAR >::whenNodeAdded(const void* src,
54  NodeId id) {
55  // nothing to do
56  }
57  template < typename GUM_SCALAR >
59  NodeId id) {
60  uninstallNode(id);
61  }
62  template < typename GUM_SCALAR >
63  INLINE void BayesNetFragment< GUM_SCALAR >::whenArcAdded(const void* src,
64  NodeId from,
65  NodeId to) {
66  // nothing to do
67  }
68  template < typename GUM_SCALAR >
70  NodeId from,
71  NodeId to) {
72  if (dag().existsArc(from, to)) _uninstallArc(from, to);
73  }
74 
75  //============================================================
76  // IBayesNet interface : BayesNetFragment here is a decorator for the bn
77 
78  template < typename GUM_SCALAR >
79  INLINE const Potential< GUM_SCALAR >&
81  if (!isInstalledNode(id))
82  GUM_ERROR(NotFound, "NodeId " << id << " is not installed");
83 
84  if (__localCPTs.exists(id))
85  return *__localCPTs[id];
86  else
87  return __bn.cpt(id);
88  }
89 
90  template < typename GUM_SCALAR >
91  INLINE const VariableNodeMap&
94  "Not implemented yet. please use referent bayesnet method");
95  }
96 
97  template < typename GUM_SCALAR >
98  INLINE const DiscreteVariable&
100  if (!isInstalledNode(id))
101  GUM_ERROR(NotFound, "NodeId " << id << " is not installed");
102 
103  return __bn.variable(id);
104  }
105 
106  template < typename GUM_SCALAR >
107  INLINE NodeId
109  NodeId id = __bn.nodeId(var);
110 
111  if (!isInstalledNode(id))
112  GUM_ERROR(NotFound, "variable " << var.name() << " is not installed");
113 
114  return id;
115  }
116 
117  template < typename GUM_SCALAR >
118  INLINE NodeId
119  BayesNetFragment< GUM_SCALAR >::idFromName(const std::string& name) const {
120  NodeId id = __bn.idFromName(name);
121 
122  if (!isInstalledNode(id))
123  GUM_ERROR(NotFound, "variable " << name << " is not installed");
124 
125  return id;
126  }
127 
128  template < typename GUM_SCALAR >
130  const std::string& name) const {
131  NodeId id = idFromName(name);
132 
133  if (!isInstalledNode(id))
134  GUM_ERROR(NotFound, "variable " << name << " is not installed");
135 
136  return __bn.variable(id);
137  }
138 
139  //============================================================
140  // specific API for BayesNetFragment
141  template < typename GUM_SCALAR >
143  return dag().existsNode(id);
144  }
145 
146  template < typename GUM_SCALAR >
148  if (!__bn.dag().existsNode(id))
149  GUM_ERROR(NotFound, "Node " << id << " does not exist in referred BayesNet");
150 
151  if (!isInstalledNode(id)) {
152  this->_dag.addNodeWithId(id);
153 
154  // adding arcs with id as a tail
155  for (auto pa: this->__bn.parents(id)) {
156  if (isInstalledNode(pa)) this->_dag.addArc(pa, id);
157  }
158 
159  // addin arcs with id as a head
160  for (auto son: this->__bn.children(id))
161  if (isInstalledNode(son)) this->_dag.addArc(id, son);
162  }
163  }
164 
165  template < typename GUM_SCALAR >
167  installNode(id);
168 
169  // bn is a dag => this will have an end ...
170  for (auto pa: this->__bn.parents(id))
171  installAscendants(pa);
172  }
173 
174  template < typename GUM_SCALAR >
176  if (isInstalledNode(id)) {
177  uninstallCPT(id);
178  this->_dag.eraseNode(id);
179  }
180  }
181 
182  template < typename GUM_SCALAR >
184  NodeId to) {
185  this->_dag.eraseArc(Arc(from, to));
186  }
187 
188  template < typename GUM_SCALAR >
190  this->_dag.addArc(from, to);
191  }
192 
193  template < typename GUM_SCALAR >
195  NodeId id, const Potential< GUM_SCALAR >& pot) {
196  // topology
197  const auto& parents = this->parents(id);
198  for (auto node_it = parents.beginSafe(); node_it != parents.endSafe();
199  ++node_it) // safe iterator needed here
200  _uninstallArc(*node_it, id);
201 
202  for (Idx i = 1; i < pot.nbrDim(); i++) {
203  NodeId parent = __bn.idFromName(pot.variable(i).name());
204 
205  if (isInstalledNode(parent)) _installArc(parent, id);
206  }
207 
208  // local cpt
209  if (__localCPTs.exists(id)) _uninstallCPT(id);
210 
211  __localCPTs.insert(id, new gum::Potential< GUM_SCALAR >(pot));
212  }
213 
214  template < typename GUM_SCALAR >
216  NodeId id, const Potential< GUM_SCALAR >& pot) {
217  if (!dag().existsNode(id))
218  GUM_ERROR(NotFound, "Node " << id << " is not installed in the fragment");
219 
220  if (&(pot.variable(0)) != &(variable(id))) {
222  "The potential is not a marginal for __bn.variable <"
223  << variable(id).name() << ">");
224  }
225 
226  const NodeSet& parents = __bn.parents(id);
227 
228  for (Idx i = 1; i < pot.nbrDim(); i++) {
229  if (!parents.contains(__bn.idFromName(pot.variable(i).name())))
231  "Variable <" << pot.variable(i).name()
232  << "> is not in the parents of node " << id);
233  }
234 
235  _installCPT(id, pot);
236  }
237 
238  template < typename GUM_SCALAR >
240  delete __localCPTs[id];
241  __localCPTs.erase(id);
242  }
243 
244  template < typename GUM_SCALAR >
246  if (__localCPTs.exists(id)) {
247  _uninstallCPT(id);
248 
249  // re-create arcs from referred potential
250  const Potential< GUM_SCALAR >& pot = cpt(id);
251 
252  for (Idx i = 1; i < pot.nbrDim(); i++) {
253  NodeId parent = __bn.idFromName(pot.variable(i).name());
254 
255  if (isInstalledNode(parent)) _installArc(parent, id);
256  }
257  }
258  }
259 
260  template < typename GUM_SCALAR >
262  NodeId id, const Potential< GUM_SCALAR >& pot) {
263  if (!isInstalledNode(id)) {
264  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
265  }
266 
267  if (pot.nbrDim() > 1) {
268  GUM_ERROR(OperationNotAllowed, "The potential is not a marginal :" << pot);
269  }
270 
271  if (&(pot.variable(0)) != &(__bn.variable(id))) {
273  "The potential is not a marginal for __bn.variable <"
274  << __bn.variable(id).name() << ">");
275  }
276 
277  _installCPT(id, pot);
278  }
279 
280  template < typename GUM_SCALAR >
282  if (!isInstalledNode(id))
283  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
284 
285  const auto& cpt = this->cpt(id);
286  NodeSet cpt_parents;
287 
288  for (Idx i = 1; i < cpt.nbrDim(); i++) {
289  cpt_parents.insert(__bn.idFromName(cpt.variable(i).name()));
290  }
291 
292  return (this->parents(id) == cpt_parents);
293  }
294 
295  template < typename GUM_SCALAR >
297  for (auto node: nodes())
298  if (!checkConsistency(node)) return false;
299 
300  return true;
301  }
302 
303  template < typename GUM_SCALAR >
305  std::stringstream output;
306  output << "digraph \"";
307 
308  std::string bn_name;
309 
310  static std::string inFragmentStyle = "fillcolor=\"#ffffaa\","
311  "color=\"#000000\","
312  "fontcolor=\"#000000\"";
313  static std::string styleWithLocalCPT = "fillcolor=\"#ffddaa\","
314  "color=\"#000000\","
315  "fontcolor=\"#000000\"";
316  static std::string notConsistantStyle = "fillcolor=\"#ff0000\","
317  "color=\"#000000\","
318  "fontcolor=\"#ffff00\"";
319  static std::string outFragmentStyle = "fillcolor=\"#f0f0f0\","
320  "color=\"#f0f0f0\","
321  "fontcolor=\"#000000\"";
322 
323  try {
324  bn_name = __bn.property("name");
325  } catch (NotFound&) { bn_name = "no_name"; }
326 
327  bn_name = "Fragment of " + bn_name;
328 
329  output << bn_name << "\" {" << std::endl;
330  output << " graph [bgcolor=transparent,label=\"" << bn_name << "\"];"
331  << std::endl;
332  output << " node [style=filled];" << std::endl << std::endl;
333 
334  for (auto node: __bn.nodes()) {
335  output << "\"" << __bn.variable(node).name() << "\" [comment=\"" << node
336  << ":" << __bn.variable(node) << ", \"";
337 
338  if (isInstalledNode(node)) {
339  if (!checkConsistency(node)) {
340  output << notConsistantStyle;
341  } else if (__localCPTs.exists(node))
342  output << styleWithLocalCPT;
343  else
344  output << inFragmentStyle;
345  } else
346  output << outFragmentStyle;
347 
348  output << "];" << std::endl;
349  }
350 
351  output << std::endl;
352 
353  std::string tab = " ";
354 
355  for (auto node: __bn.nodes()) {
356  if (__bn.children(node).size() > 0) {
357  for (auto child: __bn.children(node)) {
358  output << tab << "\"" << __bn.variable(node).name() << "\" -> "
359  << "\"" << __bn.variable(child).name() << "\" [";
360 
361  if (dag().existsArc(Arc(node, child)))
362  output << inFragmentStyle;
363  else
364  output << outFragmentStyle;
365 
366  output << "];" << std::endl;
367  }
368  }
369  }
370 
371  output << "}" << std::endl;
372 
373  return output.str();
374  }
375 
376  template < typename GUM_SCALAR >
378  if (!checkConsistency()) {
379  GUM_ERROR(OperationNotAllowed, "The fragment contains un-consistent node(s)")
380  }
382  for (const auto nod: nodes()) {
383  res.add(variable(nod), nod);
384  }
385  for (const auto arc: dag().arcs()) {
386  res.addArc(arc.tail(), arc.head());
387  }
388  for (const auto nod: nodes()) {
389  res.cpt(nod).fillWith(cpt(nod));
390  }
391 
392  return res;
393  }
394 } // namespace gum
void addArc(NodeId tail, NodeId head)
Add an arc in the BN, and update arc.head&#39;s CPT.
Definition: BayesNet_tpl.h:369
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:583
aGrUM&#39;s Potential is a multi-dimensional array with tensor operators.
Definition: potential.h:60
Class representing a Bayesian Network.
Definition: BayesNet.h:78
Abstract Base class for all diGraph Listener.
const ArcSet & arcs() const
returns the set of nodes with arc ingoing to a given node
Definition: DAGmodel_inl.h:44
gum::BayesNet< GUM_SCALAR > toBN() const
create a brand new BayesNet from a fragment.
virtual Idx nbrDim() const final
Returns the number of vars in the multidimensional container.
void _installArc(NodeId from, NodeId to)
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
NodeProperty< const Potential< GUM_SCALAR > *> __localCPTs
Mapping between the variable&#39;s id and their CPT specific to this Fragment.
void installCPT(NodeId id, const Potential< GUM_SCALAR > &pot)
install a local cpt BY COPYfor a node into the fragment.
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
Definition: DAGmodel_inl.h:46
virtual void whenNodeDeleted(const void *src, NodeId id) final
the action to take when a node has just been removed from the graph
virtual NodeId nodeId(const DiscreteVariable &var) const final
Return id node from discrete var pointer.
NodeId add(const DiscreteVariable &var)
Add a variable to the gum::BayesNet.
Definition: BayesNet_tpl.h:243
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Container used to map discrete variables with nodes.
void uninstallNode(NodeId id)
uninstall a node referenced by its nodeId
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
virtual const DiscreteVariable & variable(NodeId id) const final
Returns a constant reference over a variabe given it&#39;s node id.
Base class for discrete random variable.
void installAscendants(NodeId id)
install a node and all its ascendants
Class representing the minimal interface for Bayesian Network.
Definition: IBayesNet.h:62
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
Definition: agrum.h:25
virtual void whenNodeAdded(const void *src, NodeId id) final
the action to take when a new node is inserted into the graph
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition: set_tpl.h:504
const IBayesNet< GUM_SCALAR > & __bn
The referred BayesNet.
DAG _dag
The DAG of this Directed Graphical Model.
Definition: DAGmodel.h:162
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
void installMarginal(NodeId id, const Potential< GUM_SCALAR > &pot)
install a local marginal BY COPY for a node into the fragment.
virtual const VariableNodeMap & variableNodeMap() const final
Returns a constant reference to the VariableNodeMap of this BN.
void _installCPT(NodeId id, const Potential< GUM_SCALAR > &pot)
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
virtual void whenArcAdded(const void *src, NodeId from, NodeId to) final
the action to take when a new arc is inserted into the graph
const NodeGraphPart & nodes() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:60
virtual const DiscreteVariable & variable(Idx) const final
Returns a const ref to the ith var.
virtual const DiscreteVariable & variableFromName(const std::string &name) const final
Getter by name.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:43
const Potential< GUM_SCALAR > & cpt(NodeId varId) const final
Returns the CPT of a variable.
Definition: BayesNet_tpl.h:326
virtual NodeId idFromName(const std::string &name) const final
Getter by name.
Portion of a BN identified by the list of nodes and a BayesNet.
bool existsNode(const NodeId id) const
returns true iff the NodeGraphPart contains the given nodeId
void installNode(NodeId id)
install a node referenced by its nodeId
void _uninstallCPT(NodeId id)
uninstall a local CPT.
bool isInstalledNode(NodeId id) const
check if a certain NodeId exists in the fragment
virtual const Potential< GUM_SCALAR > & cpt(NodeId varId) const final
Returns the CPT of a variable.
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition: set_tpl.h:490
bool checkConsistency() const
returns true if all nodes in the fragment are consistent
Size Idx
Type for indexes.
Definition: types.h:53
void uninstallCPT(NodeId id)
uninstall a local CPT.
const std::string & name() const
returns the name of the variable
virtual void whenArcDeleted(const void *src, NodeId from, NodeId to) final
the action to take when an arc has just been removed from the graph
virtual void eraseNode(const NodeId id)
remove a node and its adjacent arcs from the graph
Definition: diGraph_inl.h:69
const DAG & dag() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:36
Size NodeId
Type for node ids.
Definition: graphElements.h:98
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:615
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
virtual std::string toDot() const final
creates a dot representing the whole referred BN hilighting the fragment.
void _uninstallArc(NodeId from, NodeId to)
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.