aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
BNdistance_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief KL divergence between BNs implementation
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <complex>
30 
31 #include <agrum/tools/core/math/math_utils.h>
32 #include <agrum/BN/IBayesNet.h>
33 #include <agrum/BN/algorithms/divergence/BNdistance.h>
34 
35 namespace gum {
36  template < typename GUM_SCALAR >
37  BNdistance< GUM_SCALAR >::BNdistance(const IBayesNet< GUM_SCALAR >& P,
38  const IBayesNet< GUM_SCALAR >& Q) :
39  p_(P),
40  q_(Q), klPQ_(0.0), klQP_(0.0), errorPQ_(0), errorQP_(0),
41  difficulty__(Complexity::Heavy), done__(false) {
42  checkCompatibility__(); // may throw OperationNotAllowed
43  GUM_CONSTRUCTOR(BNdistance);
44 
45  double diff = p_.log10DomainSize();
46 
47  if (diff > GAP_COMPLEXITY_KL_HEAVY_DIFFICULT)
48  difficulty__ = Complexity::Heavy;
49  else if (diff > GAP_COMPLEXITY_KL_DIFFICULT_CORRECT)
50  difficulty__ = Complexity::Difficult;
51  else
52  difficulty__ = Complexity::Correct;
53  }
54 
55  template < typename GUM_SCALAR >
56  BNdistance< GUM_SCALAR >::BNdistance(const BNdistance< GUM_SCALAR >& kl) :
57  p_(kl.p_), q_(kl.q_), klPQ_(kl.klPQ_), klQP_(kl.klQP_),
58  errorPQ_(kl.errorPQ_), errorQP_(kl.errorQP_), difficulty__(kl.difficulty__),
59  done__(kl.done__) {
60  GUM_CONSTRUCTOR(BNdistance);
61  }
62 
63  template < typename GUM_SCALAR >
64  BNdistance< GUM_SCALAR >::~BNdistance() {
65  GUM_DESTRUCTOR(BNdistance);
66  }
67 
68  template < typename GUM_SCALAR >
69  Complexity BNdistance< GUM_SCALAR >::difficulty() const {
70  return difficulty__;
71  }
72 
73  template < typename GUM_SCALAR >
74  INLINE double BNdistance< GUM_SCALAR >::klPQ() {
75  process_();
76  return klPQ_;
77  }
78 
79  template < typename GUM_SCALAR >
80  INLINE double BNdistance< GUM_SCALAR >::klQP() {
81  process_();
82  return klQP_;
83  }
84 
85  template < typename GUM_SCALAR >
86  INLINE double BNdistance< GUM_SCALAR >::hellinger() {
87  process_();
88  return hellinger_;
89  }
90 
91  template < typename GUM_SCALAR >
92  INLINE double BNdistance< GUM_SCALAR >::bhattacharya() {
93  process_();
94  return bhattacharya_;
95  }
96 
97  template < typename GUM_SCALAR >
98  INLINE double BNdistance< GUM_SCALAR >::jsd() {
99  process_();
100  return jsd_;
101  }
102 
103  template < typename GUM_SCALAR >
104  INLINE Size BNdistance< GUM_SCALAR >::errorPQ() {
105  process_();
106  return errorPQ_;
107  }
108 
109  template < typename GUM_SCALAR >
110  INLINE Size BNdistance< GUM_SCALAR >::errorQP() {
111  process_();
112  return errorQP_;
113  }
114 
115  template < typename GUM_SCALAR >
116  INLINE const IBayesNet< GUM_SCALAR >& BNdistance< GUM_SCALAR >::p() const {
117  return p_;
118  }
119 
120  template < typename GUM_SCALAR >
121  INLINE const IBayesNet< GUM_SCALAR >& BNdistance< GUM_SCALAR >::q() const {
122  return q_;
123  }
124 
125  // check if the 2 BNs are compatible
126  template < typename GUM_SCALAR >
127  bool BNdistance< GUM_SCALAR >::checkCompatibility__() const {
128  for (auto node: p_.nodes()) {
129  const DiscreteVariable& vp = p_.variable(node);
130 
131  try {
132  const DiscreteVariable& vq = q_.variableFromName(vp.name());
133 
134  if (vp.domainSize() != vq.domainSize())
135  GUM_ERROR(OperationNotAllowed,
136  "BNdistance : the 2 BNs are not compatible "
137  "(not the same domainSize for "
138  + vp.name() + ")");
139 
140  for (Idx i = 0; i < vp.domainSize(); i++) {
141  try {
142  vq[vp.label(i)];
143  vp[vq.label(i)];
144 
145  } catch (OutOfBounds&) {
146  GUM_ERROR(OperationNotAllowed,
147  "BNdistance : the 2 BNs are not compatible F(not the same "
148  "labels for "
149  + vp.name() + ")");
150  }
151  }
152  } catch (NotFound&) {
153  GUM_ERROR(OperationNotAllowed,
154  "BNdistance : the 2 BNs are not compatible (not the same vars : "
155  + vp.name() + ")");
156  }
157  }
158 
159  // should not be used
160  if (p_.size() != q_.size())
161  GUM_ERROR(OperationNotAllowed,
162  "BNdistance : the 2 BNs are not compatible (not the same size)");
163 
164  if (std::fabs(p_.log10DomainSize() - q_.log10DomainSize()) > 1e-14) {
165  GUM_ERROR(
166  OperationNotAllowed,
167  "BNdistance : the 2 BNs are not compatible (not the same domainSize) : p="
168  << p_.log10DomainSize() << " q=" << q_.log10DomainSize() << " => "
169  << p_.log10DomainSize() - q_.log10DomainSize());
170  }
171 
172  return true;
173  }
174 
175  // do the job if not already done__
176  template < typename GUM_SCALAR >
177  void BNdistance< GUM_SCALAR >::process_() {
178  if (!done__) {
179  computeKL_();
180  done__ = true;
181  }
182  }
183 
184  // in order to keep BNdistance instantiable
185  template < typename GUM_SCALAR >
186  void BNdistance< GUM_SCALAR >::computeKL_() {
187  GUM_ERROR(OperationNotAllowed, "No default computations");
188  }
189 } // namespace gum