aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
BayesNetFactory_tpl.h
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of the BayesNetFactory class.
25
*
26
* @author Lionel TORTI and Pierre-Henri WUILLEMIN(@LIP6)
27
28
*/
29
30
#
include
<
agrum
/
BN
/
BayesNetFactory
.
h
>
31
32
#
define
VERBOSITY
(
x
)
33
{
34
if
(
isVerbose
(
)
)
std
::
cerr
<<
"[BN factory] "
<<
x
<<
std
::
endl
;
35
}
36
37
namespace
gum
{
38
39
// Default constructor.
40
// @param bn A pointer over the BayesNet filled by this factory.
41
// @throw DuplicateElement Raised if two variables in bn share the same
42
// name.
43
template
<
typename
GUM_SCALAR >
44
INLINE
45
BayesNetFactory<
GUM_SCALAR
>::
BayesNetFactory
(
BayesNet
<
GUM_SCALAR
>*
bn
) :
46
parents__
(0),
47
impl__
(0),
bn__
(
bn
) {
48
GUM_CONSTRUCTOR
(
BayesNetFactory
);
49
states__
.
push_back
(
factory_state
::
NONE
);
50
51
for
(
auto
node
:
bn
->
nodes
()) {
52
if
(
varNameMap__
.
exists
(
bn
->
variable
(
node
).
name
()))
53
GUM_ERROR
(
DuplicateElement
,
54
"Name already used: "
<<
bn
->
variable
(
node
).
name
());
55
56
varNameMap__
.
insert
(
bn
->
variable
(
node
).
name
(),
node
);
57
}
58
59
resetVerbose
();
60
}
61
62
// Copy constructor.
63
// The copy will have an exact copy of the constructed BayesNet in source.
64
template
<
typename
GUM_SCALAR
>
65
INLINE
BayesNetFactory
<
GUM_SCALAR
>::
BayesNetFactory
(
66
const
BayesNetFactory
<
GUM_SCALAR
>&
source
) :
67
parents__
(0),
68
impl__
(0),
bn__
(0) {
69
GUM_CONS_CPY
(
BayesNetFactory
);
70
71
if
(
source
.
state
() !=
factory_state
::
NONE
) {
72
GUM_ERROR
(
OperationNotAllowed
,
"Illegal state to proceed make a copy."
);
73
}
else
{
74
states__
=
source
.
states__
;
75
bn__
=
new
BayesNet
<
GUM_SCALAR
>(*(
source
.
bn__
));
76
}
77
}
78
79
// Destructor
80
template
<
typename
GUM_SCALAR
>
81
INLINE
BayesNetFactory
<
GUM_SCALAR
>::~
BayesNetFactory
() {
82
GUM_DESTRUCTOR
(
BayesNetFactory
);
83
84
if
(
parents__
!= 0)
delete
parents__
;
85
86
if
(
impl__
!= 0) {
87
//@todo better than throwing an exception from inside a destructor but
88
// still ...
89
std
::
cerr
90
<<
"[BN factory] Implementation defined for a variable but not used. "
91
"You should call endVariableDeclaration() before "
92
"deleting me."
93
<<
std
::
endl
;
94
exit
(1000);
95
}
96
}
97
98
// Returns the BayesNet created by this factory.
99
template
<
typename
GUM_SCALAR
>
100
INLINE
BayesNet
<
GUM_SCALAR
>*
BayesNetFactory
<
GUM_SCALAR
>::
bayesNet
() {
101
return
bn__
;
102
}
103
104
template
<
typename
GUM_SCALAR
>
105
INLINE
const
DiscreteVariable
&
106
BayesNetFactory
<
GUM_SCALAR
>::
varInBN
(
NodeId
id
) {
107
return
bn__
->
variable
(
id
);
108
}
109
110
// Returns the current state of the factory.
111
template
<
typename
GUM_SCALAR
>
112
INLINE
IBayesNetFactory
::
factory_state
113
BayesNetFactory
<
GUM_SCALAR
>::
state
()
const
{
114
// This is ok because there is alway at least the state NONE in the stack.
115
return
states__
.
back
();
116
}
117
118
// Returns the NodeId of a variable given it's name.
119
// @throw NotFound Raised if no variable matches the name.
120
template
<
typename
GUM_SCALAR
>
121
INLINE
NodeId
122
BayesNetFactory
<
GUM_SCALAR
>::
variableId
(
const
std
::
string
&
name
)
const
{
123
try
{
124
return
varNameMap__
[
name
];
125
}
catch
(
NotFound
&) {
GUM_ERROR
(
NotFound
,
name
); }
126
}
127
128
// Returns a constant reference on a variable given it's name.
129
// @throw NotFound Raised if no variable matches the name.
130
template
<
typename
GUM_SCALAR
>
131
INLINE
const
DiscreteVariable
&
132
BayesNetFactory
<
GUM_SCALAR
>::
variable
(
const
std
::
string
&
name
)
const
{
133
try
{
134
return
bn__
->
variable
(
variableId
(
name
));
135
}
catch
(
NotFound
&) {
GUM_ERROR
(
NotFound
,
name
); }
136
}
137
138
// Returns the domainSize of the cpt for the node n.
139
// @throw NotFound raised if no such NodeId exists.
140
// @throw OperationNotAllowed if there is no Bayesian networks.
141
template
<
typename
GUM_SCALAR
>
142
INLINE
Size
BayesNetFactory
<
GUM_SCALAR
>::
cptDomainSize
(
const
NodeId
n
)
const
{
143
return
bn__
->
cpt
(
n
).
domainSize
();
144
}
145
146
// Tells the factory that we're in a network declaration.
147
template
<
typename
GUM_SCALAR
>
148
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startNetworkDeclaration
() {
149
if
(
state
() !=
factory_state
::
NONE
) {
150
illegalStateError__
(
"startNetworkDeclaration"
);
151
}
else
{
152
states__
.
push_back
(
factory_state
::
NETWORK
);
153
}
154
155
VERBOSITY
(
"starting network"
);
156
}
157
158
// Tells the factory to add a property to the current network.
159
template
<
typename
GUM_SCALAR
>
160
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addNetworkProperty
(
161
const
std
::
string
&
propName
,
162
const
std
::
string
&
propValue
) {
163
bn__
->
setProperty
(
propName
,
propValue
);
164
}
165
166
// Tells the factory that we're out of a network declaration.
167
template
<
typename
GUM_SCALAR
>
168
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endNetworkDeclaration
() {
169
if
(
state
() !=
factory_state
::
NETWORK
) {
170
illegalStateError__
(
"endNetworkDeclaration"
);
171
}
else
{
172
states__
.
pop_back
();
173
}
174
175
VERBOSITY
(
"network OK"
);
176
}
177
178
// Tells the factory that we're in a variable declaration.
179
// A variable is considered as a LabelizedVariable while its type is not defined.
180
template
<
typename
GUM_SCALAR
>
181
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startVariableDeclaration
() {
182
if
(
state
() !=
factory_state
::
NONE
) {
183
illegalStateError__
(
"startVariableDeclaration"
);
184
}
else
{
185
states__
.
push_back
(
factory_state
::
VARIABLE
);
186
stringBag__
.
push_back
(
"name"
);
187
stringBag__
.
push_back
(
"desc"
);
188
stringBag__
.
push_back
(
"L"
);
189
}
190
191
VERBOSITY
(
" starting variable"
);
192
}
193
194
// Tells the factory the current variable's name.
195
template
<
typename
GUM_SCALAR
>
196
INLINE
void
197
BayesNetFactory
<
GUM_SCALAR
>::
variableName
(
const
std
::
string
&
name
) {
198
if
(
state
() !=
factory_state
::
VARIABLE
) {
199
illegalStateError__
(
"variableName"
);
200
}
else
{
201
if
(
varNameMap__
.
exists
(
name
)) {
202
GUM_ERROR
(
DuplicateElement
,
"Name already used: "
<<
name
);
203
}
204
205
foo_flag__
=
true
;
206
stringBag__
[0] =
name
;
207
VERBOSITY
(
" -- variable "
<<
name
);
208
}
209
}
210
211
// Tells the factory the current variable's description.
212
template
<
typename
GUM_SCALAR
>
213
INLINE
void
214
BayesNetFactory
<
GUM_SCALAR
>::
variableDescription
(
const
std
::
string
&
desc
) {
215
if
(
state
() !=
factory_state
::
VARIABLE
) {
216
illegalStateError__
(
"variableDescription"
);
217
}
else
{
218
bar_flag__
=
true
;
219
stringBag__
[1] =
desc
;
220
}
221
}
222
223
// Tells the factory the current variable's type.
224
// L : Labelized
225
// R : Range
226
// C : Continuous
227
// D : Discretized
228
template
<
typename
GUM_SCALAR
>
229
INLINE
void
230
BayesNetFactory
<
GUM_SCALAR
>::
variableType
(
const
gum
::
VarType
&
type
) {
231
if
(
state
() !=
factory_state
::
VARIABLE
) {
232
illegalStateError__
(
"variableType"
);
233
}
else
{
234
switch
(
type
) {
235
case
VarType
::
Discretized
:
236
stringBag__
[2] =
"D"
;
237
break
;
238
case
VarType
::
Range
:
239
stringBag__
[2] =
"R"
;
240
break
;
241
case
VarType
::
Continuous
:
242
GUM_ERROR
(
OperationNotAllowed
,
243
"Continuous variable ("
+
stringBag__
[0]
244
+
") are not supported in Bayesian networks."
);
245
case
VarType
::
Labelized
:
246
stringBag__
[2] =
"L"
;
247
break
;
248
}
249
}
250
}
251
252
// Adds a modality to the current variable.
253
// @throw DuplicateElement If the current variable already has a modality
254
// with the same name.
255
template
<
typename
GUM_SCALAR
>
256
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addModality
(
const
std
::
string
&
name
) {
257
if
(
state
() !=
factory_state
::
VARIABLE
) {
258
illegalStateError__
(
"addModality"
);
259
}
else
{
260
checkModalityInBag__
(
name
);
261
stringBag__
.
push_back
(
name
);
262
}
263
}
264
265
// Adds a modality to the current variable.
266
// @throw DuplicateElement If the current variable already has a modality
267
// with the same name.
268
template
<
typename
GUM_SCALAR
>
269
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addMin
(
const
long
&
min
) {
270
if
(
state
() !=
factory_state
::
VARIABLE
) {
271
illegalStateError__
(
"addMin"
);
272
}
else
{
273
stringBag__
.
push_back
(
std
::
to_string
(
min
));
274
}
275
}
276
277
// Adds a modality to the current variable.
278
// @throw DuplicateElement If the current variable already has a modality
279
// with the same name.
280
template
<
typename
GUM_SCALAR
>
281
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addMax
(
const
long
&
max
) {
282
if
(
state
() !=
factory_state
::
VARIABLE
) {
283
illegalStateError__
(
"addMin"
);
284
}
else
{
285
stringBag__
.
push_back
(
std
::
to_string
(
max
));
286
}
287
}
288
289
// Adds a modality to the current variable.
290
// @throw DuplicateElement If the current variable already has a modality
291
// with the same name.
292
template
<
typename
GUM_SCALAR
>
293
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addTick
(
const
GUM_SCALAR
&
tick
) {
294
if
(
state
() !=
factory_state
::
VARIABLE
) {
295
illegalStateError__
(
"addTick"
);
296
}
else
{
297
stringBag__
.
push_back
(
std
::
to_string
(
tick
));
298
}
299
}
300
301
// @brief Defines the implementation to use for var's Potential.
302
// @warning The implementation must be empty.
303
// @warning The pointer is always delegated to var's Potential! No copy of it
304
// is made.
305
// @todo When copy of a MultiDimImplementation is available use a copy
306
// behaviour for this method.
307
// @throw NotFound Raised if no variable matches var.
308
// @throw OperationNotAllowed Raised if impl is not empty.
309
// @throw OperationNotAllowed If an implementation is already defined for the
310
// current variable.
311
template
<
typename
GUM_SCALAR
>
312
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setVariableCPTImplementation
(
313
MultiDimAdressable
*
adressable
) {
314
MultiDimImplementation
<
GUM_SCALAR
>*
impl
315
=
dynamic_cast
<
MultiDimImplementation
<
GUM_SCALAR
>* >(
adressable
);
316
317
if
(
state
() !=
factory_state
::
VARIABLE
) {
318
illegalStateError__
(
"setVariableCPTImplementation"
);
319
}
else
{
320
if
(
impl
== 0) {
321
GUM_ERROR
(
OperationNotAllowed
,
322
"An implementation for this variable is already "
323
"defined."
);
324
}
else
if
(
impl
->
nbrDim
() > 0) {
325
GUM_ERROR
(
OperationNotAllowed
,
"This implementation is not empty."
);
326
}
327
328
impl__
=
impl
;
329
}
330
}
331
332
// Tells the factory that we're out of a variable declaration.
333
template
<
typename
GUM_SCALAR
>
334
INLINE
NodeId
BayesNetFactory
<
GUM_SCALAR
>::
endVariableDeclaration
() {
335
if
(
state
() !=
factory_state
::
VARIABLE
) {
336
illegalStateError__
(
"endVariableDeclaration"
);
337
}
else
if
(
foo_flag__
&& (
stringBag__
.
size
() > 4)) {
338
DiscreteVariable
*
var
=
nullptr
;
339
340
// if the current variable is a LabelizedVariable
341
if
(
stringBag__
[2] ==
"L"
) {
342
LabelizedVariable
*
l
343
=
new
LabelizedVariable
(
stringBag__
[0],
344
(
bar_flag__
) ?
stringBag__
[1] :
""
,
345
0);
346
347
for
(
size_t
i
= 3;
i
<
stringBag__
.
size
(); ++
i
) {
348
l
->
addLabel
(
stringBag__
[
i
]);
349
}
350
351
var
=
l
;
352
// if the current variable is a RangeVariable
353
}
else
if
(
stringBag__
[2] ==
"R"
) {
354
RangeVariable
*
r
=
new
RangeVariable
(
stringBag__
[0],
355
(
bar_flag__
) ?
stringBag__
[1] :
""
,
356
std
::
stol
(
stringBag__
[3]),
357
std
::
stol
(
stringBag__
[4]));
358
359
var
=
r
;
360
// if the current variable is a DiscretizedVariable
361
}
else
if
(
stringBag__
[2] ==
"D"
) {
362
DiscretizedVariable
<
GUM_SCALAR
>*
d
363
=
new
DiscretizedVariable
<
GUM_SCALAR
>(
stringBag__
[0],
364
(
bar_flag__
) ?
stringBag__
[1]
365
:
""
);
366
367
for
(
size_t
i
= 3;
i
<
stringBag__
.
size
(); ++
i
) {
368
d
->
addTick
(
std
::
stof
(
stringBag__
[
i
]));
369
}
370
371
var
=
d
;
372
}
373
374
if
(
impl__
!= 0) {
375
varNameMap__
.
insert
(
var
->
name
(),
bn__
->
add
(*
var
,
impl__
));
376
impl__
= 0;
377
}
else
{
378
varNameMap__
.
insert
(
var
->
name
(),
bn__
->
add
(*
var
));
379
}
380
381
NodeId
retVal
=
varNameMap__
[
var
->
name
()];
382
383
delete
var
;
384
385
resetParts__
();
386
states__
.
pop_back
();
387
388
VERBOSITY
(
" variable "
<<
var
->
name
() <<
" OK"
);
389
return
retVal
;
390
}
else
{
391
std
::
stringstream
msg
;
392
msg
<<
"Not enough modalities ("
;
393
394
if
(
stringBag__
.
size
() > 3) {
395
msg
<<
stringBag__
.
size
() - 3;
396
}
else
{
397
msg
<< 0;
398
}
399
400
msg
<<
") declared for variable "
;
401
402
if
(
foo_flag__
) {
403
msg
<<
stringBag__
[0];
404
}
else
{
405
msg
<<
"unknown"
;
406
}
407
408
resetParts__
();
409
410
states__
.
pop_back
();
411
GUM_ERROR
(
OperationNotAllowed
,
msg
.
str
());
412
}
413
414
// For noisy compilers
415
return
0;
416
}
417
418
// Tells the factory that we're declaring parents for some variable.
419
// @var The concerned variable's name.
420
template
<
typename
GUM_SCALAR
>
421
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startParentsDeclaration
(
422
const
std
::
string
&
var
) {
423
if
(
state
() !=
factory_state
::
NONE
) {
424
illegalStateError__
(
"startParentsDeclaration"
);
425
}
else
{
426
checkVariableName__
(
var
);
427
std
::
vector
<
std
::
string
>::
iterator
iter
=
stringBag__
.
begin
();
428
stringBag__
.
insert
(
iter
,
var
);
429
states__
.
push_back
(
factory_state
::
PARENTS
);
430
}
431
432
VERBOSITY
(
"starting parents for "
<<
var
);
433
}
434
435
// Tells the factory for which variable we're declaring parents.
436
// @var The parent's name.
437
// @throw NotFound Raised if var does not exists.
438
template
<
typename
GUM_SCALAR
>
439
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
addParent
(
const
std
::
string
&
var
) {
440
if
(
state
() !=
factory_state
::
PARENTS
) {
441
illegalStateError__
(
"addParent"
);
442
}
else
{
443
checkVariableName__
(
var
);
444
stringBag__
.
push_back
(
var
);
445
}
446
}
447
448
// Tells the factory that we've finished declaring parents for some
449
// variable. When parents exist, endParentsDeclaration creates some arcs.
450
// These arcs are created in the inverse order of the order of the parent
451
// specification.
452
template
<
typename
GUM_SCALAR
>
453
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endParentsDeclaration
() {
454
if
(
state
() !=
factory_state
::
PARENTS
) {
455
illegalStateError__
(
"endParentsDeclaration"
);
456
}
else
{
457
NodeId
id
=
varNameMap__
[
stringBag__
[0]];
458
459
// PLEASE NOTE THAT THE ORDER IS INVERSE
460
461
for
(
size_t
i
=
stringBag__
.
size
() - 1;
i
> 0; --
i
) {
462
bn__
->
addArc
(
varNameMap__
[
stringBag__
[
i
]],
id
);
463
VERBOSITY
(
" adding parent "
<<
stringBag__
[
i
] <<
" for "
464
<<
stringBag__
[0]);
465
}
466
467
resetParts__
();
468
469
states__
.
pop_back
();
470
}
471
472
VERBOSITY
(
"end of parents for "
<<
stringBag__
[0]);
473
}
474
475
// Tells the factory that we're declaring a conditional probability table
476
// for some variable.
477
// @param var The concerned variable's name.
478
template
<
typename
GUM_SCALAR
>
479
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startRawProbabilityDeclaration
(
480
const
std
::
string
&
var
) {
481
if
(
state
() !=
factory_state
::
NONE
) {
482
illegalStateError__
(
"startRawProbabilityDeclaration"
);
483
}
else
{
484
checkVariableName__
(
var
);
485
stringBag__
.
push_back
(
var
);
486
states__
.
push_back
(
factory_state
::
RAW_CPT
);
487
}
488
489
VERBOSITY
(
" cpt starting for "
<<
var
);
490
}
491
492
// @brief Fills the variable's table with the values in rawTable.
493
// Parse the parents in the same order in which they were added to the
494
// variable.
495
// Given a sequence [var, p_1, p_2, ...,p_n-1, p_n] of parents, modalities are
496
// parsed
497
// in the given order (if all p_i are binary):
498
// [0, 0, ..., 0, 0], [0, 0, ..., 0, 1],
499
// [0, 0, ..., 1, 0], [0, 0, ..., 1, 1],
500
// ...,
501
// [1, 1, ..., 1, 0], [1, 1, ..., 1, 1].
502
// @param rawTable The raw table.
503
template
<
typename
GUM_SCALAR
>
504
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
rawConditionalTable
(
505
const
std
::
vector
<
std
::
string
>&
variables
,
506
const
std
::
vector
<
float
>&
rawTable
) {
507
if
(
state
() !=
factory_state
::
RAW_CPT
) {
508
illegalStateError__
(
"rawConditionalTable"
);
509
}
else
{
510
fillProbaWithValuesTable__
(
variables
,
rawTable
);
511
}
512
}
513
514
template
<
typename
GUM_SCALAR
>
515
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
fillProbaWithValuesTable__
(
516
const
std
::
vector
<
std
::
string
>&
variables
,
517
const
std
::
vector
<
float
>&
rawTable
) {
518
const
Potential
<
GUM_SCALAR
>&
table
=
bn__
->
cpt
(
varNameMap__
[
stringBag__
[0]]);
519
Instantiation
cptInst
(
table
);
520
521
List
<
const
DiscreteVariable
* >
varList
;
522
523
for
(
size_t
i
= 0;
i
<
variables
.
size
(); ++
i
) {
524
varList
.
pushBack
(&(
bn__
->
variable
(
varNameMap__
[
variables
[
i
]])));
525
}
526
527
// varList.pushFront(&(bn__->variable(varNameMap__[stringBag__[0]])));
528
529
Idx
nbrVar
=
varList
.
size
();
530
531
std
::
vector
<
Idx
>
modCounter
;
532
533
// initializing the array
534
for
(
NodeId
i
= 0;
i
<
nbrVar
;
i
++) {
535
modCounter
.
push_back
(
Idx
(0));
536
}
537
538
Idx
j
= 0;
539
540
do
{
541
for
(
NodeId
i
= 0;
i
<
nbrVar
;
i
++) {
542
cptInst
.
chgVal
(*(
varList
[
i
]),
modCounter
[
i
]);
543
}
544
545
if
(
j
<
rawTable
.
size
()) {
546
table
.
set
(
cptInst
, (
GUM_SCALAR
)
rawTable
[
j
]);
547
}
else
{
548
table
.
set
(
cptInst
, (
GUM_SCALAR
)0);
549
}
550
551
j
++;
552
}
while
(
increment__
(
modCounter
,
varList
));
553
}
554
555
template
<
typename
GUM_SCALAR
>
556
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
rawConditionalTable
(
557
const
std
::
vector
<
float
>&
rawTable
) {
558
if
(
state
() !=
factory_state
::
RAW_CPT
) {
559
illegalStateError__
(
"rawConditionalTable"
);
560
}
else
{
561
fillProbaWithValuesTable__
(
rawTable
);
562
}
563
}
564
565
template
<
typename
GUM_SCALAR
>
566
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
fillProbaWithValuesTable__
(
567
const
std
::
vector
<
float
>&
rawTable
) {
568
const
Potential
<
GUM_SCALAR
>&
table
=
bn__
->
cpt
(
varNameMap__
[
stringBag__
[0]]);
569
570
Instantiation
cptInst
(
table
);
571
572
// the main loop is on the first variables. The others are in the right
573
// order.
574
const
DiscreteVariable
&
first
=
table
.
variable
(0);
575
Idx
j
= 0;
576
577
for
(
cptInst
.
setFirstVar
(
first
); !
cptInst
.
end
();
cptInst
.
incVar
(
first
)) {
578
for
(
cptInst
.
setFirstNotVar
(
first
); !
cptInst
.
end
();
cptInst
.
incNotVar
(
first
))
579
table
.
set
(
cptInst
,
580
(
j
<
rawTable
.
size
()) ? (
GUM_SCALAR
)
rawTable
[
j
++]
581
: (
GUM_SCALAR
)0);
582
583
cptInst
.
unsetEnd
();
584
}
585
}
586
587
template
<
typename
GUM_SCALAR
>
588
INLINE
bool
BayesNetFactory
<
GUM_SCALAR
>::
increment__
(
589
std
::
vector
<
gum
::
Idx
>&
modCounter
,
590
List
<
const
DiscreteVariable
* >&
varList
) {
591
bool
last
=
true
;
592
593
for
(
NodeId
j
= 0;
j
<
modCounter
.
size
();
j
++) {
594
last
= (
modCounter
[
j
] == (
varList
[
j
]->
domainSize
() - 1)) &&
last
;
595
596
if
(!
last
)
break
;
597
}
598
599
if
(
last
) {
return
false
; }
600
601
bool
add
=
false
;
602
603
NodeId
i
=
NodeId
(
varList
.
size
() - 1);
604
605
do
{
606
if
(
modCounter
[
i
] == (
varList
[
i
]->
domainSize
() - 1)) {
607
modCounter
[
i
] = 0;
608
add
=
true
;
609
}
else
{
610
modCounter
[
i
] += 1;
611
add
=
false
;
612
}
613
614
i
--;
615
}
while
(
add
);
616
617
return
true
;
618
}
619
620
// Tells the factory that we finished declaring a conditional probability
621
// table.
622
template
<
typename
GUM_SCALAR
>
623
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endRawProbabilityDeclaration
() {
624
if
(
state
() !=
factory_state
::
RAW_CPT
) {
625
illegalStateError__
(
"endRawProbabilityDeclaration"
);
626
}
else
{
627
resetParts__
();
628
states__
.
pop_back
();
629
}
630
631
VERBOSITY
(
" cpt ending for "
<<
stringBag__
[0]);
632
}
633
634
// Tells the factory that we're starting a factorized declaration.
635
template
<
typename
GUM_SCALAR
>
636
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startFactorizedProbabilityDeclaration
(
637
const
std
::
string
&
var
) {
638
if
(
state
() !=
factory_state
::
NONE
) {
639
illegalStateError__
(
"startFactorizedProbabilityDeclaration"
);
640
}
else
{
641
checkVariableName__
(
var
);
642
std
::
vector
<
std
::
string
>::
iterator
iter
=
stringBag__
.
begin
();
643
stringBag__
.
insert
(
iter
,
var
);
644
states__
.
push_back
(
factory_state
::
FACT_CPT
);
645
}
646
}
647
648
// Tells the factory that we start an entry of a factorized conditional
649
// probability table.
650
template
<
typename
GUM_SCALAR
>
651
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
startFactorizedEntry
() {
652
if
(
state
() !=
factory_state
::
FACT_CPT
) {
653
illegalStateError__
(
"startFactorizedEntry"
);
654
}
else
{
655
parents__
=
new
Instantiation
();
656
states__
.
push_back
(
factory_state
::
FACT_ENTRY
);
657
}
658
}
659
660
// Tells the factory that we finished declaring a conditional probability
661
// table.
662
template
<
typename
GUM_SCALAR
>
663
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
endFactorizedEntry
() {
664
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
665
illegalStateError__
(
"endFactorizedEntry"
);
666
}
else
{
667
delete
parents__
;
668
parents__
= 0;
669
states__
.
pop_back
();
670
}
671
}
672
673
// Tells the factory on which modality we want to instantiate one of
674
// variable's parent.
675
template
<
typename
GUM_SCALAR
>
676
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setParentModality
(
677
const
std
::
string
&
parent
,
678
const
std
::
string
&
modality
) {
679
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
680
illegalStateError__
(
"string"
);
681
}
else
{
682
checkVariableName__
(
parent
);
683
Idx
id
=
checkVariableModality__
(
parent
,
modality
);
684
(*
parents__
) <<
bn__
->
variable
(
varNameMap__
[
parent
]);
685
parents__
->
chgVal
(
bn__
->
variable
(
varNameMap__
[
parent
]),
id
);
686
}
687
}
688
689
// @brief Gives the values of the variable with respect to precedent
690
// parents modality.
691
// If some parents have no modality set, then we apply values for all
692
// instantiations of that parent.
693
//
694
// This means you can declare a default value for the table by doing
695
// @code
696
// BayesNetFactory factory;
697
// // Do stuff
698
// factory.startVariableDeclaration();
699
// factory.variableName("foo");
700
// factory.endVariableDeclaration();
701
// factory.startParentsDeclaration("foo");
702
// // add parents
703
// factory.endParentsDeclaration();
704
// factory.startFactorizedProbabilityDeclaration("foo");
705
// std::vector<float> seq;
706
// seq.insert(0.4); // if foo true
707
// seq.insert(O.6); // if foo false
708
// factory.setVariableValues(seq); // fills the table with a default value
709
// // finish your stuff
710
// factory.endFactorizedProbabilityDeclaration();
711
// @code
712
// as for rawProba, if value's size is different than the number of modalities
713
// of
714
// the current variable,
715
// we don't use the supplementary values and we fill by 0 the missign values.
716
template
<
typename
GUM_SCALAR
>
717
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setVariableValuesUnchecked
(
718
const
std
::
vector
<
float
>&
values
) {
719
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
720
illegalStateError__
(
"setVariableValues"
);
721
}
else
{
722
const
DiscreteVariable
&
var
=
bn__
->
variable
(
varNameMap__
[
stringBag__
[0]]);
723
NodeId
varId
=
varNameMap__
[
stringBag__
[0]];
724
725
if
(
parents__
->
domainSize
() > 0) {
726
Instantiation
inst
(
bn__
->
cpt
(
varNameMap__
[
var
.
name
()]));
727
inst
.
setVals
(*
parents__
);
728
// Creating an instantiation containing all the variables not ins
729
// parents__.
730
Instantiation
inst_default
;
731
inst_default
<<
var
;
732
733
for
(
auto
node
:
bn__
->
parents
(
varId
)) {
734
if
(!
parents__
->
contains
(
bn__
->
variable
(
node
))) {
735
inst_default
<<
bn__
->
variable
(
node
);
736
}
737
}
738
739
// Filling the variable's table.
740
for
(
inst
.
setFirstIn
(
inst_default
); !
inst
.
end
();
741
inst
.
incIn
(
inst_default
)) {
742
(
bn__
->
cpt
(
varId
))
743
.
set
(
inst
,
744
inst
.
val
(
var
) <
values
.
size
() ? (
GUM_SCALAR
)
values
[
inst
.
val
(
var
)]
745
: (
GUM_SCALAR
)0);
746
}
747
}
else
{
748
Instantiation
inst
(
bn__
->
cpt
(
varNameMap__
[
var
.
name
()]));
749
Instantiation
var_inst
;
750
var_inst
<<
var
;
751
752
for
(
var_inst
.
setFirst
(); !
var_inst
.
end
(); ++
var_inst
) {
753
inst
.
setVals
(
var_inst
);
754
755
for
(
inst
.
setFirstOut
(
var_inst
); !
inst
.
end
();
inst
.
incOut
(
var_inst
)) {
756
(
bn__
->
cpt
(
varId
))
757
.
set
(
inst
,
758
inst
.
val
(
var
) <
values
.
size
()
759
? (
GUM_SCALAR
)
values
[
inst
.
val
(
var
)]
760
: (
GUM_SCALAR
)0);
761
}
762
}
763
}
764
}
765
}
766
767
template
<
typename
GUM_SCALAR
>
768
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setVariableValues
(
769
const
std
::
vector
<
float
>&
values
) {
770
if
(
state
() !=
factory_state
::
FACT_ENTRY
) {
771
illegalStateError__
(
"setVariableValues"
);
772
}
else
{
773
const
DiscreteVariable
&
var
=
bn__
->
variable
(
varNameMap__
[
stringBag__
[0]]);
774
// Checking consistency between values and var.
775
776
if
(
values
.
size
() !=
var
.
domainSize
()) {
777
GUM_ERROR
(
OperationNotAllowed
,
778
var
.
name
()
779
<<
" : invalid number of modalities: found "
<<
values
.
size
()
780
<<
" while needed "
<<
var
.
domainSize
());
781
}
782
783
setVariableValuesUnchecked
(
values
);
784
}
785
}
786
787
// Tells the factory that we finished declaring a conditional probability
788
// table.
789
template
<
typename
GUM_SCALAR
>
790
INLINE
void
791
BayesNetFactory
<
GUM_SCALAR
>::
endFactorizedProbabilityDeclaration
() {
792
if
(
state
() !=
factory_state
::
FACT_CPT
) {
793
illegalStateError__
(
"endFactorizedProbabilityDeclaration"
);
794
}
else
{
795
resetParts__
();
796
states__
.
pop_back
();
797
}
798
}
799
800
// @brief Define a variable.
801
// You can only call this method is the factory is in the NONE or NETWORK
802
// state.
803
// The variable is added by copy.
804
// @param var The pointer over a DiscreteVariable used to define a new
805
// variable in the built BayesNet.
806
// @throw DuplicateElement Raised if a variable with the same name already
807
// exists.
808
// @throw OperationNotAllowed Raised if redefineParents == false and if table
809
// is not a valid CPT for var in the current state
810
// of the BayesNet.
811
template
<
typename
GUM_SCALAR
>
812
INLINE
void
813
BayesNetFactory
<
GUM_SCALAR
>::
setVariable
(
const
DiscreteVariable
&
var
) {
814
if
((
state
() !=
factory_state
::
NONE
)) {
815
illegalStateError__
(
"setVariable"
);
816
}
else
{
817
try
{
818
checkVariableName__
(
var
.
name
());
819
GUM_ERROR
(
DuplicateElement
,
"Name already used: "
<<
var
.
name
());
820
}
catch
(
NotFound
&) {
821
// The var name is unused
822
varNameMap__
.
insert
(
var
.
name
(),
bn__
->
add
(
var
));
823
}
824
}
825
}
826
827
// @brief Define a variable's CPT.
828
// You can only call this method if the factory is in the NONE or NETWORK
829
// state.
830
// Be careful that table is given to the built BayesNet, so it will be
831
// deleted with it, and you should not directly access it after you call
832
// this method.
833
// When the redefineParents flag is set to true the constructed BayesNet's
834
// DAG is changed to fit with table's definition.
835
// @param var The name of the concerned variable.
836
// @param table A pointer over the CPT used for var.
837
// @param redefineParents If true redefine var's parents to match table's
838
// variables set.
839
//
840
// @throw NotFound Raised if no variable matches var.
841
// @throw OperationNotAllowed Raised if redefineParents == false and if table
842
// is not a valid CPT for var in the current state
843
// of the BayesNet.
844
template
<
typename
GUM_SCALAR
>
845
INLINE
void
846
BayesNetFactory
<
GUM_SCALAR
>::
setVariableCPT
(
const
std
::
string
&
varName
,
847
MultiDimAdressable
*
table
,
848
bool
redefineParents
) {
849
auto
pot
=
dynamic_cast
<
Potential
<
GUM_SCALAR
>* >(
table
);
850
851
if
(
state
() !=
factory_state
::
NONE
) {
852
illegalStateError__
(
"setVariableCPT"
);
853
}
else
{
854
checkVariableName__
(
varName
);
855
const
DiscreteVariable
&
var
=
bn__
->
variable
(
varNameMap__
[
varName
]);
856
NodeId
varId
=
varNameMap__
[
varName
];
857
// If we have to change the structure of the BayesNet, then we call a sub
858
// method.
859
860
if
(
redefineParents
) {
861
setCPTAndParents__
(
var
,
pot
);
862
}
else
if
(
pot
->
contains
(
var
)) {
863
for
(
auto
node
:
bn__
->
parents
(
varId
)) {
864
if
(!
pot
->
contains
(
bn__
->
variable
(
node
))) {
865
GUM_ERROR
(
OperationNotAllowed
,
866
"The CPT is not valid in the current BayesNet."
);
867
}
868
}
869
870
// CPT are created when a variable is added.
871
bn__
->
unsafeChangePotential_
(
varId
,
pot
);
872
}
873
}
874
}
875
876
// Copy operator is illegal, use only copy constructor.
877
template
<
typename
GUM_SCALAR
>
878
INLINE
BayesNetFactory
<
GUM_SCALAR
>&
BayesNetFactory
<
GUM_SCALAR
>::
operator
=(
879
const
BayesNetFactory
<
GUM_SCALAR
>&
source
) {
880
GUM_ERROR
(
OperationNotAllowed
,
"Illegal!"
);
881
// For noisy compilers
882
return
*
this
;
883
}
884
885
// Raise an OperationNotAllowed with the message "Illegal state."
886
template
<
typename
GUM_SCALAR
>
887
INLINE
void
888
BayesNetFactory
<
GUM_SCALAR
>::
illegalStateError__
(
const
std
::
string
&
s
) {
889
std
::
string
msg
=
"Illegal state call ("
;
890
msg
+=
s
;
891
msg
+=
") in state "
;
892
893
switch
(
state
()) {
894
case
factory_state
::
NONE
: {
895
msg
+=
"NONE"
;
896
break
;
897
}
898
899
case
factory_state
::
NETWORK
: {
900
msg
+=
"NETWORK"
;
901
break
;
902
}
903
904
case
factory_state
::
VARIABLE
: {
905
msg
+=
"VARIABLE"
;
906
break
;
907
}
908
909
case
factory_state
::
PARENTS
: {
910
msg
+=
"PARENTS"
;
911
break
;
912
}
913
914
case
factory_state
::
RAW_CPT
: {
915
msg
+=
"RAW_CPT"
;
916
break
;
917
}
918
919
case
factory_state
::
FACT_CPT
: {
920
msg
+=
"FACT_CPT"
;
921
break
;
922
}
923
924
case
factory_state
::
FACT_ENTRY
: {
925
msg
+=
"FACT_ENTRY"
;
926
break
;
927
}
928
929
default
: {
930
msg
+=
"Unknown state"
;
931
}
932
}
933
934
GUM_ERROR
(
OperationNotAllowed
,
msg
);
935
}
936
937
// Check if a variable with the given name exists, if not raise an NotFound
938
// exception.
939
template
<
typename
GUM_SCALAR
>
940
INLINE
void
941
BayesNetFactory
<
GUM_SCALAR
>::
checkVariableName__
(
const
std
::
string
&
name
) {
942
if
(!
varNameMap__
.
exists
(
name
)) {
GUM_ERROR
(
NotFound
,
name
); }
943
}
944
945
// Check if var exists and if mod is one of it's modality, if not raise an
946
// NotFound exception.
947
template
<
typename
GUM_SCALAR
>
948
INLINE
Idx
BayesNetFactory
<
GUM_SCALAR
>::
checkVariableModality__
(
949
const
std
::
string
&
name
,
950
const
std
::
string
&
mod
) {
951
checkVariableName__
(
name
);
952
const
DiscreteVariable
&
var
=
bn__
->
variable
(
varNameMap__
[
name
]);
953
954
for
(
Idx
i
= 0;
i
<
var
.
domainSize
(); ++
i
) {
955
if
(
mod
==
var
.
label
(
i
)) {
return
i
; }
956
}
957
958
GUM_ERROR
(
NotFound
,
mod
);
959
}
960
961
// Check if in stringBag__ there is no other modality with the same name.
962
template
<
typename
GUM_SCALAR
>
963
INLINE
void
964
BayesNetFactory
<
GUM_SCALAR
>::
checkModalityInBag__
(
const
std
::
string
&
mod
) {
965
for
(
size_t
i
= 3;
i
<
stringBag__
.
size
(); ++
i
) {
966
if
(
mod
==
stringBag__
[
i
]) {
967
GUM_ERROR
(
DuplicateElement
,
"Label already used: "
<<
mod
);
968
}
969
}
970
}
971
972
// Sub method of setVariableCPT() which redefine the BayesNet's DAG with
973
// respect to table.
974
template
<
typename
GUM_SCALAR
>
975
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
setCPTAndParents__
(
976
const
DiscreteVariable
&
var
,
977
Potential
<
GUM_SCALAR
>*
table
) {
978
NodeId
varId
=
varNameMap__
[
var
.
name
()];
979
bn__
->
dag_
.
eraseParents
(
varId
);
980
981
for
(
auto
v
:
table
->
variablesSequence
()) {
982
if
(
v
!= (&
var
)) {
983
checkVariableName__
(
v
->
name
());
984
bn__
->
dag_
.
addArc
(
varNameMap__
[
v
->
name
()],
varId
);
985
}
986
}
987
988
// CPT are created when a variable is added.
989
bn__
->
unsafeChangePotential_
(
varId
,
table
);
990
}
991
992
// Reset the different parts used to constructed the BayesNet.
993
template
<
typename
GUM_SCALAR
>
994
INLINE
void
BayesNetFactory
<
GUM_SCALAR
>::
resetParts__
() {
995
foo_flag__
=
false
;
996
bar_flag__
=
false
;
997
stringBag__
.
clear
();
998
}
999
}
/* namespace gum */
VERBOSITY
#define VERBOSITY(x)
Definition:
BayesNetFactory_tpl.h:32
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669