aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
nodeDatabase_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template implémentations of the NodeDatabase class.
25
*
26
* @author Jean-Christophe MAGNAN
27
*/
28
// =========================================================================
29
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
nodeDatabase
.
h
>
30
// =========================================================================
31
32
namespace
gum
{
33
34
// ==========================================================================
35
// Constructor & destructor.
36
// ==========================================================================
37
38
// ###################################################################
39
// Default constructor
40
// ###################################################################
41
template
< TESTNAME AttributeSelection,
bool
isScalar >
42
NodeDatabase< AttributeSelection, isScalar >::NodeDatabase(
43
const
Set<
const
DiscreteVariable* >* attrSet,
44
const
DiscreteVariable* value) :
45
value__(value) {
46
GUM_CONSTRUCTOR(NodeDatabase);
47
48
for
(SetIteratorSafe<
const
DiscreteVariable* > varIter
49
= attrSet->cbeginSafe();
50
varIter != attrSet->cendSafe();
51
++varIter)
52
attrTable__.insert(*varIter,
new
TestPolicy< ValueType >());
53
54
nbObservation__ = 0;
55
}
56
57
58
// ###################################################################
59
// Default desstructor
60
// ###################################################################
61
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
62
NodeDatabase
<
AttributeSelection
,
isScalar
>::~
NodeDatabase
() {
63
for
(
auto
varIter
=
attrTable__
.
beginSafe
();
varIter
!=
attrTable__
.
endSafe
();
64
++
varIter
)
65
delete
varIter
.
val
();
66
67
GUM_DESTRUCTOR
(
NodeDatabase
);
68
}
69
70
71
// ==========================================================================
72
// Observation handling methods
73
// ==========================================================================
74
75
// ###################################################################
76
/* Updates database with new observation
77
*
78
* Calls either @fn addObservation__( const Observation*, Int2Type<true>)
79
* or @fn addObservation__( const Observation*, Int2Type<false>)
80
* depending on if we're learning reward function or transition probability
81
*/
82
// ###################################################################
83
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
84
void
NodeDatabase
<
AttributeSelection
,
isScalar
>::
addObservation
(
85
const
Observation
*
newObs
) {
86
nbObservation__
++;
87
this
->
addObservation__
(
newObs
,
Int2Type
<
isScalar
>());
88
}
89
90
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
91
void
NodeDatabase
<
AttributeSelection
,
isScalar
>::
addObservation__
(
92
const
Observation
*
newObs
,
93
Int2Type
<
true
>) {
94
for
(
auto
varIter
=
attrTable__
.
cbeginSafe
();
95
varIter
!=
attrTable__
.
cendSafe
();
96
++
varIter
)
97
varIter
.
val
()->
addObservation
(
newObs
->
rModality
(
varIter
.
key
()),
98
newObs
->
reward
());
99
100
if
(
valueCount__
.
exists
(
newObs
->
reward
()))
101
valueCount__
[
newObs
->
reward
()]++;
102
else
103
valueCount__
.
insert
(
newObs
->
reward
(), 1);
104
}
105
106
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
107
void
NodeDatabase
<
AttributeSelection
,
isScalar
>::
addObservation__
(
108
const
Observation
*
newObs
,
109
Int2Type
<
false
>) {
110
for
(
auto
varIter
=
attrTable__
.
cbeginSafe
();
111
varIter
!=
attrTable__
.
cendSafe
();
112
++
varIter
)
113
varIter
.
val
()->
addObservation
(
newObs
->
modality
(
varIter
.
key
()),
114
newObs
->
modality
(
value__
));
115
116
if
(
valueCount__
.
exists
(
newObs
->
modality
(
value__
)))
117
valueCount__
[
newObs
->
modality
(
value__
)]++;
118
else
119
valueCount__
.
insert
(
newObs
->
modality
(
value__
), 1);
120
}
121
122
123
// ==========================================================================
124
// Aggregation Methods
125
// ==========================================================================
126
127
128
// ###################################################################
129
// Merges given NodeDatabase informations into current nDB.
130
// ###################################################################
131
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
132
NodeDatabase
<
AttributeSelection
,
isScalar
>&
133
NodeDatabase
<
AttributeSelection
,
isScalar
>::
operator
+=(
134
const
NodeDatabase
<
AttributeSelection
,
isScalar
>&
src
) {
135
this
->
nbObservation__
+=
src
.
nbObservation
();
136
137
for
(
auto
varIter
=
attrTable__
.
beginSafe
();
varIter
!=
attrTable__
.
endSafe
();
138
++
varIter
)
139
varIter
.
val
()->
add
(*(
src
.
testPolicy
(
varIter
.
key
())));
140
141
for
(
auto
valIter
=
src
.
cbeginValues
();
valIter
!=
src
.
cendValues
(); ++
valIter
)
142
if
(
valueCount__
.
exists
(
valIter
.
key
()))
143
valueCount__
[
valIter
.
key
()] +=
valIter
.
val
();
144
else
145
valueCount__
.
insert
(
valIter
.
key
(),
valIter
.
val
());
146
147
return
*
this
;
148
}
149
150
151
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
152
std
::
string
NodeDatabase
<
AttributeSelection
,
isScalar
>::
toString
()
const
{
153
std
::
stringstream
ss
;
154
155
ss
<<
"NbObservation : "
<<
this
->
nbObservation
() <<
std
::
endl
;
156
for
(
auto
varIter
=
attrTable__
.
beginSafe
();
varIter
!=
attrTable__
.
endSafe
();
157
++
varIter
)
158
ss
<<
"\t\tVariable : "
<<
varIter
.
key
()->
name
()
159
<<
" - Associated Test : "
<<
attrTable__
[
varIter
.
key
()]->
toString
()
160
<<
std
::
endl
;
161
162
return
ss
.
str
();
163
}
164
}
// End of namespace gum
165
166
167
// LEFT HERE ON PURPOSE
168
// NOT TO BE DELETED
169
170
/*template<TESTNAME AttributeSelection, bool isScalar>
171
double *NodeDatabase<AttributeSelection, isScalar>::effectif(){
172
double* ret = static_cast<double*>(
173
SmallObjectAllocator::instance().allocate(sizeof(double)*value__->domainSize()));
174
for(Idx modality = 0; modality < value__->domainSize(); ++modality)
175
if( valueCount__.exists(modality) )
176
ret[modality] = (double)valueCount__[modality];
177
else
178
ret[modality] = 0.0;
179
return ret;
180
}*/
181
182
/*template<TESTNAME AttributeSelection, bool isScalar>
183
double NodeDatabase<AttributeSelection, isScalar>::reward(){
184
double ret = 0.0;
185
for(auto valuTer = valueCount__.cbeginSafe(); valuTer !=
186
valueCount__.cendSafe(); ++valuTer)
187
ret += valuTer.key() * (double) valuTer.val();
188
return ret / nbObservation__;
189
}*/
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669