aGrUM
0.20.3
a C++ library for (probabilistic) graphical models
BayesNetFactory_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the BayesNetFactory class.
25
*
26
* @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27
28
*/
29
30
#
include
<
agrum
/
BN
/
BayesNetFactory
.
h
>
31
32
namespace
gum
{
33
34
// Default constructor.
35
// @param bn A pointer over the BayesNet filled by this factory.
36
// @throw DuplicateElement Raised if two variables in bn share the same
37
// name.
38
template
<
typename
GUM_SCALAR >
39
INLINE BayesNetFactory<
GUM_SCALAR
>::
BayesNetFactory
(
BayesNet
<
GUM_SCALAR
>*
bn
) :
40
_parents_
(0),
_impl_
(0),
_bn_
(
bn
) {
41
GUM_CONSTRUCTOR
(
BayesNetFactory
);
42
_states_
.
push_back
(
factory_state
::
NONE
);
43
44
for
(
auto
node
:
bn
->
nodes
()) {
45
if
(
_varNameMap_
.
exists
(
bn
->
variable
(
node
).
name
()))
46
GUM_ERROR
(
DuplicateElement
,
"Name already used: "
<<
bn
->
variable
(
node
).
name
())
47
48
_varNameMap_
.
insert
(
bn
->
variable
(
node
).
name
(),
node
);
49
}
50
51
resetVerbose
();
52
}
53
54
// Copy constructor.
55
// The copy will have an exact copy of the constructed BayesNet in source.
56
template
<
typename
GUM_SCALAR
>
57
INLINE
58
BayesNetFactory
<
GUM_SCALAR
>::
BayesNetFactory
(
const
BayesNetFactory
<
GUM_SCALAR
>&
source
) :
59
_parents_
(
nullptr
),
60
_impl_
(
nullptr
),
_bn_
(
nullptr
) {
61
GUM_CONS_CPY
(
BayesNetFactory
);
62
63
if
(
source
.
state
() !=
factory_state
::
NONE
) {
64
GUM_ERROR
(
OperationNotAllowed
,
"Illegal state to proceed make a copy."
)
65
}
else
{
66
_states_
=
source
.
_states_
;
67
_bn_
=
new
BayesNet
<
GUM_SCALAR
>(*(
source
.
_bn_
));
68
}
69
}
70
71
// Destructor
72
template
<
typename
GUM_SCALAR
>
73
INLINE
BayesNetFactory
<
GUM_SCALAR
>::~
BayesNetFactory
() {
74
GUM_DESTRUCTOR
(
BayesNetFactory
);
75
76
if
(
_parents_
!=
nullptr
)
delete
_parents_
;
77
78
if
(
_impl_
!=
nullptr
) {
79
//@todo better than throwing an exception from inside a destructor but
80
// still ...
81
std
::
cerr
<<
"[BN factory] Implementation defined for a variable but not used. "
82
"You should call endVariableDeclaration() before "
83
"deleting me."
84
<<
std
::
endl
;
85
exit
(1);
86
}
87
}
88
89
// Returns the BayesNet created by this factory.
90
template
<
typename
GUM_SCALAR
>
91
INLINE
BayesNet
<
GUM_SCALAR
>*
BayesNetFactory
<
GUM_SCALAR
>::
bayesNet
() {
92
return
_bn_
;
93
}
94
95
template
<
typename
GUM_SCALAR
>
96
INLINE
const
DiscreteVariable
&
BayesNetFactory
<
GUM_SCALAR
>::
varInBN
(
NodeId
id
) {
97
return
_bn_
->
variable
(
id
);
98
}
99
100
// Returns the current state of the factory.
101
template
<
typename
GUM_SCALAR
>
102
INLINE
IBayesNetFactory
::
factory_state
BayesNetFactory
<
GUM_SCALAR
>::
state
()
const
{
103
// This is ok because there is always at least the state NONE in the stack.
104
return
_states_
.
back
();
105
}
106
107
// Returns the NodeId of a variable given it's name.
108
// @throw NotFound Raised if no variable matches the name.
109
template
<
typename
GUM_SCALAR
>
110
INLINE
NodeId
BayesNetFactory
<
GUM_SCALAR
>::
variableId
(
const
std
::
string
&
name
)
const
{
111
try
{
112
return
_varNameMap_
[
name
];
113
}
catch
(
NotFound
&) {
GUM_ERROR
(
NotFound
,
name
) }
114
}
115
116
// Returns a constant reference on a variable given it's name.
117
// @throw NotFound Raised if no variable matches the name.
118
template
<
typename
GUM_SCALAR
>
119
INLINE
const
DiscreteVariable
&
120
BayesNetFactory
<
GUM_SCALAR
>::
variable
(
const
std
::
string
&
name
)
const
{
121
try
{
122
return
_bn_
->
variable
(
variableId
(
name
));
123
}
catch
(
NotFound
&) {
GUM_ERROR
(
NotFound
,
name
) }
124
}
125
126
// Returns the domainSize of the cpt for the node n.
127
// @throw NotFound raised if no such NodeId exists.
128
// @throw OperationNotAllowed if there is no Bayesian networks.
129
template
<
typename
GUM_SCALAR
>
130
INLINE
Size
BayesNetFactory
<
GUM_SCALAR
>::
cptDomainSize
(
const
NodeId
n
)
const
{
131
return
_bn_
->
cpt
(
n
).
domainSize
();
132
}
133
134
// Tells the factory that we're in a network declaration.
135
template
<
typename
GUM_SCALAR
>
136
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startNetworkDeclaration
() {
137
if
(
state
() !=
factory_state
::
NONE
) {
138
_illegalStateError_
(
"startNetworkDeclaration"
);
139
}
else
{
140
_states_
.
push_back
(
factory_state
::
NETWORK
);
141
}
142
}
143
144
// Tells the factory to add a property to the current network.
145
template
<
typename
GUM_SCALAR
>
146
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addNetworkProperty
(
const
std
::
string
&
propName
,
147
const
std
::
string
&
propValue
) {
148
_bn_
->
setProperty
(
propName
,
propValue
);
149
}
150
151
// Tells the factory that we're out of a network declaration.
152
template
<
typename
GUM_SCALAR
>
153
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endNetworkDeclaration
() {
154
if
(
state
() !=
factory_state
::
NETWORK
) {
155
_illegalStateError_
(
"endNetworkDeclaration"
);
156
}
else
{
157
_states_
.
pop_back
();
158
}
159
}
160
161
// Tells the factory that we're in a variable declaration.
162
// A variable is considered as a LabelizedVariable while its type is not defined.
163
template
<
typename
GUM_SCALAR
>
164
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startVariableDeclaration
() {
165
if
(
state
() !=
factory_state
::
NONE
) {
166
_illegalStateError_
(
"startVariableDeclaration"
);
167
}
else
{
168
_states_
.
push_back
(
factory_state
::
VARIABLE
);
169
_stringBag_
.
push_back
(
"name"
);
170
_stringBag_
.
push_back
(
"desc"
);
171
_stringBag_
.
push_back
(
"L"
);
172
}
173
}
174
175
// Tells the factory the current variable's name.
176
template
<
typename
GUM_SCALAR
>
177
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
variableName
(
const
std
::
string
&
name
) {
178
if
(
state
() !=
factory_state
::
VARIABLE
) {
179
_illegalStateError_
(
"variableName"
);
180
}
else
{
181
if
(
_varNameMap_
.
exists
(
name
)) {
GUM_ERROR
(
DuplicateElement
,
"Name already used: "
<<
name
) }
182
183
_foo_flag_
=
true
;
184
_stringBag_
[0] =
name
;
185
}
186
}
187
188
// Tells the factory the current variable's description.
189
template
<
typename
GUM_SCALAR
>
190
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
variableDescription
(
const
std
::
string
&
desc
) {
191
if
(
state
() !=
factory_state
::
VARIABLE
) {
192
_illegalStateError_
(
"variableDescription"
);
193
}
else
{
194
_bar_flag_
=
true
;
195
_stringBag_
[1] =
desc
;
196
}
197
}
198
199
// Tells the factory the current variable's type.
200
// L : Labelized
201
// R : Range
202
// C : Continuous
203
// D : Discretized
204
template
<
typename
GUM_SCALAR
>
205
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
variableType
(
const
gum
::
VarType
&
type
) {
206
if
(
state
() !=
factory_state
::
VARIABLE
) {
207
_illegalStateError_
(
"variableType"
);
208
}
else
{
209
switch
(
type
) {
210
case
VarType
::
Discretized
:
211
_stringBag_
[2] =
"D"
;
212
break
;
213
case
VarType
::
Range
:
214
_stringBag_
[2] =
"R"
;
215
break
;
216
case
VarType
::
Continuous
:
217
GUM_ERROR
(
OperationNotAllowed
,
218
"Continuous variable ("
+
_stringBag_
[0]
219
+
") are not supported in Bayesian networks."
)
220
case
VarType
::
Labelized
:
221
_stringBag_
[2] =
"L"
;
222
break
;
223
}
224
}
225
}
226
227
// Adds a modality to the current variable.
228
// @throw DuplicateElement If the current variable already has a modality
229
// with the same name.
230
template
<
typename
GUM_SCALAR
>
231
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addModality
(
const
std
::
string
&
name
) {
232
if
(
state
() !=
factory_state
::
VARIABLE
) {
233
_illegalStateError_
(
"addModality"
);
234
}
else
{
235
_checkModalityInBag_
(
name
);
236
_stringBag_
.
push_back
(
name
);
237
}
238
}
239
240
// Adds a modality to the current variable.
241
// @throw DuplicateElement If the current variable already has a modality
242
// with the same name.
243
template
<
typename
GUM_SCALAR
>
244
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addMin
(
const
long
&
min
) {
245
if
(
state
() !=
factory_state
::
VARIABLE
) {
246
_illegalStateError_
(
"addMin"
);
247
}
else
{
248
_stringBag_
.
push_back
(
std
::
to_string
(
min
));
249
}
250
}
251
252
// Adds a modality to the current variable.
253
// @throw DuplicateElement If the current variable already has a modality
254
// with the same name.
255
template
<
typename
GUM_SCALAR
>
256
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addMax
(
const
long
&
max
) {
257
if
(
state
() !=
factory_state
::
VARIABLE
) {
258
_illegalStateError_
(
"addMin"
);
259
}
else
{
260
_stringBag_
.
push_back
(
std
::
to_string
(
max
));
261
}
262
}
263
264
// Adds a modality to the current variable.
265
// @throw DuplicateElement If the current variable already has a modality
266
// with the same name.
267
template
<
typename
GUM_SCALAR
>
268
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addTick
(
const
GUM_SCALAR
&
tick
) {
269
if
(
state
() !=
factory_state
::
VARIABLE
) {
270
_illegalStateError_
(
"addTick"
);
271
}
else
{
272
_stringBag_
.
push_back
(
std
::
to_string
(
tick
));
273
}
274
}
275
276
// @brief Defines the implementation to use for Potential.
277
// @warning The implementation must be empty.
278
// @warning The pointer is always delegated to Potential! No copy of it
279
// is made.
280
// @todo When copy of a MultiDimImplementation is available use a copy
281
// behaviour for this method.
282
// @throw NotFound Raised if no variable matches var.
283
// @throw OperationNotAllowed Raised if impl is not empty.
284
// @throw OperationNotAllowed If an implementation is already defined for the
285
// current variable.
286
template
<
typename
GUM_SCALAR
>
287
INLINE
void
288
BayesNetFactory
<
GUM_SCALAR
>::
setVariableCPTImplementation
(
MultiDimAdressable
*
adressable
) {
289
MultiDimImplementation
<
GUM_SCALAR
>*
impl
290
=
dynamic_cast
<
MultiDimImplementation
<
GUM_SCALAR
>* >(
adressable
);
291
292
if
(
state
() !=
factory_state
::
VARIABLE
) {
293
_illegalStateError_
(
"setVariableCPTImplementation"
);
294
}
else
{
295
if
(
impl
== 0) {
296
GUM_ERROR
(
OperationNotAllowed
,
297
"An implementation for this variable is already "
298
"defined."
)
299
}
else
if
(
impl
->
nbrDim
() > 0) {
300
GUM_ERROR
(
OperationNotAllowed
,
"This implementation is not empty."
)
301
}
302
303
_impl_
=
impl
;
304
}
305
}
306
307
// Tells the factory that we're out of a variable declaration.
308
template
<
typename
GUM_SCALAR
>
309
INLINE
NodeId
BayesNetFactory
<
GUM_SCALAR
>::
endVariableDeclaration
() {
310
if
(
state
() !=
factory_state
::
VARIABLE
) {
311
_illegalStateError_
(
"endVariableDeclaration"
);
312
}
else
if
(
_foo_flag_
&& (
_stringBag_
.
size
() > 4)) {
313
DiscreteVariable
*
var
=
nullptr
;
314
315
// if the current variable is a LabelizedVariable
316
if
(
_stringBag_
[2] ==
"L"
) {
317
LabelizedVariable
*
l
318
=
new
LabelizedVariable
(
_stringBag_
[0], (
_bar_flag_
) ?
_stringBag_
[1] :
""
, 0);
319
320
for
(
size_t
i
= 3;
i
<
_stringBag_
.
size
(); ++
i
) {
321
l
->
addLabel
(
_stringBag_
[
i
]);
322
}
323
324
var
=
l
;
325
// if the current variable is a RangeVariable
326
}
else
if
(
_stringBag_
[2] ==
"R"
) {
327
RangeVariable
*
r
=
new
RangeVariable
(
_stringBag_
[0],
328
(
_bar_flag_
) ?
_stringBag_
[1] :
""
,
329
std
::
stol
(
_stringBag_
[3]),
330
std
::
stol
(
_stringBag_
[4]));
331
332
var
=
r
;
333
// if the current variable is a DiscretizedVariable
334
}
else
if
(
_stringBag_
[2] ==
"D"
) {
335
DiscretizedVariable
<
GUM_SCALAR
>*
d
336
=
new
DiscretizedVariable
<
GUM_SCALAR
>(
_stringBag_
[0],
337
(
_bar_flag_
) ?
_stringBag_
[1] :
""
);
338
339
for
(
size_t
i
= 3;
i
<
_stringBag_
.
size
(); ++
i
) {
340
d
->
addTick
(
std
::
stof
(
_stringBag_
[
i
]));
341
}
342
343
var
=
d
;
344
}
345
346
if
(
_impl_
!= 0) {
347
_varNameMap_
.
insert
(
var
->
name
(),
_bn_
->
add
(*
var
,
_impl_
));
348
_impl_
= 0;
349
}
else
{
350
_varNameMap_
.
insert
(
var
->
name
(),
_bn_
->
add
(*
var
));
351
}
352
353
NodeId
retVal
=
_varNameMap_
[
var
->
name
()];
354
355
delete
var
;
356
357
_resetParts_
();
358
_states_
.
pop_back
();
359
360
return
retVal
;
361
}
else
{
362
std
::
stringstream
msg
;
363
msg
<<
"Not enough modalities ("
;
364
365
if
(
_stringBag_
.
size
() > 3) {
366
msg
<<
_stringBag_
.
size
() - 3;
367
}
else
{
368
msg
<< 0;
369
}
370
371
msg
<<
") declared for variable "
;
372
373
if
(
_foo_flag_
) {
374
msg
<<
_stringBag_
[0];
375
}
else
{
376
msg
<<
"unknown"
;
377
}
378
379
_resetParts_
();
380
381
_states_
.
pop_back
();
382
GUM_ERROR
(
OperationNotAllowed
,
msg
.
str
())
383
}
384
385
// For noisy compilers
386
return
0;
387
}
388
389
// Tells the factory that we're declaring parents for some variable.
390
// @var The concerned variable's name.
391
template
<
typename
GUM_SCALAR
>
392
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startParentsDeclaration
(
const
std
::
string
&
var
) {
393
if
(
state
() !=
factory_state
::
NONE
) {
394
_illegalStateError_
(
"startParentsDeclaration"
);
395
}
else
{
396
_checkVariableName_
(
var
);
397
std
::
vector
<
std
::
string
>::
iterator
iter
=
_stringBag_
.
begin
();
398
_stringBag_
.
insert
(
iter
,
var
);
399
_states_
.
push_back
(
factory_state
::
PARENTS
);
400
}
401
}
402
403
// Tells the factory for which variable we're declaring parents.
404
// @var The parent's name.
405
// @throw NotFound Raised if var does not exists.
406
template
<
typename
GUM_SCALAR
>
407
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addParent
(
const
std
::
string
&
var
) {
408
if
(
state
() !=
factory_state
::
PARENTS
) {
409
_illegalStateError_
(
"addParent"
);
410
}
else
{
411
_checkVariableName_
(
var
);
412
_stringBag_
.
push_back
(
var
);
413
}
414
}
415
416
// Tells the factory that we've finished declaring parents for some
417
// variable. When parents exist, endParentsDeclaration creates some arcs.
418
// These arcs are created in the inverse order of the order of the parent
419
// specification.
420
template
<
typename
GUM_SCALAR
>
421
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endParentsDeclaration
() {
422
if
(
state
() !=
factory_state
::
PARENTS
) {
423
_illegalStateError_
(
"endParentsDeclaration"
);
424
}
else
{
425
NodeId
id
=
_varNameMap_
[
_stringBag_
[0]];
426
427
// PLEASE NOTE THAT THE ORDER IS INVERSE
428
429
for
(
size_t
i
=
_stringBag_
.
size
() - 1;
i
> 0; --
i
) {
430
_bn_
->
addArc
(
_varNameMap_
[
_stringBag_
[
i
]],
id
);
431
}
432
433
_resetParts_
();
434
435
_states_
.
pop_back
();
436
}
437
}
438
439
// Tells the factory that we're declaring a conditional probability table
440
// for some variable.
441
// @param var The concerned variable's name.
442
template
<
typename
GUM_SCALAR
>
443
INLINE
void
444
BayesNetFactory
<
GUM_SCALAR
>::
startRawProbabilityDeclaration
(
const
std
::
string
&
var
) {
445
if
(
state
() !=
factory_state
::
NONE
) {
446
_illegalStateError_
(
"startRawProbabilityDeclaration"
);
447
}
else
{
448
_checkVariableName_
(
var
);
449
_stringBag_
.
push_back
(
var
);
450
_states_
.
push_back
(
factory_state
::
RAW_CPT
);
451
}
452
}
453
454
// @brief Fills the variable's table with the values in rawTable.
455
// Parse the parents in the same order in which they were added to the
456
// variable.
457
// Given a sequence [var, p_1, p_2, ...,p_n-1, p_n] of parents, modalities are
458
// parsed
459
// in the given order (if all p_i are binary):
460
// [0, 0, ..., 0, 0], [0, 0, ..., 0, 1],
461
// [0, 0, ..., 1, 0], [0, 0, ..., 1, 1],
462
// ...,
463
// [1, 1, ..., 1, 0], [1, 1, ..., 1, 1].
464
// @param rawTable The raw table.
465
template
<
typename
GUM_SCALAR
>
466
INLINE
void
467
BayesNetFactory
<
GUM_SCALAR
>::
rawConditionalTable
(
const
std
::
vector
<
std
::
string
>&
variables
,
468
const
std
::
vector
<
float
>&
rawTable
) {
469
if
(
state
() !=
factory_state
::
RAW_CPT
) {
470
_illegalStateError_
(
"rawConditionalTable"
);
471
}
else
{
472
_fillProbaWithValuesTable_
(
variables
,
rawTable
);
473
}
474
}
475
476
template
<
typename
GUM_SCALAR
>
477
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
_fillProbaWithValuesTable_
(
478
const
std
::
vector
<
std
::
string
>&
variables
,
479
const
std
::
vector
<
float
>&
rawTable
) {
480
const
Potential
<
GUM_SCALAR
>&
table
=
_bn_
->
cpt
(
_varNameMap_
[
_stringBag_
[0]]);
481
Instantiation
cptInst
(
table
);
482
483
List
<
const
DiscreteVariable
* >
varList
;
484
485
for
(
size_t
i
= 0;
i
<
variables
.
size
(); ++
i
) {
486
varList
.
pushBack
(&(
_bn_
->
variable
(
_varNameMap_
[
variables
[
i
]])));
487
}
488
489
// varList.pushFront(&( _bn_->variable( _varNameMap_[ _stringBag_[0]])));
490
491
Idx
nbrVar
=
varList
.
size
();
492
493
std
::
vector
<
Idx
>
modCounter
;
494
495
// initializing the array
496
for
(
NodeId
i
= 0;
i
<
nbrVar
;
i
++) {
497
modCounter
.
push_back
(
Idx
(0));
498
}
499
500
Idx
j
= 0;
501
502
do
{
503
for
(
NodeId
i
= 0;
i
<
nbrVar
;
i
++) {
504
cptInst
.
chgVal
(*(
varList
[
i
]),
modCounter
[
i
]);
505
}
506
507
if
(
j
<
rawTable
.
size
()) {
508
table
.
set
(
cptInst
, (
GUM_SCALAR
)
rawTable
[
j
]);
509
}
else
{
510
table
.
set
(
cptInst
, (
GUM_SCALAR
)0);
511
}
512
513
j
++;
514
}
while
(
_increment_
(
modCounter
,
varList
));
515
}
516
517
template
<
typename
GUM_SCALAR
>
518
INLINE
void
519
BayesNetFactory
<
GUM_SCALAR
>::
rawConditionalTable
(
const
std
::
vector
<
float
>&
rawTable
) {
520
if
(
state
() !=
factory_state
::
RAW_CPT
) {
521
_illegalStateError_
(
"rawConditionalTable"
);
522
}
else
{
523
_fillProbaWithValuesTable_
(
rawTable
);
524
}
525
}
526
527
template
<
typename
GUM_SCALAR
>
528
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
_fillProbaWithValuesTable_
(
529
const
std
::
vector
<
float
>&
rawTable
) {
530
const
Potential
<
GUM_SCALAR
>&
table
=
_bn_
->
cpt
(
_varNameMap_
[
_stringBag_
[0]]);
531
532
Instantiation
cptInst
(
table
);
533
534
// the main loop is on the first variables. The others are in the right
535
// order.
536
const
DiscreteVariable
&
first
=
table
.
variable
(0);
537
Idx
j
= 0;
538
539
for
(
cptInst
.
setFirstVar
(
first
); !
cptInst
.
end
();
cptInst
.
incVar
(
first
)) {
540
for
(
cptInst
.
setFirstNotVar
(
first
); !
cptInst
.
end
();
cptInst
.
incNotVar
(
first
))
541
table
.
set
(
cptInst
, (
j
<
rawTable
.
size
()) ? (
GUM_SCALAR
)
rawTable
[
j
++] : (
GUM_SCALAR
)0);
542
543
cptInst
.
unsetEnd
();
544
}
545
}
546
547
template
<
typename
GUM_SCALAR
>
548
INLINE
bool
BayesNetFactory
<
GUM_SCALAR
>::
_increment_
(
std
::
vector
<
gum
::
Idx
>&
modCounter
,
549
List
<
const
DiscreteVariable
* >&
varList
) {
550
bool
last
=
true
;
551
552
for
(
NodeId
j
= 0;
j
<
modCounter
.
size
();
j
++) {
553
last
= (
modCounter
[
j
] == (
varList
[
j
]->
domainSize
() - 1)) &&
last
;
554
555
if
(!
last
)
break
;
556
}
557
558
if
(
last
) {
return
false
; }
559
560
bool
add
=
false
;
561
562
NodeId
i
=
NodeId
(
varList
.
size
() - 1);
563
564
do
{
565
if
(
modCounter
[
i
] == (
varList
[
i
]->
domainSize
() - 1)) {
566
modCounter
[
i
] = 0;
567
add
=
true
;
568
}
else
{
569
modCounter
[
i
] += 1;
570
add
=
false
;
571
}
572
573
i
--;
574
}
while
(
add
);
575
576
return
true
;
577
}
578
579
// Tells the factory that we finished declaring a conditional probability
580
// table.
581
template
<
typename
GUM_SCALAR
>
582
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endRawProbabilityDeclaration
() {
583
if
(
state
() !=
factory_state
::
RAW_CPT
) {
584
_illegalStateError_
(
"endRawProbabilityDeclaration"
);
585
}
else
{
586
_resetParts_
();
587
_states_
.
pop_back
();
588
}
589
}
590
591
// Tells the factory that we're starting a factorized declaration.
592
template
<
typename
GUM_SCALAR
>
593
INLINE
void
594
BayesNetFactory
<
GUM_SCALAR
>::
startFactorizedProbabilityDeclaration
(
const
std
::
string
&
var
) {
595
if
(
state
() !=
factory_state
::
NONE
) {
596
_illegalStateError_
(
"startFactorizedProbabilityDeclaration"
);
597
}
else
{
598
_checkVariableName_
(
var
);
599
std
::
vector
<
std
::
string
>::
iterator
iter
=
_stringBag_
.
begin
();
600
_stringBag_
.
insert
(
iter
,
var
);
601
_states_
.
push_back
(
factory_state
::
FACT_CPT
);
602
}
603
}
604
605
// Tells the factory that we start an entry of a factorized conditional
606
// probability table.
607
template
<
typename
GUM_SCALAR
>
608
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startFactorizedEntry
() {
609
if
(
state
() !=
factory_state
::
FACT_CPT
) {
610
_illegalStateError_
(
"startFactorizedEntry"
);
611
}
else
{
612
_parents_
=
new
Instantiation
();
613
_states_
.
push_back
(
factory_state
::
FACT_ENTRY
);
614
}
615
}
616
617
// Tells the factory that we finished declaring a conditional probability
618
// table.
619
template
<
typename
GUM_SCALAR
>
620
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endFactorizedEntry
() {
621
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
622
_illegalStateError_
(
"endFactorizedEntry"
);
623
}
else
{
624
delete
_parents_
;
625
_parents_
= 0;
626
_states_
.
pop_back
();
627
}
628
}
629
630
// Tells the factory on which modality we want to instantiate one of
631
// variable's parent.
632
template
<
typename
GUM_SCALAR
>
633
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setParentModality
(
const
std
::
string
&
parent
,
634
const
std
::
string
&
modality
) {
635
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
636
_illegalStateError_
(
"string"
);
637
}
else
{
638
_checkVariableName_
(
parent
);
639
Idx
id
=
_checkVariableModality_
(
parent
,
modality
);
640
(*
_parents_
) <<
_bn_
->
variable
(
_varNameMap_
[
parent
]);
641
_parents_
->
chgVal
(
_bn_
->
variable
(
_varNameMap_
[
parent
]),
id
);
642
}
643
}
644
645
// @brief Gives the values of the variable with respect to precedent
646
// parents modality.
647
// If some parents have no modality set, then we apply values for all
648
// instantiations of that parent.
649
//
650
// This means you can declare a default value for the table by doing
651
// @code
652
// BayesNetFactory factory;
653
// // Do stuff
654
// factory.startVariableDeclaration();
655
// factory.variableName("foo");
656
// factory.endVariableDeclaration();
657
// factory.startParentsDeclaration("foo");
658
// // add parents
659
// factory.endParentsDeclaration();
660
// factory.startFactorizedProbabilityDeclaration("foo");
661
// std::vector<float> seq;
662
// seq.insert(0.4); // if foo true
663
// seq.insert(O.6); // if foo false
664
// factory.setVariableValues(seq); // fills the table with a default value
665
// // finish your stuff
666
// factory.endFactorizedProbabilityDeclaration();
667
// @code
668
// as for raw Probability, if value's size is different than the number of
669
// modalities of the current variable, we don't use the supplementary values and
670
// we fill by 0 the missing values.
671
template
<
typename
GUM_SCALAR
>
672
INLINE
void
673
BayesNetFactory
<
GUM_SCALAR
>::
setVariableValuesUnchecked
(
const
std
::
vector
<
float
>&
values
) {
674
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
675
_illegalStateError_
(
"setVariableValues"
);
676
}
else
{
677
const
DiscreteVariable
&
var
=
_bn_
->
variable
(
_varNameMap_
[
_stringBag_
[0]]);
678
NodeId
varId
=
_varNameMap_
[
_stringBag_
[0]];
679
680
if
(
_parents_
->
domainSize
() > 0) {
681
Instantiation
inst
(
_bn_
->
cpt
(
_varNameMap_
[
var
.
name
()]));
682
inst
.
setVals
(*
_parents_
);
683
// Creating an instantiation containing all the variables not ins
684
// _parents_.
685
Instantiation
inst_default
;
686
inst_default
<<
var
;
687
688
for
(
auto
node
:
_bn_
->
parents
(
varId
)) {
689
if
(!
_parents_
->
contains
(
_bn_
->
variable
(
node
))) {
inst_default
<<
_bn_
->
variable
(
node
); }
690
}
691
692
// Filling the variable's table.
693
for
(
inst
.
setFirstIn
(
inst_default
); !
inst
.
end
();
inst
.
incIn
(
inst_default
)) {
694
(
_bn_
->
cpt
(
varId
))
695
.
set
(
inst
,
696
inst
.
val
(
var
) <
values
.
size
() ? (
GUM_SCALAR
)
values
[
inst
.
val
(
var
)]
697
: (
GUM_SCALAR
)0);
698
}
699
}
else
{
700
Instantiation
inst
(
_bn_
->
cpt
(
_varNameMap_
[
var
.
name
()]));
701
Instantiation
var_inst
;
702
var_inst
<<
var
;
703
704
for
(
var_inst
.
setFirst
(); !
var_inst
.
end
(); ++
var_inst
) {
705
inst
.
setVals
(
var_inst
);
706
707
for
(
inst
.
setFirstOut
(
var_inst
); !
inst
.
end
();
inst
.
incOut
(
var_inst
)) {
708
(
_bn_
->
cpt
(
varId
))
709
.
set
(
inst
,
710
inst
.
val
(
var
) <
values
.
size
() ? (
GUM_SCALAR
)
values
[
inst
.
val
(
var
)]
711
: (
GUM_SCALAR
)0);
712
}
713
}
714
}
715
}
716
}
717
718
template
<
typename
GUM_SCALAR
>
719
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setVariableValues
(
const
std
::
vector
<
float
>&
values
) {
720
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
721
_illegalStateError_
(
"setVariableValues"
);
722
}
else
{
723
const
DiscreteVariable
&
var
=
_bn_
->
variable
(
_varNameMap_
[
_stringBag_
[0]]);
724
// Checking consistency between values and var.
725
726
if
(
values
.
size
() !=
var
.
domainSize
()) {
727
GUM_ERROR
(
OperationNotAllowed
,
728
var
.
name
() <<
" : invalid number of modalities: found "
<<
values
.
size
()
729
<<
" while needed "
<<
var
.
domainSize
())
730
}
731
732
setVariableValuesUnchecked
(
values
);
733
}
734
}
735
736
// Tells the factory that we finished declaring a conditional probability
737
// table.
738
template
<
typename
GUM_SCALAR
>
739
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endFactorizedProbabilityDeclaration
() {
740
if
(
state
() !=
factory_state
::
FACT_CPT
) {
741
_illegalStateError_
(
"endFactorizedProbabilityDeclaration"
);
742
}
else
{
743
_resetParts_
();
744
_states_
.
pop_back
();
745
}
746
}
747
748
// @brief Define a variable.
749
// You can only call this method is the factory is in the NONE or NETWORK
750
// state.
751
// The variable is added by copy.
752
// @param var The pointer over a DiscreteVariable used to define a new
753
// variable in the built BayesNet.
754
// @throw DuplicateElement Raised if a variable with the same name already
755
// exists.
756
// @throw OperationNotAllowed Raised if redefineParents == false and if table
757
// is not a valid CPT for var in the current state
758
// of the BayesNet.
759
template
<
typename
GUM_SCALAR
>
760
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setVariable
(
const
DiscreteVariable
&
var
) {
761
if
((
state
() !=
factory_state
::
NONE
)) {
762
_illegalStateError_
(
"setVariable"
);
763
}
else
{
764
try
{
765
_checkVariableName_
(
var
.
name
());
766
GUM_ERROR
(
DuplicateElement
,
"Name already used: "
<<
var
.
name
())
767
}
catch
(
NotFound
&) {
768
// The var name is unused
769
_varNameMap_
.
insert
(
var
.
name
(),
_bn_
->
add
(
var
));
770
}
771
}
772
}
773
774
// @brief Define a variable's CPT.
775
// You can only call this method if the factory is in the NONE or NETWORK
776
// state.
777
// Be careful that table is given to the built BayesNet, so it will be
778
// deleted with it, and you should not directly access it after you call
779
// this method.
780
// When the redefineParents flag is set to true the constructed BayesNet's
781
// DAG is changed to fit with table's definition.
782
// @param var The name of the concerned variable.
783
// @param table A pointer over the CPT used for var.
784
// @param redefineParents If true redefine parents of the variable to match
785
// table's
786
// variables set.
787
//
788
// @throw NotFound Raised if no variable matches var.
789
// @throw OperationNotAllowed Raised if redefineParents == false and if table
790
// is not a valid CPT for var in the current state
791
// of the BayesNet.
792
template
<
typename
GUM_SCALAR
>
793
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setVariableCPT
(
const
std
::
string
&
varName
,
794
MultiDimAdressable
*
table
,
795
bool
redefineParents
) {
796
auto
pot
=
dynamic_cast
<
Potential
<
GUM_SCALAR
>* >(
table
);
797
798
if
(
state
() !=
factory_state
::
NONE
) {
799
_illegalStateError_
(
"setVariableCPT"
);
800
}
else
{
801
_checkVariableName_
(
varName
);
802
const
DiscreteVariable
&
var
=
_bn_
->
variable
(
_varNameMap_
[
varName
]);
803
NodeId
varId
=
_varNameMap_
[
varName
];
804
// If we have to change the structure of the BayesNet, then we call a sub
805
// method.
806
807
if
(
redefineParents
) {
808
_setCPTAndParents_
(
var
,
pot
);
809
}
else
if
(
pot
->
contains
(
var
)) {
810
for
(
auto
node
:
_bn_
->
parents
(
varId
)) {
811
if
(!
pot
->
contains
(
_bn_
->
variable
(
node
))) {
812
GUM_ERROR
(
OperationNotAllowed
,
"The CPT is not valid in the current BayesNet."
)
813
}
814
}
815
816
// CPT are created when a variable is added.
817
_bn_
->
_unsafeChangePotential_
(
varId
,
pot
);
818
}
819
}
820
}
821
822
// Raise an OperationNotAllowed with the message "Illegal state."
823
template
<
typename
GUM_SCALAR
>
824
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
_illegalStateError_
(
const
std
::
string
&
s
) {
825
std
::
string
msg
=
"Illegal state call ("
;
826
msg
+=
s
;
827
msg
+=
") in state "
;
828
829
switch
(
state
()) {
830
case
factory_state
::
NONE
: {
831
msg
+=
"NONE"
;
832
break
;
833
}
834
835
case
factory_state
::
NETWORK
: {
836
msg
+=
"NETWORK"
;
837
break
;
838
}
839
840
case
factory_state
::
VARIABLE
: {
841
msg
+=
"VARIABLE"
;
842
break
;
843
}
844
845
case
factory_state
::
PARENTS
: {
846
msg
+=
"PARENTS"
;
847
break
;
848
}
849
850
case
factory_state
::
RAW_CPT
: {
851
msg
+=
"RAW_CPT"
;
852
break
;
853
}
854
855
case
factory_state
::
FACT_CPT
: {
856
msg
+=
"FACT_CPT"
;
857
break
;
858
}
859
860
case
factory_state
::
FACT_ENTRY
: {
861
msg
+=
"FACT_ENTRY"
;
862
break
;
863
}
864
865
default
: {
866
msg
+=
"Unknown state"
;
867
}
868
}
869
870
GUM_ERROR
(
OperationNotAllowed
,
msg
)
871
}
872
873
// Check if a variable with the given name exists, if not raise an NotFound
874
// exception.
875
template
<
typename
GUM_SCALAR
>
876
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
_checkVariableName_
(
const
std
::
string
&
name
) {
877
if
(!
_varNameMap_
.
exists
(
name
)) {
GUM_ERROR
(
NotFound
,
name
) }
878
}
879
880
// Check if var exists and if mod is one of it's modality, if not raise an
881
// NotFound exception.
882
template
<
typename
GUM_SCALAR
>
883
INLINE
Idx
BayesNetFactory
<
GUM_SCALAR
>::
_checkVariableModality_
(
const
std
::
string
&
name
,
884
const
std
::
string
&
mod
) {
885
_checkVariableName_
(
name
);
886
const
DiscreteVariable
&
var
=
_bn_
->
variable
(
_varNameMap_
[
name
]);
887
888
for
(
Idx
i
= 0;
i
<
var
.
domainSize
(); ++
i
) {
889
if
(
mod
==
var
.
label
(
i
)) {
return
i
; }
890
}
891
892
GUM_ERROR
(
NotFound
,
mod
)
893
}
894
895
// Check if in _stringBag_ there is no other modality with the same name.
896
template
<
typename
GUM_SCALAR
>
897
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
_checkModalityInBag_
(
const
std
::
string
&
mod
) {
898
for
(
size_t
i
= 3;
i
<
_stringBag_
.
size
(); ++
i
) {
899
if
(
mod
==
_stringBag_
[
i
]) {
GUM_ERROR
(
DuplicateElement
,
"Label already used: "
<<
mod
) }
900
}
901
}
902
903
// Sub method of setVariableCPT() which redefine the BayesNet's DAG with
904
// respect to table.
905
template
<
typename
GUM_SCALAR
>
906
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
_setCPTAndParents_
(
const
DiscreteVariable
&
var
,
907
Potential
<
GUM_SCALAR
>*
table
) {
908
NodeId
varId
=
_varNameMap_
[
var
.
name
()];
909
_bn_
->
dag_
.
eraseParents
(
varId
);
910
911
for
(
auto
v
:
table
->
variablesSequence
()) {
912
if
(
v
!= (&
var
)) {
913
_checkVariableName_
(
v
->
name
());
914
_bn_
->
dag_
.
addArc
(
_varNameMap_
[
v
->
name
()],
varId
);
915
}
916
}
917
918
// CPT are created when a variable is added.
919
_bn_
->
_unsafeChangePotential_
(
varId
,
table
);
920
}
921
922
// Reset the different parts used to constructed the BayesNet.
923
template
<
typename
GUM_SCALAR
>
924
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
_resetParts_
() {
925
_foo_flag_
=
false
;
926
_bar_flag_
=
false
;
927
_stringBag_
.
clear
();
928
}
929
}
/* namespace gum */
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:643