aGrUM  0.14.2
BayesNetFragment_tpl.h
Go to the documentation of this file.
1 /***************************************************************************
2  * Copyright (C) 2005 by Pierre-Henri WUILLEMIN et Christophe GONZALES *
3  * {prenom.nom}_at_lip6.fr *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the *
17  * Free Software Foundation, Inc., *
18  * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
19  ***************************************************************************/
26 #include <agrum/BN/BayesNet.h>
29 
30 namespace gum {
31  template < typename GUM_SCALAR >
33  const IBayesNet< GUM_SCALAR >& bn) :
34  DiGraphListener(&bn.dag()),
35  __bn(bn) {
36  GUM_CONSTRUCTOR(BayesNetFragment);
37  }
38 
39  template < typename GUM_SCALAR >
41  GUM_DESTRUCTOR(BayesNetFragment);
42 
43  for (auto node : nodes())
44  if (__localCPTs.exists(node)) _uninstallCPT(node);
45  }
46 
47  //============================================================
48  // signals to keep consistency with the referred BayesNet
49  template < typename GUM_SCALAR >
50  INLINE void BayesNetFragment< GUM_SCALAR >::whenNodeAdded(const void* src,
51  NodeId id) noexcept {
52  // nothing to do
53  }
54  template < typename GUM_SCALAR >
56  NodeId id) noexcept {
57  uninstallNode(id);
58  }
59  template < typename GUM_SCALAR >
60  INLINE void BayesNetFragment< GUM_SCALAR >::whenArcAdded(const void* src,
61  NodeId from,
62  NodeId to) noexcept {
63  // nothing to do
64  }
65  template < typename GUM_SCALAR >
67  NodeId from,
68  NodeId to) noexcept {
69  if (dag().existsArc(from, to)) _uninstallArc(from, to);
70  }
71 
72  //============================================================
73  // IBayesNet interface : BayesNetFragment here is a decorator for the bn
74 
75  template < typename GUM_SCALAR >
76  INLINE const Potential< GUM_SCALAR >&
78  if (!isInstalledNode(id)) GUM_ERROR(NotFound, id << " is not installed");
79 
80  if (__localCPTs.exists(id))
81  return *__localCPTs[id];
82  else
83  return __bn.cpt(id);
84  }
85 
86  template < typename GUM_SCALAR >
87  INLINE const VariableNodeMap&
90  "Not implemented yet. please use referent bayesnet method");
91  }
92 
93  template < typename GUM_SCALAR >
94  INLINE const DiscreteVariable&
96  if (!isInstalledNode(id)) GUM_ERROR(NotFound, id << " is not installed");
97 
98  return __bn.variable(id);
99  }
100 
101  template < typename GUM_SCALAR >
102  INLINE NodeId
104  NodeId id = __bn.nodeId(var);
105 
106  if (!isInstalledNode(id))
107  GUM_ERROR(NotFound, "variable " << var.name() << " is not installed");
108 
109  return id;
110  }
111 
112  template < typename GUM_SCALAR >
113  INLINE NodeId
114  BayesNetFragment< GUM_SCALAR >::idFromName(const std::string& name) const {
115  NodeId id = __bn.idFromName(name);
116 
117  if (!isInstalledNode(id))
118  GUM_ERROR(NotFound, "variable " << name << " is not installed");
119 
120  return id;
121  }
122 
123  template < typename GUM_SCALAR >
125  const std::string& name) const {
126  NodeId id = __bn.idFromName(name);
127 
128  if (!isInstalledNode(id))
129  GUM_ERROR(NotFound, "variable " << name << " is not installed");
130 
131  return __bn.variable(id);
132  }
133 
134  //============================================================
135  // specific API for BayesNetFragment
136  template < typename GUM_SCALAR >
138  noexcept {
139  return dag().existsNode(id);
140  }
141 
142  template < typename GUM_SCALAR >
144  if (!__bn.dag().existsNode(id))
145  GUM_ERROR(NotFound, "Node " << id << " does not exist in referred BayesNet");
146 
147  if (!isInstalledNode(id)) {
148  this->_dag.addNodeWithId(id);
149 
150  // adding arcs with id as a tail
151  for (auto pa : this->__bn.parents(id)) {
152  if (isInstalledNode(pa)) this->_dag.addArc(pa, id);
153  }
154 
155  // addin arcs with id as a head
156  for (auto son : this->__bn.children(id))
157  if (isInstalledNode(son)) this->_dag.addArc(id, son);
158  }
159  }
160 
161  template < typename GUM_SCALAR >
163  installNode(id);
164 
165  // bn is a dag => this will have an end ...
166  for (auto pa : this->__bn.parents(id))
167  installAscendants(pa);
168  }
169 
170  template < typename GUM_SCALAR >
172  if (isInstalledNode(id)) {
173  this->_dag.eraseNode(id);
174  uninstallCPT(id);
175  }
176  }
177 
178  template < typename GUM_SCALAR >
180  NodeId to) noexcept {
181  this->_dag.eraseArc(Arc(from, to));
182  }
183 
184  template < typename GUM_SCALAR >
186  NodeId to) noexcept {
187  this->_dag.addArc(from, to);
188  }
189 
190  template < typename GUM_SCALAR >
192  NodeId id, const Potential< GUM_SCALAR >* pot) noexcept {
193  // topology
194  const auto& parents = this->parents(id);
195  for (auto node_it = parents.beginSafe(); node_it != parents.endSafe();
196  ++node_it) // safe iterator needed here
197  _uninstallArc(*node_it, id);
198 
199  for (Idx i = 1; i < pot->nbrDim(); i++) {
200  NodeId parent = __bn.idFromName(pot->variable(i).name());
201 
202  if (isInstalledNode(parent)) _installArc(parent, id);
203  }
204 
205  // local cpt
206  if (__localCPTs.exists(id)) _uninstallCPT(id);
207 
208  __localCPTs.insert(id, pot);
209  }
210 
211  template < typename GUM_SCALAR >
213  NodeId id, const Potential< GUM_SCALAR >* pot) {
214  if (!dag().existsNode(id))
215  GUM_ERROR(NotFound, "Node " << id << " is not installed in the fragment");
216 
217  if (&(pot->variable(0)) != &(variable(id))) {
219  "The potential is not a marginal for __bn.variable <"
220  << variable(id).name() << ">");
221  }
222 
223  const NodeSet& parents = __bn.parents(id);
224 
225  for (Idx i = 1; i < pot->nbrDim(); i++) {
226  if (!parents.contains(__bn.idFromName(pot->variable(i).name())))
228  "Variable <" << pot->variable(i).name()
229  << "> is not in the parents of node " << id);
230  }
231 
232  _installCPT(id, pot);
233  }
234 
235  template < typename GUM_SCALAR >
237  delete __localCPTs[id];
238  __localCPTs.erase(id);
239  }
240 
241  template < typename GUM_SCALAR >
243  if (__localCPTs.exists(id)) {
244  _uninstallCPT(id);
245 
246  // re-create arcs from referred potential
247  const Potential< GUM_SCALAR >& pot = cpt(id);
248 
249  for (Idx i = 1; i < pot.nbrDim(); i++) {
250  NodeId parent = __bn.idFromName(pot.variable(i).name());
251 
252  if (isInstalledNode(parent)) _installArc(parent, id);
253  }
254  }
255  }
256 
257  template < typename GUM_SCALAR >
259  NodeId id, const Potential< GUM_SCALAR >* pot) {
260  if (!isInstalledNode(id)) {
261  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
262  }
263 
264  if (pot->nbrDim() > 1) {
265  GUM_ERROR(OperationNotAllowed, "The potential is not a marginal :" << pot);
266  }
267 
268  if (&(pot->variable(0)) != &(__bn.variable(id))) {
270  "The potential is not a marginal for __bn.variable <"
271  << __bn.variable(id).name() << ">");
272  }
273 
274  _installCPT(id, pot);
275  }
276 
277  template < typename GUM_SCALAR >
279  if (!isInstalledNode(id))
280  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
281 
282  const auto& cpt = this->cpt(id);
283  NodeSet cpt_parents;
284 
285  for (Idx i = 1; i < cpt.nbrDim(); i++) {
286  cpt_parents.insert(__bn.idFromName(cpt.variable(i).name()));
287  }
288 
289  return (this->parents(id) == cpt_parents);
290  }
291 
292  template < typename GUM_SCALAR >
294  for (auto node : nodes())
295  if (!checkConsistency(node)) return false;
296 
297  return true;
298  }
299 
300  template < typename GUM_SCALAR >
302  std::stringstream output;
303  output << "digraph \"";
304 
305  std::string bn_name;
306 
307  static std::string inFragmentStyle = "fillcolor=\"#ffffaa\","
308  "color=\"#000000\","
309  "fontcolor=\"#000000\"";
310  static std::string styleWithLocalCPT = "fillcolor=\"#ffddaa\","
311  "color=\"#000000\","
312  "fontcolor=\"#000000\"";
313  static std::string notConsistantStyle = "fillcolor=\"#ff0000\","
314  "color=\"#000000\","
315  "fontcolor=\"#ffff00\"";
316  static std::string outFragmentStyle = "fillcolor=\"#f0f0f0\","
317  "color=\"#f0f0f0\","
318  "fontcolor=\"#000000\"";
319 
320  try {
321  bn_name = __bn.property("name");
322  } catch (NotFound&) { bn_name = "no_name"; }
323 
324  bn_name = "Fragment of " + bn_name;
325 
326  output << bn_name << "\" {" << std::endl;
327  output << " graph [bgcolor=transparent,label=\"" << bn_name << "\"];"
328  << std::endl;
329  output << " node [style=filled];" << std::endl << std::endl;
330 
331  for (auto node : __bn.nodes()) {
332  output << "\"" << __bn.variable(node).name() << "\" [comment=\"" << node
333  << ":" << __bn.variable(node) << ", \"";
334 
335  if (isInstalledNode(node)) {
336  if (!checkConsistency(node)) {
337  output << notConsistantStyle;
338  } else if (__localCPTs.exists(node))
339  output << styleWithLocalCPT;
340  else
341  output << inFragmentStyle;
342  } else
343  output << outFragmentStyle;
344 
345  output << "];" << std::endl;
346  }
347 
348  output << std::endl;
349 
350  std::string tab = " ";
351 
352  for (auto node : __bn.nodes()) {
353  if (__bn.children(node).size() > 0) {
354  for (auto child : __bn.children(node)) {
355  output << tab << "\"" << __bn.variable(node).name() << "\" -> "
356  << "\"" << __bn.variable(child).name() << "\" [";
357 
358  if (dag().existsArc(Arc(node, child)))
359  output << inFragmentStyle;
360  else
361  output << outFragmentStyle;
362 
363  output << "];" << std::endl;
364  }
365  }
366  }
367 
368  output << "}" << std::endl;
369 
370  return output.str();
371  }
372 } // namespace gum
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
Definition: set_tpl.h:578
aGrUM&#39;s Potential is a multi-dimensional array with tensor operators.
Definition: potential.h:57
virtual NodeId nodeId(const DiscreteVariable &var) const override
Return id node from discrete var pointer.
Abstract Base class for all diGraph Listener.
virtual void whenNodeDeleted(const void *src, NodeId id) noexcept override
the action to take when a node has just been removed from the graph
void uninstallCPT(NodeId id) noexcept
uninstall a local CPT.
virtual Idx nbrDim() const final
Returns the number of vars in the multidimensional container.
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
NodeProperty< const Potential< GUM_SCALAR > *> __localCPTs
Mapping between the variable&#39;s id and their CPT specific to this Fragment.
virtual const DiscreteVariable & variableFromName(const std::string &name) const override
Getter by name.
void _installCPT(NodeId id, const Potential< GUM_SCALAR > *pot) noexcept
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
Definition: DAGmodel_inl.h:103
void installCPT(NodeId id, const Potential< GUM_SCALAR > *pot)
install a local cpt for a node into the fragment.
virtual void whenArcAdded(const void *src, NodeId from, NodeId to) noexcept override
the action to take when a new arc is inserted into the graph
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
virtual const VariableNodeMap & variableNodeMap() const override
Returns a constant reference to the VariableNodeMap of this BN.
Container used to map discrete variables with nodes.
void _installArc(NodeId from, NodeId to) noexcept
Class representing Bayesian networks.
Base class for discrete random variable.
void installAscendants(NodeId id)
install a node and all its ascendants
Class representing the minimal interface for Bayesian Network.
Definition: IBayesNet.h:59
gum is the global namespace for all aGrUM entities
Definition: agrum.h:25
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition: set_tpl.h:499
virtual const Potential< GUM_SCALAR > & cpt(NodeId varId) const override
Returns the CPT of a variable.
const IBayesNet< GUM_SCALAR > & __bn
The referred BayesNet.
DAG _dag
The DAG of this Directed Graphical Model.
Definition: DAGmodel.h:200
Header of the Potential class.
virtual const DiscreteVariable & variable(NodeId id) const override
Returns a constant reference over a variabe given it&#39;s node id.
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
const NodeGraphPart & nodes() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:112
virtual const DiscreteVariable & variable(Idx) const final
Returns a const ref to the ith var.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition: DAG_inl.h:40
virtual void whenArcDeleted(const void *src, NodeId from, NodeId to) noexcept override
the action to take when an arc has just been removed from the graph
Portion of a BN identified by the list of nodes and a BayesNet.
bool existsNode(const NodeId id) const
returns true iff the NodeGraphPart contains the given nodeId
void installNode(NodeId id)
install a node referenced by its nodeId
virtual std::string toDot() const override
creates a dot representing the whole referred BN hilighting the fragment.
void uninstallNode(NodeId id) noexcept
uninstall a node referenced by its nodeId
bool isInstalledNode(NodeId id) const noexcept
check if a certain NodeId exists in the fragment
void installMarginal(NodeId id, const Potential< GUM_SCALAR > *pot)
install a local marginal for a node into the fragment.
bool checkConsistency() const noexcept
returns true if all nodes in the fragment are consistent
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition: set_tpl.h:485
virtual NodeId idFromName(const std::string &name) const override
Getter by name.
Size Idx
Type for indexes.
Definition: types.h:50
virtual void whenNodeAdded(const void *src, NodeId id) noexcept override
the action to take when a new node is inserted into the graph
void _uninstallArc(NodeId from, NodeId to) noexcept
void _uninstallCPT(NodeId id) noexcept
uninstall a local CPT.
const std::string & name() const
returns the name of the variable
virtual void eraseNode(const NodeId id)
remove a node and its adjacent arcs from the graph
Definition: diGraph_inl.h:66
const DAG & dag() const
Returns a constant reference to the dag of this Bayes Net.
Definition: DAGmodel_inl.h:60
Size NodeId
Type for node ids.
Definition: graphElements.h:97
void insert(const Key &k)
Inserts a new element into the set.
Definition: set_tpl.h:610
#define GUM_ERROR(type, msg)
Definition: exceptions.h:52
Class representing Fragment of Bayesian networks.