aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
fmdpLearner_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Template Implementations of the FMDPLearner class.
25
*
26
* @author Jean-Christophe MAGNAN
27
*/
28
29
// =========================================================================
30
#
include
<
agrum
/
FMDP
/
learning
/
fmdpLearner
.
h
>
31
// =========================================================================
32
33
namespace
gum
{
34
35
// ==========================================================================
36
// Constructor & destructor.
37
// ==========================================================================
38
39
// ###################################################################
40
// Default constructor
41
// ###################################################################
42
template
< TESTNAME VariableAttributeSelection,
43
TESTNAME RewardAttributeSelection,
44
LEARNERNAME LearnerSelection >
45
FMDPLearner< VariableAttributeSelection, RewardAttributeSelection, LearnerSelection >::
46
FMDPLearner(
double
lT,
bool
actionReward,
double
sT) :
47
_actionReward_(actionReward),
48
_learningThreshold_(lT), _similarityThreshold_(sT) {
49
GUM_CONSTRUCTOR(FMDPLearner);
50
_rewardLearner_ =
nullptr
;
51
}
52
53
54
// ###################################################################
55
// Default destructor
56
// ###################################################################
57
template
<
TESTNAME
VariableAttributeSelection
,
58
TESTNAME
RewardAttributeSelection
,
59
LEARNERNAME
LearnerSelection
>
60
FMDPLearner
<
VariableAttributeSelection
,
RewardAttributeSelection
,
LearnerSelection
>::
61
~
FMDPLearner
() {
62
for
(
auto
actionIter
=
_actionLearners_
.
beginSafe
();
actionIter
!=
_actionLearners_
.
endSafe
();
63
++
actionIter
) {
64
for
(
auto
learnerIter
=
actionIter
.
val
()->
beginSafe
();
65
learnerIter
!=
actionIter
.
val
()->
endSafe
();
66
++
learnerIter
)
67
delete
learnerIter
.
val
();
68
delete
actionIter
.
val
();
69
if
(
_actionRewardLearners_
.
exists
(
actionIter
.
key
()))
70
delete
_actionRewardLearners_
[
actionIter
.
key
()];
71
}
72
73
if
(
_rewardLearner_
)
delete
_rewardLearner_
;
74
75
GUM_DESTRUCTOR
(
FMDPLearner
);
76
}
77
78
79
// ==========================================================================
80
//
81
// ==========================================================================
82
83
// ###################################################################
84
//
85
// ###################################################################
86
template
<
TESTNAME
VariableAttributeSelection
,
87
TESTNAME
RewardAttributeSelection
,
88
LEARNERNAME
LearnerSelection
>
89
void
FMDPLearner
<
VariableAttributeSelection
,
RewardAttributeSelection
,
LearnerSelection
>::
90
initialize
(
FMDP
<
double
>*
fmdp
) {
91
_fmdp_
=
fmdp
;
92
93
_modaMax_
= 0;
94
_rmax_
= 0.0;
95
96
Set
<
const
DiscreteVariable
* >
mainVariables
;
97
for
(
auto
varIter
=
_fmdp_
->
beginVariables
();
varIter
!=
_fmdp_
->
endVariables
(); ++
varIter
) {
98
mainVariables
.
insert
(*
varIter
);
99
_modaMax_
=
_modaMax_
< (*
varIter
)->
domainSize
() ? (*
varIter
)->
domainSize
() :
_modaMax_
;
100
}
101
102
for
(
auto
actionIter
=
_fmdp_
->
beginActions
();
actionIter
!=
_fmdp_
->
endActions
();
103
++
actionIter
) {
104
// Adding a Hashtable for the action
105
_actionLearners_
.
insert
(*
actionIter
,
new
VarLearnerTable
());
106
107
// Adding a learner for each variable
108
for
(
auto
varIter
=
_fmdp_
->
beginVariables
();
varIter
!=
_fmdp_
->
endVariables
(); ++
varIter
) {
109
MultiDimFunctionGraph
<
double
>*
varTrans
=
_instantiateFunctionGraph_
();
110
varTrans
->
setTableName
(
"ACTION : "
+
_fmdp_
->
actionName
(*
actionIter
)
111
+
" - VARIABLE : "
+ (*
varIter
)->
name
());
112
_fmdp_
->
addTransitionForAction
(*
actionIter
, *
varIter
,
varTrans
);
113
_actionLearners_
[*
actionIter
]->
insert
(
114
(*
varIter
),
115
_instantiateVarLearner_
(
varTrans
,
mainVariables
,
_fmdp_
->
main2prime
(*
varIter
)));
116
}
117
118
if
(
_actionReward_
) {
119
MultiDimFunctionGraph
<
double
>*
reward
=
_instantiateFunctionGraph_
();
120
reward
->
setTableName
(
"REWARD - ACTION : "
+
_fmdp_
->
actionName
(*
actionIter
));
121
_fmdp_
->
addRewardForAction
(*
actionIter
,
reward
);
122
_actionRewardLearners_
.
insert
(*
actionIter
,
123
_instantiateRewardLearner_
(
reward
,
mainVariables
));
124
}
125
}
126
127
if
(!
_actionReward_
) {
128
MultiDimFunctionGraph
<
double
>*
reward
=
_instantiateFunctionGraph_
();
129
reward
->
setTableName
(
"REWARD"
);
130
_fmdp_
->
addReward
(
reward
);
131
_rewardLearner_
=
_instantiateRewardLearner_
(
reward
,
mainVariables
);
132
}
133
}
134
135
// ###################################################################
136
//
137
// ###################################################################
138
template
<
TESTNAME
VariableAttributeSelection
,
139
TESTNAME
RewardAttributeSelection
,
140
LEARNERNAME
LearnerSelection
>
141
bool
FMDPLearner
<
VariableAttributeSelection
,
RewardAttributeSelection
,
LearnerSelection
>::
142
addObservation
(
Idx
actionId
,
const
Observation
*
newObs
) {
143
for
(
SequenceIteratorSafe
<
const
DiscreteVariable
* >
varIter
=
_fmdp_
->
beginVariables
();
144
varIter
!=
_fmdp_
->
endVariables
();
145
++
varIter
) {
146
_actionLearners_
[
actionId
]->
getWithDefault
(*
varIter
,
nullptr
)->
addObservation
(
newObs
);
147
_actionLearners_
[
actionId
]->
getWithDefault
(*
varIter
,
nullptr
)->
updateGraph
();
148
}
149
150
if
(
_actionReward_
) {
151
_actionRewardLearners_
[
actionId
]->
addObservation
(
newObs
);
152
_actionRewardLearners_
[
actionId
]->
updateGraph
();
153
}
else
{
154
_rewardLearner_
->
addObservation
(
newObs
);
155
_rewardLearner_
->
updateGraph
();
156
}
157
158
_rmax_
=
_rmax_
<
std
::
abs
(
newObs
->
reward
()) ?
std
::
abs
(
newObs
->
reward
()) :
_rmax_
;
159
160
return
false
;
161
}
162
163
// ###################################################################
164
//
165
// ###################################################################
166
template
<
TESTNAME
VariableAttributeSelection
,
167
TESTNAME
RewardAttributeSelection
,
168
LEARNERNAME
LearnerSelection
>
169
Size
170
FMDPLearner
<
VariableAttributeSelection
,
RewardAttributeSelection
,
LearnerSelection
>::
size
() {
171
Size
s
= 0;
172
for
(
SequenceIteratorSafe
<
Idx
>
actionIter
=
_fmdp_
->
beginActions
();
173
actionIter
!=
_fmdp_
->
endActions
();
174
++
actionIter
) {
175
for
(
SequenceIteratorSafe
<
const
DiscreteVariable
* >
varIter
=
_fmdp_
->
beginVariables
();
176
varIter
!=
_fmdp_
->
endVariables
();
177
++
varIter
)
178
s
+=
_actionLearners_
[*
actionIter
]->
getWithDefault
(*
varIter
,
nullptr
)->
size
();
179
if
(
_actionReward_
)
s
+=
_actionRewardLearners_
[*
actionIter
]->
size
();
180
}
181
182
if
(!
_actionReward_
)
s
+=
_rewardLearner_
->
size
();
183
184
return
s
;
185
}
186
187
188
// ###################################################################
189
//
190
// ###################################################################
191
template
<
TESTNAME
VariableAttributeSelection
,
192
TESTNAME
RewardAttributeSelection
,
193
LEARNERNAME
LearnerSelection
>
194
void
FMDPLearner
<
VariableAttributeSelection
,
RewardAttributeSelection
,
LearnerSelection
>::
195
updateFMDP
() {
196
for
(
SequenceIteratorSafe
<
Idx
>
actionIter
=
_fmdp_
->
beginActions
();
197
actionIter
!=
_fmdp_
->
endActions
();
198
++
actionIter
) {
199
for
(
SequenceIteratorSafe
<
const
DiscreteVariable
* >
varIter
=
_fmdp_
->
beginVariables
();
200
varIter
!=
_fmdp_
->
endVariables
();
201
++
varIter
)
202
_actionLearners_
[*
actionIter
]->
getWithDefault
(*
varIter
,
nullptr
)->
updateFunctionGraph
();
203
if
(
_actionReward_
)
_actionRewardLearners_
[*
actionIter
]->
updateFunctionGraph
();
204
}
205
206
if
(!
_actionReward_
)
_rewardLearner_
->
updateFunctionGraph
();
207
}
208
}
// End of namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643