aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
BayesNetFragment_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Template implementation of BN/BayesNetFragment.h classes.
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
27  */
28 #include <agrum/BN/BayesNet.h>
29 #include <agrum/BN/BayesNetFragment.h>
30 #include <agrum/tools/multidim/potential.h>
31 
32 namespace gum {
33  template < typename GUM_SCALAR >
34  BayesNetFragment< GUM_SCALAR >::BayesNetFragment(
35  const IBayesNet< GUM_SCALAR >& bn) :
36  DiGraphListener(&bn.dag()),
37  bn__(bn) {
38  GUM_CONSTRUCTOR(BayesNetFragment);
39  }
40 
41  template < typename GUM_SCALAR >
44 
45  for (auto node: nodes())
47  }
48 
49  //============================================================
50  // signals to keep consistency with the referred BayesNet
51  template < typename GUM_SCALAR >
53  NodeId id) {
54  // nothing to do
55  }
56  template < typename GUM_SCALAR >
58  NodeId id) {
60  }
61  template < typename GUM_SCALAR >
63  NodeId from,
64  NodeId to) {
65  // nothing to do
66  }
67  template < typename GUM_SCALAR >
69  NodeId from,
70  NodeId to) {
72  }
73 
74  //============================================================
75  // IBayesNet interface : BayesNetFragment here is a decorator for the bn
76 
77  template < typename GUM_SCALAR >
78  INLINE const Potential< GUM_SCALAR >&
80  if (!isInstalledNode(id))
81  GUM_ERROR(NotFound, "NodeId " << id << " is not installed");
82 
83  if (localCPTs__.exists(id))
84  return *localCPTs__[id];
85  else
86  return bn__.cpt(id);
87  }
88 
89  template < typename GUM_SCALAR >
90  INLINE const VariableNodeMap&
92  return this->bn__.variableNodeMap();
93  }
94 
95  template < typename GUM_SCALAR >
98  if (!isInstalledNode(id))
99  GUM_ERROR(NotFound, "NodeId " << id << " is not installed");
100 
101  return bn__.variable(id);
102  }
103 
104  template < typename GUM_SCALAR >
105  INLINE NodeId
107  NodeId id = bn__.nodeId(var);
108 
109  if (!isInstalledNode(id))
110  GUM_ERROR(NotFound, "variable " << var.name() << " is not installed");
111 
112  return id;
113  }
114 
115  template < typename GUM_SCALAR >
116  INLINE NodeId
119 
120  if (!isInstalledNode(id))
121  GUM_ERROR(NotFound, "variable " << name << " is not installed");
122 
123  return id;
124  }
125 
126  template < typename GUM_SCALAR >
128  const std::string& name) const {
130 
131  if (!isInstalledNode(id))
132  GUM_ERROR(NotFound, "variable " << name << " is not installed");
133 
134  return bn__.variable(id);
135  }
136 
137  //============================================================
138  // specific API for BayesNetFragment
139  template < typename GUM_SCALAR >
141  return dag().existsNode(id);
142  }
143 
144  template < typename GUM_SCALAR >
146  if (!bn__.dag().existsNode(id))
147  GUM_ERROR(NotFound, "Node " << id << " does not exist in referred BayesNet");
148 
149  if (!isInstalledNode(id)) {
150  this->dag_.addNodeWithId(id);
151 
152  // adding arcs with id as a tail
153  for (auto pa: this->bn__.parents(id)) {
154  if (isInstalledNode(pa)) this->dag_.addArc(pa, id);
155  }
156 
157  // adding arcs with id as a head
158  for (auto son: this->bn__.children(id))
159  if (isInstalledNode(son)) this->dag_.addArc(id, son);
160  }
161  }
162 
163  template < typename GUM_SCALAR >
165  installNode(id);
166 
167  // bn is a dag => this will have an end ...
168  for (auto pa: this->bn__.parents(id))
170  }
171 
172  template < typename GUM_SCALAR >
174  if (isInstalledNode(id)) {
175  uninstallCPT(id);
176  this->dag_.eraseNode(id);
177  }
178  }
179 
180  template < typename GUM_SCALAR >
182  NodeId to) {
183  this->dag_.eraseArc(Arc(from, to));
184  }
185 
186  template < typename GUM_SCALAR >
188  this->dag_.addArc(from, to);
189  }
190 
191  template < typename GUM_SCALAR >
193  NodeId id,
194  const Potential< GUM_SCALAR >& pot) {
195  // topology
196  const auto& parents = this->parents(id);
197  for (auto node_it = parents.beginSafe(); node_it != parents.endSafe();
198  ++node_it) // safe iterator needed here
200 
201  for (Idx i = 1; i < pot.nbrDim(); i++) {
203 
205  }
206 
207  // local cpt
209 
211  }
212 
213  template < typename GUM_SCALAR >
215  NodeId id,
216  const Potential< GUM_SCALAR >& pot) {
217  if (!dag().existsNode(id))
218  GUM_ERROR(NotFound, "Node " << id << " is not installed in the fragment");
219 
220  if (&(pot.variable(0)) != &(variable(id))) {
222  "The potential is not a marginal for bn__.variable <"
223  << variable(id).name() << ">");
224  }
225 
226  const NodeSet& parents = bn__.parents(id);
227 
228  for (Idx i = 1; i < pot.nbrDim(); i++) {
231  "Variable <" << pot.variable(i).name()
232  << "> is not in the parents of node " << id);
233  }
234 
235  installCPT_(id, pot);
236  }
237 
238  template < typename GUM_SCALAR >
240  delete localCPTs__[id];
242  }
243 
244  template < typename GUM_SCALAR >
246  if (localCPTs__.exists(id)) {
247  uninstallCPT_(id);
248 
249  // re-create arcs from referred potential
250  const Potential< GUM_SCALAR >& pot = cpt(id);
251 
252  for (Idx i = 1; i < pot.nbrDim(); i++) {
254 
256  }
257  }
258  }
259 
260  template < typename GUM_SCALAR >
262  NodeId id,
263  const Potential< GUM_SCALAR >& pot) {
264  if (!isInstalledNode(id)) {
265  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
266  }
267 
268  if (pot.nbrDim() > 1) {
269  GUM_ERROR(OperationNotAllowed, "The potential is not a marginal :" << pot);
270  }
271 
272  if (&(pot.variable(0)) != &(bn__.variable(id))) {
274  "The potential is not a marginal for bn__.variable <"
275  << bn__.variable(id).name() << ">");
276  }
277 
278  installCPT_(id, pot);
279  }
280 
281  template < typename GUM_SCALAR >
283  if (!isInstalledNode(id))
284  GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
285 
286  const auto& cpt = this->cpt(id);
288 
289  for (Idx i = 1; i < cpt.nbrDim(); i++) {
291  }
292 
293  return (this->parents(id) == cpt_parents);
294  }
295 
296  template < typename GUM_SCALAR >
298  for (auto node: nodes())
299  if (!checkConsistency(node)) return false;
300 
301  return true;
302  }
303 
304  template < typename GUM_SCALAR >
307  output << "digraph \"";
308 
309  std::string bn_name;
310 
311  static std::string inFragmentStyle = "fillcolor=\"#ffffaa\","
312  "color=\"#000000\","
313  "fontcolor=\"#000000\"";
314  static std::string styleWithLocalCPT = "fillcolor=\"#ffddaa\","
315  "color=\"#000000\","
316  "fontcolor=\"#000000\"";
317  static std::string notConsistantStyle = "fillcolor=\"#ff0000\","
318  "color=\"#000000\","
319  "fontcolor=\"#ffff00\"";
320  static std::string outFragmentStyle = "fillcolor=\"#f0f0f0\","
321  "color=\"#f0f0f0\","
322  "fontcolor=\"#000000\"";
323 
324  try {
325  bn_name = bn__.property("name");
326  } catch (NotFound&) { bn_name = "no_name"; }
327 
328  bn_name = "Fragment of " + bn_name;
329 
330  output << bn_name << "\" {" << std::endl;
331  output << " graph [bgcolor=transparent,label=\"" << bn_name << "\"];"
332  << std::endl;
333  output << " node [style=filled];" << std::endl << std::endl;
334 
335  for (auto node: bn__.nodes()) {
336  output << "\"" << bn__.variable(node).name() << "\" [comment=\"" << node
337  << ":" << bn__.variable(node) << ", \"";
338 
339  if (isInstalledNode(node)) {
340  if (!checkConsistency(node)) {
342  } else if (localCPTs__.exists(node))
344  else
346  } else
348 
349  output << "];" << std::endl;
350  }
351 
352  output << std::endl;
353 
354  std::string tab = " ";
355 
356  for (auto node: bn__.nodes()) {
357  if (bn__.children(node).size() > 0) {
358  for (auto child: bn__.children(node)) {
359  output << tab << "\"" << bn__.variable(node).name() << "\" -> "
360  << "\"" << bn__.variable(child).name() << "\" [";
361 
362  if (dag().existsArc(Arc(node, child)))
364  else
366 
367  output << "];" << std::endl;
368  }
369  }
370  }
371 
372  output << "}" << std::endl;
373 
374  return output.str();
375  }
376 
377  template < typename GUM_SCALAR >
379  if (!checkConsistency()) {
380  GUM_ERROR(OperationNotAllowed, "The fragment contains un-consistent node(s)")
381  }
383  for (const auto nod: nodes()) {
384  res.add(variable(nod), nod);
385  }
386  for (const auto& arc: dag().arcs()) {
387  res.addArc(arc.tail(), arc.head());
388  }
389  for (const auto nod: nodes()) {
390  res.cpt(nod).fillWith(cpt(nod));
391  }
392 
393  return res;
394  }
395 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669