aGrUM  0.20.2
a C++ library for (probabilistic) graphical models
O3prmrInterpreter.cpp
Go to the documentation of this file.
1 /**
2  *
3  * Copyright 2005-2020 Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation of O3prmReader<double>.
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6), Ni NI, Lionel TORTI & Vincent RENAUDINEAU
27  */
28 
29 #include <agrum/agrum.h>
30 
31 #include <agrum/BN/BayesNet.h>
32 #include <agrum/BN/inference/lazyPropagation.h>
33 #include <agrum/BN/inference/tools/BayesNetInference.h>
34 #include <agrum/BN/inference/variableElimination.h>
35 
36 #include <agrum/PRM/inference/SVE.h>
37 #include <agrum/PRM/inference/SVED.h>
38 #include <agrum/PRM/inference/groundedInference.h>
39 #include <agrum/PRM/o3prmr/O3prmrInterpreter.h>
40 
41 #include <agrum/PRM/o3prmr/cocoR/Parser.h>
42 
43 namespace gum {
44 
45  namespace prm {
46 
47  namespace o3prmr {
48 
49  /* **************************************************************************
50  */
51 
52  /// This constructor create an empty context.
54  m_context(new O3prmrContext< double >()),
55  m_reader(new o3prm::O3prmReader< double >()), m_bn(0), m_inf(0),
56  m_syntax_flag(false), m_verbose(false), m_log(std::cout),
57  m_current_line(-1) {}
58 
59  /// Destructor. Delete current context.
61  delete m_context;
62  if (m_bn) { delete m_bn; }
63  for (auto p: m_inf_map) {
64  delete p.second;
65  }
66  delete m_reader->prm();
67  delete m_reader;
68  }
69 
70  /* **************************************************************************
71  */
72 
73  /// Getter for the context.
75  return m_context;
76  }
77 
78  /// Setter for the context.
80  delete m_context;
81 
82  if (context == 0)
83  m_context = new O3prmrContext< double >();
84  else
86  }
87 
88  /// Root paths to search from there packages.
89  /// Default are './' and one is calculate from request package if any.
91  return m_paths;
92  }
93 
94  /// Root paths to search from there packages.
95  /// Default are './' and one is calculate from request package if any.
97  if (path.length() && path.back() != '/') { path = path + '/'; }
98  if (Directory::isDir(path)) {
100  } else {
101  GUM_ERROR(NotFound, "not a directory");
102  }
103  }
104 
105  /// Root paths to search from there packages.
106  /// Default are './' and one is calculate from request package if any.
108 
109  /// syntax mode don't process anything, just check syntax.
111 
112  /// syntax mode don't process anything, just check syntax.
114 
115  /// verbose mode show more details on the program execution.
116  bool O3prmrInterpreter::isVerboseMode() const { return m_verbose; }
117 
118  /// verbose mode show more details on the program execution.
120 
121  /// Retrieve prm object.
122  const PRM< double >* O3prmrInterpreter::prm() const {
123  return m_reader->prm();
124  }
125 
126  /// Retrieve inference motor object.
127  const PRMInference< double >* O3prmrInterpreter::inference() const {
128  return m_inf;
129  }
130 
131  /// Return a std::vector of QueryResults.
132  /// Each QueryResults is a struct with query command, time and values,
133  /// a std::vector of struct SingleResult, with pair label/value.
135  return m_results;
136  }
137 
138  /**
139  * Parse the file or the command line.
140  * If errors occured, return false. Errors messages can be retrieve be
141  * getErrorsContainer() methods.
142  * If any errors occured, return true.
143  * Requests results can be retrieve be results() methods.
144  * */
146  m_results.clear();
147 
148  try {
150 
151  delete m_context;
152  m_context = new O3prmrContext< double >(filename);
153  O3prmrContext< double > c(filename);
154 
155  // On vérifie la syntaxe
156  unsigned char* buffer = new unsigned char[file_content.length() + 1];
157  strcpy((char*)buffer, file_content.c_str());
158  Scanner s(buffer, int(file_content.length() + 1));
159  Parser p(&s);
160  p.setO3prmrContext(&c);
161  p.Parse();
162 
163  m_errors = p.errors();
164 
165  if (errors() > 0) { return false; }
166 
167  // Set paths to search from.
168  delete m_reader->prm();
169  delete m_reader;
170  m_reader = new o3prm::O3prmReader< double >();
171 
172  for (size_t i = 0; i < m_paths.size(); i++) {
174  }
175 
176  // On vérifie la sémantique.
177  if (!checkSemantic(&c)) { return false; }
178 
179  if (isInSyntaxMode()) {
180  return true;
181  } else {
182  return interpret(&c);
183  }
184  } catch (gum::Exception&) { return false; }
185  }
186 
188  // read entire file into string
190  if (istream) {
191  // get length of file:
192  istream.seekg(0, istream.end);
193  int length = int(istream.tellg());
194  istream.seekg(0, istream.beg);
195 
196  std::string str;
197  str.resize(length, ' '); // reserve space
198  char* begin = &*str.begin();
199 
201  istream.close();
202 
203  return str;
204  }
205  GUM_ERROR(OperationNotAllowed, "Could not open file");
206  }
207 
209  m_results.clear();
210 
211  // On vérifie la syntaxe
212  O3prmrContext< double > c;
213  Scanner s((unsigned char*)line.c_str(), (int)line.length());
214  Parser p(&s);
215  p.setO3prmrContext(&c);
216  p.Parse();
217  m_errors = p.errors();
218 
219  if (errors() > 0) return false;
220 
221  // On vérifie la sémantique.
222  if (!checkSemantic(&c)) return false;
223 
224  if (isInSyntaxMode())
225  return true;
226  else
227  return interpret(&c);
228  }
229 
230  /**
231  * Crée le prm correspondant au contexte courant.
232  * Renvoie true en cas de succès, ou false en cas échéant d'échec
233  * de l'interprétation du contexte (import introuvable ou non défini,
234  * etc).
235  * */
237  if (isVerboseMode())
238  m_log << "## Start interpretation." << std::endl << std::flush;
239 
240  // Don't parse if any syntax errors.
241  if (errors() > 0) return false;
242 
243  // For each session
244  std::vector< O3prmrSession< double >* > sessions = c->sessions();
245 
246  for (const auto session: sessions)
247  for (auto command: session->commands()) {
248  // We process it.
249  bool result = true;
250 
251  try {
252  switch (command->type()) {
254  result = observe((ObserveCommand< double >*)command);
255  break;
256 
258  result = unobserve((UnobserveCommand< double >*)command);
259  break;
260 
263  break;
264 
267  break;
268 
270  query((QueryCommand< double >*)command);
271  break;
272  }
273  } catch (Exception& err) {
274  result = false;
276  } catch (std::string& err) {
277  result = false;
278  addError(err);
279  }
280 
281  // If there was a problem, skip the rest of this session,
282  // unless syntax mode is activated.
283  if (!result) {
284  if (m_verbose)
285  m_log << "Errors : skip the rest of this session." << std::endl;
286 
287  break;
288  }
289  }
290 
291  if (isVerboseMode())
292  m_log << "## End interpretation." << std::endl << std::flush;
293 
294  return errors() == 0;
295  }
296 
297  /* **************************************************************************
298  */
299 
300  /**
301  * Check semantic validity of context.
302  * Import first all import, and check that systems, instances, attributes
303  *and
304  *labels exists.
305  * While checking, prepare data structures for interpretation.
306  * Return true if all is right, false otherwise.
307  *
308  * Note : Stop checking at first error unless syntax mode is activated.
309  * */
311  // Don't parse if any syntax errors.
312  if (errors() > 0) return false;
313 
314  // On importe tous les systèmes.
315  for (const auto command: context->imports()) {
317  // if import doen't succed stop here unless syntax mode is activated.
318  bool succeed = import(context, command->value);
319 
320  if (!succeed && !isInSyntaxMode()) return false;
321 
322  // En cas de succès, on met à jour le contexte global
324  }
325 
326  if (m_verbose)
327  m_log << "## Check semantic for " << context->sessions().size()
328  << " sessions" << std::endl;
329 
330  // On vérifie chaque session
331  for (const auto session: context->sessions()) {
333  O3prmrSession< double >* new_session
334  = new O3prmrSession< double >(sessionName);
335 
336  if (m_verbose)
337  m_log << "## Start session '" << sessionName << "'..." << std::endl
338  << std::endl;
339 
340  for (const auto command: session->commands()) {
341  if (m_verbose)
342  m_log << "# * Going to check command : " << command->toString()
343  << std::endl;
344 
345  // Update the current line (for warnings and errors)
347 
348  // We check it.
349  bool result = true;
350 
351  try {
352  switch (command->type()) {
355  break;
356 
359  break;
360 
362  result = checkObserve((ObserveCommand< double >*)command);
363  break;
364 
367  break;
368 
370  result = checkQuery((QueryCommand< double >*)command);
371  break;
372 
373  default:
374  addError("Error : Unknow command : " + command->toString()
375  + "\n -> Command not processed.");
376  result = false;
377  }
378  } catch (Exception& err) {
379  result = false;
381  } catch (std::string& err) {
382  result = false;
383  addError(err);
384  }
385 
386  // If there was a problem, skip the rest of this session,
387  // unless syntax mode is activated.
388  if (!result && !isInSyntaxMode()) {
389  if (m_verbose)
390  m_log << "Errors : skip the rest of this session." << std::endl;
391 
392  break;
393  }
394 
395  // On l'ajoute au contexte globale
397  }
398 
399  // Ajoute la session au contexte global,
400  // ou à la dernière session.
401  if (sessionName == "default" && m_context->sessions().size() > 0)
402  *(m_context->sessions().back()) += *new_session;
403  else
405 
406  if (m_verbose)
407  m_log << std::endl
408  << "## Session '" << sessionName << "' finished." << std::endl
409  << std::endl
410  << std::endl;
411 
412  // todo : check memory leak
413  // delete new_session; ??
414  }
415 
416  if (isVerboseMode() && errors() != 0)
418 
419  return errors() == 0;
420  }
421 
424  return m_engine == "SVED" || m_engine == "GRD" || m_engine == "SVE";
425  }
426 
429  return m_bn_engine == "VE" || m_bn_engine == "VEBB"
430  || m_bn_engine == "lazy";
431  }
432 
434  try {
437 
438  // Contruct the pair (instance,attribut)
439  const PRMSystem< double >& sys = system(left_val);
440  const PRMInstance< double >& instance
442  const PRMAttribute< double >& attr
444  typename PRMInference< double >::Chain chain
445  = std::make_pair(&instance, &attr);
446 
447  command->system = &sys;
449 
450  // Check label exists for this type.
451  // Potential<double> e;
454  bool found = false;
455 
456  for (i.setFirst(); !i.end(); i.inc()) {
457  if (chain.second->type().variable().label(
458  i.val(chain.second->type().variable()))
459  == right_val) {
460  command->potentiel.set(i, (double)1.0);
461  found = true;
462  } else {
463  command->potentiel.set(i, (double)0.0);
464  }
465  }
466 
467  if (!found) addError(right_val + " is not a label of " + left_val);
468 
469  // else command->potentiel = e;
470 
471  return found;
472 
473  } catch (Exception& err) {
475  } catch (std::string& err) { addError(err); }
476 
477  return false;
478  }
479 
481  try {
483 
484  // Contruct the pair (instance,attribut)
485  const PRMSystem< double >& sys = system(name);
486  const PRMInstance< double >& instance
488  const PRMAttribute< double >& attr
490  // PRMInference<double>::Chain chain = std::make_pair(&instance,
491  // &attr);
492 
493  command->system = &sys;
495 
496  return true;
497 
498  } catch (Exception& err) {
500  } catch (std::string& err) { addError(err); }
501 
502  return false;
503  }
504 
506  try {
508 
509  // Contruct the pair (instance,attribut)
510  const PRMSystem< double >& sys = system(name);
511  const PRMInstance< double >& instance
513  const PRMAttribute< double >& attr
515  // PRMInference<double>::Chain chain = std::make_pair(&instance,
516  // &attr);
517 
518  command->system = &sys;
520 
521  return true;
522 
523  } catch (Exception& err) {
525  } catch (std::string& err) { addError(err); }
526 
527  return false;
528  }
529 
530  // Import the system o3prm file
531  // Return false if any error.
532 
535  try {
536  if (m_verbose) {
537  m_log << "# Loading system '" << import_name << "' => '" << std::flush;
538  }
539 
541 
542  std::replace(import_name.begin(), import_name.end(), '.', '/');
543  import_name += ".o3prm";
544 
545  if (m_verbose) {
546  m_log << import_name << "' ... " << std::endl << std::flush;
547  }
548 
550  bool found = false;
552 
553  // Search in o3prmr file dir.
555 
556  if (!o3prmrFilename.empty()) {
558 
559  if (index != std::string::npos) {
562 
563  if (m_verbose) {
564  m_log << "# Search from filedir '" << import_abs_filename
565  << "' ... " << std::flush;
566  }
567 
569 
570  if (file_test.is_open()) {
571  if (m_verbose) { m_log << "found !" << std::endl << std::flush; }
572 
573  file_test.close();
574  found = true;
575  } else if (m_verbose) {
576  m_log << "not found." << std::endl << std::flush;
577  }
578  }
579  }
580 
581  // Deduce root path from package name.
583 
584  if (!found && !package.empty()) {
585  std::string root;
586 
587  // if filename is not empty, start from it.
589 
590  if (!filename.empty()) {
592 
593  if (size != std::string::npos) {
594  root += filename.substr(0, size + 1); // take with the '/'
595  }
596  }
597 
598  //
599  root += "../";
600  int count = (int)std::count(package.begin(), package.end(), '.');
601 
602  for (int i = 0; i < count; i++)
603  root += "../";
604 
606 
607  if (m_verbose) {
608  m_log << "# Search from package '" << package << "' => '"
609  << import_abs_filename << "' ... " << std::flush;
610  }
611 
613 
614  if (file_test.is_open()) {
615  if (m_verbose) { m_log << "found !" << std::endl << std::flush; }
616 
617  file_test.close();
618  found = true;
619  } else if (m_verbose) {
620  m_log << "not found." << std::endl << std::flush;
621  }
622  }
623 
624  // Search import in all paths.
625  for (const auto& path: m_paths) {
627 
628  if (m_verbose) {
629  m_log << "# Search from classpath '" << import_abs_filename
630  << "' ... " << std::flush;
631  }
632 
634 
635  if (file_test.is_open()) {
636  if (m_verbose) { m_log << " found !" << std::endl << std::flush; }
637 
638  file_test.close();
639  found = true;
640  break;
641  } else if (m_verbose) {
642  m_log << " not found." << std::endl << std::flush;
643  }
644  }
645 
646  if (!found) {
647  if (m_verbose) { m_log << "Finished with errors." << std::endl; }
648 
649  addError("import not found.");
650  return false;
651  }
652 
653  // May throw std::IOError if file does't exist
656 
657  try {
659 
660  // Show errors and warning
661  if (m_verbose
662  && (m_reader->errors() > (unsigned int)previousO3prmError
663  || errors() > previousO3prmrError)) {
664  m_log << "Finished with errors." << std::endl;
665  } else if (m_verbose) {
666  m_log << "Finished." << std::endl;
667  }
668 
669  } catch (const IOError& err) {
670  if (m_verbose) { m_log << "Finished with errors." << std::endl; }
671 
673  }
674 
675  // Add o3prm errors and warnings to o3prmr errors
677  previousO3prmError++) {
679  }
680 
681  return errors() == previousO3prmrError;
682 
683  } catch (const Exception& err) {
684  if (m_verbose) { m_log << "Finished with exceptions." << std::endl; }
685 
687  return false;
688  }
689  }
690 
692  size_t dot = s.find_first_of('.');
693  std::string name = s.substr(0, dot);
694 
695  // We look first for real system, next for alias.
696  if (prm()->isSystem(name)) {
697  s = s.substr(dot + 1);
698  return name;
699  }
700 
701  if (!m_context->aliasToImport(name).empty()) {
702  s = s.substr(dot + 1);
703  return m_context->aliasToImport(name);
704  }
705 
706  while (dot != std::string::npos) {
707  if (prm()->isSystem(name)) {
708  s = s.substr(dot + 1);
709  return name;
710  }
711 
712  dot = s.find('.', dot + 1);
713  name = s.substr(0, dot);
714  }
715 
716  throw "could not find any system in '" + s + "'.";
717  }
718 
719  std::string
721  const PRMSystem< double >& sys) {
722  // We have found system before, so 's' has been stripped.
723  size_t dot = s.find_first_of('.');
724  std::string name = s.substr(0, dot);
725 
726  if (!sys.exists(name))
727  throw "'" + name + "' is not an instance of system '" + sys.name()
728  + "'.";
729 
730  s = s.substr(dot + 1);
731  return name;
732  }
733 
735  const std::string& s,
736  const PRMInstance< double >& instance) {
737  if (!instance.exists(s))
738  throw "'" + s + "' is not an attribute of instance '" + instance.name()
739  + "'.";
740 
741  return s;
742  }
743 
744  // After this method, ident doesn't contains the system name anymore.
746  try {
747  return prm()->getSystem(findSystemName(ident));
748  } catch (const std::string&) {}
749 
750  if ((m_context->mainImport() != 0)
752  return prm()->getSystem(m_context->mainImport()->value);
753 
754  throw "could not find any system or alias in '" + ident
755  + "' and no default alias has been set.";
756  }
757 
758  ///
759 
760  bool
762  const typename PRMInference< double >::Chain& chain = command->chain;
763 
764  // Generate the inference engine if it doesn't exist.
765  if (!m_inf) { generateInfEngine(*(command->system)); }
766 
767  // Prevent from something
768  if (m_inf->hasEvidence(chain))
769  addWarning(command->leftValue + " is already observed");
770 
772 
773  if (m_verbose)
774  m_log << "# Added evidence " << command->rightValue << " over attribute "
775  << command->leftValue << std::endl;
776 
777  return true;
778 
779  } catch (OperationNotAllowed& ex) {
780  addError("something went wrong when adding evidence " + command->rightValue
781  + " over " + command->leftValue + " : " + ex.errorContent());
782  return false;
783 
784  } catch (const std::string& msg) {
785  addError(msg);
786  return false;
787  }
788 
789  ///
790 
792  const UnobserveCommand< double >* command) try {
794  typename PRMInference< double >::Chain chain = command->chain;
795 
796  // Prevent from something
797  if (!m_inf || !m_inf->hasEvidence(chain)) {
798  addWarning(name + " was not observed");
799  } else {
801 
802  if (m_verbose)
803  m_log << "# Removed evidence over attribute " << name << std::endl;
804  }
805 
806  return true;
807 
808  } catch (const std::string& msg) {
809  addError(msg);
810  return false;
811  }
812 
813  ///
814  void O3prmrInterpreter::query(const QueryCommand< double >* command) try {
815  const std::string& query = command->value;
816 
817  if (m_inf_map.exists(command->system)) {
819  } else {
820  m_inf = nullptr;
821  }
822 
823  // Create inference engine if it has not been already created.
824  if (!m_inf) { generateInfEngine(*(command->system)); }
825 
826  // Inference
827  if (m_verbose) {
828  m_log << "# Starting inference over query: " << query << "... "
829  << std::endl;
830  }
831 
832  Timer timer;
833  timer.reset();
834 
835  Potential< double > m;
837 
838  // Compute spent time
839  double t = timer.step();
840 
841  if (m_verbose) { m_log << "Finished." << std::endl; }
842 
843  if (m_verbose) {
844  m_log << "# Time in seconds (accuracy ~0.001): " << t << std::endl;
845  }
846 
847  // Show results
848 
849  if (m_verbose) { m_log << std::endl; }
850 
852  result.command = query;
853  result.time = t;
854 
855  Instantiation j(m);
856  const PRMAttribute< double >& attr = *(command->chain.second);
857 
858  for (j.setFirst(); !j.end(); j.inc()) {
859  // auto label_value = j.val ( attr.type().variable() );
860  auto label_value = j.val(0);
862  float value = float(m.get(j));
863 
866  singleResult.p = value;
867 
869 
870  if (m_verbose) { m_log << label << " : " << value << std::endl; }
871  }
872 
874 
875  if (m_verbose) { m_log << std::endl; }
876 
877  } catch (Exception& e) {
878  GUM_SHOWERROR(e);
879  throw "something went wrong while infering: " + e.errorContent();
880 
881  } catch (const std::string& msg) { addError(msg); }
882 
883  ///
886  }
887 
888  ///
891  }
892 
893  ///
895  if (m_verbose)
896  m_log << "# Building the inference engine... " << std::flush;
897 
898  //
899  if (m_engine == "SVED") {
900  m_inf = new SVED< double >(*(prm()), sys);
901 
902  //
903  } else if (m_engine == "SVE") {
904  m_inf = new SVE< double >(*(prm()), sys);
905 
906  } else {
907  if (m_engine != "GRD") {
908  addWarning("unkown engine '" + m_engine + "', use GRD insteed.");
909  }
910 
911  MarginalTargetedInference< double >* bn_inf = nullptr;
912  if (m_bn) { delete m_bn; }
913  m_bn = new BayesNet< double >();
914  BayesNetFactory< double > bn_factory(m_bn);
915 
916  if (m_verbose) m_log << "(Grounding the network... " << std::flush;
917 
919 
920  if (m_verbose) m_log << "Finished)" << std::flush;
921 
922  // bn_inf = new LazyPropagation<double>( *m_bn );
923  bn_inf = new VariableElimination< double >(m_bn);
924 
925  auto grd_inf = new GroundedInference< double >(*(prm()), sys);
927  m_inf = grd_inf;
928  }
929 
931  if (m_verbose) m_log << "Finished." << std::endl;
932  }
933 
934  /* **************************************************************************
935  */
936 
937  /// # of errors + warnings
938  Size O3prmrInterpreter::count() const { return m_errors.count(); }
939 
940  ///
942 
943  ///
945 
946  ///
948  if (i >= count()) throw "Index out of bound.";
949 
950  return m_errors.error(i);
951  }
952 
953  /// Return container with all errors.
955  return m_errors;
956  }
957 
958  ///
961  }
962 
963  ///
966  }
967 
968  ///
971  }
972 
973  /* **************************************************************************
974  */
975 
976  ///
979 
980  if (m_verbose) m_log << m_errors.last().toString() << std::endl;
981  }
982 
983  ///
986 
987  if (m_verbose) m_log << m_errors.last().toString() << std::endl;
988  }
989 
990  } // namespace o3prmr
991  } // namespace prm
992 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:669
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)