aGrUM  0.16.0
BayesNetFragment_tpl.h
Go to the documentation of this file.
1 
29 #include <agrum/BN/BayesNet.h>
32 
33 namespace gum {
34  template < typename GUM_SCALAR >
36  const IBayesNet< GUM_SCALAR >& bn) :
37  DiGraphListener(&bn.dag()),
38  __bn(bn) {
39  GUM_CONSTRUCTOR(BayesNetFragment);
40  }
41 
42  template < typename GUM_SCALAR >
44  GUM_DESTRUCTOR(BayesNetFragment);
45 
46  for (auto node : nodes())
47  if (__localCPTs.exists(node)) _uninstallCPT(node);
48  }
49 
50  //============================================================
51  // signals to keep consistency with the referred BayesNet
52  template < typename GUM_SCALAR >
53  INLINE void BayesNetFragment< GUM_SCALAR >::whenNodeAdded(const void* src,
54  NodeId id) noexcept {
55  // nothing to do
56  }
57  template < typename GUM_SCALAR >
59  NodeId id) noexcept {
60  uninstallNode(id);
61  }
62  template < typename GUM_SCALAR >
63  INLINE void BayesNetFragment< GUM_SCALAR >::whenArcAdded(const void* src,
64  NodeId from,
65  NodeId to) noexcept {
66  // nothing to do
67  }
68  template < typename GUM_SCALAR >
70  NodeId from,
71  NodeId to) noexcept {
72  if (dag().existsArc(from, to)) _uninstallArc(from, to);
73  }
74 
75  //============================================================
76  // IBayesNet interface : BayesNetFragment here is a decorator for the bn
77 
78  template < typename GUM_SCALAR >
79  INLINE const Potential< GUM_SCALAR >&
81  if (!isInstalledNode(id)) GUM_ERROR(NotFound, id << " is not installed");
82 
83  if (__localCPTs.exists(id))
84  return *__localCPTs[id];
85  else
86  return __bn.cpt(id);
87  }
88 
89  template < typename GUM_SCALAR >
90  INLINE const VariableNodeMap&
93  "Not implemented yet. please use referent bayesnet method");
94  }
95 
96  template < typename GUM_SCALAR >
97  INLINE const DiscreteVariable&
99  if (!isInstalledNode(id)) GUM_ERROR(NotFound, id << " is not installed");
100 
101  return __bn.variable(id);
102  }
103 
104  template < typename GUM_SCALAR >
105  INLINE NodeId
107  NodeId id = __bn.nodeId(var);
108 
109  if (!isInstalledNode(id))
110  GUM_ERROR(NotFound, "variable " << var.name() << " is not installed");
111 
112  return id;
113  }
114 
115  template < typename GUM_SCALAR >
116  INLINE NodeId
117  BayesNetFragment< GUM_SCALAR >::idFromName(const std::string& name) const {
118  NodeId id = __bn.idFromName(name);
119 
120  if (!isInstalledNode(id))
121  GUM_ERROR(NotFound, "variable " << name << " is not installed");
122 
123  return id;
124  }
125 
126  template < typename GUM_SCALAR >
128  const std::string& name) const {
129  NodeId id = __bn.idFromName(name);
130 
131  if (!isInstalledNode(id))
132  GUM_ERROR(NotFound, "variable " << name << " is not installed");
133 
134  return __bn.variable(id);
135  }
136 
137  //============================================================
138  // specific API for BayesNetFragment
139  template < typename GUM_SCALAR >
141  noexcept {
142  return dag().existsNode(id);
143  }
144 
145  template < typename GUM_SCALAR >
147  if (!__bn.dag().existsNode(id))
148  GUM_ERROR(NotFound, "Node " << id << " does not exist in referred BayesNet");
149 
150  if (!isInstalledNode(id)) {
151  this->_dag.addNodeWithId(id);
152 
153  // adding arcs with id as a tail
154  for (auto pa : this->__bn.parents(id)) {
155  if (isInstalledNode(pa)) this->_dag.addArc(pa, id);
156  }
157 
158  // addin arcs with id as a head
159  for (auto son : this->__bn.children(id))
160  if (isInstalledNode(son)) this->_dag.addArc(id, son);
161  }
162  }
163 
164  template < typename GUM_SCALAR >
166  installNode(id);
167 
168  // bn is a dag => this will have an end ...
169  for (auto pa : this->__bn.parents(id))
170  installAscendants(pa);
171  }
172 
173  template < typename GUM_SCALAR >
175  if (isInstalledNode(id)) {
176  this->_dag.eraseNode(id);
177  uninstallCPT(id);
178  }
179  }
180 
181  template < typename GUM_SCALAR >
183  NodeId to) noexcept {
184  this->_dag.eraseArc(Arc(from, to));
185  }
186 
187  template < typename GUM_SCALAR >
189  NodeId to) noexcept {
190  this->_dag.addArc(from, to);
191  }
192 
193  template < typename GUM_SCALAR >
195  NodeId id, const Potential< GUM_SCALAR >* pot) noexcept {
196  // topology
197  const auto& parents = this->parents(id);
198  for (auto node_it = parents.beginSafe(); node_it != parents.endSafe();
199  ++node_it) // safe iterator needed here
200  _uninstallArc(*node_it, id);
201 
202  for (Idx i = 1; i < pot->nbrDim(); i++) {
203  NodeId parent = __bn.idFromName(pot->variable(i).name());
204 
205  if (isInstalledNode(parent)) _installArc(parent, id);
206  }
207 
208  // local cpt
209  if (__localCPTs.exists(id)) _uninstallCPT(id);
210 
211  __localCPTs.insert(id, pot);
212  }
213 
214  template < typename GUM_SCALAR >
216  NodeId id, const Potential< GUM_SCALAR >* pot) {
217  if (!dag().existsNode(id))
218  GUM_ERROR(NotFound, "Node " << id << " is not installed in the fragment");
219 
220  if (&(pot->variable(0)) != &(variable(id))) {
222  "The potential is not a marginal for __bn.variable <"
223  << variable(id).name() << ">");
224  }
225 
226  const NodeSet& parents = __bn.parents(id);
227 
228  for (Idx i = 1; i < pot->nbrDim(); i++) {
229  if (!parents.contains(__bn.idFromName(pot->variable(i).name())))
231  "Variable <" << pot->variable(i).name()
232  << "> is not in the parents of node " << id);
233  }
234 
235  _installCPT(id, pot);
236  }
237 
238  template < typename GUM_SCALAR >
240  delete __localCPTs[id];
241  __localCPTs.erase(id);
242  }
243 
244  template < typename GUM_SCALAR >
246  if (__localCPTs.exists(id)) {
247  _uninstallCPT(id);
248 
249  // re-create arcs from referred potential
250  const Potential< GUM_SCALAR >& pot = cpt(id);
251 
252  for (Idx i = 1; i < pot.nbrDim(); i++) {
253  NodeId parent = __bn.idFromName(pot.variable(i).name());
254 
255  if (isInstalledNode(parent)) _installArc(parent, id);
256  }
257  }
258  }
259 
260  template < typename GUM_SCALAR >
262  NodeId id, const Potential< GUM_SCALAR >* pot) {
263  if (!isInstalledNode(id)) {
264  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
265  }
266 
267  if (pot->nbrDim() > 1) {
268  GUM_ERROR(OperationNotAllowed, "The potential is not a marginal :" << pot);
269  }
270 
271  if (&(pot->variable(0)) != &(__bn.variable(id))) {
273  "The potential is not a marginal for __bn.variable <"
274  << __bn.variable(id).name() << ">");
275  }
276 
277  _installCPT(id, pot);
278  }
279 
280  template < typename GUM_SCALAR >
282  if (!isInstalledNode(id))
283  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
284 
285  const auto& cpt = this->cpt(id);
286  NodeSet cpt_parents;
287 
288  for (Idx i = 1; i < cpt.nbrDim(); i++) {
289  cpt_parents.insert(__bn.idFromName(cpt.variable(i).name()));
290  }
291 
292  return (this->parents(id) == cpt_parents);
293  }
294 
295  template < typename GUM_SCALAR >
297  for (auto node : nodes())
298  if (!checkConsistency(node)) return false;
299 
300  return true;
301  }
302 
303  template < typename GUM_SCALAR >
305  std::stringstream output;
306  output << "digraph \"";
307 
308  std::string bn_name;
309 
310  static std::string inFragmentStyle = "fillcolor=\"#ffffaa\","
311  "color=\"#000000\","
312  "fontcolor=\"#000000\"";
313  static std::string styleWithLocalCPT = "fillcolor=\"#ffddaa\","
314  "color=\"#000000\","
315  "fontcolor=\"#000000\"";
316  static std::string notConsistantStyle = "fillcolor=\"#ff0000\","
317  "color=\"#000000\","
318  "fontcolor=\"#ffff00\"";
319  static std::string outFragmentStyle = "fillcolor=\"#f0f0f0\","
320  "color=\"#f0f0f0\","
321  "fontcolor=\"#000000\"";
322 
323  try {
324  bn_name = __bn.property("name");
325  } catch (NotFound&) { bn_name = "no_name"; }
326 
327  bn_name = "Fragment of " + bn_name;
328 
329  output << bn_name << "\" {" << std::endl;
330  output << " graph [bgcolor=transparent,label=\"" << bn_name << "\"];"
331  << std::endl;
332  output << " node [style=filled];" << std::endl << std::endl;
333 
334  for (auto node : __bn.nodes()) {
335  output << "\"" << __bn.variable(node).name() << "\" [comment=\"" << node
336  << ":" << __bn.variable(node) << ", \"";
337 
338  if (isInstalledNode(node)) {
339  if (!checkConsistency(node)) {
340  output << notConsistantStyle;
341  } else if (__localCPTs.exists(node))
342  output << styleWithLocalCPT;
343  else
344  output << inFragmentStyle;
345  } else
346  output << outFragmentStyle;
347 
348  output << "];" << std::endl;
349  }
350 
351  output << std::endl;
352 
353  std::string tab = " ";
354 
355  for (auto node : __bn.nodes()) {
356  if (__bn.children(node).size() > 0) {
357  for (auto child : __bn.children(node)) {
358  output << tab << "\"" << __bn.variable(node).name() << "\" -> "
359  << "\"" << __bn.variable(child).name() << "\" [";
360 
361  if (dag().existsArc(Arc(node, child)))
362  output << inFragmentStyle;
363  else
364  output << outFragmentStyle;
365 
366  output << "];" << std::endl;
367  }
368  }
369  }
370 
371  output << "}" << std::endl;
372 
373  return output.str();
374  }
375 } // namespace gum
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:581
aGrUM&#39;s Potential is a multi-dimensional array with tensor operators.
Definition: potential.h:60
virtual NodeId nodeId(const DiscreteVariable &var) const override
Return id node from discrete var pointer.
Abstract Base class for all diGraph Listener.
virtual void whenNodeDeleted(const void *src, NodeId id) noexcept override
the action to take when a node has just been removed from the graph
void uninstallCPT(NodeId id) noexcept
uninstall a local CPT.
virtual Idx nbrDim() const final
Returns the number of vars in the multidimensional container.
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
NodeProperty< const Potential< GUM_SCALAR > *> __localCPTs
Mapping between the variable&#39;s id and their CPT specific to this Fragment.
virtual const DiscreteVariable & variableFromName(const std::string &name) const override
Getter by name.
void _installCPT(NodeId id, const Potential< GUM_SCALAR > *pot) noexcept
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
Definition: DAGmodel_inl.h:106
void installCPT(NodeId id, const Potential< GUM_SCALAR > *pot)
install a local cpt for a node into the fragment.
virtual void whenArcAdded(const void *src, NodeId from, NodeId to) noexcept override
the action to take when a new arc is inserted into the graph
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
virtual const VariableNodeMap & variableNodeMap() const override
Returns a constant reference to the VariableNodeMap of this BN.
Container used to map discrete variables with nodes.
void _installArc(NodeId from, NodeId to) noexcept
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Base class for discrete random variable.
void installAscendants(NodeId id)
install a node and all its ascendants
Class representing the minimal interface for Bayesian Network.
Definition: IBayesNet.h:62
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Definition: agrum.h:25
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition: set_tpl.h:502
virtual const Potential< GUM_SCALAR > & cpt(NodeId varId) const override
Returns the CPT of a variable.
const IBayesNet< GUM_SCALAR > & __bn
The referred BayesNet.
DAG _dag
The DAG of this Directed Graphical Model.
Definition: DAGmodel.h:203
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
virtual const DiscreteVariable & variable(NodeId id) const override
Returns a constant reference over a variabe given it&#39;s node id.
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
const NodeGraphPart & nodes() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:115
virtual const DiscreteVariable & variable(Idx) const final
Returns a const ref to the ith var.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:43
virtual void whenArcDeleted(const void *src, NodeId from, NodeId to) noexcept override
the action to take when an arc has just been removed from the graph
Portion of a BN identified by the list of nodes and a BayesNet.
bool existsNode(const NodeId id) const
returns true iff the NodeGraphPart contains the given nodeId
void installNode(NodeId id)
install a node referenced by its nodeId
virtual std::string toDot() const override
creates a dot representing the whole referred BN hilighting the fragment.
void uninstallNode(NodeId id) noexcept
uninstall a node referenced by its nodeId
bool isInstalledNode(NodeId id) const noexcept
check if a certain NodeId exists in the fragment
void installMarginal(NodeId id, const Potential< GUM_SCALAR > *pot)
install a local marginal for a node into the fragment.
bool checkConsistency() const noexcept
returns true if all nodes in the fragment are consistent
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition: set_tpl.h:488
virtual NodeId idFromName(const std::string &name) const override
Getter by name.
Size Idx
Type for indexes.
Definition: types.h:53
virtual void whenNodeAdded(const void *src, NodeId id) noexcept override
the action to take when a new node is inserted into the graph
void _uninstallArc(NodeId from, NodeId to) noexcept
void _uninstallCPT(NodeId id) noexcept
uninstall a local CPT.
const std::string & name() const
returns the name of the variable
virtual void eraseNode(const NodeId id)
remove a node and its adjacent arcs from the graph
Definition: diGraph_inl.h:69
const DAG & dag() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:63
Size NodeId
Type for node ids.
Definition: graphElements.h:98
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:613
#define GUM_ERROR(type, msg)
Definition: exceptions.h:55
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.