aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
exactBNdistance_tpl.h
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief KL divergence between BNs brute force implementation
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6)
27  */
28 
29 #include <agrum/tools/core/math/math_utils.h>
30 #include <agrum/BN/IBayesNet.h>
31 #include <agrum/BN/algorithms/divergence/BNdistance.h>
32 #include <agrum/BN/algorithms/divergence/exactBNdistance.h>
33 
34 namespace gum {
35  template < typename GUM_SCALAR >
36  ExactBNdistance< GUM_SCALAR >::ExactBNdistance(const IBayesNet< GUM_SCALAR >& P,
37  const IBayesNet< GUM_SCALAR >& Q) :
38  BNdistance< GUM_SCALAR >(P, Q) {
39  GUM_CONSTRUCTOR(ExactBNdistance);
40  }
41 
42  template < typename GUM_SCALAR >
46  }
47 
48  template < typename GUM_SCALAR >
51  }
52 
53  template < typename GUM_SCALAR >
56  errorPQ_ = errorQP_ = 0;
57 
58  auto Ip = p_.completeInstantiation();
59  auto Iq = q_.completeInstantiation();
60 
61  // map between p_ variables and q_ variables (using name of vars)
63 
64  for (Idx ite = 0; ite < Ip.nbrDim(); ++ite) {
66  }
68  for (Ip.setFirst(); !Ip.end(); ++Ip) {
69  Iq.setValsFrom(map, Ip);
72  pmid = (pp + pq) / 2.0;
73  lpmid = lpq = lpp = (GUM_SCALAR)0.0;
74  if (pmid != (GUM_SCALAR)0.0) lpmid = std::log2(pmid);
75  if (pp != (GUM_SCALAR)0.0) lpp = std::log2(pp);
76  if (pq != (GUM_SCALAR)0.0) lpq = std::log2(pq);
77 
78 
79  hellinger_ += std::pow(std::sqrt(pp) - std::sqrt(pq), 2);
80  bhattacharya_ += std::sqrt(pp * pq);
81 
82  if (pp != (GUM_SCALAR)0.0) {
83  if (pq != (GUM_SCALAR)0.0) {
84  klPQ_ -= pp * (lpq - lpp); // log2(pq / pp);
85  } else {
86  errorPQ_++;
87  }
88  }
89 
90  if (pq != (GUM_SCALAR)0.0) {
91  if (pp != (GUM_SCALAR)0.0) {
92  klQP_ -= pq * (lpp - lpq); // log2(pp / pq);
93  } else {
94  errorQP_++;
95  }
96  }
97  if (pmid != (GUM_SCALAR)0.0) {
98  jsd_ += pp * lpp + pq * lpq
99  - (pp + pq) * lpmid; // pp* log2(pp / pmid) + pq * log2(pq / pmid);
100  }
101  }
102  jsd_ /= 2.0;
105  }
106 
107 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643