aGrUM
0.20.2
a C++ library for (probabilistic) graphical models
O3prmrInterpreter.cpp
Go to the documentation of this file.
1
/**
2
*
3
* Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4
* info_at_agrum_dot_org
5
*
6
* This library is free software: you can redistribute it and/or modify
7
* it under the terms of the GNU Lesser General Public License as published by
8
* the Free Software Foundation, either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* This library is distributed in the hope that it will be useful,
12
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
* GNU Lesser General Public License for more details.
15
*
16
* You should have received a copy of the GNU Lesser General Public License
17
* along with this library. If not, see <http://www.gnu.org/licenses/>.
18
*
19
*/
20
21
22
/**
23
* @file
24
* @brief Implementation of O3prmReader<double>.
25
*
26
* @author Pierre-Henri WUILLEMIN(@LIP6), Ni NI, Lionel TORTI & Vincent RENAUDINEAU
27
*/
28
29
#
include
<
agrum
/
agrum
.
h
>
30
31
#
include
<
agrum
/
BN
/
BayesNet
.
h
>
32
#
include
<
agrum
/
BN
/
inference
/
lazyPropagation
.
h
>
33
#
include
<
agrum
/
BN
/
inference
/
tools
/
BayesNetInference
.
h
>
34
#
include
<
agrum
/
BN
/
inference
/
variableElimination
.
h
>
35
36
#
include
<
agrum
/
PRM
/
inference
/
SVE
.
h
>
37
#
include
<
agrum
/
PRM
/
inference
/
SVED
.
h
>
38
#
include
<
agrum
/
PRM
/
inference
/
groundedInference
.
h
>
39
#
include
<
agrum
/
PRM
/
o3prmr
/
O3prmrInterpreter
.
h
>
40
41
#
include
<
agrum
/
PRM
/
o3prmr
/
cocoR
/
Parser
.
h
>
42
43
namespace
gum
{
44
45
namespace
prm
{
46
47
namespace
o3prmr
{
48
49
/* **************************************************************************
50
*/
51
52
/// This constructor create an empty context.
53
O3prmrInterpreter
::
O3prmrInterpreter
() :
54
m_context
(
new
O3prmrContext
<
double
>()),
55
m_reader
(
new
o3prm
::
O3prmReader
<
double
>()),
m_bn
(0),
m_inf
(0),
56
m_syntax_flag
(
false
),
m_verbose
(
false
),
m_log
(
std
::
cout
),
57
m_current_line
(-1) {}
58
59
/// Destructor. Delete current context.
60
O3prmrInterpreter
::~
O3prmrInterpreter
() {
61
delete
m_context
;
62
if
(
m_bn
) {
delete
m_bn
; }
63
for
(
auto
p
:
m_inf_map
) {
64
delete
p
.
second
;
65
}
66
delete
m_reader
->
prm
();
67
delete
m_reader
;
68
}
69
70
/* **************************************************************************
71
*/
72
73
/// Getter for the context.
74
O3prmrContext
<
double
>*
O3prmrInterpreter
::
getContext
()
const
{
75
return
m_context
;
76
}
77
78
/// Setter for the context.
79
void
O3prmrInterpreter
::
setContext
(
O3prmrContext
<
double
>*
context
) {
80
delete
m_context
;
81
82
if
(
context
== 0)
83
m_context
=
new
O3prmrContext
<
double
>();
84
else
85
m_context
=
context
;
86
}
87
88
/// Root paths to search from there packages.
89
/// Default are './' and one is calculate from request package if any.
90
std
::
vector
<
std
::
string
>
O3prmrInterpreter
::
getPaths
()
const
{
91
return
m_paths
;
92
}
93
94
/// Root paths to search from there packages.
95
/// Default are './' and one is calculate from request package if any.
96
void
O3prmrInterpreter
::
addPath
(
std
::
string
path
) {
97
if
(
path
.
length
() &&
path
.
back
() !=
'/'
) {
path
=
path
+
'/'
; }
98
if
(
Directory
::
isDir
(
path
)) {
99
m_paths
.
push_back
(
path
);
100
}
else
{
101
GUM_ERROR
(
NotFound
,
"not a directory"
);
102
}
103
}
104
105
/// Root paths to search from there packages.
106
/// Default are './' and one is calculate from request package if any.
107
void
O3prmrInterpreter
::
clearPaths
() {
m_paths
.
clear
(); }
108
109
/// syntax mode don't process anything, just check syntax.
110
bool
O3prmrInterpreter
::
isInSyntaxMode
()
const
{
return
m_syntax_flag
; }
111
112
/// syntax mode don't process anything, just check syntax.
113
void
O3prmrInterpreter
::
setSyntaxMode
(
bool
f
) {
m_syntax_flag
=
f
; }
114
115
/// verbose mode show more details on the program execution.
116
bool
O3prmrInterpreter
::
isVerboseMode
()
const
{
return
m_verbose
; }
117
118
/// verbose mode show more details on the program execution.
119
void
O3prmrInterpreter
::
setVerboseMode
(
bool
f
) {
m_verbose
=
f
; }
120
121
/// Retrieve prm object.
122
const
PRM
<
double
>*
O3prmrInterpreter
::
prm
()
const
{
123
return
m_reader
->
prm
();
124
}
125
126
/// Retrieve inference motor object.
127
const
PRMInference
<
double
>*
O3prmrInterpreter
::
inference
()
const
{
128
return
m_inf
;
129
}
130
131
/// Return a std::vector of QueryResults.
132
/// Each QueryResults is a struct with query command, time and values,
133
/// a std::vector of struct SingleResult, with pair label/value.
134
const
std
::
vector
<
QueryResult
>&
O3prmrInterpreter
::
results
()
const
{
135
return
m_results
;
136
}
137
138
/**
139
* Parse the file or the command line.
140
* If errors occured, return false. Errors messages can be retrieve be
141
* getErrorsContainer() methods.
142
* If any errors occured, return true.
143
* Requests results can be retrieve be results() methods.
144
* */
145
bool
O3prmrInterpreter
::
interpretFile
(
const
std
::
string
&
filename
) {
146
m_results
.
clear
();
147
148
try
{
149
std
::
string
file_content
=
readFile__
(
filename
);
150
151
delete
m_context
;
152
m_context
=
new
O3prmrContext
<
double
>(
filename
);
153
O3prmrContext
<
double
>
c
(
filename
);
154
155
// On vérifie la syntaxe
156
unsigned
char
*
buffer
=
new
unsigned
char
[
file_content
.
length
() + 1];
157
strcpy
((
char
*)
buffer
,
file_content
.
c_str
());
158
Scanner
s
(
buffer
,
int
(
file_content
.
length
() + 1));
159
Parser
p
(&
s
);
160
p
.
setO3prmrContext
(&
c
);
161
p
.
Parse
();
162
163
m_errors
=
p
.
errors
();
164
165
if
(
errors
() > 0) {
return
false
; }
166
167
// Set paths to search from.
168
delete
m_reader
->
prm
();
169
delete
m_reader
;
170
m_reader
=
new
o3prm
::
O3prmReader
<
double
>();
171
172
for
(
size_t
i
= 0;
i
<
m_paths
.
size
();
i
++) {
173
m_reader
->
addClassPath
(
m_paths
[
i
]);
174
}
175
176
// On vérifie la sémantique.
177
if
(!
checkSemantic
(&
c
)) {
return
false
; }
178
179
if
(
isInSyntaxMode
()) {
180
return
true
;
181
}
else
{
182
return
interpret
(&
c
);
183
}
184
}
catch
(
gum
::
Exception
&) {
return
false
; }
185
}
186
187
std
::
string
O3prmrInterpreter
::
readFile__
(
const
std
::
string
&
file
) {
188
// read entire file into string
189
std
::
ifstream
istream
(
file
,
std
::
ifstream
::
binary
);
190
if
(
istream
) {
191
// get length of file:
192
istream
.
seekg
(0,
istream
.
end
);
193
int
length
=
int
(
istream
.
tellg
());
194
istream
.
seekg
(0,
istream
.
beg
);
195
196
std
::
string
str
;
197
str
.
resize
(
length
,
' '
);
// reserve space
198
char
*
begin
= &*
str
.
begin
();
199
200
istream
.
read
(
begin
,
length
);
201
istream
.
close
();
202
203
return
str
;
204
}
205
GUM_ERROR
(
OperationNotAllowed
,
"Could not open file"
);
206
}
207
208
bool
O3prmrInterpreter
::
interpretLine
(
const
std
::
string
&
line
) {
209
m_results
.
clear
();
210
211
// On vérifie la syntaxe
212
O3prmrContext
<
double
>
c
;
213
Scanner
s
((
unsigned
char
*)
line
.
c_str
(), (
int
)
line
.
length
());
214
Parser
p
(&
s
);
215
p
.
setO3prmrContext
(&
c
);
216
p
.
Parse
();
217
m_errors
=
p
.
errors
();
218
219
if
(
errors
() > 0)
return
false
;
220
221
// On vérifie la sémantique.
222
if
(!
checkSemantic
(&
c
))
return
false
;
223
224
if
(
isInSyntaxMode
())
225
return
true
;
226
else
227
return
interpret
(&
c
);
228
}
229
230
/**
231
* Crée le prm correspondant au contexte courant.
232
* Renvoie true en cas de succès, ou false en cas échéant d'échec
233
* de l'interprétation du contexte (import introuvable ou non défini,
234
* etc).
235
* */
236
bool
O3prmrInterpreter
::
interpret
(
O3prmrContext
<
double
>*
c
) {
237
if
(
isVerboseMode
())
238
m_log
<<
"## Start interpretation."
<<
std
::
endl
<<
std
::
flush
;
239
240
// Don't parse if any syntax errors.
241
if
(
errors
() > 0)
return
false
;
242
243
// For each session
244
std
::
vector
<
O3prmrSession
<
double
>* >
sessions
=
c
->
sessions
();
245
246
for
(
const
auto
session
:
sessions
)
247
for
(
auto
command
:
session
->
commands
()) {
248
// We process it.
249
bool
result
=
true
;
250
251
try
{
252
switch
(
command
->
type
()) {
253
case
O3prmrCommand
::
RequestType
::
Observe
:
254
result
=
observe
((
ObserveCommand
<
double
>*)
command
);
255
break
;
256
257
case
O3prmrCommand
::
RequestType
::
Unobserve
:
258
result
=
unobserve
((
UnobserveCommand
<
double
>*)
command
);
259
break
;
260
261
case
O3prmrCommand
::
RequestType
::
SetEngine
:
262
setEngine
((
SetEngineCommand
*)
command
);
263
break
;
264
265
case
O3prmrCommand
::
RequestType
::
SetGndEngine
:
266
setGndEngine
((
SetGndEngineCommand
*)
command
);
267
break
;
268
269
case
O3prmrCommand
::
RequestType
::
Query
:
270
query
((
QueryCommand
<
double
>*)
command
);
271
break
;
272
}
273
}
catch
(
Exception
&
err
) {
274
result
=
false
;
275
addError
(
err
.
errorContent
());
276
}
catch
(
std
::
string
&
err
) {
277
result
=
false
;
278
addError
(
err
);
279
}
280
281
// If there was a problem, skip the rest of this session,
282
// unless syntax mode is activated.
283
if
(!
result
) {
284
if
(
m_verbose
)
285
m_log
<<
"Errors : skip the rest of this session."
<<
std
::
endl
;
286
287
break
;
288
}
289
}
290
291
if
(
isVerboseMode
())
292
m_log
<<
"## End interpretation."
<<
std
::
endl
<<
std
::
flush
;
293
294
return
errors
() == 0;
295
}
296
297
/* **************************************************************************
298
*/
299
300
/**
301
* Check semantic validity of context.
302
* Import first all import, and check that systems, instances, attributes
303
*and
304
*labels exists.
305
* While checking, prepare data structures for interpretation.
306
* Return true if all is right, false otherwise.
307
*
308
* Note : Stop checking at first error unless syntax mode is activated.
309
* */
310
bool
O3prmrInterpreter
::
checkSemantic
(
O3prmrContext
<
double
>*
context
) {
311
// Don't parse if any syntax errors.
312
if
(
errors
() > 0)
return
false
;
313
314
// On importe tous les systèmes.
315
for
(
const
auto
command
:
context
->
imports
()) {
316
m_current_line
=
command
->
line
;
317
// if import doen't succed stop here unless syntax mode is activated.
318
bool
succeed
=
import
(
context
,
command
->
value
);
319
320
if
(!
succeed
&& !
isInSyntaxMode
())
return
false
;
321
322
// En cas de succès, on met à jour le contexte global
323
if
(
succeed
)
m_context
->
addImport
(*
command
);
324
}
325
326
if
(
m_verbose
)
327
m_log
<<
"## Check semantic for "
<<
context
->
sessions
().
size
()
328
<<
" sessions"
<<
std
::
endl
;
329
330
// On vérifie chaque session
331
for
(
const
auto
session
:
context
->
sessions
()) {
332
std
::
string
sessionName
=
session
->
name
();
333
O3prmrSession
<
double
>*
new_session
334
=
new
O3prmrSession
<
double
>(
sessionName
);
335
336
if
(
m_verbose
)
337
m_log
<<
"## Start session '"
<<
sessionName
<<
"'..."
<<
std
::
endl
338
<<
std
::
endl
;
339
340
for
(
const
auto
command
:
session
->
commands
()) {
341
if
(
m_verbose
)
342
m_log
<<
"# * Going to check command : "
<<
command
->
toString
()
343
<<
std
::
endl
;
344
345
// Update the current line (for warnings and errors)
346
m_current_line
=
command
->
line
;
347
348
// We check it.
349
bool
result
=
true
;
350
351
try
{
352
switch
(
command
->
type
()) {
353
case
O3prmrCommand
::
RequestType
::
SetEngine
:
354
result
=
checkSetEngine
((
SetEngineCommand
*)
command
);
355
break
;
356
357
case
O3prmrCommand
::
RequestType
::
SetGndEngine
:
358
result
=
checkSetGndEngine
((
SetGndEngineCommand
*)
command
);
359
break
;
360
361
case
O3prmrCommand
::
RequestType
::
Observe
:
362
result
=
checkObserve
((
ObserveCommand
<
double
>*)
command
);
363
break
;
364
365
case
O3prmrCommand
::
RequestType
::
Unobserve
:
366
result
=
checkUnobserve
((
UnobserveCommand
<
double
>*)
command
);
367
break
;
368
369
case
O3prmrCommand
::
RequestType
::
Query
:
370
result
=
checkQuery
((
QueryCommand
<
double
>*)
command
);
371
break
;
372
373
default
:
374
addError
(
"Error : Unknow command : "
+
command
->
toString
()
375
+
"\n -> Command not processed."
);
376
result
=
false
;
377
}
378
}
catch
(
Exception
&
err
) {
379
result
=
false
;
380
addError
(
err
.
errorContent
());
381
}
catch
(
std
::
string
&
err
) {
382
result
=
false
;
383
addError
(
err
);
384
}
385
386
// If there was a problem, skip the rest of this session,
387
// unless syntax mode is activated.
388
if
(!
result
&& !
isInSyntaxMode
()) {
389
if
(
m_verbose
)
390
m_log
<<
"Errors : skip the rest of this session."
<<
std
::
endl
;
391
392
break
;
393
}
394
395
// On l'ajoute au contexte globale
396
if
(
result
)
new_session
->
addCommand
((
const
O3prmrCommand
*)
command
);
397
}
398
399
// Ajoute la session au contexte global,
400
// ou à la dernière session.
401
if
(
sessionName
==
"default"
&&
m_context
->
sessions
().
size
() > 0)
402
*(
m_context
->
sessions
().
back
()) += *
new_session
;
403
else
404
m_context
->
addSession
(*
new_session
);
405
406
if
(
m_verbose
)
407
m_log
<<
std
::
endl
408
<<
"## Session '"
<<
sessionName
<<
"' finished."
<<
std
::
endl
409
<<
std
::
endl
410
<<
std
::
endl
;
411
412
// todo : check memory leak
413
// delete new_session; ??
414
}
415
416
if
(
isVerboseMode
() &&
errors
() != 0)
417
m_errors
.
elegantErrorsAndWarnings
(
m_log
);
418
419
return
errors
() == 0;
420
}
421
422
bool
O3prmrInterpreter
::
checkSetEngine
(
SetEngineCommand
*
command
) {
423
m_engine
=
command
->
value
;
424
return
m_engine
==
"SVED"
||
m_engine
==
"GRD"
||
m_engine
==
"SVE"
;
425
}
426
427
bool
O3prmrInterpreter
::
checkSetGndEngine
(
SetGndEngineCommand
*
command
) {
428
m_bn_engine
=
command
->
value
;
429
return
m_bn_engine
==
"VE"
||
m_bn_engine
==
"VEBB"
430
||
m_bn_engine
==
"lazy"
;
431
}
432
433
bool
O3prmrInterpreter
::
checkObserve
(
ObserveCommand
<
double
>*
command
) {
434
try
{
435
std
::
string
left_val
=
command
->
leftValue
;
436
const
std
::
string
right_val
=
command
->
rightValue
;
437
438
// Contruct the pair (instance,attribut)
439
const
PRMSystem
<
double
>&
sys
=
system
(
left_val
);
440
const
PRMInstance
<
double
>&
instance
441
=
sys
.
get
(
findInstanceName
(
left_val
,
sys
));
442
const
PRMAttribute
<
double
>&
attr
443
=
instance
.
get
(
findAttributeName
(
left_val
,
instance
));
444
typename
PRMInference
<
double
>::
Chain
chain
445
=
std
::
make_pair
(&
instance
, &
attr
);
446
447
command
->
system
= &
sys
;
448
command
->
chain
=
std
::
make_pair
(&
instance
, &
attr
);
449
450
// Check label exists for this type.
451
// Potential<double> e;
452
command
->
potentiel
.
add
(
chain
.
second
->
type
().
variable
());
453
Instantiation
i
(
command
->
potentiel
);
454
bool
found
=
false
;
455
456
for
(
i
.
setFirst
(); !
i
.
end
();
i
.
inc
()) {
457
if
(
chain
.
second
->
type
().
variable
().
label
(
458
i
.
val
(
chain
.
second
->
type
().
variable
()))
459
==
right_val
) {
460
command
->
potentiel
.
set
(
i
, (
double
)1.0);
461
found
=
true
;
462
}
else
{
463
command
->
potentiel
.
set
(
i
, (
double
)0.0);
464
}
465
}
466
467
if
(!
found
)
addError
(
right_val
+
" is not a label of "
+
left_val
);
468
469
// else command->potentiel = e;
470
471
return
found
;
472
473
}
catch
(
Exception
&
err
) {
474
addError
(
err
.
errorContent
());
475
}
catch
(
std
::
string
&
err
) {
addError
(
err
); }
476
477
return
false
;
478
}
479
480
bool
O3prmrInterpreter
::
checkUnobserve
(
UnobserveCommand
<
double
>*
command
) {
481
try
{
482
std
::
string
name
=
command
->
value
;
483
484
// Contruct the pair (instance,attribut)
485
const
PRMSystem
<
double
>&
sys
=
system
(
name
);
486
const
PRMInstance
<
double
>&
instance
487
=
sys
.
get
(
findInstanceName
(
name
,
sys
));
488
const
PRMAttribute
<
double
>&
attr
489
=
instance
.
get
(
findAttributeName
(
name
,
instance
));
490
// PRMInference<double>::Chain chain = std::make_pair(&instance,
491
// &attr);
492
493
command
->
system
= &
sys
;
494
command
->
chain
=
std
::
make_pair
(&
instance
, &
attr
);
495
496
return
true
;
497
498
}
catch
(
Exception
&
err
) {
499
addError
(
err
.
errorContent
());
500
}
catch
(
std
::
string
&
err
) {
addError
(
err
); }
501
502
return
false
;
503
}
504
505
bool
O3prmrInterpreter
::
checkQuery
(
QueryCommand
<
double
>*
command
) {
506
try
{
507
std
::
string
name
=
command
->
value
;
508
509
// Contruct the pair (instance,attribut)
510
const
PRMSystem
<
double
>&
sys
=
system
(
name
);
511
const
PRMInstance
<
double
>&
instance
512
=
sys
.
get
(
findInstanceName
(
name
,
sys
));
513
const
PRMAttribute
<
double
>&
attr
514
=
instance
.
get
(
findAttributeName
(
name
,
instance
));
515
// PRMInference<double>::Chain chain = std::make_pair(&instance,
516
// &attr);
517
518
command
->
system
= &
sys
;
519
command
->
chain
=
std
::
make_pair
(&
instance
, &
attr
);
520
521
return
true
;
522
523
}
catch
(
Exception
&
err
) {
524
addError
(
err
.
errorContent
());
525
}
catch
(
std
::
string
&
err
) {
addError
(
err
); }
526
527
return
false
;
528
}
529
530
// Import the system o3prm file
531
// Return false if any error.
532
533
bool
O3prmrInterpreter
::
import
(
O3prmrContext
<
double
>*
context
,
534
std
::
string
import_name
) {
535
try
{
536
if
(
m_verbose
) {
537
m_log
<<
"# Loading system '"
<<
import_name
<<
"' => '"
<<
std
::
flush
;
538
}
539
540
std
::
string
import_package
=
import_name
;
541
542
std
::
replace
(
import_name
.
begin
(),
import_name
.
end
(),
'.'
,
'/'
);
543
import_name
+=
".o3prm"
;
544
545
if
(
m_verbose
) {
546
m_log
<<
import_name
<<
"' ... "
<<
std
::
endl
<<
std
::
flush
;
547
}
548
549
std
::
ifstream
file_test
;
550
bool
found
=
false
;
551
std
::
string
import_abs_filename
;
552
553
// Search in o3prmr file dir.
554
std
::
string
o3prmrFilename
=
context
->
filename
();
555
556
if
(!
o3prmrFilename
.
empty
()) {
557
size_t
index
=
o3prmrFilename
.
find_last_of
(
'/'
);
558
559
if
(
index
!=
std
::
string
::
npos
) {
560
std
::
string
dir
=
o3prmrFilename
.
substr
(0,
index
+ 1);
561
import_abs_filename
=
dir
+
import_name
;
562
563
if
(
m_verbose
) {
564
m_log
<<
"# Search from filedir '"
<<
import_abs_filename
565
<<
"' ... "
<<
std
::
flush
;
566
}
567
568
file_test
.
open
(
import_abs_filename
.
c_str
());
569
570
if
(
file_test
.
is_open
()) {
571
if
(
m_verbose
) {
m_log
<<
"found !"
<<
std
::
endl
<<
std
::
flush
; }
572
573
file_test
.
close
();
574
found
=
true
;
575
}
else
if
(
m_verbose
) {
576
m_log
<<
"not found."
<<
std
::
endl
<<
std
::
flush
;
577
}
578
}
579
}
580
581
// Deduce root path from package name.
582
std
::
string
package
=
context
->
package
();
583
584
if
(!
found
&& !
package
.
empty
()) {
585
std
::
string
root
;
586
587
// if filename is not empty, start from it.
588
std
::
string
filename
=
context
->
filename
();
589
590
if
(!
filename
.
empty
()) {
591
size_t
size
=
filename
.
find_last_of
(
'/'
);
592
593
if
(
size
!=
std
::
string
::
npos
) {
594
root
+=
filename
.
substr
(0,
size
+ 1);
// take with the '/'
595
}
596
}
597
598
//
599
root
+=
"../"
;
600
int
count
= (
int
)
std
::
count
(
package
.
begin
(),
package
.
end
(),
'.'
);
601
602
for
(
int
i
= 0;
i
<
count
;
i
++)
603
root
+=
"../"
;
604
605
import_abs_filename
=
Directory
(
root
).
absolutePath
() +
import_name
;
606
607
if
(
m_verbose
) {
608
m_log
<<
"# Search from package '"
<<
package
<<
"' => '"
609
<<
import_abs_filename
<<
"' ... "
<<
std
::
flush
;
610
}
611
612
file_test
.
open
(
import_abs_filename
.
c_str
());
613
614
if
(
file_test
.
is_open
()) {
615
if
(
m_verbose
) {
m_log
<<
"found !"
<<
std
::
endl
<<
std
::
flush
; }
616
617
file_test
.
close
();
618
found
=
true
;
619
}
else
if
(
m_verbose
) {
620
m_log
<<
"not found."
<<
std
::
endl
<<
std
::
flush
;
621
}
622
}
623
624
// Search import in all paths.
625
for
(
const
auto
&
path
:
m_paths
) {
626
import_abs_filename
=
path
+
import_name
;
627
628
if
(
m_verbose
) {
629
m_log
<<
"# Search from classpath '"
<<
import_abs_filename
630
<<
"' ... "
<<
std
::
flush
;
631
}
632
633
file_test
.
open
(
import_abs_filename
.
c_str
());
634
635
if
(
file_test
.
is_open
()) {
636
if
(
m_verbose
) {
m_log
<<
" found !"
<<
std
::
endl
<<
std
::
flush
; }
637
638
file_test
.
close
();
639
found
=
true
;
640
break
;
641
}
else
if
(
m_verbose
) {
642
m_log
<<
" not found."
<<
std
::
endl
<<
std
::
flush
;
643
}
644
}
645
646
if
(!
found
) {
647
if
(
m_verbose
) {
m_log
<<
"Finished with errors."
<<
std
::
endl
; }
648
649
addError
(
"import not found."
);
650
return
false
;
651
}
652
653
// May throw std::IOError if file does't exist
654
Size
previousO3prmError
=
m_reader
->
errors
();
655
Size
previousO3prmrError
=
errors
();
656
657
try
{
658
m_reader
->
readFile
(
import_abs_filename
,
import_package
);
659
660
// Show errors and warning
661
if
(
m_verbose
662
&& (
m_reader
->
errors
() > (
unsigned
int
)
previousO3prmError
663
||
errors
() >
previousO3prmrError
)) {
664
m_log
<<
"Finished with errors."
<<
std
::
endl
;
665
}
else
if
(
m_verbose
) {
666
m_log
<<
"Finished."
<<
std
::
endl
;
667
}
668
669
}
catch
(
const
IOError
&
err
) {
670
if
(
m_verbose
) {
m_log
<<
"Finished with errors."
<<
std
::
endl
; }
671
672
addError
(
err
.
errorContent
());
673
}
674
675
// Add o3prm errors and warnings to o3prmr errors
676
for
(;
previousO3prmError
<
m_reader
->
errorsContainer
().
count
();
677
previousO3prmError
++) {
678
m_errors
.
add
(
m_reader
->
errorsContainer
().
error
(
previousO3prmError
));
679
}
680
681
return
errors
() ==
previousO3prmrError
;
682
683
}
catch
(
const
Exception
&
err
) {
684
if
(
m_verbose
) {
m_log
<<
"Finished with exceptions."
<<
std
::
endl
; }
685
686
addError
(
err
.
errorContent
());
687
return
false
;
688
}
689
}
690
691
std
::
string
O3prmrInterpreter
::
findSystemName
(
std
::
string
&
s
) {
692
size_t
dot
=
s
.
find_first_of
(
'.'
);
693
std
::
string
name
=
s
.
substr
(0,
dot
);
694
695
// We look first for real system, next for alias.
696
if
(
prm
()->
isSystem
(
name
)) {
697
s
=
s
.
substr
(
dot
+ 1);
698
return
name
;
699
}
700
701
if
(!
m_context
->
aliasToImport
(
name
).
empty
()) {
702
s
=
s
.
substr
(
dot
+ 1);
703
return
m_context
->
aliasToImport
(
name
);
704
}
705
706
while
(
dot
!=
std
::
string
::
npos
) {
707
if
(
prm
()->
isSystem
(
name
)) {
708
s
=
s
.
substr
(
dot
+ 1);
709
return
name
;
710
}
711
712
dot
=
s
.
find
(
'.'
,
dot
+ 1);
713
name
=
s
.
substr
(0,
dot
);
714
}
715
716
throw
"could not find any system in '"
+
s
+
"'."
;
717
}
718
719
std
::
string
720
O3prmrInterpreter
::
findInstanceName
(
std
::
string
&
s
,
721
const
PRMSystem
<
double
>&
sys
) {
722
// We have found system before, so 's' has been stripped.
723
size_t
dot
=
s
.
find_first_of
(
'.'
);
724
std
::
string
name
=
s
.
substr
(0,
dot
);
725
726
if
(!
sys
.
exists
(
name
))
727
throw
"'"
+
name
+
"' is not an instance of system '"
+
sys
.
name
()
728
+
"'."
;
729
730
s
=
s
.
substr
(
dot
+ 1);
731
return
name
;
732
}
733
734
std
::
string
O3prmrInterpreter
::
findAttributeName
(
735
const
std
::
string
&
s
,
736
const
PRMInstance
<
double
>&
instance
) {
737
if
(!
instance
.
exists
(
s
))
738
throw
"'"
+
s
+
"' is not an attribute of instance '"
+
instance
.
name
()
739
+
"'."
;
740
741
return
s
;
742
}
743
744
// After this method, ident doesn't contains the system name anymore.
745
const
PRMSystem
<
double
>&
O3prmrInterpreter
::
system
(
std
::
string
&
ident
) {
746
try
{
747
return
prm
()->
getSystem
(
findSystemName
(
ident
));
748
}
catch
(
const
std
::
string
&) {}
749
750
if
((
m_context
->
mainImport
() != 0)
751
&&
prm
()->
isSystem
(
m_context
->
mainImport
()->
value
))
752
return
prm
()->
getSystem
(
m_context
->
mainImport
()->
value
);
753
754
throw
"could not find any system or alias in '"
+
ident
755
+
"' and no default alias has been set."
;
756
}
757
758
///
759
760
bool
761
O3prmrInterpreter
::
observe
(
const
ObserveCommand
<
double
>*
command
)
try
{
762
const
typename
PRMInference
<
double
>::
Chain
&
chain
=
command
->
chain
;
763
764
// Generate the inference engine if it doesn't exist.
765
if
(!
m_inf
) {
generateInfEngine
(*(
command
->
system
)); }
766
767
// Prevent from something
768
if
(
m_inf
->
hasEvidence
(
chain
))
769
addWarning
(
command
->
leftValue
+
" is already observed"
);
770
771
m_inf
->
addEvidence
(
chain
,
command
->
potentiel
);
772
773
if
(
m_verbose
)
774
m_log
<<
"# Added evidence "
<<
command
->
rightValue
<<
" over attribute "
775
<<
command
->
leftValue
<<
std
::
endl
;
776
777
return
true
;
778
779
}
catch
(
OperationNotAllowed
&
ex
) {
780
addError
(
"something went wrong when adding evidence "
+
command
->
rightValue
781
+
" over "
+
command
->
leftValue
+
" : "
+
ex
.
errorContent
());
782
return
false
;
783
784
}
catch
(
const
std
::
string
&
msg
) {
785
addError
(
msg
);
786
return
false
;
787
}
788
789
///
790
791
bool
O3prmrInterpreter
::
unobserve
(
792
const
UnobserveCommand
<
double
>*
command
)
try
{
793
std
::
string
name
=
command
->
value
;
794
typename
PRMInference
<
double
>::
Chain
chain
=
command
->
chain
;
795
796
// Prevent from something
797
if
(!
m_inf
|| !
m_inf
->
hasEvidence
(
chain
)) {
798
addWarning
(
name
+
" was not observed"
);
799
}
else
{
800
m_inf
->
removeEvidence
(
chain
);
801
802
if
(
m_verbose
)
803
m_log
<<
"# Removed evidence over attribute "
<<
name
<<
std
::
endl
;
804
}
805
806
return
true
;
807
808
}
catch
(
const
std
::
string
&
msg
) {
809
addError
(
msg
);
810
return
false
;
811
}
812
813
///
814
void
O3prmrInterpreter
::
query
(
const
QueryCommand
<
double
>*
command
)
try
{
815
const
std
::
string
&
query
=
command
->
value
;
816
817
if
(
m_inf_map
.
exists
(
command
->
system
)) {
818
m_inf
=
m_inf_map
[
command
->
system
];
819
}
else
{
820
m_inf
=
nullptr
;
821
}
822
823
// Create inference engine if it has not been already created.
824
if
(!
m_inf
) {
generateInfEngine
(*(
command
->
system
)); }
825
826
// Inference
827
if
(
m_verbose
) {
828
m_log
<<
"# Starting inference over query: "
<<
query
<<
"... "
829
<<
std
::
endl
;
830
}
831
832
Timer
timer
;
833
timer
.
reset
();
834
835
Potential
<
double
>
m
;
836
m_inf
->
posterior
(
command
->
chain
,
m
);
837
838
// Compute spent time
839
double
t
=
timer
.
step
();
840
841
if
(
m_verbose
) {
m_log
<<
"Finished."
<<
std
::
endl
; }
842
843
if
(
m_verbose
) {
844
m_log
<<
"# Time in seconds (accuracy ~0.001): "
<<
t
<<
std
::
endl
;
845
}
846
847
// Show results
848
849
if
(
m_verbose
) {
m_log
<<
std
::
endl
; }
850
851
QueryResult
result
;
852
result
.
command
=
query
;
853
result
.
time
=
t
;
854
855
Instantiation
j
(
m
);
856
const
PRMAttribute
<
double
>&
attr
= *(
command
->
chain
.
second
);
857
858
for
(
j
.
setFirst
(); !
j
.
end
();
j
.
inc
()) {
859
// auto label_value = j.val ( attr.type().variable() );
860
auto
label_value
=
j
.
val
(0);
861
std
::
string
label
=
attr
.
type
().
variable
().
label
(
label_value
);
862
float
value
=
float
(
m
.
get
(
j
));
863
864
SingleResult
singleResult
;
865
singleResult
.
label
=
label
;
866
singleResult
.
p
=
value
;
867
868
result
.
values
.
push_back
(
singleResult
);
869
870
if
(
m_verbose
) {
m_log
<<
label
<<
" : "
<<
value
<<
std
::
endl
; }
871
}
872
873
m_results
.
push_back
(
result
);
874
875
if
(
m_verbose
) {
m_log
<<
std
::
endl
; }
876
877
}
catch
(
Exception
&
e
) {
878
GUM_SHOWERROR
(
e
);
879
throw
"something went wrong while infering: "
+
e
.
errorContent
();
880
881
}
catch
(
const
std
::
string
&
msg
) {
addError
(
msg
); }
882
883
///
884
void
O3prmrInterpreter
::
setEngine
(
const
SetEngineCommand
*
command
) {
885
m_engine
=
command
->
value
;
886
}
887
888
///
889
void
O3prmrInterpreter
::
setGndEngine
(
const
SetGndEngineCommand
*
command
) {
890
m_bn_engine
=
command
->
value
;
891
}
892
893
///
894
void
O3prmrInterpreter
::
generateInfEngine
(
const
PRMSystem
<
double
>&
sys
) {
895
if
(
m_verbose
)
896
m_log
<<
"# Building the inference engine... "
<<
std
::
flush
;
897
898
//
899
if
(
m_engine
==
"SVED"
) {
900
m_inf
=
new
SVED
<
double
>(*(
prm
()),
sys
);
901
902
//
903
}
else
if
(
m_engine
==
"SVE"
) {
904
m_inf
=
new
SVE
<
double
>(*(
prm
()),
sys
);
905
906
}
else
{
907
if
(
m_engine
!=
"GRD"
) {
908
addWarning
(
"unkown engine '"
+
m_engine
+
"', use GRD insteed."
);
909
}
910
911
MarginalTargetedInference
<
double
>*
bn_inf
=
nullptr
;
912
if
(
m_bn
) {
delete
m_bn
; }
913
m_bn
=
new
BayesNet
<
double
>();
914
BayesNetFactory
<
double
>
bn_factory
(
m_bn
);
915
916
if
(
m_verbose
)
m_log
<<
"(Grounding the network... "
<<
std
::
flush
;
917
918
sys
.
groundedBN
(
bn_factory
);
919
920
if
(
m_verbose
)
m_log
<<
"Finished)"
<<
std
::
flush
;
921
922
// bn_inf = new LazyPropagation<double>( *m_bn );
923
bn_inf
=
new
VariableElimination
<
double
>(
m_bn
);
924
925
auto
grd_inf
=
new
GroundedInference
<
double
>(*(
prm
()),
sys
);
926
grd_inf
->
setBNInference
(
bn_inf
);
927
m_inf
=
grd_inf
;
928
}
929
930
m_inf_map
.
insert
(&
sys
,
m_inf
);
931
if
(
m_verbose
)
m_log
<<
"Finished."
<<
std
::
endl
;
932
}
933
934
/* **************************************************************************
935
*/
936
937
/// # of errors + warnings
938
Size
O3prmrInterpreter
::
count
()
const
{
return
m_errors
.
count
(); }
939
940
///
941
Size
O3prmrInterpreter
::
errors
()
const
{
return
m_errors
.
error_count
; }
942
943
///
944
Size
O3prmrInterpreter
::
warnings
()
const
{
return
m_errors
.
warning_count
; }
945
946
///
947
ParseError
O3prmrInterpreter
::
error
(
Idx
i
)
const
{
948
if
(
i
>=
count
())
throw
"Index out of bound."
;
949
950
return
m_errors
.
error
(
i
);
951
}
952
953
/// Return container with all errors.
954
ErrorsContainer
O3prmrInterpreter
::
errorsContainer
()
const
{
955
return
m_errors
;
956
}
957
958
///
959
void
O3prmrInterpreter
::
showElegantErrors
(
std
::
ostream
&
o
)
const
{
960
m_errors
.
elegantErrors
(
o
);
961
}
962
963
///
964
void
O3prmrInterpreter
::
showElegantErrorsAndWarnings
(
std
::
ostream
&
o
)
const
{
965
m_errors
.
elegantErrorsAndWarnings
(
o
);
966
}
967
968
///
969
void
O3prmrInterpreter
::
showErrorCounts
(
std
::
ostream
&
o
)
const
{
970
m_errors
.
syntheticResults
(
o
);
971
}
972
973
/* **************************************************************************
974
*/
975
976
///
977
void
O3prmrInterpreter
::
addError
(
std
::
string
msg
) {
978
m_errors
.
addError
(
msg
,
m_context
->
filename
(),
m_current_line
, 0);
979
980
if
(
m_verbose
)
m_log
<<
m_errors
.
last
().
toString
() <<
std
::
endl
;
981
}
982
983
///
984
void
O3prmrInterpreter
::
addWarning
(
std
::
string
msg
) {
985
m_errors
.
addWarning
(
msg
,
m_context
->
filename
(),
m_current_line
, 0);
986
987
if
(
m_verbose
)
m_log
<<
m_errors
.
last
().
toString
() <<
std
::
endl
;
988
}
989
990
}
// namespace o3prmr
991
}
// namespace prm
992
}
// namespace gum
gum::Set::emplace
INLINE void emplace(Args &&... args)
Definition:
set_tpl.h:669
gum::prm::ParamScopeData::ParamScopeData
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)
Definition:
PRMClass_tpl.h:1101
gum::prm::o3prmr
Definition:
O3prmrContext.cpp:35