aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
nodeDatabase_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template implémentations of the NodeDatabase class.
25
*
26
* @author Jean-Christophe MAGNAN
27
*/
28
// =========================================================================
29
#
include
<
agrum
/
FMDP
/
learning
/
datastructure
/
nodeDatabase
.
h
>
30
// =========================================================================
31
32
namespace
gum
{
33
34
// ==========================================================================
35
// Constructor & destructor.
36
// ==========================================================================
37
38
// ###################################################################
39
// Default constructor
40
// ###################################################################
41
template
< TESTNAME AttributeSelection,
bool
isScalar >
42
NodeDatabase< AttributeSelection, isScalar >::NodeDatabase(
43
const
Set<
const
DiscreteVariable* >* attrSet,
44
const
DiscreteVariable* value) :
45
_value_(value) {
46
GUM_CONSTRUCTOR(NodeDatabase);
47
48
for
(SetIteratorSafe<
const
DiscreteVariable* > varIter = attrSet->cbeginSafe();
49
varIter != attrSet->cendSafe();
50
++varIter)
51
_attrTable_.insert(*varIter,
new
TestPolicy< ValueType >());
52
53
_nbObservation_ = 0;
54
}
55
56
57
// ###################################################################
58
// Default desstructor
59
// ###################################################################
60
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
61
NodeDatabase
<
AttributeSelection
,
isScalar
>::~
NodeDatabase
() {
62
for
(
auto
varIter
=
_attrTable_
.
beginSafe
();
varIter
!=
_attrTable_
.
endSafe
(); ++
varIter
)
63
delete
varIter
.
val
();
64
65
GUM_DESTRUCTOR
(
NodeDatabase
);
66
}
67
68
69
// ==========================================================================
70
// Observation handling methods
71
// ==========================================================================
72
73
// ###################################################################
74
/* Updates database with new observation
75
*
76
* Calls either @fn _addObservation_( const Observation*, Int2Type<true>)
77
* or @fn _addObservation_( const Observation*, Int2Type<false>)
78
* depending on if we're learning reward function or transition probability
79
*/
80
// ###################################################################
81
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
82
void
NodeDatabase
<
AttributeSelection
,
isScalar
>::
addObservation
(
const
Observation
*
newObs
) {
83
_nbObservation_
++;
84
this
->
_addObservation_
(
newObs
,
Int2Type
<
isScalar
>());
85
}
86
87
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
88
void
NodeDatabase
<
AttributeSelection
,
isScalar
>::
_addObservation_
(
const
Observation
*
newObs
,
89
Int2Type
<
true
>) {
90
for
(
auto
varIter
=
_attrTable_
.
cbeginSafe
();
varIter
!=
_attrTable_
.
cendSafe
(); ++
varIter
)
91
varIter
.
val
()->
addObservation
(
newObs
->
rModality
(
varIter
.
key
()),
newObs
->
reward
());
92
93
if
(
_valueCount_
.
exists
(
newObs
->
reward
()))
94
_valueCount_
[
newObs
->
reward
()]++;
95
else
96
_valueCount_
.
insert
(
newObs
->
reward
(), 1);
97
}
98
99
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
100
void
NodeDatabase
<
AttributeSelection
,
isScalar
>::
_addObservation_
(
const
Observation
*
newObs
,
101
Int2Type
<
false
>) {
102
for
(
auto
varIter
=
_attrTable_
.
cbeginSafe
();
varIter
!=
_attrTable_
.
cendSafe
(); ++
varIter
)
103
varIter
.
val
()->
addObservation
(
newObs
->
modality
(
varIter
.
key
()),
newObs
->
modality
(
_value_
));
104
105
if
(
_valueCount_
.
exists
(
newObs
->
modality
(
_value_
)))
106
_valueCount_
[
newObs
->
modality
(
_value_
)]++;
107
else
108
_valueCount_
.
insert
(
newObs
->
modality
(
_value_
), 1);
109
}
110
111
112
// ==========================================================================
113
// Aggregation Methods
114
// ==========================================================================
115
116
117
// ###################################################################
118
// Merges given NodeDatabase informations into current nDB.
119
// ###################################################################
120
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
121
NodeDatabase
<
AttributeSelection
,
isScalar
>&
122
NodeDatabase
<
AttributeSelection
,
isScalar
>::
operator
+=(
123
const
NodeDatabase
<
AttributeSelection
,
isScalar
>&
src
) {
124
this
->
_nbObservation_
+=
src
.
nbObservation
();
125
126
for
(
auto
varIter
=
_attrTable_
.
beginSafe
();
varIter
!=
_attrTable_
.
endSafe
(); ++
varIter
)
127
varIter
.
val
()->
add
(*(
src
.
testPolicy
(
varIter
.
key
())));
128
129
for
(
auto
valIter
=
src
.
cbeginValues
();
valIter
!=
src
.
cendValues
(); ++
valIter
)
130
if
(
_valueCount_
.
exists
(
valIter
.
key
()))
131
_valueCount_
[
valIter
.
key
()] +=
valIter
.
val
();
132
else
133
_valueCount_
.
insert
(
valIter
.
key
(),
valIter
.
val
());
134
135
return
*
this
;
136
}
137
138
139
template
<
TESTNAME
AttributeSelection
,
bool
isScalar
>
140
std
::
string
NodeDatabase
<
AttributeSelection
,
isScalar
>::
toString
()
const
{
141
std
::
stringstream
ss
;
142
143
ss
<<
"NbObservation : "
<<
this
->
nbObservation
() <<
std
::
endl
;
144
for
(
auto
varIter
=
_attrTable_
.
beginSafe
();
varIter
!=
_attrTable_
.
endSafe
(); ++
varIter
)
145
ss
<<
"\t\tVariable : "
<<
varIter
.
key
()->
name
()
146
<<
" - Associated Test : "
<<
_attrTable_
[
varIter
.
key
()]->
toString
() <<
std
::
endl
;
147
148
return
ss
.
str
();
149
}
150
}
// End of namespace gum
151
152
153
// LEFT HERE ON PURPOSE
154
// NOT TO BE DELETED
155
156
/*template<TESTNAME AttributeSelection, bool isScalar>
157
double *NodeDatabase<AttributeSelection, isScalar>::effectif(){
158
double* ret = static_cast<double*>(
159
SmallObjectAllocator::instance().allocate(sizeof(double)* _value_->domainSize()));
160
for(Idx modality = 0; modality < _value_->domainSize(); ++modality)
161
if( _valueCount_.exists(modality) )
162
ret[modality] = (double) _valueCount_[modality];
163
else
164
ret[modality] = 0.0;
165
return ret;
166
}*/
167
168
/*template<TESTNAME AttributeSelection, bool isScalar>
169
double NodeDatabase<AttributeSelection, isScalar>::reward(){
170
double ret = 0.0;
171
for(auto valuTer = _valueCount_.cbeginSafe(); valuTer !=
172
_valueCount_.cendSafe(); ++valuTer)
173
ret += valuTer.key() * (double) valuTer.val();
174
return ret / _nbObservation_;
175
}*/
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643