aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
BNdistance_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief KL divergence between BNs implementation
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <complex>
30 
31 #include <agrum/tools/core/math/math_utils.h>
32 #include <agrum/BN/IBayesNet.h>
33 #include <agrum/BN/algorithms/divergence/BNdistance.h>
34 
35 namespace gum {
36  template < typename GUM_SCALAR >
37  BNdistance< GUM_SCALAR >::BNdistance(const IBayesNet< GUM_SCALAR >& P,
38  const IBayesNet< GUM_SCALAR >& Q) :
39  p_(P),
40  q_(Q), klPQ_(0.0), klQP_(0.0), errorPQ_(0), errorQP_(0), _difficulty_(Complexity::Heavy),
41  _done_(false) {
42  _checkCompatibility_(); // may throw OperationNotAllowed
43  GUM_CONSTRUCTOR(BNdistance);
44 
45  double diff = p_.log10DomainSize();
46 
47  if (diff > GAP_COMPLEXITY_KL_HEAVY_DIFFICULT)
48  _difficulty_ = Complexity::Heavy;
49  else if (diff > GAP_COMPLEXITY_KL_DIFFICULT_CORRECT)
50  _difficulty_ = Complexity::Difficult;
51  else
52  _difficulty_ = Complexity::Correct;
53  }
54 
55  template < typename GUM_SCALAR >
56  BNdistance< GUM_SCALAR >::BNdistance(const BNdistance< GUM_SCALAR >& kl) :
57  p_(kl.p_), q_(kl.q_), klPQ_(kl.klPQ_), klQP_(kl.klQP_), errorPQ_(kl.errorPQ_),
58  errorQP_(kl.errorQP_), _difficulty_(kl._difficulty_), _done_(kl._done_) {
59  GUM_CONSTRUCTOR(BNdistance);
60  }
61 
62  template < typename GUM_SCALAR >
63  BNdistance< GUM_SCALAR >::~BNdistance() {
64  GUM_DESTRUCTOR(BNdistance);
65  }
66 
67  template < typename GUM_SCALAR >
68  Complexity BNdistance< GUM_SCALAR >::difficulty() const {
69  return _difficulty_;
70  }
71 
72  template < typename GUM_SCALAR >
73  INLINE double BNdistance< GUM_SCALAR >::klPQ() {
74  process_();
75  return klPQ_;
76  }
77 
78  template < typename GUM_SCALAR >
79  INLINE double BNdistance< GUM_SCALAR >::klQP() {
80  process_();
81  return klQP_;
82  }
83 
84  template < typename GUM_SCALAR >
85  INLINE double BNdistance< GUM_SCALAR >::hellinger() {
86  process_();
87  return hellinger_;
88  }
89 
90  template < typename GUM_SCALAR >
91  INLINE double BNdistance< GUM_SCALAR >::bhattacharya() {
92  process_();
93  return bhattacharya_;
94  }
95 
96  template < typename GUM_SCALAR >
97  INLINE double BNdistance< GUM_SCALAR >::jsd() {
98  process_();
99  return jsd_;
100  }
101 
102  template < typename GUM_SCALAR >
103  INLINE Size BNdistance< GUM_SCALAR >::errorPQ() {
104  process_();
105  return errorPQ_;
106  }
107 
108  template < typename GUM_SCALAR >
109  INLINE Size BNdistance< GUM_SCALAR >::errorQP() {
110  process_();
111  return errorQP_;
112  }
113 
114  template < typename GUM_SCALAR >
115  INLINE const IBayesNet< GUM_SCALAR >& BNdistance< GUM_SCALAR >::p() const {
116  return p_;
117  }
118 
119  template < typename GUM_SCALAR >
120  INLINE const IBayesNet< GUM_SCALAR >& BNdistance< GUM_SCALAR >::q() const {
121  return q_;
122  }
123 
124  // check if the 2 BNs are compatible
125  template < typename GUM_SCALAR >
126  bool BNdistance< GUM_SCALAR >::_checkCompatibility_() const {
127  for (auto node: p_.nodes()) {
128  const DiscreteVariable& vp = p_.variable(node);
129 
130  try {
131  const DiscreteVariable& vq = q_.variableFromName(vp.name());
132 
133  if (vp.domainSize() != vq.domainSize())
134  GUM_ERROR(OperationNotAllowed,
135  "BNdistance : the 2 BNs are not compatible "
136  "(not the same domainSize for "
137  + vp.name() + ")");
138 
139  for (Idx i = 0; i < vp.domainSize(); i++) {
140  try {
141  vq[vp.label(i)];
142  vp[vq.label(i)];
143 
144  } catch (OutOfBounds&) {
145  GUM_ERROR(OperationNotAllowed,
146  "BNdistance : the 2 BNs are not compatible F(not the same "
147  "labels for "
148  + vp.name() + ")");
149  }
150  }
151  } catch (NotFound&) {
152  GUM_ERROR(OperationNotAllowed,
153  "BNdistance : the 2 BNs are not compatible (not the same vars : " + vp.name()
154  + ")");
155  }
156  }
157 
158  // should not be used
159  if (p_.size() != q_.size())
160  GUM_ERROR(OperationNotAllowed,
161  "BNdistance : the 2 BNs are not compatible (not the same size)")
162 
163  if (std::fabs(p_.log10DomainSize() - q_.log10DomainSize()) > 1e-14) {
164  GUM_ERROR(OperationNotAllowed,
165  "BNdistance : the 2 BNs are not compatible (not the same domainSize) : p="
166  << p_.log10DomainSize() << " q=" << q_.log10DomainSize() << " => "
167  << p_.log10DomainSize() - q_.log10DomainSize());
168  }
169 
170  return true;
171  }
172 
173  // do the job if not already _done_
174  template < typename GUM_SCALAR >
175  void BNdistance< GUM_SCALAR >::process_() {
176  if (!_done_) {
177  computeKL_();
178  _done_ = true;
179  }
180  }
181 
182  // in order to keep BNdistance instantiable
183  template < typename GUM_SCALAR >
184  void BNdistance< GUM_SCALAR >::computeKL_() {
185  GUM_ERROR(OperationNotAllowed, "No default computations")
186  }
187 } // namespace gum