aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
leastSquareTestPolicy_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template implementations for the LeastSquareTestPolicy class.
25
*
26
* @author Jean-Christophe MAGNAN
27
*/
28
// =========================================================================
29
#
include
<
agrum
/
FMDP
/
learning
/
core
/
testPolicy
/
leastSquareTestPolicy
.
h
>
30
// =========================================================================
31
32
33
namespace
gum
{
34
35
template
<
typename
GUM_SCALAR
>
36
LeastSquareTestPolicy
<
GUM_SCALAR
>::~
LeastSquareTestPolicy
() {
37
for
(
auto
obsIter
=
this
->
obsTable__
.
cbeginSafe
();
38
obsTable__
.
cendSafe
() !=
obsIter
;
39
++
obsIter
)
40
delete
obsIter
.
val
();
41
42
GUM_DESTRUCTOR
(
LeastSquareTestPolicy
);
43
}
44
45
46
// ##########################################################################
47
//
48
// ##########################################################################
49
50
// ==========================================================================
51
//
52
// ==========================================================================
53
template
<
typename
GUM_SCALAR
>
54
void
LeastSquareTestPolicy
<
GUM_SCALAR
>::
addObservation
(
Idx
attr
,
55
GUM_SCALAR
value
) {
56
ITestPolicy
<
GUM_SCALAR
>::
addObservation
(
attr
,
value
);
57
sumO__
+=
value
;
58
59
if
(
sumAttrTable__
.
exists
(
attr
))
60
sumAttrTable__
[
attr
] +=
value
;
61
else
62
sumAttrTable__
.
insert
(
attr
,
value
);
63
64
if
(
nbObsTable__
.
exists
(
attr
))
65
nbObsTable__
[
attr
]++;
66
else
67
nbObsTable__
.
insert
(
attr
, 1);
68
69
if
(!
obsTable__
.
exists
(
attr
))
70
obsTable__
.
insert
(
attr
,
new
LinkedList
<
double
>());
71
obsTable__
[
attr
]->
addLink
(
value
);
72
}
73
74
75
// ############################################################################
76
// @name Test result
77
// ############################################################################
78
79
// ============================================================================
80
// Computes the GStat of current variable according to the test
81
// ============================================================================
82
template
<
typename
GUM_SCALAR
>
83
void
LeastSquareTestPolicy
<
GUM_SCALAR
>::
computeScore
() {
84
ITestPolicy
<
GUM_SCALAR
>::
computeScore
();
85
double
mean
=
sumO__
/ (
double
)
this
->
nbObservation
();
86
double
errorO
= 0.0;
87
double
sumErrorAttr
= 0.0;
88
for
(
auto
attrIter
=
sumAttrTable__
.
cbeginSafe
();
89
attrIter
!=
sumAttrTable__
.
cendSafe
();
90
++
attrIter
) {
91
Idx
key
=
attrIter
.
key
();
92
double
meanAttr
=
sumAttrTable__
[
key
] / (
double
)
nbObsTable__
[
key
];
93
double
errorAttr
= 0.0;
94
95
const
Link
<
double
>*
linky
=
obsTable__
[
key
]->
list
();
96
while
(
linky
) {
97
errorAttr
+=
std
::
pow
(
linky
->
element
() -
meanAttr
, 2);
98
errorO
+=
std
::
pow
(
linky
->
element
() -
mean
, 2);
99
linky
=
linky
->
nextLink
();
100
}
101
102
sumErrorAttr
+= ((
double
)
nbObsTable__
[
key
] / (
double
)
this
->
nbObservation
())
103
*
errorAttr
;
104
}
105
score__
=
errorO
-
sumErrorAttr
;
106
}
107
108
// ============================================================================
109
// Returns the performance of current variable according to the test
110
// ============================================================================
111
template
<
typename
GUM_SCALAR
>
112
double
LeastSquareTestPolicy
<
GUM_SCALAR
>::
score
() {
113
if
(
this
->
isModified_
())
computeScore
();
114
return
score__
;
115
}
116
117
// ============================================================================
118
// Returns a second criterion to severe ties
119
// ============================================================================
120
template
<
typename
GUM_SCALAR
>
121
double
LeastSquareTestPolicy
<
GUM_SCALAR
>::
secondaryscore
()
const
{
122
if
(
this
->
isModified_
())
computeScore
();
123
return
score__
;
124
}
125
126
template
<
typename
GUM_SCALAR
>
127
void
LeastSquareTestPolicy
<
GUM_SCALAR
>::
add
(
const
LeastSquareTestPolicy
&
src
) {
128
ITestPolicy
<
GUM_SCALAR
>::
add
(
src
);
129
130
for
(
auto
obsIter
=
src
.
nbObsTable
().
cbeginSafe
();
131
obsIter
!=
src
.
nbObsTable
().
cendSafe
();
132
++
obsIter
)
133
if
(
nbObsTable__
.
exists
(
obsIter
.
key
()))
134
nbObsTable__
[
obsIter
.
key
()] +=
obsIter
.
val
();
135
else
136
nbObsTable__
.
insert
(
obsIter
.
key
(),
obsIter
.
val
());
137
138
for
(
auto
attrIter
=
src
.
sumAttrTable
().
cbeginSafe
();
139
attrIter
!=
src
.
sumAttrTable
().
cendSafe
();
140
++
attrIter
)
141
if
(
sumAttrTable__
.
exists
(
attrIter
.
key
()))
142
sumAttrTable__
[
attrIter
.
key
()] +=
attrIter
.
val
();
143
else
144
sumAttrTable__
.
insert
(
attrIter
.
key
(),
attrIter
.
val
());
145
146
for
(
auto
obsIter
=
src
.
obsTable
().
cbeginSafe
();
147
obsIter
!=
src
.
obsTable
().
cendSafe
();
148
++
obsIter
) {
149
if
(!
obsTable__
.
exists
(
obsIter
.
key
()))
150
obsTable__
.
insert
(
obsIter
.
key
(),
new
LinkedList
<
double
>());
151
const
Link
<
double
>*
srcLink
=
obsIter
.
val
()->
list
();
152
while
(
srcLink
) {
153
obsTable__
[
obsIter
.
key
()]->
addLink
(
srcLink
->
element
());
154
srcLink
=
srcLink
->
nextLink
();
155
}
156
}
157
}
158
159
}
// End of namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669