aGrUM
0.21.0
a C++ library for (probabilistic) graphical models
leastSquareTestPolicy_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template implementations for the LeastSquareTestPolicy class.
25
*
26
* @author Jean-Christophe MAGNAN
27
*/
28
// =========================================================================
29
#
include
<
agrum
/
FMDP
/
learning
/
core
/
testPolicy
/
leastSquareTestPolicy
.
h
>
30
// =========================================================================
31
32
33
namespace
gum
{
34
35
template
<
typename
GUM_SCALAR
>
36
LeastSquareTestPolicy
<
GUM_SCALAR
>::~
LeastSquareTestPolicy
() {
37
for
(
auto
obsIter
=
this
->
_obsTable_
.
cbeginSafe
();
_obsTable_
.
cendSafe
() !=
obsIter
; ++
obsIter
)
38
delete
obsIter
.
val
();
39
40
GUM_DESTRUCTOR
(
LeastSquareTestPolicy
);
41
}
42
43
44
// ##########################################################################
45
//
46
// ##########################################################################
47
48
// ==========================================================================
49
//
50
// ==========================================================================
51
template
<
typename
GUM_SCALAR
>
52
void
LeastSquareTestPolicy
<
GUM_SCALAR
>::
addObservation
(
Idx
attr
,
GUM_SCALAR
value
) {
53
ITestPolicy
<
GUM_SCALAR
>::
addObservation
(
attr
,
value
);
54
_sumO_
+=
value
;
55
56
if
(
_sumAttrTable_
.
exists
(
attr
))
57
_sumAttrTable_
[
attr
] +=
value
;
58
else
59
_sumAttrTable_
.
insert
(
attr
,
value
);
60
61
if
(
_nbObsTable_
.
exists
(
attr
))
62
_nbObsTable_
[
attr
]++;
63
else
64
_nbObsTable_
.
insert
(
attr
, 1);
65
66
if
(!
_obsTable_
.
exists
(
attr
))
_obsTable_
.
insert
(
attr
,
new
LinkedList
<
double
>());
67
_obsTable_
[
attr
]->
addLink
(
value
);
68
}
69
70
71
// ############################################################################
72
// @name Test result
73
// ############################################################################
74
75
// ============================================================================
76
// Computes the GStat of current variable according to the test
77
// ============================================================================
78
template
<
typename
GUM_SCALAR
>
79
void
LeastSquareTestPolicy
<
GUM_SCALAR
>::
computeScore
() {
80
ITestPolicy
<
GUM_SCALAR
>::
computeScore
();
81
double
mean
=
_sumO_
/ (
double
)
this
->
nbObservation
();
82
double
errorO
= 0.0;
83
double
sumErrorAttr
= 0.0;
84
for
(
auto
attrIter
=
_sumAttrTable_
.
cbeginSafe
();
attrIter
!=
_sumAttrTable_
.
cendSafe
();
85
++
attrIter
) {
86
Idx
key
=
attrIter
.
key
();
87
double
meanAttr
=
_sumAttrTable_
[
key
] / (
double
)
_nbObsTable_
[
key
];
88
double
errorAttr
= 0.0;
89
90
const
Link
<
double
>*
linky
=
_obsTable_
[
key
]->
list
();
91
while
(
linky
) {
92
errorAttr
+=
std
::
pow
(
linky
->
element
() -
meanAttr
, 2);
93
errorO
+=
std
::
pow
(
linky
->
element
() -
mean
, 2);
94
linky
=
linky
->
nextLink
();
95
}
96
97
sumErrorAttr
+= ((
double
)
_nbObsTable_
[
key
] / (
double
)
this
->
nbObservation
()) *
errorAttr
;
98
}
99
_score_
=
errorO
-
sumErrorAttr
;
100
}
101
102
// ============================================================================
103
// Returns the performance of current variable according to the test
104
// ============================================================================
105
template
<
typename
GUM_SCALAR
>
106
double
LeastSquareTestPolicy
<
GUM_SCALAR
>::
score
() {
107
if
(
this
->
isModified_
())
computeScore
();
108
return
_score_
;
109
}
110
111
// ============================================================================
112
// Returns a second criterion to severe ties
113
// ============================================================================
114
template
<
typename
GUM_SCALAR
>
115
double
LeastSquareTestPolicy
<
GUM_SCALAR
>::
secondaryscore
()
const
{
116
if
(
this
->
isModified_
())
computeScore
();
117
return
_score_
;
118
}
119
120
template
<
typename
GUM_SCALAR
>
121
void
LeastSquareTestPolicy
<
GUM_SCALAR
>::
add
(
const
LeastSquareTestPolicy
&
src
) {
122
ITestPolicy
<
GUM_SCALAR
>::
add
(
src
);
123
124
for
(
auto
obsIter
=
src
.
nbObsTable
().
cbeginSafe
();
obsIter
!=
src
.
nbObsTable
().
cendSafe
();
125
++
obsIter
)
126
if
(
_nbObsTable_
.
exists
(
obsIter
.
key
()))
127
_nbObsTable_
[
obsIter
.
key
()] +=
obsIter
.
val
();
128
else
129
_nbObsTable_
.
insert
(
obsIter
.
key
(),
obsIter
.
val
());
130
131
for
(
auto
attrIter
=
src
.
sumAttrTable
().
cbeginSafe
();
attrIter
!=
src
.
sumAttrTable
().
cendSafe
();
132
++
attrIter
)
133
if
(
_sumAttrTable_
.
exists
(
attrIter
.
key
()))
134
_sumAttrTable_
[
attrIter
.
key
()] +=
attrIter
.
val
();
135
else
136
_sumAttrTable_
.
insert
(
attrIter
.
key
(),
attrIter
.
val
());
137
138
for
(
auto
obsIter
=
src
.
obsTable
().
cbeginSafe
();
obsIter
!=
src
.
obsTable
().
cendSafe
();
139
++
obsIter
) {
140
if
(!
_obsTable_
.
exists
(
obsIter
.
key
()))
141
_obsTable_
.
insert
(
obsIter
.
key
(),
new
LinkedList
<
double
>());
142
const
Link
<
double
>*
srcLink
=
obsIter
.
val
()->
list
();
143
while
(
srcLink
) {
144
_obsTable_
[
obsIter
.
key
()]->
addLink
(
srcLink
->
element
());
145
srcLink
=
srcLink
->
nextLink
();
146
}
147
}
148
}
149
150
}
// End of namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643