aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
exactBNdistance_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief KL divergence between BNs brute force implementation
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <agrum/tools/core/math/math_utils.h>
30 #include <agrum/BN/IBayesNet.h>
31 #include <agrum/BN/algorithms/divergence/BNdistance.h>
32 #include <agrum/BN/algorithms/divergence/exactBNdistance.h>
33 
34 namespace gum {
35  template < typename GUM_SCALAR >
36  ExactBNdistance< GUM_SCALAR >::ExactBNdistance(
37  const IBayesNet< GUM_SCALAR >& P,
38  const IBayesNet< GUM_SCALAR >& Q) :
39  BNdistance< GUM_SCALAR >(P, Q) {
40  GUM_CONSTRUCTOR(ExactBNdistance);
41  }
42 
43  template < typename GUM_SCALAR >
45  const BNdistance< GUM_SCALAR >& kl) :
48  }
49 
50  template < typename GUM_SCALAR >
53  }
54 
55  template < typename GUM_SCALAR >
58  errorPQ_ = errorQP_ = 0;
59 
60  auto Ip = p_.completeInstantiation();
61  auto Iq = q_.completeInstantiation();
62 
63  // map between p_ variables and q_ variables (using name of vars)
65 
66  for (Idx ite = 0; ite < Ip.nbrDim(); ++ite) {
68  }
70  for (Ip.setFirst(); !Ip.end(); ++Ip) {
71  Iq.setValsFrom(map, Ip);
74  pmid = (pp + pq) / 2.0;
75  lpmid = lpq = lpp = (GUM_SCALAR)0.0;
76  if (pmid != (GUM_SCALAR)0.0) lpmid = std::log2(pmid);
77  if (pp != (GUM_SCALAR)0.0) lpp = std::log2(pp);
78  if (pq != (GUM_SCALAR)0.0) lpq = std::log2(pq);
79 
80 
81  hellinger_ += std::pow(std::sqrt(pp) - std::sqrt(pq), 2);
82  bhattacharya_ += std::sqrt(pp * pq);
83 
84  if (pp != (GUM_SCALAR)0.0) {
85  if (pq != (GUM_SCALAR)0.0) {
86  klPQ_ -= pp * (lpq - lpp); // log2(pq / pp);
87  } else {
88  errorPQ_++;
89  }
90  }
91 
92  if (pq != (GUM_SCALAR)0.0) {
93  if (pp != (GUM_SCALAR)0.0) {
94  klQP_ -= pq * (lpp - lpq); // log2(pp / pq);
95  } else {
96  errorQP_++;
97  }
98  }
99  if (pmid != (GUM_SCALAR)0.0) {
100  jsd_ += pp * lpp + pq * lpq
101  - (pp + pq) * lpmid; // pp* log2(pp / pmid) + pq * log2(pq / pmid);
102  }
103  }
104  jsd_ /= 2.0;
107  }
108 
109 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669